diff options
author | hozan23 <hozan23@karyontech.net> | 2024-05-27 00:49:25 +0200 |
---|---|---|
committer | hozan23 <hozan23@karyontech.net> | 2024-05-27 00:49:25 +0200 |
commit | 385d53ec53e750e342cce78edb793958edf5133e (patch) | |
tree | 897861166ed9f3d6add3c6ea4bd89da48985ea11 /net | |
parent | 3acb724ba1aafeaf37e24dada7c769bb4066444a (diff) |
net: finish TODOs in websocket implemention & clean up
Diffstat (limited to 'net')
-rw-r--r-- | net/src/stream/mod.rs | 2 | ||||
-rw-r--r-- | net/src/stream/websocket.rs | 77 | ||||
-rw-r--r-- | net/src/transports/ws.rs | 18 |
3 files changed, 86 insertions, 11 deletions
diff --git a/net/src/stream/mod.rs b/net/src/stream/mod.rs index ce48a77..a9aa1ef 100644 --- a/net/src/stream/mod.rs +++ b/net/src/stream/mod.rs @@ -3,7 +3,7 @@ mod buffer; mod websocket; #[cfg(feature = "ws")] -pub use websocket::WsStream; +pub use websocket::{WsStream, WriteWsStream, ReadWsStream}; use std::{ io::ErrorKind, diff --git a/net/src/stream/websocket.rs b/net/src/stream/websocket.rs index 2626d2f..eb4de21 100644 --- a/net/src/stream/websocket.rs +++ b/net/src/stream/websocket.rs @@ -4,7 +4,11 @@ use std::{ }; use async_tungstenite::tungstenite::Message; -use futures_util::{Sink, SinkExt, Stream, StreamExt}; +use futures_util::{ + stream::{SplitSink, SplitStream}, + Sink, SinkExt, Stream, StreamExt, TryStreamExt, +}; +use pin_project_lite::pin_project; #[cfg(all(feature = "smol", feature = "tls"))] use futures_rustls::TlsStream; @@ -28,7 +32,7 @@ pub struct WsStream<C> { impl<C> WsStream<C> where - C: WebSocketCodec, + C: WebSocketCodec + Clone, { pub fn new_ws(conn: WebSocketStream<TcpStream>, codec: C) -> Self { Self { @@ -45,6 +49,42 @@ where } } + pub fn split(self) -> (ReadWsStream<C>, WriteWsStream<C>) { + let (write, read) = self.inner.split(); + + ( + ReadWsStream { + codec: self.codec.clone(), + inner: read, + }, + WriteWsStream { + inner: write, + codec: self.codec, + }, + ) + } +} + +pin_project! { + pub struct ReadWsStream<C> { + #[pin] + inner: SplitStream<InnerWSConn>, + codec: C, + } +} + +pin_project! { + pub struct WriteWsStream<C> { + #[pin] + inner: SplitSink<InnerWSConn, Message>, + codec: C, + } +} + +impl<C> ReadWsStream<C> +where + C: WebSocketCodec, +{ pub async fn recv(&mut self) -> Result<C::Item> { match self.inner.next().await { Some(msg) => match self.codec.decode(&msg?)? { @@ -54,13 +94,46 @@ where None => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())), } } +} +impl<C> WriteWsStream<C> +where + C: WebSocketCodec, +{ pub async fn send(&mut self, msg: C::Item) -> Result<()> { let ws_msg = self.codec.encode(&msg)?; self.inner.send(ws_msg).await } } +impl<C> Sink<Message> for WriteWsStream<C> { + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { + self.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: Message) -> Result<()> { + self.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { + self.project().inner.poll_close(cx) + } +} + +impl<C> Stream for ReadWsStream<C> { + type Item = Result<Message>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + self.inner.try_poll_next_unpin(cx) + } +} + enum InnerWSConn { Plain(WebSocketStream<TcpStream>), #[cfg(feature = "tls")] diff --git a/net/src/transports/ws.rs b/net/src/transports/ws.rs index 6107999..f2fde90 100644 --- a/net/src/transports/ws.rs +++ b/net/src/transports/ws.rs @@ -25,7 +25,7 @@ use crate::{ connection::{Conn, Connection, ToConn}, endpoint::Endpoint, listener::{ConnListener, Listener, ToListener}, - stream::WsStream, + stream::{ReadWsStream, WriteWsStream, WsStream}, Result, }; @@ -62,20 +62,22 @@ pub struct ClientWsConfig { /// WS network connection implementation of the [`Connection`] trait. pub struct WsConn<C> { - // XXX: remove mutex - inner: Mutex<WsStream<C>>, + read_stream: Mutex<ReadWsStream<C>>, + write_stream: Mutex<WriteWsStream<C>>, peer_endpoint: Endpoint, local_endpoint: Endpoint, } impl<C> WsConn<C> where - C: WebSocketCodec, + C: WebSocketCodec + Clone, { /// Creates a new WsConn pub fn new(ws: WsStream<C>, peer_endpoint: Endpoint, local_endpoint: Endpoint) -> Self { + let (read, write) = ws.split(); Self { - inner: Mutex::new(ws), + read_stream: Mutex::new(read), + write_stream: Mutex::new(write), peer_endpoint, local_endpoint, } @@ -97,11 +99,11 @@ where } async fn recv(&self) -> Result<Self::Item> { - self.inner.lock().await.recv().await + self.read_stream.lock().await.recv().await } async fn send(&self, msg: Self::Item) -> Result<()> { - self.inner.lock().await.send(msg).await + self.write_stream.lock().await.send(msg).await } } @@ -169,7 +171,7 @@ where /// Connects to the given WS address and port. pub async fn dial<C>(endpoint: &Endpoint, config: ClientWsConfig, codec: C) -> Result<WsConn<C>> where - C: WebSocketCodec, + C: WebSocketCodec + Clone, { let addr = SocketAddr::try_from(endpoint.clone())?; let socket = TcpStream::connect(addr).await?; |