From 938b29d418a9df2f93ee273a394f34adc99ea25d Mon Sep 17 00:00:00 2001 From: hozan23 Date: Sat, 18 Nov 2023 13:36:19 +0300 Subject: net: improve Conn API --- net/src/connection.rs | 10 +++++----- net/src/transports/tcp.rs | 17 ++++++++++------- net/src/transports/udp.rs | 12 +++++------- net/src/transports/unix.rs | 17 ++++++++++------- p2p/src/discovery/refresh.rs | 4 ++-- p2p/src/io_codec.rs | 36 +++++++++++++++++++++++++++++++++--- p2p/src/message.rs | 2 +- 7 files changed, 66 insertions(+), 32 deletions(-) diff --git a/net/src/connection.rs b/net/src/connection.rs index 518ccfd..53bcdeb 100644 --- a/net/src/connection.rs +++ b/net/src/connection.rs @@ -20,10 +20,10 @@ pub trait Connection: Send + Sync { fn local_endpoint(&self) -> Result; /// Reads data from this connection. - async fn recv(&self, buf: &mut [u8]) -> Result; + async fn read(&self, buf: &mut [u8]) -> Result; - /// Sends data to this connection - async fn send(&self, buf: &[u8]) -> Result; + /// Writes data to this connection + async fn write(&self, buf: &[u8]) -> Result; } /// Connects to the provided endpoint. @@ -40,10 +40,10 @@ pub trait Connection: Send + Sync { /// /// let conn = dial(&endpoint).await.unwrap(); /// -/// conn.send(b"MSG").await.unwrap(); +/// conn.write(b"MSG").await.unwrap(); /// /// let mut buffer = [0;32]; -/// conn.recv(&mut buffer).await.unwrap(); +/// conn.read(&mut buffer).await.unwrap(); /// }; /// /// ``` diff --git a/net/src/transports/tcp.rs b/net/src/transports/tcp.rs index 5ff7b28..9b2b928 100644 --- a/net/src/transports/tcp.rs +++ b/net/src/transports/tcp.rs @@ -10,7 +10,7 @@ use crate::{ connection::Connection, endpoint::{Addr, Endpoint, Port}, listener::Listener, - Result, + Error, Result, }; /// TCP network connection implementations of the `Connection` trait. @@ -42,14 +42,17 @@ impl Connection for TcpConn { Ok(Endpoint::new_tcp_addr(&self.inner.local_addr()?)) } - async fn recv(&self, buf: &mut [u8]) -> Result { - self.read.lock().await.read_exact(buf).await?; - Ok(buf.len()) + async fn read(&self, buf: &mut [u8]) -> Result { + self.read.lock().await.read(buf).await.map_err(Error::from) } - async fn send(&self, buf: &[u8]) -> Result { - self.write.lock().await.write_all(buf).await?; - Ok(buf.len()) + async fn write(&self, buf: &[u8]) -> Result { + self.write + .lock() + .await + .write(buf) + .await + .map_err(Error::from) } } diff --git a/net/src/transports/udp.rs b/net/src/transports/udp.rs index 27fb9ae..2050226 100644 --- a/net/src/transports/udp.rs +++ b/net/src/transports/udp.rs @@ -6,7 +6,7 @@ use smol::net::UdpSocket; use crate::{ connection::Connection, endpoint::{Addr, Endpoint, Port}, - Result, + Error, Result, }; /// UDP network connection implementations of the `Connection` trait. @@ -47,14 +47,12 @@ impl Connection for UdpConn { Ok(Endpoint::new_udp_addr(&self.inner.local_addr()?)) } - async fn recv(&self, buf: &mut [u8]) -> Result { - let size = self.inner.recv(buf).await?; - Ok(size) + async fn read(&self, buf: &mut [u8]) -> Result { + self.inner.recv(buf).await.map_err(Error::from) } - async fn send(&self, buf: &[u8]) -> Result { - let size = self.inner.send(buf).await?; - Ok(size) + async fn write(&self, buf: &[u8]) -> Result { + self.inner.send(buf).await.map_err(Error::from) } } diff --git a/net/src/transports/unix.rs b/net/src/transports/unix.rs index c89832e..1a32311 100644 --- a/net/src/transports/unix.rs +++ b/net/src/transports/unix.rs @@ -6,7 +6,7 @@ use smol::{ net::unix::{UnixListener, UnixStream}, }; -use crate::{connection::Connection, endpoint::Endpoint, listener::Listener, Result}; +use crate::{connection::Connection, endpoint::Endpoint, listener::Listener, Error, Result}; /// Unix domain socket implementations of the `Connection` trait. pub struct UnixConn { @@ -37,14 +37,17 @@ impl Connection for UnixConn { Ok(Endpoint::new_unix_addr(&self.inner.local_addr()?)) } - async fn recv(&self, buf: &mut [u8]) -> Result { - self.read.lock().await.read_exact(buf).await?; - Ok(buf.len()) + async fn read(&self, buf: &mut [u8]) -> Result { + self.read.lock().await.read(buf).await.map_err(Error::from) } - async fn send(&self, buf: &[u8]) -> Result { - self.write.lock().await.write_all(buf).await?; - Ok(buf.len()) + async fn write(&self, buf: &[u8]) -> Result { + self.write + .lock() + .await + .write(buf) + .await + .map_err(Error::from) } } diff --git a/p2p/src/discovery/refresh.rs b/p2p/src/discovery/refresh.rs index 1ced266..b9b7bae 100644 --- a/p2p/src/discovery/refresh.rs +++ b/p2p/src/discovery/refresh.rs @@ -270,11 +270,11 @@ impl RefreshService { let ping_msg = PingMsg(nonce); let buffer = encode(&ping_msg)?; - conn.send(&buffer).await?; + conn.write(&buffer).await?; let buf = &mut [0; PINGMSG_SIZE]; let t = Duration::from_secs(self.config.refresh_response_timeout); - timeout(t, conn.recv(buf)).await??; + timeout(t, conn.read(buf)).await??; let (pong_msg, _) = decode::(buf)?; diff --git a/p2p/src/io_codec.rs b/p2p/src/io_codec.rs index 4515832..ea62666 100644 --- a/p2p/src/io_codec.rs +++ b/p2p/src/io_codec.rs @@ -38,7 +38,7 @@ impl IOCodec { pub async fn read(&self) -> Result { // Read 6 bytes to get the header of the incoming message let mut buf = [0; MSG_HEADER_SIZE]; - self.conn.recv(&mut buf).await?; + self.read_exact(&mut buf).await?; // Decode the header from bytes to NetMsgHeader let (header, _) = decode::(&buf)?; @@ -51,7 +51,7 @@ impl IOCodec { // Create a buffer to hold the message based on its length let mut payload = vec![0; header.payload_size as usize]; - self.conn.recv(&mut payload).await?; + self.read_exact(&mut payload).await?; Ok(NetMsg { header, payload }) } @@ -77,7 +77,7 @@ impl IOCodec { // Append the message payload to the buffer buffer.extend_from_slice(&payload); - self.conn.send(&buffer).await?; + self.write_all(&buffer).await?; Ok(()) } @@ -99,4 +99,34 @@ impl IOCodec { .await .map_err(|_| NetError::Timeout)? } + + /// Reads the exact number of bytes required to fill `buf`. + async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> { + while !buf.is_empty() { + let n = self.conn.read(buf).await?; + let (_, rest) = std::mem::take(&mut buf).split_at_mut(n); + buf = rest; + + if n == 0 { + return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); + } + } + + Ok(()) + } + + /// Writes an entire buffer into the connection. + async fn write_all(&self, mut buf: &[u8]) -> Result<()> { + while !buf.is_empty() { + let n = self.conn.write(buf).await?; + let (_, rest) = std::mem::take(&mut buf).split_at(n); + buf = rest; + + if n == 0 { + return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); + } + } + + Ok(()) + } } diff --git a/p2p/src/message.rs b/p2p/src/message.rs index d3691c2..9e73809 100644 --- a/p2p/src/message.rs +++ b/p2p/src/message.rs @@ -10,7 +10,7 @@ use crate::{protocol::ProtocolID, routing_table::Entry, utils::VersionInt, PeerI pub const MSG_HEADER_SIZE: usize = 6; /// The maximum allowed size for a message in bytes. -pub const MAX_ALLOWED_MSG_SIZE: u32 = 1000000; +pub const MAX_ALLOWED_MSG_SIZE: u32 = 1024 * 1024; // 1MB /// Defines the main message in the karyon p2p network. /// -- cgit v1.2.3