diff options
author | hozan23 <hozan23@karyontech.net> | 2024-04-11 10:19:20 +0200 |
---|---|---|
committer | hozan23 <hozan23@karyontech.net> | 2024-05-19 13:51:30 +0200 |
commit | 0992071a7f1a36424bcfaf1fbc84541ea041df1a (patch) | |
tree | 961d73218af672797d49f899289bef295bc56493 /net/src/stream/websocket.rs | |
parent | a69917ecd8272a4946cfd12c75bf8f8c075b0e50 (diff) |
add support for tokio & improve net crate api
Diffstat (limited to 'net/src/stream/websocket.rs')
-rw-r--r-- | net/src/stream/websocket.rs | 107 |
1 files changed, 107 insertions, 0 deletions
diff --git a/net/src/stream/websocket.rs b/net/src/stream/websocket.rs new file mode 100644 index 0000000..2552eaf --- /dev/null +++ b/net/src/stream/websocket.rs @@ -0,0 +1,107 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use async_tungstenite::tungstenite::Message; +use futures_util::{Sink, SinkExt, Stream, StreamExt}; + +#[cfg(feature = "smol")] +use futures_rustls::TlsStream; +#[cfg(feature = "tokio")] +use tokio_rustls::TlsStream; + +use karyon_core::async_runtime::net::TcpStream; + +use crate::{codec::WebSocketCodec, Error, Result}; + +#[cfg(feature = "tokio")] +type WebSocketStream<T> = + async_tungstenite::WebSocketStream<async_tungstenite::tokio::TokioAdapter<T>>; +#[cfg(feature = "smol")] +use async_tungstenite::WebSocketStream; + +pub struct WsStream<C> { + inner: InnerWSConn, + codec: C, +} + +impl<C> WsStream<C> +where + C: WebSocketCodec, +{ + pub fn new_ws(conn: WebSocketStream<TcpStream>, codec: C) -> Self { + Self { + inner: InnerWSConn::Plain(conn), + codec, + } + } + + pub fn new_wss(conn: WebSocketStream<TlsStream<TcpStream>>, codec: C) -> Self { + Self { + inner: InnerWSConn::Tls(conn), + codec, + } + } + + pub async fn recv(&mut self) -> Result<C::Item> { + match self.inner.next().await { + Some(msg) => self.codec.decode(&msg?), + None => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())), + } + } + + pub async fn send(&mut self, msg: C::Item) -> Result<()> { + let ws_msg = self.codec.encode(&msg)?; + self.inner.send(ws_msg).await + } +} + +enum InnerWSConn { + Plain(WebSocketStream<TcpStream>), + Tls(WebSocketStream<TlsStream<TcpStream>>), +} + +impl Sink<Message> for InnerWSConn { + type Error = Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_ready(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_ready(cx).map_err(Error::from), + } + } + + fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<()> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).start_send(item).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).start_send(item).map_err(Error::from), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_flush(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_flush(cx).map_err(Error::from), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_close(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_close(cx).map_err(Error::from), + } + .map_err(Error::from) + } +} + +impl Stream for InnerWSConn { + type Item = Result<Message>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_next(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_next(cx).map_err(Error::from), + } + } +} |