aboutsummaryrefslogtreecommitdiff
path: root/net/src/stream/websocket.rs
diff options
context:
space:
mode:
Diffstat (limited to 'net/src/stream/websocket.rs')
-rw-r--r--net/src/stream/websocket.rs107
1 files changed, 107 insertions, 0 deletions
diff --git a/net/src/stream/websocket.rs b/net/src/stream/websocket.rs
new file mode 100644
index 0000000..2552eaf
--- /dev/null
+++ b/net/src/stream/websocket.rs
@@ -0,0 +1,107 @@
+use std::{
+ pin::Pin,
+ task::{Context, Poll},
+};
+
+use async_tungstenite::tungstenite::Message;
+use futures_util::{Sink, SinkExt, Stream, StreamExt};
+
+#[cfg(feature = "smol")]
+use futures_rustls::TlsStream;
+#[cfg(feature = "tokio")]
+use tokio_rustls::TlsStream;
+
+use karyon_core::async_runtime::net::TcpStream;
+
+use crate::{codec::WebSocketCodec, Error, Result};
+
+#[cfg(feature = "tokio")]
+type WebSocketStream<T> =
+ async_tungstenite::WebSocketStream<async_tungstenite::tokio::TokioAdapter<T>>;
+#[cfg(feature = "smol")]
+use async_tungstenite::WebSocketStream;
+
+pub struct WsStream<C> {
+ inner: InnerWSConn,
+ codec: C,
+}
+
+impl<C> WsStream<C>
+where
+ C: WebSocketCodec,
+{
+ pub fn new_ws(conn: WebSocketStream<TcpStream>, codec: C) -> Self {
+ Self {
+ inner: InnerWSConn::Plain(conn),
+ codec,
+ }
+ }
+
+ pub fn new_wss(conn: WebSocketStream<TlsStream<TcpStream>>, codec: C) -> Self {
+ Self {
+ inner: InnerWSConn::Tls(conn),
+ codec,
+ }
+ }
+
+ pub async fn recv(&mut self) -> Result<C::Item> {
+ match self.inner.next().await {
+ Some(msg) => self.codec.decode(&msg?),
+ None => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())),
+ }
+ }
+
+ pub async fn send(&mut self, msg: C::Item) -> Result<()> {
+ let ws_msg = self.codec.encode(&msg)?;
+ self.inner.send(ws_msg).await
+ }
+}
+
+enum InnerWSConn {
+ Plain(WebSocketStream<TcpStream>),
+ Tls(WebSocketStream<TlsStream<TcpStream>>),
+}
+
+impl Sink<Message> for InnerWSConn {
+ type Error = Error;
+
+ fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ match &mut *self {
+ InnerWSConn::Plain(s) => Pin::new(s).poll_ready(cx).map_err(Error::from),
+ InnerWSConn::Tls(s) => Pin::new(s).poll_ready(cx).map_err(Error::from),
+ }
+ }
+
+ fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<()> {
+ match &mut *self {
+ InnerWSConn::Plain(s) => Pin::new(s).start_send(item).map_err(Error::from),
+ InnerWSConn::Tls(s) => Pin::new(s).start_send(item).map_err(Error::from),
+ }
+ }
+
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ match &mut *self {
+ InnerWSConn::Plain(s) => Pin::new(s).poll_flush(cx).map_err(Error::from),
+ InnerWSConn::Tls(s) => Pin::new(s).poll_flush(cx).map_err(Error::from),
+ }
+ }
+
+ fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ match &mut *self {
+ InnerWSConn::Plain(s) => Pin::new(s).poll_close(cx).map_err(Error::from),
+ InnerWSConn::Tls(s) => Pin::new(s).poll_close(cx).map_err(Error::from),
+ }
+ .map_err(Error::from)
+ }
+}
+
+impl Stream for InnerWSConn {
+ type Item = Result<Message>;
+
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ match &mut *self {
+ InnerWSConn::Plain(s) => Pin::new(s).poll_next(cx).map_err(Error::from),
+ InnerWSConn::Tls(s) => Pin::new(s).poll_next(cx).map_err(Error::from),
+ }
+ }
+}