From 0992071a7f1a36424bcfaf1fbc84541ea041df1a Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 11 Apr 2024 10:19:20 +0200 Subject: add support for tokio & improve net crate api --- net/src/stream/websocket.rs | 107 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 net/src/stream/websocket.rs (limited to 'net/src/stream/websocket.rs') diff --git a/net/src/stream/websocket.rs b/net/src/stream/websocket.rs new file mode 100644 index 0000000..2552eaf --- /dev/null +++ b/net/src/stream/websocket.rs @@ -0,0 +1,107 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use async_tungstenite::tungstenite::Message; +use futures_util::{Sink, SinkExt, Stream, StreamExt}; + +#[cfg(feature = "smol")] +use futures_rustls::TlsStream; +#[cfg(feature = "tokio")] +use tokio_rustls::TlsStream; + +use karyon_core::async_runtime::net::TcpStream; + +use crate::{codec::WebSocketCodec, Error, Result}; + +#[cfg(feature = "tokio")] +type WebSocketStream = + async_tungstenite::WebSocketStream>; +#[cfg(feature = "smol")] +use async_tungstenite::WebSocketStream; + +pub struct WsStream { + inner: InnerWSConn, + codec: C, +} + +impl WsStream +where + C: WebSocketCodec, +{ + pub fn new_ws(conn: WebSocketStream, codec: C) -> Self { + Self { + inner: InnerWSConn::Plain(conn), + codec, + } + } + + pub fn new_wss(conn: WebSocketStream>, codec: C) -> Self { + Self { + inner: InnerWSConn::Tls(conn), + codec, + } + } + + pub async fn recv(&mut self) -> Result { + match self.inner.next().await { + Some(msg) => self.codec.decode(&msg?), + None => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())), + } + } + + pub async fn send(&mut self, msg: C::Item) -> Result<()> { + let ws_msg = self.codec.encode(&msg)?; + self.inner.send(ws_msg).await + } +} + +enum InnerWSConn { + Plain(WebSocketStream), + Tls(WebSocketStream>), +} + +impl Sink for InnerWSConn { + type Error = Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_ready(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_ready(cx).map_err(Error::from), + } + } + + fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<()> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).start_send(item).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).start_send(item).map_err(Error::from), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_flush(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_flush(cx).map_err(Error::from), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_close(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_close(cx).map_err(Error::from), + } + .map_err(Error::from) + } +} + +impl Stream for InnerWSConn { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_next(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_next(cx).map_err(Error::from), + } + } +} -- cgit v1.2.3