From 98a1de91a2dae06323558422c239e5a45fc86e7b Mon Sep 17 00:00:00 2001 From: hozan23 Date: Tue, 28 Nov 2023 22:41:33 +0300 Subject: implement TLS for inbound and outbound connections --- net/src/endpoint.rs | 70 +++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 57 insertions(+), 13 deletions(-) (limited to 'net/src/endpoint.rs') diff --git a/net/src/endpoint.rs b/net/src/endpoint.rs index 50dfe6b..720eea3 100644 --- a/net/src/endpoint.rs +++ b/net/src/endpoint.rs @@ -5,7 +5,7 @@ use std::{ str::FromStr, }; -use bincode::{Decode, Encode}; +use bincode::{impl_borrow_decode, Decode, Encode}; use url::Url; use crate::{Error, Result}; @@ -33,6 +33,7 @@ pub type Port = u16; pub enum Endpoint { Udp(Addr, Port), Tcp(Addr, Port), + Tls(Addr, Port), Unix(String), } @@ -45,6 +46,9 @@ impl std::fmt::Display for Endpoint { Endpoint::Tcp(ip, port) => { write!(f, "tcp://{}:{}", ip, port) } + Endpoint::Tls(ip, port) => { + write!(f, "tls://{}:{}", ip, port) + } Endpoint::Unix(path) => { if path.is_empty() { write!(f, "unix:/UNNAMED") @@ -60,9 +64,10 @@ impl TryFrom for SocketAddr { type Error = Error; fn try_from(endpoint: Endpoint) -> std::result::Result { match endpoint { - Endpoint::Udp(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)), - Endpoint::Tcp(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)), - Endpoint::Unix(_) => Err(Error::TryFromEndpointError), + Endpoint::Udp(ip, port) | Endpoint::Tcp(ip, port) | Endpoint::Tls(ip, port) => { + Ok(SocketAddr::new(ip.try_into()?, port)) + } + Endpoint::Unix(_) => Err(Error::TryFromEndpoint), } } } @@ -72,7 +77,7 @@ impl TryFrom for PathBuf { fn try_from(endpoint: Endpoint) -> std::result::Result { match endpoint { Endpoint::Unix(path) => Ok(PathBuf::from(&path)), - _ => Err(Error::TryFromEndpointError), + _ => Err(Error::TryFromEndpoint), } } } @@ -82,7 +87,7 @@ impl TryFrom for UnixSocketAddress { fn try_from(endpoint: Endpoint) -> std::result::Result { match endpoint { Endpoint::Unix(a) => Ok(UnixSocketAddress::from_pathname(a)?), - _ => Err(Error::TryFromEndpointError), + _ => Err(Error::TryFromEndpoint), } } } @@ -112,6 +117,7 @@ impl FromStr for Endpoint { match url.scheme() { "tcp" => Ok(Endpoint::Tcp(addr, port)), "udp" => Ok(Endpoint::Udp(addr, port)), + "tls" => Ok(Endpoint::Tls(addr, port)), _ => Err(Error::InvalidEndpoint(s.to_string())), } } else { @@ -133,6 +139,11 @@ impl Endpoint { Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port()) } + /// Creates a new TLS endpoint from a `SocketAddr`. + pub fn new_tls_addr(addr: &SocketAddr) -> Endpoint { + Endpoint::Tls(Addr::Ip(addr.ip()), addr.port()) + } + /// Creates a new UDP endpoint from a `SocketAddr`. pub fn new_udp_addr(addr: &SocketAddr) -> Endpoint { Endpoint::Udp(Addr::Ip(addr.ip()), addr.port()) @@ -151,29 +162,62 @@ impl Endpoint { /// Returns the `Port` of the endpoint. pub fn port(&self) -> Result<&Port> { match self { - Endpoint::Tcp(_, port) => Ok(port), - Endpoint::Udp(_, port) => Ok(port), - _ => Err(Error::TryFromEndpointError), + Endpoint::Udp(_, port) | Endpoint::Tcp(_, port) | Endpoint::Tls(_, port) => Ok(port), + _ => Err(Error::TryFromEndpoint), } } /// Returns the `Addr` of the endpoint. pub fn addr(&self) -> Result<&Addr> { match self { - Endpoint::Tcp(addr, _) => Ok(addr), - Endpoint::Udp(addr, _) => Ok(addr), - _ => Err(Error::TryFromEndpointError), + Endpoint::Udp(addr, _) | Endpoint::Tcp(addr, _) | Endpoint::Tls(addr, _) => Ok(addr), + _ => Err(Error::TryFromEndpoint), } } } /// Addr defines a type for an address, either IP or domain. -#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Addr { Ip(IpAddr), Domain(String), } +impl Encode for Addr { + fn encode( + &self, + encoder: &mut E, + ) -> std::result::Result<(), bincode::error::EncodeError> { + match self { + Addr::Ip(addr) => { + 0u32.encode(encoder)?; + addr.encode(encoder) + } + Addr::Domain(domain) => { + 1u32.encode(encoder)?; + domain.encode(encoder) + } + } + } +} + +impl Decode for Addr { + fn decode( + decoder: &mut D, + ) -> std::result::Result { + match u32::decode(decoder)? { + 0 => Ok(Addr::Ip(IpAddr::decode(decoder)?)), + 1 => Ok(Addr::Domain(String::decode(decoder)?)), + found => Err(bincode::error::DecodeError::UnexpectedVariant { + allowed: &bincode::error::AllowedEnumVariants::Range { min: 0, max: 1 }, + found, + type_name: core::any::type_name::(), + }), + } + } +} +impl_borrow_decode!(Addr); + impl TryFrom for IpAddr { type Error = Error; fn try_from(addr: Addr) -> std::result::Result { -- cgit v1.2.3