From df6aad2be4c6c5d11483f20e62d41e71f0ac989e Mon Sep 17 00:00:00 2001 From: hozan23 Date: Wed, 13 Mar 2024 12:33:34 +0100 Subject: net: major cleanup and improvement of the crate api --- net/src/connection.rs | 6 +++--- net/src/lib.rs | 7 +------ net/src/listener.rs | 4 ++-- net/src/transports/tcp.rs | 4 ++-- net/src/transports/tls.rs | 31 +++++++++---------------------- net/src/transports/udp.rs | 4 ++-- net/src/transports/unix.rs | 4 ++-- p2p/src/connector.rs | 8 +++++--- p2p/src/discovery/refresh.rs | 19 ++++++++----------- p2p/src/listener.rs | 14 ++++++++------ 10 files changed, 42 insertions(+), 59 deletions(-) diff --git a/net/src/connection.rs b/net/src/connection.rs index ea73a39..fa4640f 100644 --- a/net/src/connection.rs +++ b/net/src/connection.rs @@ -57,9 +57,9 @@ pub trait Connection: Send + Sync { /// pub async fn dial(endpoint: &Endpoint) -> Result { match endpoint { - Endpoint::Tcp(_, _) => Ok(Box::new(tcp::dial_tcp(endpoint).await?)), - Endpoint::Udp(_, _) => Ok(Box::new(udp::dial_udp(endpoint).await?)), - Endpoint::Unix(addr) => Ok(Box::new(unix::dial_unix(addr).await?)), + Endpoint::Tcp(_, _) => Ok(Box::new(tcp::dial(endpoint).await?)), + Endpoint::Udp(_, _) => Ok(Box::new(udp::dial(endpoint).await?)), + Endpoint::Unix(addr) => Ok(Box::new(unix::dial(addr).await?)), _ => Err(Error::InvalidEndpoint(endpoint.to_string())), } } diff --git a/net/src/lib.rs b/net/src/lib.rs index 5b9bdd7..5f1c8a6 100644 --- a/net/src/lib.rs +++ b/net/src/lib.rs @@ -8,12 +8,7 @@ pub use { connection::{dial, Conn, Connection, ToConn}, endpoint::{Addr, Endpoint, Port}, listener::{listen, ConnListener, Listener, ToListener}, - transports::{ - tcp::{dial_tcp, listen_tcp, TcpConn}, - tls, - udp::{dial_udp, listen_udp, UdpConn}, - unix::{dial_unix, listen_unix, UnixConn}, - }, + transports::{tcp, tls, udp, unix}, }; use error::{Error, Result}; diff --git a/net/src/listener.rs b/net/src/listener.rs index 7f6709a..4511212 100644 --- a/net/src/listener.rs +++ b/net/src/listener.rs @@ -39,8 +39,8 @@ pub trait ConnListener: Send + Sync { /// ``` pub async fn listen(endpoint: &Endpoint) -> Result> { match endpoint { - Endpoint::Tcp(_, _) => Ok(Box::new(tcp::listen_tcp(endpoint).await?)), - Endpoint::Unix(addr) => Ok(Box::new(unix::listen_unix(addr)?)), + Endpoint::Tcp(_, _) => Ok(Box::new(tcp::listen(endpoint).await?)), + Endpoint::Unix(addr) => Ok(Box::new(unix::listen(addr)?)), _ => Err(Error::InvalidEndpoint(endpoint.to_string())), } } diff --git a/net/src/transports/tcp.rs b/net/src/transports/tcp.rs index af50c10..21fce3d 100644 --- a/net/src/transports/tcp.rs +++ b/net/src/transports/tcp.rs @@ -71,7 +71,7 @@ impl ConnListener for TcpListener { } /// Connects to the given TCP address and port. -pub async fn dial_tcp(endpoint: &Endpoint) -> Result { +pub async fn dial(endpoint: &Endpoint) -> Result { let addr = SocketAddr::try_from(endpoint.clone())?; let conn = TcpStream::connect(addr).await?; conn.set_nodelay(true)?; @@ -79,7 +79,7 @@ pub async fn dial_tcp(endpoint: &Endpoint) -> Result { } /// Listens on the given TCP address and port. -pub async fn listen_tcp(endpoint: &Endpoint) -> Result { +pub async fn listen(endpoint: &Endpoint) -> Result { let addr = SocketAddr::try_from(endpoint.clone())?; let listener = TcpListener::bind(addr).await?; Ok(listener) diff --git a/net/src/transports/tls.rs b/net/src/transports/tls.rs index 53b4566..476f495 100644 --- a/net/src/transports/tls.rs +++ b/net/src/transports/tls.rs @@ -59,7 +59,7 @@ impl Connection for TlsConn { } /// Connects to the given TLS address and port. -pub async fn dial_tls( +pub async fn dial( endpoint: &Endpoint, config: rustls::ClientConfig, dns_name: &'static str, @@ -76,36 +76,20 @@ pub async fn dial_tls( Ok(TlsConn::new(sock, TlsStream::Client(conn))) } -/// Connects to the given TLS endpoint, returns `Conn` ([`Connection`]). -pub async fn dial( - endpoint: &Endpoint, - config: rustls::ClientConfig, - dns_name: &'static str, -) -> Result> { - match endpoint { - Endpoint::Tcp(..) | Endpoint::Tls(..) => {} - _ => return Err(Error::InvalidEndpoint(endpoint.to_string())), - } - - dial_tls(endpoint, config, dns_name) - .await - .map(|c| Box::new(c) as Box) -} - /// Tls network listener implementation of the `Listener` [`ConnListener`] trait. pub struct TlsListener { + inner: TcpListener, acceptor: TlsAcceptor, - listener: TcpListener, } #[async_trait] impl ConnListener for TlsListener { fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_tls_addr(&self.listener.local_addr()?)) + Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?)) } async fn accept(&self) -> Result> { - let (sock, _) = self.listener.accept().await?; + let (sock, _) = self.inner.accept().await?; sock.set_nodelay(true)?; let conn = self.acceptor.accept(sock.clone()).await?; Ok(Box::new(TlsConn::new(sock, TlsStream::Server(conn)))) @@ -113,11 +97,14 @@ impl ConnListener for TlsListener { } /// Listens on the given TLS address and port. -pub async fn listen_tls(endpoint: &Endpoint, config: rustls::ServerConfig) -> Result { +pub async fn listen(endpoint: &Endpoint, config: rustls::ServerConfig) -> Result { let addr = SocketAddr::try_from(endpoint.clone())?; let acceptor = TlsAcceptor::from(Arc::new(config)); let listener = TcpListener::bind(addr).await?; - Ok(TlsListener { acceptor, listener }) + Ok(TlsListener { + acceptor, + inner: listener, + }) } impl From> for Box { diff --git a/net/src/transports/udp.rs b/net/src/transports/udp.rs index 991b1fd..bd1fe83 100644 --- a/net/src/transports/udp.rs +++ b/net/src/transports/udp.rs @@ -57,7 +57,7 @@ impl Connection for UdpConn { } /// Connects to the given UDP address and port. -pub async fn dial_udp(endpoint: &Endpoint) -> Result { +pub async fn dial(endpoint: &Endpoint) -> Result { let addr = SocketAddr::try_from(endpoint.clone())?; // Let the operating system assign an available port to this socket @@ -67,7 +67,7 @@ pub async fn dial_udp(endpoint: &Endpoint) -> Result { } /// Listens on the given UDP address and port. -pub async fn listen_udp(endpoint: &Endpoint) -> Result { +pub async fn listen(endpoint: &Endpoint) -> Result { let addr = SocketAddr::try_from(endpoint.clone())?; let conn = UdpSocket::bind(addr).await?; let udp_conn = UdpConn::new(conn); diff --git a/net/src/transports/unix.rs b/net/src/transports/unix.rs index 3867040..494e104 100644 --- a/net/src/transports/unix.rs +++ b/net/src/transports/unix.rs @@ -69,13 +69,13 @@ impl ConnListener for UnixListener { } /// Connects to the given Unix socket path. -pub async fn dial_unix(path: &String) -> Result { +pub async fn dial(path: &String) -> Result { let conn = UnixStream::connect(path).await?; Ok(UnixConn::new(conn)) } /// Listens on the given Unix socket path. -pub fn listen_unix(path: &String) -> Result { +pub fn listen(path: &String) -> Result { let listener = UnixListener::bind(path)?; Ok(listener) } diff --git a/p2p/src/connector.rs b/p2p/src/connector.rs index e83d8da..41839ab 100644 --- a/p2p/src/connector.rs +++ b/p2p/src/connector.rs @@ -7,7 +7,7 @@ use karyon_core::{ crypto::KeyPair, GlobalExecutor, }; -use karyon_net::{dial, tls, Conn, Endpoint, NetError}; +use karyon_net::{tcp, tls, Conn, Endpoint, NetError}; use crate::{ monitor::{ConnEvent, Monitor}, @@ -142,9 +142,11 @@ impl Connector { async fn dial(&self, endpoint: &Endpoint, peer_id: &Option) -> Result { if self.enable_tls { let tls_config = tls_client_config(&self.key_pair, peer_id.clone())?; - tls::dial(endpoint, tls_config, DNS_NAME).await + tls::dial(endpoint, tls_config, DNS_NAME) + .await + .map(|l| Box::new(l) as Conn) } else { - dial(endpoint).await + tcp::dial(endpoint).await.map(|l| Box::new(l) as Conn) } .map_err(Error::KaryonNet) } diff --git a/p2p/src/discovery/refresh.rs b/p2p/src/discovery/refresh.rs index 882a93e..bfcab56 100644 --- a/p2p/src/discovery/refresh.rs +++ b/p2p/src/discovery/refresh.rs @@ -15,7 +15,7 @@ use karyon_core::{ GlobalExecutor, }; -use karyon_net::{dial_udp, listen_udp, Addr, Connection, Endpoint, NetError, Port, UdpConn}; +use karyon_net::{udp, Connection, Endpoint, NetError}; /// Maximum failures for an entry before removing it from the routing table. pub const MAX_FAILURES: u32 = 3; @@ -82,12 +82,10 @@ impl RefreshService { pub async fn start(self: &Arc) -> Result<()> { if let Some(endpoint) = &self.listen_endpoint { let endpoint = endpoint.read().await; - let addr = endpoint.addr()?; - let port = self.config.discovery_port; let selfc = self.clone(); self.task_group - .spawn(selfc.listen_loop(addr.clone(), port), |res| async move { + .spawn(selfc.listen_loop(endpoint.clone()), |res| async move { if let TaskResult::Completed(Err(err)) = res { error!("Listen loop stopped: {err}"); } @@ -195,8 +193,8 @@ impl RefreshService { /// specified in the Config, with backoff between each retry. async fn connect(&self, entry: &Entry) -> Result<()> { let mut retry = 0; - let endpoint = Endpoint::Ws(entry.addr.clone(), entry.discovery_port); - let conn = dial_udp(&endpoint).await?; + let endpoint = Endpoint::Udp(entry.addr.clone(), entry.discovery_port); + let conn = udp::dial(&endpoint).await?; let backoff = Backoff::new(100, 5000); while retry < self.config.refresh_connect_retries { match self.send_ping_msg(&conn).await { @@ -216,9 +214,8 @@ impl RefreshService { /// Set up a UDP listener and start listening for Ping messages from other /// peers. - async fn listen_loop(self: Arc, addr: Addr, port: Port) -> Result<()> { - let endpoint = Endpoint::Udp(addr.clone(), port); - let conn = match listen_udp(&endpoint).await { + async fn listen_loop(self: Arc, endpoint: Endpoint) -> Result<()> { + let conn = match udp::listen(&endpoint).await { Ok(c) => { self.monitor .notify(&ConnEvent::Listening(endpoint.clone()).into()) @@ -244,7 +241,7 @@ impl RefreshService { } /// Listen to receive a Ping message and respond with a Pong message. - async fn listen_to_ping_msg(&self, conn: &UdpConn) -> Result<()> { + async fn listen_to_ping_msg(&self, conn: &udp::UdpConn) -> Result<()> { let mut buf = [0; PINGMSG_SIZE]; let (_, endpoint) = conn.recv_from(&mut buf).await?; @@ -266,7 +263,7 @@ impl RefreshService { } /// Sends a Ping msg and wait to receive the Pong message. - async fn send_ping_msg(&self, conn: &UdpConn) -> Result<()> { + async fn send_ping_msg(&self, conn: &udp::UdpConn) -> Result<()> { let mut nonce: [u8; 32] = [0; 32]; RngCore::fill_bytes(&mut OsRng, &mut nonce); diff --git a/p2p/src/listener.rs b/p2p/src/listener.rs index 254e4e6..17aa187 100644 --- a/p2p/src/listener.rs +++ b/p2p/src/listener.rs @@ -8,7 +8,7 @@ use karyon_core::{ GlobalExecutor, }; -use karyon_net::{listen, tls, Conn, ConnListener, Endpoint}; +use karyon_net::{tcp, tls, Conn, ConnListener, Endpoint}; use crate::{ monitor::{ConnEvent, Monitor}, @@ -67,7 +67,7 @@ impl Listener { where Fut: Future> + Send + 'static, { - let listener = match self.listend(&endpoint).await { + let listener = match self.listen(&endpoint).await { Ok(listener) => { self.monitor .notify(&ConnEvent::Listening(endpoint.clone()).into()) @@ -152,14 +152,16 @@ impl Listener { } } - async fn listend(&self, endpoint: &Endpoint) -> Result> { + async fn listen(&self, endpoint: &Endpoint) -> Result { if self.enable_tls { let tls_config = tls_server_config(&self.key_pair)?; - tls::listen_tls(endpoint, tls_config) + tls::listen(endpoint, tls_config) .await - .map(|l| Box::new(l) as Box) + .map(|l| Box::new(l) as karyon_net::Listener) } else { - listen(endpoint).await + tcp::listen(endpoint) + .await + .map(|l| Box::new(l) as karyon_net::Listener) } .map_err(Error::KaryonNet) } -- cgit v1.2.3