From 0992071a7f1a36424bcfaf1fbc84541ea041df1a Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 11 Apr 2024 10:19:20 +0200 Subject: add support for tokio & improve net crate api --- net/src/transports/tcp.rs | 188 ++++++++++++++++++++++++----------- net/src/transports/tls.rs | 220 ++++++++++++++++++++++++++++------------- net/src/transports/udp.rs | 114 ++++++++++++--------- net/src/transports/unix.rs | 193 +++++++++++++++++++++++++----------- net/src/transports/ws.rs | 242 +++++++++++++++++++++++++++++++++++---------- 5 files changed, 673 insertions(+), 284 deletions(-) (limited to 'net/src/transports') diff --git a/net/src/transports/tcp.rs b/net/src/transports/tcp.rs index 21fce3d..03c8ab2 100644 --- a/net/src/transports/tcp.rs +++ b/net/src/transports/tcp.rs @@ -1,116 +1,184 @@ use std::net::SocketAddr; use async_trait::async_trait; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, +use futures_util::SinkExt; + +use karyon_core::async_runtime::{ + io::{split, ReadHalf, WriteHalf}, lock::Mutex, - net::{TcpListener, TcpStream}, + net::{TcpListener as AsyncTcpListener, TcpStream}, }; use crate::{ - connection::{Connection, ToConn}, + codec::Codec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, - listener::{ConnListener, ToListener}, - Error, Result, + listener::{ConnListener, Listener, ToListener}, + stream::{ReadStream, WriteStream}, + Result, }; -/// TCP network connection implementation of the [`Connection`] trait. -pub struct TcpConn { - inner: TcpStream, - read: Mutex>, - write: Mutex>, +/// TCP configuration +#[derive(Clone)] +pub struct TcpConfig { + pub nodelay: bool, +} + +impl Default for TcpConfig { + fn default() -> Self { + Self { nodelay: true } + } +} + +/// TCP connection implementation of the [`Connection`] trait. +pub struct TcpConn { + read_stream: Mutex, C>>, + write_stream: Mutex, C>>, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, } -impl TcpConn { +impl TcpConn +where + C: Codec + Clone, +{ /// Creates a new TcpConn - pub fn new(conn: TcpStream) -> Self { - let (read, write) = split(conn.clone()); + pub fn new( + socket: TcpStream, + codec: C, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, + ) -> Self { + let (read, write) = split(socket); + let read_stream = Mutex::new(ReadStream::new(read, codec.clone())); + let write_stream = Mutex::new(WriteStream::new(write, codec)); Self { - inner: conn, - read: Mutex::new(read), - write: Mutex::new(write), + read_stream, + write_stream, + peer_endpoint, + local_endpoint, } } } #[async_trait] -impl Connection for TcpConn { +impl Connection for TcpConn +where + C: Codec + Clone, +{ + type Item = C::Item; fn peer_endpoint(&self) -> Result { - Ok(Endpoint::new_tcp_addr(&self.inner.peer_addr()?)) + Ok(self.peer_endpoint.clone()) } fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_tcp_addr(&self.inner.local_addr()?)) + Ok(self.local_endpoint.clone()) } - async fn read(&self, buf: &mut [u8]) -> Result { - self.read.lock().await.read(buf).await.map_err(Error::from) + async fn recv(&self) -> Result { + self.read_stream.lock().await.recv().await } - async fn write(&self, buf: &[u8]) -> Result { - self.write - .lock() - .await - .write(buf) - .await - .map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + self.write_stream.lock().await.send(msg).await + } +} + +pub struct TcpListener { + inner: AsyncTcpListener, + config: TcpConfig, + codec: C, +} + +impl TcpListener +where + C: Codec, +{ + pub fn new(listener: AsyncTcpListener, config: TcpConfig, codec: C) -> Self { + Self { + inner: listener, + config: config.clone(), + codec, + } } } #[async_trait] -impl ConnListener for TcpListener { +impl ConnListener for TcpListener +where + C: Codec + Clone, +{ + type Item = C::Item; fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_tcp_addr(&self.local_addr()?)) + Ok(Endpoint::new_tcp_addr(self.inner.local_addr()?)) } - async fn accept(&self) -> Result> { - let (conn, _) = self.accept().await?; - conn.set_nodelay(true)?; - Ok(Box::new(TcpConn::new(conn))) + async fn accept(&self) -> Result> { + let (socket, _) = self.inner.accept().await?; + socket.set_nodelay(self.config.nodelay)?; + + let peer_endpoint = socket.peer_addr().map(Endpoint::new_tcp_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_tcp_addr)?; + + Ok(Box::new(TcpConn::new( + socket, + self.codec.clone(), + peer_endpoint, + local_endpoint, + ))) } } /// Connects to the given TCP address and port. -pub async fn dial(endpoint: &Endpoint) -> Result { +pub async fn dial(endpoint: &Endpoint, config: TcpConfig, codec: C) -> Result> +where + C: Codec + Clone, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let conn = TcpStream::connect(addr).await?; - conn.set_nodelay(true)?; - Ok(TcpConn::new(conn)) + let socket = TcpStream::connect(addr).await?; + socket.set_nodelay(config.nodelay)?; + + let peer_endpoint = socket.peer_addr().map(Endpoint::new_tcp_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_tcp_addr)?; + + Ok(TcpConn::new(socket, codec, peer_endpoint, local_endpoint)) } /// Listens on the given TCP address and port. -pub async fn listen(endpoint: &Endpoint) -> Result { +pub async fn listen(endpoint: &Endpoint, config: TcpConfig, codec: C) -> Result> +where + C: Codec, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let listener = TcpListener::bind(addr).await?; - Ok(listener) -} - -impl From for Box { - fn from(conn: TcpStream) -> Self { - Box::new(TcpConn::new(conn)) - } + let listener = AsyncTcpListener::bind(addr).await?; + Ok(TcpListener::new(listener, config, codec)) } -impl From for Box { - fn from(listener: TcpListener) -> Self { +impl From> for Box> +where + C: Clone + Codec, +{ + fn from(listener: TcpListener) -> Self { Box::new(listener) } } -impl ToConn for TcpStream { - fn to_conn(self) -> Box { - self.into() - } -} - -impl ToConn for TcpConn { - fn to_conn(self) -> Box { +impl ToConn for TcpConn +where + C: Codec + Clone, +{ + type Item = C::Item; + fn to_conn(self) -> Conn { Box::new(self) } } -impl ToListener for TcpListener { - fn to_listener(self) -> Box { +impl ToListener for TcpListener +where + C: Clone + Codec, +{ + type Item = C::Item; + fn to_listener(self) -> Listener { self.into() } } diff --git a/net/src/transports/tls.rs b/net/src/transports/tls.rs index 476f495..c972f63 100644 --- a/net/src/transports/tls.rs +++ b/net/src/transports/tls.rs @@ -1,138 +1,218 @@ use std::{net::SocketAddr, sync::Arc}; use async_trait::async_trait; -use futures_rustls::{pki_types, rustls, TlsAcceptor, TlsConnector, TlsStream}; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, +use futures_util::SinkExt; +use rustls_pki_types as pki_types; + +#[cfg(feature = "smol")] +use futures_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream}; +#[cfg(feature = "tokio")] +use tokio_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream}; + +use karyon_core::async_runtime::{ + io::{split, ReadHalf, WriteHalf}, lock::Mutex, net::{TcpListener, TcpStream}, }; use crate::{ - connection::{Connection, ToConn}, + codec::Codec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, - listener::{ConnListener, ToListener}, - Error, Result, + listener::{ConnListener, Listener, ToListener}, + stream::{ReadStream, WriteStream}, + Result, }; +use super::tcp::TcpConfig; + +/// TLS configuration +#[derive(Clone)] +pub struct ServerTlsConfig { + pub tcp_config: TcpConfig, + pub server_config: rustls::ServerConfig, +} + +#[derive(Clone)] +pub struct ClientTlsConfig { + pub tcp_config: TcpConfig, + pub client_config: rustls::ClientConfig, + pub dns_name: String, +} + /// TLS network connection implementation of the [`Connection`] trait. -pub struct TlsConn { - inner: TcpStream, - read: Mutex>>, - write: Mutex>>, +pub struct TlsConn { + read_stream: Mutex>, C>>, + write_stream: Mutex>, C>>, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, } -impl TlsConn { +impl TlsConn +where + C: Codec + Clone, +{ /// Creates a new TlsConn - pub fn new(sock: TcpStream, conn: TlsStream) -> Self { + pub fn new( + conn: TlsStream, + codec: C, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, + ) -> Self { let (read, write) = split(conn); + let read_stream = Mutex::new(ReadStream::new(read, codec.clone())); + let write_stream = Mutex::new(WriteStream::new(write, codec)); Self { - inner: sock, - read: Mutex::new(read), - write: Mutex::new(write), + read_stream, + write_stream, + peer_endpoint, + local_endpoint, } } } #[async_trait] -impl Connection for TlsConn { +impl Connection for TlsConn +where + C: Clone + Codec, +{ + type Item = C::Item; fn peer_endpoint(&self) -> Result { - Ok(Endpoint::new_tls_addr(&self.inner.peer_addr()?)) + Ok(self.peer_endpoint.clone()) } fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?)) + Ok(self.local_endpoint.clone()) } - async fn read(&self, buf: &mut [u8]) -> Result { - self.read.lock().await.read(buf).await.map_err(Error::from) + async fn recv(&self) -> Result { + self.read_stream.lock().await.recv().await } - async fn write(&self, buf: &[u8]) -> Result { - self.write - .lock() - .await - .write(buf) - .await - .map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + self.write_stream.lock().await.send(msg).await } } /// Connects to the given TLS address and port. -pub async fn dial( - endpoint: &Endpoint, - config: rustls::ClientConfig, - dns_name: &'static str, -) -> Result { +pub async fn dial(endpoint: &Endpoint, config: ClientTlsConfig, codec: C) -> Result> +where + C: Codec + Clone, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let connector = TlsConnector::from(Arc::new(config)); + let connector = TlsConnector::from(Arc::new(config.client_config.clone())); + + let socket = TcpStream::connect(addr).await?; + socket.set_nodelay(config.tcp_config.nodelay)?; - let sock = TcpStream::connect(addr).await?; - sock.set_nodelay(true)?; + let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?; - let altname = pki_types::ServerName::try_from(dns_name)?; - let conn = connector.connect(altname, sock.clone()).await?; - Ok(TlsConn::new(sock, TlsStream::Client(conn))) + let altname = pki_types::ServerName::try_from(config.dns_name.clone())?; + let conn = connector.connect(altname, socket).await?; + Ok(TlsConn::new( + TlsStream::Client(conn), + codec, + peer_endpoint, + local_endpoint, + )) } /// Tls network listener implementation of the `Listener` [`ConnListener`] trait. -pub struct TlsListener { +pub struct TlsListener { inner: TcpListener, acceptor: TlsAcceptor, + config: ServerTlsConfig, + codec: C, +} + +impl TlsListener +where + C: Codec + Clone, +{ + pub fn new( + acceptor: TlsAcceptor, + listener: TcpListener, + config: ServerTlsConfig, + codec: C, + ) -> Self { + Self { + inner: listener, + acceptor, + config: config.clone(), + codec, + } + } } #[async_trait] -impl ConnListener for TlsListener { +impl ConnListener for TlsListener +where + C: Clone + Codec, +{ + type Item = C::Item; fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?)) + Ok(Endpoint::new_tls_addr(self.inner.local_addr()?)) } - async fn accept(&self) -> Result> { - let (sock, _) = self.inner.accept().await?; - sock.set_nodelay(true)?; - let conn = self.acceptor.accept(sock.clone()).await?; - Ok(Box::new(TlsConn::new(sock, TlsStream::Server(conn)))) + async fn accept(&self) -> Result> { + let (socket, _) = self.inner.accept().await?; + socket.set_nodelay(self.config.tcp_config.nodelay)?; + + let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?; + + let conn = self.acceptor.accept(socket).await?; + Ok(Box::new(TlsConn::new( + TlsStream::Server(conn), + self.codec.clone(), + peer_endpoint, + local_endpoint, + ))) } } /// Listens on the given TLS address and port. -pub async fn listen(endpoint: &Endpoint, config: rustls::ServerConfig) -> Result { +pub async fn listen( + endpoint: &Endpoint, + config: ServerTlsConfig, + codec: C, +) -> Result> +where + C: Clone + Codec, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let acceptor = TlsAcceptor::from(Arc::new(config)); + let acceptor = TlsAcceptor::from(Arc::new(config.server_config.clone())); let listener = TcpListener::bind(addr).await?; - Ok(TlsListener { - acceptor, - inner: listener, - }) -} - -impl From> for Box { - fn from(conn: TlsStream) -> Self { - Box::new(TlsConn::new(conn.get_ref().0.clone(), conn)) - } + Ok(TlsListener::new(acceptor, listener, config, codec)) } -impl From for Box { - fn from(listener: TlsListener) -> Self { +impl From> for Listener +where + C: Codec + Clone, +{ + fn from(listener: TlsListener) -> Self { Box::new(listener) } } -impl ToConn for TlsStream { - fn to_conn(self) -> Box { - self.into() - } -} - -impl ToConn for TlsConn { - fn to_conn(self) -> Box { +impl ToConn for TlsConn +where + C: Codec + Clone, +{ + type Item = C::Item; + fn to_conn(self) -> Conn { Box::new(self) } } -impl ToListener for TlsListener { - fn to_listener(self) -> Box { +impl ToListener for TlsListener +where + C: Clone + Codec, +{ + type Item = C::Item; + fn to_listener(self) -> Listener { self.into() } } diff --git a/net/src/transports/udp.rs b/net/src/transports/udp.rs index bd1fe83..950537c 100644 --- a/net/src/transports/udp.rs +++ b/net/src/transports/udp.rs @@ -1,93 +1,111 @@ use std::net::SocketAddr; use async_trait::async_trait; -use smol::net::UdpSocket; +use karyon_core::async_runtime::net::UdpSocket; use crate::{ - connection::{Connection, ToConn}, + codec::Codec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, Error, Result, }; +const BUFFER_SIZE: usize = 64 * 1024; + +/// UDP configuration +#[derive(Default)] +pub struct UdpConfig {} + /// UDP network connection implementation of the [`Connection`] trait. -pub struct UdpConn { +#[allow(dead_code)] +pub struct UdpConn { inner: UdpSocket, + codec: C, + config: UdpConfig, } -impl UdpConn { +impl UdpConn +where + C: Codec + Clone, +{ /// Creates a new UdpConn - pub fn new(conn: UdpSocket) -> Self { - Self { inner: conn } - } -} - -impl UdpConn { - /// Receives a single datagram message. Returns the number of bytes read - /// and the origin endpoint. - pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, Endpoint)> { - let (size, addr) = self.inner.recv_from(buf).await?; - Ok((size, Endpoint::new_udp_addr(&addr))) - } - - /// Sends data to the given address. Returns the number of bytes written. - pub async fn send_to(&self, buf: &[u8], addr: &Endpoint) -> Result { - let addr: SocketAddr = addr.clone().try_into()?; - let size = self.inner.send_to(buf, addr).await?; - Ok(size) + fn new(socket: UdpSocket, config: UdpConfig, codec: C) -> Self { + Self { + inner: socket, + codec, + config, + } } } #[async_trait] -impl Connection for UdpConn { +impl Connection for UdpConn +where + C: Codec + Clone, +{ + type Item = (C::Item, Endpoint); fn peer_endpoint(&self) -> Result { - Ok(Endpoint::new_udp_addr(&self.inner.peer_addr()?)) + self.inner + .peer_addr() + .map(Endpoint::new_udp_addr) + .map_err(Error::from) } fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_udp_addr(&self.inner.local_addr()?)) + self.inner + .local_addr() + .map(Endpoint::new_udp_addr) + .map_err(Error::from) } - async fn read(&self, buf: &mut [u8]) -> Result { - self.inner.recv(buf).await.map_err(Error::from) + async fn recv(&self) -> Result { + let mut buf = [0u8; BUFFER_SIZE]; + let (_, addr) = self.inner.recv_from(&mut buf).await?; + match self.codec.decode(&mut buf)? { + Some((_, item)) => Ok((item, Endpoint::new_udp_addr(addr))), + None => Err(Error::Decode("Unable to decode".into())), + } } - async fn write(&self, buf: &[u8]) -> Result { - self.inner.send(buf).await.map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + let (msg, out_addr) = msg; + let mut buf = [0u8; BUFFER_SIZE]; + self.codec.encode(&msg, &mut buf)?; + let addr: SocketAddr = out_addr.try_into()?; + self.inner.send_to(&buf, addr).await?; + Ok(()) } } /// Connects to the given UDP address and port. -pub async fn dial(endpoint: &Endpoint) -> Result { +pub async fn dial(endpoint: &Endpoint, config: UdpConfig, codec: C) -> Result> +where + C: Codec + Clone, +{ let addr = SocketAddr::try_from(endpoint.clone())?; // Let the operating system assign an available port to this socket let conn = UdpSocket::bind("[::]:0").await?; conn.connect(addr).await?; - Ok(UdpConn::new(conn)) + Ok(UdpConn::new(conn, config, codec)) } /// Listens on the given UDP address and port. -pub async fn listen(endpoint: &Endpoint) -> Result { +pub async fn listen(endpoint: &Endpoint, config: UdpConfig, codec: C) -> Result> +where + C: Codec + Clone, +{ let addr = SocketAddr::try_from(endpoint.clone())?; let conn = UdpSocket::bind(addr).await?; - let udp_conn = UdpConn::new(conn); - Ok(udp_conn) -} - -impl From for Box { - fn from(conn: UdpSocket) -> Self { - Box::new(UdpConn::new(conn)) - } -} - -impl ToConn for UdpSocket { - fn to_conn(self) -> Box { - self.into() - } + Ok(UdpConn::new(conn, config, codec)) } -impl ToConn for UdpConn { - fn to_conn(self) -> Box { +impl ToConn for UdpConn +where + C: Codec + Clone, +{ + type Item = (C::Item, Endpoint); + fn to_conn(self) -> Conn { Box::new(self) } } diff --git a/net/src/transports/unix.rs b/net/src/transports/unix.rs index 494e104..bafebaf 100644 --- a/net/src/transports/unix.rs +++ b/net/src/transports/unix.rs @@ -1,111 +1,192 @@ use async_trait::async_trait; +use futures_util::SinkExt; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, +use karyon_core::async_runtime::{ + io::{split, ReadHalf, WriteHalf}, lock::Mutex, - net::unix::{UnixListener, UnixStream}, + net::{UnixListener as AsyncUnixListener, UnixStream}, }; use crate::{ - connection::{Connection, ToConn}, + codec::Codec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, - listener::{ConnListener, ToListener}, + listener::{ConnListener, Listener, ToListener}, + stream::{ReadStream, WriteStream}, Error, Result, }; +/// Unix Conn config +#[derive(Clone, Default)] +pub struct UnixConfig {} + /// Unix domain socket implementation of the [`Connection`] trait. -pub struct UnixConn { - inner: UnixStream, - read: Mutex>, - write: Mutex>, +pub struct UnixConn { + read_stream: Mutex, C>>, + write_stream: Mutex, C>>, + peer_endpoint: Option, + local_endpoint: Option, } -impl UnixConn { - /// Creates a new UnixConn - pub fn new(conn: UnixStream) -> Self { - let (read, write) = split(conn.clone()); +impl UnixConn +where + C: Codec + Clone, +{ + /// Creates a new TcpConn + pub fn new(conn: UnixStream, codec: C) -> Self { + let peer_endpoint = conn + .peer_addr() + .and_then(|a| { + Ok(Endpoint::new_unix_addr( + a.as_pathname() + .ok_or(std::io::ErrorKind::AddrNotAvailable)?, + )) + }) + .ok(); + let local_endpoint = conn + .local_addr() + .and_then(|a| { + Ok(Endpoint::new_unix_addr( + a.as_pathname() + .ok_or(std::io::ErrorKind::AddrNotAvailable)?, + )) + }) + .ok(); + + let (read, write) = split(conn); + let read_stream = Mutex::new(ReadStream::new(read, codec.clone())); + let write_stream = Mutex::new(WriteStream::new(write, codec)); Self { - inner: conn, - read: Mutex::new(read), - write: Mutex::new(write), + read_stream, + write_stream, + peer_endpoint, + local_endpoint, } } } #[async_trait] -impl Connection for UnixConn { +impl Connection for UnixConn +where + C: Codec + Clone, +{ + type Item = C::Item; fn peer_endpoint(&self) -> Result { - Ok(Endpoint::new_unix_addr(&self.inner.peer_addr()?)) + self.peer_endpoint + .clone() + .ok_or(Error::IO(std::io::ErrorKind::AddrNotAvailable.into())) } fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_unix_addr(&self.inner.local_addr()?)) + self.local_endpoint + .clone() + .ok_or(Error::IO(std::io::ErrorKind::AddrNotAvailable.into())) } - async fn read(&self, buf: &mut [u8]) -> Result { - self.read.lock().await.read(buf).await.map_err(Error::from) + async fn recv(&self) -> Result { + self.read_stream.lock().await.recv().await } - async fn write(&self, buf: &[u8]) -> Result { - self.write - .lock() - .await - .write(buf) - .await - .map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + self.write_stream.lock().await.send(msg).await + } +} + +#[allow(dead_code)] +pub struct UnixListener { + inner: AsyncUnixListener, + config: UnixConfig, + codec: C, +} + +impl UnixListener +where + C: Codec + Clone, +{ + pub fn new(listener: AsyncUnixListener, config: UnixConfig, codec: C) -> Self { + Self { + inner: listener, + config, + codec, + } } } #[async_trait] -impl ConnListener for UnixListener { +impl ConnListener for UnixListener +where + C: Codec + Clone, +{ + type Item = C::Item; fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_unix_addr(&self.local_addr()?)) + self.inner + .local_addr() + .and_then(|a| { + Ok(Endpoint::new_unix_addr( + a.as_pathname() + .ok_or(std::io::ErrorKind::AddrNotAvailable)?, + )) + }) + .map_err(Error::from) } - async fn accept(&self) -> Result> { - let (conn, _) = self.accept().await?; - Ok(Box::new(UnixConn::new(conn))) + async fn accept(&self) -> Result> { + let (conn, _) = self.inner.accept().await?; + Ok(Box::new(UnixConn::new(conn, self.codec.clone()))) } } /// Connects to the given Unix socket path. -pub async fn dial(path: &String) -> Result { +pub async fn dial(endpoint: &Endpoint, _config: UnixConfig, codec: C) -> Result> +where + C: Codec + Clone, +{ + let path: std::path::PathBuf = endpoint.clone().try_into()?; let conn = UnixStream::connect(path).await?; - Ok(UnixConn::new(conn)) + Ok(UnixConn::new(conn, codec)) } /// Listens on the given Unix socket path. -pub fn listen(path: &String) -> Result { - let listener = UnixListener::bind(path)?; - Ok(listener) -} - -impl From for Box { - fn from(conn: UnixStream) -> Self { - Box::new(UnixConn::new(conn)) - } +pub fn listen(endpoint: &Endpoint, config: UnixConfig, codec: C) -> Result> +where + C: Codec + Clone, +{ + let path: std::path::PathBuf = endpoint.clone().try_into()?; + let listener = AsyncUnixListener::bind(path)?; + Ok(UnixListener::new(listener, config, codec)) } -impl From for Box { - fn from(listener: UnixListener) -> Self { +// impl From for Box { +// fn from(conn: UnixStream) -> Self { +// Box::new(UnixConn::new(conn)) +// } +// } + +impl From> for Listener +where + C: Codec + Clone, +{ + fn from(listener: UnixListener) -> Self { Box::new(listener) } } -impl ToConn for UnixStream { - fn to_conn(self) -> Box { - self.into() - } -} - -impl ToConn for UnixConn { - fn to_conn(self) -> Box { +impl ToConn for UnixConn +where + C: Codec + Clone, +{ + type Item = C::Item; + fn to_conn(self) -> Conn { Box::new(self) } } -impl ToListener for UnixListener { - fn to_listener(self) -> Box { +impl ToListener for UnixListener +where + C: Codec + Clone, +{ + type Item = C::Item; + fn to_listener(self) -> Listener { self.into() } } diff --git a/net/src/transports/ws.rs b/net/src/transports/ws.rs index eaf3b9b..17fe924 100644 --- a/net/src/transports/ws.rs +++ b/net/src/transports/ws.rs @@ -1,112 +1,254 @@ -use std::net::SocketAddr; +use std::{net::SocketAddr, sync::Arc}; use async_trait::async_trait; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, +use rustls_pki_types as pki_types; + +#[cfg(feature = "tokio")] +use async_tungstenite::tokio as async_tungstenite; + +#[cfg(feature = "smol")] +use futures_rustls::{rustls, TlsAcceptor, TlsConnector}; +#[cfg(feature = "tokio")] +use tokio_rustls::{rustls, TlsAcceptor, TlsConnector}; + +use karyon_core::async_runtime::{ lock::Mutex, net::{TcpListener, TcpStream}, }; -use ws_stream_tungstenite::WsStream; - use crate::{ - connection::{Connection, ToConn}, + codec::WebSocketCodec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, - listener::{ConnListener, ToListener}, - Error, Result, + listener::{ConnListener, Listener, ToListener}, + stream::WsStream, + Result, }; +use super::tcp::TcpConfig; + +/// WSS configuration +#[derive(Clone)] +pub struct ServerWssConfig { + pub server_config: rustls::ServerConfig, +} + +/// WSS configuration +#[derive(Clone)] +pub struct ClientWssConfig { + pub client_config: rustls::ClientConfig, + pub dns_name: String, +} + +/// WS configuration +#[derive(Clone, Default)] +pub struct ServerWsConfig { + pub tcp_config: TcpConfig, + pub wss_config: Option, +} + +/// WS configuration +#[derive(Clone, Default)] +pub struct ClientWsConfig { + pub tcp_config: TcpConfig, + pub wss_config: Option, +} + /// WS network connection implementation of the [`Connection`] trait. -pub struct WsConn { - inner: TcpStream, - read: Mutex>>, - write: Mutex>>, +pub struct WsConn { + // XXX: remove mutex + inner: Mutex>, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, } -impl WsConn { +impl WsConn +where + C: WebSocketCodec, +{ /// Creates a new WsConn - pub fn new(inner: TcpStream, conn: WsStream) -> Self { - let (read, write) = split(conn); + pub fn new(ws: WsStream, peer_endpoint: Endpoint, local_endpoint: Endpoint) -> Self { Self { - inner, - read: Mutex::new(read), - write: Mutex::new(write), + inner: Mutex::new(ws), + peer_endpoint, + local_endpoint, } } } #[async_trait] -impl Connection for WsConn { +impl Connection for WsConn +where + C: WebSocketCodec, +{ + type Item = C::Item; fn peer_endpoint(&self) -> Result { - Ok(Endpoint::new_ws_addr(&self.inner.peer_addr()?)) + Ok(self.peer_endpoint.clone()) } fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?)) + Ok(self.local_endpoint.clone()) } - async fn read(&self, buf: &mut [u8]) -> Result { - self.read.lock().await.read(buf).await.map_err(Error::from) + async fn recv(&self) -> Result { + self.inner.lock().await.recv().await } - async fn write(&self, buf: &[u8]) -> Result { - self.write - .lock() - .await - .write(buf) - .await - .map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + self.inner.lock().await.send(msg).await } } /// Ws network listener implementation of the `Listener` [`ConnListener`] trait. -pub struct WsListener { +pub struct WsListener { inner: TcpListener, + config: ServerWsConfig, + codec: C, + tls_acceptor: Option, } #[async_trait] -impl ConnListener for WsListener { +impl ConnListener for WsListener +where + C: WebSocketCodec + Clone, +{ + type Item = C::Item; fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?)) + match self.config.wss_config { + Some(_) => Ok(Endpoint::new_wss_addr(self.inner.local_addr()?)), + None => Ok(Endpoint::new_ws_addr(self.inner.local_addr()?)), + } } - async fn accept(&self) -> Result> { - let (stream, _) = self.inner.accept().await?; - let conn = async_tungstenite::accept_async(stream.clone()).await?; - Ok(Box::new(WsConn::new(stream, WsStream::new(conn)))) + async fn accept(&self) -> Result> { + let (socket, _) = self.inner.accept().await?; + socket.set_nodelay(self.config.tcp_config.nodelay)?; + + match &self.config.wss_config { + Some(_) => match &self.tls_acceptor { + Some(acceptor) => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?; + + let tls_conn = acceptor.accept(socket).await?.into(); + let conn = async_tungstenite::accept_async(tls_conn).await?; + Ok(Box::new(WsConn::new( + WsStream::new_wss(conn, self.codec.clone()), + peer_endpoint, + local_endpoint, + ))) + } + None => unreachable!(), + }, + None => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?; + + let conn = async_tungstenite::accept_async(socket).await?; + + Ok(Box::new(WsConn::new( + WsStream::new_ws(conn, self.codec.clone()), + peer_endpoint, + local_endpoint, + ))) + } + } } } /// Connects to the given WS address and port. -pub async fn dial(endpoint: &Endpoint) -> Result { +pub async fn dial(endpoint: &Endpoint, config: ClientWsConfig, codec: C) -> Result> +where + C: WebSocketCodec, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let stream = TcpStream::connect(addr).await?; - let (conn, _resp) = - async_tungstenite::client_async(endpoint.to_string(), stream.clone()).await?; - Ok(WsConn::new(stream, WsStream::new(conn))) + let socket = TcpStream::connect(addr).await?; + socket.set_nodelay(config.tcp_config.nodelay)?; + + match &config.wss_config { + Some(conf) => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?; + + let connector = TlsConnector::from(Arc::new(conf.client_config.clone())); + + let altname = pki_types::ServerName::try_from(conf.dns_name.clone())?; + let tls_conn = connector.connect(altname, socket).await?.into(); + let (conn, _resp) = + async_tungstenite::client_async(endpoint.to_string(), tls_conn).await?; + Ok(WsConn::new( + WsStream::new_wss(conn, codec), + peer_endpoint, + local_endpoint, + )) + } + None => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?; + let (conn, _resp) = + async_tungstenite::client_async(endpoint.to_string(), socket).await?; + Ok(WsConn::new( + WsStream::new_ws(conn, codec), + peer_endpoint, + local_endpoint, + )) + } + } } /// Listens on the given WS address and port. -pub async fn listen(endpoint: &Endpoint) -> Result { +pub async fn listen( + endpoint: &Endpoint, + config: ServerWsConfig, + codec: C, +) -> Result> { let addr = SocketAddr::try_from(endpoint.clone())?; + let listener = TcpListener::bind(addr).await?; - Ok(WsListener { inner: listener }) + match &config.wss_config { + Some(conf) => { + let acceptor = TlsAcceptor::from(Arc::new(conf.server_config.clone())); + Ok(WsListener { + inner: listener, + config, + codec, + tls_acceptor: Some(acceptor), + }) + } + None => Ok(WsListener { + inner: listener, + config, + codec, + tls_acceptor: None, + }), + } } -impl From for Box { - fn from(listener: WsListener) -> Self { +impl From> for Listener +where + C: WebSocketCodec + Clone, +{ + fn from(listener: WsListener) -> Self { Box::new(listener) } } -impl ToConn for WsConn { - fn to_conn(self) -> Box { +impl ToConn for WsConn +where + C: WebSocketCodec, +{ + type Item = C::Item; + fn to_conn(self) -> Conn { Box::new(self) } } -impl ToListener for WsListener { - fn to_listener(self) -> Box { +impl ToListener for WsListener +where + C: WebSocketCodec + Clone, +{ + type Item = C::Item; + fn to_listener(self) -> Listener { self.into() } } -- cgit v1.2.3