use std::sync::Arc; use async_trait::async_trait; use futures_rustls::{pki_types, rustls, TlsAcceptor, TlsConnector, TlsStream}; use smol::{ io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, lock::Mutex, net::{TcpListener, TcpStream}, }; use crate::{ connection::{Connection, ToConn}, endpoint::{Addr, Endpoint, Port}, listener::{ConnListener, ToListener}, Error, Result, }; /// TLS network connection implementation of the [`Connection`] trait. pub struct TlsConn { inner: TcpStream, read: Mutex>>, write: Mutex>>, } impl TlsConn { /// Creates a new TlsConn pub fn new(sock: TcpStream, conn: TlsStream) -> Self { let (read, write) = split(conn); Self { inner: sock, read: Mutex::new(read), write: Mutex::new(write), } } } #[async_trait] impl Connection for TlsConn { fn peer_endpoint(&self) -> Result { Ok(Endpoint::new_tls_addr(&self.inner.peer_addr()?)) } fn local_endpoint(&self) -> Result { Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?)) } async fn read(&self, buf: &mut [u8]) -> Result { self.read.lock().await.read(buf).await.map_err(Error::from) } async fn write(&self, buf: &[u8]) -> Result { self.write .lock() .await .write(buf) .await .map_err(Error::from) } } /// Connects to the given TLS address and port. pub async fn dial_tls( addr: &Addr, port: &Port, config: rustls::ClientConfig, dns_name: &'static str, ) -> Result { let address = format!("{}:{}", addr, port); let connector = TlsConnector::from(Arc::new(config)); let sock = TcpStream::connect(&address).await?; sock.set_nodelay(true)?; let altname = pki_types::ServerName::try_from(dns_name)?; let conn = connector.connect(altname, sock.clone()).await?; Ok(TlsConn::new(sock, TlsStream::Client(conn))) } /// Connects to the given TLS endpoint, returns `Conn` ([`Connection`]). pub async fn dial( endpoint: &Endpoint, config: rustls::ClientConfig, dns_name: &'static str, ) -> Result> { match endpoint { Endpoint::Tcp(..) | Endpoint::Tls(..) => {} _ => return Err(Error::InvalidEndpoint(endpoint.to_string())), } dial_tls(endpoint.addr()?, endpoint.port()?, config, dns_name) .await .map(|c| Box::new(c) as Box) } /// Tls network listener implementation of the `Listener` [`ConnListener`] trait. pub struct TlsListener { acceptor: TlsAcceptor, listener: TcpListener, } #[async_trait] impl ConnListener for TlsListener { fn local_endpoint(&self) -> Result { Ok(Endpoint::new_tls_addr(&self.listener.local_addr()?)) } async fn accept(&self) -> Result> { let (sock, _) = self.listener.accept().await?; sock.set_nodelay(true)?; let conn = self.acceptor.accept(sock.clone()).await?; Ok(Box::new(TlsConn::new(sock, TlsStream::Server(conn)))) } } /// Listens on the given TLS address and port. pub async fn listen_tls( addr: &Addr, port: &Port, config: rustls::ServerConfig, ) -> Result { let address = format!("{}:{}", addr, port); let acceptor = TlsAcceptor::from(Arc::new(config)); let listener = TcpListener::bind(&address).await?; Ok(TlsListener { acceptor, listener }) } /// Listens on the given TLS endpoint, returns `Listener` [`ConnListener`]. pub async fn listen( endpoint: &Endpoint, config: rustls::ServerConfig, ) -> Result> { match endpoint { Endpoint::Tcp(..) | Endpoint::Tls(..) => {} _ => return Err(Error::InvalidEndpoint(endpoint.to_string())), } listen_tls(endpoint.addr()?, endpoint.port()?, config) .await .map(|l| Box::new(l) as Box) } impl From> for Box { fn from(conn: TlsStream) -> Self { Box::new(TlsConn::new(conn.get_ref().0.clone(), conn)) } } impl From for Box { fn from(listener: TlsListener) -> Self { Box::new(listener) } } impl ToConn for TlsStream { fn to_conn(self) -> Box { self.into() } } impl ToConn for TlsConn { fn to_conn(self) -> Box { Box::new(self) } } impl ToListener for TlsListener { fn to_listener(self) -> Box { self.into() } }