From cae0c15d10235bf0ec0bd6f8b20814dc7b63dfd5 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Tue, 16 Jul 2024 08:16:57 +0200 Subject: p2p: check for the endpoints before listen/connect to them --- net/src/endpoint.rs | 148 +++++++++++++++++++++++++++----------------- p2p/src/connection.rs | 3 +- p2p/src/connector.rs | 8 +++ p2p/src/discovery/lookup.rs | 27 +++++--- p2p/src/discovery/mod.rs | 1 - p2p/src/error.rs | 3 + p2p/src/listener.rs | 8 +++ 7 files changed, 131 insertions(+), 67 deletions(-) diff --git a/net/src/endpoint.rs b/net/src/endpoint.rs index c3626ec..dff703d 100644 --- a/net/src/endpoint.rs +++ b/net/src/endpoint.rs @@ -45,6 +45,98 @@ pub enum Endpoint { Unix(PathBuf), } +impl Endpoint { + /// Creates a new TCP endpoint from a `SocketAddr`. + pub fn new_tcp_addr(addr: SocketAddr) -> Endpoint { + Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port()) + } + + /// Creates a new UDP endpoint from a `SocketAddr`. + pub fn new_udp_addr(addr: SocketAddr) -> Endpoint { + Endpoint::Udp(Addr::Ip(addr.ip()), addr.port()) + } + + /// Creates a new TLS endpoint from a `SocketAddr`. + pub fn new_tls_addr(addr: SocketAddr) -> Endpoint { + Endpoint::Tls(Addr::Ip(addr.ip()), addr.port()) + } + + /// Creates a new WS endpoint from a `SocketAddr`. + pub fn new_ws_addr(addr: SocketAddr) -> Endpoint { + Endpoint::Ws(Addr::Ip(addr.ip()), addr.port()) + } + + /// Creates a new WSS endpoint from a `SocketAddr`. + pub fn new_wss_addr(addr: SocketAddr) -> Endpoint { + Endpoint::Wss(Addr::Ip(addr.ip()), addr.port()) + } + + /// Creates a new Unix endpoint from a `UnixSocketAddr`. + pub fn new_unix_addr(addr: &std::path::Path) -> Endpoint { + Endpoint::Unix(addr.to_path_buf()) + } + + #[inline] + /// Checks if the `Endpoint` is of type `Tcp`. + pub fn is_tcp(&self) -> bool { + matches!(self, Endpoint::Tcp(..)) + } + + #[inline] + /// Checks if the `Endpoint` is of type `Tls`. + pub fn is_tls(&self) -> bool { + matches!(self, Endpoint::Tls(..)) + } + + #[inline] + /// Checks if the `Endpoint` is of type `Ws`. + pub fn is_ws(&self) -> bool { + matches!(self, Endpoint::Ws(..)) + } + + #[inline] + /// Checks if the `Endpoint` is of type `Wss`. + pub fn is_wss(&self) -> bool { + matches!(self, Endpoint::Wss(..)) + } + + #[inline] + /// Checks if the `Endpoint` is of type `Udp`. + pub fn is_udp(&self) -> bool { + matches!(self, Endpoint::Udp(..)) + } + + #[inline] + /// Checks if the `Endpoint` is of type `Unix`. + pub fn is_unix(&self) -> bool { + matches!(self, Endpoint::Unix(..)) + } + + /// Returns the `Port` of the endpoint. + pub fn port(&self) -> Result<&Port> { + match self { + Endpoint::Tcp(_, port) + | Endpoint::Udp(_, port) + | Endpoint::Tls(_, port) + | Endpoint::Ws(_, port) + | Endpoint::Wss(_, port) => Ok(port), + _ => Err(Error::TryFromEndpoint), + } + } + + /// Returns the `Addr` of the endpoint. + pub fn addr(&self) -> Result<&Addr> { + match self { + Endpoint::Tcp(addr, _) + | Endpoint::Udp(addr, _) + | Endpoint::Tls(addr, _) + | Endpoint::Ws(addr, _) + | Endpoint::Wss(addr, _) => Ok(addr), + _ => Err(Error::TryFromEndpoint), + } + } +} + impl std::fmt::Display for Endpoint { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { @@ -152,62 +244,6 @@ impl FromStr for Endpoint { } } -impl Endpoint { - /// Creates a new TCP endpoint from a `SocketAddr`. - pub fn new_tcp_addr(addr: SocketAddr) -> Endpoint { - Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port()) - } - - /// Creates a new UDP endpoint from a `SocketAddr`. - pub fn new_udp_addr(addr: SocketAddr) -> Endpoint { - Endpoint::Udp(Addr::Ip(addr.ip()), addr.port()) - } - - /// Creates a new TLS endpoint from a `SocketAddr`. - pub fn new_tls_addr(addr: SocketAddr) -> Endpoint { - Endpoint::Tls(Addr::Ip(addr.ip()), addr.port()) - } - - /// Creates a new WS endpoint from a `SocketAddr`. - pub fn new_ws_addr(addr: SocketAddr) -> Endpoint { - Endpoint::Ws(Addr::Ip(addr.ip()), addr.port()) - } - - /// Creates a new WSS endpoint from a `SocketAddr`. - pub fn new_wss_addr(addr: SocketAddr) -> Endpoint { - Endpoint::Wss(Addr::Ip(addr.ip()), addr.port()) - } - - /// Creates a new Unix endpoint from a `UnixSocketAddr`. - pub fn new_unix_addr(addr: &std::path::Path) -> Endpoint { - Endpoint::Unix(addr.to_path_buf()) - } - - /// Returns the `Port` of the endpoint. - pub fn port(&self) -> Result<&Port> { - match self { - Endpoint::Tcp(_, port) - | Endpoint::Udp(_, port) - | Endpoint::Tls(_, port) - | Endpoint::Ws(_, port) - | Endpoint::Wss(_, port) => Ok(port), - _ => Err(Error::TryFromEndpoint), - } - } - - /// Returns the `Addr` of the endpoint. - pub fn addr(&self) -> Result<&Addr> { - match self { - Endpoint::Tcp(addr, _) - | Endpoint::Udp(addr, _) - | Endpoint::Tls(addr, _) - | Endpoint::Ws(addr, _) - | Endpoint::Wss(addr, _) => Ok(addr), - _ => Err(Error::TryFromEndpoint), - } - } -} - /// Addr defines a type for an address, either IP or domain. #[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] diff --git a/p2p/src/connection.rs b/p2p/src/connection.rs index c1a7a8c..f2e9d1e 100644 --- a/p2p/src/connection.rs +++ b/p2p/src/connection.rs @@ -74,8 +74,7 @@ impl Connection { pub async fn recv(&self) -> Result { match self.listeners.get(&P::id()) { Some(l) => l.recv().await.map_err(Error::from), - // TODO - None => todo!(), + None => Err(Error::UnsupportedProtocol(P::id())), } } diff --git a/p2p/src/connector.rs b/p2p/src/connector.rs index cfa661b..98cdfc7 100644 --- a/p2p/src/connector.rs +++ b/p2p/src/connector.rs @@ -148,6 +148,10 @@ impl Connector { async fn dial(&self, endpoint: &Endpoint, peer_id: &Option) -> Result> { if self.enable_tls { + if !endpoint.is_tcp() && !endpoint.is_tls() { + return Err(Error::UnsupportedEndpoint(endpoint.to_string())); + } + let tls_config = tls::ClientTlsConfig { tcp_config: Default::default(), client_config: tls_client_config(&self.key_pair, peer_id.clone())?, @@ -157,6 +161,10 @@ impl Connector { .await .map(|l| Box::new(l) as karyon_net::Conn) } else { + if !endpoint.is_tcp() { + return Err(Error::UnsupportedEndpoint(endpoint.to_string())); + } + tcp::dial(endpoint, tcp::TcpConfig::default(), NetMsgCodec::new()) .await .map(|l| Box::new(l) as karyon_net::Conn) diff --git a/p2p/src/discovery/lookup.rs b/p2p/src/discovery/lookup.rs index 47a1d09..ba15da3 100644 --- a/p2p/src/discovery/lookup.rs +++ b/p2p/src/discovery/lookup.rs @@ -41,6 +41,9 @@ pub struct LookupService { /// Resolved listen endpoint listen_endpoint: RwLock>, + /// Resolved discovery endpoint + discovery_endpoint: RwLock>, + /// Holds the configuration for the P2P network. config: Arc, @@ -52,7 +55,6 @@ impl LookupService { /// Creates a new lookup service pub fn new( key_pair: &KeyPair, - id: &PeerID, table: Arc, config: Arc, monitor: Arc, @@ -78,13 +80,18 @@ impl LookupService { ex, ); + let id = key_pair + .public() + .try_into() + .expect("Get PeerID from KeyPair"); Self { - id: id.clone(), + id, table, listener, connector, outbound_slots, listen_endpoint: RwLock::new(None), + discovery_endpoint: RwLock::new(None), config, monitor, } @@ -98,19 +105,23 @@ impl LookupService { /// Set the resolved listen endpoint. pub fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) -> Result<()> { - let resolved_endpoint = Endpoint::Tcp( + let discovery_endpoint = Endpoint::Tcp( resolved_endpoint.addr()?.clone(), self.config.discovery_port, ); - *self.listen_endpoint.write() = Some(resolved_endpoint); + *self.listen_endpoint.write() = Some(resolved_endpoint.clone()); + *self.discovery_endpoint.write() = Some(discovery_endpoint.clone()); Ok(()) } - /// Get the listening endpoint. pub fn listen_endpoint(&self) -> Option { self.listen_endpoint.read().clone() } + pub fn discovery_endpoint(&self) -> Option { + self.discovery_endpoint.read().clone() + } + /// Shuts down the lookup service. pub async fn shutdown(&self) { self.connector.shutdown().await; @@ -278,7 +289,7 @@ impl LookupService { trace!("Send Peer msg"); if let Some(endpoint) = self.listen_endpoint() { - self.send_peer_msg(&conn, endpoint.clone()).await?; + self.send_peer_msg(&conn, endpoint).await?; } trace!("Send Shutdown msg"); @@ -289,8 +300,8 @@ impl LookupService { /// Start a listener. async fn start_listener(self: &Arc) -> Result<()> { - let endpoint: Endpoint = match self.listen_endpoint() { - Some(e) => e.clone(), + let endpoint: Endpoint = match self.discovery_endpoint() { + Some(e) => e, None => return Ok(()), }; diff --git a/p2p/src/discovery/mod.rs b/p2p/src/discovery/mod.rs index a81a817..b4c20bf 100644 --- a/p2p/src/discovery/mod.rs +++ b/p2p/src/discovery/mod.rs @@ -80,7 +80,6 @@ impl Discovery { let lookup_service = Arc::new(LookupService::new( key_pair, - peer_id, table.clone(), config.clone(), monitor.clone(), diff --git a/p2p/src/error.rs b/p2p/src/error.rs index a490b57..cc30aff 100644 --- a/p2p/src/error.rs +++ b/p2p/src/error.rs @@ -11,6 +11,9 @@ pub enum Error { #[error("Unsupported protocol error: {0}")] UnsupportedProtocol(String), + #[error("Unsupported Endpoint: {0}")] + UnsupportedEndpoint(String), + #[error("PeerID try from PublicKey Error")] PeerIDTryFromPublicKey, diff --git a/p2p/src/listener.rs b/p2p/src/listener.rs index 8a5deaa..347099d 100644 --- a/p2p/src/listener.rs +++ b/p2p/src/listener.rs @@ -157,6 +157,10 @@ impl Listener { async fn listen(&self, endpoint: &Endpoint) -> Result> { if self.enable_tls { + if !endpoint.is_tcp() && !endpoint.is_tls() { + return Err(Error::UnsupportedEndpoint(endpoint.to_string())); + } + let tls_config = tls::ServerTlsConfig { tcp_config: Default::default(), server_config: tls_server_config(&self.key_pair)?, @@ -165,6 +169,10 @@ impl Listener { .await .map(|l| Box::new(l) as karyon_net::Listener) } else { + if !endpoint.is_tcp() { + return Err(Error::UnsupportedEndpoint(endpoint.to_string())); + } + tcp::listen(endpoint, tcp::TcpConfig::default(), NetMsgCodec::new()) .await .map(|l| Box::new(l) as karyon_net::Listener) -- cgit v1.2.3