aboutsummaryrefslogtreecommitdiff
path: root/p2p/src/codec.rs
diff options
context:
space:
mode:
Diffstat (limited to 'p2p/src/codec.rs')
-rw-r--r--p2p/src/codec.rs149
1 files changed, 49 insertions, 100 deletions
diff --git a/p2p/src/codec.rs b/p2p/src/codec.rs
index 726a2f7..3d0f323 100644
--- a/p2p/src/codec.rs
+++ b/p2p/src/codec.rs
@@ -1,120 +1,69 @@
-use std::time::Duration;
+use karyon_core::util::{decode, encode, encode_into_slice};
-use bincode::{Decode, Encode};
-
-use karyon_core::{
- async_util::timeout,
- util::{decode, encode, encode_into_slice},
-};
-
-use karyon_net::{Connection, NetError};
-
-use crate::{
- message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE},
- Error, Result,
+use karyon_net::{
+ codec::{Codec, Decoder, Encoder, LengthCodec},
+ Result,
};
-pub trait CodecMsg: Decode + Encode + std::fmt::Debug {}
-impl<T: Encode + Decode + std::fmt::Debug> CodecMsg for T {}
+use crate::message::{NetMsg, RefreshMsg};
-/// A Codec working with generic network connections.
-///
-/// It is responsible for both decoding data received from the network and
-/// encoding data before sending it.
-pub struct Codec {
- conn: Box<dyn Connection>,
+#[derive(Clone)]
+pub struct NetMsgCodec {
+ inner_codec: LengthCodec,
}
-impl Codec {
- /// Creates a new Codec.
- pub fn new(conn: Box<dyn Connection>) -> Self {
- Self { conn }
- }
-
- /// Reads a message of type `NetMsg` from the connection.
- ///
- /// It reads the first 6 bytes as the header of the message, then reads
- /// and decodes the remaining message data based on the determined header.
- pub async fn read(&self) -> Result<NetMsg> {
- // Read 6 bytes to get the header of the incoming message
- let mut buf = [0; MSG_HEADER_SIZE];
- self.read_exact(&mut buf).await?;
-
- // Decode the header from bytes to NetMsgHeader
- let (header, _) = decode::<NetMsgHeader>(&buf)?;
-
- if header.payload_size > MAX_ALLOWED_MSG_SIZE {
- return Err(Error::InvalidMsg(
- "Message exceeds the maximum allowed size".to_string(),
- ));
+impl NetMsgCodec {
+ pub fn new() -> Self {
+ Self {
+ inner_codec: LengthCodec {},
}
-
- // Create a buffer to hold the message based on its length
- let mut payload = vec![0; header.payload_size as usize];
- self.read_exact(&mut payload).await?;
-
- Ok(NetMsg { header, payload })
}
+}
- /// Writes a message of type `T` to the connection.
- ///
- /// Before appending the actual message payload, it calculates the length of
- /// the encoded message in bytes and appends this length to the message header.
- pub async fn write<T: CodecMsg>(&self, command: NetMsgCmd, msg: &T) -> Result<()> {
- let payload = encode(msg)?;
-
- // Create a buffer to hold the message header (6 bytes)
- let header_buf = &mut [0; MSG_HEADER_SIZE];
- let header = NetMsgHeader {
- command,
- payload_size: payload.len() as u32,
- };
- encode_into_slice(&header, header_buf)?;
-
- let mut buffer = vec![];
- // Append the header bytes to the buffer
- buffer.extend_from_slice(header_buf);
- // Append the message payload to the buffer
- buffer.extend_from_slice(&payload);
-
- self.write_all(&buffer).await?;
- Ok(())
- }
+impl Codec for NetMsgCodec {
+ type Item = NetMsg;
+}
- /// Reads a message of type `NetMsg` with the given timeout.
- pub async fn read_timeout(&self, duration: Duration) -> Result<NetMsg> {
- timeout(duration, self.read())
- .await
- .map_err(|_| NetError::Timeout)?
+impl Encoder for NetMsgCodec {
+ type EnItem = NetMsg;
+ fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result<usize> {
+ let src = encode(src)?;
+ self.inner_codec.encode(&src, dst)
}
+}
- /// Reads the exact number of bytes required to fill `buf`.
- async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> {
- while !buf.is_empty() {
- let n = self.conn.read(buf).await?;
- let (_, rest) = std::mem::take(&mut buf).split_at_mut(n);
- buf = rest;
-
- if n == 0 {
- return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
+impl Decoder for NetMsgCodec {
+ type DeItem = NetMsg;
+ fn decode(&self, src: &mut [u8]) -> Result<Option<(usize, Self::DeItem)>> {
+ match self.inner_codec.decode(src)? {
+ Some((n, s)) => {
+ let (m, _) = decode::<Self::DeItem>(&s)?;
+ Ok(Some((n, m)))
}
+ None => Ok(None),
}
-
- Ok(())
}
+}
- /// Writes an entire buffer into the connection.
- async fn write_all(&self, mut buf: &[u8]) -> Result<()> {
- while !buf.is_empty() {
- let n = self.conn.write(buf).await?;
- let (_, rest) = std::mem::take(&mut buf).split_at(n);
- buf = rest;
+#[derive(Clone)]
+pub struct RefreshMsgCodec {}
- if n == 0 {
- return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
- }
- }
+impl Codec for RefreshMsgCodec {
+ type Item = RefreshMsg;
+}
+
+impl Encoder for RefreshMsgCodec {
+ type EnItem = RefreshMsg;
+ fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result<usize> {
+ let n = encode_into_slice(src, dst)?;
+ Ok(n)
+ }
+}
- Ok(())
+impl Decoder for RefreshMsgCodec {
+ type DeItem = RefreshMsg;
+ fn decode(&self, src: &mut [u8]) -> Result<Option<(usize, Self::DeItem)>> {
+ let (m, n) = decode::<Self::DeItem>(src)?;
+ Ok(Some((n, m)))
}
}