From 98a1de91a2dae06323558422c239e5a45fc86e7b Mon Sep 17 00:00:00 2001 From: hozan23 Date: Tue, 28 Nov 2023 22:41:33 +0300 Subject: implement TLS for inbound and outbound connections --- net/src/connection.rs | 9 ++- net/src/endpoint.rs | 70 ++++++++++++++++++----- net/src/error.rs | 8 ++- net/src/lib.rs | 1 + net/src/listener.rs | 5 +- net/src/transports/mod.rs | 1 + net/src/transports/tcp.rs | 2 +- net/src/transports/tls.rs | 140 +++++++++++++++++++++++++++++++++++++++++++++ net/src/transports/udp.rs | 2 +- net/src/transports/unix.rs | 2 +- 10 files changed, 217 insertions(+), 23 deletions(-) create mode 100644 net/src/transports/tls.rs (limited to 'net/src') diff --git a/net/src/connection.rs b/net/src/connection.rs index d8ec0a3..b1d7550 100644 --- a/net/src/connection.rs +++ b/net/src/connection.rs @@ -1,7 +1,9 @@ -use crate::{Endpoint, Result}; use async_trait::async_trait; -use crate::transports::{tcp, udp, unix}; +use crate::{ + transports::{tcp, udp, unix}, + Endpoint, Error, Result, +}; /// Alias for `Box` pub type Conn = Box; @@ -28,7 +30,7 @@ pub trait Connection: Send + Sync { /// Connects to the provided endpoint. /// -/// it only supports `tcp4/6`, `udp4/6` and `unix`. +/// it only supports `tcp4/6`, `udp4/6`, and `unix`. /// /// #Example /// @@ -53,5 +55,6 @@ pub async fn dial(endpoint: &Endpoint) -> Result { Endpoint::Tcp(addr, port) => Ok(Box::new(tcp::dial_tcp(addr, port).await?)), Endpoint::Udp(addr, port) => Ok(Box::new(udp::dial_udp(addr, port).await?)), Endpoint::Unix(addr) => Ok(Box::new(unix::dial_unix(addr).await?)), + _ => Err(Error::InvalidEndpoint(endpoint.to_string())), } } diff --git a/net/src/endpoint.rs b/net/src/endpoint.rs index 50dfe6b..720eea3 100644 --- a/net/src/endpoint.rs +++ b/net/src/endpoint.rs @@ -5,7 +5,7 @@ use std::{ str::FromStr, }; -use bincode::{Decode, Encode}; +use bincode::{impl_borrow_decode, Decode, Encode}; use url::Url; use crate::{Error, Result}; @@ -33,6 +33,7 @@ pub type Port = u16; pub enum Endpoint { Udp(Addr, Port), Tcp(Addr, Port), + Tls(Addr, Port), Unix(String), } @@ -45,6 +46,9 @@ impl std::fmt::Display for Endpoint { Endpoint::Tcp(ip, port) => { write!(f, "tcp://{}:{}", ip, port) } + Endpoint::Tls(ip, port) => { + write!(f, "tls://{}:{}", ip, port) + } Endpoint::Unix(path) => { if path.is_empty() { write!(f, "unix:/UNNAMED") @@ -60,9 +64,10 @@ impl TryFrom for SocketAddr { type Error = Error; fn try_from(endpoint: Endpoint) -> std::result::Result { match endpoint { - Endpoint::Udp(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)), - Endpoint::Tcp(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)), - Endpoint::Unix(_) => Err(Error::TryFromEndpointError), + Endpoint::Udp(ip, port) | Endpoint::Tcp(ip, port) | Endpoint::Tls(ip, port) => { + Ok(SocketAddr::new(ip.try_into()?, port)) + } + Endpoint::Unix(_) => Err(Error::TryFromEndpoint), } } } @@ -72,7 +77,7 @@ impl TryFrom for PathBuf { fn try_from(endpoint: Endpoint) -> std::result::Result { match endpoint { Endpoint::Unix(path) => Ok(PathBuf::from(&path)), - _ => Err(Error::TryFromEndpointError), + _ => Err(Error::TryFromEndpoint), } } } @@ -82,7 +87,7 @@ impl TryFrom for UnixSocketAddress { fn try_from(endpoint: Endpoint) -> std::result::Result { match endpoint { Endpoint::Unix(a) => Ok(UnixSocketAddress::from_pathname(a)?), - _ => Err(Error::TryFromEndpointError), + _ => Err(Error::TryFromEndpoint), } } } @@ -112,6 +117,7 @@ impl FromStr for Endpoint { match url.scheme() { "tcp" => Ok(Endpoint::Tcp(addr, port)), "udp" => Ok(Endpoint::Udp(addr, port)), + "tls" => Ok(Endpoint::Tls(addr, port)), _ => Err(Error::InvalidEndpoint(s.to_string())), } } else { @@ -133,6 +139,11 @@ impl Endpoint { Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port()) } + /// Creates a new TLS endpoint from a `SocketAddr`. + pub fn new_tls_addr(addr: &SocketAddr) -> Endpoint { + Endpoint::Tls(Addr::Ip(addr.ip()), addr.port()) + } + /// Creates a new UDP endpoint from a `SocketAddr`. pub fn new_udp_addr(addr: &SocketAddr) -> Endpoint { Endpoint::Udp(Addr::Ip(addr.ip()), addr.port()) @@ -151,29 +162,62 @@ impl Endpoint { /// Returns the `Port` of the endpoint. pub fn port(&self) -> Result<&Port> { match self { - Endpoint::Tcp(_, port) => Ok(port), - Endpoint::Udp(_, port) => Ok(port), - _ => Err(Error::TryFromEndpointError), + Endpoint::Udp(_, port) | Endpoint::Tcp(_, port) | Endpoint::Tls(_, port) => Ok(port), + _ => Err(Error::TryFromEndpoint), } } /// Returns the `Addr` of the endpoint. pub fn addr(&self) -> Result<&Addr> { match self { - Endpoint::Tcp(addr, _) => Ok(addr), - Endpoint::Udp(addr, _) => Ok(addr), - _ => Err(Error::TryFromEndpointError), + Endpoint::Udp(addr, _) | Endpoint::Tcp(addr, _) | Endpoint::Tls(addr, _) => Ok(addr), + _ => Err(Error::TryFromEndpoint), } } } /// Addr defines a type for an address, either IP or domain. -#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Addr { Ip(IpAddr), Domain(String), } +impl Encode for Addr { + fn encode( + &self, + encoder: &mut E, + ) -> std::result::Result<(), bincode::error::EncodeError> { + match self { + Addr::Ip(addr) => { + 0u32.encode(encoder)?; + addr.encode(encoder) + } + Addr::Domain(domain) => { + 1u32.encode(encoder)?; + domain.encode(encoder) + } + } + } +} + +impl Decode for Addr { + fn decode( + decoder: &mut D, + ) -> std::result::Result { + match u32::decode(decoder)? { + 0 => Ok(Addr::Ip(IpAddr::decode(decoder)?)), + 1 => Ok(Addr::Domain(String::decode(decoder)?)), + found => Err(bincode::error::DecodeError::UnexpectedVariant { + allowed: &bincode::error::AllowedEnumVariants::Range { min: 0, max: 1 }, + found, + type_name: core::any::type_name::(), + }), + } + } +} +impl_borrow_decode!(Addr); + impl TryFrom for IpAddr { type Error = Error; fn try_from(addr: Addr) -> std::result::Result { diff --git a/net/src/error.rs b/net/src/error.rs index 346184a..5dd6348 100644 --- a/net/src/error.rs +++ b/net/src/error.rs @@ -8,7 +8,7 @@ pub enum Error { IO(#[from] std::io::Error), #[error("Try from endpoint Error")] - TryFromEndpointError, + TryFromEndpoint, #[error("invalid address {0}")] InvalidAddress(String), @@ -28,6 +28,12 @@ pub enum Error { #[error(transparent)] ChannelRecv(#[from] smol::channel::RecvError), + #[error("Tls Error: {0}")] + Rustls(#[from] async_rustls::rustls::Error), + + #[error("Invalid DNS Name: {0}")] + InvalidDnsNameError(#[from] async_rustls::rustls::client::InvalidDnsNameError), + #[error(transparent)] KaryonsCore(#[from] karyons_core::error::Error), } diff --git a/net/src/lib.rs b/net/src/lib.rs index 0e4c361..61069ef 100644 --- a/net/src/lib.rs +++ b/net/src/lib.rs @@ -10,6 +10,7 @@ pub use { listener::{listen, Listener}, transports::{ tcp::{dial_tcp, listen_tcp, TcpConn}, + tls, udp::{dial_udp, listen_udp, UdpConn}, unix::{dial_unix, listen_unix, UnixConn}, }, diff --git a/net/src/listener.rs b/net/src/listener.rs index 31a63ae..c6c3d94 100644 --- a/net/src/listener.rs +++ b/net/src/listener.rs @@ -1,9 +1,8 @@ -use crate::{Endpoint, Error, Result}; use async_trait::async_trait; use crate::{ transports::{tcp, unix}, - Conn, + Conn, Endpoint, Error, Result, }; /// Listener is a generic network listener. @@ -15,7 +14,7 @@ pub trait Listener: Send + Sync { /// Listens to the provided endpoint. /// -/// it only supports `tcp4/6` and `unix`. +/// it only supports `tcp4/6`, and `unix`. /// /// #Example /// diff --git a/net/src/transports/mod.rs b/net/src/transports/mod.rs index f399133..ac23021 100644 --- a/net/src/transports/mod.rs +++ b/net/src/transports/mod.rs @@ -1,3 +1,4 @@ pub mod tcp; +pub mod tls; pub mod udp; pub mod unix; diff --git a/net/src/transports/tcp.rs b/net/src/transports/tcp.rs index 84aa980..37f00a7 100644 --- a/net/src/transports/tcp.rs +++ b/net/src/transports/tcp.rs @@ -13,7 +13,7 @@ use crate::{ Error, Result, }; -/// TCP network connection implementations of the [`Connection`] trait. +/// TCP network connection implementation of the [`Connection`] trait. pub struct TcpConn { inner: TcpStream, read: Mutex>, diff --git a/net/src/transports/tls.rs b/net/src/transports/tls.rs new file mode 100644 index 0000000..01bb5aa --- /dev/null +++ b/net/src/transports/tls.rs @@ -0,0 +1,140 @@ +use std::sync::Arc; + +use async_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream}; +use async_trait::async_trait; +use smol::{ + io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, + lock::Mutex, + net::{TcpListener, TcpStream}, +}; + +use crate::{ + connection::Connection, + endpoint::{Addr, Endpoint, Port}, + listener::Listener, + Error, Result, +}; + +/// TLS network connection implementation of the [`Connection`] trait. +pub struct TlsConn { + inner: TcpStream, + read: Mutex>>, + write: Mutex>>, +} + +impl TlsConn { + /// Creates a new TlsConn + pub fn new(sock: TcpStream, conn: TlsStream) -> Self { + let (read, write) = split(conn); + Self { + inner: sock, + read: Mutex::new(read), + write: Mutex::new(write), + } + } +} + +#[async_trait] +impl Connection for TlsConn { + fn peer_endpoint(&self) -> Result { + Ok(Endpoint::new_tls_addr(&self.inner.peer_addr()?)) + } + + fn local_endpoint(&self) -> Result { + Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?)) + } + + async fn read(&self, buf: &mut [u8]) -> Result { + self.read.lock().await.read(buf).await.map_err(Error::from) + } + + async fn write(&self, buf: &[u8]) -> Result { + self.write + .lock() + .await + .write(buf) + .await + .map_err(Error::from) + } +} + +/// Connects to the given TLS address and port. +pub async fn dial_tls( + addr: &Addr, + port: &Port, + config: rustls::ClientConfig, + dns_name: &str, +) -> Result { + let address = format!("{}:{}", addr, port); + + let connector = TlsConnector::from(Arc::new(config)); + + let sock = TcpStream::connect(&address).await?; + sock.set_nodelay(true)?; + + let altname = rustls::ServerName::try_from(dns_name)?; + let conn = connector.connect(altname, sock.clone()).await?; + Ok(TlsConn::new(sock, TlsStream::Client(conn))) +} + +/// Connects to the given TLS endpoint, returns `Conn` ([`Connection`]). +pub async fn dial( + endpoint: &Endpoint, + config: rustls::ClientConfig, + dns_name: &str, +) -> Result> { + match endpoint { + Endpoint::Tcp(..) | Endpoint::Tls(..) => {} + _ => return Err(Error::InvalidEndpoint(endpoint.to_string())), + } + + dial_tls(endpoint.addr()?, endpoint.port()?, config, dns_name) + .await + .map(|c| Box::new(c) as Box) +} +/// Tls network listener implementation of the [`Listener`] trait. +pub struct TlsListener { + acceptor: TlsAcceptor, + listener: TcpListener, +} + +#[async_trait] +impl Listener for TlsListener { + fn local_endpoint(&self) -> Result { + Ok(Endpoint::new_tls_addr(&self.listener.local_addr()?)) + } + + async fn accept(&self) -> Result> { + let (sock, _) = self.listener.accept().await?; + sock.set_nodelay(true)?; + let conn = self.acceptor.accept(sock.clone()).await?; + Ok(Box::new(TlsConn::new(sock, TlsStream::Server(conn)))) + } +} + +/// Listens on the given TLS address and port. +pub async fn listen_tls( + addr: &Addr, + port: &Port, + config: rustls::ServerConfig, +) -> Result { + let address = format!("{}:{}", addr, port); + let acceptor = TlsAcceptor::from(Arc::new(config)); + let listener = TcpListener::bind(&address).await?; + Ok(TlsListener { acceptor, listener }) +} + +/// Listens on the given TLS endpoint, returns [`Listener`]. +pub async fn listen( + endpoint: &Endpoint, + config: rustls::ServerConfig, +) -> Result> { + match endpoint { + Endpoint::Tcp(..) | Endpoint::Tls(..) => {} + _ => return Err(Error::InvalidEndpoint(endpoint.to_string())), + } + + listen_tls(endpoint.addr()?, endpoint.port()?, config) + .await + .map(|l| Box::new(l) as Box) +} diff --git a/net/src/transports/udp.rs b/net/src/transports/udp.rs index ca5b94d..8a2fbec 100644 --- a/net/src/transports/udp.rs +++ b/net/src/transports/udp.rs @@ -9,7 +9,7 @@ use crate::{ Error, Result, }; -/// UDP network connection implementations of the [`Connection`] trait. +/// UDP network connection implementation of the [`Connection`] trait. pub struct UdpConn { inner: UdpSocket, } diff --git a/net/src/transports/unix.rs b/net/src/transports/unix.rs index a720d91..e504934 100644 --- a/net/src/transports/unix.rs +++ b/net/src/transports/unix.rs @@ -8,7 +8,7 @@ use smol::{ use crate::{connection::Connection, endpoint::Endpoint, listener::Listener, Error, Result}; -/// Unix domain socket implementations of the [`Connection`] trait. +/// Unix domain socket implementation of the [`Connection`] trait. pub struct UnixConn { inner: UnixStream, read: Mutex>, -- cgit v1.2.3