From 0992071a7f1a36424bcfaf1fbc84541ea041df1a Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 11 Apr 2024 10:19:20 +0200 Subject: add support for tokio & improve net crate api --- p2p/src/codec.rs | 149 ++++++++++++++++++------------------------------------- 1 file changed, 49 insertions(+), 100 deletions(-) (limited to 'p2p/src/codec.rs') diff --git a/p2p/src/codec.rs b/p2p/src/codec.rs index 726a2f7..3d0f323 100644 --- a/p2p/src/codec.rs +++ b/p2p/src/codec.rs @@ -1,120 +1,69 @@ -use std::time::Duration; +use karyon_core::util::{decode, encode, encode_into_slice}; -use bincode::{Decode, Encode}; - -use karyon_core::{ - async_util::timeout, - util::{decode, encode, encode_into_slice}, -}; - -use karyon_net::{Connection, NetError}; - -use crate::{ - message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE}, - Error, Result, +use karyon_net::{ + codec::{Codec, Decoder, Encoder, LengthCodec}, + Result, }; -pub trait CodecMsg: Decode + Encode + std::fmt::Debug {} -impl CodecMsg for T {} +use crate::message::{NetMsg, RefreshMsg}; -/// A Codec working with generic network connections. -/// -/// It is responsible for both decoding data received from the network and -/// encoding data before sending it. -pub struct Codec { - conn: Box, +#[derive(Clone)] +pub struct NetMsgCodec { + inner_codec: LengthCodec, } -impl Codec { - /// Creates a new Codec. - pub fn new(conn: Box) -> Self { - Self { conn } - } - - /// Reads a message of type `NetMsg` from the connection. - /// - /// It reads the first 6 bytes as the header of the message, then reads - /// and decodes the remaining message data based on the determined header. - pub async fn read(&self) -> Result { - // Read 6 bytes to get the header of the incoming message - let mut buf = [0; MSG_HEADER_SIZE]; - self.read_exact(&mut buf).await?; - - // Decode the header from bytes to NetMsgHeader - let (header, _) = decode::(&buf)?; - - if header.payload_size > MAX_ALLOWED_MSG_SIZE { - return Err(Error::InvalidMsg( - "Message exceeds the maximum allowed size".to_string(), - )); +impl NetMsgCodec { + pub fn new() -> Self { + Self { + inner_codec: LengthCodec {}, } - - // Create a buffer to hold the message based on its length - let mut payload = vec![0; header.payload_size as usize]; - self.read_exact(&mut payload).await?; - - Ok(NetMsg { header, payload }) } +} - /// Writes a message of type `T` to the connection. - /// - /// Before appending the actual message payload, it calculates the length of - /// the encoded message in bytes and appends this length to the message header. - pub async fn write(&self, command: NetMsgCmd, msg: &T) -> Result<()> { - let payload = encode(msg)?; - - // Create a buffer to hold the message header (6 bytes) - let header_buf = &mut [0; MSG_HEADER_SIZE]; - let header = NetMsgHeader { - command, - payload_size: payload.len() as u32, - }; - encode_into_slice(&header, header_buf)?; - - let mut buffer = vec![]; - // Append the header bytes to the buffer - buffer.extend_from_slice(header_buf); - // Append the message payload to the buffer - buffer.extend_from_slice(&payload); - - self.write_all(&buffer).await?; - Ok(()) - } +impl Codec for NetMsgCodec { + type Item = NetMsg; +} - /// Reads a message of type `NetMsg` with the given timeout. - pub async fn read_timeout(&self, duration: Duration) -> Result { - timeout(duration, self.read()) - .await - .map_err(|_| NetError::Timeout)? +impl Encoder for NetMsgCodec { + type EnItem = NetMsg; + fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result { + let src = encode(src)?; + self.inner_codec.encode(&src, dst) } +} - /// Reads the exact number of bytes required to fill `buf`. - async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> { - while !buf.is_empty() { - let n = self.conn.read(buf).await?; - let (_, rest) = std::mem::take(&mut buf).split_at_mut(n); - buf = rest; - - if n == 0 { - return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); +impl Decoder for NetMsgCodec { + type DeItem = NetMsg; + fn decode(&self, src: &mut [u8]) -> Result> { + match self.inner_codec.decode(src)? { + Some((n, s)) => { + let (m, _) = decode::(&s)?; + Ok(Some((n, m))) } + None => Ok(None), } - - Ok(()) } +} - /// Writes an entire buffer into the connection. - async fn write_all(&self, mut buf: &[u8]) -> Result<()> { - while !buf.is_empty() { - let n = self.conn.write(buf).await?; - let (_, rest) = std::mem::take(&mut buf).split_at(n); - buf = rest; +#[derive(Clone)] +pub struct RefreshMsgCodec {} - if n == 0 { - return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); - } - } +impl Codec for RefreshMsgCodec { + type Item = RefreshMsg; +} + +impl Encoder for RefreshMsgCodec { + type EnItem = RefreshMsg; + fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result { + let n = encode_into_slice(src, dst)?; + Ok(n) + } +} - Ok(()) +impl Decoder for RefreshMsgCodec { + type DeItem = RefreshMsg; + fn decode(&self, src: &mut [u8]) -> Result> { + let (m, n) = decode::(src)?; + Ok(Some((n, m))) } } -- cgit v1.2.3