From e15d3e6fd20b3f87abaad7ddec1c88b0e66419f9 Mon Sep 17 00:00:00 2001
From: hozan23 <hozan23@karyontech.net>
Date: Mon, 15 Jul 2024 13:16:01 +0200
Subject: p2p: Major refactoring of the handshake protocol

Introduce a new protocol InitProtocol which can be used as the core protocol
for initializing a connection with a peer.
Move the handshake logic from the PeerPool module to the protocols directory and
build a handshake protocol that implements InitProtocol trait.
---
 p2p/src/discovery/lookup.rs  | 114 ++++++++++++++++++++-----------------------
 p2p/src/discovery/mod.rs     |  17 +++----
 p2p/src/discovery/refresh.rs |  28 +++++------
 3 files changed, 73 insertions(+), 86 deletions(-)

(limited to 'p2p/src/discovery')

diff --git a/p2p/src/discovery/lookup.rs b/p2p/src/discovery/lookup.rs
index 9ddf614..47a1d09 100644
--- a/p2p/src/discovery/lookup.rs
+++ b/p2p/src/discovery/lookup.rs
@@ -2,24 +2,17 @@ use std::{sync::Arc, time::Duration};
 
 use futures_util::stream::{FuturesUnordered, StreamExt};
 use log::{error, trace};
+use parking_lot::RwLock;
 use rand::{rngs::OsRng, seq::SliceRandom, RngCore};
 
