From 0992071a7f1a36424bcfaf1fbc84541ea041df1a Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 11 Apr 2024 10:19:20 +0200 Subject: add support for tokio & improve net crate api --- net/src/stream/buffer.rs | 82 +++++++++++++++++++ net/src/stream/mod.rs | 191 ++++++++++++++++++++++++++++++++++++++++++++ net/src/stream/websocket.rs | 107 +++++++++++++++++++++++++ 3 files changed, 380 insertions(+) create mode 100644 net/src/stream/buffer.rs create mode 100644 net/src/stream/mod.rs create mode 100644 net/src/stream/websocket.rs (limited to 'net/src/stream') diff --git a/net/src/stream/buffer.rs b/net/src/stream/buffer.rs new file mode 100644 index 0000000..f211600 --- /dev/null +++ b/net/src/stream/buffer.rs @@ -0,0 +1,82 @@ +#[derive(Debug)] +pub struct Buffer { + inner: B, + len: usize, + cap: usize, +} + +impl Buffer +where + B: AsMut<[u8]> + AsRef<[u8]>, +{ + /// Constructs a new, empty Buffer. + pub fn new(b: B) -> Self { + Self { + cap: b.as_ref().len(), + inner: b, + len: 0, + } + } + + /// Returns the number of elements in the buffer. + #[allow(dead_code)] + pub fn len(&self) -> usize { + self.len + } + + /// Resizes the buffer in-place so that `len` is equal to `new_size`. + pub fn resize(&mut self, new_size: usize) { + assert!(self.cap > new_size); + self.len = new_size; + } + + /// Appends all elements in a slice to the buffer. + pub fn extend_from_slice(&mut self, bytes: &[u8]) { + let old_len = self.len; + self.resize(self.len + bytes.len()); + self.inner.as_mut()[old_len..bytes.len() + old_len].copy_from_slice(bytes); + } + + /// Shortens the buffer, dropping the first `cnt` bytes and keeping the + /// rest. + pub fn advance(&mut self, cnt: usize) { + assert!(self.len >= cnt); + self.inner.as_mut().rotate_left(cnt); + self.resize(self.len - cnt); + } + + /// Returns `true` if the buffer contains no elements. + pub fn is_empty(&self) -> bool { + self.len == 0 + } +} + +impl AsMut<[u8]> for Buffer +where + B: AsMut<[u8]> + AsRef<[u8]>, +{ + fn as_mut(&mut self) -> &mut [u8] { + &mut self.inner.as_mut()[..self.len] + } +} + +impl AsRef<[u8]> for Buffer +where + B: AsMut<[u8]> + AsRef<[u8]>, +{ + fn as_ref(&self) -> &[u8] { + &self.inner.as_ref()[..self.len] + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_buffer_advance() { + let mut buf = Buffer::new([0u8; 32]); + buf.extend_from_slice(&[1, 2, 3]); + assert_eq!([1, 2, 3], buf.as_ref()); + } +} diff --git a/net/src/stream/mod.rs b/net/src/stream/mod.rs new file mode 100644 index 0000000..9493b29 --- /dev/null +++ b/net/src/stream/mod.rs @@ -0,0 +1,191 @@ +mod buffer; +mod websocket; + +pub use websocket::WsStream; + +use std::{ + io::ErrorKind, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_util::{ + ready, + stream::{Stream, StreamExt}, + Sink, +}; +use pin_project_lite::pin_project; + +use karyon_core::async_runtime::io::{AsyncRead, AsyncWrite}; + +use crate::{ + codec::{Decoder, Encoder}, + Error, Result, +}; + +use buffer::Buffer; + +const BUFFER_SIZE: usize = 2048 * 2024; // 4MB +const INITIAL_BUFFER_SIZE: usize = 1024 * 1024; // 1MB + +pub struct ReadStream { + inner: T, + decoder: C, + buffer: Buffer<[u8; BUFFER_SIZE]>, +} + +impl ReadStream +where + T: AsyncRead + Unpin, + C: Decoder + Unpin, +{ + pub fn new(inner: T, decoder: C) -> Self { + Self { + inner, + decoder, + buffer: Buffer::new([0u8; BUFFER_SIZE]), + } + } + + pub async fn recv(&mut self) -> Result { + match self.next().await { + Some(m) => m, + None => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())), + } + } +} + +pin_project! { + pub struct WriteStream { + #[pin] + inner: T, + encoder: C, + high_water_mark: usize, + buffer: Buffer<[u8; BUFFER_SIZE]>, + } +} + +impl WriteStream +where + T: AsyncWrite + Unpin, + C: Encoder + Unpin, +{ + pub fn new(inner: T, encoder: C) -> Self { + Self { + inner, + encoder, + high_water_mark: 131072, + buffer: Buffer::new([0u8; BUFFER_SIZE]), + } + } +} + +impl Stream for ReadStream +where + T: AsyncRead + Unpin, + C: Decoder + Unpin, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = &mut *self; + + if let Some((n, item)) = this.decoder.decode(this.buffer.as_mut())? { + this.buffer.advance(n); + return Poll::Ready(Some(Ok(item))); + } + + let mut buf = [0u8; INITIAL_BUFFER_SIZE]; + #[cfg(feature = "tokio")] + let mut buf = tokio::io::ReadBuf::new(&mut buf); + + loop { + #[cfg(feature = "smol")] + let n = ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?; + #[cfg(feature = "smol")] + let bytes = &buf[..n]; + + #[cfg(feature = "tokio")] + ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?; + #[cfg(feature = "tokio")] + let bytes = buf.filled(); + #[cfg(feature = "tokio")] + let n = bytes.len(); + + this.buffer.extend_from_slice(bytes); + + match this.decoder.decode(this.buffer.as_mut())? { + Some((cn, item)) => { + this.buffer.advance(cn); + return Poll::Ready(Some(Ok(item))); + } + None if n == 0 => { + if this.buffer.is_empty() { + return Poll::Ready(None); + } else { + return Poll::Ready(Some(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "bytes remaining in read stream", + ) + .into()))); + } + } + _ => continue, + } + } + } +} + +impl Sink for WriteStream +where + T: AsyncWrite + Unpin, + C: Encoder + Unpin, +{ + type Error = Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = &mut *self; + while !this.buffer.is_empty() { + let n = ready!(Pin::new(&mut this.inner).poll_write(cx, this.buffer.as_ref()))?; + + if n == 0 { + return Poll::Ready(Err(std::io::Error::new( + ErrorKind::UnexpectedEof, + "End of file", + ) + .into())); + } + + this.buffer.advance(n); + } + + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: C::EnItem) -> Result<()> { + let this = &mut *self; + let mut buf = [0u8; INITIAL_BUFFER_SIZE]; + let n = this.encoder.encode(&item, &mut buf)?; + this.buffer.extend_from_slice(&buf[..n]); + Ok(()) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + ready!(self.as_mut().poll_ready(cx))?; + self.project().inner.poll_flush(cx).map_err(Into::into) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + #[cfg(feature = "smol")] + return self.project().inner.poll_close(cx).map_err(Error::from); + #[cfg(feature = "tokio")] + return self.project().inner.poll_shutdown(cx).map_err(Error::from); + } +} diff --git a/net/src/stream/websocket.rs b/net/src/stream/websocket.rs new file mode 100644 index 0000000..2552eaf --- /dev/null +++ b/net/src/stream/websocket.rs @@ -0,0 +1,107 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use async_tungstenite::tungstenite::Message; +use futures_util::{Sink, SinkExt, Stream, StreamExt}; + +#[cfg(feature = "smol")] +use futures_rustls::TlsStream; +#[cfg(feature = "tokio")] +use tokio_rustls::TlsStream; + +use karyon_core::async_runtime::net::TcpStream; + +use crate::{codec::WebSocketCodec, Error, Result}; + +#[cfg(feature = "tokio")] +type WebSocketStream = + async_tungstenite::WebSocketStream>; +#[cfg(feature = "smol")] +use async_tungstenite::WebSocketStream; + +pub struct WsStream { + inner: InnerWSConn, + codec: C, +} + +impl WsStream +where + C: WebSocketCodec, +{ + pub fn new_ws(conn: WebSocketStream, codec: C) -> Self { + Self { + inner: InnerWSConn::Plain(conn), + codec, + } + } + + pub fn new_wss(conn: WebSocketStream>, codec: C) -> Self { + Self { + inner: InnerWSConn::Tls(conn), + codec, + } + } + + pub async fn recv(&mut self) -> Result { + match self.inner.next().await { + Some(msg) => self.codec.decode(&msg?), + None => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())), + } + } + + pub async fn send(&mut self, msg: C::Item) -> Result<()> { + let ws_msg = self.codec.encode(&msg)?; + self.inner.send(ws_msg).await + } +} + +enum InnerWSConn { + Plain(WebSocketStream), + Tls(WebSocketStream>), +} + +impl Sink for InnerWSConn { + type Error = Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_ready(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_ready(cx).map_err(Error::from), + } + } + + fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<()> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).start_send(item).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).start_send(item).map_err(Error::from), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_flush(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_flush(cx).map_err(Error::from), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_close(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_close(cx).map_err(Error::from), + } + .map_err(Error::from) + } +} + +impl Stream for InnerWSConn { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_next(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_next(cx).map_err(Error::from), + } + } +} -- cgit v1.2.3