aboutsummaryrefslogtreecommitdiff
path: root/net/src/stream/websocket.rs
diff options
context:
space:
mode:
Diffstat (limited to 'net/src/stream/websocket.rs')
-rw-r--r--net/src/stream/websocket.rs77
1 files changed, 75 insertions, 2 deletions
diff --git a/net/src/stream/websocket.rs b/net/src/stream/websocket.rs
index 2626d2f..eb4de21 100644
--- a/net/src/stream/websocket.rs
+++ b/net/src/stream/websocket.rs
@@ -4,7 +4,11 @@ use std::{
};
use async_tungstenite::tungstenite::Message;
-use futures_util::{Sink, SinkExt, Stream, StreamExt};
+use futures_util::{
+ stream::{SplitSink, SplitStream},
+ Sink, SinkExt, Stream, StreamExt, TryStreamExt,
+};
+use pin_project_lite::pin_project;
#[cfg(all(feature = "smol", feature = "tls"))]
use futures_rustls::TlsStream;
@@ -28,7 +32,7 @@ pub struct WsStream<C> {
impl<C> WsStream<C>
where
- C: WebSocketCodec,
+ C: WebSocketCodec + Clone,
{
pub fn new_ws(conn: WebSocketStream<TcpStream>, codec: C) -> Self {
Self {
@@ -45,6 +49,42 @@ where
}
}
+ pub fn split(self) -> (ReadWsStream<C>, WriteWsStream<C>) {
+ let (write, read) = self.inner.split();
+
+ (
+ ReadWsStream {
+ codec: self.codec.clone(),
+ inner: read,
+ },
+ WriteWsStream {
+ inner: write,
+ codec: self.codec,
+ },
+ )
+ }
+}
+
+pin_project! {
+ pub struct ReadWsStream<C> {
+ #[pin]
+ inner: SplitStream<InnerWSConn>,
+ codec: C,
+ }
+}
+
+pin_project! {
+ pub struct WriteWsStream<C> {
+ #[pin]
+ inner: SplitSink<InnerWSConn, Message>,
+ codec: C,
+ }
+}
+
+impl<C> ReadWsStream<C>
+where
+ C: WebSocketCodec,
+{
pub async fn recv(&mut self) -> Result<C::Item> {
match self.inner.next().await {
Some(msg) => match self.codec.decode(&msg?)? {
@@ -54,13 +94,46 @@ where
None => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())),
}
}
+}
+impl<C> WriteWsStream<C>
+where
+ C: WebSocketCodec,
+{
pub async fn send(&mut self, msg: C::Item) -> Result<()> {
let ws_msg = self.codec.encode(&msg)?;
self.inner.send(ws_msg).await
}
}
+impl<C> Sink<Message> for WriteWsStream<C> {
+ type Error = Error;
+
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ self.project().inner.poll_ready(cx)
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: Message) -> Result<()> {
+ self.project().inner.start_send(item)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ self.project().inner.poll_flush(cx)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ self.project().inner.poll_close(cx)
+ }
+}
+
+impl<C> Stream for ReadWsStream<C> {
+ type Item = Result<Message>;
+
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ self.inner.try_poll_next_unpin(cx)
+ }
+}
+
enum InnerWSConn {
Plain(WebSocketStream<TcpStream>),
#[cfg(feature = "tls")]