From 98a1de91a2dae06323558422c239e5a45fc86e7b Mon Sep 17 00:00:00 2001 From: hozan23 Date: Tue, 28 Nov 2023 22:41:33 +0300 Subject: implement TLS for inbound and outbound connections --- net/src/transports/tls.rs | 140 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 net/src/transports/tls.rs (limited to 'net/src/transports/tls.rs') diff --git a/net/src/transports/tls.rs b/net/src/transports/tls.rs new file mode 100644 index 0000000..01bb5aa --- /dev/null +++ b/net/src/transports/tls.rs @@ -0,0 +1,140 @@ +use std::sync::Arc; + +use async_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream}; +use async_trait::async_trait; +use smol::{ + io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, + lock::Mutex, + net::{TcpListener, TcpStream}, +}; + +use crate::{ + connection::Connection, + endpoint::{Addr, Endpoint, Port}, + listener::Listener, + Error, Result, +}; + +/// TLS network connection implementation of the [`Connection`] trait. +pub struct TlsConn { + inner: TcpStream, + read: Mutex>>, + write: Mutex>>, +} + +impl TlsConn { + /// Creates a new TlsConn + pub fn new(sock: TcpStream, conn: TlsStream) -> Self { + let (read, write) = split(conn); + Self { + inner: sock, + read: Mutex::new(read), + write: Mutex::new(write), + } + } +} + +#[async_trait] +impl Connection for TlsConn { + fn peer_endpoint(&self) -> Result { + Ok(Endpoint::new_tls_addr(&self.inner.peer_addr()?)) + } + + fn local_endpoint(&self) -> Result { + Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?)) + } + + async fn read(&self, buf: &mut [u8]) -> Result { + self.read.lock().await.read(buf).await.map_err(Error::from) + } + + async fn write(&self, buf: &[u8]) -> Result { + self.write + .lock() + .await + .write(buf) + .await + .map_err(Error::from) + } +} + +/// Connects to the given TLS address and port. +pub async fn dial_tls( + addr: &Addr, + port: &Port, + config: rustls::ClientConfig, + dns_name: &str, +) -> Result { + let address = format!("{}:{}", addr, port); + + let connector = TlsConnector::from(Arc::new(config)); + + let sock = TcpStream::connect(&address).await?; + sock.set_nodelay(true)?; + + let altname = rustls::ServerName::try_from(dns_name)?; + let conn = connector.connect(altname, sock.clone()).await?; + Ok(TlsConn::new(sock, TlsStream::Client(conn))) +} + +/// Connects to the given TLS endpoint, returns `Conn` ([`Connection`]). +pub async fn dial( + endpoint: &Endpoint, + config: rustls::ClientConfig, + dns_name: &str, +) -> Result> { + match endpoint { + Endpoint::Tcp(..) | Endpoint::Tls(..) => {} + _ => return Err(Error::InvalidEndpoint(endpoint.to_string())), + } + + dial_tls(endpoint.addr()?, endpoint.port()?, config, dns_name) + .await + .map(|c| Box::new(c) as Box) +} +/// Tls network listener implementation of the [`Listener`] trait. +pub struct TlsListener { + acceptor: TlsAcceptor, + listener: TcpListener, +} + +#[async_trait] +impl Listener for TlsListener { + fn local_endpoint(&self) -> Result { + Ok(Endpoint::new_tls_addr(&self.listener.local_addr()?)) + } + + async fn accept(&self) -> Result> { + let (sock, _) = self.listener.accept().await?; + sock.set_nodelay(true)?; + let conn = self.acceptor.accept(sock.clone()).await?; + Ok(Box::new(TlsConn::new(sock, TlsStream::Server(conn)))) + } +} + +/// Listens on the given TLS address and port. +pub async fn listen_tls( + addr: &Addr, + port: &Port, + config: rustls::ServerConfig, +) -> Result { + let address = format!("{}:{}", addr, port); + let acceptor = TlsAcceptor::from(Arc::new(config)); + let listener = TcpListener::bind(&address).await?; + Ok(TlsListener { acceptor, listener }) +} + +/// Listens on the given TLS endpoint, returns [`Listener`]. +pub async fn listen( + endpoint: &Endpoint, + config: rustls::ServerConfig, +) -> Result> { + match endpoint { + Endpoint::Tcp(..) | Endpoint::Tls(..) => {} + _ => return Err(Error::InvalidEndpoint(endpoint.to_string())), + } + + listen_tls(endpoint.addr()?, endpoint.port()?, config) + .await + .map(|l| Box::new(l) as Box) +} -- cgit v1.2.3