aboutsummaryrefslogtreecommitdiff
path: root/net/src/transports/tls.rs
diff options
context:
space:
mode:
authorhozan23 <hozan23@proton.me>2023-11-28 22:41:33 +0300
committerhozan23 <hozan23@proton.me>2023-11-28 22:41:33 +0300
commit98a1de91a2dae06323558422c239e5a45fc86e7b (patch)
tree38c640248824fcb3b4ca5ba12df47c13ef26ccda /net/src/transports/tls.rs
parentca2a5f8bbb6983d9555abd10eaaf86950b794957 (diff)
implement TLS for inbound and outbound connections
Diffstat (limited to 'net/src/transports/tls.rs')
-rw-r--r--net/src/transports/tls.rs140
1 files changed, 140 insertions, 0 deletions
diff --git a/net/src/transports/tls.rs b/net/src/transports/tls.rs
new file mode 100644
index 0000000..01bb5aa
--- /dev/null
+++ b/net/src/transports/tls.rs
@@ -0,0 +1,140 @@
+use std::sync::Arc;
+
+use async_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream};
+use async_trait::async_trait;
+use smol::{
+ io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
+ lock::Mutex,
+ net::{TcpListener, TcpStream},
+};
+
+use crate::{
+ connection::Connection,
+ endpoint::{Addr, Endpoint, Port},
+ listener::Listener,
+ Error, Result,
+};
+
+/// TLS network connection implementation of the [`Connection`] trait.
+pub struct TlsConn {
+ inner: TcpStream,
+ read: Mutex<ReadHalf<TlsStream<TcpStream>>>,
+ write: Mutex<WriteHalf<TlsStream<TcpStream>>>,
+}
+
+impl TlsConn {
+ /// Creates a new TlsConn
+ pub fn new(sock: TcpStream, conn: TlsStream<TcpStream>) -> Self {
+ let (read, write) = split(conn);
+ Self {
+ inner: sock,
+ read: Mutex::new(read),
+ write: Mutex::new(write),
+ }
+ }
+}
+
+#[async_trait]
+impl Connection for TlsConn {
+ fn peer_endpoint(&self) -> Result<Endpoint> {
+ Ok(Endpoint::new_tls_addr(&self.inner.peer_addr()?))
+ }
+
+ fn local_endpoint(&self) -> Result<Endpoint> {
+ Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?))
+ }
+
+ async fn read(&self, buf: &mut [u8]) -> Result<usize> {
+ self.read.lock().await.read(buf).await.map_err(Error::from)
+ }
+
+ async fn write(&self, buf: &[u8]) -> Result<usize> {
+ self.write
+ .lock()
+ .await
+ .write(buf)
+ .await
+ .map_err(Error::from)
+ }
+}
+
+/// Connects to the given TLS address and port.
+pub async fn dial_tls(
+ addr: &Addr,
+ port: &Port,
+ config: rustls::ClientConfig,
+ dns_name: &str,
+) -> Result<TlsConn> {
+ let address = format!("{}:{}", addr, port);
+
+ let connector = TlsConnector::from(Arc::new(config));
+
+ let sock = TcpStream::connect(&address).await?;
+ sock.set_nodelay(true)?;
+
+ let altname = rustls::ServerName::try_from(dns_name)?;
+ let conn = connector.connect(altname, sock.clone()).await?;
+ Ok(TlsConn::new(sock, TlsStream::Client(conn)))
+}
+
+/// Connects to the given TLS endpoint, returns `Conn` ([`Connection`]).
+pub async fn dial(
+ endpoint: &Endpoint,
+ config: rustls::ClientConfig,
+ dns_name: &str,
+) -> Result<Box<dyn Connection>> {
+ match endpoint {
+ Endpoint::Tcp(..) | Endpoint::Tls(..) => {}
+ _ => return Err(Error::InvalidEndpoint(endpoint.to_string())),
+ }
+
+ dial_tls(endpoint.addr()?, endpoint.port()?, config, dns_name)
+ .await
+ .map(|c| Box::new(c) as Box<dyn Connection>)
+}
+/// Tls network listener implementation of the [`Listener`] trait.
+pub struct TlsListener {
+ acceptor: TlsAcceptor,
+ listener: TcpListener,
+}
+
+#[async_trait]
+impl Listener for TlsListener {
+ fn local_endpoint(&self) -> Result<Endpoint> {
+ Ok(Endpoint::new_tls_addr(&self.listener.local_addr()?))
+ }
+
+ async fn accept(&self) -> Result<Box<dyn Connection>> {
+ let (sock, _) = self.listener.accept().await?;
+ sock.set_nodelay(true)?;
+ let conn = self.acceptor.accept(sock.clone()).await?;
+ Ok(Box::new(TlsConn::new(sock, TlsStream::Server(conn))))
+ }
+}
+
+/// Listens on the given TLS address and port.
+pub async fn listen_tls(
+ addr: &Addr,
+ port: &Port,
+ config: rustls::ServerConfig,
+) -> Result<TlsListener> {
+ let address = format!("{}:{}", addr, port);
+ let acceptor = TlsAcceptor::from(Arc::new(config));
+ let listener = TcpListener::bind(&address).await?;
+ Ok(TlsListener { acceptor, listener })
+}
+
+/// Listens on the given TLS endpoint, returns [`Listener`].
+pub async fn listen(
+ endpoint: &Endpoint,
+ config: rustls::ServerConfig,
+) -> Result<Box<dyn Listener>> {
+ match endpoint {
+ Endpoint::Tcp(..) | Endpoint::Tls(..) => {}
+ _ => return Err(Error::InvalidEndpoint(endpoint.to_string())),
+ }
+
+ listen_tls(endpoint.addr()?, endpoint.port()?, config)
+ .await
+ .map(|l| Box::new(l) as Box<dyn Listener>)
+}