aboutsummaryrefslogtreecommitdiff
path: root/p2p/src/io_codec.rs
blob: ea62666472d06c8002cafeae4aa3385b392d0738 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
use std::time::Duration;

use bincode::{Decode, Encode};

use karyons_core::{
    async_utils::timeout,
    utils::{decode, encode, encode_into_slice},
};

use karyons_net::{Connection, NetError};

use crate::{
    message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE},
    Error, Result,
};

pub trait CodecMsg: Decode + Encode + std::fmt::Debug {}
impl<T: Encode + Decode + std::fmt::Debug> CodecMsg for T {}

/// I/O codec working with generic network connections.
///
/// It is responsible for both decoding data received from the network and
/// encoding data before sending it.
pub struct IOCodec {
    conn: Box<dyn Connection>,
}

impl IOCodec {
    /// Creates a new IOCodec.
    pub fn new(conn: Box<dyn Connection>) -> Self {
        Self { conn }
    }

    /// Reads a message of type `NetMsg` from the connection.
    ///
    /// It reads the first 6 bytes as the header of the message, then reads
    /// and decodes the remaining message data based on the determined header.
    pub async fn read(&self) -> Result<NetMsg> {
        // Read 6 bytes to get the header of the incoming message
        let mut buf = [0; MSG_HEADER_SIZE];
        self.read_exact(&mut buf).await?;

        // Decode the header from bytes to NetMsgHeader
        let (header, _) = decode::<NetMsgHeader>(&buf)?;

        if header.payload_size > MAX_ALLOWED_MSG_SIZE {
            return Err(Error::InvalidMsg(
                "Message exceeds the maximum allowed size".to_string(),
            ));
        }

        // Create a buffer to hold the message based on its length
        let mut payload = vec![0; header.payload_size as usize];
        self.read_exact(&mut payload).await?;

        Ok(NetMsg { header, payload })
    }

    /// Writes a message of type `T` to the connection.
    ///
    /// Before appending the actual message payload, it calculates the length of
    /// the encoded message in bytes and appends this length to the message header.
    pub async fn write<T: CodecMsg>(&self, command: NetMsgCmd, msg: &T) -> Result<()> {
        let payload = encode(msg)?;

        // Create a buffer to hold the message header (6 bytes)
        let header_buf = &mut [0; MSG_HEADER_SIZE];
        let header = NetMsgHeader {
            command,
            payload_size: payload.len() as u32,
        };
        encode_into_slice(&header, header_buf)?;

        let mut buffer = vec![];
        // Append the header bytes to the buffer
        buffer.extend_from_slice(header_buf);
        // Append the message payload to the buffer
        buffer.extend_from_slice(&payload);

        self.write_all(&buffer).await?;
        Ok(())
    }

    /// Reads a message of type `NetMsg` with the given timeout.
    pub async fn read_timeout(&self, duration: Duration) -> Result<NetMsg> {
        timeout(duration, self.read())
            .await
            .map_err(|_| NetError::Timeout)?
    }

    /// Writes a message of type `T` with the given timeout.
    pub async fn write_timeout<T: CodecMsg>(
        &self,
        command: NetMsgCmd,
        msg: &T,
        duration: Duration,
    ) -> Result<()> {
        timeout(duration, self.write(command, msg))
            .await
            .map_err(|_| NetError::Timeout)?
    }

    /// Reads the exact number of bytes required to fill `buf`.
    async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> {
        while !buf.is_empty() {
            let n = self.conn.read(buf).await?;
            let (_, rest) = std::mem::take(&mut buf).split_at_mut(n);
            buf = rest;

            if n == 0 {
                return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
            }
        }

        Ok(())
    }

    /// Writes an entire buffer into the connection.
    async fn write_all(&self, mut buf: &[u8]) -> Result<()> {
        while !buf.is_empty() {
            let n = self.conn.write(buf).await?;
            let (_, rest) = std::mem::take(&mut buf).split_at(n);
            buf = rest;

            if n == 0 {
                return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
            }
        }

        Ok(())
    }
}