From 76e952830302271e07a4be9df6dfaa1c11e3e675 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Wed, 13 Mar 2024 03:35:09 +0100 Subject: net: pass `Endpoint` to dial and listen functions --- net/src/connection.rs | 4 ++-- net/src/endpoint.rs | 33 +++++++++++++++++++++++++-------- net/src/listener.rs | 2 +- net/src/transports/tcp.rs | 17 +++++++++-------- net/src/transports/tls.rs | 39 ++++++++++----------------------------- net/src/transports/udp.rs | 14 +++++++------- p2p/src/discovery/refresh.rs | 5 +++-- p2p/src/listener.rs | 4 +++- 8 files changed, 60 insertions(+), 58 deletions(-) diff --git a/net/src/connection.rs b/net/src/connection.rs index 73606a2..ea73a39 100644 --- a/net/src/connection.rs +++ b/net/src/connection.rs @@ -57,8 +57,8 @@ pub trait Connection: Send + Sync { /// pub async fn dial(endpoint: &Endpoint) -> Result { match endpoint { - Endpoint::Tcp(addr, port) => Ok(Box::new(tcp::dial_tcp(addr, port).await?)), - Endpoint::Udp(addr, port) => Ok(Box::new(udp::dial_udp(addr, port).await?)), + Endpoint::Tcp(_, _) => Ok(Box::new(tcp::dial_tcp(endpoint).await?)), + Endpoint::Udp(_, _) => Ok(Box::new(udp::dial_udp(endpoint).await?)), Endpoint::Unix(addr) => Ok(Box::new(unix::dial_unix(addr).await?)), _ => Err(Error::InvalidEndpoint(endpoint.to_string())), } diff --git a/net/src/endpoint.rs b/net/src/endpoint.rs index fdb2735..9193628 100644 --- a/net/src/endpoint.rs +++ b/net/src/endpoint.rs @@ -34,6 +34,7 @@ pub enum Endpoint { Udp(Addr, Port), Tcp(Addr, Port), Tls(Addr, Port), + Ws(Addr, Port), Unix(String), } @@ -49,6 +50,9 @@ impl std::fmt::Display for Endpoint { Endpoint::Tls(ip, port) => { write!(f, "tls://{}:{}", ip, port) } + Endpoint::Ws(ip, port) => { + write!(f, "ws://{}:{}", ip, port) + } Endpoint::Unix(path) => { if path.is_empty() { write!(f, "unix:/UNNAMED") @@ -64,9 +68,10 @@ impl TryFrom for SocketAddr { type Error = Error; fn try_from(endpoint: Endpoint) -> std::result::Result { match endpoint { - Endpoint::Udp(ip, port) | Endpoint::Tcp(ip, port) | Endpoint::Tls(ip, port) => { - Ok(SocketAddr::new(ip.try_into()?, port)) - } + Endpoint::Udp(ip, port) + | Endpoint::Tcp(ip, port) + | Endpoint::Tls(ip, port) + | Endpoint::Ws(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)), Endpoint::Unix(_) => Err(Error::TryFromEndpoint), } } @@ -118,6 +123,7 @@ impl FromStr for Endpoint { "tcp" => Ok(Endpoint::Tcp(addr, port)), "udp" => Ok(Endpoint::Udp(addr, port)), "tls" => Ok(Endpoint::Tls(addr, port)), + "ws" => Ok(Endpoint::Ws(addr, port)), _ => Err(Error::InvalidEndpoint(s.to_string())), } } else { @@ -139,14 +145,19 @@ impl Endpoint { Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port()) } + /// Creates a new UDP endpoint from a `SocketAddr`. + pub fn new_udp_addr(addr: &SocketAddr) -> Endpoint { + Endpoint::Udp(Addr::Ip(addr.ip()), addr.port()) + } + /// Creates a new TLS endpoint from a `SocketAddr`. pub fn new_tls_addr(addr: &SocketAddr) -> Endpoint { Endpoint::Tls(Addr::Ip(addr.ip()), addr.port()) } - /// Creates a new UDP endpoint from a `SocketAddr`. - pub fn new_udp_addr(addr: &SocketAddr) -> Endpoint { - Endpoint::Udp(Addr::Ip(addr.ip()), addr.port()) + /// Creates a new WS endpoint from a `SocketAddr`. + pub fn new_ws_addr(addr: &SocketAddr) -> Endpoint { + Endpoint::Ws(Addr::Ip(addr.ip()), addr.port()) } /// Creates a new Unix endpoint from a `UnixSocketAddress`. @@ -162,7 +173,10 @@ impl Endpoint { /// Returns the `Port` of the endpoint. pub fn port(&self) -> Result<&Port> { match self { - Endpoint::Udp(_, port) | Endpoint::Tcp(_, port) | Endpoint::Tls(_, port) => Ok(port), + Endpoint::Tcp(_, port) + | Endpoint::Udp(_, port) + | Endpoint::Tls(_, port) + | Endpoint::Ws(_, port) => Ok(port), _ => Err(Error::TryFromEndpoint), } } @@ -170,7 +184,10 @@ impl Endpoint { /// Returns the `Addr` of the endpoint. pub fn addr(&self) -> Result<&Addr> { match self { - Endpoint::Udp(addr, _) | Endpoint::Tcp(addr, _) | Endpoint::Tls(addr, _) => Ok(addr), + Endpoint::Tcp(addr, _) + | Endpoint::Udp(addr, _) + | Endpoint::Tls(addr, _) + | Endpoint::Ws(addr, _) => Ok(addr), _ => Err(Error::TryFromEndpoint), } } diff --git a/net/src/listener.rs b/net/src/listener.rs index f12f33e..7f6709a 100644 --- a/net/src/listener.rs +++ b/net/src/listener.rs @@ -39,7 +39,7 @@ pub trait ConnListener: Send + Sync { /// ``` pub async fn listen(endpoint: &Endpoint) -> Result> { match endpoint { - Endpoint::Tcp(addr, port) => Ok(Box::new(tcp::listen_tcp(addr, port).await?)), + Endpoint::Tcp(_, _) => Ok(Box::new(tcp::listen_tcp(endpoint).await?)), Endpoint::Unix(addr) => Ok(Box::new(unix::listen_unix(addr)?)), _ => Err(Error::InvalidEndpoint(endpoint.to_string())), } diff --git a/net/src/transports/tcp.rs b/net/src/transports/tcp.rs index 99243b5..af50c10 100644 --- a/net/src/transports/tcp.rs +++ b/net/src/transports/tcp.rs @@ -1,5 +1,6 @@ -use async_trait::async_trait; +use std::net::SocketAddr; +use async_trait::async_trait; use smol::{ io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, lock::Mutex, @@ -8,7 +9,7 @@ use smol::{ use crate::{ connection::{Connection, ToConn}, - endpoint::{Addr, Endpoint, Port}, + endpoint::Endpoint, listener::{ConnListener, ToListener}, Error, Result, }; @@ -70,17 +71,17 @@ impl ConnListener for TcpListener { } /// Connects to the given TCP address and port. -pub async fn dial_tcp(addr: &Addr, port: &Port) -> Result { - let address = format!("{}:{}", addr, port); - let conn = TcpStream::connect(address).await?; +pub async fn dial_tcp(endpoint: &Endpoint) -> Result { + let addr = SocketAddr::try_from(endpoint.clone())?; + let conn = TcpStream::connect(addr).await?; conn.set_nodelay(true)?; Ok(TcpConn::new(conn)) } /// Listens on the given TCP address and port. -pub async fn listen_tcp(addr: &Addr, port: &Port) -> Result { - let address = format!("{}:{}", addr, port); - let listener = TcpListener::bind(address).await?; +pub async fn listen_tcp(endpoint: &Endpoint) -> Result { + let addr = SocketAddr::try_from(endpoint.clone())?; + let listener = TcpListener::bind(addr).await?; Ok(listener) } diff --git a/net/src/transports/tls.rs b/net/src/transports/tls.rs index 83f7a11..53b4566 100644 --- a/net/src/transports/tls.rs +++ b/net/src/transports/tls.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{net::SocketAddr, sync::Arc}; use async_trait::async_trait; use futures_rustls::{pki_types, rustls, TlsAcceptor, TlsConnector, TlsStream}; @@ -10,7 +10,7 @@ use smol::{ use crate::{ connection::{Connection, ToConn}, - endpoint::{Addr, Endpoint, Port}, + endpoint::Endpoint, listener::{ConnListener, ToListener}, Error, Result, }; @@ -60,16 +60,15 @@ impl Connection for TlsConn { /// Connects to the given TLS address and port. pub async fn dial_tls( - addr: &Addr, - port: &Port, + endpoint: &Endpoint, config: rustls::ClientConfig, dns_name: &'static str, ) -> Result { - let address = format!("{}:{}", addr, port); + let addr = SocketAddr::try_from(endpoint.clone())?; let connector = TlsConnector::from(Arc::new(config)); - let sock = TcpStream::connect(&address).await?; + let sock = TcpStream::connect(addr).await?; sock.set_nodelay(true)?; let altname = pki_types::ServerName::try_from(dns_name)?; @@ -88,10 +87,11 @@ pub async fn dial( _ => return Err(Error::InvalidEndpoint(endpoint.to_string())), } - dial_tls(endpoint.addr()?, endpoint.port()?, config, dns_name) + dial_tls(endpoint, config, dns_name) .await .map(|c| Box::new(c) as Box) } + /// Tls network listener implementation of the `Listener` [`ConnListener`] trait. pub struct TlsListener { acceptor: TlsAcceptor, @@ -113,32 +113,13 @@ impl ConnListener for TlsListener { } /// Listens on the given TLS address and port. -pub async fn listen_tls( - addr: &Addr, - port: &Port, - config: rustls::ServerConfig, -) -> Result { - let address = format!("{}:{}", addr, port); +pub async fn listen_tls(endpoint: &Endpoint, config: rustls::ServerConfig) -> Result { + let addr = SocketAddr::try_from(endpoint.clone())?; let acceptor = TlsAcceptor::from(Arc::new(config)); - let listener = TcpListener::bind(&address).await?; + let listener = TcpListener::bind(addr).await?; Ok(TlsListener { acceptor, listener }) } -/// Listens on the given TLS endpoint, returns `Listener` [`ConnListener`]. -pub async fn listen( - endpoint: &Endpoint, - config: rustls::ServerConfig, -) -> Result> { - match endpoint { - Endpoint::Tcp(..) | Endpoint::Tls(..) => {} - _ => return Err(Error::InvalidEndpoint(endpoint.to_string())), - } - - listen_tls(endpoint.addr()?, endpoint.port()?, config) - .await - .map(|l| Box::new(l) as Box) -} - impl From> for Box { fn from(conn: TlsStream) -> Self { Box::new(TlsConn::new(conn.get_ref().0.clone(), conn)) diff --git a/net/src/transports/udp.rs b/net/src/transports/udp.rs index a8b505c..991b1fd 100644 --- a/net/src/transports/udp.rs +++ b/net/src/transports/udp.rs @@ -5,7 +5,7 @@ use smol::net::UdpSocket; use crate::{ connection::{Connection, ToConn}, - endpoint::{Addr, Endpoint, Port}, + endpoint::Endpoint, Error, Result, }; @@ -57,19 +57,19 @@ impl Connection for UdpConn { } /// Connects to the given UDP address and port. -pub async fn dial_udp(addr: &Addr, port: &Port) -> Result { - let address = format!("{}:{}", addr, port); +pub async fn dial_udp(endpoint: &Endpoint) -> Result { + let addr = SocketAddr::try_from(endpoint.clone())?; // Let the operating system assign an available port to this socket let conn = UdpSocket::bind("[::]:0").await?; - conn.connect(address).await?; + conn.connect(addr).await?; Ok(UdpConn::new(conn)) } /// Listens on the given UDP address and port. -pub async fn listen_udp(addr: &Addr, port: &Port) -> Result { - let address = format!("{}:{}", addr, port); - let conn = UdpSocket::bind(address).await?; +pub async fn listen_udp(endpoint: &Endpoint) -> Result { + let addr = SocketAddr::try_from(endpoint.clone())?; + let conn = UdpSocket::bind(addr).await?; let udp_conn = UdpConn::new(conn); Ok(udp_conn) } diff --git a/p2p/src/discovery/refresh.rs b/p2p/src/discovery/refresh.rs index ed111fb..882a93e 100644 --- a/p2p/src/discovery/refresh.rs +++ b/p2p/src/discovery/refresh.rs @@ -195,7 +195,8 @@ impl RefreshService { /// specified in the Config, with backoff between each retry. async fn connect(&self, entry: &Entry) -> Result<()> { let mut retry = 0; - let conn = dial_udp(&entry.addr, &entry.discovery_port).await?; + let endpoint = Endpoint::Ws(entry.addr.clone(), entry.discovery_port); + let conn = dial_udp(&endpoint).await?; let backoff = Backoff::new(100, 5000); while retry < self.config.refresh_connect_retries { match self.send_ping_msg(&conn).await { @@ -217,7 +218,7 @@ impl RefreshService { /// peers. async fn listen_loop(self: Arc, addr: Addr, port: Port) -> Result<()> { let endpoint = Endpoint::Udp(addr.clone(), port); - let conn = match listen_udp(&addr, &port).await { + let conn = match listen_udp(&endpoint).await { Ok(c) => { self.monitor .notify(&ConnEvent::Listening(endpoint.clone()).into()) diff --git a/p2p/src/listener.rs b/p2p/src/listener.rs index ab6f7c1..254e4e6 100644 --- a/p2p/src/listener.rs +++ b/p2p/src/listener.rs @@ -155,7 +155,9 @@ impl Listener { async fn listend(&self, endpoint: &Endpoint) -> Result> { if self.enable_tls { let tls_config = tls_server_config(&self.key_pair)?; - tls::listen(endpoint, tls_config).await + tls::listen_tls(endpoint, tls_config) + .await + .map(|l| Box::new(l) as Box) } else { listen(endpoint).await } -- cgit v1.2.3