From 0992071a7f1a36424bcfaf1fbc84541ea041df1a Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 11 Apr 2024 10:19:20 +0200 Subject: add support for tokio & improve net crate api --- net/src/transports/tls.rs | 220 +++++++++++++++++++++++++++++++--------------- 1 file changed, 150 insertions(+), 70 deletions(-) (limited to 'net/src/transports/tls.rs') diff --git a/net/src/transports/tls.rs b/net/src/transports/tls.rs index 476f495..c972f63 100644 --- a/net/src/transports/tls.rs +++ b/net/src/transports/tls.rs @@ -1,138 +1,218 @@ use std::{net::SocketAddr, sync::Arc}; use async_trait::async_trait; -use futures_rustls::{pki_types, rustls, TlsAcceptor, TlsConnector, TlsStream}; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, +use futures_util::SinkExt; +use rustls_pki_types as pki_types; + +#[cfg(feature = "smol")] +use futures_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream}; +#[cfg(feature = "tokio")] +use tokio_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream}; + +use karyon_core::async_runtime::{ + io::{split, ReadHalf, WriteHalf}, lock::Mutex, net::{TcpListener, TcpStream}, }; use crate::{ - connection::{Connection, ToConn}, + codec::Codec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, - listener::{ConnListener, ToListener}, - Error, Result, + listener::{ConnListener, Listener, ToListener}, + stream::{ReadStream, WriteStream}, + Result, }; +use super::tcp::TcpConfig; + +/// TLS configuration +#[derive(Clone)] +pub struct ServerTlsConfig { + pub tcp_config: TcpConfig, + pub server_config: rustls::ServerConfig, +} + +#[derive(Clone)] +pub struct ClientTlsConfig { + pub tcp_config: TcpConfig, + pub client_config: rustls::ClientConfig, + pub dns_name: String, +} + /// TLS network connection implementation of the [`Connection`] trait. -pub struct TlsConn { - inner: TcpStream, - read: Mutex>>, - write: Mutex>>, +pub struct TlsConn { + read_stream: Mutex>, C>>, + write_stream: Mutex>, C>>, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, } -impl TlsConn { +impl TlsConn +where + C: Codec + Clone, +{ /// Creates a new TlsConn - pub fn new(sock: TcpStream, conn: TlsStream) -> Self { + pub fn new( + conn: TlsStream, + codec: C, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, + ) -> Self { let (read, write) = split(conn); + let read_stream = Mutex::new(ReadStream::new(read, codec.clone())); + let write_stream = Mutex::new(WriteStream::new(write, codec)); Self { - inner: sock, - read: Mutex::new(read), - write: Mutex::new(write), + read_stream, + write_stream, + peer_endpoint, + local_endpoint, } } } #[async_trait] -impl Connection for TlsConn { +impl Connection for TlsConn +where + C: Clone + Codec, +{ + type Item = C::Item; fn peer_endpoint(&self) -> Result { - Ok(Endpoint::new_tls_addr(&self.inner.peer_addr()?)) + Ok(self.peer_endpoint.clone()) } fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?)) + Ok(self.local_endpoint.clone()) } - async fn read(&self, buf: &mut [u8]) -> Result { - self.read.lock().await.read(buf).await.map_err(Error::from) + async fn recv(&self) -> Result { + self.read_stream.lock().await.recv().await } - async fn write(&self, buf: &[u8]) -> Result { - self.write - .lock() - .await - .write(buf) - .await - .map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + self.write_stream.lock().await.send(msg).await } } /// Connects to the given TLS address and port. -pub async fn dial( - endpoint: &Endpoint, - config: rustls::ClientConfig, - dns_name: &'static str, -) -> Result { +pub async fn dial(endpoint: &Endpoint, config: ClientTlsConfig, codec: C) -> Result> +where + C: Codec + Clone, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let connector = TlsConnector::from(Arc::new(config)); + let connector = TlsConnector::from(Arc::new(config.client_config.clone())); + + let socket = TcpStream::connect(addr).await?; + socket.set_nodelay(config.tcp_config.nodelay)?; - let sock = TcpStream::connect(addr).await?; - sock.set_nodelay(true)?; + let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?; - let altname = pki_types::ServerName::try_from(dns_name)?; - let conn = connector.connect(altname, sock.clone()).await?; - Ok(TlsConn::new(sock, TlsStream::Client(conn))) + let altname = pki_types::ServerName::try_from(config.dns_name.clone())?; + let conn = connector.connect(altname, socket).await?; + Ok(TlsConn::new( + TlsStream::Client(conn), + codec, + peer_endpoint, + local_endpoint, + )) } /// Tls network listener implementation of the `Listener` [`ConnListener`] trait. -pub struct TlsListener { +pub struct TlsListener { inner: TcpListener, acceptor: TlsAcceptor, + config: ServerTlsConfig, + codec: C, +} + +impl TlsListener +where + C: Codec + Clone, +{ + pub fn new( + acceptor: TlsAcceptor, + listener: TcpListener, + config: ServerTlsConfig, + codec: C, + ) -> Self { + Self { + inner: listener, + acceptor, + config: config.clone(), + codec, + } + } } #[async_trait] -impl ConnListener for TlsListener { +impl ConnListener for TlsListener +where + C: Clone + Codec, +{ + type Item = C::Item; fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?)) + Ok(Endpoint::new_tls_addr(self.inner.local_addr()?)) } - async fn accept(&self) -> Result> { - let (sock, _) = self.inner.accept().await?; - sock.set_nodelay(true)?; - let conn = self.acceptor.accept(sock.clone()).await?; - Ok(Box::new(TlsConn::new(sock, TlsStream::Server(conn)))) + async fn accept(&self) -> Result> { + let (socket, _) = self.inner.accept().await?; + socket.set_nodelay(self.config.tcp_config.nodelay)?; + + let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?; + + let conn = self.acceptor.accept(socket).await?; + Ok(Box::new(TlsConn::new( + TlsStream::Server(conn), + self.codec.clone(), + peer_endpoint, + local_endpoint, + ))) } } /// Listens on the given TLS address and port. -pub async fn listen(endpoint: &Endpoint, config: rustls::ServerConfig) -> Result { +pub async fn listen( + endpoint: &Endpoint, + config: ServerTlsConfig, + codec: C, +) -> Result> +where + C: Clone + Codec, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let acceptor = TlsAcceptor::from(Arc::new(config)); + let acceptor = TlsAcceptor::from(Arc::new(config.server_config.clone())); let listener = TcpListener::bind(addr).await?; - Ok(TlsListener { - acceptor, - inner: listener, - }) -} - -impl From> for Box { - fn from(conn: TlsStream) -> Self { - Box::new(TlsConn::new(conn.get_ref().0.clone(), conn)) - } + Ok(TlsListener::new(acceptor, listener, config, codec)) } -impl From for Box { - fn from(listener: TlsListener) -> Self { +impl From> for Listener +where + C: Codec + Clone, +{ + fn from(listener: TlsListener) -> Self { Box::new(listener) } } -impl ToConn for TlsStream { - fn to_conn(self) -> Box { - self.into() - } -} - -impl ToConn for TlsConn { - fn to_conn(self) -> Box { +impl ToConn for TlsConn +where + C: Codec + Clone, +{ + type Item = C::Item; + fn to_conn(self) -> Conn { Box::new(self) } } -impl ToListener for TlsListener { - fn to_listener(self) -> Box { +impl ToListener for TlsListener +where + C: Clone + Codec, +{ + type Item = C::Item; + fn to_listener(self) -> Listener { self.into() } } -- cgit v1.2.3