diff options
author | hozan23 <hozan23@karyontech.net> | 2024-04-11 10:19:20 +0200 |
---|---|---|
committer | hozan23 <hozan23@karyontech.net> | 2024-05-19 13:51:30 +0200 |
commit | 0992071a7f1a36424bcfaf1fbc84541ea041df1a (patch) | |
tree | 961d73218af672797d49f899289bef295bc56493 /p2p/src/codec.rs | |
parent | a69917ecd8272a4946cfd12c75bf8f8c075b0e50 (diff) |
add support for tokio & improve net crate api
Diffstat (limited to 'p2p/src/codec.rs')
-rw-r--r-- | p2p/src/codec.rs | 149 |
1 files changed, 49 insertions, 100 deletions
diff --git a/p2p/src/codec.rs b/p2p/src/codec.rs index 726a2f7..3d0f323 100644 --- a/p2p/src/codec.rs +++ b/p2p/src/codec.rs @@ -1,120 +1,69 @@ -use std::time::Duration; +use karyon_core::util::{decode, encode, encode_into_slice}; -use bincode::{Decode, Encode}; - -use karyon_core::{ - async_util::timeout, - util::{decode, encode, encode_into_slice}, -}; - -use karyon_net::{Connection, NetError}; - -use crate::{ - message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE}, - Error, Result, +use karyon_net::{ + codec::{Codec, Decoder, Encoder, LengthCodec}, + Result, }; -pub trait CodecMsg: Decode + Encode + std::fmt::Debug {} -impl<T: Encode + Decode + std::fmt::Debug> CodecMsg for T {} +use crate::message::{NetMsg, RefreshMsg}; -/// A Codec working with generic network connections. -/// -/// It is responsible for both decoding data received from the network and -/// encoding data before sending it. -pub struct Codec { - conn: Box<dyn Connection>, +#[derive(Clone)] +pub struct NetMsgCodec { + inner_codec: LengthCodec, } -impl Codec { - /// Creates a new Codec. - pub fn new(conn: Box<dyn Connection>) -> Self { - Self { conn } - } - - /// Reads a message of type `NetMsg` from the connection. - /// - /// It reads the first 6 bytes as the header of the message, then reads - /// and decodes the remaining message data based on the determined header. - pub async fn read(&self) -> Result<NetMsg> { - // Read 6 bytes to get the header of the incoming message - let mut buf = [0; MSG_HEADER_SIZE]; - self.read_exact(&mut buf).await?; - - // Decode the header from bytes to NetMsgHeader - let (header, _) = decode::<NetMsgHeader>(&buf)?; - - if header.payload_size > MAX_ALLOWED_MSG_SIZE { - return Err(Error::InvalidMsg( - "Message exceeds the maximum allowed size".to_string(), - )); +impl NetMsgCodec { + pub fn new() -> Self { + Self { + inner_codec: LengthCodec {}, } - - // Create a buffer to hold the message based on its length - let mut payload = vec![0; header.payload_size as usize]; - self.read_exact(&mut payload).await?; - - Ok(NetMsg { header, payload }) } +} - /// Writes a message of type `T` to the connection. - /// - /// Before appending the actual message payload, it calculates the length of - /// the encoded message in bytes and appends this length to the message header. - pub async fn write<T: CodecMsg>(&self, command: NetMsgCmd, msg: &T) -> Result<()> { - let payload = encode(msg)?; - - // Create a buffer to hold the message header (6 bytes) - let header_buf = &mut [0; MSG_HEADER_SIZE]; - let header = NetMsgHeader { - command, - payload_size: payload.len() as u32, - }; - encode_into_slice(&header, header_buf)?; - - let mut buffer = vec![]; - // Append the header bytes to the buffer - buffer.extend_from_slice(header_buf); - // Append the message payload to the buffer - buffer.extend_from_slice(&payload); - - self.write_all(&buffer).await?; - Ok(()) - } +impl Codec for NetMsgCodec { + type Item = NetMsg; +} - /// Reads a message of type `NetMsg` with the given timeout. - pub async fn read_timeout(&self, duration: Duration) -> Result<NetMsg> { - timeout(duration, self.read()) - .await - .map_err(|_| NetError::Timeout)? +impl Encoder for NetMsgCodec { + type EnItem = NetMsg; + fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result<usize> { + let src = encode(src)?; + self.inner_codec.encode(&src, dst) } +} - /// Reads the exact number of bytes required to fill `buf`. - async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> { - while !buf.is_empty() { - let n = self.conn.read(buf).await?; - let (_, rest) = std::mem::take(&mut buf).split_at_mut(n); - buf = rest; - - if n == 0 { - return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); +impl Decoder for NetMsgCodec { + type DeItem = NetMsg; + fn decode(&self, src: &mut [u8]) -> Result<Option<(usize, Self::DeItem)>> { + match self.inner_codec.decode(src)? { + Some((n, s)) => { + let (m, _) = decode::<Self::DeItem>(&s)?; + Ok(Some((n, m))) } + None => Ok(None), } - - Ok(()) } +} - /// Writes an entire buffer into the connection. - async fn write_all(&self, mut buf: &[u8]) -> Result<()> { - while !buf.is_empty() { - let n = self.conn.write(buf).await?; - let (_, rest) = std::mem::take(&mut buf).split_at(n); - buf = rest; +#[derive(Clone)] +pub struct RefreshMsgCodec {} - if n == 0 { - return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); - } - } +impl Codec for RefreshMsgCodec { + type Item = RefreshMsg; +} + +impl Encoder for RefreshMsgCodec { + type EnItem = RefreshMsg; + fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result<usize> { + let n = encode_into_slice(src, dst)?; + Ok(n) + } +} - Ok(()) +impl Decoder for RefreshMsgCodec { + type DeItem = RefreshMsg; + fn decode(&self, src: &mut [u8]) -> Result<Option<(usize, Self::DeItem)>> { + let (m, n) = decode::<Self::DeItem>(src)?; + Ok(Some((n, m))) } } |