From 98a1de91a2dae06323558422c239e5a45fc86e7b Mon Sep 17 00:00:00 2001
From: hozan23 <hozan23@proton.me>
Date: Tue, 28 Nov 2023 22:41:33 +0300
Subject: implement TLS for inbound and outbound connections

---
 net/Cargo.toml             |   3 +-
 net/src/connection.rs      |   9 ++-
 net/src/endpoint.rs        |  70 ++++++++++++++++++-----
 net/src/error.rs           |   8 ++-
 net/src/lib.rs             |   1 +
 net/src/listener.rs        |   5 +-
 net/src/transports/mod.rs  |   1 +
 net/src/transports/tcp.rs  |   2 +-
 net/src/transports/tls.rs  | 140 +++++++++++++++++++++++++++++++++++++++++++++
 net/src/transports/udp.rs  |   2 +-
 net/src/transports/unix.rs |   2 +-
 11 files changed, 219 insertions(+), 24 deletions(-)
 create mode 100644 net/src/transports/tls.rs

(limited to 'net')

diff --git a/net/Cargo.toml b/net/Cargo.toml
index de9b33b..863a250 100644
--- a/net/Cargo.toml
+++ b/net/Cargo.toml
@@ -11,6 +11,7 @@ karyons_core.workspace = true
 smol = "1.3.0"
 async-trait = "0.1.73"
 log = "0.4.20"
-bincode = { version="2.0.0-rc.3", features = ["derive"]}
+bincode = "2.0.0-rc.3"
 thiserror = "1.0.47"
 url = "2.4.1"
+async-rustls = { version = "0.4.1", features = ["dangerous_configuration"] }
diff --git a/net/src/connection.rs b/net/src/connection.rs
index d8ec0a3..b1d7550 100644
--- a/net/src/connection.rs
+++ b/net/src/connection.rs
@@ -1,7 +1,9 @@
-use crate::{Endpoint, Result};
 use async_trait::async_trait;
 
-use crate::transports::{tcp, udp, unix};
+use crate::{
+    transports::{tcp, udp, unix},
+    Endpoint, Error, Result,
+};
 
 /// Alias for `Box<dyn Connection>`
 pub type Conn = Box<dyn Connection>;
@@ -28,7 +30,7 @@ pub trait Connection: Send + Sync {
 
 /// Connects to the provided endpoint.
 ///
-/// it only supports `tcp4/6`, `udp4/6` and `unix`.
+/// it only supports `tcp4/6`, `udp4/6`, and `unix`.
 ///
 /// #Example
 ///
@@ -53,5 +55,6 @@ pub async fn dial(endpoint: &Endpoint) -> Result<Conn> {
         Endpoint::Tcp(addr, port) => Ok(Box::new(tcp::dial_tcp(addr, port).await?)),
         Endpoint::Udp(addr, port) => Ok(Box::new(udp::dial_udp(addr, port).await?)),
         Endpoint::Unix(addr) => Ok(Box::new(unix::dial_unix(addr).await?)),
+        _ => Err(Error::InvalidEndpoint(endpoint.to_string())),
     }
 }
diff --git a/net/src/endpoint.rs b/net/src/endpoint.rs
index 50dfe6b..720eea3 100644
--- a/net/src/endpoint.rs
+++ b/net/src/endpoint.rs
@@ -5,7 +5,7 @@ use std::{
     str::FromStr,
 };
 
-use bincode::{Decode, Encode};
+use bincode::{impl_borrow_decode, Decode, Encode};
 use url::Url;
 
 use crate::{Error, Result};
@@ -33,6 +33,7 @@ pub type Port = u16;
 pub enum Endpoint {
     Udp(Addr, Port),
     Tcp(Addr, Port),
+    Tls(Addr, Port),
     Unix(String),
 }
 
