From 340957fec147f4429796413f27bbd9b84ba6f141 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 14 Mar 2024 17:01:59 +0100 Subject: net: add support for websocket protocol --- net/Cargo.toml | 2 + net/src/error.rs | 3 ++ net/src/lib.rs | 2 +- net/src/transports/mod.rs | 1 + net/src/transports/ws.rs | 112 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 net/src/transports/ws.rs (limited to 'net') diff --git a/net/Cargo.toml b/net/Cargo.toml index 0b6534c..fe209cd 100644 --- a/net/Cargo.toml +++ b/net/Cargo.toml @@ -15,3 +15,5 @@ bincode = { version="2.0.0-rc.3", features = ["derive"]} thiserror = "1.0.58" url = "2.5.0" futures-rustls = "0.25.1" +async-tungstenite = "0.25.0" +ws_stream_tungstenite = "0.13.0" diff --git a/net/src/error.rs b/net/src/error.rs index be90a03..6e04a12 100644 --- a/net/src/error.rs +++ b/net/src/error.rs @@ -28,6 +28,9 @@ pub enum Error { #[error(transparent)] ChannelRecv(#[from] smol::channel::RecvError), + #[error("Ws Error: {0}")] + WsError(#[from] async_tungstenite::tungstenite::Error), + #[error("Tls Error: {0}")] Rustls(#[from] futures_rustls::rustls::Error), diff --git a/net/src/lib.rs b/net/src/lib.rs index 5f1c8a6..c1d72b2 100644 --- a/net/src/lib.rs +++ b/net/src/lib.rs @@ -8,7 +8,7 @@ pub use { connection::{dial, Conn, Connection, ToConn}, endpoint::{Addr, Endpoint, Port}, listener::{listen, ConnListener, Listener, ToListener}, - transports::{tcp, tls, udp, unix}, + transports::{tcp, tls, udp, unix, ws}, }; use error::{Error, Result}; diff --git a/net/src/transports/mod.rs b/net/src/transports/mod.rs index ac23021..14ef6f3 100644 --- a/net/src/transports/mod.rs +++ b/net/src/transports/mod.rs @@ -2,3 +2,4 @@ pub mod tcp; pub mod tls; pub mod udp; pub mod unix; +pub mod ws; diff --git a/net/src/transports/ws.rs b/net/src/transports/ws.rs new file mode 100644 index 0000000..eaf3b9b --- /dev/null +++ b/net/src/transports/ws.rs @@ -0,0 +1,112 @@ +use std::net::SocketAddr; + +use async_trait::async_trait; +use smol::{ + io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, + lock::Mutex, + net::{TcpListener, TcpStream}, +}; + +use ws_stream_tungstenite::WsStream; + +use crate::{ + connection::{Connection, ToConn}, + endpoint::Endpoint, + listener::{ConnListener, ToListener}, + Error, Result, +}; + +/// WS network connection implementation of the [`Connection`] trait. +pub struct WsConn { + inner: TcpStream, + read: Mutex>>, + write: Mutex>>, +} + +impl WsConn { + /// Creates a new WsConn + pub fn new(inner: TcpStream, conn: WsStream) -> Self { + let (read, write) = split(conn); + Self { + inner, + read: Mutex::new(read), + write: Mutex::new(write), + } + } +} + +#[async_trait] +impl Connection for WsConn { + fn peer_endpoint(&self) -> Result { + Ok(Endpoint::new_ws_addr(&self.inner.peer_addr()?)) + } + + fn local_endpoint(&self) -> Result { + Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?)) + } + + async fn read(&self, buf: &mut [u8]) -> Result { + self.read.lock().await.read(buf).await.map_err(Error::from) + } + + async fn write(&self, buf: &[u8]) -> Result { + self.write + .lock() + .await + .write(buf) + .await + .map_err(Error::from) + } +} + +/// Ws network listener implementation of the `Listener` [`ConnListener`] trait. +pub struct WsListener { + inner: TcpListener, +} + +#[async_trait] +impl ConnListener for WsListener { + fn local_endpoint(&self) -> Result { + Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?)) + } + + async fn accept(&self) -> Result> { + let (stream, _) = self.inner.accept().await?; + let conn = async_tungstenite::accept_async(stream.clone()).await?; + Ok(Box::new(WsConn::new(stream, WsStream::new(conn)))) + } +} + +/// Connects to the given WS address and port. +pub async fn dial(endpoint: &Endpoint) -> Result { + let addr = SocketAddr::try_from(endpoint.clone())?; + let stream = TcpStream::connect(addr).await?; + let (conn, _resp) = + async_tungstenite::client_async(endpoint.to_string(), stream.clone()).await?; + Ok(WsConn::new(stream, WsStream::new(conn))) +} + +/// Listens on the given WS address and port. +pub async fn listen(endpoint: &Endpoint) -> Result { + let addr = SocketAddr::try_from(endpoint.clone())?; + let listener = TcpListener::bind(addr).await?; + Ok(WsListener { inner: listener }) +} + +impl From for Box { + fn from(listener: WsListener) -> Self { + Box::new(listener) + } +} + +impl ToConn for WsConn { + fn to_conn(self) -> Box { + Box::new(self) + } +} + +impl ToListener for WsListener { + fn to_listener(self) -> Box { + self.into() + } +} -- cgit v1.2.3