aboutsummaryrefslogtreecommitdiff
path: root/net/src
diff options
context:
space:
mode:
Diffstat (limited to 'net/src')
-rw-r--r--net/src/stream/mod.rs2
-rw-r--r--net/src/stream/websocket.rs77
-rw-r--r--net/src/transports/ws.rs18
3 files changed, 86 insertions, 11 deletions
diff --git a/net/src/stream/mod.rs b/net/src/stream/mod.rs
index ce48a77..a9aa1ef 100644
--- a/net/src/stream/mod.rs
+++ b/net/src/stream/mod.rs
@@ -3,7 +3,7 @@ mod buffer;
mod websocket;
#[cfg(feature = "ws")]
-pub use websocket::WsStream;
+pub use websocket::{WsStream, WriteWsStream, ReadWsStream};
use std::{
io::ErrorKind,
diff --git a/net/src/stream/websocket.rs b/net/src/stream/websocket.rs
index 2626d2f..eb4de21 100644
--- a/net/src/stream/websocket.rs
+++ b/net/src/stream/websocket.rs
@@ -4,7 +4,11 @@ use std::{
};
use async_tungstenite::tungstenite::Message;
-use futures_util::{Sink, SinkExt, Stream, StreamExt};
+use futures_util::{
+ stream::{SplitSink, SplitStream},
+ Sink, SinkExt, Stream, StreamExt, TryStreamExt,
+};
+use pin_project_lite::pin_project;
#[cfg(all(feature = "smol", feature = "tls"))]
use futures_rustls::TlsStream;
@@ -28,7 +32,7 @@ pub struct WsStream<C> {
impl<C> WsStream<C>
where
- C: WebSocketCodec,
+ C: WebSocketCodec + Clone,
{
pub fn new_ws(conn: WebSocketStream<TcpStream>, codec: C) -> Self {
Self {
@@ -45,6 +49,42 @@ where
}
}
+ pub fn split(self) -> (ReadWsStream<C>, WriteWsStream<C>) {
+ let (write, read) = self.inner.split();
+
+ (
+ ReadWsStream {
+ codec: self.codec.clone(),
+ inner: read,
+ },
+ WriteWsStream {
+ inner: write,
+ codec: self.codec,
+ },
+ )
+ }
+}
+
+pin_project! {
+ pub struct ReadWsStream<C> {
+ #[pin]
+ inner: SplitStream<InnerWSConn>,
+ codec: C,
+ }
+}
+
+pin_project! {
+ pub struct WriteWsStream<C> {
+ #[pin]
+ inner: SplitSink<InnerWSConn, Message>,
+ codec: C,
+ }
+}
+
+impl<C> ReadWsStream<C>
+where
+ C: WebSocketCodec,
+{
pub async fn recv(&mut self) -> Result<C::Item> {
match self.inner.next().await {
Some(msg) => match self.codec.decode(&msg?)? {
@@ -54,13 +94,46 @@ where
None => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())),
}
}
+}
+impl<C> WriteWsStream<C>
+where
+ C: WebSocketCodec,
+{
pub async fn send(&mut self, msg: C::Item) -> Result<()> {
let ws_msg = self.codec.encode(&msg)?;
self.inner.send(ws_msg).await
}
}
+impl<C> Sink<Message> for WriteWsStream<C> {
+ type Error = Error;
+
+ fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ self.project().inner.poll_ready(cx)
+ }
+
+ fn start_send(self: Pin<&mut Self>, item: Message) -> Result<()> {
+ self.project().inner.start_send(item)
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ self.project().inner.poll_flush(cx)
+ }
+
+ fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ self.project().inner.poll_close(cx)
+ }
+}
+
+impl<C> Stream for ReadWsStream<C> {
+ type Item = Result<Message>;
+
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ self.inner.try_poll_next_unpin(cx)
+ }
+}
+
enum InnerWSConn {
Plain(WebSocketStream<TcpStream>),
#[cfg(feature = "tls")]
diff --git a/net/src/transports/ws.rs b/net/src/transports/ws.rs
index 6107999..f2fde90 100644
--- a/net/src/transports/ws.rs
+++ b/net/src/transports/ws.rs
@@ -25,7 +25,7 @@ use crate::{
connection::{Conn, Connection, ToConn},
endpoint::Endpoint,
listener::{ConnListener, Listener, ToListener},
- stream::WsStream,
+ stream::{ReadWsStream, WriteWsStream, WsStream},
Result,
};
@@ -62,20 +62,22 @@ pub struct ClientWsConfig {
/// WS network connection implementation of the [`Connection`] trait.
pub struct WsConn<C> {
- // XXX: remove mutex
- inner: Mutex<WsStream<C>>,
+ read_stream: Mutex<ReadWsStream<C>>,
+ write_stream: Mutex<WriteWsStream<C>>,
peer_endpoint: Endpoint,
local_endpoint: Endpoint,
}
impl<C> WsConn<C>
where
- C: WebSocketCodec,
+ C: WebSocketCodec + Clone,
{
/// Creates a new WsConn
pub fn new(ws: WsStream<C>, peer_endpoint: Endpoint, local_endpoint: Endpoint) -> Self {
+ let (read, write) = ws.split();
Self {
- inner: Mutex::new(ws),
+ read_stream: Mutex::new(read),
+ write_stream: Mutex::new(write),
peer_endpoint,
local_endpoint,
}
@@ -97,11 +99,11 @@ where
}
async fn recv(&self) -> Result<Self::Item> {
- self.inner.lock().await.recv().await
+ self.read_stream.lock().await.recv().await
}
async fn send(&self, msg: Self::Item) -> Result<()> {
- self.inner.lock().await.send(msg).await
+ self.write_stream.lock().await.send(msg).await
}
}
@@ -169,7 +171,7 @@ where
/// Connects to the given WS address and port.
pub async fn dial<C>(endpoint: &Endpoint, config: ClientWsConfig, codec: C) -> Result<WsConn<C>>
where
- C: WebSocketCodec,
+ C: WebSocketCodec + Clone,
{
let addr = SocketAddr::try_from(endpoint.clone())?;
let socket = TcpStream::connect(addr).await?;