aboutsummaryrefslogtreecommitdiff
path: root/net
diff options
context:
space:
mode:
authorhozan23 <hozan23@karyontech.net>2024-04-11 10:19:20 +0200
committerhozan23 <hozan23@karyontech.net>2024-05-19 13:51:30 +0200
commit0992071a7f1a36424bcfaf1fbc84541ea041df1a (patch)
tree961d73218af672797d49f899289bef295bc56493 /net
parenta69917ecd8272a4946cfd12c75bf8f8c075b0e50 (diff)
add support for tokio & improve net crate api
Diffstat (limited to 'net')
-rw-r--r--net/Cargo.toml38
-rw-r--r--net/examples/tcp_codec.rs59
-rw-r--r--net/src/codec/bytes_codec.rs29
-rw-r--r--net/src/codec/length_codec.rs49
-rw-r--r--net/src/codec/mod.rs25
-rw-r--r--net/src/codec/websocket.rs23
-rw-r--r--net/src/connection.rs53
-rw-r--r--net/src/endpoint.rs78
-rw-r--r--net/src/error.rs24
-rw-r--r--net/src/lib.rs14
-rw-r--r--net/src/listener.rs41
-rw-r--r--net/src/stream/buffer.rs82
-rw-r--r--net/src/stream/mod.rs191
-rw-r--r--net/src/stream/websocket.rs107
-rw-r--r--net/src/transports/tcp.rs188
-rw-r--r--net/src/transports/tls.rs220
-rw-r--r--net/src/transports/udp.rs114
-rw-r--r--net/src/transports/unix.rs193
-rw-r--r--net/src/transports/ws.rs242
19 files changed, 1364 insertions, 406 deletions
diff --git a/net/Cargo.toml b/net/Cargo.toml
index fe209cd..304cbb2 100644
--- a/net/Cargo.toml
+++ b/net/Cargo.toml
@@ -1,19 +1,43 @@
[package]
name = "karyon_net"
-version.workspace = true
+version.workspace = true
edition.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+[features]
+default = ["smol"]
+smol = [
+ "karyon_core/smol",
+ "async-tungstenite/async-std-runtime",
+ "dep:futures-rustls",
+]
+tokio = [
+ "karyon_core/tokio",
+ "async-tungstenite/tokio-runtime",
+ "dep:tokio",
+ "dep:tokio-rustls",
+]
+
+
[dependencies]
-karyon_core.workspace = true
+karyon_core = { workspace = true, default-features = false }
-smol = "2.0.0"
+pin-project-lite = "0.2.13"
async-trait = "0.1.77"
log = "0.4.21"
-bincode = { version="2.0.0-rc.3", features = ["derive"]}
+bincode = { version = "2.0.0-rc.3", features = ["derive"] }
thiserror = "1.0.58"
url = "2.5.0"
-futures-rustls = "0.25.1"
-async-tungstenite = "0.25.0"
-ws_stream_tungstenite = "0.13.0"
+async-tungstenite = { version = "0.25.0", default-features = false }
+asynchronous-codec = "0.7.0"
+futures-util = "0.3.30"
+async-channel = "2.3.0"
+rustls-pki-types = "1.7.0"
+
+futures-rustls = { version = "0.25.1", optional = true }
+tokio-rustls = { version = "0.26.0", optional = true }
+tokio = { version = "1.37.0", features = ["io-util"], optional = true }
+
+[dev-dependencies]
+smol = "2.0.0"
diff --git a/net/examples/tcp_codec.rs b/net/examples/tcp_codec.rs
new file mode 100644
index 0000000..93deaae
--- /dev/null
+++ b/net/examples/tcp_codec.rs
@@ -0,0 +1,59 @@
+use std::time::Duration;
+
+use karyon_core::async_util::sleep;
+
+use karyon_net::{
+ codec::{Codec, Decoder, Encoder},
+ tcp, ConnListener, Connection, Endpoint, Result,
+};
+
+#[derive(Clone)]
+struct NewLineCodec {}
+
+impl Codec for NewLineCodec {
+ type Item = String;
+}
+
+impl Encoder for NewLineCodec {
+ type EnItem = String;
+ fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result<usize> {
+ dst[..src.len()].copy_from_slice(src.as_bytes());
+ Ok(src.len())
+ }
+}
+
+impl Decoder for NewLineCodec {
+ type DeItem = String;
+ fn decode(&self, src: &mut [u8]) -> Result<Option<(usize, Self::DeItem)>> {
+ match src.iter().position(|&b| b == b'\n') {
+ Some(i) => Ok(Some((i + 1, String::from_utf8(src[..i].to_vec()).unwrap()))),
+ None => Ok(None),
+ }
+ }
+}
+
+fn main() {
+ smol::block_on(async {
+ let endpoint: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap();
+
+ let config = tcp::TcpConfig::default();
+
+ let listener = tcp::listen(&endpoint, config.clone(), NewLineCodec {})
+ .await
+ .unwrap();
+ smol::spawn(async move {
+ if let Ok(conn) = listener.accept().await {
+ loop {
+ let msg = conn.recv().await.unwrap();
+ println!("Receive a message: {:?}", msg);
+ }
+ };
+ })
+ .detach();
+
+ let conn = tcp::dial(&endpoint, config, NewLineCodec {}).await.unwrap();
+ conn.send("hello".to_string()).await.unwrap();
+ conn.send(" world\n".to_string()).await.unwrap();
+ sleep(Duration::from_secs(1)).await;
+ });
+}
diff --git a/net/src/codec/bytes_codec.rs b/net/src/codec/bytes_codec.rs
new file mode 100644
index 0000000..b319e53
--- /dev/null
+++ b/net/src/codec/bytes_codec.rs
@@ -0,0 +1,29 @@
+use crate::{
+ codec::{Codec, Decoder, Encoder},
+ Result,
+};
+
+#[derive(Clone)]
+pub struct BytesCodec {}
+impl Codec for BytesCodec {
+ type Item = Vec<u8>;
+}
+
+impl Encoder for BytesCodec {
+ type EnItem = Vec<u8>;
+ fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result<usize> {
+ dst[..src.len()].copy_from_slice(src);
+ Ok(src.len())
+ }
+}
+
+impl Decoder for BytesCodec {
+ type DeItem = Vec<u8>;
+ fn decode(&self, src: &mut [u8]) -> Result<Option<(usize, Self::DeItem)>> {
+ if src.is_empty() {
+ Ok(None)
+ } else {
+ Ok(Some((src.len(), src.to_vec())))
+ }
+ }
+}
diff --git a/net/src/codec/length_codec.rs b/net/src/codec/length_codec.rs
new file mode 100644
index 0000000..76a1679
--- /dev/null
+++ b/net/src/codec/length_codec.rs
@@ -0,0 +1,49 @@
+use karyon_core::util::{decode, encode_into_slice};
+
+use crate::{
+ codec::{Codec, Decoder, Encoder},
+ Result,
+};
+
+/// The size of the message length.
+const MSG_LENGTH_SIZE: usize = std::mem::size_of::<u32>();
+
+#[derive(Clone)]
+pub struct LengthCodec {}
+impl Codec for LengthCodec {
+ type Item = Vec<u8>;
+}
+
+impl Encoder for LengthCodec {
+ type EnItem = Vec<u8>;
+ fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result<usize> {
+ let length_buf = &mut [0; MSG_LENGTH_SIZE];
+ encode_into_slice(&(src.len() as u32), length_buf)?;
+ dst[..MSG_LENGTH_SIZE].copy_from_slice(length_buf);
+ dst[MSG_LENGTH_SIZE..src.len() + MSG_LENGTH_SIZE].copy_from_slice(src);
+ Ok(src.len() + MSG_LENGTH_SIZE)
+ }
+}
+
+impl Decoder for LengthCodec {
+ type DeItem = Vec<u8>;
+ fn decode(&self, src: &mut [u8]) -> Result<Option<(usize, Self::DeItem)>> {
+ if src.len() < MSG_LENGTH_SIZE {
+ return Ok(None);
+ }
+
+ let mut length = [0; MSG_LENGTH_SIZE];
+ length.copy_from_slice(&src[..MSG_LENGTH_SIZE]);
+ let (length, _) = decode::<u32>(&length)?;
+ let length = length as usize;
+
+ if src.len() - MSG_LENGTH_SIZE >= length {
+ Ok(Some((
+ length + MSG_LENGTH_SIZE,
+ src[MSG_LENGTH_SIZE..length + MSG_LENGTH_SIZE].to_vec(),
+ )))
+ } else {
+ Ok(None)
+ }
+ }
+}
diff --git a/net/src/codec/mod.rs b/net/src/codec/mod.rs
new file mode 100644
index 0000000..565cb07
--- /dev/null
+++ b/net/src/codec/mod.rs
@@ -0,0 +1,25 @@
+mod bytes_codec;
+mod length_codec;
+mod websocket;
+
+pub use bytes_codec::BytesCodec;
+pub use length_codec::LengthCodec;
+pub use websocket::{WebSocketCodec, WebSocketDecoder, WebSocketEncoder};
+
+use crate::Result;
+
+pub trait Codec:
+ Decoder<DeItem = Self::Item> + Encoder<EnItem = Self::Item> + Send + Sync + 'static + Unpin
+{
+ type Item: Send + Sync;
+}
+
+pub trait Encoder {
+ type EnItem;
+ fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result<usize>;
+}
+
+pub trait Decoder {
+ type DeItem;
+ fn decode(&self, src: &mut [u8]) -> Result<Option<(usize, Self::DeItem)>>;
+}
diff --git a/net/src/codec/websocket.rs b/net/src/codec/websocket.rs
new file mode 100644
index 0000000..b59a55c
--- /dev/null
+++ b/net/src/codec/websocket.rs
@@ -0,0 +1,23 @@
+use crate::Result;
+use async_tungstenite::tungstenite::Message;
+
+pub trait WebSocketCodec:
+ WebSocketDecoder<DeItem = Self::Item>
+ + WebSocketEncoder<EnItem = Self::Item>
+ + Send
+ + Sync
+ + 'static
+ + Unpin
+{
+ type Item: Send + Sync;
+}
+
+pub trait WebSocketEncoder {
+ type EnItem;
+ fn encode(&self, src: &Self::EnItem) -> Result<Message>;
+}
+
+pub trait WebSocketDecoder {
+ type DeItem;
+ fn decode(&self, src: &Message) -> Result<Self::DeItem>;
+}
diff --git a/net/src/connection.rs b/net/src/connection.rs
index fa4640f..bbd21de 100644
--- a/net/src/connection.rs
+++ b/net/src/connection.rs
@@ -1,65 +1,34 @@
use async_trait::async_trait;
-use crate::{
- transports::{tcp, udp, unix},
- Endpoint, Error, Result,
-};
+use crate::{Endpoint, Result};
/// Alias for `Box<dyn Connection>`
-pub type Conn = Box<dyn Connection>;
+pub type Conn<T> = Box<dyn Connection<Item = T>>;
/// A trait for objects which can be converted to [`Conn`].
pub trait ToConn {
- fn to_conn(self) -> Conn;
+ type Item;
+ fn to_conn(self) -> Conn<Self::Item>;
}
/// Connection is a generic network connection interface for
-/// [`udp::UdpConn`], [`tcp::TcpConn`], and [`unix::UnixConn`].
+/// [`udp::UdpConn`], [`tcp::TcpConn`], [`tls::TlsConn`], [`ws::WsConn`],
+/// and [`unix::UnixConn`].
///
/// If you are familiar with the Go language, this is similar to the
/// [Conn](https://pkg.go.dev/net#Conn) interface
#[async_trait]
pub trait Connection: Send + Sync {
+ type Item;
/// Returns the remote peer endpoint of this connection
fn peer_endpoint(&self) -> Result<Endpoint>;
/// Returns the local socket endpoint of this connection
fn local_endpoint(&self) -> Result<Endpoint>;
- /// Reads data from this connection.
- async fn read(&self, buf: &mut [u8]) -> Result<usize>;
+ /// Recvs data from this connection.
+ async fn recv(&self) -> Result<Self::Item>;
- /// Writes data to this connection
- async fn write(&self, buf: &[u8]) -> Result<usize>;
-}
-
-/// Connects to the provided endpoint.
-///
-/// it only supports `tcp4/6`, `udp4/6`, and `unix`.
-///
-/// #Example
-///
-/// ```
-/// use karyon_net::{Endpoint, dial};
-///
-/// async {
-/// let endpoint: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap();
-///
-/// let conn = dial(&endpoint).await.unwrap();
-///
-/// conn.write(b"MSG").await.unwrap();
-///
-/// let mut buffer = [0;32];
-/// conn.read(&mut buffer).await.unwrap();
-/// };
-///
-/// ```
-///
-pub async fn dial(endpoint: &Endpoint) -> Result<Conn> {
- match endpoint {
- Endpoint::Tcp(_, _) => Ok(Box::new(tcp::dial(endpoint).await?)),
- Endpoint::Udp(_, _) => Ok(Box::new(udp::dial(endpoint).await?)),
- Endpoint::Unix(addr) => Ok(Box::new(unix::dial(addr).await?)),
- _ => Err(Error::InvalidEndpoint(endpoint.to_string())),
- }
+ /// Sends data to this connection
+ async fn send(&self, msg: Self::Item) -> Result<()>;
}
diff --git a/net/src/endpoint.rs b/net/src/endpoint.rs
index 9193628..0c7ecd1 100644
--- a/net/src/endpoint.rs
+++ b/net/src/endpoint.rs
@@ -1,10 +1,11 @@
use std::{
net::{IpAddr, SocketAddr},
- os::unix::net::SocketAddr as UnixSocketAddress,
path::PathBuf,
str::FromStr,
};
+use std::os::unix::net::SocketAddr as UnixSocketAddr;
+
use bincode::{Decode, Encode};
use url::Url;
@@ -25,7 +26,7 @@ pub type Port = u16;
/// let endpoint: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap();
///
/// let socketaddr: SocketAddr = "127.0.0.1:3000".parse().unwrap();
-/// let endpoint = Endpoint::new_udp_addr(&socketaddr);
+/// let endpoint = Endpoint::new_udp_addr(socketaddr);
///
/// ```
///
@@ -35,7 +36,8 @@ pub enum Endpoint {
Tcp(Addr, Port),
Tls(Addr, Port),
Ws(Addr, Port),
- Unix(String),
+ Wss(Addr, Port),
+ Unix(PathBuf),
}
impl std::fmt::Display for Endpoint {
@@ -53,12 +55,11 @@ impl std::fmt::Display for Endpoint {
Endpoint::Ws(ip, port) => {
write!(f, "ws://{}:{}", ip, port)
}
+ Endpoint::Wss(ip, port) => {
+ write!(f, "wss://{}:{}", ip, port)
+ }
Endpoint::Unix(path) => {
- if path.is_empty() {
- write!(f, "unix:/UNNAMED")
- } else {
- write!(f, "unix:/{}", path)
- }
+ write!(f, "unix:/{}", path.to_string_lossy())
}
}
}
@@ -71,7 +72,8 @@ impl TryFrom<Endpoint> for SocketAddr {
Endpoint::Udp(ip, port)
| Endpoint::Tcp(ip, port)
| Endpoint::Tls(ip, port)
- | Endpoint::Ws(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)),
+ | Endpoint::Ws(ip, port)
+ | Endpoint::Wss(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)),
Endpoint::Unix(_) => Err(Error::TryFromEndpoint),
}
}
@@ -87,11 +89,11 @@ impl TryFrom<Endpoint> for PathBuf {
}
}
-impl TryFrom<Endpoint> for UnixSocketAddress {
+impl TryFrom<Endpoint> for UnixSocketAddr {
type Error = Error;
- fn try_from(endpoint: Endpoint) -> std::result::Result<UnixSocketAddress, Self::Error> {
+ fn try_from(endpoint: Endpoint) -> std::result::Result<UnixSocketAddr, Self::Error> {
match endpoint {
- Endpoint::Unix(a) => Ok(UnixSocketAddress::from_pathname(a)?),
+ Endpoint::Unix(a) => Ok(UnixSocketAddr::from_pathname(a)?),
_ => Err(Error::TryFromEndpoint),
}
}
@@ -124,6 +126,7 @@ impl FromStr for Endpoint {
"udp" => Ok(Endpoint::Udp(addr, port)),
"tls" => Ok(Endpoint::Tls(addr, port)),
"ws" => Ok(Endpoint::Ws(addr, port)),
+ "wss" => Ok(Endpoint::Wss(addr, port)),
_ => Err(Error::InvalidEndpoint(s.to_string())),
}
} else {
@@ -132,7 +135,7 @@ impl FromStr for Endpoint {
}
match url.scheme() {
- "unix" => Ok(Endpoint::Unix(url.path().to_string())),
+ "unix" => Ok(Endpoint::Unix(url.path().into())),
_ => Err(Error::InvalidEndpoint(s.to_string())),
}
}
@@ -141,33 +144,33 @@ impl FromStr for Endpoint {
impl Endpoint {
/// Creates a new TCP endpoint from a `SocketAddr`.
- pub fn new_tcp_addr(addr: &SocketAddr) -> Endpoint {
+ pub fn new_tcp_addr(addr: SocketAddr) -> Endpoint {
Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port())
}
/// Creates a new UDP endpoint from a `SocketAddr`.
- pub fn new_udp_addr(addr: &SocketAddr) -> Endpoint {
+ pub fn new_udp_addr(addr: SocketAddr) -> Endpoint {
Endpoint::Udp(Addr::Ip(addr.ip()), addr.port())
}
/// Creates a new TLS endpoint from a `SocketAddr`.
- pub fn new_tls_addr(addr: &SocketAddr) -> Endpoint {
+ pub fn new_tls_addr(addr: SocketAddr) -> Endpoint {
Endpoint::Tls(Addr::Ip(addr.ip()), addr.port())
}
/// Creates a new WS endpoint from a `SocketAddr`.
- pub fn new_ws_addr(addr: &SocketAddr) -> Endpoint {
+ pub fn new_ws_addr(addr: SocketAddr) -> Endpoint {
Endpoint::Ws(Addr::Ip(addr.ip()), addr.port())
}
- /// Creates a new Unix endpoint from a `UnixSocketAddress`.
- pub fn new_unix_addr(addr: &UnixSocketAddress) -> Endpoint {
- Endpoint::Unix(
- addr.as_pathname()
- .and_then(|a| a.to_str())
- .unwrap_or("")
- .to_string(),
- )
+ /// Creates a new WSS endpoint from a `SocketAddr`.
+ pub fn new_wss_addr(addr: SocketAddr) -> Endpoint {
+ Endpoint::Wss(Addr::Ip(addr.ip()), addr.port())
+ }
+
+ /// Creates a new Unix endpoint from a `UnixSocketAddr`.
+ pub fn new_unix_addr(addr: &std::path::Path) -> Endpoint {
+ Endpoint::Unix(addr.to_path_buf())
}
/// Returns the `Port` of the endpoint.
@@ -176,7 +179,8 @@ impl Endpoint {
Endpoint::Tcp(_, port)
| Endpoint::Udp(_, port)
| Endpoint::Tls(_, port)
- | Endpoint::Ws(_, port) => Ok(port),
+ | Endpoint::Ws(_, port)
+ | Endpoint::Wss(_, port) => Ok(port),
_ => Err(Error::TryFromEndpoint),
}
}
@@ -187,7 +191,8 @@ impl Endpoint {
Endpoint::Tcp(addr, _)
| Endpoint::Udp(addr, _)
| Endpoint::Tls(addr, _)
- | Endpoint::Ws(addr, _) => Ok(addr),
+ | Endpoint::Ws(addr, _)
+ | Endpoint::Wss(addr, _) => Ok(addr),
_ => Err(Error::TryFromEndpoint),
}
}
@@ -223,10 +228,27 @@ impl std::fmt::Display for Addr {
}
}
+pub trait ToEndpoint {
+ fn to_endpoint(&self) -> Result<Endpoint>;
+}
+
+impl ToEndpoint for String {
+ fn to_endpoint(&self) -> Result<Endpoint> {
+ Endpoint::from_str(self)
+ }
+}
+
+impl ToEndpoint for &str {
+ fn to_endpoint(&self) -> Result<Endpoint> {
+ Endpoint::from_str(self)
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
+ use std::path::PathBuf;
#[test]
fn test_endpoint_from_str() {
@@ -243,7 +265,7 @@ mod tests {
assert_eq!(endpoint_str, endpoint);
let endpoint_str = "unix:/home/x/s.socket".parse::<Endpoint>().unwrap();
- let endpoint = Endpoint::Unix("/home/x/s.socket".to_string());
+ let endpoint = Endpoint::Unix(PathBuf::from_str("/home/x/s.socket").unwrap());
assert_eq!(endpoint_str, endpoint);
}
}
diff --git a/net/src/error.rs b/net/src/error.rs
index 6e04a12..ee93168 100644
--- a/net/src/error.rs
+++ b/net/src/error.rs
@@ -13,9 +13,18 @@ pub enum Error {
#[error("invalid address {0}")]
InvalidAddress(String),
+ #[error("invalid path {0}")]
+ InvalidPath(String),
+
#[error("invalid endpoint {0}")]
InvalidEndpoint(String),
+ #[error("Encode error: {0}")]
+ Encode(String),
+
+ #[error("Decode error: {0}")]
+ Decode(String),
+
#[error("Parse endpoint error {0}")]
ParseEndpoint(String),
@@ -26,23 +35,28 @@ pub enum Error {
ChannelSend(String),
#[error(transparent)]
- ChannelRecv(#[from] smol::channel::RecvError),
+ ChannelRecv(#[from] async_channel::RecvError),
#[error("Ws Error: {0}")]
WsError(#[from] async_tungstenite::tungstenite::Error),
+ #[cfg(feature = "smol")]
#[error("Tls Error: {0}")]
Rustls(#[from] futures_rustls::rustls::Error),
+ #[cfg(feature = "tokio")]
+ #[error("Tls Error: {0}")]
+ Rustls(#[from] tokio_rustls::rustls::Error),
+
#[error("Invalid DNS Name: {0}")]
- InvalidDnsNameError(#[from] futures_rustls::pki_types::InvalidDnsNameError),
+ InvalidDnsNameError(#[from] rustls_pki_types::InvalidDnsNameError),
#[error(transparent)]
- KaryonCore(#[from] karyon_core::error::Error),
+ KaryonCore(#[from] karyon_core::Error),
}
-impl<T> From<smol::channel::SendError<T>> for Error {
- fn from(error: smol::channel::SendError<T>) -> Self {
+impl<T> From<async_channel::SendError<T>> for Error {
+ fn from(error: async_channel::SendError<T>) -> Self {
Error::ChannelSend(error.to_string())
}
}
diff --git a/net/src/lib.rs b/net/src/lib.rs
index c1d72b2..ddb53cf 100644
--- a/net/src/lib.rs
+++ b/net/src/lib.rs
@@ -1,20 +1,20 @@
+pub mod codec;
mod connection;
mod endpoint;
mod error;
mod listener;
+mod stream;
mod transports;
pub use {
- connection::{dial, Conn, Connection, ToConn},
- endpoint::{Addr, Endpoint, Port},
- listener::{listen, ConnListener, Listener, ToListener},
+ connection::{Conn, Connection, ToConn},
+ endpoint::{Addr, Endpoint, Port, ToEndpoint},
+ listener::{ConnListener, Listener, ToListener},
transports::{tcp, tls, udp, unix, ws},
};
-use error::{Error, Result};
-
/// Represents karyon's Net Error
-pub use error::Error as NetError;
+pub use error::Error;
/// Represents karyon's Net Result
-pub use error::Result as NetResult;
+pub use error::Result;
diff --git a/net/src/listener.rs b/net/src/listener.rs
index 4511212..469f5e9 100644
--- a/net/src/listener.rs
+++ b/net/src/listener.rs
@@ -1,46 +1,21 @@
use async_trait::async_trait;
-use crate::{
- transports::{tcp, unix},
- Conn, Endpoint, Error, Result,
-};
+use crate::{Conn, Endpoint, Result};
/// Alias for `Box<dyn ConnListener>`
-pub type Listener = Box<dyn ConnListener>;
+pub type Listener<T> = Box<dyn ConnListener<Item = T>>;
/// A trait for objects which can be converted to [`Listener`].
pub trait ToListener {
- fn to_listener(self) -> Listener;
+ type Item;
+ fn to_listener(self) -> Listener<Self::Item>;
}
-/// ConnListener is a generic network listener.
+/// ConnListener is a generic network listener interface for
+/// [`tcp::TcpConn`], [`tls::TlsConn`], [`ws::WsConn`], and [`unix::UnixConn`].
#[async_trait]
pub trait ConnListener: Send + Sync {
+ type Item;
fn local_endpoint(&self) -> Result<Endpoint>;
- async fn accept(&self) -> Result<Conn>;
-}
-
-/// Listens to the provided endpoint.
-///
-/// it only supports `tcp4/6`, and `unix`.
-///
-/// #Example
-///
-/// ```
-/// use karyon_net::{Endpoint, listen};
-///
-/// async {
-/// let endpoint: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap();
-///
-/// let listener = listen(&endpoint).await.unwrap();
-/// let conn = listener.accept().await.unwrap();
-/// };
-///
-/// ```
-pub async fn listen(endpoint: &Endpoint) -> Result<Box<dyn ConnListener>> {
- match endpoint {
- Endpoint::Tcp(_, _) => Ok(Box::new(tcp::listen(endpoint).await?)),
- Endpoint::Unix(addr) => Ok(Box::new(unix::listen(addr)?)),
- _ => Err(Error::InvalidEndpoint(endpoint.to_string())),
- }
+ async fn accept(&self) -> Result<Conn<Self::Item>>;
}
diff --git a/net/src/stream/buffer.rs b/net/src/stream/buffer.rs
new file mode 100644
index 0000000..f211600
--- /dev/null
+++ b/net/src/stream/buffer.rs
@@ -0,0 +1,82 @@
+#[derive(Debug)]
+pub struct Buffer<B> {
+ inner: B,
+ len: usize,
+ cap: usize,
+}
+
+impl<B> Buffer<B>
+where
+ B: AsMut<[u8]> + AsRef<[u8]>,
+{
+ /// Constructs a new, empty Buffer<B>.
+ pub fn new(b: B) -> Self {
+ Self {
+ cap: b.as_ref().len(),
+ inner: b,
+ len: 0,
+ }
+ }
+
+ /// Returns the number of elements in the buffer.
+ #[allow(dead_code)]
+ pub fn len(&self) -> usize {
+ self.len
+ }
+
+ /// Resizes the buffer in-place so that `len` is equal to `new_size`.
+ pub fn resize(&mut self, new_size: usize) {
+ assert!(self.cap > new_size);
+ self.len = new_size;
+ }
+
+ /// Appends all elements in a slice to the buffer.
+ pub fn extend_from_slice(&mut self, bytes: &[u8]) {
+ let old_len = self.len;
+ self.resize(self.len + bytes.len());
+ self.inner.as_mut()[old_len..bytes.len() + old_len].copy_from_slice(bytes);
+ }
+
+ /// Shortens the buffer, dropping the first `cnt` bytes and keeping the
+ /// rest.
+ pub fn advance(&mut self, cnt: usize) {
+ assert!(self.len >= cnt);
+ self.inner.as_mut().rotate_left(cnt);
+ self.resize(self.len - cnt);
+ }
+
+ /// Returns `true` if the buffer contains no elements.
+ pub fn is_empty(&self) -> bool {
+ self.len == 0
+ }
+}
+
+impl<B> AsMut<[u8]> for Buffer<B>
+where
+ B: AsMut<[u8]> + AsRef<[u8]>,
+{
+ fn as_mut(&mut self) -> &mut [u8] {
+ &mut self.inner.as_mut()[..self.len]
+ }
+}
+
+impl<B> AsRef<[u8]> for Buffer<B>
+where
+ B: AsMut<[u8]> + AsRef<[u8]>,
+{
+ fn as_ref(&self) -> &[u8] {
+ &self.inner.as_ref()[..self.len]
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_buffer_advance() {
+ let mut buf = Buffer::new([0u8; 32]);
+ buf.extend_from_slice(&[1, 2, 3]);
+ assert_eq!([1, 2, 3], buf.as_ref());
+ }
+}
diff --git a/net/src/stream/mod.rs b/net/src/stream/mod.rs
new file mode 100644
index 0000000..9493b29
--- /dev/null
+++ b/net/src/stream/mod.rs
@@ -0,0 +1,191 @@
+mod buffer;
+mod websocket;
+
+pub use websocket::WsStream;
+
+use std::{
+ io::ErrorKind,
+ pin::Pin,
+ task::{Context, Poll},
+};
+
+use futures_util::{
+ ready,
+ stream::{Stream, StreamExt},
+ Sink,
+};
+use pin_project_lite::pin_project;
+
+use karyon_core::async_runtime::io::{AsyncRead, AsyncWrite};
+
+use crate::{
+ codec::{Decoder, Encoder},
+ Error, Result,
+};
+
+use buffer::Buffer;
+
+const BUFFER_SIZE: usize = 2048 * 2024; // 4MB
+const INITIAL_BUFFER_SIZE: usize = 1024 * 1024; // 1MB
+
+pub struct ReadStream<T, C> {
+ inner: T,
+ decoder: C,
+ buffer: Buffer<[u8; BUFFER_SIZE]>,
+}
+
+impl<T, C> ReadStream<T, C>
+where
+ T: AsyncRead + Unpin,
+ C: Decoder + Unpin,
+{
+ pub fn new(inner: T, decoder: C) -> Self {
+ Self {
+ inner,
+ decoder,
+ buffer: Buffer::new([0u8; BUFFER_SIZE]),
+ }
+ }
+
+ pub async fn recv(&mut self) -> Result<C::DeItem> {
+ match self.next().await {
+ Some(m) => m,
+ None => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())),
+ }
+ }
+}
+
+pin_project! {
+ pub struct WriteStream<T, C> {
+ #[pin]
+ inner: T,
+ encoder: C,
+ high_water_mark: usize,
+ buffer: Buffer<[u8; BUFFER_SIZE]>,
+ }
+}
+
+impl<T, C> WriteStream<T, C>
+where
+ T: AsyncWrite + Unpin,
+ C: Encoder + Unpin,
+{
+ pub fn new(inner: T, encoder: C) -> Self {
+ Self {
+ inner,
+ encoder,
+ high_water_mark: 131072,
+ buffer: Buffer::new([0u8; BUFFER_SIZE]),
+ }
+ }
+}
+
+impl<T, C> Stream for ReadStream<T, C>
+where
+ T: AsyncRead + Unpin,
+ C: Decoder + Unpin,
+{
+ type Item = Result<C::DeItem>;
+
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ let this = &mut *self;
+
+ if let Some((n, item)) = this.decoder.decode(this.buffer.as_mut())? {
+ this.buffer.advance(n);
+ return Poll::Ready(Some(Ok(item)));
+ }
+
+ let mut buf = [0u8; INITIAL_BUFFER_SIZE];
+ #[cfg(feature = "tokio")]
+ let mut buf = tokio::io::ReadBuf::new(&mut buf);
+
+ loop {
+ #[cfg(feature = "smol")]
+ let n = ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?;
+ #[cfg(feature = "smol")]
+ let bytes = &buf[..n];
+
+ #[cfg(feature = "tokio")]
+ ready!(Pin::new(&mut this.inner).poll_read(cx, &mut buf))?;
+ #[cfg(feature = "tokio")]
+ let bytes = buf.filled();
+ #[cfg(feature = "tokio")]
+ let n = bytes.len();
+
+ this.buffer.extend_from_slice(bytes);
+
+ match this.decoder.decode(this.buffer.as_mut())? {
+ Some((cn, item)) => {
+ this.buffer.advance(cn);
+ return Poll::Ready(Some(Ok(item)));
+ }
+ None if n == 0 => {
+ if this.buffer.is_empty() {
+ return Poll::Ready(None);
+ } else {
+ return Poll::Ready(Some(Err(std::io::Error::new(
+ std::io::ErrorKind::UnexpectedEof,
+ "bytes remaining in read stream",
+ )
+ .into())));
+ }
+ }
+ _ => continue,
+ }
+ }
+ }
+}
+
+impl<T, C> Sink<C::EnItem> for WriteStream<T, C>
+where
+ T: AsyncWrite + Unpin,
+ C: Encoder + Unpin,
+{
+ type Error = Error;
+
+ fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<()>> {
+ let this = &mut *self;
+ while !this.buffer.is_empty() {
+ let n = ready!(Pin::new(&mut this.inner).poll_write(cx, this.buffer.as_ref()))?;
+
+ if n == 0 {
+ return Poll::Ready(Err(std::io::Error::new(
+ ErrorKind::UnexpectedEof,
+ "End of file",
+ )
+ .into()));
+ }
+
+ this.buffer.advance(n);
+ }
+
+ Poll::Ready(Ok(()))
+ }
+
+ fn start_send(mut self: Pin<&mut Self>, item: C::EnItem) -> Result<()> {
+ let this = &mut *self;
+ let mut buf = [0u8; INITIAL_BUFFER_SIZE];
+ let n = this.encoder.encode(&item, &mut buf)?;
+ this.buffer.extend_from_slice(&buf[..n]);
+ Ok(())
+ }
+
+ fn poll_flush(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context,
+ ) -> Poll<std::result::Result<(), Self::Error>> {
+ ready!(self.as_mut().poll_ready(cx))?;
+ self.project().inner.poll_flush(cx).map_err(Into::into)
+ }
+
+ fn poll_close(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context,
+ ) -> Poll<std::result::Result<(), Self::Error>> {
+ ready!(self.as_mut().poll_flush(cx))?;
+ #[cfg(feature = "smol")]
+ return self.project().inner.poll_close(cx).map_err(Error::from);
+ #[cfg(feature = "tokio")]
+ return self.project().inner.poll_shutdown(cx).map_err(Error::from);
+ }
+}
diff --git a/net/src/stream/websocket.rs b/net/src/stream/websocket.rs
new file mode 100644
index 0000000..2552eaf
--- /dev/null
+++ b/net/src/stream/websocket.rs
@@ -0,0 +1,107 @@
+use std::{
+ pin::Pin,
+ task::{Context, Poll},
+};
+
+use async_tungstenite::tungstenite::Message;
+use futures_util::{Sink, SinkExt, Stream, StreamExt};
+
+#[cfg(feature = "smol")]
+use futures_rustls::TlsStream;
+#[cfg(feature = "tokio")]
+use tokio_rustls::TlsStream;
+
+use karyon_core::async_runtime::net::TcpStream;
+
+use crate::{codec::WebSocketCodec, Error, Result};
+
+#[cfg(feature = "tokio")]
+type WebSocketStream<T> =
+ async_tungstenite::WebSocketStream<async_tungstenite::tokio::TokioAdapter<T>>;
+#[cfg(feature = "smol")]
+use async_tungstenite::WebSocketStream;
+
+pub struct WsStream<C> {
+ inner: InnerWSConn,
+ codec: C,
+}
+
+impl<C> WsStream<C>
+where
+ C: WebSocketCodec,
+{
+ pub fn new_ws(conn: WebSocketStream<TcpStream>, codec: C) -> Self {
+ Self {
+ inner: InnerWSConn::Plain(conn),
+ codec,
+ }
+ }
+
+ pub fn new_wss(conn: WebSocketStream<TlsStream<TcpStream>>, codec: C) -> Self {
+ Self {
+ inner: InnerWSConn::Tls(conn),
+ codec,
+ }
+ }
+
+ pub async fn recv(&mut self) -> Result<C::Item> {
+ match self.inner.next().await {
+ Some(msg) => self.codec.decode(&msg?),
+ None => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())),
+ }
+ }
+
+ pub async fn send(&mut self, msg: C::Item) -> Result<()> {
+ let ws_msg = self.codec.encode(&msg)?;
+ self.inner.send(ws_msg).await
+ }
+}
+
+enum InnerWSConn {
+ Plain(WebSocketStream<TcpStream>),
+ Tls(WebSocketStream<TlsStream<TcpStream>>),
+}
+
+impl Sink<Message> for InnerWSConn {
+ type Error = Error;
+
+ fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ match &mut *self {
+ InnerWSConn::Plain(s) => Pin::new(s).poll_ready(cx).map_err(Error::from),
+ InnerWSConn::Tls(s) => Pin::new(s).poll_ready(cx).map_err(Error::from),
+ }
+ }
+
+ fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<()> {
+ match &mut *self {
+ InnerWSConn::Plain(s) => Pin::new(s).start_send(item).map_err(Error::from),
+ InnerWSConn::Tls(s) => Pin::new(s).start_send(item).map_err(Error::from),
+ }
+ }
+
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ match &mut *self {
+ InnerWSConn::Plain(s) => Pin::new(s).poll_flush(cx).map_err(Error::from),
+ InnerWSConn::Tls(s) => Pin::new(s).poll_flush(cx).map_err(Error::from),
+ }
+ }
+
+ fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
+ match &mut *self {
+ InnerWSConn::Plain(s) => Pin::new(s).poll_close(cx).map_err(Error::from),
+ InnerWSConn::Tls(s) => Pin::new(s).poll_close(cx).map_err(Error::from),
+ }
+ .map_err(Error::from)
+ }
+}
+
+impl Stream for InnerWSConn {
+ type Item = Result<Message>;
+
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ match &mut *self {
+ InnerWSConn::Plain(s) => Pin::new(s).poll_next(cx).map_err(Error::from),
+ InnerWSConn::Tls(s) => Pin::new(s).poll_next(cx).map_err(Error::from),
+ }
+ }
+}
diff --git a/net/src/transports/tcp.rs b/net/src/transports/tcp.rs
index 21fce3d..03c8ab2 100644
--- a/net/src/transports/tcp.rs
+++ b/net/src/transports/tcp.rs
@@ -1,116 +1,184 @@
use std::net::SocketAddr;
use async_trait::async_trait;
-use smol::{
- io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
+use futures_util::SinkExt;
+
+use karyon_core::async_runtime::{
+ io::{split, ReadHalf, WriteHalf},
lock::Mutex,
- net::{TcpListener, TcpStream},
+ net::{TcpListener as AsyncTcpListener, TcpStream},
};
use crate::{
- connection::{Connection, ToConn},
+ codec::Codec,
+ connection::{Conn, Connection, ToConn},
endpoint::Endpoint,
- listener::{ConnListener, ToListener},
- Error, Result,
+ listener::{ConnListener, Listener, ToListener},
+ stream::{ReadStream, WriteStream},
+ Result,
};
-/// TCP network connection implementation of the [`Connection`] trait.
-pub struct TcpConn {
- inner: TcpStream,
- read: Mutex<ReadHalf<TcpStream>>,
- write: Mutex<WriteHalf<TcpStream>>,
+/// TCP configuration
+#[derive(Clone)]
+pub struct TcpConfig {
+ pub nodelay: bool,
+}
+
+impl Default for TcpConfig {
+ fn default() -> Self {
+ Self { nodelay: true }
+ }
+}
+
+/// TCP connection implementation of the [`Connection`] trait.
+pub struct TcpConn<C> {
+ read_stream: Mutex<ReadStream<ReadHalf<TcpStream>, C>>,
+ write_stream: Mutex<WriteStream<WriteHalf<TcpStream>, C>>,
+ peer_endpoint: Endpoint,
+ local_endpoint: Endpoint,
}
-impl TcpConn {
+impl<C> TcpConn<C>
+where
+ C: Codec + Clone,
+{
/// Creates a new TcpConn
- pub fn new(conn: TcpStream) -> Self {
- let (read, write) = split(conn.clone());
+ pub fn new(
+ socket: TcpStream,
+ codec: C,
+ peer_endpoint: Endpoint,
+ local_endpoint: Endpoint,
+ ) -> Self {
+ let (read, write) = split(socket);
+ let read_stream = Mutex::new(ReadStream::new(read, codec.clone()));
+ let write_stream = Mutex::new(WriteStream::new(write, codec));
Self {
- inner: conn,
- read: Mutex::new(read),
- write: Mutex::new(write),
+ read_stream,
+ write_stream,
+ peer_endpoint,
+ local_endpoint,
}
}
}
#[async_trait]
-impl Connection for TcpConn {
+impl<C> Connection for TcpConn<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
fn peer_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_tcp_addr(&self.inner.peer_addr()?))
+ Ok(self.peer_endpoint.clone())
}
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_tcp_addr(&self.inner.local_addr()?))
+ Ok(self.local_endpoint.clone())
}
- async fn read(&self, buf: &mut [u8]) -> Result<usize> {
- self.read.lock().await.read(buf).await.map_err(Error::from)
+ async fn recv(&self) -> Result<Self::Item> {
+ self.read_stream.lock().await.recv().await
}
- async fn write(&self, buf: &[u8]) -> Result<usize> {
- self.write
- .lock()
- .await
- .write(buf)
- .await
- .map_err(Error::from)
+ async fn send(&self, msg: Self::Item) -> Result<()> {
+ self.write_stream.lock().await.send(msg).await
+ }
+}
+
+pub struct TcpListener<C> {
+ inner: AsyncTcpListener,
+ config: TcpConfig,
+ codec: C,
+}
+
+impl<C> TcpListener<C>
+where
+ C: Codec,
+{
+ pub fn new(listener: AsyncTcpListener, config: TcpConfig, codec: C) -> Self {
+ Self {
+ inner: listener,
+ config: config.clone(),
+ codec,
+ }
}
}
#[async_trait]
-impl ConnListener for TcpListener {
+impl<C> ConnListener for TcpListener<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_tcp_addr(&self.local_addr()?))
+ Ok(Endpoint::new_tcp_addr(self.inner.local_addr()?))
}
- async fn accept(&self) -> Result<Box<dyn Connection>> {
- let (conn, _) = self.accept().await?;
- conn.set_nodelay(true)?;
- Ok(Box::new(TcpConn::new(conn)))
+ async fn accept(&self) -> Result<Conn<C::Item>> {
+ let (socket, _) = self.inner.accept().await?;
+ socket.set_nodelay(self.config.nodelay)?;
+
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_tcp_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_tcp_addr)?;
+
+ Ok(Box::new(TcpConn::new(
+ socket,
+ self.codec.clone(),
+ peer_endpoint,
+ local_endpoint,
+ )))
}
}
/// Connects to the given TCP address and port.
-pub async fn dial(endpoint: &Endpoint) -> Result<TcpConn> {
+pub async fn dial<C>(endpoint: &Endpoint, config: TcpConfig, codec: C) -> Result<TcpConn<C>>
+where
+ C: Codec + Clone,
+{
let addr = SocketAddr::try_from(endpoint.clone())?;
- let conn = TcpStream::connect(addr).await?;
- conn.set_nodelay(true)?;
- Ok(TcpConn::new(conn))
+ let socket = TcpStream::connect(addr).await?;
+ socket.set_nodelay(config.nodelay)?;
+
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_tcp_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_tcp_addr)?;
+
+ Ok(TcpConn::new(socket, codec, peer_endpoint, local_endpoint))
}
/// Listens on the given TCP address and port.
-pub async fn listen(endpoint: &Endpoint) -> Result<TcpListener> {
+pub async fn listen<C>(endpoint: &Endpoint, config: TcpConfig, codec: C) -> Result<TcpListener<C>>
+where
+ C: Codec,
+{
let addr = SocketAddr::try_from(endpoint.clone())?;
- let listener = TcpListener::bind(addr).await?;
- Ok(listener)
-}
-
-impl From<TcpStream> for Box<dyn Connection> {
- fn from(conn: TcpStream) -> Self {
- Box::new(TcpConn::new(conn))
- }
+ let listener = AsyncTcpListener::bind(addr).await?;
+ Ok(TcpListener::new(listener, config, codec))
}
-impl From<TcpListener> for Box<dyn ConnListener> {
- fn from(listener: TcpListener) -> Self {
+impl<C> From<TcpListener<C>> for Box<dyn ConnListener<Item = C::Item>>
+where
+ C: Clone + Codec,
+{
+ fn from(listener: TcpListener<C>) -> Self {
Box::new(listener)
}
}
-impl ToConn for TcpStream {
- fn to_conn(self) -> Box<dyn Connection> {
- self.into()
- }
-}
-
-impl ToConn for TcpConn {
- fn to_conn(self) -> Box<dyn Connection> {
+impl<C> ToConn for TcpConn<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
+ fn to_conn(self) -> Conn<Self::Item> {
Box::new(self)
}
}
-impl ToListener for TcpListener {
- fn to_listener(self) -> Box<dyn ConnListener> {
+impl<C> ToListener for TcpListener<C>
+where
+ C: Clone + Codec,
+{
+ type Item = C::Item;
+ fn to_listener(self) -> Listener<Self::Item> {
self.into()
}
}
diff --git a/net/src/transports/tls.rs b/net/src/transports/tls.rs
index 476f495..c972f63 100644
--- a/net/src/transports/tls.rs
+++ b/net/src/transports/tls.rs
@@ -1,138 +1,218 @@
use std::{net::SocketAddr, sync::Arc};
use async_trait::async_trait;
-use futures_rustls::{pki_types, rustls, TlsAcceptor, TlsConnector, TlsStream};
-use smol::{
- io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
+use futures_util::SinkExt;
+use rustls_pki_types as pki_types;
+
+#[cfg(feature = "smol")]
+use futures_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream};
+#[cfg(feature = "tokio")]
+use tokio_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream};
+
+use karyon_core::async_runtime::{
+ io::{split, ReadHalf, WriteHalf},
lock::Mutex,
net::{TcpListener, TcpStream},
};
use crate::{
- connection::{Connection, ToConn},
+ codec::Codec,
+ connection::{Conn, Connection, ToConn},
endpoint::Endpoint,
- listener::{ConnListener, ToListener},
- Error, Result,
+ listener::{ConnListener, Listener, ToListener},
+ stream::{ReadStream, WriteStream},
+ Result,
};
+use super::tcp::TcpConfig;
+
+/// TLS configuration
+#[derive(Clone)]
+pub struct ServerTlsConfig {
+ pub tcp_config: TcpConfig,
+ pub server_config: rustls::ServerConfig,
+}
+
+#[derive(Clone)]
+pub struct ClientTlsConfig {
+ pub tcp_config: TcpConfig,
+ pub client_config: rustls::ClientConfig,
+ pub dns_name: String,
+}
+
/// TLS network connection implementation of the [`Connection`] trait.
-pub struct TlsConn {
- inner: TcpStream,
- read: Mutex<ReadHalf<TlsStream<TcpStream>>>,
- write: Mutex<WriteHalf<TlsStream<TcpStream>>>,
+pub struct TlsConn<C> {
+ read_stream: Mutex<ReadStream<ReadHalf<TlsStream<TcpStream>>, C>>,
+ write_stream: Mutex<WriteStream<WriteHalf<TlsStream<TcpStream>>, C>>,
+ peer_endpoint: Endpoint,
+ local_endpoint: Endpoint,
}
-impl TlsConn {
+impl<C> TlsConn<C>
+where
+ C: Codec + Clone,
+{
/// Creates a new TlsConn
- pub fn new(sock: TcpStream, conn: TlsStream<TcpStream>) -> Self {
+ pub fn new(
+ conn: TlsStream<TcpStream>,
+ codec: C,
+ peer_endpoint: Endpoint,
+ local_endpoint: Endpoint,
+ ) -> Self {
let (read, write) = split(conn);
+ let read_stream = Mutex::new(ReadStream::new(read, codec.clone()));
+ let write_stream = Mutex::new(WriteStream::new(write, codec));
Self {
- inner: sock,
- read: Mutex::new(read),
- write: Mutex::new(write),
+ read_stream,
+ write_stream,
+ peer_endpoint,
+ local_endpoint,
}
}
}
#[async_trait]
-impl Connection for TlsConn {
+impl<C> Connection for TlsConn<C>
+where
+ C: Clone + Codec,
+{
+ type Item = C::Item;
fn peer_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_tls_addr(&self.inner.peer_addr()?))
+ Ok(self.peer_endpoint.clone())
}
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?))
+ Ok(self.local_endpoint.clone())
}
- async fn read(&self, buf: &mut [u8]) -> Result<usize> {
- self.read.lock().await.read(buf).await.map_err(Error::from)
+ async fn recv(&self) -> Result<Self::Item> {
+ self.read_stream.lock().await.recv().await
}
- async fn write(&self, buf: &[u8]) -> Result<usize> {
- self.write
- .lock()
- .await
- .write(buf)
- .await
- .map_err(Error::from)
+ async fn send(&self, msg: Self::Item) -> Result<()> {
+ self.write_stream.lock().await.send(msg).await
}
}
/// Connects to the given TLS address and port.
-pub async fn dial(
- endpoint: &Endpoint,
- config: rustls::ClientConfig,
- dns_name: &'static str,
-) -> Result<TlsConn> {
+pub async fn dial<C>(endpoint: &Endpoint, config: ClientTlsConfig, codec: C) -> Result<TlsConn<C>>
+where
+ C: Codec + Clone,
+{
let addr = SocketAddr::try_from(endpoint.clone())?;
- let connector = TlsConnector::from(Arc::new(config));
+ let connector = TlsConnector::from(Arc::new(config.client_config.clone()));
+
+ let socket = TcpStream::connect(addr).await?;
+ socket.set_nodelay(config.tcp_config.nodelay)?;
- let sock = TcpStream::connect(addr).await?;
- sock.set_nodelay(true)?;
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?;
- let altname = pki_types::ServerName::try_from(dns_name)?;
- let conn = connector.connect(altname, sock.clone()).await?;
- Ok(TlsConn::new(sock, TlsStream::Client(conn)))
+ let altname = pki_types::ServerName::try_from(config.dns_name.clone())?;
+ let conn = connector.connect(altname, socket).await?;
+ Ok(TlsConn::new(
+ TlsStream::Client(conn),
+ codec,
+ peer_endpoint,
+ local_endpoint,
+ ))
}
/// Tls network listener implementation of the `Listener` [`ConnListener`] trait.
-pub struct TlsListener {
+pub struct TlsListener<C> {
inner: TcpListener,
acceptor: TlsAcceptor,
+ config: ServerTlsConfig,
+ codec: C,
+}
+
+impl<C> TlsListener<C>
+where
+ C: Codec + Clone,
+{
+ pub fn new(
+ acceptor: TlsAcceptor,
+ listener: TcpListener,
+ config: ServerTlsConfig,
+ codec: C,
+ ) -> Self {
+ Self {
+ inner: listener,
+ acceptor,
+ config: config.clone(),
+ codec,
+ }
+ }
}
#[async_trait]
-impl ConnListener for TlsListener {
+impl<C> ConnListener for TlsListener<C>
+where
+ C: Clone + Codec,
+{
+ type Item = C::Item;
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?))
+ Ok(Endpoint::new_tls_addr(self.inner.local_addr()?))
}
- async fn accept(&self) -> Result<Box<dyn Connection>> {
- let (sock, _) = self.inner.accept().await?;
- sock.set_nodelay(true)?;
- let conn = self.acceptor.accept(sock.clone()).await?;
- Ok(Box::new(TlsConn::new(sock, TlsStream::Server(conn))))
+ async fn accept(&self) -> Result<Conn<C::Item>> {
+ let (socket, _) = self.inner.accept().await?;
+ socket.set_nodelay(self.config.tcp_config.nodelay)?;
+
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?;
+
+ let conn = self.acceptor.accept(socket).await?;
+ Ok(Box::new(TlsConn::new(
+ TlsStream::Server(conn),
+ self.codec.clone(),
+ peer_endpoint,
+ local_endpoint,
+ )))
}
}
/// Listens on the given TLS address and port.
-pub async fn listen(endpoint: &Endpoint, config: rustls::ServerConfig) -> Result<TlsListener> {
+pub async fn listen<C>(
+ endpoint: &Endpoint,
+ config: ServerTlsConfig,
+ codec: C,
+) -> Result<TlsListener<C>>
+where
+ C: Clone + Codec,
+{
let addr = SocketAddr::try_from(endpoint.clone())?;
- let acceptor = TlsAcceptor::from(Arc::new(config));
+ let acceptor = TlsAcceptor::from(Arc::new(config.server_config.clone()));
let listener = TcpListener::bind(addr).await?;
- Ok(TlsListener {
- acceptor,
- inner: listener,
- })
-}
-
-impl From<TlsStream<TcpStream>> for Box<dyn Connection> {
- fn from(conn: TlsStream<TcpStream>) -> Self {
- Box::new(TlsConn::new(conn.get_ref().0.clone(), conn))
- }
+ Ok(TlsListener::new(acceptor, listener, config, codec))
}
-impl From<TlsListener> for Box<dyn ConnListener> {
- fn from(listener: TlsListener) -> Self {
+impl<C> From<TlsListener<C>> for Listener<C::Item>
+where
+ C: Codec + Clone,
+{
+ fn from(listener: TlsListener<C>) -> Self {
Box::new(listener)
}
}
-impl ToConn for TlsStream<TcpStream> {
- fn to_conn(self) -> Box<dyn Connection> {
- self.into()
- }
-}
-
-impl ToConn for TlsConn {
- fn to_conn(self) -> Box<dyn Connection> {
+impl<C> ToConn for TlsConn<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
+ fn to_conn(self) -> Conn<Self::Item> {
Box::new(self)
}
}
-impl ToListener for TlsListener {
- fn to_listener(self) -> Box<dyn ConnListener> {
+impl<C> ToListener for TlsListener<C>
+where
+ C: Clone + Codec,
+{
+ type Item = C::Item;
+ fn to_listener(self) -> Listener<Self::Item> {
self.into()
}
}
diff --git a/net/src/transports/udp.rs b/net/src/transports/udp.rs
index bd1fe83..950537c 100644
--- a/net/src/transports/udp.rs
+++ b/net/src/transports/udp.rs
@@ -1,93 +1,111 @@
use std::net::SocketAddr;
use async_trait::async_trait;
-use smol::net::UdpSocket;
+use karyon_core::async_runtime::net::UdpSocket;
use crate::{
- connection::{Connection, ToConn},
+ codec::Codec,
+ connection::{Conn, Connection, ToConn},
endpoint::Endpoint,
Error, Result,
};
+const BUFFER_SIZE: usize = 64 * 1024;
+
+/// UDP configuration
+#[derive(Default)]
+pub struct UdpConfig {}
+
/// UDP network connection implementation of the [`Connection`] trait.
-pub struct UdpConn {
+#[allow(dead_code)]
+pub struct UdpConn<C> {
inner: UdpSocket,
+ codec: C,
+ config: UdpConfig,
}
-impl UdpConn {
+impl<C> UdpConn<C>
+where
+ C: Codec + Clone,
+{
/// Creates a new UdpConn
- pub fn new(conn: UdpSocket) -> Self {
- Self { inner: conn }
- }
-}
-
-impl UdpConn {
- /// Receives a single datagram message. Returns the number of bytes read
- /// and the origin endpoint.
- pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, Endpoint)> {
- let (size, addr) = self.inner.recv_from(buf).await?;
- Ok((size, Endpoint::new_udp_addr(&addr)))
- }
-
- /// Sends data to the given address. Returns the number of bytes written.
- pub async fn send_to(&self, buf: &[u8], addr: &Endpoint) -> Result<usize> {
- let addr: SocketAddr = addr.clone().try_into()?;
- let size = self.inner.send_to(buf, addr).await?;
- Ok(size)
+ fn new(socket: UdpSocket, config: UdpConfig, codec: C) -> Self {
+ Self {
+ inner: socket,
+ codec,
+ config,
+ }
}
}
#[async_trait]
-impl Connection for UdpConn {
+impl<C> Connection for UdpConn<C>
+where
+ C: Codec + Clone,
+{
+ type Item = (C::Item, Endpoint);
fn peer_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_udp_addr(&self.inner.peer_addr()?))
+ self.inner
+ .peer_addr()
+ .map(Endpoint::new_udp_addr)
+ .map_err(Error::from)
}
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_udp_addr(&self.inner.local_addr()?))
+ self.inner
+ .local_addr()
+ .map(Endpoint::new_udp_addr)
+ .map_err(Error::from)
}
- async fn read(&self, buf: &mut [u8]) -> Result<usize> {
- self.inner.recv(buf).await.map_err(Error::from)
+ async fn recv(&self) -> Result<Self::Item> {
+ let mut buf = [0u8; BUFFER_SIZE];
+ let (_, addr) = self.inner.recv_from(&mut buf).await?;
+ match self.codec.decode(&mut buf)? {
+ Some((_, item)) => Ok((item, Endpoint::new_udp_addr(addr))),
+ None => Err(Error::Decode("Unable to decode".into())),
+ }
}
- async fn write(&self, buf: &[u8]) -> Result<usize> {
- self.inner.send(buf).await.map_err(Error::from)
+ async fn send(&self, msg: Self::Item) -> Result<()> {
+ let (msg, out_addr) = msg;
+ let mut buf = [0u8; BUFFER_SIZE];
+ self.codec.encode(&msg, &mut buf)?;
+ let addr: SocketAddr = out_addr.try_into()?;
+ self.inner.send_to(&buf, addr).await?;
+ Ok(())
}
}
/// Connects to the given UDP address and port.
-pub async fn dial(endpoint: &Endpoint) -> Result<UdpConn> {
+pub async fn dial<C>(endpoint: &Endpoint, config: UdpConfig, codec: C) -> Result<UdpConn<C>>
+where
+ C: Codec + Clone,
+{
let addr = SocketAddr::try_from(endpoint.clone())?;
// Let the operating system assign an available port to this socket
let conn = UdpSocket::bind("[::]:0").await?;
conn.connect(addr).await?;
- Ok(UdpConn::new(conn))
+ Ok(UdpConn::new(conn, config, codec))
}
/// Listens on the given UDP address and port.
-pub async fn listen(endpoint: &Endpoint) -> Result<UdpConn> {
+pub async fn listen<C>(endpoint: &Endpoint, config: UdpConfig, codec: C) -> Result<UdpConn<C>>
+where
+ C: Codec + Clone,
+{
let addr = SocketAddr::try_from(endpoint.clone())?;
let conn = UdpSocket::bind(addr).await?;
- let udp_conn = UdpConn::new(conn);
- Ok(udp_conn)
-}
-
-impl From<UdpSocket> for Box<dyn Connection> {
- fn from(conn: UdpSocket) -> Self {
- Box::new(UdpConn::new(conn))
- }
-}
-
-impl ToConn for UdpSocket {
- fn to_conn(self) -> Box<dyn Connection> {
- self.into()
- }
+ Ok(UdpConn::new(conn, config, codec))
}
-impl ToConn for UdpConn {
- fn to_conn(self) -> Box<dyn Connection> {
+impl<C> ToConn for UdpConn<C>
+where
+ C: Codec + Clone,
+{
+ type Item = (C::Item, Endpoint);
+ fn to_conn(self) -> Conn<Self::Item> {
Box::new(self)
}
}
diff --git a/net/src/transports/unix.rs b/net/src/transports/unix.rs
index 494e104..bafebaf 100644
--- a/net/src/transports/unix.rs
+++ b/net/src/transports/unix.rs
@@ -1,111 +1,192 @@
use async_trait::async_trait;
+use futures_util::SinkExt;
-use smol::{
- io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
+use karyon_core::async_runtime::{
+ io::{split, ReadHalf, WriteHalf},
lock::Mutex,
- net::unix::{UnixListener, UnixStream},
+ net::{UnixListener as AsyncUnixListener, UnixStream},
};
use crate::{
- connection::{Connection, ToConn},
+ codec::Codec,
+ connection::{Conn, Connection, ToConn},
endpoint::Endpoint,
- listener::{ConnListener, ToListener},
+ listener::{ConnListener, Listener, ToListener},
+ stream::{ReadStream, WriteStream},
Error, Result,
};
+/// Unix Conn config
+#[derive(Clone, Default)]
+pub struct UnixConfig {}
+
/// Unix domain socket implementation of the [`Connection`] trait.
-pub struct UnixConn {
- inner: UnixStream,
- read: Mutex<ReadHalf<UnixStream>>,
- write: Mutex<WriteHalf<UnixStream>>,
+pub struct UnixConn<C> {
+ read_stream: Mutex<ReadStream<ReadHalf<UnixStream>, C>>,
+ write_stream: Mutex<WriteStream<WriteHalf<UnixStream>, C>>,
+ peer_endpoint: Option<Endpoint>,
+ local_endpoint: Option<Endpoint>,
}
-impl UnixConn {
- /// Creates a new UnixConn
- pub fn new(conn: UnixStream) -> Self {
- let (read, write) = split(conn.clone());
+impl<C> UnixConn<C>
+where
+ C: Codec + Clone,
+{
+ /// Creates a new TcpConn
+ pub fn new(conn: UnixStream, codec: C) -> Self {
+ let peer_endpoint = conn
+ .peer_addr()
+ .and_then(|a| {
+ Ok(Endpoint::new_unix_addr(
+ a.as_pathname()
+ .ok_or(std::io::ErrorKind::AddrNotAvailable)?,
+ ))
+ })
+ .ok();
+ let local_endpoint = conn
+ .local_addr()
+ .and_then(|a| {
+ Ok(Endpoint::new_unix_addr(
+ a.as_pathname()
+ .ok_or(std::io::ErrorKind::AddrNotAvailable)?,
+ ))
+ })
+ .ok();
+
+ let (read, write) = split(conn);
+ let read_stream = Mutex::new(ReadStream::new(read, codec.clone()));
+ let write_stream = Mutex::new(WriteStream::new(write, codec));
Self {
- inner: conn,
- read: Mutex::new(read),
- write: Mutex::new(write),
+ read_stream,
+ write_stream,
+ peer_endpoint,
+ local_endpoint,
}
}
}
#[async_trait]
-impl Connection for UnixConn {
+impl<C> Connection for UnixConn<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
fn peer_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_unix_addr(&self.inner.peer_addr()?))
+ self.peer_endpoint
+ .clone()
+ .ok_or(Error::IO(std::io::ErrorKind::AddrNotAvailable.into()))
}
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_unix_addr(&self.inner.local_addr()?))
+ self.local_endpoint
+ .clone()
+ .ok_or(Error::IO(std::io::ErrorKind::AddrNotAvailable.into()))
}
- async fn read(&self, buf: &mut [u8]) -> Result<usize> {
- self.read.lock().await.read(buf).await.map_err(Error::from)
+ async fn recv(&self) -> Result<Self::Item> {
+ self.read_stream.lock().await.recv().await
}
- async fn write(&self, buf: &[u8]) -> Result<usize> {
- self.write
- .lock()
- .await
- .write(buf)
- .await
- .map_err(Error::from)
+ async fn send(&self, msg: Self::Item) -> Result<()> {
+ self.write_stream.lock().await.send(msg).await
+ }
+}
+
+#[allow(dead_code)]
+pub struct UnixListener<C> {
+ inner: AsyncUnixListener,
+ config: UnixConfig,
+ codec: C,
+}
+
+impl<C> UnixListener<C>
+where
+ C: Codec + Clone,
+{
+ pub fn new(listener: AsyncUnixListener, config: UnixConfig, codec: C) -> Self {
+ Self {
+ inner: listener,
+ config,
+ codec,
+ }
}
}
#[async_trait]
-impl ConnListener for UnixListener {
+impl<C> ConnListener for UnixListener<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_unix_addr(&self.local_addr()?))
+ self.inner
+ .local_addr()
+ .and_then(|a| {
+ Ok(Endpoint::new_unix_addr(
+ a.as_pathname()
+ .ok_or(std::io::ErrorKind::AddrNotAvailable)?,
+ ))
+ })
+ .map_err(Error::from)
}
- async fn accept(&self) -> Result<Box<dyn Connection>> {
- let (conn, _) = self.accept().await?;
- Ok(Box::new(UnixConn::new(conn)))
+ async fn accept(&self) -> Result<Conn<C::Item>> {
+ let (conn, _) = self.inner.accept().await?;
+ Ok(Box::new(UnixConn::new(conn, self.codec.clone())))
}
}
/// Connects to the given Unix socket path.
-pub async fn dial(path: &String) -> Result<UnixConn> {
+pub async fn dial<C>(endpoint: &Endpoint, _config: UnixConfig, codec: C) -> Result<UnixConn<C>>
+where
+ C: Codec + Clone,
+{
+ let path: std::path::PathBuf = endpoint.clone().try_into()?;
let conn = UnixStream::connect(path).await?;
- Ok(UnixConn::new(conn))
+ Ok(UnixConn::new(conn, codec))
}
/// Listens on the given Unix socket path.
-pub fn listen(path: &String) -> Result<UnixListener> {
- let listener = UnixListener::bind(path)?;
- Ok(listener)
-}
-
-impl From<UnixStream> for Box<dyn Connection> {
- fn from(conn: UnixStream) -> Self {
- Box::new(UnixConn::new(conn))
- }
+pub fn listen<C>(endpoint: &Endpoint, config: UnixConfig, codec: C) -> Result<UnixListener<C>>
+where
+ C: Codec + Clone,
+{
+ let path: std::path::PathBuf = endpoint.clone().try_into()?;
+ let listener = AsyncUnixListener::bind(path)?;
+ Ok(UnixListener::new(listener, config, codec))
}
-impl From<UnixListener> for Box<dyn ConnListener> {
- fn from(listener: UnixListener) -> Self {
+// impl From<UnixStream> for Box<dyn Connection> {
+// fn from(conn: UnixStream) -> Self {
+// Box::new(UnixConn::new(conn))
+// }
+// }
+
+impl<C> From<UnixListener<C>> for Listener<C::Item>
+where
+ C: Codec + Clone,
+{
+ fn from(listener: UnixListener<C>) -> Self {
Box::new(listener)
}
}
-impl ToConn for UnixStream {
- fn to_conn(self) -> Box<dyn Connection> {
- self.into()
- }
-}
-
-impl ToConn for UnixConn {
- fn to_conn(self) -> Box<dyn Connection> {
+impl<C> ToConn for UnixConn<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
+ fn to_conn(self) -> Conn<Self::Item> {
Box::new(self)
}
}
-impl ToListener for UnixListener {
- fn to_listener(self) -> Box<dyn ConnListener> {
+impl<C> ToListener for UnixListener<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
+ fn to_listener(self) -> Listener<Self::Item> {
self.into()
}
}
diff --git a/net/src/transports/ws.rs b/net/src/transports/ws.rs
index eaf3b9b..17fe924 100644
--- a/net/src/transports/ws.rs
+++ b/net/src/transports/ws.rs
@@ -1,112 +1,254 @@
-use std::net::SocketAddr;
+use std::{net::SocketAddr, sync::Arc};
use async_trait::async_trait;
-use smol::{
- io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
+use rustls_pki_types as pki_types;
+
+#[cfg(feature = "tokio")]
+use async_tungstenite::tokio as async_tungstenite;
+
+#[cfg(feature = "smol")]
+use futures_rustls::{rustls, TlsAcceptor, TlsConnector};
+#[cfg(feature = "tokio")]
+use tokio_rustls::{rustls, TlsAcceptor, TlsConnector};
+
+use karyon_core::async_runtime::{
lock::Mutex,
net::{TcpListener, TcpStream},
};
-use ws_stream_tungstenite::WsStream;
-
use crate::{
- connection::{Connection, ToConn},
+ codec::WebSocketCodec,
+ connection::{Conn, Connection, ToConn},
endpoint::Endpoint,
- listener::{ConnListener, ToListener},
- Error, Result,
+ listener::{ConnListener, Listener, ToListener},
+ stream::WsStream,
+ Result,
};
+use super::tcp::TcpConfig;
+
+/// WSS configuration
+#[derive(Clone)]
+pub struct ServerWssConfig {
+ pub server_config: rustls::ServerConfig,
+}
+
+/// WSS configuration
+#[derive(Clone)]
+pub struct ClientWssConfig {
+ pub client_config: rustls::ClientConfig,
+ pub dns_name: String,
+}
+
+/// WS configuration
+#[derive(Clone, Default)]
+pub struct ServerWsConfig {
+ pub tcp_config: TcpConfig,
+ pub wss_config: Option<ServerWssConfig>,
+}
+
+/// WS configuration
+#[derive(Clone, Default)]
+pub struct ClientWsConfig {
+ pub tcp_config: TcpConfig,
+ pub wss_config: Option<ClientWssConfig>,
+}
+
/// WS network connection implementation of the [`Connection`] trait.
-pub struct WsConn {
- inner: TcpStream,
- read: Mutex<ReadHalf<WsStream<TcpStream>>>,
- write: Mutex<WriteHalf<WsStream<TcpStream>>>,
+pub struct WsConn<C> {
+ // XXX: remove mutex
+ inner: Mutex<WsStream<C>>,
+ peer_endpoint: Endpoint,
+ local_endpoint: Endpoint,
}
-impl WsConn {
+impl<C> WsConn<C>
+where
+ C: WebSocketCodec,
+{
/// Creates a new WsConn
- pub fn new(inner: TcpStream, conn: WsStream<TcpStream>) -> Self {
- let (read, write) = split(conn);
+ pub fn new(ws: WsStream<C>, peer_endpoint: Endpoint, local_endpoint: Endpoint) -> Self {
Self {
- inner,
- read: Mutex::new(read),
- write: Mutex::new(write),
+ inner: Mutex::new(ws),
+ peer_endpoint,
+ local_endpoint,
}
}
}
#[async_trait]
-impl Connection for WsConn {
+impl<C> Connection for WsConn<C>
+where
+ C: WebSocketCodec,
+{
+ type Item = C::Item;
fn peer_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_ws_addr(&self.inner.peer_addr()?))
+ Ok(self.peer_endpoint.clone())
}
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?))
+ Ok(self.local_endpoint.clone())
}
- async fn read(&self, buf: &mut [u8]) -> Result<usize> {
- self.read.lock().await.read(buf).await.map_err(Error::from)
+ async fn recv(&self) -> Result<Self::Item> {
+ self.inner.lock().await.recv().await
}
- async fn write(&self, buf: &[u8]) -> Result<usize> {
- self.write
- .lock()
- .await
- .write(buf)
- .await
- .map_err(Error::from)
+ async fn send(&self, msg: Self::Item) -> Result<()> {
+ self.inner.lock().await.send(msg).await
}
}
/// Ws network listener implementation of the `Listener` [`ConnListener`] trait.
-pub struct WsListener {
+pub struct WsListener<C> {
inner: TcpListener,
+ config: ServerWsConfig,
+ codec: C,
+ tls_acceptor: Option<TlsAcceptor>,
}
#[async_trait]
-impl ConnListener for WsListener {
+impl<C> ConnListener for WsListener<C>
+where
+ C: WebSocketCodec + Clone,
+{
+ type Item = C::Item;
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?))
+ match self.config.wss_config {
+ Some(_) => Ok(Endpoint::new_wss_addr(self.inner.local_addr()?)),
+ None => Ok(Endpoint::new_ws_addr(self.inner.local_addr()?)),
+ }
}
- async fn accept(&self) -> Result<Box<dyn Connection>> {
- let (stream, _) = self.inner.accept().await?;
- let conn = async_tungstenite::accept_async(stream.clone()).await?;
- Ok(Box::new(WsConn::new(stream, WsStream::new(conn))))
+ async fn accept(&self) -> Result<Conn<Self::Item>> {
+ let (socket, _) = self.inner.accept().await?;
+ socket.set_nodelay(self.config.tcp_config.nodelay)?;
+
+ match &self.config.wss_config {
+ Some(_) => match &self.tls_acceptor {
+ Some(acceptor) => {
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?;
+
+ let tls_conn = acceptor.accept(socket).await?.into();
+ let conn = async_tungstenite::accept_async(tls_conn).await?;
+ Ok(Box::new(WsConn::new(
+ WsStream::new_wss(conn, self.codec.clone()),
+ peer_endpoint,
+ local_endpoint,
+ )))
+ }
+ None => unreachable!(),
+ },
+ None => {
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?;
+
+ let conn = async_tungstenite::accept_async(socket).await?;
+
+ Ok(Box::new(WsConn::new(
+ WsStream::new_ws(conn, self.codec.clone()),
+ peer_endpoint,
+ local_endpoint,
+ )))
+ }
+ }
}
}
/// Connects to the given WS address and port.
-pub async fn dial(endpoint: &Endpoint) -> Result<WsConn> {
+pub async fn dial<C>(endpoint: &Endpoint, config: ClientWsConfig, codec: C) -> Result<WsConn<C>>
+where
+ C: WebSocketCodec,
+{
let addr = SocketAddr::try_from(endpoint.clone())?;
- let stream = TcpStream::connect(addr).await?;
- let (conn, _resp) =
- async_tungstenite::client_async(endpoint.to_string(), stream.clone()).await?;
- Ok(WsConn::new(stream, WsStream::new(conn)))
+ let socket = TcpStream::connect(addr).await?;
+ socket.set_nodelay(config.tcp_config.nodelay)?;
+
+ match &config.wss_config {
+ Some(conf) => {
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?;
+
+ let connector = TlsConnector::from(Arc::new(conf.client_config.clone()));
+
+ let altname = pki_types::ServerName::try_from(conf.dns_name.clone())?;
+ let tls_conn = connector.connect(altname, socket).await?.into();
+ let (conn, _resp) =
+ async_tungstenite::client_async(endpoint.to_string(), tls_conn).await?;
+ Ok(WsConn::new(
+ WsStream::new_wss(conn, codec),
+ peer_endpoint,
+ local_endpoint,
+ ))
+ }
+ None => {
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?;
+ let (conn, _resp) =
+ async_tungstenite::client_async(endpoint.to_string(), socket).await?;
+ Ok(WsConn::new(
+ WsStream::new_ws(conn, codec),
+ peer_endpoint,
+ local_endpoint,
+ ))
+ }
+ }
}
/// Listens on the given WS address and port.
-pub async fn listen(endpoint: &Endpoint) -> Result<WsListener> {
+pub async fn listen<C>(
+ endpoint: &Endpoint,
+ config: ServerWsConfig,
+ codec: C,
+) -> Result<WsListener<C>> {
let addr = SocketAddr::try_from(endpoint.clone())?;
+
let listener = TcpListener::bind(addr).await?;
- Ok(WsListener { inner: listener })
+ match &config.wss_config {
+ Some(conf) => {
+ let acceptor = TlsAcceptor::from(Arc::new(conf.server_config.clone()));
+ Ok(WsListener {
+ inner: listener,
+ config,
+ codec,
+ tls_acceptor: Some(acceptor),
+ })
+ }
+ None => Ok(WsListener {
+ inner: listener,
+ config,
+ codec,
+ tls_acceptor: None,
+ }),
+ }
}
-impl From<WsListener> for Box<dyn ConnListener> {
- fn from(listener: WsListener) -> Self {
+impl<C> From<WsListener<C>> for Listener<C::Item>
+where
+ C: WebSocketCodec + Clone,
+{
+ fn from(listener: WsListener<C>) -> Self {
Box::new(listener)
}
}
-impl ToConn for WsConn {
- fn to_conn(self) -> Box<dyn Connection> {
+impl<C> ToConn for WsConn<C>
+where
+ C: WebSocketCodec,
+{
+ type Item = C::Item;
+ fn to_conn(self) -> Conn<Self::Item> {
Box::new(self)
}
}
-impl ToListener for WsListener {
- fn to_listener(self) -> Box<dyn ConnListener> {
+impl<C> ToListener for WsListener<C>
+where
+ C: WebSocketCodec + Clone,
+{
+ type Item = C::Item;
+ fn to_listener(self) -> Listener<Self::Item> {
self.into()
}
}