aboutsummaryrefslogtreecommitdiff
path: root/net/src
diff options
context:
space:
mode:
authorhozan23 <hozan23@karyontech.net>2024-05-19 23:41:31 +0200
committerhozan23 <hozan23@karyontech.net>2024-05-19 23:41:31 +0200
commitf6f44784fff5488bb59d563ee7ff7b94c08a48c1 (patch)
tree63fa6fa0d620748a92d819f4773773ea9d53afc5 /net/src
parenta6016c7eeb11fc8aeaa1a3b160b970b15362695d (diff)
use cargo features to enable/disable protocols for net crate
Diffstat (limited to 'net/src')
-rw-r--r--net/src/codec/mod.rs3
-rw-r--r--net/src/error.rs2
-rw-r--r--net/src/lib.rs17
-rw-r--r--net/src/stream/mod.rs2
-rw-r--r--net/src/stream/websocket.rs11
-rw-r--r--net/src/transports/mod.rs5
-rw-r--r--net/src/transports/ws.rs23
7 files changed, 57 insertions, 6 deletions
diff --git a/net/src/codec/mod.rs b/net/src/codec/mod.rs
index 565cb07..43a02f3 100644
--- a/net/src/codec/mod.rs
+++ b/net/src/codec/mod.rs
@@ -1,9 +1,12 @@
mod bytes_codec;
mod length_codec;
+#[cfg(feature = "ws")]
mod websocket;
pub use bytes_codec::BytesCodec;
pub use length_codec::LengthCodec;
+
+#[cfg(feature = "ws")]
pub use websocket::{WebSocketCodec, WebSocketDecoder, WebSocketEncoder};
use crate::Result;
diff --git a/net/src/error.rs b/net/src/error.rs
index ee93168..102a343 100644
--- a/net/src/error.rs
+++ b/net/src/error.rs
@@ -37,6 +37,7 @@ pub enum Error {
#[error(transparent)]
ChannelRecv(#[from] async_channel::RecvError),
+ #[cfg(feature = "ws")]
#[error("Ws Error: {0}")]
WsError(#[from] async_tungstenite::tungstenite::Error),
@@ -48,6 +49,7 @@ pub enum Error {
#[error("Tls Error: {0}")]
Rustls(#[from] tokio_rustls::rustls::Error),
+ #[cfg(feature = "tls")]
#[error("Invalid DNS Name: {0}")]
InvalidDnsNameError(#[from] rustls_pki_types::InvalidDnsNameError),
diff --git a/net/src/lib.rs b/net/src/lib.rs
index ddb53cf..cd5fc8b 100644
--- a/net/src/lib.rs
+++ b/net/src/lib.rs
@@ -3,6 +3,7 @@ mod connection;
mod endpoint;
mod error;
mod listener;
+#[cfg(feature = "stream")]
mod stream;
mod transports;
@@ -10,9 +11,23 @@ pub use {
connection::{Conn, Connection, ToConn},
endpoint::{Addr, Endpoint, Port, ToEndpoint},
listener::{ConnListener, Listener, ToListener},
- transports::{tcp, tls, udp, unix, ws},
};
+#[cfg(feature = "tcp")]
+pub use transports::tcp;
+
+#[cfg(feature = "tls")]
+pub use transports::tls;
+
+#[cfg(feature = "ws")]
+pub use transports::ws;
+
+#[cfg(feature = "udp")]
+pub use transports::udp;
+
+#[cfg(all(feature = "unix", target_family = "unix"))]
+pub use transports::unix;
+
/// Represents karyon's Net Error
pub use error::Error;
diff --git a/net/src/stream/mod.rs b/net/src/stream/mod.rs
index b792292..ce48a77 100644
--- a/net/src/stream/mod.rs
+++ b/net/src/stream/mod.rs
@@ -1,6 +1,8 @@
mod buffer;
+#[cfg(feature = "ws")]
mod websocket;
+#[cfg(feature = "ws")]
pub use websocket::WsStream;
use std::{
diff --git a/net/src/stream/websocket.rs b/net/src/stream/websocket.rs
index 2552eaf..9d41626 100644
--- a/net/src/stream/websocket.rs
+++ b/net/src/stream/websocket.rs
@@ -6,9 +6,9 @@ use std::{
use async_tungstenite::tungstenite::Message;
use futures_util::{Sink, SinkExt, Stream, StreamExt};
-#[cfg(feature = "smol")]
+#[cfg(all(feature = "smol", feature = "tls"))]
use futures_rustls::TlsStream;
-#[cfg(feature = "tokio")]
+#[cfg(all(feature = "tokio", feature = "tls"))]
use tokio_rustls::TlsStream;
use karyon_core::async_runtime::net::TcpStream;
@@ -37,6 +37,7 @@ where
}
}
+ #[cfg(feature = "tls")]
pub fn new_wss(conn: WebSocketStream<TlsStream<TcpStream>>, codec: C) -> Self {
Self {
inner: InnerWSConn::Tls(conn),
@@ -59,6 +60,7 @@ where
enum InnerWSConn {
Plain(WebSocketStream<TcpStream>),
+ #[cfg(feature = "tls")]
Tls(WebSocketStream<TlsStream<TcpStream>>),
}
@@ -68,6 +70,7 @@ impl Sink<Message> for InnerWSConn {
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match &mut *self {
InnerWSConn::Plain(s) => Pin::new(s).poll_ready(cx).map_err(Error::from),
+ #[cfg(feature = "tls")]
InnerWSConn::Tls(s) => Pin::new(s).poll_ready(cx).map_err(Error::from),
}
}
@@ -75,6 +78,7 @@ impl Sink<Message> for InnerWSConn {
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<()> {
match &mut *self {
InnerWSConn::Plain(s) => Pin::new(s).start_send(item).map_err(Error::from),
+ #[cfg(feature = "tls")]
InnerWSConn::Tls(s) => Pin::new(s).start_send(item).map_err(Error::from),
}
}
@@ -82,6 +86,7 @@ impl Sink<Message> for InnerWSConn {
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match &mut *self {
InnerWSConn::Plain(s) => Pin::new(s).poll_flush(cx).map_err(Error::from),
+ #[cfg(feature = "tls")]
InnerWSConn::Tls(s) => Pin::new(s).poll_flush(cx).map_err(Error::from),
}
}
@@ -89,6 +94,7 @@ impl Sink<Message> for InnerWSConn {
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match &mut *self {
InnerWSConn::Plain(s) => Pin::new(s).poll_close(cx).map_err(Error::from),
+ #[cfg(feature = "tls")]
InnerWSConn::Tls(s) => Pin::new(s).poll_close(cx).map_err(Error::from),
}
.map_err(Error::from)
@@ -101,6 +107,7 @@ impl Stream for InnerWSConn {
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match &mut *self {
InnerWSConn::Plain(s) => Pin::new(s).poll_next(cx).map_err(Error::from),
+ #[cfg(feature = "tls")]
InnerWSConn::Tls(s) => Pin::new(s).poll_next(cx).map_err(Error::from),
}
}
diff --git a/net/src/transports/mod.rs b/net/src/transports/mod.rs
index 14ef6f3..c7d684b 100644
--- a/net/src/transports/mod.rs
+++ b/net/src/transports/mod.rs
@@ -1,5 +1,10 @@
+#[cfg(feature = "tcp")]
pub mod tcp;
+#[cfg(feature = "tls")]
pub mod tls;
+#[cfg(feature = "udp")]
pub mod udp;
+#[cfg(all(feature = "unix", target_family = "unix"))]
pub mod unix;
+#[cfg(feature = "ws")]
pub mod ws;
diff --git a/net/src/transports/ws.rs b/net/src/transports/ws.rs
index 17fe924..6107999 100644
--- a/net/src/transports/ws.rs
+++ b/net/src/transports/ws.rs
@@ -1,14 +1,18 @@
-use std::{net::SocketAddr, sync::Arc};
+use std::net::SocketAddr;
+
+#[cfg(feature = "tls")]
+use std::sync::Arc;
use async_trait::async_trait;
+#[cfg(feature = "tls")]
use rustls_pki_types as pki_types;
#[cfg(feature = "tokio")]
use async_tungstenite::tokio as async_tungstenite;
-#[cfg(feature = "smol")]
+#[cfg(all(feature = "smol", feature = "tls"))]
use futures_rustls::{rustls, TlsAcceptor, TlsConnector};
-#[cfg(feature = "tokio")]
+#[cfg(all(feature = "tokio", feature = "tls"))]
use tokio_rustls::{rustls, TlsAcceptor, TlsConnector};
use karyon_core::async_runtime::{
@@ -30,12 +34,14 @@ use super::tcp::TcpConfig;
/// WSS configuration
#[derive(Clone)]
pub struct ServerWssConfig {
+ #[cfg(feature = "tls")]
pub server_config: rustls::ServerConfig,
}
/// WSS configuration
#[derive(Clone)]
pub struct ClientWssConfig {
+ #[cfg(feature = "tls")]
pub client_config: rustls::ClientConfig,
pub dns_name: String,
}
@@ -104,6 +110,7 @@ pub struct WsListener<C> {
inner: TcpListener,
config: ServerWsConfig,
codec: C,
+ #[cfg(feature = "tls")]
tls_acceptor: Option<TlsAcceptor>,
}
@@ -125,6 +132,7 @@ where
socket.set_nodelay(self.config.tcp_config.nodelay)?;
match &self.config.wss_config {
+ #[cfg(feature = "tls")]
Some(_) => match &self.tls_acceptor {
Some(acceptor) => {
let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?;
@@ -152,6 +160,8 @@ where
local_endpoint,
)))
}
+ #[cfg(not(feature = "tls"))]
+ _ => unreachable!(),
}
}
}
@@ -166,6 +176,7 @@ where
socket.set_nodelay(config.tcp_config.nodelay)?;
match &config.wss_config {
+ #[cfg(feature = "tls")]
Some(conf) => {
let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?;
let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?;
@@ -193,6 +204,8 @@ where
local_endpoint,
))
}
+ #[cfg(not(feature = "tls"))]
+ _ => unreachable!(),
}
}
@@ -206,6 +219,7 @@ pub async fn listen<C>(
let listener = TcpListener::bind(addr).await?;
match &config.wss_config {
+ #[cfg(feature = "tls")]
Some(conf) => {
let acceptor = TlsAcceptor::from(Arc::new(conf.server_config.clone()));
Ok(WsListener {
@@ -219,8 +233,11 @@ pub async fn listen<C>(
inner: listener,
config,
codec,
+ #[cfg(feature = "tls")]
tls_acceptor: None,
}),
+ #[cfg(not(feature = "tls"))]
+ _ => unreachable!(),
}
}