aboutsummaryrefslogtreecommitdiff
path: root/net/src/stream
diff options
context:
space:
mode:
authorhozan23 <hozan23@karyontech.net>2024-04-11 10:19:20 +0200
committerhozan23 <hozan23@karyontech.net>2024-05-19 13:51:30 +0200
commit0992071a7f1a36424bcfaf1fbc84541ea041df1a (patch)
tree961d73218af672797d49f899289bef295bc56493 /net/src/stream
parenta69917ecd8272a4946cfd12c75bf8f8c075b0e50 (diff)
add support for tokio & improve net crate api
Diffstat (limited to 'net/src/stream')
-rw-r--r--net/src/stream/buffer.rs82
-rw-r--r--net/src/stream/mod.rs191
-rw-r--r--net/src/stream/websocket.rs107
3 files changed, 380 insertions, 0 deletions
diff --git a/net/src/stream/buffer.rs b/net/src/stream/buffer.rs
new file mode 100644
index 0000000..f211600
--- /dev/null
+++ b/net/src/stream/buffer.rs
@@ -0,0 +1,82 @@
+#[derive(Debug)]
+pub struct Buffer<B> {
+ inner: B,
+ len: usize,
+ cap: usize,
+}
+
+impl<B> Buffer<B>
+where
+ B: AsMut<[u8]> + AsRef<[u8]>,
+{
+ /// Constructs a new, empty Buffer<B>.
+ pub fn new(b: B) -> Self {
+ Self {
+ cap: b.as_ref().len(),
+ inner: b,
+ len: 0,
+ }
+ }
+
+ /// Returns the number of elements in the buffer.
+ #[allow(dead_code)]
+ pub fn len(&self) -> usize {
+ self.len
+ }
+
+ /// Resizes the buffer in-place so that `len` is equal to `new_size`.
+ pub fn resize(&mut self, new_size: usize) {
+ assert!(self.cap > new_size);
+ self.len = new_size;
+ }
+
+ /// Appends all elements in a slice to the buffer.
+ pub fn extend_from_slice(&mut self, bytes: &[u8]) {
+ let old_len = self.len;
+ self.resize(self.len + bytes.len());
+ self.inner.as_mut()[old_len..bytes.len() + old_len].copy_from_slice(bytes);
+ }
+
+ /// Shortens the buffer, dropping the first `cnt` bytes and keeping the
+ /// rest.
+ pub fn advance(&mut self, cnt: usize) {
+ assert!(self.len >= cnt);
+ self.inner.as_mut().rotate_left(cnt);
+ self.resize(self.len - cnt);
+ }
+
+ /// Returns `true` if the buffer contains no elements.
+ pub fn is_empty(&self) -> bool {
+ self.len == 0
+ }
+}
+
+impl<B> AsMut<[u8]> for Buffer<B>
+where
+ B: AsMut<[u8]> + AsRef<[u8]>,
+{
+ fn as_mut(&mut self) -> &mut [u8] {
+ &mut self.inner.as_mut()[..self.len]
+ }
+}
+
+impl<B> AsRef<[u8]> for Buffer<B>
+where
+ B: AsMut<[u8]> + AsRef<[u8]>,
+{
+ fn as_ref(&self) -> &[u8] {
+ &self.inner.as_ref()[..self.len]
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_buffer_advance() {
+ let mut buf = Buffer::new([0u8; 32]);
+ buf.extend_from_slice(&[1, 2, 3]);
+ assert_eq!([1, 2, 3], buf.as_ref());
+ }
+}
diff --git a/net/src/stream/mod.rs b/net/src/stream/mod.rs
new file mode 100644
index 0000000..9493b29
--- /dev/null
+++ b/net/src/stream/mod.rs
@@ -0,0 +1,191 @@
+mod buffer;
+mod websocket;
+
+pub use websocket::WsStream;
+
+use std::{
+ io::ErrorKind,
+ pin::Pin,
+ task::{Context, Poll},
+};
+
+use futures_util::{
+ ready,
+ stream::{Stream, StreamExt},
+ Sink,
+};
+use pin_project_lite::pin_project;
+
+use karyon_core::async_runtime::io::{AsyncRead, AsyncWrite};
+
+use crate::{
+ codec::{Decoder, Encoder},
+ Error, Result,
+};
+
+use buffer::Buffer;
+
+const BUFFER_SIZE: usize = 2048 * 2024; // 4MB
+const INITIAL_BUFFER_SIZE: usize = 1024 * 1024; // 1MB
+
+pub struct ReadStream<T, C> {
+ inner: T,
+ decoder: C,
+ buffer: Buffer<[u8; BUFFER_SIZE]>,
+}
+
+impl<T, C> ReadStream<T, C>
+where
+ T: AsyncRead + Unpin,
+ C: Decoder + Unpin,
+{
+ pub fn new(inner: T, decoder: C) -> Self {
+ Self {
+ inner,
+ decoder,
+ buffer: Buffer::new([0u8; BUFFER_SIZE]),
+ }
+ }
+
+ pub async fn recv(&mut self) -> Result<C::DeItem> {
+ match self.next().await {
+ Some(m) => m,
+ None => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())),
+ }
+ }
+}
+
+pin_project! {
+ pub struct WriteStream<T, C> {
+ #[pin]
+ inner: T,
+ encoder: C,
+ high_water_mark: usize,
+ buffer: Buffer<[u8; BUFFER_SIZE]>,
+ }
+}
+
+impl<T, C> WriteStream<T, C>
+where
+ T: AsyncWrite + Unpin,
+ C: Encoder + Unpin,
+{
+ pub fn new(inner: T, encoder: C) -> Self {
+ Self {
+ inner,
+ encoder,
+ high_water_mark: 131072,
+ buffer: Buffer::new([0u8; BUFFER_SIZE]),
+ }
+ }
+}
+
+impl<T, C> Stream for ReadStream<T, C>
+where
+ T: AsyncRead + Unpin,
+ C: Decoder + Unpin,
+{
+ type Item = Result<C::DeItem>;
+
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ let this = &mut *self;
+
+ if let Some((n, item)) = this.decoder.decode(this.buffer.as_mut())? {
+ this.buffer.advance(n);
+ return Poll::Ready(Some(Ok(item)));
+ }
+
+ let mut buf = [0u8; INITIAL_BUFFER_SIZE];
+ #[cfg(feature = "tokio")]
+ let mut buf = tokio::io::ReadBuf::new(&mut buf);
+
+ loop {
+ #[cfg(feature = "smol")]
+ let n = ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?;
+ #[cfg(feature = "smol")]
+ let bytes = &buf[..n];
+
+ #[cfg(feature = "tokio")]
+ ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?;
+ #[cfg(feature = "tokio")]
+ let bytes = buf.filled();
+ #[cfg(feature = "tokio")]
+ let n = bytes.len();
+
+ this.buffer.extend_from_slice(bytes);
+
+ match this.decoder.decode(this.buffer.as_mut())? {
+ Some((cn, item)) => {
+ this.buffer.advance(cn);
+ return Poll::Ready(Some(Ok(item)));
+ }
+ None if n == 0 => {
+ if this.buffer.is_empty() {
+ return Poll::Ready(None);
+ } else {
+ return Poll::Ready(Some(Err(std::io::Error::new(
+ std::io::ErrorKind::UnexpectedEof,
+ "bytes remaining in read stream",
+ )
+ .into())));
+ }
+ }
+ _ => continue,
+ }
+ }
+ }
+}
+
+impl<T, C> Sink<C::EnItem> for WriteStream<T, C>
+where
+ T: AsyncWrite + Unpin,
+ C: Encoder + Unpin,
+{
+ type Error = Error;
+
+ fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
+ let this = &mut *self;
+ while !this.buffer.is_empty() {
+ let n = ready!(Pin::new(&mut this.inner).poll_write(cx, this.buffer.as_ref()))?;
+
+ if n == 0 {
+ return Poll::Ready(Err(std::io::Error::new(
+ ErrorKind::UnexpectedEof,
+ "End of file",
+ )
+ .into()));
+ }
+
+ this.buffer.advance(n);
+ }
+
+ Poll::Ready(Ok(()))
+ }
+
+ fn start_send(mut self: Pin<&mut Self>, item: C::EnItem) -> Result<()> {
+ let this = &mut *self;
+ let mut buf = [0u8; INITIAL_BUFFER_SIZE];
+ let n = this.encoder.encode(&item, &mut buf)?;
+ this.buffer.extend_from_slice(&buf[..n]);
+ Ok(())
+ }
+
+ fn poll_flush(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context,
+ ) -> Poll<std::result::Result<(), Self::Error>> {
+ ready!(self.as_mut().poll_ready(cx))?;
+ self.project().inner.poll_flush(cx).map_err(Into::into)
+ }
+
+ fn poll_close(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context,
+ ) -> Poll<std::result::Result<(), Self::Error>> {
+ ready!(self.as_mut().poll_flush(cx))?;
+ #[cfg(feature = "smol")]
+ return self.project().inner.poll_close(cx).map_err(Error::from);
+ #[cfg(feature = "tokio")]
+ return self.project().inner.poll_shutdown(cx).map_err(Error::from);
+ }
+}
diff --git a/net/src/stream/websocket.rs b/net/src/stream/websocket.rs
new file mode 100644
index 0000000..2552eaf
--- /dev/null
+++ b/net/src/stream/websocket.rs
@@ -0,0 +1,107 @@
+use std::{
+ pin::Pin,
+ task::{Context, Poll},
+};
+
+use async_tungstenite::tungstenite::Message;
+use futures_util::{Sink, SinkExt, Stream, StreamExt};
+
+#[cfg(feature = "smol")]
+use futures_rustls::TlsStream;
+#[cfg(feature = "tokio")]
+use tokio_rustls::TlsStream;
+
+use karyon_core::async_runtime::net::TcpStream;
+
+use crate::{codec::WebSocketCodec, Error, Result};
+
+#[cfg(feature = "tokio")]
+type WebSocketStream<T> =
+ async_tungstenite::WebSocketStream<async_tungstenite::tokio::TokioAdapter<T>>;
+#[cfg(feature = "smol")]
+use async_tungstenite::WebSocketStream;
+
+pub struct WsStream<C> {
+ inner: InnerWSConn,
+ codec: C,
+}
+
+impl<C> WsStream<C>
+where
+ C: WebSocketCodec,
+{
+ pub fn new_ws(conn: WebSocketStream<TcpStream>, codec: C) -> Self {
+ Self {
+ inner: InnerWSConn::Plain(conn),
+ codec,
+ }
+ }
+
+ pub fn new_wss(conn: WebSocketStream<TlsStream<TcpStream>>, codec: C) -> Self {
+ Self {
+ inner: InnerWSConn::Tls(conn),
+ codec,
+ }
+ }
+
+ pub async fn recv(&mut self) -> Result<C::Item> {
+ match self.inner.next().await {
+ Some(msg) => self.codec.decode(&msg?),
+ None => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())),
+ }
+ }
+
+ pub async fn send(&mut self, msg: C::Item) -> Result<()> {
+ let ws_msg = self.codec.encode(&msg)?;
+ self.inner.send(ws_msg).await
+ }
+}
+
+enum InnerWSConn {
+ Plain(WebSocketStream<TcpStream>),
+ Tls(WebSocketStream<TlsStream<TcpStream>>),
+}
+
+impl Sink<Message> for InnerWSConn {
+ type Error = Error;
+
+ fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ match &mut *self {
+ InnerWSConn::Plain(s) => Pin::new(s).poll_ready(cx).map_err(Error::from),
+ InnerWSConn::Tls(s) => Pin::new(s).poll_ready(cx).map_err(Error::from),
+ }
+ }
+
+ fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<()> {
+ match &mut *self {
+ InnerWSConn::Plain(s) => Pin::new(s).start_send(item).map_err(Error::from),
+ InnerWSConn::Tls(s) => Pin::new(s).start_send(item).map_err(Error::from),
+ }
+ }
+
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ match &mut *self {
+ InnerWSConn::Plain(s) => Pin::new(s).poll_flush(cx).map_err(Error::from),
+ InnerWSConn::Tls(s) => Pin::new(s).poll_flush(cx).map_err(Error::from),
+ }
+ }
+
+ fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ match &mut *self {
+ InnerWSConn::Plain(s) => Pin::new(s).poll_close(cx).map_err(Error::from),
+ InnerWSConn::Tls(s) => Pin::new(s).poll_close(cx).map_err(Error::from),
+ }
+ .map_err(Error::from)
+ }
+}
+
+impl Stream for InnerWSConn {
+ type Item = Result<Message>;
+
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ match &mut *self {
+ InnerWSConn::Plain(s) => Pin::new(s).poll_next(cx).map_err(Error::from),
+ InnerWSConn::Tls(s) => Pin::new(s).poll_next(cx).map_err(Error::from),
+ }
+ }
+}