@@ -45,6 +46,9 @@ impl std::fmt::Display for Endpoint {
             Endpoint::Tcp(ip, port) => {
                 write!(f, "tcp://{}:{}", ip, port)
             }
+            Endpoint::Tls(ip, port) => {
+                write!(f, "tls://{}:{}", ip, port)
+            }
             Endpoint::Unix(path) => {
                 if path.is_empty() {
                     write!(f, "unix:/UNNAMED")
@@ -60,9 +64,10 @@ impl TryFrom<Endpoint> for SocketAddr {
     type Error = Error;
     fn try_from(endpoint: Endpoint) -> std::result::Result<SocketAddr, Self::Error> {
         match endpoint {
-            Endpoint::Udp(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)),
-            Endpoint::Tcp(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)),
-            Endpoint::Unix(_) => Err(Error::TryFromEndpointError),
+            Endpoint::Udp(ip, port) | Endpoint::Tcp(ip, port) | Endpoint::Tls(ip, port) => {
+                Ok(SocketAddr::new(ip.try_into()?, port))
+            }
+            Endpoint::Unix(_) => Err(Error::TryFromEndpoint),
         }
     }
 }
@@ -72,7 +77,7 @@ impl TryFrom<Endpoint> for PathBuf {
     fn try_from(endpoint: Endpoint) -> std::result::Result<PathBuf, Self::Error> {
         match endpoint {
             Endpoint::Unix(path) => Ok(PathBuf::from(&path)),
-            _ => Err(Error::TryFromEndpointError),
+            _ => Err(Error::TryFromEndpoint),
         }
     }
 }
@@ -82,7 +87,7 @@ impl TryFrom<Endpoint> for UnixSocketAddress {
     fn try_from(endpoint: Endpoint) -> std::result::Result<UnixSocketAddress, Self::Error> {
         match endpoint {
             Endpoint::Unix(a) => Ok(UnixSocketAddress::from_pathname(a)?),
-            _ => Err(Error::TryFromEndpointError),
+            _ => Err(Error::TryFromEndpoint),
         }
     }
 }
@@ -112,6 +117,7 @@ impl FromStr for Endpoint {
             match url.scheme() {
                 "tcp" => Ok(Endpoint::Tcp(addr, port)),
                 "udp" => Ok(Endpoint::Udp(addr, port)),
+                "tls" => Ok(Endpoint::Tls(addr, port)),
                 _ => Err(Error::InvalidEndpoint(s.to_string())),
             }
         } else {
@@ -133,6 +139,11 @@ impl Endpoint {
         Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port())
     }
 
+    /// Creates a new TLS endpoint from a `SocketAddr`.
+    pub fn new_tls_addr(addr: &SocketAddr) -> Endpoint {
+        Endpoint::Tls(Addr::Ip(addr.ip()), addr.port())
+    }
+
     /// Creates a new UDP endpoint from a `SocketAddr`.
     pub fn new_udp_addr(addr: &SocketAddr) -> Endpoint {
         Endpoint::Udp(Addr::Ip(addr.ip()), addr.port())
@@ -151,29 +162,62 @@ impl Endpoint {
     /// Returns the `Port` of the endpoint.
     pub fn port(&self) -> Result<&Port> {
         match self {
-            Endpoint::Tcp(_, port) => Ok(port),
-            Endpoint::Udp(_, port) => Ok(port),
-            _ => Err(Error::TryFromEndpointError),
+            Endpoint::Udp(_, port) | Endpoint::Tcp(_, port) | Endpoint::Tls(_, port) => Ok(port),
+            _ => Err(Error::TryFromEndpoint),
         }
     }
 
     /// Returns the `Addr` of the endpoint.
     pub fn addr(&self) -> Result<&Addr> {
         match self {
-            Endpoint::Tcp(addr, _) => Ok(addr),
-            Endpoint::Udp(addr, _) => Ok(addr),
-            _ => Err(Error::TryFromEndpointError),
+            Endpoint::Udp(addr, _) | Endpoint::Tcp(addr, _) | Endpoint::Tls(addr, _) => Ok(addr),
+            _ => Err(Error::TryFromEndpoint),
         }
     }
 }
 
 /// Addr defines a type for an address, either IP or domain.
-#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 pub enum Addr {
     Ip(IpAddr),
     Domain(String),
 }
 
+impl Encode for Addr {
+    fn encode<E: bincode::enc::Encoder>(
+        &self,
+        encoder: &mut E,
+    ) -> std::result::Result<(), bincode::error::EncodeError> {
+        match self {
+            Addr::Ip(addr) => {
+                0u32.encode(encoder)?;
+                addr.encode(encoder)
+            }
+            Addr::Domain(domain) => {
+                1u32.encode(encoder)?;
+                domain.encode(encoder)
+            }
+        }
+    }
+}
+
+impl Decode for Addr {
+    fn decode<D: bincode::de::Decoder>(
+        decoder: &mut D,
+    ) -> std::result::Result<Self, bincode::error::DecodeError> {
+        match u32::decode(decoder)? {
+            0 => Ok(Addr::Ip(IpAddr::decode(decoder)?)),
+            1 => Ok(Addr::Domain(String::decode(decoder)?)),
+            found => Err(bincode::error::DecodeError::UnexpectedVariant {
+                allowed: &bincode::error::AllowedEnumVariants::Range { min: 0, max: 1 },
+                found,
+                type_name: core::any::type_name::<Addr>(),
+            }),
+        }
+    }
+}
+impl_borrow_decode!(Addr);
+
 impl TryFrom<Addr> for IpAddr {
     type Error = Error;
     fn try_from(addr: Addr) -> std::result::Result<IpAddr, Self::Error> {
diff --git a/net/src/error.rs b/net/src/error.rs
index 346184a..5dd6348 100644
--- a/net/src/error.rs
+++ b/net/src/error.rs
@@ -8,7 +8,7 @@ pub enum Error {
     IO(#[from] std::io::Error),
 
     #[error("Try from endpoint Error")]
-    TryFromEndpointError,
+    TryFromEndpoint,
 
     #[error("invalid address {0}")]
     InvalidAddress(String),
@@ -28,6 +28,12 @@ pub enum Error {
     #[error(transparent)]
     ChannelRecv(#[from] smol::channel::RecvError),
 
+    #[error("Tls Error: {0}")]
+    Rustls(#[from] async_rustls::rustls::Error),
+
+    #[error("Invalid DNS Name: {0}")]
+    InvalidDnsNameError(#[from] async_rustls::rustls::client::InvalidDnsNameError),
+
     #[error(transparent)]
     KaryonsCore(#[from] karyons_core::error::Error),
 }
diff --git a/net/src/lib.rs b/net/src/lib.rs
index 0e4c361..61069ef 100644
--- a/net/src/lib.rs
+++ b/net/src/lib.rs
@@ -10,6 +10,7 @@ pub use {
     listener::{listen, Listener},
     transports::{
         tcp::{dial_tcp, listen_tcp, TcpConn},
+        tls,
         udp::{dial_udp, listen_udp, UdpConn},
         unix::{dial_unix, listen_unix, UnixConn},
     },
diff --git a/net/src/listener.rs b/net/src/listener.rs
index 31a63ae..c6c3d94 100644
--- a/net/src/listener.rs
+++ b/net/src/listener.rs
@@ -1,9 +1,8 @@
-use crate::{Endpoint, Error, Result};
 use async_trait::async_trait;
 
 use crate::{
     transports::{tcp, unix},
-    Conn,
+    Conn, Endpoint, Error, Result,
 };
 
 /// Listener is a generic network listener.
@@ -15,7 +14,7 @@ pub trait Listener: Send + Sync {
 
 /// Listens to the provided endpoint.
 ///
-/// it only supports `tcp4/6` and `unix`.
+/// it only supports `tcp4/6`, and `unix`.
 ///
 /// #Example
 ///
diff --git a/net/src/transports/mod.rs b/net/src/transports/mod.rs
index f399133..ac23021 100644
--- a/net/src/transports/mod.rs
+++ b/net/src/transports/mod.rs
@@ -1,3 +1,4 @@
 pub mod tcp;
+pub mod tls;
 pub mod udp;
 pub mod unix;
diff --git a/net/src/transports/tcp.rs b/net/src/transports/tcp.rs
index 84aa980..37f00a7 100644
--- a/net/src/transports/tcp.rs
+++ b/net/src/transports/tcp.rs
@@ -13,7 +13,7 @@ use crate::{
     Error, Result,
 };
 
-/// TCP network connection implementations of the [`Connection`] trait.
+/// TCP network connection implementation of the [`Connection`] trait.
 pub struct TcpConn {
     inner: TcpStream,
     read: Mutex<ReadHalf<TcpStream>>,
diff --git a/net/src/transports/tls.rs b/net/src/transports/tls.rs
new file mode 100644
index 0000000..01bb5aa
--- /dev/null
+++ b/net/src/transports/tls.rs
@@ -0,0 +1,140 @@
+use std::sync::Arc;
+
+use async_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream};
+use async_trait::async_trait;
+use smol::{
+    io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
+    lock::Mutex,
+    net::{TcpListener, TcpStream},
+};
+
+use crate::{
+    connection::Connection,
+    endpoint::{Addr, Endpoint, Port},
+    listener::Listener,
+    Error, Result,
+};
+
+/// TLS network connection implementation of the [`Connection`] trait.
+pub struct TlsConn {
+    inner: TcpStream,
+    read: Mutex<ReadHalf<TlsStream<TcpStream>>>,
+    write: Mutex<WriteHalf<TlsStream<TcpStream>>>,
+}
+
+impl TlsConn {
+    /// Creates a new TlsConn
+    pub fn new(sock: TcpStream, conn: TlsStream<TcpStream>) -> Self {
+        let (read, write) = split(conn);
+        Self {
+            inner: sock,
+            read: Mutex::new(read),
+            write: Mutex::new(write),
+        }
+    }
+}
+
+#[async_trait]
+impl Connection for TlsConn {
+    fn peer_endpoint(&self) -> Result<Endpoint> {
+        Ok(Endpoint::new_tls_addr(&self.inner.peer_addr()?))
+    }
+
+    fn local_endpoint(&self) -> Result<Endpoint> {
+        Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?))
+    }
+
+    async fn read(&self, buf: &mut [u8]) -> Result<usize> {
+        self.read.lock().await.read(buf).await.map_err(Error::from)
+    }
+
+    async fn write(&self, buf: &[u8]) -> Result<usize> {
+        self.write
+            .lock()
+            .await
+            .write(buf)
+            .await
+            .map_err(Error::from)
+    }
+}
+
+/// Connects to the given TLS address and port.
+pub async fn dial_tls(
+    addr: &Addr,
+    port: &Port,
+    config: rustls::ClientConfig,
+    dns_name: &str,
+) -> Result<TlsConn> {
+    let address = format!("{}:{}", addr, port);
+
+    let connector = TlsConnector::from(Arc::new(config));
+
+    let sock = TcpStream::connect(&address).await?;
+    sock.set_nodelay(true)?;
+
+    let altname = rustls::ServerName::try_from(dns_name)?;
+    let conn = connector.connect(altname, sock.clone()).await?;
+    Ok(TlsConn::new(sock, TlsStream::Client(conn)))
+}
+
+/// Connects to the given TLS endpoint, returns `Conn` ([`Connection`]).
+pub async fn dial(
+    endpoint: &Endpoint,
+    config: rustls::ClientConfig,
+    dns_name: &str,
+) -> Result<Box<dyn Connection>> {
+    match endpoint {
+        Endpoint::Tcp(..) | Endpoint::Tls(..) => {}
+        _ => return Err(Error::InvalidEndpoint(endpoint.to_string())),
+    }
+
+    dial_tls(endpoint.addr()?, endpoint.port()?, config, dns_name)
+        .await
+        .map(|c| Box::new(c) as Box<dyn Connection>)
+}
+/// Tls network listener implementation of the [`Listener`] trait.
+pub struct TlsListener {
+    acceptor: TlsAcceptor,
+    listener: TcpListener,
+}
+
+#[async_trait]
+impl Listener for TlsListener {
+    fn local_endpoint(&self) -> Result<Endpoint> {
+        Ok(Endpoint::new_tls_addr(&self.listener.local_addr()?))
+    }
+
+    async fn accept(&self) -> Result<Box<dyn Connection>> {
+        let (sock, _) = self.listener.accept().await?;
+        sock.set_nodelay(true)?;
+        let conn = self.acceptor.accept(sock.clone()).await?;
+        Ok(Box::new(TlsConn::new(sock, TlsStream::Server(conn))))
+    }
+}
+
+/// Listens on the given TLS address and port.
+pub async fn listen_tls(
+    addr: &Addr,
+    port: &Port,
+    config: rustls::ServerConfig,
+) -> Result<TlsListener> {
+    let address = format!("{}:{}", addr, port);
+    let acceptor = TlsAcceptor::from(Arc::new(config));
+    let listener = TcpListener::bind(&address).await?;
+    Ok(TlsListener { acceptor, listener })
+}
+
+/// Listens on the given TLS endpoint, returns [`Listener`].
+pub async fn listen(
+    endpoint: &Endpoint,
+    config: rustls::ServerConfig,
+) -> Result<Box<dyn Listener>> {
+    match endpoint {
+        Endpoint::Tcp(..) | Endpoint::Tls(..) => {}
+        _ => return Err(Error::InvalidEndpoint(endpoint.to_string())),
+    }
+
+    listen_tls(endpoint.addr()?, endpoint.port()?, config)
+        .await
+        .map(|l| Box::new(l) as Box<dyn Listener>)
+}
diff --git a/net/src/transports/udp.rs b/net/src/transports/udp.rs
index ca5b94d..8a2fbec 100644
--- a/net/src/transports/udp.rs
+++ b/net/src/transports/udp.rs
@@ -9,7 +9,7 @@ use crate::{
     Error, Result,
 };
 
-/// UDP network connection implementations of the [`Connection`] trait.
+/// UDP network connection implementation of the [`Connection`] trait.
 pub struct UdpConn {
     inner: UdpSocket,
 }
diff --git a/net/src/transports/unix.rs b/net/src/transports/unix.rs
index a720d91..e504934 100644
--- a/net/src/transports/unix.rs
+++ b/net/src/transports/unix.rs
@@ -8,7 +8,7 @@ use smol::{
 
 use crate::{connection::Connection, endpoint::Endpoint, listener::Listener, Error, Result};
 
-/// Unix domain socket implementations of the [`Connection`] trait.
+/// Unix domain socket implementation of the [`Connection`] trait.
 pub struct UnixConn {
     inner: UnixStream,
     read: Mutex<ReadHalf<UnixStream>>,
-- 
cgit v1.2.3