use std::net::SocketAddr; #[cfg(feature = "tls")] use std::sync::Arc; use async_trait::async_trait; #[cfg(feature = "tls")] use rustls_pki_types as pki_types; #[cfg(feature = "tokio")] use async_tungstenite::tokio as async_tungstenite; #[cfg(all(feature = "smol", feature = "tls"))] use futures_rustls::{rustls, TlsAcceptor, TlsConnector}; #[cfg(all(feature = "tokio", feature = "tls"))] use tokio_rustls::{rustls, TlsAcceptor, TlsConnector}; use karyon_core::async_runtime::{ lock::Mutex, net::{TcpListener, TcpStream}, }; use crate::{ codec::WebSocketCodec, connection::{Conn, Connection, ToConn}, endpoint::Endpoint, listener::{ConnListener, Listener, ToListener}, stream::{ReadWsStream, WriteWsStream, WsStream}, Result, }; use super::tcp::TcpConfig; /// WSS configuration #[derive(Clone)] pub struct ServerWssConfig { #[cfg(feature = "tls")] pub server_config: rustls::ServerConfig, } /// WSS configuration #[derive(Clone)] pub struct ClientWssConfig { #[cfg(feature = "tls")] pub client_config: rustls::ClientConfig, pub dns_name: String, } /// WS configuration #[derive(Clone, Default)] pub struct ServerWsConfig { pub tcp_config: TcpConfig, pub wss_config: Option, } /// WS configuration #[derive(Clone, Default)] pub struct ClientWsConfig { pub tcp_config: TcpConfig, pub wss_config: Option, } /// WS network connection implementation of the [`Connection`] trait. pub struct WsConn { read_stream: Mutex>, write_stream: Mutex>, peer_endpoint: Endpoint, local_endpoint: Endpoint, } impl WsConn where C: WebSocketCodec + Clone, { /// Creates a new WsConn pub fn new(ws: WsStream, peer_endpoint: Endpoint, local_endpoint: Endpoint) -> Self { let (read, write) = ws.split(); Self { read_stream: Mutex::new(read), write_stream: Mutex::new(write), peer_endpoint, local_endpoint, } } } #[async_trait] impl Connection for WsConn where C: WebSocketCodec, { type Item = C::Item; fn peer_endpoint(&self) -> Result { Ok(self.peer_endpoint.clone()) } fn local_endpoint(&self) -> Result { Ok(self.local_endpoint.clone()) } async fn recv(&self) -> Result { self.read_stream.lock().await.recv().await } async fn send(&self, msg: Self::Item) -> Result<()> { self.write_stream.lock().await.send(msg).await } } /// Ws network listener implementation of the `Listener` [`ConnListener`] trait. pub struct WsListener { inner: TcpListener, config: ServerWsConfig, codec: C, #[cfg(feature = "tls")] tls_acceptor: Option, } #[async_trait] impl ConnListener for WsListener where C: WebSocketCodec + Clone, { type Item = C::Item; fn local_endpoint(&self) -> Result { match self.config.wss_config { Some(_) => Ok(Endpoint::new_wss_addr(self.inner.local_addr()?)), None => Ok(Endpoint::new_ws_addr(self.inner.local_addr()?)), } } async fn accept(&self) -> Result> { let (socket, _) = self.inner.accept().await?; socket.set_nodelay(self.config.tcp_config.nodelay)?; match &self.config.wss_config { #[cfg(feature = "tls")] Some(_) => match &self.tls_acceptor { Some(acceptor) => { let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?; let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?; let tls_conn = acceptor.accept(socket).await?.into(); let conn = async_tungstenite::accept_async(tls_conn).await?; Ok(Box::new(WsConn::new( WsStream::new_wss(conn, self.codec.clone()), peer_endpoint, local_endpoint, ))) } None => unreachable!(), }, None => { let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?; let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?; let conn = async_tungstenite::accept_async(socket).await?; Ok(Box::new(WsConn::new( WsStream::new_ws(conn, self.codec.clone()), peer_endpoint, local_endpoint, ))) } #[cfg(not(feature = "tls"))] _ => unreachable!(), } } } /// Connects to the given WS address and port. pub async fn dial(endpoint: &Endpoint, config: ClientWsConfig, codec: C) -> Result> where C: WebSocketCodec + Clone, { let addr = SocketAddr::try_from(endpoint.clone())?; let socket = TcpStream::connect(addr).await?; socket.set_nodelay(config.tcp_config.nodelay)?; match &config.wss_config { #[cfg(feature = "tls")] Some(conf) => { let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?; let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?; let connector = TlsConnector::from(Arc::new(conf.client_config.clone())); let altname = pki_types::ServerName::try_from(conf.dns_name.clone())?; let tls_conn = connector.connect(altname, socket).await?.into(); let (conn, _resp) = async_tungstenite::client_async(endpoint.to_string(), tls_conn).await?; Ok(WsConn::new( WsStream::new_wss(conn, codec), peer_endpoint, local_endpoint, )) } None => { let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?; let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?; let (conn, _resp) = async_tungstenite::client_async(endpoint.to_string(), socket).await?; Ok(WsConn::new( WsStream::new_ws(conn, codec), peer_endpoint, local_endpoint, )) } #[cfg(not(feature = "tls"))] _ => unreachable!(), } } /// Listens on the given WS address and port. pub async fn listen( endpoint: &Endpoint, config: ServerWsConfig, codec: C, ) -> Result> { let addr = SocketAddr::try_from(endpoint.clone())?; let listener = TcpListener::bind(addr).await?; match &config.wss_config { #[cfg(feature = "tls")] Some(conf) => { let acceptor = TlsAcceptor::from(Arc::new(conf.server_config.clone())); Ok(WsListener { inner: listener, config, codec, tls_acceptor: Some(acceptor), }) } None => Ok(WsListener { inner: listener, config, codec, #[cfg(feature = "tls")] tls_acceptor: None, }), #[cfg(not(feature = "tls"))] _ => unreachable!(), } } impl From> for Listener where C: WebSocketCodec + Clone, { fn from(listener: WsListener) -> Self { Box::new(listener) } } impl ToConn for WsConn where C: WebSocketCodec, { type Item = C::Item; fn to_conn(self) -> Conn { Box::new(self) } } impl ToListener for WsListener where C: WebSocketCodec + Clone, { type Item = C::Item; fn to_listener(self) -> Listener { self.into() } }