-use karyon_core::{
-    async_runtime::{lock::RwLock, Executor},
-    async_util::timeout,
-    crypto::KeyPair,
-    util::decode,
-};
+use karyon_core::{async_runtime::Executor, async_util::timeout, crypto::KeyPair, util::decode};
 
 use karyon_net::{Conn, Endpoint};
 
 use crate::{
     connector::Connector,
     listener::Listener,
-    message::{
-        get_msg_payload, FindPeerMsg, NetMsg, NetMsgCmd, PeerMsg, PeersMsg, PingMsg, PongMsg,
-        ShutdownMsg,
-    },
+    message::{FindPeerMsg, NetMsg, NetMsgCmd, PeerMsg, PeersMsg, PingMsg, PongMsg, ShutdownMsg},
     monitor::{ConnEvent, DiscvEvent, Monitor},
     routing_table::RoutingTable,
     slots::ConnectionSlots,
@@ -46,7 +39,7 @@ pub struct LookupService {
     outbound_slots: Arc<ConnectionSlots>,
 
     /// Resolved listen endpoint
-    listen_endpoint: Option<RwLock<Endpoint>>,
+    listen_endpoint: RwLock<Option<Endpoint>>,
 
     /// Holds the configuration for the P2P network.
     config: Arc<Config>,
@@ -85,18 +78,13 @@ impl LookupService {
             ex,
         );
 
-        let listen_endpoint = config
-            .listen_endpoint
-            .as_ref()
-            .map(|endpoint| RwLock::new(endpoint.clone()));
-
         Self {
             id: id.clone(),
             table,
             listener,
             connector,
             outbound_slots,
-            listen_endpoint,
+            listen_endpoint: RwLock::new(None),
             config,
             monitor,
         }
@@ -109,10 +97,18 @@ impl LookupService {
     }
 
     /// Set the resolved listen endpoint.
-    pub async fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) {
-        if let Some(endpoint) = &self.listen_endpoint {
-            *endpoint.write().await = resolved_endpoint.clone();
-        }
+    pub fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) -> Result<()> {
+        let resolved_endpoint = Endpoint::Tcp(
+            resolved_endpoint.addr()?.clone(),
+            self.config.discovery_port,
+        );
+        *self.listen_endpoint.write() = Some(resolved_endpoint);
+        Ok(())
+    }
+
+    /// Get the listening endpoint.
+    pub fn listen_endpoint(&self) -> Option<Endpoint> {
+        self.listen_endpoint.read().clone()
     }
 
     /// Shuts down the lookup service.
@@ -253,36 +249,51 @@ impl LookupService {
         target_peer_id: &PeerID,
     ) -> Result<Vec<PeerMsg>> {
         trace!("Send Ping msg");
-        self.send_ping_msg(&conn).await?;
+        let peers;
 
-        trace!("Send FindPeer msg");
-        let peers = self.send_findpeer_msg(&conn, target_peer_id).await?;
+        let ping_msg = self.send_ping_msg(&conn).await?;
 
-        if peers.0.len() >= MAX_PEERS_IN_PEERSMSG {
-            return Err(Error::Lookup("Received too many peers in PeersMsg"));
+        loop {
+            let t = Duration::from_secs(self.config.lookup_response_timeout);
+            let msg: NetMsg = timeout(t, conn.recv()).await??;
+            match msg.header.command {
+                NetMsgCmd::Pong => {
+                    let (pong_msg, _) = decode::<PongMsg>(&msg.payload)?;
+                    if ping_msg.nonce != pong_msg.0 {
+                        return Err(Error::InvalidPongMsg);
+                    }
+                    trace!("Send FindPeer msg");
+                    self.send_findpeer_msg(&conn, target_peer_id).await?;
+                }
+                NetMsgCmd::Peers => {
+                    peers = decode::<PeersMsg>(&msg.payload)?.0.peers;
+                    if peers.len() >= MAX_PEERS_IN_PEERSMSG {
+                        return Err(Error::Lookup("Received too many peers in PeersMsg"));
+                    }
+                    break;
+                }
+                c => return Err(Error::InvalidMsg(format!("Unexpected msg: {:?}", c))),
+            };
         }
 
         trace!("Send Peer msg");
-        if let Some(endpoint) = &self.listen_endpoint {
-            self.send_peer_msg(&conn, endpoint.read().await.clone())
-                .await?;
+        if let Some(endpoint) = self.listen_endpoint() {
+            self.send_peer_msg(&conn, endpoint.clone()).await?;
         }
 
         trace!("Send Shutdown msg");
         self.send_shutdown_msg(&conn).await?;
 
-        Ok(peers.0)
+        Ok(peers)
     }
 
     /// Start a listener.
     async fn start_listener(self: &Arc<Self>) -> Result<()> {
-        let addr = match &self.listen_endpoint {
-            Some(a) => a.read().await.addr()?.clone(),
+        let endpoint: Endpoint = match self.listen_endpoint() {
+            Some(e) => e.clone(),
             None => return Ok(()),
         };
 
-        let endpoint = Endpoint::Tcp(addr, self.config.discovery_port);
-
         let callback = {
             let this = self.clone();
             |conn: Conn<NetMsg>| async move {
@@ -292,7 +303,7 @@ impl LookupService {
             }
         };
 
-        self.listener.start(endpoint.clone(), callback).await?;
+        self.listener.start(endpoint, callback).await?;
         Ok(())
     }
 
@@ -329,10 +340,9 @@ impl LookupService {
         }
     }
 
-    /// Sends a Ping msg and wait to receive the Pong message.
-    async fn send_ping_msg(&self, conn: &Conn<NetMsg>) -> Result<()> {
+    /// Sends a Ping msg.
+    async fn send_ping_msg(&self, conn: &Conn<NetMsg>) -> Result<PingMsg> {
         trace!("Send Pong msg");
-
         let mut nonce: [u8; 32] = [0; 32];
         RngCore::fill_bytes(&mut OsRng, &mut nonce);
 
@@ -341,18 +351,7 @@ impl LookupService {
             nonce,
         };
         conn.send(NetMsg::new(NetMsgCmd::Ping, &ping_msg)?).await?;
-
-        let t = Duration::from_secs(self.config.lookup_response_timeout);
-        let recv_msg: NetMsg = timeout(t, conn.recv()).await??;
-
-        let payload = get_msg_payload!(Pong, recv_msg);
-        let (pong_msg, _) = decode::<PongMsg>(&payload)?;
-
-        if ping_msg.nonce != pong_msg.0 {
-            return Err(Error::InvalidPongMsg);
-        }
-
-        Ok(())
+        Ok(ping_msg)
     }
 
     /// Sends a Pong msg
@@ -363,22 +362,15 @@ impl LookupService {
         Ok(())
     }
 
-    /// Sends a FindPeer msg and wait to receivet the Peers msg.
-    async fn send_findpeer_msg(&self, conn: &Conn<NetMsg>, peer_id: &PeerID) -> Result<PeersMsg> {
+    /// Sends a FindPeer msg
+    async fn send_findpeer_msg(&self, conn: &Conn<NetMsg>, peer_id: &PeerID) -> Result<()> {
         trace!("Send FindPeer msg");
         conn.send(NetMsg::new(
             NetMsgCmd::FindPeer,
             FindPeerMsg(peer_id.clone()),
         )?)
         .await?;
-
-        let t = Duration::from_secs(self.config.lookup_response_timeout);
-        let recv_msg: NetMsg = timeout(t, conn.recv()).await??;
-
-        let payload = get_msg_payload!(Peers, recv_msg);
-        let (peers, _) = decode(&payload)?;
-
-        Ok(peers)
+        Ok(())
     }
 
     /// Sends a Peers msg.
@@ -389,7 +381,7 @@ impl LookupService {
             .closest_entries(&peer_id.0, MAX_PEERS_IN_PEERSMSG);
 
         let peers: Vec<PeerMsg> = entries.into_iter().map(|e| e.into()).collect();
-        conn.send(NetMsg::new(NetMsgCmd::Peers, PeersMsg(peers))?)
+        conn.send(NetMsg::new(NetMsgCmd::Peers, PeersMsg { peers })?)
             .await?;
         Ok(())
     }
diff --git a/p2p/src/discovery/mod.rs b/p2p/src/discovery/mod.rs
index a9d99d6..a81a817 100644
--- a/p2p/src/discovery/mod.rs
+++ b/p2p/src/discovery/mod.rs
@@ -16,7 +16,8 @@ use karyon_net::{Conn, Endpoint};
 
 use crate::{
     config::Config,
-    conn_queue::{ConnDirection, ConnQueue},
+    conn_queue::ConnQueue,
+    connection::ConnDirection,
     connector::Connector,
     listener::Listener,
     message::NetMsg,
@@ -132,15 +133,11 @@ impl Discovery {
 
             let resolved_endpoint = self.start_listener(endpoint).await?;
 
-            if endpoint.addr()? != resolved_endpoint.addr()? {
-                info!("Resolved listen endpoint: {resolved_endpoint}");
-                self.lookup_service
-                    .set_listen_endpoint(&resolved_endpoint)
-                    .await;
-                self.refresh_service
-                    .set_listen_endpoint(&resolved_endpoint)
-                    .await;
-            }
+            info!("Resolved listen endpoint: {resolved_endpoint}");
+            self.lookup_service
+                .set_listen_endpoint(&resolved_endpoint)?;
+            self.refresh_service
+                .set_listen_endpoint(&resolved_endpoint)?;
         }
 
         // Start the lookup service
diff --git a/p2p/src/discovery/refresh.rs b/p2p/src/discovery/refresh.rs
index b4f5396..1452a1b 100644
--- a/p2p/src/discovery/refresh.rs
+++ b/p2p/src/discovery/refresh.rs
@@ -2,10 +2,11 @@ use std::{sync::Arc, time::Duration};
 
 use bincode::{Decode, Encode};
 use log::{error, info, trace};
+use parking_lot::RwLock;
 use rand::{rngs::OsRng, RngCore};
 
 use karyon_core::{
-    async_runtime::{lock::RwLock, Executor},
+    async_runtime::Executor,
     async_util::{sleep, timeout, Backoff, TaskGroup, TaskResult},
 };
 
@@ -33,7 +34,7 @@ pub struct RefreshService {
     table: Arc<RoutingTable>,
 
     /// Resolved listen endpoint
-    listen_endpoint: Option<RwLock<Endpoint>>,
+    listen_endpoint: RwLock<Option<Endpoint>>,
 
     /// Managing spawned tasks.
     task_group: TaskGroup,
@@ -53,14 +54,9 @@ impl RefreshService {
         monitor: Arc<Monitor>,
         executor: Executor,
     ) -> Self {
-        let listen_endpoint = config
-            .listen_endpoint
-            .as_ref()
-            .map(|endpoint| RwLock::new(endpoint.clone()));
-
         Self {
             table,
-            listen_endpoint,
+            listen_endpoint: RwLock::new(None),
             task_group: TaskGroup::with_executor(executor.clone()),
             config,
             monitor,
@@ -69,9 +65,8 @@ impl RefreshService {
 
     /// Start the refresh service
     pub async fn start(self: &Arc<Self>) -> Result<()> {
-        if let Some(endpoint) = &self.listen_endpoint {
-            let endpoint = endpoint.read().await.clone();
-
+        if let Some(endpoint) = self.listen_endpoint.read().as_ref() {
+            let endpoint = endpoint.clone();
             self.task_group.spawn(
                 {
                     let this = self.clone();
@@ -101,10 +96,13 @@ impl RefreshService {
     }
 
     /// Set the resolved listen endpoint.
-    pub async fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) {
-        if let Some(endpoint) = &self.listen_endpoint {
-            *endpoint.write().await = resolved_endpoint.clone();
-        }
+    pub fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) -> Result<()> {
+        let resolved_endpoint = Endpoint::Udp(
+            resolved_endpoint.addr()?.clone(),
+            self.config.discovery_port,
+        );
+        *self.listen_endpoint.write() = Some(resolved_endpoint);
+        Ok(())
     }
 
     /// Shuts down the refresh service
-- 
cgit v1.2.3