aboutsummaryrefslogtreecommitdiff
path: root/net/src
diff options
context:
space:
mode:
authorhozan23 <hozan23@proton.me>2023-11-28 22:41:33 +0300
committerhozan23 <hozan23@proton.me>2023-11-28 22:41:33 +0300
commit98a1de91a2dae06323558422c239e5a45fc86e7b (patch)
tree38c640248824fcb3b4ca5ba12df47c13ef26ccda /net/src
parentca2a5f8bbb6983d9555abd10eaaf86950b794957 (diff)
implement TLS for inbound and outbound connections
Diffstat (limited to 'net/src')
-rw-r--r--net/src/connection.rs9
-rw-r--r--net/src/endpoint.rs70
-rw-r--r--net/src/error.rs8
-rw-r--r--net/src/lib.rs1
-rw-r--r--net/src/listener.rs5
-rw-r--r--net/src/transports/mod.rs1
-rw-r--r--net/src/transports/tcp.rs2
-rw-r--r--net/src/transports/tls.rs140
-rw-r--r--net/src/transports/udp.rs2
-rw-r--r--net/src/transports/unix.rs2
10 files changed, 217 insertions, 23 deletions
diff --git a/net/src/connection.rs b/net/src/connection.rs
index d8ec0a3..b1d7550 100644
--- a/net/src/connection.rs
+++ b/net/src/connection.rs
@@ -1,7 +1,9 @@
-use crate::{Endpoint, Result};
use async_trait::async_trait;
-use crate::transports::{tcp, udp, unix};
+use crate::{
+ transports::{tcp, udp, unix},
+ Endpoint, Error, Result,
+};
/// Alias for `Box<dyn Connection>`
pub type Conn = Box<dyn Connection>;
@@ -28,7 +30,7 @@ pub trait Connection: Send + Sync {
/// Connects to the provided endpoint.
///
-/// it only supports `tcp4/6`, `udp4/6` and `unix`.
+/// it only supports `tcp4/6`, `udp4/6`, and `unix`.
///
/// #Example
///
@@ -53,5 +55,6 @@ pub async fn dial(endpoint: &Endpoint) -> Result<Conn> {
Endpoint::Tcp(addr, port) => Ok(Box::new(tcp::dial_tcp(addr, port).await?)),
Endpoint::Udp(addr, port) => Ok(Box::new(udp::dial_udp(addr, port).await?)),
Endpoint::Unix(addr) => Ok(Box::new(unix::dial_unix(addr).await?)),
+ _ => Err(Error::InvalidEndpoint(endpoint.to_string())),
}
}
diff --git a/net/src/endpoint.rs b/net/src/endpoint.rs
index 50dfe6b..720eea3 100644
--- a/net/src/endpoint.rs
+++ b/net/src/endpoint.rs
@@ -5,7 +5,7 @@ use std::{
str::FromStr,
};
-use bincode::{Decode, Encode};
+use bincode::{impl_borrow_decode, Decode, Encode};
use url::Url;
use crate::{Error, Result};
@@ -33,6 +33,7 @@ pub type Port = u16;
pub enum Endpoint {
Udp(Addr, Port),
Tcp(Addr, Port),
+ Tls(Addr, Port),
Unix(String),
}
@@ -45,6 +46,9 @@ impl std::fmt::Display for Endpoint {
Endpoint::Tcp(ip, port) => {
write!(f, "tcp://{}:{}", ip, port)
}
+ Endpoint::Tls(ip, port) => {
+ write!(f, "tls://{}:{}", ip, port)
+ }
Endpoint::Unix(path) => {
if path.is_empty() {
write!(f, "unix:/UNNAMED")
@@ -60,9 +64,10 @@ impl TryFrom<Endpoint> for SocketAddr {
type Error = Error;
fn try_from(endpoint: Endpoint) -> std::result::Result<SocketAddr, Self::Error> {
match endpoint {
- Endpoint::Udp(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)),
- Endpoint::Tcp(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)),
- Endpoint::Unix(_) => Err(Error::TryFromEndpointError),
+ Endpoint::Udp(ip, port) | Endpoint::Tcp(ip, port) | Endpoint::Tls(ip, port) => {
+ Ok(SocketAddr::new(ip.try_into()?, port))
+ }
+ Endpoint::Unix(_) => Err(Error::TryFromEndpoint),
}
}
}
@@ -72,7 +77,7 @@ impl TryFrom<Endpoint> for PathBuf {
fn try_from(endpoint: Endpoint) -> std::result::Result<PathBuf, Self::Error> {
match endpoint {
Endpoint::Unix(path) => Ok(PathBuf::from(&path)),
- _ => Err(Error::TryFromEndpointError),
+ _ => Err(Error::TryFromEndpoint),
}
}
}
@@ -82,7 +87,7 @@ impl TryFrom<Endpoint> for UnixSocketAddress {
fn try_from(endpoint: Endpoint) -> std::result::Result<UnixSocketAddress, Self::Error> {
match endpoint {
Endpoint::Unix(a) => Ok(UnixSocketAddress::from_pathname(a)?),
- _ => Err(Error::TryFromEndpointError),
+ _ => Err(Error::TryFromEndpoint),
}
}
}
@@ -112,6 +117,7 @@ impl FromStr for Endpoint {
match url.scheme() {
"tcp" => Ok(Endpoint::Tcp(addr, port)),
"udp" => Ok(Endpoint::Udp(addr, port)),
+ "tls" => Ok(Endpoint::Tls(addr, port)),
_ => Err(Error::InvalidEndpoint(s.to_string())),
}
} else {
@@ -133,6 +139,11 @@ impl Endpoint {
Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port())
}
+ /// Creates a new TLS endpoint from a `SocketAddr`.
+ pub fn new_tls_addr(addr: &SocketAddr) -> Endpoint {
+ Endpoint::Tls(Addr::Ip(addr.ip()), addr.port())
+ }
+
/// Creates a new UDP endpoint from a `SocketAddr`.
pub fn new_udp_addr(addr: &SocketAddr) -> Endpoint {
Endpoint::Udp(Addr::Ip(addr.ip()), addr.port())
@@ -151,29 +162,62 @@ impl Endpoint {
/// Returns the `Port` of the endpoint.
pub fn port(&self) -> Result<&Port> {
match self {
- Endpoint::Tcp(_, port) => Ok(port),
- Endpoint::Udp(_, port) => Ok(port),
- _ => Err(Error::TryFromEndpointError),
+ Endpoint::Udp(_, port) | Endpoint::Tcp(_, port) | Endpoint::Tls(_, port) => Ok(port),
+ _ => Err(Error::TryFromEndpoint),
}
}
/// Returns the `Addr` of the endpoint.
pub fn addr(&self) -> Result<&Addr> {
match self {
- Endpoint::Tcp(addr, _) => Ok(addr),
- Endpoint::Udp(addr, _) => Ok(addr),
- _ => Err(Error::TryFromEndpointError),
+ Endpoint::Udp(addr, _) | Endpoint::Tcp(addr, _) | Endpoint::Tls(addr, _) => Ok(addr),
+ _ => Err(Error::TryFromEndpoint),
}
}
}
/// Addr defines a type for an address, either IP or domain.
-#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Addr {
Ip(IpAddr),
Domain(String),
}
+impl Encode for Addr {
+ fn encode<E: bincode::enc::Encoder>(
+ &self,
+ encoder: &mut E,
+ ) -> std::result::Result<(), bincode::error::EncodeError> {
+ match self {
+ Addr::Ip(addr) => {
+ 0u32.encode(encoder)?;
+ addr.encode(encoder)
+ }
+ Addr::Domain(domain) => {
+ 1u32.encode(encoder)?;
+ domain.encode(encoder)
+ }
+ }
+ }
+}
+
+impl Decode for Addr {
+ fn decode<D: bincode::de::Decoder>(
+ decoder: &mut D,
+ ) -> std::result::Result<Self, bincode::error::DecodeError> {
+ match u32::decode(decoder)? {
+ 0 => Ok(Addr::Ip(IpAddr::decode(decoder)?)),
+ 1 => Ok(Addr::Domain(String::decode(decoder)?)),
+ found => Err(bincode::error::DecodeError::UnexpectedVariant {
+ allowed: &bincode::error::AllowedEnumVariants::Range { min: 0, max: 1 },
+ found,
+ type_name: core::any::type_name::<Addr>(),
+ }),
+ }
+ }
+}
+impl_borrow_decode!(Addr);
+
impl TryFrom<Addr> for IpAddr {
type Error = Error;
fn try_from(addr: Addr) -> std::result::Result<IpAddr, Self::Error> {
diff --git a/net/src/error.rs b/net/src/error.rs
index 346184a..5dd6348 100644
--- a/net/src/error.rs
+++ b/net/src/error.rs
@@ -8,7 +8,7 @@ pub enum Error {
IO(#[from] std::io::Error),
#[error("Try from endpoint Error")]
- TryFromEndpointError,
+ TryFromEndpoint,
#[error("invalid address {0}")]
InvalidAddress(String),
@@ -28,6 +28,12 @@ pub enum Error {
#[error(transparent)]
ChannelRecv(#[from] smol::channel::RecvError),
+ #[error("Tls Error: {0}")]
+ Rustls(#[from] async_rustls::rustls::Error),
+
+ #[error("Invalid DNS Name: {0}")]
+ InvalidDnsNameError(#[from] async_rustls::rustls::client::InvalidDnsNameError),
+
#[error(transparent)]
KaryonsCore(#[from] karyons_core::error::Error),
}
diff --git a/net/src/lib.rs b/net/src/lib.rs
index 0e4c361..61069ef 100644
--- a/net/src/lib.rs
+++ b/net/src/lib.rs
@@ -10,6 +10,7 @@ pub use {
listener::{listen, Listener},
transports::{
tcp::{dial_tcp, listen_tcp, TcpConn},
+ tls,
udp::{dial_udp, listen_udp, UdpConn},
unix::{dial_unix, listen_unix, UnixConn},
},
diff --git a/net/src/listener.rs b/net/src/listener.rs
index 31a63ae..c6c3d94 100644
--- a/net/src/listener.rs
+++ b/net/src/listener.rs
@@ -1,9 +1,8 @@
-use crate::{Endpoint, Error, Result};
use async_trait::async_trait;
use crate::{
transports::{tcp, unix},
- Conn,
+ Conn, Endpoint, Error, Result,
};
/// Listener is a generic network listener.
@@ -15,7 +14,7 @@ pub trait Listener: Send + Sync {
/// Listens to the provided endpoint.
///
-/// it only supports `tcp4/6` and `unix`.
+/// it only supports `tcp4/6`, and `unix`.
///
/// #Example
///
diff --git a/net/src/transports/mod.rs b/net/src/transports/mod.rs
index f399133..ac23021 100644
--- a/net/src/transports/mod.rs
+++ b/net/src/transports/mod.rs
@@ -1,3 +1,4 @@
pub mod tcp;
+pub mod tls;
pub mod udp;
pub mod unix;
diff --git a/net/src/transports/tcp.rs b/net/src/transports/tcp.rs
index 84aa980..37f00a7 100644
--- a/net/src/transports/tcp.rs
+++ b/net/src/transports/tcp.rs
@@ -13,7 +13,7 @@ use crate::{
Error, Result,
};
-/// TCP network connection implementations of the [`Connection`] trait.
+/// TCP network connection implementation of the [`Connection`] trait.
pub struct TcpConn {
inner: TcpStream,
read: Mutex<ReadHalf<TcpStream>>,
diff --git a/net/src/transports/tls.rs b/net/src/transports/tls.rs
new file mode 100644
index 0000000..01bb5aa
--- /dev/null
+++ b/net/src/transports/tls.rs
@@ -0,0 +1,140 @@
+use std::sync::Arc;
+
+use async_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream};
+use async_trait::async_trait;
+use smol::{
+ io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
+ lock::Mutex,
+ net::{TcpListener, TcpStream},
+};
+
+use crate::{
+ connection::Connection,
+ endpoint::{Addr, Endpoint, Port},
+ listener::Listener,
+ Error, Result,
+};
+
+/// TLS network connection implementation of the [`Connection`] trait.
+pub struct TlsConn {
+ inner: TcpStream,
+ read: Mutex<ReadHalf<TlsStream<TcpStream>>>,
+ write: Mutex<WriteHalf<TlsStream<TcpStream>>>,
+}
+
+impl TlsConn {
+ /// Creates a new TlsConn
+ pub fn new(sock: TcpStream, conn: TlsStream<TcpStream>) -> Self {
+ let (read, write) = split(conn);
+ Self {
+ inner: sock,
+ read: Mutex::new(read),
+ write: Mutex::new(write),
+ }
+ }
+}
+
+#[async_trait]
+impl Connection for TlsConn {
+ fn peer_endpoint(&self) -> Result<Endpoint> {
+ Ok(Endpoint::new_tls_addr(&self.inner.peer_addr()?))
+ }
+
+ fn local_endpoint(&self) -> Result<Endpoint> {
+ Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?))
+ }
+
+ async fn read(&self, buf: &mut [u8]) -> Result<usize> {
+ self.read.lock().await.read(buf).await.map_err(Error::from)
+ }
+
+ async fn write(&self, buf: &[u8]) -> Result<usize> {
+ self.write
+ .lock()
+ .await
+ .write(buf)
+ .await
+ .map_err(Error::from)
+ }
+}
+
+/// Connects to the given TLS address and port.
+pub async fn dial_tls(
+ addr: &Addr,
+ port: &Port,
+ config: rustls::ClientConfig,
+ dns_name: &str,
+) -> Result<TlsConn> {
+ let address = format!("{}:{}", addr, port);
+
+ let connector = TlsConnector::from(Arc::new(config));
+
+ let sock = TcpStream::connect(&address).await?;
+ sock.set_nodelay(true)?;
+
+ let altname = rustls::ServerName::try_from(dns_name)?;
+ let conn = connector.connect(altname, sock.clone()).await?;
+ Ok(TlsConn::new(sock, TlsStream::Client(conn)))
+}
+
+/// Connects to the given TLS endpoint, returns `Conn` ([`Connection`]).
+pub async fn dial(
+ endpoint: &Endpoint,
+ config: rustls::ClientConfig,
+ dns_name: &str,
+) -> Result<Box<dyn Connection>> {
+ match endpoint {
+ Endpoint::Tcp(..) | Endpoint::Tls(..) => {}
+ _ => return Err(Error::InvalidEndpoint(endpoint.to_string())),
+ }
+
+ dial_tls(endpoint.addr()?, endpoint.port()?, config, dns_name)
+ .await
+ .map(|c| Box::new(c) as Box<dyn Connection>)
+}
+/// Tls network listener implementation of the [`Listener`] trait.
+pub struct TlsListener {
+ acceptor: TlsAcceptor,
+ listener: TcpListener,
+}
+
+#[async_trait]
+impl Listener for TlsListener {
+ fn local_endpoint(&self) -> Result<Endpoint> {
+ Ok(Endpoint::new_tls_addr(&self.listener.local_addr()?))
+ }
+
+ async fn accept(&self) -> Result<Box<dyn Connection>> {
+ let (sock, _) = self.listener.accept().await?;
+ sock.set_nodelay(true)?;
+ let conn = self.acceptor.accept(sock.clone()).await?;
+ Ok(Box::new(TlsConn::new(sock, TlsStream::Server(conn))))
+ }
+}
+
+/// Listens on the given TLS address and port.
+pub async fn listen_tls(
+ addr: &Addr,
+ port: &Port,
+ config: rustls::ServerConfig,
+) -> Result<TlsListener> {
+ let address = format!("{}:{}", addr, port);
+ let acceptor = TlsAcceptor::from(Arc::new(config));
+ let listener = TcpListener::bind(&address).await?;
+ Ok(TlsListener { acceptor, listener })
+}
+
+/// Listens on the given TLS endpoint, returns [`Listener`].
+pub async fn listen(
+ endpoint: &Endpoint,
+ config: rustls::ServerConfig,
+) -> Result<Box<dyn Listener>> {
+ match endpoint {
+ Endpoint::Tcp(..) | Endpoint::Tls(..) => {}
+ _ => return Err(Error::InvalidEndpoint(endpoint.to_string())),
+ }
+
+ listen_tls(endpoint.addr()?, endpoint.port()?, config)
+ .await
+ .map(|l| Box::new(l) as Box<dyn Listener>)
+}
diff --git a/net/src/transports/udp.rs b/net/src/transports/udp.rs
index ca5b94d..8a2fbec 100644
--- a/net/src/transports/udp.rs
+++ b/net/src/transports/udp.rs
@@ -9,7 +9,7 @@ use crate::{
Error, Result,
};
-/// UDP network connection implementations of the [`Connection`] trait.
+/// UDP network connection implementation of the [`Connection`] trait.
pub struct UdpConn {
inner: UdpSocket,
}
diff --git a/net/src/transports/unix.rs b/net/src/transports/unix.rs
index a720d91..e504934 100644
--- a/net/src/transports/unix.rs
+++ b/net/src/transports/unix.rs
@@ -8,7 +8,7 @@ use smol::{
use crate::{connection::Connection, endpoint::Endpoint, listener::Listener, Error, Result};
-/// Unix domain socket implementations of the [`Connection`] trait.
+/// Unix domain socket implementation of the [`Connection`] trait.
pub struct UnixConn {
inner: UnixStream,
read: Mutex<ReadHalf<UnixStream>>,