diff options
author | hozan23 <hozan23@proton.me> | 2023-11-28 22:41:33 +0300 |
---|---|---|
committer | hozan23 <hozan23@proton.me> | 2023-11-28 22:41:33 +0300 |
commit | 98a1de91a2dae06323558422c239e5a45fc86e7b (patch) | |
tree | 38c640248824fcb3b4ca5ba12df47c13ef26ccda /p2p/src/codec.rs | |
parent | ca2a5f8bbb6983d9555abd10eaaf86950b794957 (diff) |
implement TLS for inbound and outbound connections
Diffstat (limited to 'p2p/src/codec.rs')
-rw-r--r-- | p2p/src/codec.rs | 120 |
1 files changed, 120 insertions, 0 deletions
diff --git a/p2p/src/codec.rs b/p2p/src/codec.rs new file mode 100644 index 0000000..e521824 --- /dev/null +++ b/p2p/src/codec.rs @@ -0,0 +1,120 @@ +use std::time::Duration; + +use bincode::{Decode, Encode}; + +use karyons_core::{ + async_util::timeout, + util::{decode, encode, encode_into_slice}, +}; + +use karyons_net::{Connection, NetError}; + +use crate::{ + message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE}, + Error, Result, +}; + +pub trait CodecMsg: Decode + Encode + std::fmt::Debug {} +impl<T: Encode + Decode + std::fmt::Debug> CodecMsg for T {} + +/// A Codec working with generic network connections. +/// +/// It is responsible for both decoding data received from the network and +/// encoding data before sending it. +pub struct Codec { + conn: Box<dyn Connection>, +} + +impl Codec { + /// Creates a new Codec. + pub fn new(conn: Box<dyn Connection>) -> Self { + Self { conn } + } + + /// Reads a message of type `NetMsg` from the connection. + /// + /// It reads the first 6 bytes as the header of the message, then reads + /// and decodes the remaining message data based on the determined header. + pub async fn read(&self) -> Result<NetMsg> { + // Read 6 bytes to get the header of the incoming message + let mut buf = [0; MSG_HEADER_SIZE]; + self.read_exact(&mut buf).await?; + + // Decode the header from bytes to NetMsgHeader + let (header, _) = decode::<NetMsgHeader>(&buf)?; + + if header.payload_size > MAX_ALLOWED_MSG_SIZE { + return Err(Error::InvalidMsg( + "Message exceeds the maximum allowed size".to_string(), + )); + } + + // Create a buffer to hold the message based on its length + let mut payload = vec![0; header.payload_size as usize]; + self.read_exact(&mut payload).await?; + + Ok(NetMsg { header, payload }) + } + + /// Writes a message of type `T` to the connection. + /// + /// Before appending the actual message payload, it calculates the length of + /// the encoded message in bytes and appends this length to the message header. + pub async fn write<T: CodecMsg>(&self, command: NetMsgCmd, msg: &T) -> Result<()> { + let payload = encode(msg)?; + + // Create a buffer to hold the message header (6 bytes) + let header_buf = &mut [0; MSG_HEADER_SIZE]; + let header = NetMsgHeader { + command, + payload_size: payload.len() as u32, + }; + encode_into_slice(&header, header_buf)?; + + let mut buffer = vec![]; + // Append the header bytes to the buffer + buffer.extend_from_slice(header_buf); + // Append the message payload to the buffer + buffer.extend_from_slice(&payload); + + self.write_all(&buffer).await?; + Ok(()) + } + + /// Reads a message of type `NetMsg` with the given timeout. + pub async fn read_timeout(&self, duration: Duration) -> Result<NetMsg> { + timeout(duration, self.read()) + .await + .map_err(|_| NetError::Timeout)? + } + + /// Reads the exact number of bytes required to fill `buf`. + async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> { + while !buf.is_empty() { + let n = self.conn.read(buf).await?; + let (_, rest) = std::mem::take(&mut buf).split_at_mut(n); + buf = rest; + + if n == 0 { + return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); + } + } + + Ok(()) + } + + /// Writes an entire buffer into the connection. + async fn write_all(&self, mut buf: &[u8]) -> Result<()> { + while !buf.is_empty() { + let n = self.conn.write(buf).await?; + let (_, rest) = std::mem::take(&mut buf).split_at(n); + buf = rest; + + if n == 0 { + return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); + } + } + + Ok(()) + } +} |