1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
|
use std::time::Duration;
use bincode::{Decode, Encode};
use karyons_core::{
async_util::timeout,
util::{decode, encode, encode_into_slice},
};
use karyons_net::{Connection, NetError};
use crate::{
message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE},
Error, Result,
};
pub trait CodecMsg: Decode + Encode + std::fmt::Debug {}
impl<T: Encode + Decode + std::fmt::Debug> CodecMsg for T {}
/// A Codec working with generic network connections.
///
/// It is responsible for both decoding data received from the network and
/// encoding data before sending it.
pub struct Codec {
conn: Box<dyn Connection>,
}
impl Codec {
/// Creates a new Codec.
pub fn new(conn: Box<dyn Connection>) -> Self {
Self { conn }
}
/// Reads a message of type `NetMsg` from the connection.
///
/// It reads the first 6 bytes as the header of the message, then reads
/// and decodes the remaining message data based on the determined header.
pub async fn read(&self) -> Result<NetMsg> {
// Read 6 bytes to get the header of the incoming message
let mut buf = [0; MSG_HEADER_SIZE];
self.read_exact(&mut buf).await?;
// Decode the header from bytes to NetMsgHeader
let (header, _) = decode::<NetMsgHeader>(&buf)?;
if header.payload_size > MAX_ALLOWED_MSG_SIZE {
return Err(Error::InvalidMsg(
"Message exceeds the maximum allowed size".to_string(),
));
}
// Create a buffer to hold the message based on its length
let mut payload = vec![0; header.payload_size as usize];
self.read_exact(&mut payload).await?;
Ok(NetMsg { header, payload })
}
/// Writes a message of type `T` to the connection.
///
/// Before appending the actual message payload, it calculates the length of
/// the encoded message in bytes and appends this length to the message header.
pub async fn write<T: CodecMsg>(&self, command: NetMsgCmd, msg: &T) -> Result<()> {
let payload = encode(msg)?;
// Create a buffer to hold the message header (6 bytes)
let header_buf = &mut [0; MSG_HEADER_SIZE];
let header = NetMsgHeader {
command,
payload_size: payload.len() as u32,
};
encode_into_slice(&header, header_buf)?;
let mut buffer = vec![];
// Append the header bytes to the buffer
buffer.extend_from_slice(header_buf);
// Append the message payload to the buffer
buffer.extend_from_slice(&payload);
self.write_all(&buffer).await?;
Ok(())
}
/// Reads a message of type `NetMsg` with the given timeout.
pub async fn read_timeout(&self, duration: Duration) -> Result<NetMsg> {
timeout(duration, self.read())
.await
.map_err(|_| NetError::Timeout)?
}
/// Reads the exact number of bytes required to fill `buf`.
async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> {
while !buf.is_empty() {
let n = self.conn.read(buf).await?;
let (_, rest) = std::mem::take(&mut buf).split_at_mut(n);
buf = rest;
if n == 0 {
return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
}
}
Ok(())
}
/// Writes an entire buffer into the connection.
async fn write_all(&self, mut buf: &[u8]) -> Result<()> {
while !buf.is_empty() {
let n = self.conn.write(buf).await?;
let (_, rest) = std::mem::take(&mut buf).split_at(n);
buf = rest;
if n == 0 {
return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
}
}
Ok(())
}
}
|