From 76e952830302271e07a4be9df6dfaa1c11e3e675 Mon Sep 17 00:00:00 2001
From: hozan23 <hozan23@proton.me>
Date: Wed, 13 Mar 2024 03:35:09 +0100
Subject: net: pass `Endpoint` to dial and listen functions

---
 net/src/connection.rs     |  4 ++--
 net/src/endpoint.rs       | 33 +++++++++++++++++++++++++--------
 net/src/listener.rs       |  2 +-
 net/src/transports/tcp.rs | 17 +++++++++--------
 net/src/transports/tls.rs | 39 ++++++++++-----------------------------
 net/src/transports/udp.rs | 14 +++++++-------
 6 files changed, 54 insertions(+), 55 deletions(-)

(limited to 'net/src')

diff --git a/net/src/connection.rs b/net/src/connection.rs
index 73606a2..ea73a39 100644
--- a/net/src/connection.rs
+++ b/net/src/connection.rs
@@ -57,8 +57,8 @@ pub trait Connection: Send + Sync {
 ///
 pub async fn dial(endpoint: &Endpoint) -> Result<Conn> {
     match endpoint {
-        Endpoint::Tcp(addr, port) => Ok(Box::new(tcp::dial_tcp(addr, port).await?)),
-        Endpoint::Udp(addr, port) => Ok(Box::new(udp::dial_udp(addr, port).await?)),
+        Endpoint::Tcp(_, _) => Ok(Box::new(tcp::dial_tcp(endpoint).await?)),
+        Endpoint::Udp(_, _) => Ok(Box::new(udp::dial_udp(endpoint).await?)),
         Endpoint::Unix(addr) => Ok(Box::new(unix::dial_unix(addr).await?)),
         _ => Err(Error::InvalidEndpoint(endpoint.to_string())),
     }
diff --git a/net/src/endpoint.rs b/net/src/endpoint.rs
index fdb2735..9193628 100644
--- a/net/src/endpoint.rs
+++ b/net/src/endpoint.rs
@@ -34,6 +34,7 @@ pub enum Endpoint {
     Udp(Addr, Port),
     Tcp(Addr, Port),
     Tls(Addr, Port),
+    Ws(Addr, Port),
     Unix(String),
 }
 
@@ -49,6 +50,9 @@ impl std::fmt::Display for Endpoint {
             Endpoint::Tls(ip, port) => {
                 write!(f, "tls://{}:{}", ip, port)
             }
+            Endpoint::Ws(ip, port) => {
+                write!(f, "ws://{}:{}", ip, port)
+            }
             Endpoint::Unix(path) => {
                 if path.is_empty() {
                     write!(f, "unix:/UNNAMED")
@@ -64,9 +68,10 @@ impl TryFrom<Endpoint> for SocketAddr {
     type Error = Error;
     fn try_from(endpoint: Endpoint) -> std::result::Result<SocketAddr, Self::Error> {
         match endpoint {
-            Endpoint::Udp(ip, port) | Endpoint::Tcp(ip, port) | Endpoint::Tls(ip, port) => {
-                Ok(SocketAddr::new(ip.try_into()?, port))
-            }
+            Endpoint::Udp(ip, port)
+            | Endpoint::Tcp(ip, port)
+            | Endpoint::Tls(ip, port)
+            | Endpoint::Ws(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)),
             Endpoint::Unix(_) => Err(Error::TryFromEndpoint),
         }
     }
@@ -118,6 +123,7 @@ impl FromStr for Endpoint {
                 "tcp" => Ok(Endpoint::Tcp(addr, port)),
                 "udp" => Ok(Endpoint::Udp(addr, port)),
                 "tls" => Ok(Endpoint::Tls(addr, port)),
+                "ws" => Ok(Endpoint::Ws(addr, port)),
                 _ => Err(Error::InvalidEndpoint(s.to_string())),
             }
         } else {
@@ -139,14 +145,19 @@ impl Endpoint {
         Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port())
     }
 
+    /// Creates a new UDP endpoint from a `SocketAddr`.
+    pub fn new_udp_addr(addr: &SocketAddr) -> Endpoint {
+        Endpoint::Udp(Addr::Ip(addr.ip()), addr.port())
+    }
+
     /// Creates a new TLS endpoint from a `SocketAddr`.
     pub fn new_tls_addr(addr: &SocketAddr) -> Endpoint {
         Endpoint::Tls(Addr::Ip(addr.ip()), addr.port())
     }
 
-    /// Creates a new UDP endpoint from a `SocketAddr`.
-    pub fn new_udp_addr(addr: &SocketAddr) -> Endpoint {
-        Endpoint::Udp(Addr::Ip(addr.ip()), addr.port())
+    /// Creates a new WS endpoint from a `SocketAddr`.
+    pub fn new_ws_addr(addr: &SocketAddr) -> Endpoint {
+        Endpoint::Ws(Addr::Ip(addr.ip()), addr.port())
     }
 
     /// Creates a new Unix endpoint from a `UnixSocketAddress`.
@@ -162,7 +173,10 @@ impl Endpoint {
     /// Returns the `Port` of the endpoint.
     pub fn port(&self) -> Result<&Port> {
         match self {
-            Endpoint::Udp(_, port) | Endpoint::Tcp(_, port) | Endpoint::Tls(_, port) => Ok(port),
+            Endpoint::Tcp(_, port)
+            | Endpoint::Udp(_, port)
+            | Endpoint::Tls(_, port)
+            | Endpoint::Ws(_, port) => Ok(port),
             _ => Err(Error::TryFromEndpoint),
         }
     }
@@ -170,7 +184,10 @@ impl Endpoint {
     /// Returns the `Addr` of the endpoint.
     pub fn addr(&self) -> Result<&Addr> {
         match self {
-            Endpoint::Udp(addr, _) | Endpoint::Tcp(addr, _) | Endpoint::Tls(addr, _) => Ok(addr),
+            Endpoint::Tcp(addr, _)
+            | Endpoint::Udp(addr, _)
+            | Endpoint::Tls(addr, _)
+            | Endpoint::Ws(addr, _) => Ok(addr),
             _ => Err(Error::TryFromEndpoint),
         }
     }
diff --git a/net/src/listener.rs b/net/src/listener.rs
index f12f33e..7f6709a 100644
--- a/net/src/listener.rs
+++ b/net/src/listener.rs
@@ -39,7 +39,7 @@ pub trait ConnListener: Send + Sync {
 /// ```
 pub async fn listen(endpoint: &Endpoint) -> Result<Box<dyn ConnListener>> {
     match endpoint {
-        Endpoint::Tcp(addr, port) => Ok(Box::new(tcp::listen_tcp(addr, port).await?)),
+        Endpoint::Tcp(_, _) => Ok(Box::new(tcp::listen_tcp(endpoint).await?)),
         Endpoint::Unix(addr) => Ok(Box::new(unix::listen_unix(addr)?)),
         _ => Err(Error::InvalidEndpoint(endpoint.to_string())),
     }
diff --git a/net/src/transports/tcp.rs b/net/src/transports/tcp.rs
index 99243b5..af50c10 100644
--- a/net/src/transports/tcp.rs
+++ b/net/src/transports/tcp.rs
@@ -1,5 +1,6 @@
-use async_trait::async_trait;
+use std::net::SocketAddr;
 
+use async_trait::async_trait;
 use smol::{
     io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
     lock::Mutex,
@@ -8,7 +9,7 @@ use smol::{
 
 use crate::{
     connection::{Connection, ToConn},
-    endpoint::{Addr, Endpoint, Port},
+    endpoint::Endpoint,
     listener::{ConnListener, ToListener},
     Error, Result,
 };
@@ -70,17 +71,17 @@ impl ConnListener for TcpListener {
 }
 
 /// Connects to the given TCP address and port.
-pub async fn dial_tcp(addr: &Addr, port: &Port) -> Result<TcpConn> {
-    let address = format!("{}:{}", addr, port);
-    let conn = TcpStream::connect(address).await?;
+pub async fn dial_tcp(endpoint: &Endpoint) -> Result<TcpConn> {
+    let addr = SocketAddr::try_from(endpoint.clone())?;
+    let conn = TcpStream::connect(addr).await?;
     conn.set_nodelay(true)?;
     Ok(TcpConn::new(conn))
 }
 
 /// Listens on the given TCP address and port.
-pub async fn listen_tcp(addr: &Addr, port: &Port) -> Result<TcpListener> {
-    let address = format!("{}:{}", addr, port);
-    let listener = TcpListener::bind(address).await?;
+pub async fn listen_tcp(endpoint: &Endpoint) -> Result<TcpListener> {
+    let addr = SocketAddr::try_from(endpoint.clone())?;
+    let listener = TcpListener::bind(addr).await?;
     Ok(listener)
 }
 
diff --git a/net/src/transports/tls.rs b/net/src/transports/tls.rs
index 83f7a11..53b4566 100644
--- a/net/src/transports/tls.rs
+++ b/net/src/transports/tls.rs
@@ -1,4 +1,4 @@
-use std::sync::Arc;
+use std::{net::SocketAddr, sync::Arc};
 
 use async_trait::async_trait;
 use futures_rustls::{pki_types, rustls, TlsAcceptor, TlsConnector, TlsStream};
@@ -10,7 +10,7 @@ use smol::{
 
 use crate::{
     connection::{Connection, ToConn},
-    endpoint::{Addr, Endpoint, Port},
+    endpoint::Endpoint,
     listener::{ConnListener, ToListener},
     Error, Result,
 };
@@ -60,16 +60,15 @@ impl Connection for TlsConn {
 
 /// Connects to the given TLS address and port.
 pub async fn dial_tls(
-    addr: &Addr,
-    port: &Port,
+    endpoint: &Endpoint,
     config: rustls::ClientConfig,
     dns_name: &'static str,
 ) -> Result<TlsConn> {
-    let address = format!("{}:{}", addr, port);
+    let addr = SocketAddr::try_from(endpoint.clone())?;
 
     let connector = TlsConnector::from(Arc::new(config));
 
-    let sock = TcpStream::connect(&address).await?;
+    let sock = TcpStream::connect(addr).await?;
     sock.set_nodelay(true)?;
 
     let altname = pki_types::ServerName::try_from(dns_name)?;
@@ -88,10 +87,11 @@ pub async fn dial(
         _ => return Err(Error::InvalidEndpoint(endpoint.to_string())),
     }
 
-    dial_tls(endpoint.addr()?, endpoint.port()?, config, dns_name)
+    dial_tls(endpoint, config, dns_name)
         .await
         .map(|c| Box::new(c) as Box<dyn Connection>)
 }
+
 /// Tls network listener implementation of the `Listener` [`ConnListener`] trait.
 pub struct TlsListener {
     acceptor: TlsAcceptor,
@@ -113,32 +113,13 @@ impl ConnListener for TlsListener {
 }
 
 /// Listens on the given TLS address and port.
-pub async fn listen_tls(
-    addr: &Addr,
-    port: &Port,
-    config: rustls::ServerConfig,
-) -> Result<TlsListener> {
-    let address = format!("{}:{}", addr, port);
+pub async fn listen_tls(endpoint: &Endpoint, config: rustls::ServerConfig) -> Result<TlsListener> {
+    let addr = SocketAddr::try_from(endpoint.clone())?;
     let acceptor = TlsAcceptor::from(Arc::new(config));
-    let listener = TcpListener::bind(&address).await?;
+    let listener = TcpListener::bind(addr).await?;
     Ok(TlsListener { acceptor, listener })
 }
 
-/// Listens on the given TLS endpoint, returns `Listener` [`ConnListener`].
-pub async fn listen(
-    endpoint: &Endpoint,
-    config: rustls::ServerConfig,
-) -> Result<Box<dyn ConnListener>> {
-    match endpoint {
-        Endpoint::Tcp(..) | Endpoint::Tls(..) => {}
-        _ => return Err(Error::InvalidEndpoint(endpoint.to_string())),
-    }
-
-    listen_tls(endpoint.addr()?, endpoint.port()?, config)
-        .await
-        .map(|l| Box::new(l) as Box<dyn ConnListener>)
-}
-
 impl From<TlsStream<TcpStream>> for Box<dyn Connection> {
     fn from(conn: TlsStream<TcpStream>) -> Self {
         Box::new(TlsConn::new(conn.get_ref().0.clone(), conn))
diff --git a/net/src/transports/udp.rs b/net/src/transports/udp.rs
index a8b505c..991b1fd 100644
--- a/net/src/transports/udp.rs
+++ b/net/src/transports/udp.rs
@@ -5,7 +5,7 @@ use smol::net::UdpSocket;
 
 use crate::{
     connection::{Connection, ToConn},
-    endpoint::{Addr, Endpoint, Port},
+    endpoint::Endpoint,
     Error, Result,
 };
 
@@ -57,19 +57,19 @@ impl Connection for UdpConn {
 }
 
 /// Connects to the given UDP address and port.
-pub async fn dial_udp(addr: &Addr, port: &Port) -> Result<UdpConn> {
-    let address = format!("{}:{}", addr, port);
+pub async fn dial_udp(endpoint: &Endpoint) -> Result<UdpConn> {
+    let addr = SocketAddr::try_from(endpoint.clone())?;
 
     // Let the operating system assign an available port to this socket
     let conn = UdpSocket::bind("[::]:0").await?;
-    conn.connect(address).await?;
+    conn.connect(addr).await?;
     Ok(UdpConn::new(conn))
 }
 
 /// Listens on the given UDP address and port.
-pub async fn listen_udp(addr: &Addr, port: &Port) -> Result<UdpConn> {
-    let address = format!("{}:{}", addr, port);
-    let conn = UdpSocket::bind(address).await?;
+pub async fn listen_udp(endpoint: &Endpoint) -> Result<UdpConn> {
+    let addr = SocketAddr::try_from(endpoint.clone())?;
+    let conn = UdpSocket::bind(addr).await?;
     let udp_conn = UdpConn::new(conn);
     Ok(udp_conn)
 }
-- 
cgit v1.2.3