aboutsummaryrefslogtreecommitdiff
path: root/net/src/transports
diff options
context:
space:
mode:
authorhozan23 <hozan23@karyontech.net>2024-04-11 10:19:20 +0200
committerhozan23 <hozan23@karyontech.net>2024-05-19 13:51:30 +0200
commit0992071a7f1a36424bcfaf1fbc84541ea041df1a (patch)
tree961d73218af672797d49f899289bef295bc56493 /net/src/transports
parenta69917ecd8272a4946cfd12c75bf8f8c075b0e50 (diff)
add support for tokio & improve net crate api
Diffstat (limited to 'net/src/transports')
-rw-r--r--net/src/transports/tcp.rs188
-rw-r--r--net/src/transports/tls.rs220
-rw-r--r--net/src/transports/udp.rs114
-rw-r--r--net/src/transports/unix.rs193
-rw-r--r--net/src/transports/ws.rs242
5 files changed, 673 insertions, 284 deletions
diff --git a/net/src/transports/tcp.rs b/net/src/transports/tcp.rs
index 21fce3d..03c8ab2 100644
--- a/net/src/transports/tcp.rs
+++ b/net/src/transports/tcp.rs
@@ -1,116 +1,184 @@
use std::net::SocketAddr;
use async_trait::async_trait;
-use smol::{
- io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
+use futures_util::SinkExt;
+
+use karyon_core::async_runtime::{
+ io::{split, ReadHalf, WriteHalf},
lock::Mutex,
- net::{TcpListener, TcpStream},
+ net::{TcpListener as AsyncTcpListener, TcpStream},
};
use crate::{
- connection::{Connection, ToConn},
+ codec::Codec,
+ connection::{Conn, Connection, ToConn},
endpoint::Endpoint,
- listener::{ConnListener, ToListener},
- Error, Result,
+ listener::{ConnListener, Listener, ToListener},
+ stream::{ReadStream, WriteStream},
+ Result,
};
-/// TCP network connection implementation of the [`Connection`] trait.
-pub struct TcpConn {
- inner: TcpStream,
- read: Mutex<ReadHalf<TcpStream>>,
- write: Mutex<WriteHalf<TcpStream>>,
+/// TCP configuration
+#[derive(Clone)]
+pub struct TcpConfig {
+ pub nodelay: bool,
+}
+
+impl Default for TcpConfig {
+ fn default() -> Self {
+ Self { nodelay: true }
+ }
+}
+
+/// TCP connection implementation of the [`Connection`] trait.
+pub struct TcpConn<C> {
+ read_stream: Mutex<ReadStream<ReadHalf<TcpStream>, C>>,
+ write_stream: Mutex<WriteStream<WriteHalf<TcpStream>, C>>,
+ peer_endpoint: Endpoint,
+ local_endpoint: Endpoint,
}
-impl TcpConn {
+impl<C> TcpConn<C>
+where
+ C: Codec + Clone,
+{
/// Creates a new TcpConn
- pub fn new(conn: TcpStream) -> Self {
- let (read, write) = split(conn.clone());
+ pub fn new(
+ socket: TcpStream,
+ codec: C,
+ peer_endpoint: Endpoint,
+ local_endpoint: Endpoint,
+ ) -> Self {
+ let (read, write) = split(socket);
+ let read_stream = Mutex::new(ReadStream::new(read, codec.clone()));
+ let write_stream = Mutex::new(WriteStream::new(write, codec));
Self {
- inner: conn,
- read: Mutex::new(read),
- write: Mutex::new(write),
+ read_stream,
+ write_stream,
+ peer_endpoint,
+ local_endpoint,
}
}
}
#[async_trait]
-impl Connection for TcpConn {
+impl<C> Connection for TcpConn<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
fn peer_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_tcp_addr(&self.inner.peer_addr()?))
+ Ok(self.peer_endpoint.clone())
}
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_tcp_addr(&self.inner.local_addr()?))
+ Ok(self.local_endpoint.clone())
}
- async fn read(&self, buf: &mut [u8]) -> Result<usize> {
- self.read.lock().await.read(buf).await.map_err(Error::from)
+ async fn recv(&self) -> Result<Self::Item> {
+ self.read_stream.lock().await.recv().await
}
- async fn write(&self, buf: &[u8]) -> Result<usize> {
- self.write
- .lock()
- .await
- .write(buf)
- .await
- .map_err(Error::from)
+ async fn send(&self, msg: Self::Item) -> Result<()> {
+ self.write_stream.lock().await.send(msg).await
+ }
+}
+
+pub struct TcpListener<C> {
+ inner: AsyncTcpListener,
+ config: TcpConfig,
+ codec: C,
+}
+
+impl<C> TcpListener<C>
+where
+ C: Codec,
+{
+ pub fn new(listener: AsyncTcpListener, config: TcpConfig, codec: C) -> Self {
+ Self {
+ inner: listener,
+ config: config.clone(),
+ codec,
+ }
}
}
#[async_trait]
-impl ConnListener for TcpListener {
+impl<C> ConnListener for TcpListener<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_tcp_addr(&self.local_addr()?))
+ Ok(Endpoint::new_tcp_addr(self.inner.local_addr()?))
}
- async fn accept(&self) -> Result<Box<dyn Connection>> {
- let (conn, _) = self.accept().await?;
- conn.set_nodelay(true)?;
- Ok(Box::new(TcpConn::new(conn)))
+ async fn accept(&self) -> Result<Conn<C::Item>> {
+ let (socket, _) = self.inner.accept().await?;
+ socket.set_nodelay(self.config.nodelay)?;
+
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_tcp_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_tcp_addr)?;
+
+ Ok(Box::new(TcpConn::new(
+ socket,
+ self.codec.clone(),
+ peer_endpoint,
+ local_endpoint,
+ )))
}
}
/// Connects to the given TCP address and port.
-pub async fn dial(endpoint: &Endpoint) -> Result<TcpConn> {
+pub async fn dial<C>(endpoint: &Endpoint, config: TcpConfig, codec: C) -> Result<TcpConn<C>>
+where
+ C: Codec + Clone,
+{
let addr = SocketAddr::try_from(endpoint.clone())?;
- let conn = TcpStream::connect(addr).await?;
- conn.set_nodelay(true)?;
- Ok(TcpConn::new(conn))
+ let socket = TcpStream::connect(addr).await?;
+ socket.set_nodelay(config.nodelay)?;
+
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_tcp_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_tcp_addr)?;
+
+ Ok(TcpConn::new(socket, codec, peer_endpoint, local_endpoint))
}
/// Listens on the given TCP address and port.
-pub async fn listen(endpoint: &Endpoint) -> Result<TcpListener> {
+pub async fn listen<C>(endpoint: &Endpoint, config: TcpConfig, codec: C) -> Result<TcpListener<C>>
+where
+ C: Codec,
+{
let addr = SocketAddr::try_from(endpoint.clone())?;
- let listener = TcpListener::bind(addr).await?;
- Ok(listener)
-}
-
-impl From<TcpStream> for Box<dyn Connection> {
- fn from(conn: TcpStream) -> Self {
- Box::new(TcpConn::new(conn))
- }
+ let listener = AsyncTcpListener::bind(addr).await?;
+ Ok(TcpListener::new(listener, config, codec))
}
-impl From<TcpListener> for Box<dyn ConnListener> {
- fn from(listener: TcpListener) -> Self {
+impl<C> From<TcpListener<C>> for Box<dyn ConnListener<Item = C::Item>>
+where
+ C: Clone + Codec,
+{
+ fn from(listener: TcpListener<C>) -> Self {
Box::new(listener)
}
}
-impl ToConn for TcpStream {
- fn to_conn(self) -> Box<dyn Connection> {
- self.into()
- }
-}
-
-impl ToConn for TcpConn {
- fn to_conn(self) -> Box<dyn Connection> {
+impl<C> ToConn for TcpConn<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
+ fn to_conn(self) -> Conn<Self::Item> {
Box::new(self)
}
}
-impl ToListener for TcpListener {
- fn to_listener(self) -> Box<dyn ConnListener> {
+impl<C> ToListener for TcpListener<C>
+where
+ C: Clone + Codec,
+{
+ type Item = C::Item;
+ fn to_listener(self) -> Listener<Self::Item> {
self.into()
}
}
diff --git a/net/src/transports/tls.rs b/net/src/transports/tls.rs
index 476f495..c972f63 100644
--- a/net/src/transports/tls.rs
+++ b/net/src/transports/tls.rs
@@ -1,138 +1,218 @@
use std::{net::SocketAddr, sync::Arc};
use async_trait::async_trait;
-use futures_rustls::{pki_types, rustls, TlsAcceptor, TlsConnector, TlsStream};
-use smol::{
- io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
+use futures_util::SinkExt;
+use rustls_pki_types as pki_types;
+
+#[cfg(feature = "smol")]
+use futures_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream};
+#[cfg(feature = "tokio")]
+use tokio_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream};
+
+use karyon_core::async_runtime::{
+ io::{split, ReadHalf, WriteHalf},
lock::Mutex,
net::{TcpListener, TcpStream},
};
use crate::{
- connection::{Connection, ToConn},
+ codec::Codec,
+ connection::{Conn, Connection, ToConn},
endpoint::Endpoint,
- listener::{ConnListener, ToListener},
- Error, Result,
+ listener::{ConnListener, Listener, ToListener},
+ stream::{ReadStream, WriteStream},
+ Result,
};
+use super::tcp::TcpConfig;
+
+/// TLS configuration
+#[derive(Clone)]
+pub struct ServerTlsConfig {
+ pub tcp_config: TcpConfig,
+ pub server_config: rustls::ServerConfig,
+}
+
+#[derive(Clone)]
+pub struct ClientTlsConfig {
+ pub tcp_config: TcpConfig,
+ pub client_config: rustls::ClientConfig,
+ pub dns_name: String,
+}
+
/// TLS network connection implementation of the [`Connection`] trait.
-pub struct TlsConn {
- inner: TcpStream,
- read: Mutex<ReadHalf<TlsStream<TcpStream>>>,
- write: Mutex<WriteHalf<TlsStream<TcpStream>>>,
+pub struct TlsConn<C> {
+ read_stream: Mutex<ReadStream<ReadHalf<TlsStream<TcpStream>>, C>>,
+ write_stream: Mutex<WriteStream<WriteHalf<TlsStream<TcpStream>>, C>>,
+ peer_endpoint: Endpoint,
+ local_endpoint: Endpoint,
}
-impl TlsConn {
+impl<C> TlsConn<C>
+where
+ C: Codec + Clone,
+{
/// Creates a new TlsConn
- pub fn new(sock: TcpStream, conn: TlsStream<TcpStream>) -> Self {
+ pub fn new(
+ conn: TlsStream<TcpStream>,
+ codec: C,
+ peer_endpoint: Endpoint,
+ local_endpoint: Endpoint,
+ ) -> Self {
let (read, write) = split(conn);
+ let read_stream = Mutex::new(ReadStream::new(read, codec.clone()));
+ let write_stream = Mutex::new(WriteStream::new(write, codec));
Self {
- inner: sock,
- read: Mutex::new(read),
- write: Mutex::new(write),
+ read_stream,
+ write_stream,
+ peer_endpoint,
+ local_endpoint,
}
}
}
#[async_trait]
-impl Connection for TlsConn {
+impl<C> Connection for TlsConn<C>
+where
+ C: Clone + Codec,
+{
+ type Item = C::Item;
fn peer_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_tls_addr(&self.inner.peer_addr()?))
+ Ok(self.peer_endpoint.clone())
}
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?))
+ Ok(self.local_endpoint.clone())
}
- async fn read(&self, buf: &mut [u8]) -> Result<usize> {
- self.read.lock().await.read(buf).await.map_err(Error::from)
+ async fn recv(&self) -> Result<Self::Item> {
+ self.read_stream.lock().await.recv().await
}
- async fn write(&self, buf: &[u8]) -> Result<usize> {
- self.write
- .lock()
- .await
- .write(buf)
- .await
- .map_err(Error::from)
+ async fn send(&self, msg: Self::Item) -> Result<()> {
+ self.write_stream.lock().await.send(msg).await
}
}
/// Connects to the given TLS address and port.
-pub async fn dial(
- endpoint: &Endpoint,
- config: rustls::ClientConfig,
- dns_name: &'static str,
-) -> Result<TlsConn> {
+pub async fn dial<C>(endpoint: &Endpoint, config: ClientTlsConfig, codec: C) -> Result<TlsConn<C>>
+where
+ C: Codec + Clone,
+{
let addr = SocketAddr::try_from(endpoint.clone())?;
- let connector = TlsConnector::from(Arc::new(config));
+ let connector = TlsConnector::from(Arc::new(config.client_config.clone()));
+
+ let socket = TcpStream::connect(addr).await?;
+ socket.set_nodelay(config.tcp_config.nodelay)?;
- let sock = TcpStream::connect(addr).await?;
- sock.set_nodelay(true)?;
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?;
- let altname = pki_types::ServerName::try_from(dns_name)?;
- let conn = connector.connect(altname, sock.clone()).await?;
- Ok(TlsConn::new(sock, TlsStream::Client(conn)))
+ let altname = pki_types::ServerName::try_from(config.dns_name.clone())?;
+ let conn = connector.connect(altname, socket).await?;
+ Ok(TlsConn::new(
+ TlsStream::Client(conn),
+ codec,
+ peer_endpoint,
+ local_endpoint,
+ ))
}
/// Tls network listener implementation of the `Listener` [`ConnListener`] trait.
-pub struct TlsListener {
+pub struct TlsListener<C> {
inner: TcpListener,
acceptor: TlsAcceptor,
+ config: ServerTlsConfig,
+ codec: C,
+}
+
+impl<C> TlsListener<C>
+where
+ C: Codec + Clone,
+{
+ pub fn new(
+ acceptor: TlsAcceptor,
+ listener: TcpListener,
+ config: ServerTlsConfig,
+ codec: C,
+ ) -> Self {
+ Self {
+ inner: listener,
+ acceptor,
+ config: config.clone(),
+ codec,
+ }
+ }
}
#[async_trait]
-impl ConnListener for TlsListener {
+impl<C> ConnListener for TlsListener<C>
+where
+ C: Clone + Codec,
+{
+ type Item = C::Item;
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?))
+ Ok(Endpoint::new_tls_addr(self.inner.local_addr()?))
}
- async fn accept(&self) -> Result<Box<dyn Connection>> {
- let (sock, _) = self.inner.accept().await?;
- sock.set_nodelay(true)?;
- let conn = self.acceptor.accept(sock.clone()).await?;
- Ok(Box::new(TlsConn::new(sock, TlsStream::Server(conn))))
+ async fn accept(&self) -> Result<Conn<C::Item>> {
+ let (socket, _) = self.inner.accept().await?;
+ socket.set_nodelay(self.config.tcp_config.nodelay)?;
+
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_tls_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_tls_addr)?;
+
+ let conn = self.acceptor.accept(socket).await?;
+ Ok(Box::new(TlsConn::new(
+ TlsStream::Server(conn),
+ self.codec.clone(),
+ peer_endpoint,
+ local_endpoint,
+ )))
}
}
/// Listens on the given TLS address and port.
-pub async fn listen(endpoint: &Endpoint, config: rustls::ServerConfig) -> Result<TlsListener> {
+pub async fn listen<C>(
+ endpoint: &Endpoint,
+ config: ServerTlsConfig,
+ codec: C,
+) -> Result<TlsListener<C>>
+where
+ C: Clone + Codec,
+{
let addr = SocketAddr::try_from(endpoint.clone())?;
- let acceptor = TlsAcceptor::from(Arc::new(config));
+ let acceptor = TlsAcceptor::from(Arc::new(config.server_config.clone()));
let listener = TcpListener::bind(addr).await?;
- Ok(TlsListener {
- acceptor,
- inner: listener,
- })
-}
-
-impl From<TlsStream<TcpStream>> for Box<dyn Connection> {
- fn from(conn: TlsStream<TcpStream>) -> Self {
- Box::new(TlsConn::new(conn.get_ref().0.clone(), conn))
- }
+ Ok(TlsListener::new(acceptor, listener, config, codec))
}
-impl From<TlsListener> for Box<dyn ConnListener> {
- fn from(listener: TlsListener) -> Self {
+impl<C> From<TlsListener<C>> for Listener<C::Item>
+where
+ C: Codec + Clone,
+{
+ fn from(listener: TlsListener<C>) -> Self {
Box::new(listener)
}
}
-impl ToConn for TlsStream<TcpStream> {
- fn to_conn(self) -> Box<dyn Connection> {
- self.into()
- }
-}
-
-impl ToConn for TlsConn {
- fn to_conn(self) -> Box<dyn Connection> {
+impl<C> ToConn for TlsConn<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
+ fn to_conn(self) -> Conn<Self::Item> {
Box::new(self)
}
}
-impl ToListener for TlsListener {
- fn to_listener(self) -> Box<dyn ConnListener> {
+impl<C> ToListener for TlsListener<C>
+where
+ C: Clone + Codec,
+{
+ type Item = C::Item;
+ fn to_listener(self) -> Listener<Self::Item> {
self.into()
}
}
diff --git a/net/src/transports/udp.rs b/net/src/transports/udp.rs
index bd1fe83..950537c 100644
--- a/net/src/transports/udp.rs
+++ b/net/src/transports/udp.rs
@@ -1,93 +1,111 @@
use std::net::SocketAddr;
use async_trait::async_trait;
-use smol::net::UdpSocket;
+use karyon_core::async_runtime::net::UdpSocket;
use crate::{
- connection::{Connection, ToConn},
+ codec::Codec,
+ connection::{Conn, Connection, ToConn},
endpoint::Endpoint,
Error, Result,
};
+const BUFFER_SIZE: usize = 64 * 1024;
+
+/// UDP configuration
+#[derive(Default)]
+pub struct UdpConfig {}
+
/// UDP network connection implementation of the [`Connection`] trait.
-pub struct UdpConn {
+#[allow(dead_code)]
+pub struct UdpConn<C> {
inner: UdpSocket,
+ codec: C,
+ config: UdpConfig,
}
-impl UdpConn {
+impl<C> UdpConn<C>
+where
+ C: Codec + Clone,
+{
/// Creates a new UdpConn
- pub fn new(conn: UdpSocket) -> Self {
- Self { inner: conn }
- }
-}
-
-impl UdpConn {
- /// Receives a single datagram message. Returns the number of bytes read
- /// and the origin endpoint.
- pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, Endpoint)> {
- let (size, addr) = self.inner.recv_from(buf).await?;
- Ok((size, Endpoint::new_udp_addr(&addr)))
- }
-
- /// Sends data to the given address. Returns the number of bytes written.
- pub async fn send_to(&self, buf: &[u8], addr: &Endpoint) -> Result<usize> {
- let addr: SocketAddr = addr.clone().try_into()?;
- let size = self.inner.send_to(buf, addr).await?;
- Ok(size)
+ fn new(socket: UdpSocket, config: UdpConfig, codec: C) -> Self {
+ Self {
+ inner: socket,
+ codec,
+ config,
+ }
}
}
#[async_trait]
-impl Connection for UdpConn {
+impl<C> Connection for UdpConn<C>
+where
+ C: Codec + Clone,
+{
+ type Item = (C::Item, Endpoint);
fn peer_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_udp_addr(&self.inner.peer_addr()?))
+ self.inner
+ .peer_addr()
+ .map(Endpoint::new_udp_addr)
+ .map_err(Error::from)
}
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_udp_addr(&self.inner.local_addr()?))
+ self.inner
+ .local_addr()
+ .map(Endpoint::new_udp_addr)
+ .map_err(Error::from)
}
- async fn read(&self, buf: &mut [u8]) -> Result<usize> {
- self.inner.recv(buf).await.map_err(Error::from)
+ async fn recv(&self) -> Result<Self::Item> {
+ let mut buf = [0u8; BUFFER_SIZE];
+ let (_, addr) = self.inner.recv_from(&mut buf).await?;
+ match self.codec.decode(&mut buf)? {
+ Some((_, item)) => Ok((item, Endpoint::new_udp_addr(addr))),
+ None => Err(Error::Decode("Unable to decode".into())),
+ }
}
- async fn write(&self, buf: &[u8]) -> Result<usize> {
- self.inner.send(buf).await.map_err(Error::from)
+ async fn send(&self, msg: Self::Item) -> Result<()> {
+ let (msg, out_addr) = msg;
+ let mut buf = [0u8; BUFFER_SIZE];
+ self.codec.encode(&msg, &mut buf)?;
+ let addr: SocketAddr = out_addr.try_into()?;
+ self.inner.send_to(&buf, addr).await?;
+ Ok(())
}
}
/// Connects to the given UDP address and port.
-pub async fn dial(endpoint: &Endpoint) -> Result<UdpConn> {
+pub async fn dial<C>(endpoint: &Endpoint, config: UdpConfig, codec: C) -> Result<UdpConn<C>>
+where
+ C: Codec + Clone,
+{
let addr = SocketAddr::try_from(endpoint.clone())?;
// Let the operating system assign an available port to this socket
let conn = UdpSocket::bind("[::]:0").await?;
conn.connect(addr).await?;
- Ok(UdpConn::new(conn))
+ Ok(UdpConn::new(conn, config, codec))
}
/// Listens on the given UDP address and port.
-pub async fn listen(endpoint: &Endpoint) -> Result<UdpConn> {
+pub async fn listen<C>(endpoint: &Endpoint, config: UdpConfig, codec: C) -> Result<UdpConn<C>>
+where
+ C: Codec + Clone,
+{
let addr = SocketAddr::try_from(endpoint.clone())?;
let conn = UdpSocket::bind(addr).await?;
- let udp_conn = UdpConn::new(conn);
- Ok(udp_conn)
-}
-
-impl From<UdpSocket> for Box<dyn Connection> {
- fn from(conn: UdpSocket) -> Self {
- Box::new(UdpConn::new(conn))
- }
-}
-
-impl ToConn for UdpSocket {
- fn to_conn(self) -> Box<dyn Connection> {
- self.into()
- }
+ Ok(UdpConn::new(conn, config, codec))
}
-impl ToConn for UdpConn {
- fn to_conn(self) -> Box<dyn Connection> {
+impl<C> ToConn for UdpConn<C>
+where
+ C: Codec + Clone,
+{
+ type Item = (C::Item, Endpoint);
+ fn to_conn(self) -> Conn<Self::Item> {
Box::new(self)
}
}
diff --git a/net/src/transports/unix.rs b/net/src/transports/unix.rs
index 494e104..bafebaf 100644
--- a/net/src/transports/unix.rs
+++ b/net/src/transports/unix.rs
@@ -1,111 +1,192 @@
use async_trait::async_trait;
+use futures_util::SinkExt;
-use smol::{
- io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
+use karyon_core::async_runtime::{
+ io::{split, ReadHalf, WriteHalf},
lock::Mutex,
- net::unix::{UnixListener, UnixStream},
+ net::{UnixListener as AsyncUnixListener, UnixStream},
};
use crate::{
- connection::{Connection, ToConn},
+ codec::Codec,
+ connection::{Conn, Connection, ToConn},
endpoint::Endpoint,
- listener::{ConnListener, ToListener},
+ listener::{ConnListener, Listener, ToListener},
+ stream::{ReadStream, WriteStream},
Error, Result,
};
+/// Unix Conn config
+#[derive(Clone, Default)]
+pub struct UnixConfig {}
+
/// Unix domain socket implementation of the [`Connection`] trait.
-pub struct UnixConn {
- inner: UnixStream,
- read: Mutex<ReadHalf<UnixStream>>,
- write: Mutex<WriteHalf<UnixStream>>,
+pub struct UnixConn<C> {
+ read_stream: Mutex<ReadStream<ReadHalf<UnixStream>, C>>,
+ write_stream: Mutex<WriteStream<WriteHalf<UnixStream>, C>>,
+ peer_endpoint: Option<Endpoint>,
+ local_endpoint: Option<Endpoint>,
}
-impl UnixConn {
- /// Creates a new UnixConn
- pub fn new(conn: UnixStream) -> Self {
- let (read, write) = split(conn.clone());
+impl<C> UnixConn<C>
+where
+ C: Codec + Clone,
+{
+ /// Creates a new TcpConn
+ pub fn new(conn: UnixStream, codec: C) -> Self {
+ let peer_endpoint = conn
+ .peer_addr()
+ .and_then(|a| {
+ Ok(Endpoint::new_unix_addr(
+ a.as_pathname()
+ .ok_or(std::io::ErrorKind::AddrNotAvailable)?,
+ ))
+ })
+ .ok();
+ let local_endpoint = conn
+ .local_addr()
+ .and_then(|a| {
+ Ok(Endpoint::new_unix_addr(
+ a.as_pathname()
+ .ok_or(std::io::ErrorKind::AddrNotAvailable)?,
+ ))
+ })
+ .ok();
+
+ let (read, write) = split(conn);
+ let read_stream = Mutex::new(ReadStream::new(read, codec.clone()));
+ let write_stream = Mutex::new(WriteStream::new(write, codec));
Self {
- inner: conn,
- read: Mutex::new(read),
- write: Mutex::new(write),
+ read_stream,
+ write_stream,
+ peer_endpoint,
+ local_endpoint,
}
}
}
#[async_trait]
-impl Connection for UnixConn {
+impl<C> Connection for UnixConn<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
fn peer_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_unix_addr(&self.inner.peer_addr()?))
+ self.peer_endpoint
+ .clone()
+ .ok_or(Error::IO(std::io::ErrorKind::AddrNotAvailable.into()))
}
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_unix_addr(&self.inner.local_addr()?))
+ self.local_endpoint
+ .clone()
+ .ok_or(Error::IO(std::io::ErrorKind::AddrNotAvailable.into()))
}
- async fn read(&self, buf: &mut [u8]) -> Result<usize> {
- self.read.lock().await.read(buf).await.map_err(Error::from)
+ async fn recv(&self) -> Result<Self::Item> {
+ self.read_stream.lock().await.recv().await
}
- async fn write(&self, buf: &[u8]) -> Result<usize> {
- self.write
- .lock()
- .await
- .write(buf)
- .await
- .map_err(Error::from)
+ async fn send(&self, msg: Self::Item) -> Result<()> {
+ self.write_stream.lock().await.send(msg).await
+ }
+}
+
+#[allow(dead_code)]
+pub struct UnixListener<C> {
+ inner: AsyncUnixListener,
+ config: UnixConfig,
+ codec: C,
+}
+
+impl<C> UnixListener<C>
+where
+ C: Codec + Clone,
+{
+ pub fn new(listener: AsyncUnixListener, config: UnixConfig, codec: C) -> Self {
+ Self {
+ inner: listener,
+ config,
+ codec,
+ }
}
}
#[async_trait]
-impl ConnListener for UnixListener {
+impl<C> ConnListener for UnixListener<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_unix_addr(&self.local_addr()?))
+ self.inner
+ .local_addr()
+ .and_then(|a| {
+ Ok(Endpoint::new_unix_addr(
+ a.as_pathname()
+ .ok_or(std::io::ErrorKind::AddrNotAvailable)?,
+ ))
+ })
+ .map_err(Error::from)
}
- async fn accept(&self) -> Result<Box<dyn Connection>> {
- let (conn, _) = self.accept().await?;
- Ok(Box::new(UnixConn::new(conn)))
+ async fn accept(&self) -> Result<Conn<C::Item>> {
+ let (conn, _) = self.inner.accept().await?;
+ Ok(Box::new(UnixConn::new(conn, self.codec.clone())))
}
}
/// Connects to the given Unix socket path.
-pub async fn dial(path: &String) -> Result<UnixConn> {
+pub async fn dial<C>(endpoint: &Endpoint, _config: UnixConfig, codec: C) -> Result<UnixConn<C>>
+where
+ C: Codec + Clone,
+{
+ let path: std::path::PathBuf = endpoint.clone().try_into()?;
let conn = UnixStream::connect(path).await?;
- Ok(UnixConn::new(conn))
+ Ok(UnixConn::new(conn, codec))
}
/// Listens on the given Unix socket path.
-pub fn listen(path: &String) -> Result<UnixListener> {
- let listener = UnixListener::bind(path)?;
- Ok(listener)
-}
-
-impl From<UnixStream> for Box<dyn Connection> {
- fn from(conn: UnixStream) -> Self {
- Box::new(UnixConn::new(conn))
- }
+pub fn listen<C>(endpoint: &Endpoint, config: UnixConfig, codec: C) -> Result<UnixListener<C>>
+where
+ C: Codec + Clone,
+{
+ let path: std::path::PathBuf = endpoint.clone().try_into()?;
+ let listener = AsyncUnixListener::bind(path)?;
+ Ok(UnixListener::new(listener, config, codec))
}
-impl From<UnixListener> for Box<dyn ConnListener> {
- fn from(listener: UnixListener) -> Self {
+// impl From<UnixStream> for Box<dyn Connection> {
+// fn from(conn: UnixStream) -> Self {
+// Box::new(UnixConn::new(conn))
+// }
+// }
+
+impl<C> From<UnixListener<C>> for Listener<C::Item>
+where
+ C: Codec + Clone,
+{
+ fn from(listener: UnixListener<C>) -> Self {
Box::new(listener)
}
}
-impl ToConn for UnixStream {
- fn to_conn(self) -> Box<dyn Connection> {
- self.into()
- }
-}
-
-impl ToConn for UnixConn {
- fn to_conn(self) -> Box<dyn Connection> {
+impl<C> ToConn for UnixConn<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
+ fn to_conn(self) -> Conn<Self::Item> {
Box::new(self)
}
}
-impl ToListener for UnixListener {
- fn to_listener(self) -> Box<dyn ConnListener> {
+impl<C> ToListener for UnixListener<C>
+where
+ C: Codec + Clone,
+{
+ type Item = C::Item;
+ fn to_listener(self) -> Listener<Self::Item> {
self.into()
}
}
diff --git a/net/src/transports/ws.rs b/net/src/transports/ws.rs
index eaf3b9b..17fe924 100644
--- a/net/src/transports/ws.rs
+++ b/net/src/transports/ws.rs
@@ -1,112 +1,254 @@
-use std::net::SocketAddr;
+use std::{net::SocketAddr, sync::Arc};
use async_trait::async_trait;
-use smol::{
- io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
+use rustls_pki_types as pki_types;
+
+#[cfg(feature = "tokio")]
+use async_tungstenite::tokio as async_tungstenite;
+
+#[cfg(feature = "smol")]
+use futures_rustls::{rustls, TlsAcceptor, TlsConnector};
+#[cfg(feature = "tokio")]
+use tokio_rustls::{rustls, TlsAcceptor, TlsConnector};
+
+use karyon_core::async_runtime::{
lock::Mutex,
net::{TcpListener, TcpStream},
};
-use ws_stream_tungstenite::WsStream;
-
use crate::{
- connection::{Connection, ToConn},
+ codec::WebSocketCodec,
+ connection::{Conn, Connection, ToConn},
endpoint::Endpoint,
- listener::{ConnListener, ToListener},
- Error, Result,
+ listener::{ConnListener, Listener, ToListener},
+ stream::WsStream,
+ Result,
};
+use super::tcp::TcpConfig;
+
+/// WSS configuration
+#[derive(Clone)]
+pub struct ServerWssConfig {
+ pub server_config: rustls::ServerConfig,
+}
+
+/// WSS configuration
+#[derive(Clone)]
+pub struct ClientWssConfig {
+ pub client_config: rustls::ClientConfig,
+ pub dns_name: String,
+}
+
+/// WS configuration
+#[derive(Clone, Default)]
+pub struct ServerWsConfig {
+ pub tcp_config: TcpConfig,
+ pub wss_config: Option<ServerWssConfig>,
+}
+
+/// WS configuration
+#[derive(Clone, Default)]
+pub struct ClientWsConfig {
+ pub tcp_config: TcpConfig,
+ pub wss_config: Option<ClientWssConfig>,
+}
+
/// WS network connection implementation of the [`Connection`] trait.
-pub struct WsConn {
- inner: TcpStream,
- read: Mutex<ReadHalf<WsStream<TcpStream>>>,
- write: Mutex<WriteHalf<WsStream<TcpStream>>>,
+pub struct WsConn<C> {
+ // XXX: remove mutex
+ inner: Mutex<WsStream<C>>,
+ peer_endpoint: Endpoint,
+ local_endpoint: Endpoint,
}
-impl WsConn {
+impl<C> WsConn<C>
+where
+ C: WebSocketCodec,
+{
/// Creates a new WsConn
- pub fn new(inner: TcpStream, conn: WsStream<TcpStream>) -> Self {
- let (read, write) = split(conn);
+ pub fn new(ws: WsStream<C>, peer_endpoint: Endpoint, local_endpoint: Endpoint) -> Self {
Self {
- inner,
- read: Mutex::new(read),
- write: Mutex::new(write),
+ inner: Mutex::new(ws),
+ peer_endpoint,
+ local_endpoint,
}
}
}
#[async_trait]
-impl Connection for WsConn {
+impl<C> Connection for WsConn<C>
+where
+ C: WebSocketCodec,
+{
+ type Item = C::Item;
fn peer_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_ws_addr(&self.inner.peer_addr()?))
+ Ok(self.peer_endpoint.clone())
}
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?))
+ Ok(self.local_endpoint.clone())
}
- async fn read(&self, buf: &mut [u8]) -> Result<usize> {
- self.read.lock().await.read(buf).await.map_err(Error::from)
+ async fn recv(&self) -> Result<Self::Item> {
+ self.inner.lock().await.recv().await
}
- async fn write(&self, buf: &[u8]) -> Result<usize> {
- self.write
- .lock()
- .await
- .write(buf)
- .await
- .map_err(Error::from)
+ async fn send(&self, msg: Self::Item) -> Result<()> {
+ self.inner.lock().await.send(msg).await
}
}
/// Ws network listener implementation of the `Listener` [`ConnListener`] trait.
-pub struct WsListener {
+pub struct WsListener<C> {
inner: TcpListener,
+ config: ServerWsConfig,
+ codec: C,
+ tls_acceptor: Option<TlsAcceptor>,
}
#[async_trait]
-impl ConnListener for WsListener {
+impl<C> ConnListener for WsListener<C>
+where
+ C: WebSocketCodec + Clone,
+{
+ type Item = C::Item;
fn local_endpoint(&self) -> Result<Endpoint> {
- Ok(Endpoint::new_ws_addr(&self.inner.local_addr()?))
+ match self.config.wss_config {
+ Some(_) => Ok(Endpoint::new_wss_addr(self.inner.local_addr()?)),
+ None => Ok(Endpoint::new_ws_addr(self.inner.local_addr()?)),
+ }
}
- async fn accept(&self) -> Result<Box<dyn Connection>> {
- let (stream, _) = self.inner.accept().await?;
- let conn = async_tungstenite::accept_async(stream.clone()).await?;
- Ok(Box::new(WsConn::new(stream, WsStream::new(conn))))
+ async fn accept(&self) -> Result<Conn<Self::Item>> {
+ let (socket, _) = self.inner.accept().await?;
+ socket.set_nodelay(self.config.tcp_config.nodelay)?;
+
+ match &self.config.wss_config {
+ Some(_) => match &self.tls_acceptor {
+ Some(acceptor) => {
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?;
+
+ let tls_conn = acceptor.accept(socket).await?.into();
+ let conn = async_tungstenite::accept_async(tls_conn).await?;
+ Ok(Box::new(WsConn::new(
+ WsStream::new_wss(conn, self.codec.clone()),
+ peer_endpoint,
+ local_endpoint,
+ )))
+ }
+ None => unreachable!(),
+ },
+ None => {
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?;
+
+ let conn = async_tungstenite::accept_async(socket).await?;
+
+ Ok(Box::new(WsConn::new(
+ WsStream::new_ws(conn, self.codec.clone()),
+ peer_endpoint,
+ local_endpoint,
+ )))
+ }
+ }
}
}
/// Connects to the given WS address and port.
-pub async fn dial(endpoint: &Endpoint) -> Result<WsConn> {
+pub async fn dial<C>(endpoint: &Endpoint, config: ClientWsConfig, codec: C) -> Result<WsConn<C>>
+where
+ C: WebSocketCodec,
+{
let addr = SocketAddr::try_from(endpoint.clone())?;
- let stream = TcpStream::connect(addr).await?;
- let (conn, _resp) =
- async_tungstenite::client_async(endpoint.to_string(), stream.clone()).await?;
- Ok(WsConn::new(stream, WsStream::new(conn)))
+ let socket = TcpStream::connect(addr).await?;
+ socket.set_nodelay(config.tcp_config.nodelay)?;
+
+ match &config.wss_config {
+ Some(conf) => {
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?;
+
+ let connector = TlsConnector::from(Arc::new(conf.client_config.clone()));
+
+ let altname = pki_types::ServerName::try_from(conf.dns_name.clone())?;
+ let tls_conn = connector.connect(altname, socket).await?.into();
+ let (conn, _resp) =
+ async_tungstenite::client_async(endpoint.to_string(), tls_conn).await?;
+ Ok(WsConn::new(
+ WsStream::new_wss(conn, codec),
+ peer_endpoint,
+ local_endpoint,
+ ))
+ }
+ None => {
+ let peer_endpoint = socket.peer_addr().map(Endpoint::new_ws_addr)?;
+ let local_endpoint = socket.local_addr().map(Endpoint::new_ws_addr)?;
+ let (conn, _resp) =
+ async_tungstenite::client_async(endpoint.to_string(), socket).await?;
+ Ok(WsConn::new(
+ WsStream::new_ws(conn, codec),
+ peer_endpoint,
+ local_endpoint,
+ ))
+ }
+ }
}
/// Listens on the given WS address and port.
-pub async fn listen(endpoint: &Endpoint) -> Result<WsListener> {
+pub async fn listen<C>(
+ endpoint: &Endpoint,
+ config: ServerWsConfig,
+ codec: C,
+) -> Result<WsListener<C>> {
let addr = SocketAddr::try_from(endpoint.clone())?;
+
let listener = TcpListener::bind(addr).await?;
- Ok(WsListener { inner: listener })
+ match &config.wss_config {
+ Some(conf) => {
+ let acceptor = TlsAcceptor::from(Arc::new(conf.server_config.clone()));
+ Ok(WsListener {
+ inner: listener,
+ config,
+ codec,
+ tls_acceptor: Some(acceptor),
+ })
+ }
+ None => Ok(WsListener {
+ inner: listener,
+ config,
+ codec,
+ tls_acceptor: None,
+ }),
+ }
}
-impl From<WsListener> for Box<dyn ConnListener> {
- fn from(listener: WsListener) -> Self {
+impl<C> From<WsListener<C>> for Listener<C::Item>
+where
+ C: WebSocketCodec + Clone,
+{
+ fn from(listener: WsListener<C>) -> Self {
Box::new(listener)
}
}
-impl ToConn for WsConn {
- fn to_conn(self) -> Box<dyn Connection> {
+impl<C> ToConn for WsConn<C>
+where
+ C: WebSocketCodec,
+{
+ type Item = C::Item;
+ fn to_conn(self) -> Conn<Self::Item> {
Box::new(self)
}
}
-impl ToListener for WsListener {
- fn to_listener(self) -> Box<dyn ConnListener> {
+impl<C> ToListener for WsListener<C>
+where
+ C: WebSocketCodec + Clone,
+{
+ type Item = C::Item;
+ fn to_listener(self) -> Listener<Self::Item> {
self.into()
}
}