From 340957fec147f4429796413f27bbd9b84ba6f141 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 14 Mar 2024 17:01:59 +0100 Subject: net: add support for websocket protocol --- net/src/transports/ws.rs | 112 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 net/src/transports/ws.rs (limited to 'net/src/transports/ws.rs') diff --git a/net/src/transports/ws.rs b/net/src/transports/ws.rs new file mode 100644 index 0000000..eaf3b9b --- /dev/null +++ b/net/src/transports/ws.rs @@ -0,0 +1,112 @@ +use std::net::SocketAddr; + +use async_trait::async_trait; +use smol::{ + io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, + lock::Mutex, + net::{TcpListener, TcpStream}, +}; + +use ws_stream_tungstenite::WsStream; + +use crate::{ + connection::{Connection, ToConn}, + endpoint::Endpoint, + listener::{ConnListener, ToListener}, + Error, Result, +}; + +/// WS network connection implementation of the [`Connection`] trait. +pub struct WsConn { + inner: TcpStream, + read: Mutex>>, + write: Mutex>>, +} + +impl WsConn { + /// Creates a new WsConn + pub fn new(inner: TcpStream, conn: WsStream) -> Self { + let (read, write) = split(conn); + Self { + inner, + read: Mutex::new(read), + write: Mutex::new(write), + } + } +} + +#[async_trait] +impl Connection for WsConn { + fn peer_endpoint(&self) -> Result { + Ok(Endpoint::new_ws_addr(&self.inner.peer_addr()?)) + } + + fn local_endpoint(&self) -> Result { + Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?)) + } + + async fn read(&self, buf: &mut [u8]) -> Result { + self.read.lock().await.read(buf).await.map_err(Error::from) + } + + async fn write(&self, buf: &[u8]) -> Result { + self.write + .lock() + .await + .write(buf) + .await + .map_err(Error::from) + } +} + +/// Ws network listener implementation of the `Listener` [`ConnListener`] trait. +pub struct WsListener { + inner: TcpListener, +} + +#[async_trait] +impl ConnListener for WsListener { + fn local_endpoint(&self) -> Result { + Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?)) + } + + async fn accept(&self) -> Result> { + let (stream, _) = self.inner.accept().await?; + let conn = async_tungstenite::accept_async(stream.clone()).await?; + Ok(Box::new(WsConn::new(stream, WsStream::new(conn)))) + } +} + +/// Connects to the given WS address and port. +pub async fn dial(endpoint: &Endpoint) -> Result { + let addr = SocketAddr::try_from(endpoint.clone())?; + let stream = TcpStream::connect(addr).await?; + let (conn, _resp) = + async_tungstenite::client_async(endpoint.to_string(), stream.clone()).await?; + Ok(WsConn::new(stream, WsStream::new(conn))) +} + +/// Listens on the given WS address and port. +pub async fn listen(endpoint: &Endpoint) -> Result { + let addr = SocketAddr::try_from(endpoint.clone())?; + let listener = TcpListener::bind(addr).await?; + Ok(WsListener { inner: listener }) +} + +impl From for Box { + fn from(listener: WsListener) -> Self { + Box::new(listener) + } +} + +impl ToConn for WsConn { + fn to_conn(self) -> Box { + Box::new(self) + } +} + +impl ToListener for WsListener { + fn to_listener(self) -> Box { + self.into() + } +} -- cgit v1.2.3