diff options
author | hozan23 <hozan23@karyontech.net> | 2024-04-11 10:19:20 +0200 |
---|---|---|
committer | hozan23 <hozan23@karyontech.net> | 2024-05-19 13:51:30 +0200 |
commit | 0992071a7f1a36424bcfaf1fbc84541ea041df1a (patch) | |
tree | 961d73218af672797d49f899289bef295bc56493 /net/src/transports | |
parent | a69917ecd8272a4946cfd12c75bf8f8c075b0e50 (diff) |
add support for tokio & improve net crate api
Diffstat (limited to 'net/src/transports')
-rw-r--r-- | net/src/transports/tcp.rs | 188 | ||||
-rw-r--r-- | net/src/transports/tls.rs | 220 | ||||
-rw-r--r-- | net/src/transports/udp.rs | 114 | ||||
-rw-r--r-- | net/src/transports/unix.rs | 193 | ||||
-rw-r--r-- | net/src/transports/ws.rs | 242 |
5 files changed, 673 insertions, 284 deletions
diff --git a/net/src/transports/tcp.rs b/net/src/transports/tcp.rs index 21fce3d..03c8ab2 100644 --- a/net/src/transports/tcp.rs +++ b/net/src/transports/tcp.rs @@ -1,116 +1,184 @@ use std::net::SocketAddr; use async_trait::async_trait; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, +use futures_util::SinkExt; + +use karyon_core::async_runtime::{ + io::{split, ReadHalf, WriteHalf}, lock::Mutex, - net::{TcpListener, TcpStream}, + net::{TcpListener as AsyncTcpListener, TcpStream}, }; use crate::{ - connection::{Connection, ToConn}, + codec::Codec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, - listener::{ConnListener, ToListener}, - Error, Result, + listener::{ConnListener, Listener, ToListener}, + stream::{ReadStream, WriteStream}, + Result, }; -/// TCP network connection implementation of the [`Connection`] trait. -pub struct TcpConn { - inner: TcpStream, - read: Mutex<ReadHalf<TcpStream>>, - write: Mutex<WriteHalf<TcpStream>>, +/// TCP configuration +#[derive(Clone)] +pub struct TcpConfig { + pub nodelay: bool, +} + +impl Default for TcpConfig { + fn default() -> Self { + Self { nodelay: true } + } +} + +/// TCP connection implementation of the [`Connection`] trait. +pub struct TcpConn<C> { + read_stream: Mutex<ReadStream<ReadHalf<TcpStream>, C>>, + write_stream: Mutex<WriteStream<WriteHalf<TcpStream>, C>>, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, } -impl TcpConn { +impl<C> TcpConn<C> +where + C: Codec + Clone, +{ /// Creates a new TcpConn - pub fn new(conn: TcpStream) -> Self { - let (read, write) = split(conn.clone()); + pub fn new( + socket: TcpStream, + codec: C, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, + ) -> Self { + let (read, write) = split(socket); + let read_stream = Mutex::new(ReadStream::new(read, codec.clone())); + let write_stream = Mutex::new(WriteStream::new(write, codec)); Self { - inner: conn, - read: Mutex::new(read), - write: Mutex::new(write), + read_stream, + write_stream, + peer_endpoint, + local_endpoint, } } } #[async_trait] -impl Connection for TcpConn { +impl<C> Connection for TcpConn<C> +where + C: Codec + Clone, +{ + type Item = C::Item; fn peer_endpoint(&self) -> Result<Endpoint> { - Ok(Endpoint::new_tcp_addr(&self.inner.peer_addr()?)) + Ok(self.peer_endpoint.clone()) } fn local_endpoint(&self) -> Result<Endpoint> { - Ok(Endpoint::new_tcp_addr(&self.inner.local_addr()?)) + Ok(self.local_endpoint.clone()) } - async fn read(&self, buf: &mut [u8]) -> Result<usize> { - self.read.lock().await.read(buf).await.map_err(Error::from) + async fn recv(&self) -> Result<Self::Item> { + self.read_stream.lock().await.recv().await } - async fn write(&self, buf: &[u8]) -> Result<usize> { - self.write - .lock() - .await - .write(buf) - .await - .map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + self.write_stream.lock().await.send(msg).await + } +} + +pub struct TcpListener<C> { + inner: AsyncTcpListener, + config: TcpConfig, + codec: C, +} + +impl<C> TcpListener<C> +where + C: Codec, +{ + pub fn new(listener: AsyncTcpListener, config: TcpConfig, codec: C) -> Self { + Self { + inner: listener, + config: config.clone(), + codec, + } } } #[async_trait] -impl ConnListener for TcpListener { +impl<C> ConnListener for TcpListener<C> +where + C: Codec + Clone, +{ + type Item = C::Item; fn local_endpoint(&self) -> Result<Endpoint> { - Ok(Endpoint::new_tcp_addr(&self.local_addr()?)) + Ok(Endpoint::new_tcp_addr(self.inner.local_addr()?)) } - async fn accept(&self) -> Result<Box<dyn Connection>> { - let (conn, _) = self.accept().await?; - conn.set_nodelay(true)?; - Ok(Box::new(TcpConn::new(conn))) + async fn accept(&self) -> Result<Conn<C::Item>> { + let (socket, _) = self.inner.accept().await?; + socket.set_nodelay(self.config.nodelay)?; + + let peer_endpoint = socket.peer_addr().map(Endpoint::new_tcp_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_tcp_addr)?; + + Ok(Box::new(TcpConn::new( + socket, + self.codec.clone(), + peer_endpoint, + local_endpoint, + ))) } } /// Connects to the given TCP address and port. -pub async fn dial(endpoint: &Endpoint) -> Result<TcpConn> { +pub async fn dial<C>(endpoint: &Endpoint, config: TcpConfig, codec: C) -> Result<TcpConn<C>> +where + C: Codec + Clone, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let conn = TcpStream::connect(addr).await?; - conn.set_nodelay(true)?; - Ok(TcpConn::new(conn)) + let socket = TcpStream::connect(addr).await?; + socket.set_nodelay(config.nodelay)?; + + let peer_endpoint = socket.peer_addr().map(Endpoint::new_tcp_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_tcp_addr)?; + + Ok(TcpConn::new(socket, codec, peer_endpoint, local_endpoint)) } /// Listens on the given TCP address and port. -pub async fn listen(endpoint: &Endpoint) -> Result<TcpListener> { +pub async fn listen<C>(endpoint: &Endpoint, config: TcpConfig, codec: C) -> Result<TcpListener<C>> +where + C: Codec, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let listener = TcpListener::bind(addr).await?; - Ok(listener) -} - -impl From<TcpStream> for Box<dyn Connection> { - fn from(conn: TcpStream) -> Self { - Box::new(TcpConn::new(conn)) - } + let listener = AsyncTcpListener::bind(addr).await?; + Ok(TcpListener::new(listener, config, codec)) } -impl From<TcpListener> for Box<dyn ConnListener> { - fn from(listener: TcpListener) -> Self { +impl<C> From<TcpListener<C>> for Box<dyn ConnListener<Item = C::Item>> +where + C: Clone + Codec, +{ + fn from(listener: TcpListener<C>) -> Self { Box::new(listener) } } -impl ToConn for TcpStream { - fn to_conn(self) -> Box<dyn Connection> { - self.into() - } -} - -impl ToConn for TcpConn { - fn to_conn(self) -> Box<dyn Connection> { +impl<C> ToConn for TcpConn<C> +where + C: Codec + Clone, +{ + type Item = C::Item; + fn to_conn(self) -> Conn<Self::Item> { Box::new(self) } } -impl ToListener for TcpListener { - fn to_listener(self) -> Box<dyn ConnListener> { +impl<C> ToListener for TcpListener<C> +where + C: Clone + Codec, +{ + type Item = C::Item; + fn to_listener(self) -> Listener<Self::Item> { self.into() } } diff --git a/net/src/transports/tls.rs b/net/src/transports/tls.rs index 476f495..c972f63 100644 --- a/net/src/transports/tls.rs +++ b/net/src/transports/tls.rs @@ -1,138 +1,218 @@ use std::{net::SocketAddr, sync::Arc}; use async_trait::async_trait; -use futures_rustls::{pki_types, rustls, TlsAcceptor, TlsConnector, TlsStream}; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, +use futures_util::SinkExt; +use rustls_pki_types as pki_types; + +#[cfg(feature = "smol")] +use futures_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream}; +#[cfg(feature = "tokio")] +use tokio_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream}; + +use karyon_core::async_runtime::{ + io::{split, ReadHalf, WriteHalf}, lock::Mutex, net::{TcpListener, TcpStream}, }; use crate::{ - connection::{Connection, ToConn}, + codec::Codec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, - listener::{ConnListener, ToListener}, - Error, Result, + listener::{ConnListener, Listener, ToListener}, + stream::{ReadStream, WriteStream}, + Result, }; +use super::tcp::TcpConfig; + +/// TLS configuration +#[derive(Clone)] +pub struct ServerTlsConfig { + pub tcp_config: TcpConfig, + pub server_config: rustls::ServerConfig, +} + +#[derive(Clone)] +pub struct ClientTlsConfig { + pub tcp_config: TcpConfig, + pub client_config: rustls::ClientConfig, + pub dns_name: String, +} + /// TLS network connection implementation of the [`Connection`] trait. -pub struct TlsConn { - inner: TcpStream, - read: Mutex<ReadHalf<TlsStream<TcpStream>>>, - write: Mutex<WriteHalf<TlsStream<TcpStream>>>, +pub struct TlsConn<C> { + read_stream: Mutex<ReadStream<ReadHalf<TlsStream<TcpStream>>, C>>, + write_stream: Mutex<WriteStream<WriteHalf<TlsStream<TcpStream>>, C>>, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, } -impl TlsConn { +impl<C> TlsConn<C> +where + C: Codec + Clone, +{ /// Creates a new TlsConn - pub fn new(sock: TcpStream, conn: TlsStream<TcpStream>) -> Self { + pub fn new( + conn: TlsStream<TcpStream>, + codec: C, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, + ) -> Self { let (read, write) = split(conn); + let read_stream = Mutex::new(ReadStream::new(read, codec.clone())); + let write_stream = Mutex::new(WriteStream::new(write, codec)); Self { - inner: sock, - read: Mutex::new(read), - write: Mutex::new(write), + read_stream, + write_stream, + peer_endpoint, + local_endpoint, } } } #[async_trait] -impl Connection for TlsConn { +impl<C> Connection for TlsConn<C> +where + C: Clone + Codec, +{ + type Item = C::Item; fn peer_endpoint(&self) -> Result<Endpoint> { - Ok(Endpoint::new_tls_addr(&self.inner.peer_addr()?)) + Ok(self.peer_endpoint.clone()) } fn local_endpoint(&self) -> Result<Endpoint> { - Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?)) + Ok(self.local_endpoint.clone()) } - async fn read(&self, buf: &mut [u8]) -> Result<usize> { - self.read.lock().await.read(buf).await.map_err(Error::from) + async fn recv(&self) -> Result<Self::Item> { + self.read_stream.lock().await.recv().await } - async fn write(&self, buf: &[u8]) -> Result<usize> { - self.write - .lock() - .await - .write(buf) - .await - .map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + self.write_stream.lock().await.send(msg).await } } /// Connects to the given TLS address and port. -pub async fn dial( - endpoint: &Endpoint, - config: rustls::ClientConfig, - dns_name: &'static str, -) -> Result<TlsConn> { +pub async fn dial<C>(endpoint: &Endpoint, config: ClientTlsConfig, codec: C) -> Result<TlsConn<C>> +where + C: Codec + Clone, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let connector = TlsConnector::from(Arc::new(config)); + let connector = TlsConnector::from(Arc::new(config.client_config.clone())); + + let socket = TcpStream::connect(addr).await?; + socket.set_nodelay(config.tcp_config.nodelay)?; - let sock = TcpStream::connect(addr).await?; - sock.set_nodelay(true)?; + let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?; - let altname = pki_types::ServerName::try_from(dns_name)?; - let conn = connector.connect(altname, sock.clone()).await?; - Ok(TlsConn::new(sock, TlsStream::Client(conn))) + let altname = pki_types::ServerName::try_from(config.dns_name.clone())?; + let conn = connector.connect(altname, socket).await?; + Ok(TlsConn::new( + TlsStream::Client(conn), + codec, + peer_endpoint, + local_endpoint, + )) } /// Tls network listener implementation of the `Listener` [`ConnListener`] trait. -pub struct TlsListener { +pub struct TlsListener<C> { inner: TcpListener, acceptor: TlsAcceptor, + config: ServerTlsConfig, + codec: C, +} + +impl<C> TlsListener<C> +where + C: Codec + Clone, +{ + pub fn new( + acceptor: TlsAcceptor, + listener: TcpListener, + config: ServerTlsConfig, + codec: C, + ) -> Self { + Self { + inner: listener, + acceptor, + config: config.clone(), + codec, + } + } } #[async_trait] -impl ConnListener for TlsListener { +impl<C> ConnListener for TlsListener<C> +where + C: Clone + Codec, +{ + type Item = C::Item; fn local_endpoint(&self) -> Result<Endpoint> { - Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?)) + Ok(Endpoint::new_tls_addr(self.inner.local_addr()?)) } - async fn accept(&self) -> Result<Box<dyn Connection>> { - let (sock, _) = self.inner.accept().await?; - sock.set_nodelay(true)?; - let conn = self.acceptor.accept(sock.clone()).await?; - Ok(Box::new(TlsConn::new(sock, TlsStream::Server(conn)))) + async fn accept(&self) -> Result<Conn<C::Item>> { + let (socket, _) = self.inner.accept().await?; + socket.set_nodelay(self.config.tcp_config.nodelay)?; + + let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?; + + let conn = self.acceptor.accept(socket).await?; + Ok(Box::new(TlsConn::new( + TlsStream::Server(conn), + self.codec.clone(), + peer_endpoint, + local_endpoint, + ))) } } /// Listens on the given TLS address and port. -pub async fn listen(endpoint: &Endpoint, config: rustls::ServerConfig) -> Result<TlsListener> { +pub async fn listen<C>( + endpoint: &Endpoint, + config: ServerTlsConfig, + codec: C, +) -> Result<TlsListener<C>> +where + C: Clone + Codec, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let acceptor = TlsAcceptor::from(Arc::new(config)); + let acceptor = TlsAcceptor::from(Arc::new(config.server_config.clone())); let listener = TcpListener::bind(addr).await?; - Ok(TlsListener { - acceptor, - inner: listener, - }) -} - -impl From<TlsStream<TcpStream>> for Box<dyn Connection> { - fn from(conn: TlsStream<TcpStream>) -> Self { - Box::new(TlsConn::new(conn.get_ref().0.clone(), conn)) - } + Ok(TlsListener::new(acceptor, listener, config, codec)) } -impl From<TlsListener> for Box<dyn ConnListener> { - fn from(listener: TlsListener) -> Self { +impl<C> From<TlsListener<C>> for Listener<C::Item> +where + C: Codec + Clone, +{ + fn from(listener: TlsListener<C>) -> Self { Box::new(listener) } } -impl ToConn for TlsStream<TcpStream> { - fn to_conn(self) -> Box<dyn Connection> { - self.into() - } -} - -impl ToConn for TlsConn { - fn to_conn(self) -> Box<dyn Connection> { +impl<C> ToConn for TlsConn<C> +where + C: Codec + Clone, +{ + type Item = C::Item; + fn to_conn(self) -> Conn<Self::Item> { Box::new(self) } } -impl ToListener for TlsListener { - fn to_listener(self) -> Box<dyn ConnListener> { +impl<C> ToListener for TlsListener<C> +where + C: Clone + Codec, +{ + type Item = C::Item; + fn to_listener(self) -> Listener<Self::Item> { self.into() } } diff --git a/net/src/transports/udp.rs b/net/src/transports/udp.rs index bd1fe83..950537c 100644 --- a/net/src/transports/udp.rs +++ b/net/src/transports/udp.rs @@ -1,93 +1,111 @@ use std::net::SocketAddr; use async_trait::async_trait; -use smol::net::UdpSocket; +use karyon_core::async_runtime::net::UdpSocket; use crate::{ - connection::{Connection, ToConn}, + codec::Codec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, Error, Result, }; +const BUFFER_SIZE: usize = 64 * 1024; + +/// UDP configuration +#[derive(Default)] +pub struct UdpConfig {} + /// UDP network connection implementation of the [`Connection`] trait. -pub struct UdpConn { +#[allow(dead_code)] +pub struct UdpConn<C> { inner: UdpSocket, + codec: C, + config: UdpConfig, } -impl UdpConn { +impl<C> UdpConn<C> +where + C: Codec + Clone, +{ /// Creates a new UdpConn - pub fn new(conn: UdpSocket) -> Self { - Self { inner: conn } - } -} - -impl UdpConn { - /// Receives a single datagram message. Returns the number of bytes read - /// and the origin endpoint. - pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, Endpoint)> { - let (size, addr) = self.inner.recv_from(buf).await?; - Ok((size, Endpoint::new_udp_addr(&addr))) - } - - /// Sends data to the given address. Returns the number of bytes written. - pub async fn send_to(&self, buf: &[u8], addr: &Endpoint) -> Result<usize> { - let addr: SocketAddr = addr.clone().try_into()?; - let size = self.inner.send_to(buf, addr).await?; - Ok(size) + fn new(socket: UdpSocket, config: UdpConfig, codec: C) -> Self { + Self { + inner: socket, + codec, + config, + } } } #[async_trait] -impl Connection for UdpConn { +impl<C> Connection for UdpConn<C> +where + C: Codec + Clone, +{ + type Item = (C::Item, Endpoint); fn peer_endpoint(&self) -> Result<Endpoint> { - Ok(Endpoint::new_udp_addr(&self.inner.peer_addr()?)) + self.inner + .peer_addr() + .map(Endpoint::new_udp_addr) + .map_err(Error::from) } fn local_endpoint(&self) -> Result<Endpoint> { - Ok(Endpoint::new_udp_addr(&self.inner.local_addr()?)) + self.inner + .local_addr() + .map(Endpoint::new_udp_addr) + .map_err(Error::from) } - async fn read(&self, buf: &mut [u8]) -> Result<usize> { - self.inner.recv(buf).await.map_err(Error::from) + async fn recv(&self) -> Result<Self::Item> { + let mut buf = [0u8; BUFFER_SIZE]; + let (_, addr) = self.inner.recv_from(&mut buf).await?; + match self.codec.decode(&mut buf)? { + Some((_, item)) => Ok((item, Endpoint::new_udp_addr(addr))), + None => Err(Error::Decode("Unable to decode".into())), + } } - async fn write(&self, buf: &[u8]) -> Result<usize> { - self.inner.send(buf).await.map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + let (msg, out_addr) = msg; + let mut buf = [0u8; BUFFER_SIZE]; + self.codec.encode(&msg, &mut buf)?; + let addr: SocketAddr = out_addr.try_into()?; + self.inner.send_to(&buf, addr).await?; + Ok(()) } } /// Connects to the given UDP address and port. -pub async fn dial(endpoint: &Endpoint) -> Result<UdpConn> { +pub async fn dial<C>(endpoint: &Endpoint, config: UdpConfig, codec: C) -> Result<UdpConn<C>> +where + C: Codec + Clone, +{ let addr = SocketAddr::try_from(endpoint.clone())?; // Let the operating system assign an available port to this socket let conn = UdpSocket::bind("[::]:0").await?; conn.connect(addr).await?; - Ok(UdpConn::new(conn)) + Ok(UdpConn::new(conn, config, codec)) } /// Listens on the given UDP address and port. -pub async fn listen(endpoint: &Endpoint) -> Result<UdpConn> { +pub async fn listen<C>(endpoint: &Endpoint, config: UdpConfig, codec: C) -> Result<UdpConn<C>> +where + C: Codec + Clone, +{ let addr = SocketAddr::try_from(endpoint.clone())?; let conn = UdpSocket::bind(addr).await?; - let udp_conn = UdpConn::new(conn); - Ok(udp_conn) -} - -impl From<UdpSocket> for Box<dyn Connection> { - fn from(conn: UdpSocket) -> Self { - Box::new(UdpConn::new(conn)) - } -} - -impl ToConn for UdpSocket { - fn to_conn(self) -> Box<dyn Connection> { - self.into() - } + Ok(UdpConn::new(conn, config, codec)) } -impl ToConn for UdpConn { - fn to_conn(self) -> Box<dyn Connection> { +impl<C> ToConn for UdpConn<C> +where + C: Codec + Clone, +{ + type Item = (C::Item, Endpoint); + fn to_conn(self) -> Conn<Self::Item> { Box::new(self) } } diff --git a/net/src/transports/unix.rs b/net/src/transports/unix.rs index 494e104..bafebaf 100644 --- a/net/src/transports/unix.rs +++ b/net/src/transports/unix.rs @@ -1,111 +1,192 @@ use async_trait::async_trait; +use futures_util::SinkExt; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, +use karyon_core::async_runtime::{ + io::{split, ReadHalf, WriteHalf}, lock::Mutex, - net::unix::{UnixListener, UnixStream}, + net::{UnixListener as AsyncUnixListener, UnixStream}, }; use crate::{ - connection::{Connection, ToConn}, + codec::Codec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, - listener::{ConnListener, ToListener}, + listener::{ConnListener, Listener, ToListener}, + stream::{ReadStream, WriteStream}, Error, Result, }; +/// Unix Conn config +#[derive(Clone, Default)] +pub struct UnixConfig {} + /// Unix domain socket implementation of the [`Connection`] trait. -pub struct UnixConn { - inner: UnixStream, - read: Mutex<ReadHalf<UnixStream>>, - write: Mutex<WriteHalf<UnixStream>>, +pub struct UnixConn<C> { + read_stream: Mutex<ReadStream<ReadHalf<UnixStream>, C>>, + write_stream: Mutex<WriteStream<WriteHalf<UnixStream>, C>>, + peer_endpoint: Option<Endpoint>, + local_endpoint: Option<Endpoint>, } -impl UnixConn { - /// Creates a new UnixConn - pub fn new(conn: UnixStream) -> Self { - let (read, write) = split(conn.clone()); +impl<C> UnixConn<C> +where + C: Codec + Clone, +{ + /// Creates a new TcpConn + pub fn new(conn: UnixStream, codec: C) -> Self { + let peer_endpoint = conn + .peer_addr() + .and_then(|a| { + Ok(Endpoint::new_unix_addr( + a.as_pathname() + .ok_or(std::io::ErrorKind::AddrNotAvailable)?, + )) + }) + .ok(); + let local_endpoint = conn + .local_addr() + .and_then(|a| { + Ok(Endpoint::new_unix_addr( + a.as_pathname() + .ok_or(std::io::ErrorKind::AddrNotAvailable)?, + )) + }) + .ok(); + + let (read, write) = split(conn); + let read_stream = Mutex::new(ReadStream::new(read, codec.clone())); + let write_stream = Mutex::new(WriteStream::new(write, codec)); Self { - inner: conn, - read: Mutex::new(read), - write: Mutex::new(write), + read_stream, + write_stream, + peer_endpoint, + local_endpoint, } } } #[async_trait] -impl Connection for UnixConn { +impl<C> Connection for UnixConn<C> +where + C: Codec + Clone, +{ + type Item = C::Item; fn peer_endpoint(&self) -> Result<Endpoint> { - Ok(Endpoint::new_unix_addr(&self.inner.peer_addr()?)) + self.peer_endpoint + .clone() + .ok_or(Error::IO(std::io::ErrorKind::AddrNotAvailable.into())) } fn local_endpoint(&self) -> Result<Endpoint> { - Ok(Endpoint::new_unix_addr(&self.inner.local_addr()?)) + self.local_endpoint + .clone() + .ok_or(Error::IO(std::io::ErrorKind::AddrNotAvailable.into())) } - async fn read(&self, buf: &mut [u8]) -> Result<usize> { - self.read.lock().await.read(buf).await.map_err(Error::from) + async fn recv(&self) -> Result<Self::Item> { + self.read_stream.lock().await.recv().await } - async fn write(&self, buf: &[u8]) -> Result<usize> { - self.write - .lock() - .await - .write(buf) - .await - .map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + self.write_stream.lock().await.send(msg).await + } +} + +#[allow(dead_code)] +pub struct UnixListener<C> { + inner: AsyncUnixListener, + config: UnixConfig, + codec: C, +} + +impl<C> UnixListener<C> +where + C: Codec + Clone, +{ + pub fn new(listener: AsyncUnixListener, config: UnixConfig, codec: C) -> Self { + Self { + inner: listener, + config, + codec, + } } } #[async_trait] -impl ConnListener for UnixListener { +impl<C> ConnListener for UnixListener<C> +where + C: Codec + Clone, +{ + type Item = C::Item; fn local_endpoint(&self) -> Result<Endpoint> { - Ok(Endpoint::new_unix_addr(&self.local_addr()?)) + self.inner + .local_addr() + .and_then(|a| { + Ok(Endpoint::new_unix_addr( + a.as_pathname() + .ok_or(std::io::ErrorKind::AddrNotAvailable)?, + )) + }) + .map_err(Error::from) } - async fn accept(&self) -> Result<Box<dyn Connection>> { - let (conn, _) = self.accept().await?; - Ok(Box::new(UnixConn::new(conn))) + async fn accept(&self) -> Result<Conn<C::Item>> { + let (conn, _) = self.inner.accept().await?; + Ok(Box::new(UnixConn::new(conn, self.codec.clone()))) } } /// Connects to the given Unix socket path. -pub async fn dial(path: &String) -> Result<UnixConn> { +pub async fn dial<C>(endpoint: &Endpoint, _config: UnixConfig, codec: C) -> Result<UnixConn<C>> +where + C: Codec + Clone, +{ + let path: std::path::PathBuf = endpoint.clone().try_into()?; let conn = UnixStream::connect(path).await?; - Ok(UnixConn::new(conn)) + Ok(UnixConn::new(conn, codec)) } /// Listens on the given Unix socket path. -pub fn listen(path: &String) -> Result<UnixListener> { - let listener = UnixListener::bind(path)?; - Ok(listener) -} - -impl From<UnixStream> for Box<dyn Connection> { - fn from(conn: UnixStream) -> Self { - Box::new(UnixConn::new(conn)) - } +pub fn listen<C>(endpoint: &Endpoint, config: UnixConfig, codec: C) -> Result<UnixListener<C>> +where + C: Codec + Clone, +{ + let path: std::path::PathBuf = endpoint.clone().try_into()?; + let listener = AsyncUnixListener::bind(path)?; + Ok(UnixListener::new(listener, config, codec)) } -impl From<UnixListener> for Box<dyn ConnListener> { - fn from(listener: UnixListener) -> Self { +// impl From<UnixStream> for Box<dyn Connection> { +// fn from(conn: UnixStream) -> Self { +// Box::new(UnixConn::new(conn)) +// } +// } + +impl<C> From<UnixListener<C>> for Listener<C::Item> +where + C: Codec + Clone, +{ + fn from(listener: UnixListener<C>) -> Self { Box::new(listener) } } -impl ToConn for UnixStream { - fn to_conn(self) -> Box<dyn Connection> { - self.into() - } -} - -impl ToConn for UnixConn { - fn to_conn(self) -> Box<dyn Connection> { +impl<C> ToConn for UnixConn<C> +where + C: Codec + Clone, +{ + type Item = C::Item; + fn to_conn(self) -> Conn<Self::Item> { Box::new(self) } } -impl ToListener for UnixListener { - fn to_listener(self) -> Box<dyn ConnListener> { +impl<C> ToListener for UnixListener<C> +where + C: Codec + Clone, +{ + type Item = C::Item; + fn to_listener(self) -> Listener<Self::Item> { self.into() } } diff --git a/net/src/transports/ws.rs b/net/src/transports/ws.rs index eaf3b9b..17fe924 100644 --- a/net/src/transports/ws.rs +++ b/net/src/transports/ws.rs @@ -1,112 +1,254 @@ -use std::net::SocketAddr; +use std::{net::SocketAddr, sync::Arc}; use async_trait::async_trait; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, +use rustls_pki_types as pki_types; + +#[cfg(feature = "tokio")] +use async_tungstenite::tokio as async_tungstenite; + +#[cfg(feature = "smol")] +use futures_rustls::{rustls, TlsAcceptor, TlsConnector}; +#[cfg(feature = "tokio")] +use tokio_rustls::{rustls, TlsAcceptor, TlsConnector}; + +use karyon_core::async_runtime::{ lock::Mutex, net::{TcpListener, TcpStream}, }; -use ws_stream_tungstenite::WsStream; - use crate::{ - connection::{Connection, ToConn}, + codec::WebSocketCodec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, - listener::{ConnListener, ToListener}, - Error, Result, + listener::{ConnListener, Listener, ToListener}, + stream::WsStream, + Result, }; +use super::tcp::TcpConfig; + +/// WSS configuration +#[derive(Clone)] +pub struct ServerWssConfig { + pub server_config: rustls::ServerConfig, +} + +/// WSS configuration +#[derive(Clone)] +pub struct ClientWssConfig { + pub client_config: rustls::ClientConfig, + pub dns_name: String, +} + +/// WS configuration +#[derive(Clone, Default)] +pub struct ServerWsConfig { + pub tcp_config: TcpConfig, + pub wss_config: Option<ServerWssConfig>, +} + +/// WS configuration +#[derive(Clone, Default)] +pub struct ClientWsConfig { + pub tcp_config: TcpConfig, + pub wss_config: Option<ClientWssConfig>, +} + /// WS network connection implementation of the [`Connection`] trait. -pub struct WsConn { - inner: TcpStream, - read: Mutex<ReadHalf<WsStream<TcpStream>>>, - write: Mutex<WriteHalf<WsStream<TcpStream>>>, +pub struct WsConn<C> { + // XXX: remove mutex + inner: Mutex<WsStream<C>>, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, } -impl WsConn { +impl<C> WsConn<C> +where + C: WebSocketCodec, +{ /// Creates a new WsConn - pub fn new(inner: TcpStream, conn: WsStream<TcpStream>) -> Self { - let (read, write) = split(conn); + pub fn new(ws: WsStream<C>, peer_endpoint: Endpoint, local_endpoint: Endpoint) -> Self { Self { - inner, - read: Mutex::new(read), - write: Mutex::new(write), + inner: Mutex::new(ws), + peer_endpoint, + local_endpoint, } } } #[async_trait] -impl Connection for WsConn { +impl<C> Connection for WsConn<C> +where + C: WebSocketCodec, +{ + type Item = C::Item; fn peer_endpoint(&self) -> Result<Endpoint> { - Ok(Endpoint::new_ws_addr(&self.inner.peer_addr()?)) + Ok(self.peer_endpoint.clone()) } fn local_endpoint(&self) -> Result<Endpoint> { - Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?)) + Ok(self.local_endpoint.clone()) } - async fn read(&self, buf: &mut [u8]) -> Result<usize> { - self.read.lock().await.read(buf).await.map_err(Error::from) + async fn recv(&self) -> Result<Self::Item> { + self.inner.lock().await.recv().await } - async fn write(&self, buf: &[u8]) -> Result<usize> { - self.write - .lock() - .await - .write(buf) - .await - .map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + self.inner.lock().await.send(msg).await } } /// Ws network listener implementation of the `Listener` [`ConnListener`] trait. -pub struct WsListener { +pub struct WsListener<C> { inner: TcpListener, + config: ServerWsConfig, + codec: C, + tls_acceptor: Option<TlsAcceptor>, } #[async_trait] -impl ConnListener for WsListener { +impl<C> ConnListener for WsListener<C> +where + C: WebSocketCodec + Clone, +{ + type Item = C::Item; fn local_endpoint(&self) -> Result<Endpoint> { - Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?)) + match self.config.wss_config { + Some(_) => Ok(Endpoint::new_wss_addr(self.inner.local_addr()?)), + None => Ok(Endpoint::new_ws_addr(self.inner.local_addr()?)), + } } - async fn accept(&self) -> Result<Box<dyn Connection>> { - let (stream, _) = self.inner.accept().await?; - let conn = async_tungstenite::accept_async(stream.clone()).await?; - Ok(Box::new(WsConn::new(stream, WsStream::new(conn)))) + async fn accept(&self) -> Result<Conn<Self::Item>> { + let (socket, _) = self.inner.accept().await?; + socket.set_nodelay(self.config.tcp_config.nodelay)?; + + match &self.config.wss_config { + Some(_) => match &self.tls_acceptor { + Some(acceptor) => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?; + + let tls_conn = acceptor.accept(socket).await?.into(); + let conn = async_tungstenite::accept_async(tls_conn).await?; + Ok(Box::new(WsConn::new( + WsStream::new_wss(conn, self.codec.clone()), + peer_endpoint, + local_endpoint, + ))) + } + None => unreachable!(), + }, + None => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?; + + let conn = async_tungstenite::accept_async(socket).await?; + + Ok(Box::new(WsConn::new( + WsStream::new_ws(conn, self.codec.clone()), + peer_endpoint, + local_endpoint, + ))) + } + } } } /// Connects to the given WS address and port. -pub async fn dial(endpoint: &Endpoint) -> Result<WsConn> { +pub async fn dial<C>(endpoint: &Endpoint, config: ClientWsConfig, codec: C) -> Result<WsConn<C>> +where + C: WebSocketCodec, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let stream = TcpStream::connect(addr).await?; - let (conn, _resp) = - async_tungstenite::client_async(endpoint.to_string(), stream.clone()).await?; - Ok(WsConn::new(stream, WsStream::new(conn))) + let socket = TcpStream::connect(addr).await?; + socket.set_nodelay(config.tcp_config.nodelay)?; + + match &config.wss_config { + Some(conf) => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?; + + let connector = TlsConnector::from(Arc::new(conf.client_config.clone())); + + let altname = pki_types::ServerName::try_from(conf.dns_name.clone())?; + let tls_conn = connector.connect(altname, socket).await?.into(); + let (conn, _resp) = + async_tungstenite::client_async(endpoint.to_string(), tls_conn).await?; + Ok(WsConn::new( + WsStream::new_wss(conn, codec), + peer_endpoint, + local_endpoint, + )) + } + None => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?; + let (conn, _resp) = + async_tungstenite::client_async(endpoint.to_string(), socket).await?; + Ok(WsConn::new( + WsStream::new_ws(conn, codec), + peer_endpoint, + local_endpoint, + )) + } + } } /// Listens on the given WS address and port. -pub async fn listen(endpoint: &Endpoint) -> Result<WsListener> { +pub async fn listen<C>( + endpoint: &Endpoint, + config: ServerWsConfig, + codec: C, +) -> Result<WsListener<C>> { let addr = SocketAddr::try_from(endpoint.clone())?; + let listener = TcpListener::bind(addr).await?; - Ok(WsListener { inner: listener }) + match &config.wss_config { + Some(conf) => { + let acceptor = TlsAcceptor::from(Arc::new(conf.server_config.clone())); + Ok(WsListener { + inner: listener, + config, + codec, + tls_acceptor: Some(acceptor), + }) + } + None => Ok(WsListener { + inner: listener, + config, + codec, + tls_acceptor: None, + }), + } } -impl From<WsListener> for Box<dyn ConnListener> { - fn from(listener: WsListener) -> Self { +impl<C> From<WsListener<C>> for Listener<C::Item> +where + C: WebSocketCodec + Clone, +{ + fn from(listener: WsListener<C>) -> Self { Box::new(listener) } } -impl ToConn for WsConn { - fn to_conn(self) -> Box<dyn Connection> { +impl<C> ToConn for WsConn<C> +where + C: WebSocketCodec, +{ + type Item = C::Item; + fn to_conn(self) -> Conn<Self::Item> { Box::new(self) } } -impl ToListener for WsListener { - fn to_listener(self) -> Box<dyn ConnListener> { +impl<C> ToListener for WsListener<C> +where + C: WebSocketCodec + Clone, +{ + type Item = C::Item; + fn to_listener(self) -> Listener<Self::Item> { self.into() } } |