From 0992071a7f1a36424bcfaf1fbc84541ea041df1a Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 11 Apr 2024 10:19:20 +0200 Subject: add support for tokio & improve net crate api --- net/src/stream/mod.rs | 191 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 net/src/stream/mod.rs (limited to 'net/src/stream/mod.rs') diff --git a/net/src/stream/mod.rs b/net/src/stream/mod.rs new file mode 100644 index 0000000..9493b29 --- /dev/null +++ b/net/src/stream/mod.rs @@ -0,0 +1,191 @@ +mod buffer; +mod websocket; + +pub use websocket::WsStream; + +use std::{ + io::ErrorKind, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_util::{ + ready, + stream::{Stream, StreamExt}, + Sink, +}; +use pin_project_lite::pin_project; + +use karyon_core::async_runtime::io::{AsyncRead, AsyncWrite}; + +use crate::{ + codec::{Decoder, Encoder}, + Error, Result, +}; + +use buffer::Buffer; + +const BUFFER_SIZE: usize = 2048 * 2024; // 4MB +const INITIAL_BUFFER_SIZE: usize = 1024 * 1024; // 1MB + +pub struct ReadStream { + inner: T, + decoder: C, + buffer: Buffer<[u8; BUFFER_SIZE]>, +} + +impl ReadStream +where + T: AsyncRead + Unpin, + C: Decoder + Unpin, +{ + pub fn new(inner: T, decoder: C) -> Self { + Self { + inner, + decoder, + buffer: Buffer::new([0u8; BUFFER_SIZE]), + } + } + + pub async fn recv(&mut self) -> Result { + match self.next().await { + Some(m) => m, + None => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())), + } + } +} + +pin_project! { + pub struct WriteStream { + #[pin] + inner: T, + encoder: C, + high_water_mark: usize, + buffer: Buffer<[u8; BUFFER_SIZE]>, + } +} + +impl WriteStream +where + T: AsyncWrite + Unpin, + C: Encoder + Unpin, +{ + pub fn new(inner: T, encoder: C) -> Self { + Self { + inner, + encoder, + high_water_mark: 131072, + buffer: Buffer::new([0u8; BUFFER_SIZE]), + } + } +} + +impl Stream for ReadStream +where + T: AsyncRead + Unpin, + C: Decoder + Unpin, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = &mut *self; + + if let Some((n, item)) = this.decoder.decode(this.buffer.as_mut())? { + this.buffer.advance(n); + return Poll::Ready(Some(Ok(item))); + } + + let mut buf = [0u8; INITIAL_BUFFER_SIZE]; + #[cfg(feature = "tokio")] + let mut buf = tokio::io::ReadBuf::new(&mut buf); + + loop { + #[cfg(feature = "smol")] + let n = ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?; + #[cfg(feature = "smol")] + let bytes = &buf[..n]; + + #[cfg(feature = "tokio")] + ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?; + #[cfg(feature = "tokio")] + let bytes = buf.filled(); + #[cfg(feature = "tokio")] + let n = bytes.len(); + + this.buffer.extend_from_slice(bytes); + + match this.decoder.decode(this.buffer.as_mut())? { + Some((cn, item)) => { + this.buffer.advance(cn); + return Poll::Ready(Some(Ok(item))); + } + None if n == 0 => { + if this.buffer.is_empty() { + return Poll::Ready(None); + } else { + return Poll::Ready(Some(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "bytes remaining in read stream", + ) + .into()))); + } + } + _ => continue, + } + } + } +} + +impl Sink for WriteStream +where + T: AsyncWrite + Unpin, + C: Encoder + Unpin, +{ + type Error = Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = &mut *self; + while !this.buffer.is_empty() { + let n = ready!(Pin::new(&mut this.inner).poll_write(cx, this.buffer.as_ref()))?; + + if n == 0 { + return Poll::Ready(Err(std::io::Error::new( + ErrorKind::UnexpectedEof, + "End of file", + ) + .into())); + } + + this.buffer.advance(n); + } + + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: C::EnItem) -> Result<()> { + let this = &mut *self; + let mut buf = [0u8; INITIAL_BUFFER_SIZE]; + let n = this.encoder.encode(&item, &mut buf)?; + this.buffer.extend_from_slice(&buf[..n]); + Ok(()) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + ready!(self.as_mut().poll_ready(cx))?; + self.project().inner.poll_flush(cx).map_err(Into::into) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + #[cfg(feature = "smol")] + return self.project().inner.poll_close(cx).map_err(Error::from); + #[cfg(feature = "tokio")] + return self.project().inner.poll_shutdown(cx).map_err(Error::from); + } +} -- cgit v1.2.3