From 0992071a7f1a36424bcfaf1fbc84541ea041df1a Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 11 Apr 2024 10:19:20 +0200 Subject: add support for tokio & improve net crate api --- net/Cargo.toml | 38 +++++-- net/examples/tcp_codec.rs | 59 ++++++++++ net/src/codec/bytes_codec.rs | 29 +++++ net/src/codec/length_codec.rs | 49 +++++++++ net/src/codec/mod.rs | 25 +++++ net/src/codec/websocket.rs | 23 ++++ net/src/connection.rs | 53 ++------- net/src/endpoint.rs | 78 +++++++++----- net/src/error.rs | 24 ++++- net/src/lib.rs | 14 +-- net/src/listener.rs | 41 ++----- net/src/stream/buffer.rs | 82 ++++++++++++++ net/src/stream/mod.rs | 191 +++++++++++++++++++++++++++++++++ net/src/stream/websocket.rs | 107 +++++++++++++++++++ net/src/transports/tcp.rs | 188 +++++++++++++++++++++----------- net/src/transports/tls.rs | 220 ++++++++++++++++++++++++++------------ net/src/transports/udp.rs | 114 +++++++++++--------- net/src/transports/unix.rs | 193 +++++++++++++++++++++++---------- net/src/transports/ws.rs | 242 +++++++++++++++++++++++++++++++++--------- 19 files changed, 1364 insertions(+), 406 deletions(-) create mode 100644 net/examples/tcp_codec.rs create mode 100644 net/src/codec/bytes_codec.rs create mode 100644 net/src/codec/length_codec.rs create mode 100644 net/src/codec/mod.rs create mode 100644 net/src/codec/websocket.rs create mode 100644 net/src/stream/buffer.rs create mode 100644 net/src/stream/mod.rs create mode 100644 net/src/stream/websocket.rs (limited to 'net') diff --git a/net/Cargo.toml b/net/Cargo.toml index fe209cd..304cbb2 100644 --- a/net/Cargo.toml +++ b/net/Cargo.toml @@ -1,19 +1,43 @@ [package] name = "karyon_net" -version.workspace = true +version.workspace = true edition.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +default = ["smol"] +smol = [ + "karyon_core/smol", + "async-tungstenite/async-std-runtime", + "dep:futures-rustls", +] +tokio = [ + "karyon_core/tokio", + "async-tungstenite/tokio-runtime", + "dep:tokio", + "dep:tokio-rustls", +] + + [dependencies] -karyon_core.workspace = true +karyon_core = { workspace = true, default-features = false } -smol = "2.0.0" +pin-project-lite = "0.2.13" async-trait = "0.1.77" log = "0.4.21" -bincode = { version="2.0.0-rc.3", features = ["derive"]} +bincode = { version = "2.0.0-rc.3", features = ["derive"] } thiserror = "1.0.58" url = "2.5.0" -futures-rustls = "0.25.1" -async-tungstenite = "0.25.0" -ws_stream_tungstenite = "0.13.0" +async-tungstenite = { version = "0.25.0", default-features = false } +asynchronous-codec = "0.7.0" +futures-util = "0.3.30" +async-channel = "2.3.0" +rustls-pki-types = "1.7.0" + +futures-rustls = { version = "0.25.1", optional = true } +tokio-rustls = { version = "0.26.0", optional = true } +tokio = { version = "1.37.0", features = ["io-util"], optional = true } + +[dev-dependencies] +smol = "2.0.0" diff --git a/net/examples/tcp_codec.rs b/net/examples/tcp_codec.rs new file mode 100644 index 0000000..93deaae --- /dev/null +++ b/net/examples/tcp_codec.rs @@ -0,0 +1,59 @@ +use std::time::Duration; + +use karyon_core::async_util::sleep; + +use karyon_net::{ + codec::{Codec, Decoder, Encoder}, + tcp, ConnListener, Connection, Endpoint, Result, +}; + +#[derive(Clone)] +struct NewLineCodec {} + +impl Codec for NewLineCodec { + type Item = String; +} + +impl Encoder for NewLineCodec { + type EnItem = String; + fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result { + dst[..src.len()].copy_from_slice(src.as_bytes()); + Ok(src.len()) + } +} + +impl Decoder for NewLineCodec { + type DeItem = String; + fn decode(&self, src: &mut [u8]) -> Result> { + match src.iter().position(|&b| b == b'\n') { + Some(i) => Ok(Some((i + 1, String::from_utf8(src[..i].to_vec()).unwrap()))), + None => Ok(None), + } + } +} + +fn main() { + smol::block_on(async { + let endpoint: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap(); + + let config = tcp::TcpConfig::default(); + + let listener = tcp::listen(&endpoint, config.clone(), NewLineCodec {}) + .await + .unwrap(); + smol::spawn(async move { + if let Ok(conn) = listener.accept().await { + loop { + let msg = conn.recv().await.unwrap(); + println!("Receive a message: {:?}", msg); + } + }; + }) + .detach(); + + let conn = tcp::dial(&endpoint, config, NewLineCodec {}).await.unwrap(); + conn.send("hello".to_string()).await.unwrap(); + conn.send(" world\n".to_string()).await.unwrap(); + sleep(Duration::from_secs(1)).await; + }); +} diff --git a/net/src/codec/bytes_codec.rs b/net/src/codec/bytes_codec.rs new file mode 100644 index 0000000..b319e53 --- /dev/null +++ b/net/src/codec/bytes_codec.rs @@ -0,0 +1,29 @@ +use crate::{ + codec::{Codec, Decoder, Encoder}, + Result, +}; + +#[derive(Clone)] +pub struct BytesCodec {} +impl Codec for BytesCodec { + type Item = Vec; +} + +impl Encoder for BytesCodec { + type EnItem = Vec; + fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result { + dst[..src.len()].copy_from_slice(src); + Ok(src.len()) + } +} + +impl Decoder for BytesCodec { + type DeItem = Vec; + fn decode(&self, src: &mut [u8]) -> Result> { + if src.is_empty() { + Ok(None) + } else { + Ok(Some((src.len(), src.to_vec()))) + } + } +} diff --git a/net/src/codec/length_codec.rs b/net/src/codec/length_codec.rs new file mode 100644 index 0000000..76a1679 --- /dev/null +++ b/net/src/codec/length_codec.rs @@ -0,0 +1,49 @@ +use karyon_core::util::{decode, encode_into_slice}; + +use crate::{ + codec::{Codec, Decoder, Encoder}, + Result, +}; + +/// The size of the message length. +const MSG_LENGTH_SIZE: usize = std::mem::size_of::(); + +#[derive(Clone)] +pub struct LengthCodec {} +impl Codec for LengthCodec { + type Item = Vec; +} + +impl Encoder for LengthCodec { + type EnItem = Vec; + fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result { + let length_buf = &mut [0; MSG_LENGTH_SIZE]; + encode_into_slice(&(src.len() as u32), length_buf)?; + dst[..MSG_LENGTH_SIZE].copy_from_slice(length_buf); + dst[MSG_LENGTH_SIZE..src.len() + MSG_LENGTH_SIZE].copy_from_slice(src); + Ok(src.len() + MSG_LENGTH_SIZE) + } +} + +impl Decoder for LengthCodec { + type DeItem = Vec; + fn decode(&self, src: &mut [u8]) -> Result> { + if src.len() < MSG_LENGTH_SIZE { + return Ok(None); + } + + let mut length = [0; MSG_LENGTH_SIZE]; + length.copy_from_slice(&src[..MSG_LENGTH_SIZE]); + let (length, _) = decode::(&length)?; + let length = length as usize; + + if src.len() - MSG_LENGTH_SIZE >= length { + Ok(Some(( + length + MSG_LENGTH_SIZE, + src[MSG_LENGTH_SIZE..length + MSG_LENGTH_SIZE].to_vec(), + ))) + } else { + Ok(None) + } + } +} diff --git a/net/src/codec/mod.rs b/net/src/codec/mod.rs new file mode 100644 index 0000000..565cb07 --- /dev/null +++ b/net/src/codec/mod.rs @@ -0,0 +1,25 @@ +mod bytes_codec; +mod length_codec; +mod websocket; + +pub use bytes_codec::BytesCodec; +pub use length_codec::LengthCodec; +pub use websocket::{WebSocketCodec, WebSocketDecoder, WebSocketEncoder}; + +use crate::Result; + +pub trait Codec: + Decoder + Encoder + Send + Sync + 'static + Unpin +{ + type Item: Send + Sync; +} + +pub trait Encoder { + type EnItem; + fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result; +} + +pub trait Decoder { + type DeItem; + fn decode(&self, src: &mut [u8]) -> Result>; +} diff --git a/net/src/codec/websocket.rs b/net/src/codec/websocket.rs new file mode 100644 index 0000000..b59a55c --- /dev/null +++ b/net/src/codec/websocket.rs @@ -0,0 +1,23 @@ +use crate::Result; +use async_tungstenite::tungstenite::Message; + +pub trait WebSocketCodec: + WebSocketDecoder + + WebSocketEncoder + + Send + + Sync + + 'static + + Unpin +{ + type Item: Send + Sync; +} + +pub trait WebSocketEncoder { + type EnItem; + fn encode(&self, src: &Self::EnItem) -> Result; +} + +pub trait WebSocketDecoder { + type DeItem; + fn decode(&self, src: &Message) -> Result; +} diff --git a/net/src/connection.rs b/net/src/connection.rs index fa4640f..bbd21de 100644 --- a/net/src/connection.rs +++ b/net/src/connection.rs @@ -1,65 +1,34 @@ use async_trait::async_trait; -use crate::{ - transports::{tcp, udp, unix}, - Endpoint, Error, Result, -}; +use crate::{Endpoint, Result}; /// Alias for `Box` -pub type Conn = Box; +pub type Conn = Box>; /// A trait for objects which can be converted to [`Conn`]. pub trait ToConn { - fn to_conn(self) -> Conn; + type Item; + fn to_conn(self) -> Conn; } /// Connection is a generic network connection interface for -/// [`udp::UdpConn`], [`tcp::TcpConn`], and [`unix::UnixConn`]. +/// [`udp::UdpConn`], [`tcp::TcpConn`], [`tls::TlsConn`], [`ws::WsConn`], +/// and [`unix::UnixConn`]. /// /// If you are familiar with the Go language, this is similar to the /// [Conn](https://pkg.go.dev/net#Conn) interface #[async_trait] pub trait Connection: Send + Sync { + type Item; /// Returns the remote peer endpoint of this connection fn peer_endpoint(&self) -> Result; /// Returns the local socket endpoint of this connection fn local_endpoint(&self) -> Result; - /// Reads data from this connection. - async fn read(&self, buf: &mut [u8]) -> Result; + /// Recvs data from this connection. + async fn recv(&self) -> Result; - /// Writes data to this connection - async fn write(&self, buf: &[u8]) -> Result; -} - -/// Connects to the provided endpoint. -/// -/// it only supports `tcp4/6`, `udp4/6`, and `unix`. -/// -/// #Example -/// -/// ``` -/// use karyon_net::{Endpoint, dial}; -/// -/// async { -/// let endpoint: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap(); -/// -/// let conn = dial(&endpoint).await.unwrap(); -/// -/// conn.write(b"MSG").await.unwrap(); -/// -/// let mut buffer = [0;32]; -/// conn.read(&mut buffer).await.unwrap(); -/// }; -/// -/// ``` -/// -pub async fn dial(endpoint: &Endpoint) -> Result { - match endpoint { - Endpoint::Tcp(_, _) => Ok(Box::new(tcp::dial(endpoint).await?)), - Endpoint::Udp(_, _) => Ok(Box::new(udp::dial(endpoint).await?)), - Endpoint::Unix(addr) => Ok(Box::new(unix::dial(addr).await?)), - _ => Err(Error::InvalidEndpoint(endpoint.to_string())), - } + /// Sends data to this connection + async fn send(&self, msg: Self::Item) -> Result<()>; } diff --git a/net/src/endpoint.rs b/net/src/endpoint.rs index 9193628..0c7ecd1 100644 --- a/net/src/endpoint.rs +++ b/net/src/endpoint.rs @@ -1,10 +1,11 @@ use std::{ net::{IpAddr, SocketAddr}, - os::unix::net::SocketAddr as UnixSocketAddress, path::PathBuf, str::FromStr, }; +use std::os::unix::net::SocketAddr as UnixSocketAddr; + use bincode::{Decode, Encode}; use url::Url; @@ -25,7 +26,7 @@ pub type Port = u16; /// let endpoint: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap(); /// /// let socketaddr: SocketAddr = "127.0.0.1:3000".parse().unwrap(); -/// let endpoint = Endpoint::new_udp_addr(&socketaddr); +/// let endpoint = Endpoint::new_udp_addr(socketaddr); /// /// ``` /// @@ -35,7 +36,8 @@ pub enum Endpoint { Tcp(Addr, Port), Tls(Addr, Port), Ws(Addr, Port), - Unix(String), + Wss(Addr, Port), + Unix(PathBuf), } impl std::fmt::Display for Endpoint { @@ -53,12 +55,11 @@ impl std::fmt::Display for Endpoint { Endpoint::Ws(ip, port) => { write!(f, "ws://{}:{}", ip, port) } + Endpoint::Wss(ip, port) => { + write!(f, "wss://{}:{}", ip, port) + } Endpoint::Unix(path) => { - if path.is_empty() { - write!(f, "unix:/UNNAMED") - } else { - write!(f, "unix:/{}", path) - } + write!(f, "unix:/{}", path.to_string_lossy()) } } } @@ -71,7 +72,8 @@ impl TryFrom for SocketAddr { Endpoint::Udp(ip, port) | Endpoint::Tcp(ip, port) | Endpoint::Tls(ip, port) - | Endpoint::Ws(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)), + | Endpoint::Ws(ip, port) + | Endpoint::Wss(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)), Endpoint::Unix(_) => Err(Error::TryFromEndpoint), } } @@ -87,11 +89,11 @@ impl TryFrom for PathBuf { } } -impl TryFrom for UnixSocketAddress { +impl TryFrom for UnixSocketAddr { type Error = Error; - fn try_from(endpoint: Endpoint) -> std::result::Result { + fn try_from(endpoint: Endpoint) -> std::result::Result { match endpoint { - Endpoint::Unix(a) => Ok(UnixSocketAddress::from_pathname(a)?), + Endpoint::Unix(a) => Ok(UnixSocketAddr::from_pathname(a)?), _ => Err(Error::TryFromEndpoint), } } @@ -124,6 +126,7 @@ impl FromStr for Endpoint { "udp" => Ok(Endpoint::Udp(addr, port)), "tls" => Ok(Endpoint::Tls(addr, port)), "ws" => Ok(Endpoint::Ws(addr, port)), + "wss" => Ok(Endpoint::Wss(addr, port)), _ => Err(Error::InvalidEndpoint(s.to_string())), } } else { @@ -132,7 +135,7 @@ impl FromStr for Endpoint { } match url.scheme() { - "unix" => Ok(Endpoint::Unix(url.path().to_string())), + "unix" => Ok(Endpoint::Unix(url.path().into())), _ => Err(Error::InvalidEndpoint(s.to_string())), } } @@ -141,33 +144,33 @@ impl FromStr for Endpoint { impl Endpoint { /// Creates a new TCP endpoint from a `SocketAddr`. - pub fn new_tcp_addr(addr: &SocketAddr) -> Endpoint { + pub fn new_tcp_addr(addr: SocketAddr) -> Endpoint { Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port()) } /// Creates a new UDP endpoint from a `SocketAddr`. - pub fn new_udp_addr(addr: &SocketAddr) -> Endpoint { + pub fn new_udp_addr(addr: SocketAddr) -> Endpoint { Endpoint::Udp(Addr::Ip(addr.ip()), addr.port()) } /// Creates a new TLS endpoint from a `SocketAddr`. - pub fn new_tls_addr(addr: &SocketAddr) -> Endpoint { + pub fn new_tls_addr(addr: SocketAddr) -> Endpoint { Endpoint::Tls(Addr::Ip(addr.ip()), addr.port()) } /// Creates a new WS endpoint from a `SocketAddr`. - pub fn new_ws_addr(addr: &SocketAddr) -> Endpoint { + pub fn new_ws_addr(addr: SocketAddr) -> Endpoint { Endpoint::Ws(Addr::Ip(addr.ip()), addr.port()) } - /// Creates a new Unix endpoint from a `UnixSocketAddress`. - pub fn new_unix_addr(addr: &UnixSocketAddress) -> Endpoint { - Endpoint::Unix( - addr.as_pathname() - .and_then(|a| a.to_str()) - .unwrap_or("") - .to_string(), - ) + /// Creates a new WSS endpoint from a `SocketAddr`. + pub fn new_wss_addr(addr: SocketAddr) -> Endpoint { + Endpoint::Wss(Addr::Ip(addr.ip()), addr.port()) + } + + /// Creates a new Unix endpoint from a `UnixSocketAddr`. + pub fn new_unix_addr(addr: &std::path::Path) -> Endpoint { + Endpoint::Unix(addr.to_path_buf()) } /// Returns the `Port` of the endpoint. @@ -176,7 +179,8 @@ impl Endpoint { Endpoint::Tcp(_, port) | Endpoint::Udp(_, port) | Endpoint::Tls(_, port) - | Endpoint::Ws(_, port) => Ok(port), + | Endpoint::Ws(_, port) + | Endpoint::Wss(_, port) => Ok(port), _ => Err(Error::TryFromEndpoint), } } @@ -187,7 +191,8 @@ impl Endpoint { Endpoint::Tcp(addr, _) | Endpoint::Udp(addr, _) | Endpoint::Tls(addr, _) - | Endpoint::Ws(addr, _) => Ok(addr), + | Endpoint::Ws(addr, _) + | Endpoint::Wss(addr, _) => Ok(addr), _ => Err(Error::TryFromEndpoint), } } @@ -223,10 +228,27 @@ impl std::fmt::Display for Addr { } } +pub trait ToEndpoint { + fn to_endpoint(&self) -> Result; +} + +impl ToEndpoint for String { + fn to_endpoint(&self) -> Result { + Endpoint::from_str(self) + } +} + +impl ToEndpoint for &str { + fn to_endpoint(&self) -> Result { + Endpoint::from_str(self) + } +} + #[cfg(test)] mod tests { use super::*; use std::net::Ipv4Addr; + use std::path::PathBuf; #[test] fn test_endpoint_from_str() { @@ -243,7 +265,7 @@ mod tests { assert_eq!(endpoint_str, endpoint); let endpoint_str = "unix:/home/x/s.socket".parse::().unwrap(); - let endpoint = Endpoint::Unix("/home/x/s.socket".to_string()); + let endpoint = Endpoint::Unix(PathBuf::from_str("/home/x/s.socket").unwrap()); assert_eq!(endpoint_str, endpoint); } } diff --git a/net/src/error.rs b/net/src/error.rs index 6e04a12..ee93168 100644 --- a/net/src/error.rs +++ b/net/src/error.rs @@ -13,9 +13,18 @@ pub enum Error { #[error("invalid address {0}")] InvalidAddress(String), + #[error("invalid path {0}")] + InvalidPath(String), + #[error("invalid endpoint {0}")] InvalidEndpoint(String), + #[error("Encode error: {0}")] + Encode(String), + + #[error("Decode error: {0}")] + Decode(String), + #[error("Parse endpoint error {0}")] ParseEndpoint(String), @@ -26,23 +35,28 @@ pub enum Error { ChannelSend(String), #[error(transparent)] - ChannelRecv(#[from] smol::channel::RecvError), + ChannelRecv(#[from] async_channel::RecvError), #[error("Ws Error: {0}")] WsError(#[from] async_tungstenite::tungstenite::Error), + #[cfg(feature = "smol")] #[error("Tls Error: {0}")] Rustls(#[from] futures_rustls::rustls::Error), + #[cfg(feature = "tokio")] + #[error("Tls Error: {0}")] + Rustls(#[from] tokio_rustls::rustls::Error), + #[error("Invalid DNS Name: {0}")] - InvalidDnsNameError(#[from] futures_rustls::pki_types::InvalidDnsNameError), + InvalidDnsNameError(#[from] rustls_pki_types::InvalidDnsNameError), #[error(transparent)] - KaryonCore(#[from] karyon_core::error::Error), + KaryonCore(#[from] karyon_core::Error), } -impl From> for Error { - fn from(error: smol::channel::SendError) -> Self { +impl From> for Error { + fn from(error: async_channel::SendError) -> Self { Error::ChannelSend(error.to_string()) } } diff --git a/net/src/lib.rs b/net/src/lib.rs index c1d72b2..ddb53cf 100644 --- a/net/src/lib.rs +++ b/net/src/lib.rs @@ -1,20 +1,20 @@ +pub mod codec; mod connection; mod endpoint; mod error; mod listener; +mod stream; mod transports; pub use { - connection::{dial, Conn, Connection, ToConn}, - endpoint::{Addr, Endpoint, Port}, - listener::{listen, ConnListener, Listener, ToListener}, + connection::{Conn, Connection, ToConn}, + endpoint::{Addr, Endpoint, Port, ToEndpoint}, + listener::{ConnListener, Listener, ToListener}, transports::{tcp, tls, udp, unix, ws}, }; -use error::{Error, Result}; - /// Represents karyon's Net Error -pub use error::Error as NetError; +pub use error::Error; /// Represents karyon's Net Result -pub use error::Result as NetResult; +pub use error::Result; diff --git a/net/src/listener.rs b/net/src/listener.rs index 4511212..469f5e9 100644 --- a/net/src/listener.rs +++ b/net/src/listener.rs @@ -1,46 +1,21 @@ use async_trait::async_trait; -use crate::{ - transports::{tcp, unix}, - Conn, Endpoint, Error, Result, -}; +use crate::{Conn, Endpoint, Result}; /// Alias for `Box` -pub type Listener = Box; +pub type Listener = Box>; /// A trait for objects which can be converted to [`Listener`]. pub trait ToListener { - fn to_listener(self) -> Listener; + type Item; + fn to_listener(self) -> Listener; } -/// ConnListener is a generic network listener. +/// ConnListener is a generic network listener interface for +/// [`tcp::TcpConn`], [`tls::TlsConn`], [`ws::WsConn`], and [`unix::UnixConn`]. #[async_trait] pub trait ConnListener: Send + Sync { + type Item; fn local_endpoint(&self) -> Result; - async fn accept(&self) -> Result; -} - -/// Listens to the provided endpoint. -/// -/// it only supports `tcp4/6`, and `unix`. -/// -/// #Example -/// -/// ``` -/// use karyon_net::{Endpoint, listen}; -/// -/// async { -/// let endpoint: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap(); -/// -/// let listener = listen(&endpoint).await.unwrap(); -/// let conn = listener.accept().await.unwrap(); -/// }; -/// -/// ``` -pub async fn listen(endpoint: &Endpoint) -> Result> { - match endpoint { - Endpoint::Tcp(_, _) => Ok(Box::new(tcp::listen(endpoint).await?)), - Endpoint::Unix(addr) => Ok(Box::new(unix::listen(addr)?)), - _ => Err(Error::InvalidEndpoint(endpoint.to_string())), - } + async fn accept(&self) -> Result>; } diff --git a/net/src/stream/buffer.rs b/net/src/stream/buffer.rs new file mode 100644 index 0000000..f211600 --- /dev/null +++ b/net/src/stream/buffer.rs @@ -0,0 +1,82 @@ +#[derive(Debug)] +pub struct Buffer { + inner: B, + len: usize, + cap: usize, +} + +impl Buffer +where + B: AsMut<[u8]> + AsRef<[u8]>, +{ + /// Constructs a new, empty Buffer. + pub fn new(b: B) -> Self { + Self { + cap: b.as_ref().len(), + inner: b, + len: 0, + } + } + + /// Returns the number of elements in the buffer. + #[allow(dead_code)] + pub fn len(&self) -> usize { + self.len + } + + /// Resizes the buffer in-place so that `len` is equal to `new_size`. + pub fn resize(&mut self, new_size: usize) { + assert!(self.cap > new_size); + self.len = new_size; + } + + /// Appends all elements in a slice to the buffer. + pub fn extend_from_slice(&mut self, bytes: &[u8]) { + let old_len = self.len; + self.resize(self.len + bytes.len()); + self.inner.as_mut()[old_len..bytes.len() + old_len].copy_from_slice(bytes); + } + + /// Shortens the buffer, dropping the first `cnt` bytes and keeping the + /// rest. + pub fn advance(&mut self, cnt: usize) { + assert!(self.len >= cnt); + self.inner.as_mut().rotate_left(cnt); + self.resize(self.len - cnt); + } + + /// Returns `true` if the buffer contains no elements. + pub fn is_empty(&self) -> bool { + self.len == 0 + } +} + +impl AsMut<[u8]> for Buffer +where + B: AsMut<[u8]> + AsRef<[u8]>, +{ + fn as_mut(&mut self) -> &mut [u8] { + &mut self.inner.as_mut()[..self.len] + } +} + +impl AsRef<[u8]> for Buffer +where + B: AsMut<[u8]> + AsRef<[u8]>, +{ + fn as_ref(&self) -> &[u8] { + &self.inner.as_ref()[..self.len] + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_buffer_advance() { + let mut buf = Buffer::new([0u8; 32]); + buf.extend_from_slice(&[1, 2, 3]); + assert_eq!([1, 2, 3], buf.as_ref()); + } +} diff --git a/net/src/stream/mod.rs b/net/src/stream/mod.rs new file mode 100644 index 0000000..9493b29 --- /dev/null +++ b/net/src/stream/mod.rs @@ -0,0 +1,191 @@ +mod buffer; +mod websocket; + +pub use websocket::WsStream; + +use std::{ + io::ErrorKind, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_util::{ + ready, + stream::{Stream, StreamExt}, + Sink, +}; +use pin_project_lite::pin_project; + +use karyon_core::async_runtime::io::{AsyncRead, AsyncWrite}; + +use crate::{ + codec::{Decoder, Encoder}, + Error, Result, +}; + +use buffer::Buffer; + +const BUFFER_SIZE: usize = 2048 * 2024; // 4MB +const INITIAL_BUFFER_SIZE: usize = 1024 * 1024; // 1MB + +pub struct ReadStream { + inner: T, + decoder: C, + buffer: Buffer<[u8; BUFFER_SIZE]>, +} + +impl ReadStream +where + T: AsyncRead + Unpin, + C: Decoder + Unpin, +{ + pub fn new(inner: T, decoder: C) -> Self { + Self { + inner, + decoder, + buffer: Buffer::new([0u8; BUFFER_SIZE]), + } + } + + pub async fn recv(&mut self) -> Result { + match self.next().await { + Some(m) => m, + None => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())), + } + } +} + +pin_project! { + pub struct WriteStream { + #[pin] + inner: T, + encoder: C, + high_water_mark: usize, + buffer: Buffer<[u8; BUFFER_SIZE]>, + } +} + +impl WriteStream +where + T: AsyncWrite + Unpin, + C: Encoder + Unpin, +{ + pub fn new(inner: T, encoder: C) -> Self { + Self { + inner, + encoder, + high_water_mark: 131072, + buffer: Buffer::new([0u8; BUFFER_SIZE]), + } + } +} + +impl Stream for ReadStream +where + T: AsyncRead + Unpin, + C: Decoder + Unpin, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = &mut *self; + + if let Some((n, item)) = this.decoder.decode(this.buffer.as_mut())? { + this.buffer.advance(n); + return Poll::Ready(Some(Ok(item))); + } + + let mut buf = [0u8; INITIAL_BUFFER_SIZE]; + #[cfg(feature = "tokio")] + let mut buf = tokio::io::ReadBuf::new(&mut buf); + + loop { + #[cfg(feature = "smol")] + let n = ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?; + #[cfg(feature = "smol")] + let bytes = &buf[..n]; + + #[cfg(feature = "tokio")] + ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?; + #[cfg(feature = "tokio")] + let bytes = buf.filled(); + #[cfg(feature = "tokio")] + let n = bytes.len(); + + this.buffer.extend_from_slice(bytes); + + match this.decoder.decode(this.buffer.as_mut())? { + Some((cn, item)) => { + this.buffer.advance(cn); + return Poll::Ready(Some(Ok(item))); + } + None if n == 0 => { + if this.buffer.is_empty() { + return Poll::Ready(None); + } else { + return Poll::Ready(Some(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "bytes remaining in read stream", + ) + .into()))); + } + } + _ => continue, + } + } + } +} + +impl Sink for WriteStream +where + T: AsyncWrite + Unpin, + C: Encoder + Unpin, +{ + type Error = Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = &mut *self; + while !this.buffer.is_empty() { + let n = ready!(Pin::new(&mut this.inner).poll_write(cx, this.buffer.as_ref()))?; + + if n == 0 { + return Poll::Ready(Err(std::io::Error::new( + ErrorKind::UnexpectedEof, + "End of file", + ) + .into())); + } + + this.buffer.advance(n); + } + + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: C::EnItem) -> Result<()> { + let this = &mut *self; + let mut buf = [0u8; INITIAL_BUFFER_SIZE]; + let n = this.encoder.encode(&item, &mut buf)?; + this.buffer.extend_from_slice(&buf[..n]); + Ok(()) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + ready!(self.as_mut().poll_ready(cx))?; + self.project().inner.poll_flush(cx).map_err(Into::into) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + #[cfg(feature = "smol")] + return self.project().inner.poll_close(cx).map_err(Error::from); + #[cfg(feature = "tokio")] + return self.project().inner.poll_shutdown(cx).map_err(Error::from); + } +} diff --git a/net/src/stream/websocket.rs b/net/src/stream/websocket.rs new file mode 100644 index 0000000..2552eaf --- /dev/null +++ b/net/src/stream/websocket.rs @@ -0,0 +1,107 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use async_tungstenite::tungstenite::Message; +use futures_util::{Sink, SinkExt, Stream, StreamExt}; + +#[cfg(feature = "smol")] +use futures_rustls::TlsStream; +#[cfg(feature = "tokio")] +use tokio_rustls::TlsStream; + +use karyon_core::async_runtime::net::TcpStream; + +use crate::{codec::WebSocketCodec, Error, Result}; + +#[cfg(feature = "tokio")] +type WebSocketStream = + async_tungstenite::WebSocketStream>; +#[cfg(feature = "smol")] +use async_tungstenite::WebSocketStream; + +pub struct WsStream { + inner: InnerWSConn, + codec: C, +} + +impl WsStream +where + C: WebSocketCodec, +{ + pub fn new_ws(conn: WebSocketStream, codec: C) -> Self { + Self { + inner: InnerWSConn::Plain(conn), + codec, + } + } + + pub fn new_wss(conn: WebSocketStream>, codec: C) -> Self { + Self { + inner: InnerWSConn::Tls(conn), + codec, + } + } + + pub async fn recv(&mut self) -> Result { + match self.inner.next().await { + Some(msg) => self.codec.decode(&msg?), + None => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())), + } + } + + pub async fn send(&mut self, msg: C::Item) -> Result<()> { + let ws_msg = self.codec.encode(&msg)?; + self.inner.send(ws_msg).await + } +} + +enum InnerWSConn { + Plain(WebSocketStream), + Tls(WebSocketStream>), +} + +impl Sink for InnerWSConn { + type Error = Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_ready(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_ready(cx).map_err(Error::from), + } + } + + fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<()> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).start_send(item).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).start_send(item).map_err(Error::from), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_flush(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_flush(cx).map_err(Error::from), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_close(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_close(cx).map_err(Error::from), + } + .map_err(Error::from) + } +} + +impl Stream for InnerWSConn { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + InnerWSConn::Plain(s) => Pin::new(s).poll_next(cx).map_err(Error::from), + InnerWSConn::Tls(s) => Pin::new(s).poll_next(cx).map_err(Error::from), + } + } +} diff --git a/net/src/transports/tcp.rs b/net/src/transports/tcp.rs index 21fce3d..03c8ab2 100644 --- a/net/src/transports/tcp.rs +++ b/net/src/transports/tcp.rs @@ -1,116 +1,184 @@ use std::net::SocketAddr; use async_trait::async_trait; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, +use futures_util::SinkExt; + +use karyon_core::async_runtime::{ + io::{split, ReadHalf, WriteHalf}, lock::Mutex, - net::{TcpListener, TcpStream}, + net::{TcpListener as AsyncTcpListener, TcpStream}, }; use crate::{ - connection::{Connection, ToConn}, + codec::Codec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, - listener::{ConnListener, ToListener}, - Error, Result, + listener::{ConnListener, Listener, ToListener}, + stream::{ReadStream, WriteStream}, + Result, }; -/// TCP network connection implementation of the [`Connection`] trait. -pub struct TcpConn { - inner: TcpStream, - read: Mutex>, - write: Mutex>, +/// TCP configuration +#[derive(Clone)] +pub struct TcpConfig { + pub nodelay: bool, +} + +impl Default for TcpConfig { + fn default() -> Self { + Self { nodelay: true } + } +} + +/// TCP connection implementation of the [`Connection`] trait. +pub struct TcpConn { + read_stream: Mutex, C>>, + write_stream: Mutex, C>>, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, } -impl TcpConn { +impl TcpConn +where + C: Codec + Clone, +{ /// Creates a new TcpConn - pub fn new(conn: TcpStream) -> Self { - let (read, write) = split(conn.clone()); + pub fn new( + socket: TcpStream, + codec: C, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, + ) -> Self { + let (read, write) = split(socket); + let read_stream = Mutex::new(ReadStream::new(read, codec.clone())); + let write_stream = Mutex::new(WriteStream::new(write, codec)); Self { - inner: conn, - read: Mutex::new(read), - write: Mutex::new(write), + read_stream, + write_stream, + peer_endpoint, + local_endpoint, } } } #[async_trait] -impl Connection for TcpConn { +impl Connection for TcpConn +where + C: Codec + Clone, +{ + type Item = C::Item; fn peer_endpoint(&self) -> Result { - Ok(Endpoint::new_tcp_addr(&self.inner.peer_addr()?)) + Ok(self.peer_endpoint.clone()) } fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_tcp_addr(&self.inner.local_addr()?)) + Ok(self.local_endpoint.clone()) } - async fn read(&self, buf: &mut [u8]) -> Result { - self.read.lock().await.read(buf).await.map_err(Error::from) + async fn recv(&self) -> Result { + self.read_stream.lock().await.recv().await } - async fn write(&self, buf: &[u8]) -> Result { - self.write - .lock() - .await - .write(buf) - .await - .map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + self.write_stream.lock().await.send(msg).await + } +} + +pub struct TcpListener { + inner: AsyncTcpListener, + config: TcpConfig, + codec: C, +} + +impl TcpListener +where + C: Codec, +{ + pub fn new(listener: AsyncTcpListener, config: TcpConfig, codec: C) -> Self { + Self { + inner: listener, + config: config.clone(), + codec, + } } } #[async_trait] -impl ConnListener for TcpListener { +impl ConnListener for TcpListener +where + C: Codec + Clone, +{ + type Item = C::Item; fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_tcp_addr(&self.local_addr()?)) + Ok(Endpoint::new_tcp_addr(self.inner.local_addr()?)) } - async fn accept(&self) -> Result> { - let (conn, _) = self.accept().await?; - conn.set_nodelay(true)?; - Ok(Box::new(TcpConn::new(conn))) + async fn accept(&self) -> Result> { + let (socket, _) = self.inner.accept().await?; + socket.set_nodelay(self.config.nodelay)?; + + let peer_endpoint = socket.peer_addr().map(Endpoint::new_tcp_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_tcp_addr)?; + + Ok(Box::new(TcpConn::new( + socket, + self.codec.clone(), + peer_endpoint, + local_endpoint, + ))) } } /// Connects to the given TCP address and port. -pub async fn dial(endpoint: &Endpoint) -> Result { +pub async fn dial(endpoint: &Endpoint, config: TcpConfig, codec: C) -> Result> +where + C: Codec + Clone, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let conn = TcpStream::connect(addr).await?; - conn.set_nodelay(true)?; - Ok(TcpConn::new(conn)) + let socket = TcpStream::connect(addr).await?; + socket.set_nodelay(config.nodelay)?; + + let peer_endpoint = socket.peer_addr().map(Endpoint::new_tcp_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_tcp_addr)?; + + Ok(TcpConn::new(socket, codec, peer_endpoint, local_endpoint)) } /// Listens on the given TCP address and port. -pub async fn listen(endpoint: &Endpoint) -> Result { +pub async fn listen(endpoint: &Endpoint, config: TcpConfig, codec: C) -> Result> +where + C: Codec, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let listener = TcpListener::bind(addr).await?; - Ok(listener) -} - -impl From for Box { - fn from(conn: TcpStream) -> Self { - Box::new(TcpConn::new(conn)) - } + let listener = AsyncTcpListener::bind(addr).await?; + Ok(TcpListener::new(listener, config, codec)) } -impl From for Box { - fn from(listener: TcpListener) -> Self { +impl From> for Box> +where + C: Clone + Codec, +{ + fn from(listener: TcpListener) -> Self { Box::new(listener) } } -impl ToConn for TcpStream { - fn to_conn(self) -> Box { - self.into() - } -} - -impl ToConn for TcpConn { - fn to_conn(self) -> Box { +impl ToConn for TcpConn +where + C: Codec + Clone, +{ + type Item = C::Item; + fn to_conn(self) -> Conn { Box::new(self) } } -impl ToListener for TcpListener { - fn to_listener(self) -> Box { +impl ToListener for TcpListener +where + C: Clone + Codec, +{ + type Item = C::Item; + fn to_listener(self) -> Listener { self.into() } } diff --git a/net/src/transports/tls.rs b/net/src/transports/tls.rs index 476f495..c972f63 100644 --- a/net/src/transports/tls.rs +++ b/net/src/transports/tls.rs @@ -1,138 +1,218 @@ use std::{net::SocketAddr, sync::Arc}; use async_trait::async_trait; -use futures_rustls::{pki_types, rustls, TlsAcceptor, TlsConnector, TlsStream}; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, +use futures_util::SinkExt; +use rustls_pki_types as pki_types; + +#[cfg(feature = "smol")] +use futures_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream}; +#[cfg(feature = "tokio")] +use tokio_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream}; + +use karyon_core::async_runtime::{ + io::{split, ReadHalf, WriteHalf}, lock::Mutex, net::{TcpListener, TcpStream}, }; use crate::{ - connection::{Connection, ToConn}, + codec::Codec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, - listener::{ConnListener, ToListener}, - Error, Result, + listener::{ConnListener, Listener, ToListener}, + stream::{ReadStream, WriteStream}, + Result, }; +use super::tcp::TcpConfig; + +/// TLS configuration +#[derive(Clone)] +pub struct ServerTlsConfig { + pub tcp_config: TcpConfig, + pub server_config: rustls::ServerConfig, +} + +#[derive(Clone)] +pub struct ClientTlsConfig { + pub tcp_config: TcpConfig, + pub client_config: rustls::ClientConfig, + pub dns_name: String, +} + /// TLS network connection implementation of the [`Connection`] trait. -pub struct TlsConn { - inner: TcpStream, - read: Mutex>>, - write: Mutex>>, +pub struct TlsConn { + read_stream: Mutex>, C>>, + write_stream: Mutex>, C>>, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, } -impl TlsConn { +impl TlsConn +where + C: Codec + Clone, +{ /// Creates a new TlsConn - pub fn new(sock: TcpStream, conn: TlsStream) -> Self { + pub fn new( + conn: TlsStream, + codec: C, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, + ) -> Self { let (read, write) = split(conn); + let read_stream = Mutex::new(ReadStream::new(read, codec.clone())); + let write_stream = Mutex::new(WriteStream::new(write, codec)); Self { - inner: sock, - read: Mutex::new(read), - write: Mutex::new(write), + read_stream, + write_stream, + peer_endpoint, + local_endpoint, } } } #[async_trait] -impl Connection for TlsConn { +impl Connection for TlsConn +where + C: Clone + Codec, +{ + type Item = C::Item; fn peer_endpoint(&self) -> Result { - Ok(Endpoint::new_tls_addr(&self.inner.peer_addr()?)) + Ok(self.peer_endpoint.clone()) } fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?)) + Ok(self.local_endpoint.clone()) } - async fn read(&self, buf: &mut [u8]) -> Result { - self.read.lock().await.read(buf).await.map_err(Error::from) + async fn recv(&self) -> Result { + self.read_stream.lock().await.recv().await } - async fn write(&self, buf: &[u8]) -> Result { - self.write - .lock() - .await - .write(buf) - .await - .map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + self.write_stream.lock().await.send(msg).await } } /// Connects to the given TLS address and port. -pub async fn dial( - endpoint: &Endpoint, - config: rustls::ClientConfig, - dns_name: &'static str, -) -> Result { +pub async fn dial(endpoint: &Endpoint, config: ClientTlsConfig, codec: C) -> Result> +where + C: Codec + Clone, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let connector = TlsConnector::from(Arc::new(config)); + let connector = TlsConnector::from(Arc::new(config.client_config.clone())); + + let socket = TcpStream::connect(addr).await?; + socket.set_nodelay(config.tcp_config.nodelay)?; - let sock = TcpStream::connect(addr).await?; - sock.set_nodelay(true)?; + let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?; - let altname = pki_types::ServerName::try_from(dns_name)?; - let conn = connector.connect(altname, sock.clone()).await?; - Ok(TlsConn::new(sock, TlsStream::Client(conn))) + let altname = pki_types::ServerName::try_from(config.dns_name.clone())?; + let conn = connector.connect(altname, socket).await?; + Ok(TlsConn::new( + TlsStream::Client(conn), + codec, + peer_endpoint, + local_endpoint, + )) } /// Tls network listener implementation of the `Listener` [`ConnListener`] trait. -pub struct TlsListener { +pub struct TlsListener { inner: TcpListener, acceptor: TlsAcceptor, + config: ServerTlsConfig, + codec: C, +} + +impl TlsListener +where + C: Codec + Clone, +{ + pub fn new( + acceptor: TlsAcceptor, + listener: TcpListener, + config: ServerTlsConfig, + codec: C, + ) -> Self { + Self { + inner: listener, + acceptor, + config: config.clone(), + codec, + } + } } #[async_trait] -impl ConnListener for TlsListener { +impl ConnListener for TlsListener +where + C: Clone + Codec, +{ + type Item = C::Item; fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?)) + Ok(Endpoint::new_tls_addr(self.inner.local_addr()?)) } - async fn accept(&self) -> Result> { - let (sock, _) = self.inner.accept().await?; - sock.set_nodelay(true)?; - let conn = self.acceptor.accept(sock.clone()).await?; - Ok(Box::new(TlsConn::new(sock, TlsStream::Server(conn)))) + async fn accept(&self) -> Result> { + let (socket, _) = self.inner.accept().await?; + socket.set_nodelay(self.config.tcp_config.nodelay)?; + + let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?; + + let conn = self.acceptor.accept(socket).await?; + Ok(Box::new(TlsConn::new( + TlsStream::Server(conn), + self.codec.clone(), + peer_endpoint, + local_endpoint, + ))) } } /// Listens on the given TLS address and port. -pub async fn listen(endpoint: &Endpoint, config: rustls::ServerConfig) -> Result { +pub async fn listen( + endpoint: &Endpoint, + config: ServerTlsConfig, + codec: C, +) -> Result> +where + C: Clone + Codec, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let acceptor = TlsAcceptor::from(Arc::new(config)); + let acceptor = TlsAcceptor::from(Arc::new(config.server_config.clone())); let listener = TcpListener::bind(addr).await?; - Ok(TlsListener { - acceptor, - inner: listener, - }) -} - -impl From> for Box { - fn from(conn: TlsStream) -> Self { - Box::new(TlsConn::new(conn.get_ref().0.clone(), conn)) - } + Ok(TlsListener::new(acceptor, listener, config, codec)) } -impl From for Box { - fn from(listener: TlsListener) -> Self { +impl From> for Listener +where + C: Codec + Clone, +{ + fn from(listener: TlsListener) -> Self { Box::new(listener) } } -impl ToConn for TlsStream { - fn to_conn(self) -> Box { - self.into() - } -} - -impl ToConn for TlsConn { - fn to_conn(self) -> Box { +impl ToConn for TlsConn +where + C: Codec + Clone, +{ + type Item = C::Item; + fn to_conn(self) -> Conn { Box::new(self) } } -impl ToListener for TlsListener { - fn to_listener(self) -> Box { +impl ToListener for TlsListener +where + C: Clone + Codec, +{ + type Item = C::Item; + fn to_listener(self) -> Listener { self.into() } } diff --git a/net/src/transports/udp.rs b/net/src/transports/udp.rs index bd1fe83..950537c 100644 --- a/net/src/transports/udp.rs +++ b/net/src/transports/udp.rs @@ -1,93 +1,111 @@ use std::net::SocketAddr; use async_trait::async_trait; -use smol::net::UdpSocket; +use karyon_core::async_runtime::net::UdpSocket; use crate::{ - connection::{Connection, ToConn}, + codec::Codec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, Error, Result, }; +const BUFFER_SIZE: usize = 64 * 1024; + +/// UDP configuration +#[derive(Default)] +pub struct UdpConfig {} + /// UDP network connection implementation of the [`Connection`] trait. -pub struct UdpConn { +#[allow(dead_code)] +pub struct UdpConn { inner: UdpSocket, + codec: C, + config: UdpConfig, } -impl UdpConn { +impl UdpConn +where + C: Codec + Clone, +{ /// Creates a new UdpConn - pub fn new(conn: UdpSocket) -> Self { - Self { inner: conn } - } -} - -impl UdpConn { - /// Receives a single datagram message. Returns the number of bytes read - /// and the origin endpoint. - pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, Endpoint)> { - let (size, addr) = self.inner.recv_from(buf).await?; - Ok((size, Endpoint::new_udp_addr(&addr))) - } - - /// Sends data to the given address. Returns the number of bytes written. - pub async fn send_to(&self, buf: &[u8], addr: &Endpoint) -> Result { - let addr: SocketAddr = addr.clone().try_into()?; - let size = self.inner.send_to(buf, addr).await?; - Ok(size) + fn new(socket: UdpSocket, config: UdpConfig, codec: C) -> Self { + Self { + inner: socket, + codec, + config, + } } } #[async_trait] -impl Connection for UdpConn { +impl Connection for UdpConn +where + C: Codec + Clone, +{ + type Item = (C::Item, Endpoint); fn peer_endpoint(&self) -> Result { - Ok(Endpoint::new_udp_addr(&self.inner.peer_addr()?)) + self.inner + .peer_addr() + .map(Endpoint::new_udp_addr) + .map_err(Error::from) } fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_udp_addr(&self.inner.local_addr()?)) + self.inner + .local_addr() + .map(Endpoint::new_udp_addr) + .map_err(Error::from) } - async fn read(&self, buf: &mut [u8]) -> Result { - self.inner.recv(buf).await.map_err(Error::from) + async fn recv(&self) -> Result { + let mut buf = [0u8; BUFFER_SIZE]; + let (_, addr) = self.inner.recv_from(&mut buf).await?; + match self.codec.decode(&mut buf)? { + Some((_, item)) => Ok((item, Endpoint::new_udp_addr(addr))), + None => Err(Error::Decode("Unable to decode".into())), + } } - async fn write(&self, buf: &[u8]) -> Result { - self.inner.send(buf).await.map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + let (msg, out_addr) = msg; + let mut buf = [0u8; BUFFER_SIZE]; + self.codec.encode(&msg, &mut buf)?; + let addr: SocketAddr = out_addr.try_into()?; + self.inner.send_to(&buf, addr).await?; + Ok(()) } } /// Connects to the given UDP address and port. -pub async fn dial(endpoint: &Endpoint) -> Result { +pub async fn dial(endpoint: &Endpoint, config: UdpConfig, codec: C) -> Result> +where + C: Codec + Clone, +{ let addr = SocketAddr::try_from(endpoint.clone())?; // Let the operating system assign an available port to this socket let conn = UdpSocket::bind("[::]:0").await?; conn.connect(addr).await?; - Ok(UdpConn::new(conn)) + Ok(UdpConn::new(conn, config, codec)) } /// Listens on the given UDP address and port. -pub async fn listen(endpoint: &Endpoint) -> Result { +pub async fn listen(endpoint: &Endpoint, config: UdpConfig, codec: C) -> Result> +where + C: Codec + Clone, +{ let addr = SocketAddr::try_from(endpoint.clone())?; let conn = UdpSocket::bind(addr).await?; - let udp_conn = UdpConn::new(conn); - Ok(udp_conn) -} - -impl From for Box { - fn from(conn: UdpSocket) -> Self { - Box::new(UdpConn::new(conn)) - } -} - -impl ToConn for UdpSocket { - fn to_conn(self) -> Box { - self.into() - } + Ok(UdpConn::new(conn, config, codec)) } -impl ToConn for UdpConn { - fn to_conn(self) -> Box { +impl ToConn for UdpConn +where + C: Codec + Clone, +{ + type Item = (C::Item, Endpoint); + fn to_conn(self) -> Conn { Box::new(self) } } diff --git a/net/src/transports/unix.rs b/net/src/transports/unix.rs index 494e104..bafebaf 100644 --- a/net/src/transports/unix.rs +++ b/net/src/transports/unix.rs @@ -1,111 +1,192 @@ use async_trait::async_trait; +use futures_util::SinkExt; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, +use karyon_core::async_runtime::{ + io::{split, ReadHalf, WriteHalf}, lock::Mutex, - net::unix::{UnixListener, UnixStream}, + net::{UnixListener as AsyncUnixListener, UnixStream}, }; use crate::{ - connection::{Connection, ToConn}, + codec::Codec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, - listener::{ConnListener, ToListener}, + listener::{ConnListener, Listener, ToListener}, + stream::{ReadStream, WriteStream}, Error, Result, }; +/// Unix Conn config +#[derive(Clone, Default)] +pub struct UnixConfig {} + /// Unix domain socket implementation of the [`Connection`] trait. -pub struct UnixConn { - inner: UnixStream, - read: Mutex>, - write: Mutex>, +pub struct UnixConn { + read_stream: Mutex, C>>, + write_stream: Mutex, C>>, + peer_endpoint: Option, + local_endpoint: Option, } -impl UnixConn { - /// Creates a new UnixConn - pub fn new(conn: UnixStream) -> Self { - let (read, write) = split(conn.clone()); +impl UnixConn +where + C: Codec + Clone, +{ + /// Creates a new TcpConn + pub fn new(conn: UnixStream, codec: C) -> Self { + let peer_endpoint = conn + .peer_addr() + .and_then(|a| { + Ok(Endpoint::new_unix_addr( + a.as_pathname() + .ok_or(std::io::ErrorKind::AddrNotAvailable)?, + )) + }) + .ok(); + let local_endpoint = conn + .local_addr() + .and_then(|a| { + Ok(Endpoint::new_unix_addr( + a.as_pathname() + .ok_or(std::io::ErrorKind::AddrNotAvailable)?, + )) + }) + .ok(); + + let (read, write) = split(conn); + let read_stream = Mutex::new(ReadStream::new(read, codec.clone())); + let write_stream = Mutex::new(WriteStream::new(write, codec)); Self { - inner: conn, - read: Mutex::new(read), - write: Mutex::new(write), + read_stream, + write_stream, + peer_endpoint, + local_endpoint, } } } #[async_trait] -impl Connection for UnixConn { +impl Connection for UnixConn +where + C: Codec + Clone, +{ + type Item = C::Item; fn peer_endpoint(&self) -> Result { - Ok(Endpoint::new_unix_addr(&self.inner.peer_addr()?)) + self.peer_endpoint + .clone() + .ok_or(Error::IO(std::io::ErrorKind::AddrNotAvailable.into())) } fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_unix_addr(&self.inner.local_addr()?)) + self.local_endpoint + .clone() + .ok_or(Error::IO(std::io::ErrorKind::AddrNotAvailable.into())) } - async fn read(&self, buf: &mut [u8]) -> Result { - self.read.lock().await.read(buf).await.map_err(Error::from) + async fn recv(&self) -> Result { + self.read_stream.lock().await.recv().await } - async fn write(&self, buf: &[u8]) -> Result { - self.write - .lock() - .await - .write(buf) - .await - .map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + self.write_stream.lock().await.send(msg).await + } +} + +#[allow(dead_code)] +pub struct UnixListener { + inner: AsyncUnixListener, + config: UnixConfig, + codec: C, +} + +impl UnixListener +where + C: Codec + Clone, +{ + pub fn new(listener: AsyncUnixListener, config: UnixConfig, codec: C) -> Self { + Self { + inner: listener, + config, + codec, + } } } #[async_trait] -impl ConnListener for UnixListener { +impl ConnListener for UnixListener +where + C: Codec + Clone, +{ + type Item = C::Item; fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_unix_addr(&self.local_addr()?)) + self.inner + .local_addr() + .and_then(|a| { + Ok(Endpoint::new_unix_addr( + a.as_pathname() + .ok_or(std::io::ErrorKind::AddrNotAvailable)?, + )) + }) + .map_err(Error::from) } - async fn accept(&self) -> Result> { - let (conn, _) = self.accept().await?; - Ok(Box::new(UnixConn::new(conn))) + async fn accept(&self) -> Result> { + let (conn, _) = self.inner.accept().await?; + Ok(Box::new(UnixConn::new(conn, self.codec.clone()))) } } /// Connects to the given Unix socket path. -pub async fn dial(path: &String) -> Result { +pub async fn dial(endpoint: &Endpoint, _config: UnixConfig, codec: C) -> Result> +where + C: Codec + Clone, +{ + let path: std::path::PathBuf = endpoint.clone().try_into()?; let conn = UnixStream::connect(path).await?; - Ok(UnixConn::new(conn)) + Ok(UnixConn::new(conn, codec)) } /// Listens on the given Unix socket path. -pub fn listen(path: &String) -> Result { - let listener = UnixListener::bind(path)?; - Ok(listener) -} - -impl From for Box { - fn from(conn: UnixStream) -> Self { - Box::new(UnixConn::new(conn)) - } +pub fn listen(endpoint: &Endpoint, config: UnixConfig, codec: C) -> Result> +where + C: Codec + Clone, +{ + let path: std::path::PathBuf = endpoint.clone().try_into()?; + let listener = AsyncUnixListener::bind(path)?; + Ok(UnixListener::new(listener, config, codec)) } -impl From for Box { - fn from(listener: UnixListener) -> Self { +// impl From for Box { +// fn from(conn: UnixStream) -> Self { +// Box::new(UnixConn::new(conn)) +// } +// } + +impl From> for Listener +where + C: Codec + Clone, +{ + fn from(listener: UnixListener) -> Self { Box::new(listener) } } -impl ToConn for UnixStream { - fn to_conn(self) -> Box { - self.into() - } -} - -impl ToConn for UnixConn { - fn to_conn(self) -> Box { +impl ToConn for UnixConn +where + C: Codec + Clone, +{ + type Item = C::Item; + fn to_conn(self) -> Conn { Box::new(self) } } -impl ToListener for UnixListener { - fn to_listener(self) -> Box { +impl ToListener for UnixListener +where + C: Codec + Clone, +{ + type Item = C::Item; + fn to_listener(self) -> Listener { self.into() } } diff --git a/net/src/transports/ws.rs b/net/src/transports/ws.rs index eaf3b9b..17fe924 100644 --- a/net/src/transports/ws.rs +++ b/net/src/transports/ws.rs @@ -1,112 +1,254 @@ -use std::net::SocketAddr; +use std::{net::SocketAddr, sync::Arc}; use async_trait::async_trait; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, +use rustls_pki_types as pki_types; + +#[cfg(feature = "tokio")] +use async_tungstenite::tokio as async_tungstenite; + +#[cfg(feature = "smol")] +use futures_rustls::{rustls, TlsAcceptor, TlsConnector}; +#[cfg(feature = "tokio")] +use tokio_rustls::{rustls, TlsAcceptor, TlsConnector}; + +use karyon_core::async_runtime::{ lock::Mutex, net::{TcpListener, TcpStream}, }; -use ws_stream_tungstenite::WsStream; - use crate::{ - connection::{Connection, ToConn}, + codec::WebSocketCodec, + connection::{Conn, Connection, ToConn}, endpoint::Endpoint, - listener::{ConnListener, ToListener}, - Error, Result, + listener::{ConnListener, Listener, ToListener}, + stream::WsStream, + Result, }; +use super::tcp::TcpConfig; + +/// WSS configuration +#[derive(Clone)] +pub struct ServerWssConfig { + pub server_config: rustls::ServerConfig, +} + +/// WSS configuration +#[derive(Clone)] +pub struct ClientWssConfig { + pub client_config: rustls::ClientConfig, + pub dns_name: String, +} + +/// WS configuration +#[derive(Clone, Default)] +pub struct ServerWsConfig { + pub tcp_config: TcpConfig, + pub wss_config: Option, +} + +/// WS configuration +#[derive(Clone, Default)] +pub struct ClientWsConfig { + pub tcp_config: TcpConfig, + pub wss_config: Option, +} + /// WS network connection implementation of the [`Connection`] trait. -pub struct WsConn { - inner: TcpStream, - read: Mutex>>, - write: Mutex>>, +pub struct WsConn { + // XXX: remove mutex + inner: Mutex>, + peer_endpoint: Endpoint, + local_endpoint: Endpoint, } -impl WsConn { +impl WsConn +where + C: WebSocketCodec, +{ /// Creates a new WsConn - pub fn new(inner: TcpStream, conn: WsStream) -> Self { - let (read, write) = split(conn); + pub fn new(ws: WsStream, peer_endpoint: Endpoint, local_endpoint: Endpoint) -> Self { Self { - inner, - read: Mutex::new(read), - write: Mutex::new(write), + inner: Mutex::new(ws), + peer_endpoint, + local_endpoint, } } } #[async_trait] -impl Connection for WsConn { +impl Connection for WsConn +where + C: WebSocketCodec, +{ + type Item = C::Item; fn peer_endpoint(&self) -> Result { - Ok(Endpoint::new_ws_addr(&self.inner.peer_addr()?)) + Ok(self.peer_endpoint.clone()) } fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?)) + Ok(self.local_endpoint.clone()) } - async fn read(&self, buf: &mut [u8]) -> Result { - self.read.lock().await.read(buf).await.map_err(Error::from) + async fn recv(&self) -> Result { + self.inner.lock().await.recv().await } - async fn write(&self, buf: &[u8]) -> Result { - self.write - .lock() - .await - .write(buf) - .await - .map_err(Error::from) + async fn send(&self, msg: Self::Item) -> Result<()> { + self.inner.lock().await.send(msg).await } } /// Ws network listener implementation of the `Listener` [`ConnListener`] trait. -pub struct WsListener { +pub struct WsListener { inner: TcpListener, + config: ServerWsConfig, + codec: C, + tls_acceptor: Option, } #[async_trait] -impl ConnListener for WsListener { +impl ConnListener for WsListener +where + C: WebSocketCodec + Clone, +{ + type Item = C::Item; fn local_endpoint(&self) -> Result { - Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?)) + match self.config.wss_config { + Some(_) => Ok(Endpoint::new_wss_addr(self.inner.local_addr()?)), + None => Ok(Endpoint::new_ws_addr(self.inner.local_addr()?)), + } } - async fn accept(&self) -> Result> { - let (stream, _) = self.inner.accept().await?; - let conn = async_tungstenite::accept_async(stream.clone()).await?; - Ok(Box::new(WsConn::new(stream, WsStream::new(conn)))) + async fn accept(&self) -> Result> { + let (socket, _) = self.inner.accept().await?; + socket.set_nodelay(self.config.tcp_config.nodelay)?; + + match &self.config.wss_config { + Some(_) => match &self.tls_acceptor { + Some(acceptor) => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?; + + let tls_conn = acceptor.accept(socket).await?.into(); + let conn = async_tungstenite::accept_async(tls_conn).await?; + Ok(Box::new(WsConn::new( + WsStream::new_wss(conn, self.codec.clone()), + peer_endpoint, + local_endpoint, + ))) + } + None => unreachable!(), + }, + None => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?; + + let conn = async_tungstenite::accept_async(socket).await?; + + Ok(Box::new(WsConn::new( + WsStream::new_ws(conn, self.codec.clone()), + peer_endpoint, + local_endpoint, + ))) + } + } } } /// Connects to the given WS address and port. -pub async fn dial(endpoint: &Endpoint) -> Result { +pub async fn dial(endpoint: &Endpoint, config: ClientWsConfig, codec: C) -> Result> +where + C: WebSocketCodec, +{ let addr = SocketAddr::try_from(endpoint.clone())?; - let stream = TcpStream::connect(addr).await?; - let (conn, _resp) = - async_tungstenite::client_async(endpoint.to_string(), stream.clone()).await?; - Ok(WsConn::new(stream, WsStream::new(conn))) + let socket = TcpStream::connect(addr).await?; + socket.set_nodelay(config.tcp_config.nodelay)?; + + match &config.wss_config { + Some(conf) => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?; + + let connector = TlsConnector::from(Arc::new(conf.client_config.clone())); + + let altname = pki_types::ServerName::try_from(conf.dns_name.clone())?; + let tls_conn = connector.connect(altname, socket).await?.into(); + let (conn, _resp) = + async_tungstenite::client_async(endpoint.to_string(), tls_conn).await?; + Ok(WsConn::new( + WsStream::new_wss(conn, codec), + peer_endpoint, + local_endpoint, + )) + } + None => { + let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?; + let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?; + let (conn, _resp) = + async_tungstenite::client_async(endpoint.to_string(), socket).await?; + Ok(WsConn::new( + WsStream::new_ws(conn, codec), + peer_endpoint, + local_endpoint, + )) + } + } } /// Listens on the given WS address and port. -pub async fn listen(endpoint: &Endpoint) -> Result { +pub async fn listen( + endpoint: &Endpoint, + config: ServerWsConfig, + codec: C, +) -> Result> { let addr = SocketAddr::try_from(endpoint.clone())?; + let listener = TcpListener::bind(addr).await?; - Ok(WsListener { inner: listener }) + match &config.wss_config { + Some(conf) => { + let acceptor = TlsAcceptor::from(Arc::new(conf.server_config.clone())); + Ok(WsListener { + inner: listener, + config, + codec, + tls_acceptor: Some(acceptor), + }) + } + None => Ok(WsListener { + inner: listener, + config, + codec, + tls_acceptor: None, + }), + } } -impl From for Box { - fn from(listener: WsListener) -> Self { +impl From> for Listener +where + C: WebSocketCodec + Clone, +{ + fn from(listener: WsListener) -> Self { Box::new(listener) } } -impl ToConn for WsConn { - fn to_conn(self) -> Box { +impl ToConn for WsConn +where + C: WebSocketCodec, +{ + type Item = C::Item; + fn to_conn(self) -> Conn { Box::new(self) } } -impl ToListener for WsListener { - fn to_listener(self) -> Box { +impl ToListener for WsListener +where + C: WebSocketCodec + Clone, +{ + type Item = C::Item; + fn to_listener(self) -> Listener { self.into() } } -- cgit v1.2.3