diff options
author | hozan23 <hozan23@karyontech.net> | 2024-05-19 23:41:31 +0200 |
---|---|---|
committer | hozan23 <hozan23@karyontech.net> | 2024-05-19 23:41:31 +0200 |
commit | f6f44784fff5488bb59d563ee7ff7b94c08a48c1 (patch) | |
tree | 63fa6fa0d620748a92d819f4773773ea9d53afc5 /net/src | |
parent | a6016c7eeb11fc8aeaa1a3b160b970b15362695d (diff) |
use cargo features to enable/disable protocols for net crate
Diffstat (limited to 'net/src')
-rw-r--r-- | net/src/codec/mod.rs | 3 | ||||
-rw-r--r-- | net/src/error.rs | 2 | ||||
-rw-r--r-- | net/src/lib.rs | 17 | ||||
-rw-r--r-- | net/src/stream/mod.rs | 2 | ||||
-rw-r--r-- | net/src/stream/websocket.rs | 11 | ||||
-rw-r--r-- | net/src/transports/mod.rs | 5 | ||||
-rw-r--r-- | net/src/transports/ws.rs | 23 |
7 files changed, 57 insertions, 6 deletions
diff --git a/net/src/codec/mod.rs b/net/src/codec/mod.rs index 565cb07..43a02f3 100644 --- a/net/src/codec/mod.rs +++ b/net/src/codec/mod.rs @@ -1,9 +1,12 @@ mod bytes_codec; mod length_codec; +#[cfg(feature = "ws")] mod websocket; pub use bytes_codec::BytesCodec; pub use length_codec::LengthCodec; + +#[cfg(feature = "ws")] pub use websocket::{WebSocketCodec, WebSocketDecoder, WebSocketEncoder}; use crate::Result; diff --git a/net/src/error.rs b/net/src/error.rs index ee93168..102a343 100644 --- a/net/src/error.rs +++ b/net/src/error.rs @@ -37,6 +37,7 @@ pub enum Error { #[error(transparent)] ChannelRecv(#[from] async_channel::RecvError), + #[cfg(feature = "ws")] #[error("Ws Error: {0}")] WsError(#[from] async_tungstenite::tungstenite::Error), @@ -48,6 +49,7 @@ pub enum Error { #[error("Tls Error: {0}")] Rustls(#[from] tokio_rustls::rustls::Error), + #[cfg(feature = "tls")] #[error("Invalid DNS Name: {0}")] InvalidDnsNameError(#[from] rustls_pki_types::InvalidDnsNameError), diff --git a/net/src/lib.rs b/net/src/lib.rs index ddb53cf..cd5fc8b 100644 --- a/net/src/lib.rs +++ b/net/src/lib.rs @@ -3,6 +3,7 @@ mod connection; mod endpoint; mod error; mod listener; +#[cfg(feature = "stream")] mod stream; mod transports; @@ -10,9 +11,23 @@ pub use { connection::{Conn, Connection, ToConn}, endpoint::{Addr, Endpoint, Port, ToEndpoint}, listener::{ConnListener, Listener, ToListener}, - transports::{tcp, tls, udp, unix, ws}, }; +#[cfg(feature = "tcp")] +pub use transports::tcp; + +#[cfg(feature = "tls")] +pub use transports::tls; + +#[cfg(feature = "ws")] +pub use transports::ws; + +#[cfg(feature = "udp")] +pub use transports::udp; + +#[cfg(all(feature = "unix", target_family = "unix"))] +pub use transports::unix; + /// Represents karyon's Net Error pub use error::Error; diff --git a/net/src/stream/mod.rs b/net/src/stream/mod.rs index b792292..ce48a77 100644 --- a/net/src/stream/mod.rs +++ b/net/src/stream/mod.rs @@ -1,6 +1,8 @@ mod buffer; +#[cfg(feature = "ws")] mod websocket; +#[cfg(feature = "ws")] pub use websocket::WsStream; use std::{ diff --git a/net/src/stream/websocket.rs b/net/src/stream/websocket.rs index 2552eaf..9d41626 100644 --- a/net/src/stream/websocket.rs +++ b/net/src/stream/websocket.rs @@ -6,9 +6,9 @@ use std::{ use async_tungstenite::tungstenite::Message; use futures_util::{Sink, SinkExt, Stream, StreamExt}; -#[cfg(feature = "smol")] +#[cfg(all(feature = "smol", feature = "tls"))] use futures_rustls::TlsStream; -#[cfg(feature = "tokio")] +#[cfg(all(feature = "tokio", feature = "tls"))] use tokio_rustls::TlsStream; use karyon_core::async_runtime::net::TcpStream; @@ -37,6 +37,7 @@ where } } + #[cfg(feature = "tls")] pub fn new_wss(conn: WebSocketStream<TlsStream<TcpStream>>, codec: C) -> Self { Self { inner: InnerWSConn::Tls(conn), @@ -59,6 +60,7 @@ where enum InnerWSConn { Plain(WebSocketStream<TcpStream>), + #[cfg(feature = "tls")] Tls(WebSocketStream<TlsStream<TcpStream>>), } @@ -68,6 +70,7 @@ impl Sink<Message> for InnerWSConn { fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { match &mut *self { InnerWSConn::Plain(s) => Pin::new(s).poll_ready(cx).map_err(Error::from), + #[cfg(feature = "tls")] InnerWSConn::Tls(s) => Pin::new(s).poll_ready(cx).map_err(Error::from), } } @@ -75,6 +78,7 @@ impl Sink<Message> for InnerWSConn { fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<()> { match &mut *self { InnerWSConn::Plain(s) => Pin::new(s).start_send(item).map_err(Error::from), + #[cfg(feature = "tls")] InnerWSConn::Tls(s) => Pin::new(s).start_send(item).map_err(Error::from), } } @@ -82,6 +86,7 @@ impl Sink<Message> for InnerWSConn { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { match &mut *self { InnerWSConn::Plain(s) => Pin::new(s).poll_flush(cx).map_err(Error::from), + #[cfg(feature = "tls")] InnerWSConn::Tls(s) => Pin::new(s).poll_flush(cx).map_err(Error::from), } } @@ -89,6 +94,7 @@ impl Sink<Message> for InnerWSConn { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { match &mut *self { InnerWSConn::Plain(s) => Pin::new(s).poll_close(cx).map_err(Error::from), + #[cfg(feature = "tls")] InnerWSConn::Tls(s) => Pin::new(s).poll_close(cx).map_err(Error::from), } .map_err(Error::from) @@ -101,6 +107,7 @@ impl Stream for InnerWSConn { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { match &mut *self { InnerWSConn::Plain(s) => Pin::new(s).poll_next(cx).map_err(Error::from), + #[cfg(feature = "tls")] InnerWSConn::Tls(s) => Pin::new(s).poll_next(cx).map_err(Error::from), } } diff --git a/net/src/transports/mod.rs b/net/src/transports/mod.rs index 14ef6f3..c7d684b 100644 --- a/net/src/transports/mod.rs +++ b/net/src/transports/mod.rs @@ -1,5 +1,10 @@ +#[cfg(feature = "tcp")] pub mod tcp; +#[cfg(feature = "tls")] pub mod tls; +#[cfg(feature = "udp")] pub mod udp; +#[cfg(all(feature = "unix", target_family = "unix"))] pub mod unix; +#[cfg(feature = "ws")] pub mod ws; diff --git a/net/src/transports/ws.rs b/net/src/transports/ws.rs index 17fe924..6107999 100644 --- a/net/src/transports/ws.rs +++ b/net/src/transports/ws.rs @@ -1,14 +1,18 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::net::SocketAddr; + +#[cfg(feature = "tls")] +use std::sync::Arc; use async_trait::async_trait; +#[cfg(feature = "tls")] use rustls_pki_types as pki_types; #[cfg(feature = "tokio")] use async_tungstenite::tokio as async_tungstenite; -#[cfg(feature = "smol")] +#[cfg(all(feature = "smol", feature = "tls"))] use futures_rustls::{rustls, TlsAcceptor, TlsConnector}; -#[cfg(feature = "tokio")] +#[cfg(all(feature = "tokio", feature = "tls"))] use tokio_rustls::{rustls, TlsAcceptor, TlsConnector}; use karyon_core::async_runtime::{ @@ -30,12 +34,14 @@ use super::tcp::TcpConfig; /// WSS configuration #[derive(Clone)] pub struct ServerWssConfig { + #[cfg(feature = "tls")] pub server_config: rustls::ServerConfig, } /// WSS configuration #[derive(Clone)] pub struct ClientWssConfig { + #[cfg(feature = "tls")] pub client_config: rustls::ClientConfig, pub dns_name: String, } @@ -104,6 +110,7 @@ pub struct WsListener<C> { inner: TcpListener, config: ServerWsConfig, codec: C, + #[cfg(feature = "tls")] tls_acceptor: Option<TlsAcceptor>, } @@ -125,6 +132,7 @@ where socket.set_nodelay(self.config.tcp_config.nodelay)?; match &self.config.wss_config { + #[cfg(feature = "tls")] Some(_) => match &self.tls_acceptor { Some(acceptor) => { let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?; @@ -152,6 +160,8 @@ where local_endpoint, ))) } + #[cfg(not(feature = "tls"))] + _ => unreachable!(), } } } @@ -166,6 +176,7 @@ where socket.set_nodelay(config.tcp_config.nodelay)?; match &config.wss_config { + #[cfg(feature = "tls")] Some(conf) => { let peer_endpoint = socket.peer_addr().map(Endpoint::new_wss_addr)?; let local_endpoint = socket.local_addr().map(Endpoint::new_wss_addr)?; @@ -193,6 +204,8 @@ where local_endpoint, )) } + #[cfg(not(feature = "tls"))] + _ => unreachable!(), } } @@ -206,6 +219,7 @@ pub async fn listen<C>( let listener = TcpListener::bind(addr).await?; match &config.wss_config { + #[cfg(feature = "tls")] Some(conf) => { let acceptor = TlsAcceptor::from(Arc::new(conf.server_config.clone())); Ok(WsListener { @@ -219,8 +233,11 @@ pub async fn listen<C>( inner: listener, config, codec, + #[cfg(feature = "tls")] tls_acceptor: None, }), + #[cfg(not(feature = "tls"))] + _ => unreachable!(), } } |