diff options
| author | hozan23 <hozan23@karyontech.net> | 2024-04-11 10:19:20 +0200 | 
|---|---|---|
| committer | hozan23 <hozan23@karyontech.net> | 2024-05-19 13:51:30 +0200 | 
| commit | 0992071a7f1a36424bcfaf1fbc84541ea041df1a (patch) | |
| tree | 961d73218af672797d49f899289bef295bc56493 /p2p/src/codec.rs | |
| parent | a69917ecd8272a4946cfd12c75bf8f8c075b0e50 (diff) | |
add support for tokio & improve net crate api
Diffstat (limited to 'p2p/src/codec.rs')
| -rw-r--r-- | p2p/src/codec.rs | 149 | 
1 files changed, 49 insertions, 100 deletions
| diff --git a/p2p/src/codec.rs b/p2p/src/codec.rs index 726a2f7..3d0f323 100644 --- a/p2p/src/codec.rs +++ b/p2p/src/codec.rs @@ -1,120 +1,69 @@ -use std::time::Duration; +use karyon_core::util::{decode, encode, encode_into_slice}; -use bincode::{Decode, Encode}; - -use karyon_core::{ -    async_util::timeout, -    util::{decode, encode, encode_into_slice}, -}; - -use karyon_net::{Connection, NetError}; - -use crate::{ -    message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE}, -    Error, Result, +use karyon_net::{ +    codec::{Codec, Decoder, Encoder, LengthCodec}, +    Result,  }; -pub trait CodecMsg: Decode + Encode + std::fmt::Debug {} -impl<T: Encode + Decode + std::fmt::Debug> CodecMsg for T {} +use crate::message::{NetMsg, RefreshMsg}; -/// A Codec working with generic network connections. -/// -/// It is responsible for both decoding data received from the network and -/// encoding data before sending it. -pub struct Codec { -    conn: Box<dyn Connection>, +#[derive(Clone)] +pub struct NetMsgCodec { +    inner_codec: LengthCodec,  } -impl Codec { -    /// Creates a new Codec. -    pub fn new(conn: Box<dyn Connection>) -> Self { -        Self { conn } -    } - -    /// Reads a message of type `NetMsg` from the connection. -    /// -    /// It reads the first 6 bytes as the header of the message, then reads -    /// and decodes the remaining message data based on the determined header. -    pub async fn read(&self) -> Result<NetMsg> { -        // Read 6 bytes to get the header of the incoming message -        let mut buf = [0; MSG_HEADER_SIZE]; -        self.read_exact(&mut buf).await?; - -        // Decode the header from bytes to NetMsgHeader -        let (header, _) = decode::<NetMsgHeader>(&buf)?; - -        if header.payload_size > MAX_ALLOWED_MSG_SIZE { -            return Err(Error::InvalidMsg( -                "Message exceeds the maximum allowed size".to_string(), -            )); +impl NetMsgCodec { +    pub fn new() -> Self { +        Self { +            inner_codec: LengthCodec {},          } - -        // Create a buffer to hold the message based on its length -        let mut payload = vec![0; header.payload_size as usize]; -        self.read_exact(&mut payload).await?; - -        Ok(NetMsg { header, payload })      } +} -    /// Writes a message of type `T` to the connection. -    /// -    /// Before appending the actual message payload, it calculates the length of -    /// the encoded message in bytes and appends this length to the message header. -    pub async fn write<T: CodecMsg>(&self, command: NetMsgCmd, msg: &T) -> Result<()> { -        let payload = encode(msg)?; - -        // Create a buffer to hold the message header (6 bytes) -        let header_buf = &mut [0; MSG_HEADER_SIZE]; -        let header = NetMsgHeader { -            command, -            payload_size: payload.len() as u32, -        }; -        encode_into_slice(&header, header_buf)?; - -        let mut buffer = vec![]; -        // Append the header bytes to the buffer -        buffer.extend_from_slice(header_buf); -        // Append the message payload to the buffer -        buffer.extend_from_slice(&payload); - -        self.write_all(&buffer).await?; -        Ok(()) -    } +impl Codec for NetMsgCodec { +    type Item = NetMsg; +} -    /// Reads a message of type `NetMsg` with the given timeout. -    pub async fn read_timeout(&self, duration: Duration) -> Result<NetMsg> { -        timeout(duration, self.read()) -            .await -            .map_err(|_| NetError::Timeout)? +impl Encoder for NetMsgCodec { +    type EnItem = NetMsg; +    fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result<usize> { +        let src = encode(src)?; +        self.inner_codec.encode(&src, dst)      } +} -    /// Reads the exact number of bytes required to fill `buf`. -    async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> { -        while !buf.is_empty() { -            let n = self.conn.read(buf).await?; -            let (_, rest) = std::mem::take(&mut buf).split_at_mut(n); -            buf = rest; - -            if n == 0 { -                return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); +impl Decoder for NetMsgCodec { +    type DeItem = NetMsg; +    fn decode(&self, src: &mut [u8]) -> Result<Option<(usize, Self::DeItem)>> { +        match self.inner_codec.decode(src)? { +            Some((n, s)) => { +                let (m, _) = decode::<Self::DeItem>(&s)?; +                Ok(Some((n, m)))              } +            None => Ok(None),          } - -        Ok(())      } +} -    /// Writes an entire buffer into the connection. -    async fn write_all(&self, mut buf: &[u8]) -> Result<()> { -        while !buf.is_empty() { -            let n = self.conn.write(buf).await?; -            let (_, rest) = std::mem::take(&mut buf).split_at(n); -            buf = rest; +#[derive(Clone)] +pub struct RefreshMsgCodec {} -            if n == 0 { -                return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); -            } -        } +impl Codec for RefreshMsgCodec { +    type Item = RefreshMsg; +} + +impl Encoder for RefreshMsgCodec { +    type EnItem = RefreshMsg; +    fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result<usize> { +        let n = encode_into_slice(src, dst)?; +        Ok(n) +    } +} -        Ok(()) +impl Decoder for RefreshMsgCodec { +    type DeItem = RefreshMsg; +    fn decode(&self, src: &mut [u8]) -> Result<Option<(usize, Self::DeItem)>> { +        let (m, n) = decode::<Self::DeItem>(src)?; +        Ok(Some((n, m)))      }  } | 
