aboutsummaryrefslogtreecommitdiff
path: root/p2p/src/codec.rs
blob: e52182423192b27f2bf3a6a1c68def30af766d44 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
use std::time::Duration;

use bincode::{Decode, Encode};

use karyons_core::{
    async_util::timeout,
    util::{decode, encode, encode_into_slice},
};

use karyons_net::{Connection, NetError};

use crate::{
    message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE},
    Error, Result,
};

pub trait CodecMsg: Decode + Encode + std::fmt::Debug {}
impl<T: Encode + Decode + std::fmt::Debug> CodecMsg for T {}

/// A Codec working with generic network connections.
///
/// It is responsible for both decoding data received from the network and
/// encoding data before sending it.
pub struct Codec {
    conn: Box<dyn Connection>,
}

impl Codec {
    /// Creates a new Codec.
    pub fn new(conn: Box<dyn Connection>) -> Self {
        Self { conn }
    }

    /// Reads a message of type `NetMsg` from the connection.
    ///
    /// It reads the first 6 bytes as the header of the message, then reads
    /// and decodes the remaining message data based on the determined header.
    pub async fn read(&self) -> Result<NetMsg> {
        // Read 6 bytes to get the header of the incoming message
        let mut buf = [0; MSG_HEADER_SIZE];
        self.read_exact(&mut buf).await?;

        // Decode the header from bytes to NetMsgHeader
        let (header, _) = decode::<NetMsgHeader>(&buf)?;

        if header.payload_size > MAX_ALLOWED_MSG_SIZE {
            return Err(Error::InvalidMsg(
                "Message exceeds the maximum allowed size".to_string(),
            ));
        }

        // Create a buffer to hold the message based on its length
        let mut payload = vec![0; header.payload_size as usize];
        self.read_exact(&mut payload).await?;

        Ok(NetMsg { header, payload })
    }

    /// Writes a message of type `T` to the connection.
    ///
    /// Before appending the actual message payload, it calculates the length of
    /// the encoded message in bytes and appends this length to the message header.
    pub async fn write<T: CodecMsg>(&self, command: NetMsgCmd, msg: &T) -> Result<()> {
        let payload = encode(msg)?;

        // Create a buffer to hold the message header (6 bytes)
        let header_buf = &mut [0; MSG_HEADER_SIZE];
        let header = NetMsgHeader {
            command,
            payload_size: payload.len() as u32,
        };
        encode_into_slice(&header, header_buf)?;

        let mut buffer = vec![];
        // Append the header bytes to the buffer
        buffer.extend_from_slice(header_buf);
        // Append the message payload to the buffer
        buffer.extend_from_slice(&payload);

        self.write_all(&buffer).await?;
        Ok(())
    }

    /// Reads a message of type `NetMsg` with the given timeout.
    pub async fn read_timeout(&self, duration: Duration) -> Result<NetMsg> {
        timeout(duration, self.read())
            .await
            .map_err(|_| NetError::Timeout)?
    }

    /// Reads the exact number of bytes required to fill `buf`.
    async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> {
        while !buf.is_empty() {
            let n = self.conn.read(buf).await?;
            let (_, rest) = std::mem::take(&mut buf).split_at_mut(n);
            buf = rest;

            if n == 0 {
                return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
            }
        }

        Ok(())
    }

    /// Writes an entire buffer into the connection.
    async fn write_all(&self, mut buf: &[u8]) -> Result<()> {
        while !buf.is_empty() {
            let n = self.conn.write(buf).await?;
            let (_, rest) = std::mem::take(&mut buf).split_at(n);
            buf = rest;

            if n == 0 {
                return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
            }
        }

        Ok(())
    }
}