From 0992071a7f1a36424bcfaf1fbc84541ea041df1a Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 11 Apr 2024 10:19:20 +0200 Subject: add support for tokio & improve net crate api --- net/src/transports/ws.rs | 242 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 192 insertions(+), 50 deletions(-) (limited to 'net/src/transports/ws.rs') diff --git a/net/src/transports/ws.rs b/net/src/transports/ws.rs index eaf3b9b..17fe924 100644 --- a/net/src/transports/ws.rs +++ b/net/src/transports/ws.rs @@ -1,112 +1,254 @@ -use std::net::SocketAddr; +use std::{net::SocketAddr, sync::Arc}; use async_trait::async_trait; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, +use rustls_pki_types as pki_types; + +#[cfg(feature = "tokio")] +use async_tungstenite::tokio as async_tungstenite; + +#[cfg(feature = "smol")] +use futures_rustls::{rustls, TlsAcceptor, TlsConnector}; +#[cfg(feature = "tokio")] +use tokio_rustls::{rustls, TlsAcceptor, TlsConnector}; + +use karyon_core::async_runtime::{ lock::Mutex, net::{TcpListener, TcpStream}, }; -use ws_stream_tungstenite::WsStream; - use crate::{ - connection::{Connection, ToConn}, + codec::WebSocketCodec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, - listener::{ConnListener, ToListener}, - Error, Result, + listener::{ConnListener, Listener, ToListener}, + stream::WsStream, + Result, }; +use super::tcp::TcpConfig; + +/// WSS configuration +#[derive(Clone)] +pub struct ServerWssConfig { + pub server_config: rustls::ServerConfig, +} + +/// WSS configuration +#[derive(Clone)] +pub struct ClientWssConfig { + pub client_config: rustls::ClientConfig, + pub dns_name: String, +} + +/// WS configuration +#[derive(Clone, Default)] +pub struct ServerWsConfig { + pub tcp_config: TcpConfig, + pub wss_config: Option, +} + +/// WS configuration +#[derive(Clone, Default)] +pub struct ClientWsConfig { + pub tcp_config: TcpConfig, + pub wss_config: Option, +} + /// WS network connection implementation of the [`Connection`] trait. -pub struct WsConn { - inner: TcpStream, - read: Mutex>>, - write: Mutex>>, +pub struct WsConn { + // XXX: remove mutex + inner: Mutex>, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, } -impl WsConn { +impl WsConn +where + C: WebSocketCodec, +{ /// Creates a new WsConn - pub fn new(inner: TcpStream, conn: WsStream) -> Self { - let (read, write) = split(conn); + pub fn new(ws: WsStream, peer_endpoint: Endpoint, local_endpoint: Endpoint) -> Self { Self { - inner, - read: Mutex::new(read), - write: Mutex::new(write), + inner: Mutex::new(ws), + peer_endpoint, + local_endpoint, } } } #[async_trait] -impl Connection for WsConn { +impl Connection for WsConn +where + C: WebSocketCodec, +{ + type Item = C::Item; fn peer_endpoint(&self) -> Result { - Ok(Endpoint::new_ws_addr(&self.inner.peer_addr()?)) + Ok(self.peer_endpoint.clone()) } fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?)) + Ok(self.local_endpoint.clone()) } - async fn read(&self, buf: &mut [u8]) -> Result { - self.read.lock().await.read(buf).await.map_err(Error::from) + async fn recv(&self) -> Result { + self.inner.lock().await.recv().await } - async fn write(&self, buf: &[u8]) -> Result { - self.write - .lock() - .await - .write(buf) - .await - .map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + self.inner.lock().await.send(msg).await } } /// Ws network listener implementation of the `Listener` [`ConnListener`] trait. -pub struct WsListener { +pub struct WsListener { inner: TcpListener, + config: ServerWsConfig, + codec: C, + tls_acceptor: Option, } #[async_trait] -impl ConnListener for WsListener { +impl ConnListener for WsListener +where + C: WebSocketCodec + Clone, +{ + type Item = C::Item; fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?)) + match self.config.wss_config { + Some(_) => Ok(Endpoint::new_wss_addr(self.inner.local_addr()?)), + None => Ok(Endpoint::new_ws_addr(self.inner.local_addr()?)), + } } - async fn accept(&self) -> Result> { - let (stream, _) = self.inner.accept().await?; - let conn = async_tungstenite::accept_async(stream.clone()).await?; - Ok(Box::new(WsConn::new(stream, WsStream::new(conn)))) + async fn accept(&self) -> Result> { + let (socket, _) = self.inner.accept().await?; + socket.set_nodelay(self.config.tcp_config.nodelay)?; + + match &self.config.wss_config { + Some(_) => match &self.tls_acceptor { + Some(acceptor) => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?; + + let tls_conn = acceptor.accept(socket).await?.into(); + let conn = async_tungstenite::accept_async(tls_conn).await?; + Ok(Box::new(WsConn::new( + WsStream::new_wss(conn, self.codec.clone()), + peer_endpoint, + local_endpoint, + ))) + } + None => unreachable!(), + }, + None => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?; + + let conn = async_tungstenite::accept_async(socket).await?; + + Ok(Box::new(WsConn::new( + WsStream::new_ws(conn, self.codec.clone()), + peer_endpoint, + local_endpoint, + ))) + } + } } } /// Connects to the given WS address and port. -pub async fn dial(endpoint: &Endpoint) -> Result { +pub async fn dial(endpoint: &Endpoint, config: ClientWsConfig, codec: C) -> Result> +where + C: WebSocketCodec, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let stream = TcpStream::connect(addr).await?; - let (conn, _resp) = - async_tungstenite::client_async(endpoint.to_string(), stream.clone()).await?; - Ok(WsConn::new(stream, WsStream::new(conn))) + let socket = TcpStream::connect(addr).await?; + socket.set_nodelay(config.tcp_config.nodelay)?; + + match &config.wss_config { + Some(conf) => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?; + + let connector = TlsConnector::from(Arc::new(conf.client_config.clone())); + + let altname = pki_types::ServerName::try_from(conf.dns_name.clone())?; + let tls_conn = connector.connect(altname, socket).await?.into(); + let (conn, _resp) = + async_tungstenite::client_async(endpoint.to_string(), tls_conn).await?; + Ok(WsConn::new( + WsStream::new_wss(conn, codec), + peer_endpoint, + local_endpoint, + )) + } + None => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?; + let (conn, _resp) = + async_tungstenite::client_async(endpoint.to_string(), socket).await?; + Ok(WsConn::new( + WsStream::new_ws(conn, codec), + peer_endpoint, + local_endpoint, + )) + } + } } /// Listens on the given WS address and port. -pub async fn listen(endpoint: &Endpoint) -> Result { +pub async fn listen( + endpoint: &Endpoint, + config: ServerWsConfig, + codec: C, +) -> Result> { let addr = SocketAddr::try_from(endpoint.clone())?; + let listener = TcpListener::bind(addr).await?; - Ok(WsListener { inner: listener }) + match &config.wss_config { + Some(conf) => { + let acceptor = TlsAcceptor::from(Arc::new(conf.server_config.clone())); + Ok(WsListener { + inner: listener, + config, + codec, + tls_acceptor: Some(acceptor), + }) + } + None => Ok(WsListener { + inner: listener, + config, + codec, + tls_acceptor: None, + }), + } } -impl From for Box { - fn from(listener: WsListener) -> Self { +impl From> for Listener +where + C: WebSocketCodec + Clone, +{ + fn from(listener: WsListener) -> Self { Box::new(listener) } } -impl ToConn for WsConn { - fn to_conn(self) -> Box { +impl ToConn for WsConn +where + C: WebSocketCodec, +{ + type Item = C::Item; + fn to_conn(self) -> Conn { Box::new(self) } } -impl ToListener for WsListener { - fn to_listener(self) -> Box { +impl ToListener for WsListener +where + C: WebSocketCodec + Clone, +{ + type Item = C::Item; + fn to_listener(self) -> Listener { self.into() } } -- cgit v1.2.3