aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhozan23 <hozan23@proton.me>2023-11-18 13:36:19 +0300
committerhozan23 <hozan23@proton.me>2023-11-19 04:37:50 +0300
commit938b29d418a9df2f93ee273a394f34adc99ea25d (patch)
treef8adfeede7c6f56091ef6a018820fa0b52f38bf3
parent0d6c8ad2ed66ff7bd1078be9ea7b582262a12d86 (diff)
net: improve Conn API
-rw-r--r--net/src/connection.rs10
-rw-r--r--net/src/transports/tcp.rs17
-rw-r--r--net/src/transports/udp.rs12
-rw-r--r--net/src/transports/unix.rs17
-rw-r--r--p2p/src/discovery/refresh.rs4
-rw-r--r--p2p/src/io_codec.rs36
-rw-r--r--p2p/src/message.rs2
7 files changed, 66 insertions, 32 deletions
diff --git a/net/src/connection.rs b/net/src/connection.rs
index 518ccfd..53bcdeb 100644
--- a/net/src/connection.rs
+++ b/net/src/connection.rs
@@ -20,10 +20,10 @@ pub trait Connection: Send + Sync {
fn local_endpoint(&self) -> Result<Endpoint>;
/// Reads data from this connection.
- async fn recv(&self, buf: &mut [u8]) -> Result<usize>;
+ async fn read(&self, buf: &mut [u8]) -> Result<usize>;
- /// Sends data to this connection
- async fn send(&self, buf: &[u8]) -> Result<usize>;
+ /// Writes data to this connection
+ async fn write(&self, buf: &[u8]) -> Result<usize>;
}
/// Connects to the provided endpoint.
@@ -40,10 +40,10 @@ pub trait Connection: Send + Sync {
///
/// let conn = dial(&endpoint).await.unwrap();
///
-/// conn.send(b"MSG").await.unwrap();
+/// conn.write(b"MSG").await.unwrap();
///
/// let mut buffer = [0;32];
-/// conn.recv(&mut buffer).await.unwrap();
+/// conn.read(&mut buffer).await.unwrap();
/// };
///
/// ```
diff --git a/net/src/transports/tcp.rs b/net/src/transports/tcp.rs
index 5ff7b28..9b2b928 100644
--- a/net/src/transports/tcp.rs
+++ b/net/src/transports/tcp.rs
@@ -10,7 +10,7 @@ use crate::{
connection::Connection,
endpoint::{Addr, Endpoint, Port},
listener::Listener,
- Result,
+ Error, Result,
};
/// TCP network connection implementations of the `Connection` trait.
@@ -42,14 +42,17 @@ impl Connection for TcpConn {
Ok(Endpoint::new_tcp_addr(&self.inner.local_addr()?))
}
- async fn recv(&self, buf: &mut [u8]) -> Result<usize> {
- self.read.lock().await.read_exact(buf).await?;
- Ok(buf.len())
+ async fn read(&self, buf: &mut [u8]) -> Result<usize> {
+ self.read.lock().await.read(buf).await.map_err(Error::from)
}
- async fn send(&self, buf: &[u8]) -> Result<usize> {
- self.write.lock().await.write_all(buf).await?;
- Ok(buf.len())
+ async fn write(&self, buf: &[u8]) -> Result<usize> {
+ self.write
+ .lock()
+ .await
+ .write(buf)
+ .await
+ .map_err(Error::from)
}
}
diff --git a/net/src/transports/udp.rs b/net/src/transports/udp.rs
index 27fb9ae..2050226 100644
--- a/net/src/transports/udp.rs
+++ b/net/src/transports/udp.rs
@@ -6,7 +6,7 @@ use smol::net::UdpSocket;
use crate::{
connection::Connection,
endpoint::{Addr, Endpoint, Port},
- Result,
+ Error, Result,
};
/// UDP network connection implementations of the `Connection` trait.
@@ -47,14 +47,12 @@ impl Connection for UdpConn {
Ok(Endpoint::new_udp_addr(&self.inner.local_addr()?))
}
- async fn recv(&self, buf: &mut [u8]) -> Result<usize> {
- let size = self.inner.recv(buf).await?;
- Ok(size)
+ async fn read(&self, buf: &mut [u8]) -> Result<usize> {
+ self.inner.recv(buf).await.map_err(Error::from)
}
- async fn send(&self, buf: &[u8]) -> Result<usize> {
- let size = self.inner.send(buf).await?;
- Ok(size)
+ async fn write(&self, buf: &[u8]) -> Result<usize> {
+ self.inner.send(buf).await.map_err(Error::from)
}
}
diff --git a/net/src/transports/unix.rs b/net/src/transports/unix.rs
index c89832e..1a32311 100644
--- a/net/src/transports/unix.rs
+++ b/net/src/transports/unix.rs
@@ -6,7 +6,7 @@ use smol::{
net::unix::{UnixListener, UnixStream},
};
-use crate::{connection::Connection, endpoint::Endpoint, listener::Listener, Result};
+use crate::{connection::Connection, endpoint::Endpoint, listener::Listener, Error, Result};
/// Unix domain socket implementations of the `Connection` trait.
pub struct UnixConn {
@@ -37,14 +37,17 @@ impl Connection for UnixConn {
Ok(Endpoint::new_unix_addr(&self.inner.local_addr()?))
}
- async fn recv(&self, buf: &mut [u8]) -> Result<usize> {
- self.read.lock().await.read_exact(buf).await?;
- Ok(buf.len())
+ async fn read(&self, buf: &mut [u8]) -> Result<usize> {
+ self.read.lock().await.read(buf).await.map_err(Error::from)
}
- async fn send(&self, buf: &[u8]) -> Result<usize> {
- self.write.lock().await.write_all(buf).await?;
- Ok(buf.len())
+ async fn write(&self, buf: &[u8]) -> Result<usize> {
+ self.write
+ .lock()
+ .await
+ .write(buf)
+ .await
+ .map_err(Error::from)
}
}
diff --git a/p2p/src/discovery/refresh.rs b/p2p/src/discovery/refresh.rs
index 1ced266..b9b7bae 100644
--- a/p2p/src/discovery/refresh.rs
+++ b/p2p/src/discovery/refresh.rs
@@ -270,11 +270,11 @@ impl RefreshService {
let ping_msg = PingMsg(nonce);
let buffer = encode(&ping_msg)?;
- conn.send(&buffer).await?;
+ conn.write(&buffer).await?;
let buf = &mut [0; PINGMSG_SIZE];
let t = Duration::from_secs(self.config.refresh_response_timeout);
- timeout(t, conn.recv(buf)).await??;
+ timeout(t, conn.read(buf)).await??;
let (pong_msg, _) = decode::<PongMsg>(buf)?;
diff --git a/p2p/src/io_codec.rs b/p2p/src/io_codec.rs
index 4515832..ea62666 100644
--- a/p2p/src/io_codec.rs
+++ b/p2p/src/io_codec.rs
@@ -38,7 +38,7 @@ impl IOCodec {
pub async fn read(&self) -> Result<NetMsg> {
// Read 6 bytes to get the header of the incoming message
let mut buf = [0; MSG_HEADER_SIZE];
- self.conn.recv(&mut buf).await?;
+ self.read_exact(&mut buf).await?;
// Decode the header from bytes to NetMsgHeader
let (header, _) = decode::<NetMsgHeader>(&buf)?;
@@ -51,7 +51,7 @@ impl IOCodec {
// Create a buffer to hold the message based on its length
let mut payload = vec![0; header.payload_size as usize];
- self.conn.recv(&mut payload).await?;
+ self.read_exact(&mut payload).await?;
Ok(NetMsg { header, payload })
}
@@ -77,7 +77,7 @@ impl IOCodec {
// Append the message payload to the buffer
buffer.extend_from_slice(&payload);
- self.conn.send(&buffer).await?;
+ self.write_all(&buffer).await?;
Ok(())
}
@@ -99,4 +99,34 @@ impl IOCodec {
.await
.map_err(|_| NetError::Timeout)?
}
+
+ /// Reads the exact number of bytes required to fill `buf`.
+ async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> {
+ while !buf.is_empty() {
+ let n = self.conn.read(buf).await?;
+ let (_, rest) = std::mem::take(&mut buf).split_at_mut(n);
+ buf = rest;
+
+ if n == 0 {
+ return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Writes an entire buffer into the connection.
+ async fn write_all(&self, mut buf: &[u8]) -> Result<()> {
+ while !buf.is_empty() {
+ let n = self.conn.write(buf).await?;
+ let (_, rest) = std::mem::take(&mut buf).split_at(n);
+ buf = rest;
+
+ if n == 0 {
+ return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
+ }
+ }
+
+ Ok(())
+ }
}
diff --git a/p2p/src/message.rs b/p2p/src/message.rs
index d3691c2..9e73809 100644
--- a/p2p/src/message.rs
+++ b/p2p/src/message.rs
@@ -10,7 +10,7 @@ use crate::{protocol::ProtocolID, routing_table::Entry, utils::VersionInt, PeerI
pub const MSG_HEADER_SIZE: usize = 6;
/// The maximum allowed size for a message in bytes.
-pub const MAX_ALLOWED_MSG_SIZE: u32 = 1000000;
+pub const MAX_ALLOWED_MSG_SIZE: u32 = 1024 * 1024; // 1MB
/// Defines the main message in the karyon p2p network.
///