aboutsummaryrefslogtreecommitdiff
path: root/p2p/src/io_codec.rs
diff options
context:
space:
mode:
authorhozan23 <hozan23@proton.me>2023-11-18 13:36:19 +0300
committerhozan23 <hozan23@proton.me>2023-11-19 04:37:50 +0300
commit938b29d418a9df2f93ee273a394f34adc99ea25d (patch)
treef8adfeede7c6f56091ef6a018820fa0b52f38bf3 /p2p/src/io_codec.rs
parent0d6c8ad2ed66ff7bd1078be9ea7b582262a12d86 (diff)
net: improve Conn API
Diffstat (limited to 'p2p/src/io_codec.rs')
-rw-r--r--p2p/src/io_codec.rs36
1 files changed, 33 insertions, 3 deletions
diff --git a/p2p/src/io_codec.rs b/p2p/src/io_codec.rs
index 4515832..ea62666 100644
--- a/p2p/src/io_codec.rs
+++ b/p2p/src/io_codec.rs
@@ -38,7 +38,7 @@ impl IOCodec {
pub async fn read(&self) -> Result<NetMsg> {
// Read 6 bytes to get the header of the incoming message
let mut buf = [0; MSG_HEADER_SIZE];
- self.conn.recv(&mut buf).await?;
+ self.read_exact(&mut buf).await?;
// Decode the header from bytes to NetMsgHeader
let (header, _) = decode::<NetMsgHeader>(&buf)?;
@@ -51,7 +51,7 @@ impl IOCodec {
// Create a buffer to hold the message based on its length
let mut payload = vec![0; header.payload_size as usize];
- self.conn.recv(&mut payload).await?;
+ self.read_exact(&mut payload).await?;
Ok(NetMsg { header, payload })
}
@@ -77,7 +77,7 @@ impl IOCodec {
// Append the message payload to the buffer
buffer.extend_from_slice(&payload);
- self.conn.send(&buffer).await?;
+ self.write_all(&buffer).await?;
Ok(())
}
@@ -99,4 +99,34 @@ impl IOCodec {
.await
.map_err(|_| NetError::Timeout)?
}
+
+ /// Reads the exact number of bytes required to fill `buf`.
+ async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> {
+ while !buf.is_empty() {
+ let n = self.conn.read(buf).await?;
+ let (_, rest) = std::mem::take(&mut buf).split_at_mut(n);
+ buf = rest;
+
+ if n == 0 {
+ return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Writes an entire buffer into the connection.
+ async fn write_all(&self, mut buf: &[u8]) -> Result<()> {
+ while !buf.is_empty() {
+ let n = self.conn.write(buf).await?;
+ let (_, rest) = std::mem::take(&mut buf).split_at(n);
+ buf = rest;
+
+ if n == 0 {
+ return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
+ }
+ }
+
+ Ok(())
+ }
}