From e15d3e6fd20b3f87abaad7ddec1c88b0e66419f9 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Mon, 15 Jul 2024 13:16:01 +0200 Subject: p2p: Major refactoring of the handshake protocol Introduce a new protocol InitProtocol which can be used as the core protocol for initializing a connection with a peer. Move the handshake logic from the PeerPool module to the protocols directory and build a handshake protocol that implements InitProtocol trait. --- p2p/src/conn_queue.rs | 53 +++----- p2p/src/connection.rs | 110 +++++++++++++++ p2p/src/discovery/lookup.rs | 114 ++++++++-------- p2p/src/discovery/mod.rs | 17 +-- p2p/src/discovery/refresh.rs | 28 ++-- p2p/src/lib.rs | 1 + p2p/src/message.rs | 20 +-- p2p/src/monitor/mod.rs | 6 +- p2p/src/peer/mod.rs | 247 ++++++++++++++++------------------ p2p/src/peer_pool.rs | 296 +++++++++++------------------------------ p2p/src/protocol.rs | 14 +- p2p/src/protocols/handshake.rs | 139 +++++++++++++++++++ p2p/src/protocols/mod.rs | 4 +- p2p/src/protocols/ping.rs | 39 +++--- 14 files changed, 571 insertions(+), 517 deletions(-) create mode 100644 p2p/src/connection.rs create mode 100644 p2p/src/protocols/handshake.rs (limited to 'p2p/src') diff --git a/p2p/src/conn_queue.rs b/p2p/src/conn_queue.rs index 9a153f3..1b6ef98 100644 --- a/p2p/src/conn_queue.rs +++ b/p2p/src/conn_queue.rs @@ -1,37 +1,13 @@ -use std::{collections::VecDeque, fmt, sync::Arc}; - -use async_channel::Sender; +use std::{collections::VecDeque, sync::Arc}; use karyon_core::{async_runtime::lock::Mutex, async_util::CondVar}; use karyon_net::Conn; -use crate::{message::NetMsg, Result}; - -/// Defines the direction of a network connection. -#[derive(Clone, Debug)] -pub enum ConnDirection { - Inbound, - Outbound, -} - -impl fmt::Display for ConnDirection { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - ConnDirection::Inbound => write!(f, "Inbound"), - ConnDirection::Outbound => write!(f, "Outbound"), - } - } -} - -pub struct NewConn { - pub direction: ConnDirection, - pub conn: Conn, - pub disconnect_signal: Sender>, -} +use crate::{connection::ConnDirection, connection::Connection, message::NetMsg, Result}; /// Connection queue pub struct ConnQueue { - queue: Mutex>, + queue: Mutex>, conn_available: CondVar, } @@ -43,24 +19,27 @@ impl ConnQueue { }) } - /// Push a connection into the queue and wait for the disconnect signal + /// Handle a connection by pushing it into the queue and wait for the disconnect signal pub async fn handle(&self, conn: Conn, direction: ConnDirection) -> Result<()> { - let (disconnect_signal, chan) = async_channel::bounded(1); - let new_conn = NewConn { - direction, - conn, - disconnect_signal, - }; + let endpoint = conn.peer_endpoint()?; + + let (disconnect_tx, disconnect_rx) = async_channel::bounded(1); + let new_conn = Connection::new(conn, disconnect_tx, direction, endpoint); + + // Push a new conn to the queue self.queue.lock().await.push_back(new_conn); self.conn_available.signal(); - if let Ok(result) = chan.recv().await { + + // Wait for the disconnect signal from the connection handler + if let Ok(result) = disconnect_rx.recv().await { return result; } + Ok(()) } - /// Receive the next connection in the queue - pub async fn next(&self) -> NewConn { + /// Waits for the next connection in the queue + pub async fn next(&self) -> Connection { let mut queue = self.queue.lock().await; while queue.is_empty() { queue = self.conn_available.wait(queue).await; diff --git a/p2p/src/connection.rs b/p2p/src/connection.rs new file mode 100644 index 0000000..52190a8 --- /dev/null +++ b/p2p/src/connection.rs @@ -0,0 +1,110 @@ +use std::{collections::HashMap, fmt, sync::Arc}; + +use async_channel::Sender; +use bincode::Encode; + +use karyon_core::{ + event::{EventListener, EventSys}, + util::encode, +}; + +use karyon_net::{Conn, Endpoint}; + +use crate::{ + message::{NetMsg, NetMsgCmd, ProtocolMsg, ShutdownMsg}, + protocol::{Protocol, ProtocolEvent, ProtocolID}, + Error, Result, +}; + +/// Defines the direction of a network connection. +#[derive(Clone, Debug)] +pub enum ConnDirection { + Inbound, + Outbound, +} + +impl fmt::Display for ConnDirection { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ConnDirection::Inbound => write!(f, "Inbound"), + ConnDirection::Outbound => write!(f, "Outbound"), + } + } +} + +pub struct Connection { + pub(crate) direction: ConnDirection, + conn: Conn, + disconnect_signal: Sender>, + /// `EventSys` responsible for sending events to the registered protocols. + protocol_events: Arc>, + pub(crate) remote_endpoint: Endpoint, + listeners: HashMap>, +} + +impl Connection { + pub fn new( + conn: Conn, + signal: Sender>, + direction: ConnDirection, + remote_endpoint: Endpoint, + ) -> Self { + Self { + conn, + direction, + protocol_events: EventSys::new(), + disconnect_signal: signal, + remote_endpoint, + listeners: HashMap::new(), + } + } + + pub async fn send(&self, protocol_id: ProtocolID, msg: T) -> Result<()> { + let payload = encode(&msg)?; + + let proto_msg = ProtocolMsg { + protocol_id, + payload: payload.to_vec(), + }; + + let msg = NetMsg::new(NetMsgCmd::Protocol, &proto_msg)?; + self.conn.send(msg).await.map_err(Error::from) + } + + pub async fn recv(&self) -> Result { + match self.listeners.get(&P::id()) { + Some(l) => l.recv().await.map_err(Error::from), + // TODO + None => todo!(), + } + } + + /// Registers a listener for the given Protocol `P`. + pub async fn register_protocol(&mut self, protocol_id: String) { + let listener = self.protocol_events.register(&protocol_id).await; + self.listeners.insert(protocol_id, listener); + } + + pub async fn emit_msg(&self, id: &ProtocolID, event: &ProtocolEvent) -> Result<()> { + self.protocol_events.emit_by_topic(id, event).await?; + Ok(()) + } + + pub async fn recv_inner(&self) -> Result { + self.conn.recv().await.map_err(Error::from) + } + + pub async fn send_inner(&self, msg: NetMsg) -> Result<()> { + self.conn.send(msg).await.map_err(Error::from) + } + + pub async fn disconnect(&self, res: Result<()>) -> Result<()> { + self.protocol_events.clear().await; + self.disconnect_signal.send(res).await?; + + let m = NetMsg::new(NetMsgCmd::Shutdown, ShutdownMsg(0)).expect("Create shutdown message"); + self.conn.send(m).await.map_err(Error::from)?; + + Ok(()) + } +} diff --git a/p2p/src/discovery/lookup.rs b/p2p/src/discovery/lookup.rs index 9ddf614..47a1d09 100644 --- a/p2p/src/discovery/lookup.rs +++ b/p2p/src/discovery/lookup.rs @@ -2,24 +2,17 @@ use std::{sync::Arc, time::Duration}; use futures_util::stream::{FuturesUnordered, StreamExt}; use log::{error, trace}; +use parking_lot::RwLock; use rand::{rngs::OsRng, seq::SliceRandom, RngCore}; -use karyon_core::{ - async_runtime::{lock::RwLock, Executor}, - async_util::timeout, - crypto::KeyPair, - util::decode, -}; +use karyon_core::{async_runtime::Executor, async_util::timeout, crypto::KeyPair, util::decode}; use karyon_net::{Conn, Endpoint}; use crate::{ connector::Connector, listener::Listener, - message::{ - get_msg_payload, FindPeerMsg, NetMsg, NetMsgCmd, PeerMsg, PeersMsg, PingMsg, PongMsg, - ShutdownMsg, - }, + message::{FindPeerMsg, NetMsg, NetMsgCmd, PeerMsg, PeersMsg, PingMsg, PongMsg, ShutdownMsg}, monitor::{ConnEvent, DiscvEvent, Monitor}, routing_table::RoutingTable, slots::ConnectionSlots, @@ -46,7 +39,7 @@ pub struct LookupService { outbound_slots: Arc, /// Resolved listen endpoint - listen_endpoint: Option>, + listen_endpoint: RwLock>, /// Holds the configuration for the P2P network. config: Arc, @@ -85,18 +78,13 @@ impl LookupService { ex, ); - let listen_endpoint = config - .listen_endpoint - .as_ref() - .map(|endpoint| RwLock::new(endpoint.clone())); - Self { id: id.clone(), table, listener, connector, outbound_slots, - listen_endpoint, + listen_endpoint: RwLock::new(None), config, monitor, } @@ -109,10 +97,18 @@ impl LookupService { } /// Set the resolved listen endpoint. - pub async fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) { - if let Some(endpoint) = &self.listen_endpoint { - *endpoint.write().await = resolved_endpoint.clone(); - } + pub fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) -> Result<()> { + let resolved_endpoint = Endpoint::Tcp( + resolved_endpoint.addr()?.clone(), + self.config.discovery_port, + ); + *self.listen_endpoint.write() = Some(resolved_endpoint); + Ok(()) + } + + /// Get the listening endpoint. + pub fn listen_endpoint(&self) -> Option { + self.listen_endpoint.read().clone() } /// Shuts down the lookup service. @@ -253,36 +249,51 @@ impl LookupService { target_peer_id: &PeerID, ) -> Result> { trace!("Send Ping msg"); - self.send_ping_msg(&conn).await?; + let peers; - trace!("Send FindPeer msg"); - let peers = self.send_findpeer_msg(&conn, target_peer_id).await?; + let ping_msg = self.send_ping_msg(&conn).await?; - if peers.0.len() >= MAX_PEERS_IN_PEERSMSG { - return Err(Error::Lookup("Received too many peers in PeersMsg")); + loop { + let t = Duration::from_secs(self.config.lookup_response_timeout); + let msg: NetMsg = timeout(t, conn.recv()).await??; + match msg.header.command { + NetMsgCmd::Pong => { + let (pong_msg, _) = decode::(&msg.payload)?; + if ping_msg.nonce != pong_msg.0 { + return Err(Error::InvalidPongMsg); + } + trace!("Send FindPeer msg"); + self.send_findpeer_msg(&conn, target_peer_id).await?; + } + NetMsgCmd::Peers => { + peers = decode::(&msg.payload)?.0.peers; + if peers.len() >= MAX_PEERS_IN_PEERSMSG { + return Err(Error::Lookup("Received too many peers in PeersMsg")); + } + break; + } + c => return Err(Error::InvalidMsg(format!("Unexpected msg: {:?}", c))), + }; } trace!("Send Peer msg"); - if let Some(endpoint) = &self.listen_endpoint { - self.send_peer_msg(&conn, endpoint.read().await.clone()) - .await?; + if let Some(endpoint) = self.listen_endpoint() { + self.send_peer_msg(&conn, endpoint.clone()).await?; } trace!("Send Shutdown msg"); self.send_shutdown_msg(&conn).await?; - Ok(peers.0) + Ok(peers) } /// Start a listener. async fn start_listener(self: &Arc) -> Result<()> { - let addr = match &self.listen_endpoint { - Some(a) => a.read().await.addr()?.clone(), + let endpoint: Endpoint = match self.listen_endpoint() { + Some(e) => e.clone(), None => return Ok(()), }; - let endpoint = Endpoint::Tcp(addr, self.config.discovery_port); - let callback = { let this = self.clone(); |conn: Conn| async move { @@ -292,7 +303,7 @@ impl LookupService { } }; - self.listener.start(endpoint.clone(), callback).await?; + self.listener.start(endpoint, callback).await?; Ok(()) } @@ -329,10 +340,9 @@ impl LookupService { } } - /// Sends a Ping msg and wait to receive the Pong message. - async fn send_ping_msg(&self, conn: &Conn) -> Result<()> { + /// Sends a Ping msg. + async fn send_ping_msg(&self, conn: &Conn) -> Result { trace!("Send Pong msg"); - let mut nonce: [u8; 32] = [0; 32]; RngCore::fill_bytes(&mut OsRng, &mut nonce); @@ -341,18 +351,7 @@ impl LookupService { nonce, }; conn.send(NetMsg::new(NetMsgCmd::Ping, &ping_msg)?).await?; - - let t = Duration::from_secs(self.config.lookup_response_timeout); - let recv_msg: NetMsg = timeout(t, conn.recv()).await??; - - let payload = get_msg_payload!(Pong, recv_msg); - let (pong_msg, _) = decode::(&payload)?; - - if ping_msg.nonce != pong_msg.0 { - return Err(Error::InvalidPongMsg); - } - - Ok(()) + Ok(ping_msg) } /// Sends a Pong msg @@ -363,22 +362,15 @@ impl LookupService { Ok(()) } - /// Sends a FindPeer msg and wait to receivet the Peers msg. - async fn send_findpeer_msg(&self, conn: &Conn, peer_id: &PeerID) -> Result { + /// Sends a FindPeer msg + async fn send_findpeer_msg(&self, conn: &Conn, peer_id: &PeerID) -> Result<()> { trace!("Send FindPeer msg"); conn.send(NetMsg::new( NetMsgCmd::FindPeer, FindPeerMsg(peer_id.clone()), )?) .await?; - - let t = Duration::from_secs(self.config.lookup_response_timeout); - let recv_msg: NetMsg = timeout(t, conn.recv()).await??; - - let payload = get_msg_payload!(Peers, recv_msg); - let (peers, _) = decode(&payload)?; - - Ok(peers) + Ok(()) } /// Sends a Peers msg. @@ -389,7 +381,7 @@ impl LookupService { .closest_entries(&peer_id.0, MAX_PEERS_IN_PEERSMSG); let peers: Vec = entries.into_iter().map(|e| e.into()).collect(); - conn.send(NetMsg::new(NetMsgCmd::Peers, PeersMsg(peers))?) + conn.send(NetMsg::new(NetMsgCmd::Peers, PeersMsg { peers })?) .await?; Ok(()) } diff --git a/p2p/src/discovery/mod.rs b/p2p/src/discovery/mod.rs index a9d99d6..a81a817 100644 --- a/p2p/src/discovery/mod.rs +++ b/p2p/src/discovery/mod.rs @@ -16,7 +16,8 @@ use karyon_net::{Conn, Endpoint}; use crate::{ config::Config, - conn_queue::{ConnDirection, ConnQueue}, + conn_queue::ConnQueue, + connection::ConnDirection, connector::Connector, listener::Listener, message::NetMsg, @@ -132,15 +133,11 @@ impl Discovery { let resolved_endpoint = self.start_listener(endpoint).await?; - if endpoint.addr()? != resolved_endpoint.addr()? { - info!("Resolved listen endpoint: {resolved_endpoint}"); - self.lookup_service - .set_listen_endpoint(&resolved_endpoint) - .await; - self.refresh_service - .set_listen_endpoint(&resolved_endpoint) - .await; - } + info!("Resolved listen endpoint: {resolved_endpoint}"); + self.lookup_service + .set_listen_endpoint(&resolved_endpoint)?; + self.refresh_service + .set_listen_endpoint(&resolved_endpoint)?; } // Start the lookup service diff --git a/p2p/src/discovery/refresh.rs b/p2p/src/discovery/refresh.rs index b4f5396..1452a1b 100644 --- a/p2p/src/discovery/refresh.rs +++ b/p2p/src/discovery/refresh.rs @@ -2,10 +2,11 @@ use std::{sync::Arc, time::Duration}; use bincode::{Decode, Encode}; use log::{error, info, trace}; +use parking_lot::RwLock; use rand::{rngs::OsRng, RngCore}; use karyon_core::{ - async_runtime::{lock::RwLock, Executor}, + async_runtime::Executor, async_util::{sleep, timeout, Backoff, TaskGroup, TaskResult}, }; @@ -33,7 +34,7 @@ pub struct RefreshService { table: Arc, /// Resolved listen endpoint - listen_endpoint: Option>, + listen_endpoint: RwLock>, /// Managing spawned tasks. task_group: TaskGroup, @@ -53,14 +54,9 @@ impl RefreshService { monitor: Arc, executor: Executor, ) -> Self { - let listen_endpoint = config - .listen_endpoint - .as_ref() - .map(|endpoint| RwLock::new(endpoint.clone())); - Self { table, - listen_endpoint, + listen_endpoint: RwLock::new(None), task_group: TaskGroup::with_executor(executor.clone()), config, monitor, @@ -69,9 +65,8 @@ impl RefreshService { /// Start the refresh service pub async fn start(self: &Arc) -> Result<()> { - if let Some(endpoint) = &self.listen_endpoint { - let endpoint = endpoint.read().await.clone(); - + if let Some(endpoint) = self.listen_endpoint.read().as_ref() { + let endpoint = endpoint.clone(); self.task_group.spawn( { let this = self.clone(); @@ -101,10 +96,13 @@ impl RefreshService { } /// Set the resolved listen endpoint. - pub async fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) { - if let Some(endpoint) = &self.listen_endpoint { - *endpoint.write().await = resolved_endpoint.clone(); - } + pub fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) -> Result<()> { + let resolved_endpoint = Endpoint::Udp( + resolved_endpoint.addr()?.clone(), + self.config.discovery_port, + ); + *self.listen_endpoint.write() = Some(resolved_endpoint); + Ok(()) } /// Shuts down the refresh service diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs index b21a353..f0dc725 100644 --- a/p2p/src/lib.rs +++ b/p2p/src/lib.rs @@ -41,6 +41,7 @@ mod backend; mod codec; mod config; mod conn_queue; +mod connection; mod connector; mod discovery; mod error; diff --git a/p2p/src/message.rs b/p2p/src/message.rs index 6498ef7..5bf0853 100644 --- a/p2p/src/message.rs +++ b/p2p/src/message.rs @@ -110,7 +110,9 @@ pub struct PeerMsg { /// PeersMsg a list of `PeerMsg`. #[derive(Decode, Encode, Debug)] -pub struct PeersMsg(pub Vec); +pub struct PeersMsg { + pub peers: Vec, +} impl From for PeerMsg { fn from(entry: Entry) -> PeerMsg { @@ -133,19 +135,3 @@ impl From for Entry { } } } - -macro_rules! get_msg_payload { - ($a:ident, $b:ident) => { - if let NetMsgCmd::$a = $b.header.command { - $b.payload - } else { - return Err(Error::InvalidMsg(format!( - "Expected {:?} msg found {:?} msg", - stringify!($a), - $b.header.command - ))); - } - }; -} - -pub(super) use get_msg_payload; diff --git a/p2p/src/monitor/mod.rs b/p2p/src/monitor/mod.rs index 4ecb431..86db23e 100644 --- a/p2p/src/monitor/mod.rs +++ b/p2p/src/monitor/mod.rs @@ -2,6 +2,8 @@ mod event; use std::sync::Arc; +use log::error; + use karyon_core::event::{EventListener, EventSys, EventValue, EventValueTopic}; use karyon_net::Endpoint; @@ -62,7 +64,9 @@ impl Monitor { pub(crate) async fn notify(&self, event: E) { if self.config.enable_monitor { let event = event.to_struct(); - self.event_sys.emit(&event).await + if let Err(err) = self.event_sys.emit(&event).await { + error!("Failed to notify monitor event {:?}: {err}", event); + } } } diff --git a/p2p/src/peer/mod.rs b/p2p/src/peer/mod.rs index 6903294..a5ac7ad 100644 --- a/p2p/src/peer/mod.rs +++ b/p2p/src/peer/mod.rs @@ -1,138 +1,111 @@ mod peer_id; -pub use peer_id::PeerID; - use std::sync::{Arc, Weak}; use async_channel::{Receiver, Sender}; -use bincode::{Decode, Encode}; +use bincode::Encode; use log::{error, trace}; +use parking_lot::RwLock; use karyon_core::{ - async_runtime::{lock::RwLock, Executor}, + async_runtime::Executor, async_util::{select, Either, TaskGroup, TaskResult}, - event::{EventListener, EventSys}, - util::{decode, encode}, + util::decode, }; -use karyon_net::{Conn, Endpoint}; - use crate::{ - conn_queue::ConnDirection, - message::{NetMsg, NetMsgCmd, ProtocolMsg, ShutdownMsg}, + connection::{ConnDirection, Connection}, + endpoint::Endpoint, + message::{NetMsgCmd, ProtocolMsg}, peer_pool::PeerPool, - protocol::{Protocol, ProtocolEvent, ProtocolID}, + protocol::{InitProtocol, Protocol, ProtocolEvent, ProtocolID}, + protocols::HandshakeProtocol, Config, Error, Result, }; +pub use peer_id::PeerID; + pub struct Peer { + /// Own ID + own_id: PeerID, + /// Peer's ID - id: PeerID, + id: RwLock>, - /// A weak pointer to `PeerPool` + /// A weak pointer to [`PeerPool`] peer_pool: Weak, /// Holds the peer connection - conn: Conn, - - /// Remote endpoint for the peer - remote_endpoint: Endpoint, - - /// The direction of the connection, either `Inbound` or `Outbound` - conn_direction: ConnDirection, - - /// A list of protocol IDs - protocol_ids: RwLock>, - - /// `EventSys` responsible for sending events to the protocols. - protocol_events: Arc>, + pub(crate) conn: Arc, /// This channel is used to send a stop signal to the read loop. stop_chan: (Sender>, Receiver>), + /// The Configuration for the P2P network. + config: Arc, + /// Managing spawned tasks. task_group: TaskGroup, } impl Peer { /// Creates a new peer - pub fn new( + pub(crate) fn new( + own_id: PeerID, peer_pool: Weak, - id: &PeerID, - conn: Conn, - remote_endpoint: Endpoint, - conn_direction: ConnDirection, + conn: Arc, + config: Arc, ex: Executor, ) -> Arc { Arc::new(Peer { - id: id.clone(), + own_id, + id: RwLock::new(None), peer_pool, conn, - protocol_ids: RwLock::new(Vec::new()), - remote_endpoint, - conn_direction, - protocol_events: EventSys::new(), + config, task_group: TaskGroup::with_executor(ex), stop_chan: async_channel::bounded(1), }) } - /// Run the peer - pub async fn run(self: Arc) -> Result<()> { - self.start_protocols().await; - self.read_loop().await + /// Send a msg to this peer connection using the specified protocol. + pub async fn send(&self, proto_id: ProtocolID, msg: T) -> Result<()> { + self.conn.send(proto_id, msg).await } - /// Send a message to the peer connection using the specified protocol. - pub async fn send(&self, protocol_id: &ProtocolID, msg: &T) -> Result<()> { - let payload = encode(msg)?; - - let proto_msg = ProtocolMsg { - protocol_id: protocol_id.to_string(), - payload: payload.to_vec(), - }; - - self.conn - .send(NetMsg::new(NetMsgCmd::Protocol, &proto_msg)?) - .await?; - Ok(()) + /// Receives a new msg from this peer connection. + pub async fn recv(&self) -> Result { + self.conn.recv::

().await } /// Broadcast a message to all connected peers using the specified protocol. - pub async fn broadcast(&self, protocol_id: &ProtocolID, msg: &T) { - self.peer_pool().broadcast(protocol_id, msg).await; + pub async fn broadcast(&self, proto_id: &ProtocolID, msg: &T) { + self.peer_pool().broadcast(proto_id, msg).await; } - /// Shuts down the peer - pub async fn shutdown(&self) { - trace!("peer {} start shutting down", self.id); - - // Send shutdown event to all protocols - for protocol_id in self.protocol_ids.read().await.iter() { - self.protocol_events - .emit_by_topic(protocol_id, &ProtocolEvent::Shutdown) - .await; - } + /// Returns the peer's ID + pub fn id(&self) -> Option { + self.id.read().clone() + } - // Send a stop signal to the read loop - // - // No need to handle the error here; a dropped channel and - // sending a stop signal have the same effect. - let _ = self.stop_chan.0.try_send(Ok(())); + /// Returns own ID + pub fn own_id(&self) -> &PeerID { + &self.own_id + } - // No need to handle the error here - let shutdown_msg = - NetMsg::new(NetMsgCmd::Shutdown, ShutdownMsg(0)).expect("pack shutdown message"); - let _ = self.conn.send(shutdown_msg).await; + /// Returns the [`Config`] + pub fn config(&self) -> Arc { + self.config.clone() + } - // Force shutting down - self.task_group.cancel().await; + /// Returns the remote endpoint for the peer + pub fn remote_endpoint(&self) -> &Endpoint { + &self.conn.remote_endpoint } /// Check if the connection is Inbound - #[inline] pub fn is_inbound(&self) -> bool { - match self.conn_direction { + match self.conn.direction { ConnDirection::Inbound => true, ConnDirection::Outbound => false, } @@ -140,40 +113,82 @@ impl Peer { /// Returns the direction of the connection, which can be either `Inbound` /// or `Outbound`. - #[inline] pub fn direction(&self) -> &ConnDirection { - &self.conn_direction + &self.conn.direction } - /// Returns the remote endpoint for the peer - #[inline] - pub fn remote_endpoint(&self) -> &Endpoint { - &self.remote_endpoint + pub(crate) async fn init(self: &Arc) -> Result<()> { + let handshake_protocol = HandshakeProtocol::new( + self.clone(), + self.peer_pool().protocol_versions.read().await.clone(), + ); + + let pid = handshake_protocol.init().await?; + *self.id.write() = Some(pid); + + Ok(()) } - /// Return the peer's ID - #[inline] - pub fn id(&self) -> &PeerID { - &self.id + /// Run the peer + pub(crate) async fn run(self: Arc) -> Result<()> { + self.run_connect_protocols().await; + self.read_loop().await } - /// Returns the `Config` instance. - pub fn config(&self) -> Arc { - self.peer_pool().config.clone() + /// Shuts down the peer + pub(crate) async fn shutdown(self: &Arc) -> Result<()> { + trace!("peer {:?} shutting down", self.id()); + + // Send shutdown event to the attached protocols + for proto_id in self.peer_pool().protocols.read().await.keys() { + let _ = self.conn.emit_msg(proto_id, &ProtocolEvent::Shutdown).await; + } + + // Send a stop signal to the read loop + // + // No need to handle the error here; a dropped channel and + // sendig a stop signal have the same effect. + let _ = self.stop_chan.0.try_send(Ok(())); + + self.conn.disconnect(Ok(())).await?; + + // Force shutting down + self.task_group.cancel().await; + Ok(()) } - /// Registers a listener for the given Protocol `P`. - pub async fn register_listener(&self) -> EventListener { - self.protocol_events.register(&P::id()).await + /// Run running the Connect Protocols for this peer connection. + async fn run_connect_protocols(self: &Arc) { + for (proto_id, constructor) in self.peer_pool().protocols.read().await.iter() { + trace!("peer {:?} run protocol {proto_id}", self.id()); + + let protocol = constructor(self.clone()); + + let on_failure = { + let this = self.clone(); + let proto_id = proto_id.clone(); + |result: TaskResult>| async move { + if let TaskResult::Completed(res) = result { + if res.is_err() { + error!("protocol {} stopped", proto_id); + } + // Send a stop signal to read loop + let _ = this.stop_chan.0.try_send(res); + } + } + }; + + self.task_group.spawn(protocol.start(), on_failure); + } } - /// Start a read loop to handle incoming messages from the peer connection. + /// Run a read loop to handle incoming messages from the peer connection. async fn read_loop(&self) -> Result<()> { loop { - let fut = select(self.stop_chan.1.recv(), self.conn.recv()).await; + let fut = select(self.stop_chan.1.recv(), self.conn.recv_inner()).await; let result = match fut { Either::Left(stop_signal) => { - trace!("Peer {} received a stop signal", self.id); + trace!("Peer {:?} received a stop signal", self.id()); return stop_signal?; } Either::Right(result) => result, @@ -184,14 +199,9 @@ impl Peer { match msg.header.command { NetMsgCmd::Protocol => { let msg: ProtocolMsg = decode(&msg.payload)?.0; - - if !self.protocol_ids.read().await.contains(&msg.protocol_id) { - return Err(Error::UnsupportedProtocol(msg.protocol_id)); - } - - let proto_id = &msg.protocol_id; - let msg = ProtocolEvent::Message(msg.payload); - self.protocol_events.emit_by_topic(proto_id, &msg).await; + self.conn + .emit_msg(&msg.protocol_id, &ProtocolEvent::Message(msg.payload)) + .await?; } NetMsgCmd::Shutdown => { return Err(Error::PeerShutdown); @@ -201,32 +211,7 @@ impl Peer { } } - /// Start running the protocols for this peer connection. - async fn start_protocols(self: &Arc) { - for (protocol_id, constructor) in self.peer_pool().protocols.read().await.iter() { - trace!("peer {} start protocol {protocol_id}", self.id); - let protocol = constructor(self.clone()); - - self.protocol_ids.write().await.push(protocol_id.clone()); - - let on_failure = { - let this = self.clone(); - let protocol_id = protocol_id.clone(); - |result: TaskResult>| async move { - if let TaskResult::Completed(res) = result { - if res.is_err() { - error!("protocol {} stopped", protocol_id); - } - // Send a stop signal to read loop - let _ = this.stop_chan.0.try_send(res); - } - } - }; - - self.task_group.spawn(protocol.start(), on_failure); - } - } - + /// Returns `PeerPool` pointer fn peer_pool(&self) -> Arc { self.peer_pool.upgrade().unwrap() } diff --git a/p2p/src/peer_pool.rs b/p2p/src/peer_pool.rs index 1f3ca55..549dc76 100644 --- a/p2p/src/peer_pool.rs +++ b/p2p/src/peer_pool.rs @@ -1,26 +1,24 @@ -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{collections::HashMap, sync::Arc}; -use async_channel::Sender; -use bincode::{Decode, Encode}; -use log::{error, info, trace, warn}; +use bincode::Encode; +use log::{error, info, warn}; use karyon_core::{ async_runtime::{lock::RwLock, Executor}, - async_util::{timeout, TaskGroup, TaskResult}, - util::decode, + async_util::{TaskGroup, TaskResult}, }; -use karyon_net::{Conn, Endpoint}; +use karyon_net::Endpoint; use crate::{ config::Config, - conn_queue::{ConnDirection, ConnQueue}, - message::{get_msg_payload, NetMsg, NetMsgCmd, VerAckMsg, VerMsg}, + conn_queue::ConnQueue, + connection::Connection, monitor::{Monitor, PPEvent}, peer::Peer, protocol::{Protocol, ProtocolConstructor, ProtocolID}, protocols::PingProtocol, - version::{version_match, Version, VersionInt}, + version::Version, Error, PeerID, Result, }; @@ -37,8 +35,8 @@ pub struct PeerPool { /// Hashmap contains protocol constructors. pub(crate) protocols: RwLock>>, - /// Hashmap contains protocol IDs and their versions. - protocol_versions: Arc>>, + /// Hashmap contains protocols with their versions + pub(crate) protocol_versions: RwLock>, /// Managing spawned tasks. task_group: TaskGroup, @@ -47,7 +45,7 @@ pub struct PeerPool { executor: Executor, /// The Configuration for the P2P network. - pub(crate) config: Arc, + config: Arc, /// Responsible for network and system monitoring. monitor: Arc, @@ -62,15 +60,12 @@ impl PeerPool { monitor: Arc, executor: Executor, ) -> Arc { - let protocols = RwLock::new(HashMap::new()); - let protocol_versions = Arc::new(RwLock::new(HashMap::new())); - Arc::new(Self { id: id.clone(), conn_queue, peers: RwLock::new(HashMap::new()), - protocols, - protocol_versions, + protocols: RwLock::new(HashMap::new()), + protocol_versions: RwLock::new(HashMap::new()), task_group: TaskGroup::with_executor(executor.clone()), executor, monitor, @@ -80,21 +75,15 @@ impl PeerPool { /// Starts the [`PeerPool`] pub async fn start(self: &Arc) -> Result<()> { - self.setup_protocols().await?; - self.task_group.spawn( - { - let this = self.clone(); - async move { this.listen_loop().await } - }, - |_| async {}, - ); + self.setup_core_protocols().await?; + self.task_group.spawn(self.clone().run(), |_| async {}); Ok(()) } /// Shuts down pub async fn shutdown(&self) { for (_, peer) in self.peers.read().await.iter() { - peer.shutdown().await; + let _ = peer.shutdown().await; } self.task_group.cancel().await; @@ -102,76 +91,24 @@ impl PeerPool { /// Attach a custom protocol to the network pub async fn attach_protocol(&self, c: Box) -> Result<()> { - let protocol_versions = &mut self.protocol_versions.write().await; - let protocols = &mut self.protocols.write().await; - - protocol_versions.insert(P::id(), P::version()?); - protocols.insert(P::id(), c); + self.protocols.write().await.insert(P::id(), c); + self.protocol_versions + .write() + .await + .insert(P::id(), P::version()?); Ok(()) } /// Broadcast a message to all connected peers using the specified protocol. - pub async fn broadcast(&self, proto_id: &ProtocolID, msg: &T) { + pub async fn broadcast(&self, proto_id: &ProtocolID, msg: &T) { for (pid, peer) in self.peers.read().await.iter() { - if let Err(err) = peer.send(proto_id, msg).await { + if let Err(err) = peer.conn.send(proto_id.to_string(), msg).await { error!("failed to send msg to {pid}: {err}"); continue; } } } - /// Add a new peer to the peer list. - pub async fn new_peer( - self: &Arc, - conn: Conn, - conn_direction: &ConnDirection, - disconnect_signal: Sender>, - ) -> Result<()> { - let endpoint = conn.peer_endpoint()?; - - // Do a handshake with the connection before creating a new peer. - let pid = self.do_handshake(&conn, conn_direction).await?; - - // TODO: Consider restricting the subnet for inbound connections - if self.contains_peer(&pid).await { - return Err(Error::PeerAlreadyConnected); - } - - // Create a new peer - let peer = Peer::new( - Arc::downgrade(self), - &pid, - conn, - endpoint.clone(), - conn_direction.clone(), - self.executor.clone(), - ); - - // Insert the new peer - self.peers.write().await.insert(pid.clone(), peer.clone()); - - let on_disconnect = { - let this = self.clone(); - let pid = pid.clone(); - |result| async move { - if let TaskResult::Completed(result) = result { - if let Err(err) = this.remove_peer(&pid).await { - error!("Failed to remove peer {pid}: {err}"); - } - let _ = disconnect_signal.send(result).await; - } - } - }; - - self.task_group.spawn(peer.run(), on_disconnect); - - info!("Add new peer {pid}, direction: {conn_direction}, endpoint: {endpoint}"); - - self.monitor.notify(PPEvent::NewPeer(pid.clone())).await; - - Ok(()) - } - /// Checks if the peer list contains a peer with the given peer id pub async fn contains_peer(&self, pid: &PeerID) -> bool { self.peers.read().await.contains_key(pid) @@ -204,162 +141,89 @@ impl PeerPool { peers } - /// Listens to a new connection from the connection queue - async fn listen_loop(self: Arc) { + async fn run(self: Arc) { loop { - let conn = self.conn_queue.next().await; - let signal = conn.disconnect_signal; + let mut conn = self.conn_queue.next().await; + + for protocol_id in self.protocols.read().await.keys() { + conn.register_protocol(protocol_id.to_string()).await; + } - let result = self - .new_peer(conn.conn, &conn.direction, signal.clone()) - .await; + let conn = Arc::new(conn); - // Only send a disconnect signal if there is an error when adding a peer. + let result = self.new_peer(conn.clone()).await; + + // Disconnect if there is an error when adding a peer. if result.is_err() { - let _ = signal.send(result).await; + let _ = conn.disconnect(result).await; } } } - /// Shuts down the peer and remove it from the peer list. - async fn remove_peer(&self, pid: &PeerID) -> Result<()> { - let result = self.peers.write().await.remove(pid); - - let peer = match result { - Some(p) => p, - None => return Ok(()), - }; - - peer.shutdown().await; - - self.monitor.notify(PPEvent::RemovePeer(pid.clone())).await; - - let endpoint = peer.remote_endpoint(); - let direction = peer.direction(); + /// Add a new peer to the peer list. + async fn new_peer(self: &Arc, conn: Arc) -> Result<()> { + // Create a new peer + let peer = Peer::new( + self.id.clone(), + Arc::downgrade(self), + conn.clone(), + self.config.clone(), + self.executor.clone(), + ); + peer.init().await?; + let pid = peer.id().expect("Get peer id after peer initialization"); - warn!("Peer {pid} removed, direction: {direction}, endpoint: {endpoint}",); - Ok(()) - } + // TODO: Consider restricting the subnet for inbound connections + if self.contains_peer(&pid).await { + return Err(Error::PeerAlreadyConnected); + } - /// Attach the core protocols. - async fn setup_protocols(&self) -> Result<()> { - let executor = self.executor.clone(); - let c = move |peer| PingProtocol::new(peer, executor.clone()); - self.attach_protocol::(Box::new(c)).await - } + // Insert the new peer + self.peers.write().await.insert(pid.clone(), peer.clone()); - /// Initiate a handshake with a connection. - async fn do_handshake( - &self, - conn: &Conn, - conn_direction: &ConnDirection, - ) -> Result { - trace!("Handshake started: {}", conn.peer_endpoint()?); - match conn_direction { - ConnDirection::Inbound => { - let result = self.wait_vermsg(conn).await; - match result { - Ok(_) => { - self.send_verack(conn, true).await?; - } - Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => { - self.send_verack(conn, false).await?; + let on_disconnect = { + let this = self.clone(); + let pid = pid.clone(); + |result| async move { + if let TaskResult::Completed(_) = result { + if let Err(err) = this.remove_peer(&pid).await { + error!("Failed to remove peer {pid}: {err}"); } - _ => {} } - result - } - - ConnDirection::Outbound => { - self.send_vermsg(conn).await?; - self.wait_verack(conn).await } - } - } + }; - /// Send a Version message - async fn send_vermsg(&self, conn: &Conn) -> Result<()> { - let pids = self.protocol_versions.read().await; - let protocols = pids.iter().map(|p| (p.0.clone(), p.1.v.clone())).collect(); - drop(pids); + self.task_group.spawn(peer.run(), on_disconnect); - let vermsg = VerMsg { - peer_id: self.id.clone(), - protocols, - version: self.config.version.v.clone(), - }; + info!("Add new peer {pid}"); + self.monitor.notify(PPEvent::NewPeer(pid)).await; - trace!("Send VerMsg"); - conn.send(NetMsg::new(NetMsgCmd::Version, &vermsg)?).await?; Ok(()) } - /// Wait for a Version message - /// - /// Returns the peer's ID upon successfully receiving the Version message. - async fn wait_vermsg(&self, conn: &Conn) -> Result { - let t = Duration::from_secs(self.config.handshake_timeout); - let msg: NetMsg = timeout(t, conn.recv()).await??; - - let payload = get_msg_payload!(Version, msg); - let (vermsg, _) = decode::(&payload)?; - - if !version_match(&self.config.version.req, &vermsg.version) { - return Err(Error::IncompatibleVersion("system: {}".into())); - } - - self.protocols_match(&vermsg.protocols).await?; - - trace!("Received VerMsg from: {}", vermsg.peer_id); - Ok(vermsg.peer_id) - } + /// Shuts down the peer and remove it from the peer list. + async fn remove_peer(&self, pid: &PeerID) -> Result<()> { + let result = self.peers.write().await.remove(pid); - /// Send a Verack message - async fn send_verack(&self, conn: &Conn, ack: bool) -> Result<()> { - let verack = VerAckMsg { - peer_id: self.id.clone(), - ack, + let peer = match result { + Some(p) => p, + None => return Ok(()), }; - trace!("Send VerAckMsg {:?}", verack); - conn.send(NetMsg::new(NetMsgCmd::Verack, &verack)?).await?; - Ok(()) - } - - /// Wait for a Verack message - /// - /// Returns the peer's ID upon successfully receiving the Verack message. - async fn wait_verack(&self, conn: &Conn) -> Result { - let t = Duration::from_secs(self.config.handshake_timeout); - let msg: NetMsg = timeout(t, conn.recv()).await??; + let _ = peer.shutdown().await; - let payload = get_msg_payload!(Verack, msg); - let (verack, _) = decode::(&payload)?; - - if !verack.ack { - return Err(Error::IncompatiblePeer); - } + self.monitor.notify(PPEvent::RemovePeer(pid.clone())).await; - trace!("Received VerAckMsg from: {}", verack.peer_id); - Ok(verack.peer_id) + warn!("Peer {pid} removed",); + Ok(()) } - /// Check if the new connection has compatible protocols. - async fn protocols_match(&self, protocols: &HashMap) -> Result<()> { - for (n, pv) in protocols.iter() { - let pids = self.protocol_versions.read().await; - - match pids.get(n) { - Some(v) => { - if !version_match(&v.req, pv) { - return Err(Error::IncompatibleVersion(format!("{n} protocol: {pv}"))); - } - } - None => { - return Err(Error::UnsupportedProtocol(n.to_string())); - } - } - } - Ok(()) + /// Attach the core protocols. + async fn setup_core_protocols(&self) -> Result<()> { + let executor = self.executor.clone(); + let ping_interval = self.config.ping_interval; + let ping_timeout = self.config.ping_timeout; + let c = move |peer| PingProtocol::new(peer, ping_interval, ping_timeout, executor.clone()); + self.attach_protocol::(Box::new(c)).await } } diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs index 021844f..249692b 100644 --- a/p2p/src/protocol.rs +++ b/p2p/src/protocol.rs @@ -56,11 +56,8 @@ impl EventValue for ProtocolEvent { /// #[async_trait] /// impl Protocol for NewProtocol { /// async fn start(self: Arc) -> Result<(), Error> { -/// let listener = self.peer.register_listener::().await; /// loop { -/// let event = listener.recv().await.unwrap(); -/// -/// match event { +/// match self.peer.recv::().await.expect("Receive msg") { /// ProtocolEvent::Message(msg) => { /// println!("{:?}", msg); /// } @@ -69,8 +66,6 @@ impl EventValue for ProtocolEvent { /// } /// } /// } -/// -/// listener.cancel().await; /// Ok(()) /// } /// @@ -114,3 +109,10 @@ pub trait Protocol: Send + Sync { where Self: Sized; } + +#[async_trait] +pub(crate) trait InitProtocol: Send + Sync { + type T; + /// Initialize the protocol + async fn init(self: Arc) -> Self::T; +} diff --git a/p2p/src/protocols/handshake.rs b/p2p/src/protocols/handshake.rs new file mode 100644 index 0000000..b3fe989 --- /dev/null +++ b/p2p/src/protocols/handshake.rs @@ -0,0 +1,139 @@ +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use async_trait::async_trait; +use log::trace; + +use karyon_core::{async_util::timeout, util::decode}; + +use crate::{ + message::{NetMsg, NetMsgCmd, VerAckMsg, VerMsg}, + peer::Peer, + protocol::{InitProtocol, ProtocolID}, + version::{version_match, VersionInt}, + Error, PeerID, Result, Version, +}; + +pub struct HandshakeProtocol { + peer: Arc, + protocols: HashMap, +} + +#[async_trait] +impl InitProtocol for HandshakeProtocol { + type T = Result; + /// Initiate a handshake with a connection. + async fn init(self: Arc) -> Self::T { + trace!("Init Handshake: {}", self.peer.remote_endpoint()); + + if !self.peer.is_inbound() { + self.send_vermsg().await?; + } + + let t = Duration::from_secs(self.peer.config().handshake_timeout); + let msg: NetMsg = timeout(t, self.peer.conn.recv_inner()).await??; + match msg.header.command { + NetMsgCmd::Version => { + let result = self.validate_version_msg(&msg).await; + match result { + Ok(_) => { + self.send_verack(true).await?; + } + Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => { + self.send_verack(false).await?; + } + _ => {} + }; + result + } + NetMsgCmd::Verack => self.validate_verack_msg(&msg).await, + cmd => Err(Error::InvalidMsg(format!("unexpected msg found {:?}", cmd))), + } + } +} + +impl HandshakeProtocol { + pub fn new(peer: Arc, protocols: HashMap) -> Arc { + Arc::new(Self { peer, protocols }) + } + + /// Sends a Version message + async fn send_vermsg(&self) -> Result<()> { + let protocols = self + .protocols + .clone() + .into_iter() + .map(|p| (p.0, p.1.v)) + .collect(); + + let vermsg = VerMsg { + peer_id: self.peer.own_id().clone(), + protocols, + version: self.peer.config().version.v.clone(), + }; + + trace!("Send VerMsg"); + self.peer + .conn + .send_inner(NetMsg::new(NetMsgCmd::Version, &vermsg)?) + .await?; + Ok(()) + } + + /// Sends a Verack message + async fn send_verack(&self, ack: bool) -> Result<()> { + let verack = VerAckMsg { + peer_id: self.peer.own_id().clone(), + ack, + }; + + trace!("Send VerAckMsg {:?}", verack); + self.peer + .conn + .send_inner(NetMsg::new(NetMsgCmd::Verack, &verack)?) + .await?; + Ok(()) + } + + /// Validates the given version msg + async fn validate_version_msg(&self, msg: &NetMsg) -> Result { + let (vermsg, _) = decode::(&msg.payload)?; + + if !version_match(&self.peer.config().version.req, &vermsg.version) { + return Err(Error::IncompatibleVersion("system: {}".into())); + } + + self.protocols_match(&vermsg.protocols).await?; + + trace!("Received VerMsg from: {}", vermsg.peer_id); + Ok(vermsg.peer_id) + } + + /// Validates the given verack msg + async fn validate_verack_msg(&self, msg: &NetMsg) -> Result { + let (verack, _) = decode::(&msg.payload)?; + + if !verack.ack { + return Err(Error::IncompatiblePeer); + } + + trace!("Received VerAckMsg from: {}", verack.peer_id); + Ok(verack.peer_id) + } + + /// Check if the new connection has compatible protocols. + async fn protocols_match(&self, protocols: &HashMap) -> Result<()> { + for (n, pv) in protocols.iter() { + match self.protocols.get(n) { + Some(v) => { + if !version_match(&v.req, pv) { + return Err(Error::IncompatibleVersion(format!("{n} protocol: {pv}"))); + } + } + None => { + return Err(Error::UnsupportedProtocol(n.to_string())); + } + } + } + Ok(()) + } +} diff --git a/p2p/src/protocols/mod.rs b/p2p/src/protocols/mod.rs index 4a8f6b9..c58df03 100644 --- a/p2p/src/protocols/mod.rs +++ b/p2p/src/protocols/mod.rs @@ -1,3 +1,5 @@ +mod handshake; mod ping; -pub use ping::PingProtocol; +pub(crate) use handshake::HandshakeProtocol; +pub(crate) use ping::PingProtocol; diff --git a/p2p/src/protocols/ping.rs b/p2p/src/protocols/ping.rs index b800b23..f35b203 100644 --- a/p2p/src/protocols/ping.rs +++ b/p2p/src/protocols/ping.rs @@ -9,7 +9,6 @@ use rand::{rngs::OsRng, RngCore}; use karyon_core::{ async_runtime::Executor, async_util::{select, sleep, timeout, Either, TaskGroup, TaskResult}, - event::EventListener, util::decode, }; @@ -39,9 +38,12 @@ pub struct PingProtocol { impl PingProtocol { #[allow(clippy::new_ret_no_self)] - pub fn new(peer: Arc, executor: Executor) -> Arc { - let ping_interval = peer.config().ping_interval; - let ping_timeout = peer.config().ping_timeout; + pub fn new( + peer: Arc, + ping_interval: u64, + ping_timeout: u64, + executor: Executor, + ) -> Arc { Arc::new(Self { peer, ping_interval, @@ -50,13 +52,9 @@ impl PingProtocol { }) } - async fn recv_loop( - &self, - listener: &EventListener, - pong_chan: Sender<[u8; 32]>, - ) -> Result<()> { + async fn recv_loop(&self, pong_chan: Sender<[u8; 32]>) -> Result<()> { loop { - let event = listener.recv().await?; + let event = self.peer.recv::().await?; let msg_payload = match event.clone() { ProtocolEvent::Message(m) => m, ProtocolEvent::Shutdown => { @@ -70,7 +68,7 @@ impl PingProtocol { PingProtocolMsg::Ping(nonce) => { trace!("Received Ping message {:?}", nonce); self.peer - .send(&Self::id(), &PingProtocolMsg::Pong(nonce)) + .send(Self::id(), &PingProtocolMsg::Pong(nonce)) .await?; trace!("Send back Pong message {:?}", nonce); } @@ -82,7 +80,7 @@ impl PingProtocol { Ok(()) } - async fn ping_loop(self: Arc, chan: Receiver<[u8; 32]>) -> Result<()> { + async fn ping_loop(&self, chan: Receiver<[u8; 32]>) -> Result<()> { let rng = &mut OsRng; let mut retry = 0; @@ -94,12 +92,11 @@ impl PingProtocol { trace!("Send Ping message {:?}", ping_nonce); self.peer - .send(&Self::id(), &PingProtocolMsg::Ping(ping_nonce)) + .send(Self::id(), &PingProtocolMsg::Ping(ping_nonce)) .await?; - let d = Duration::from_secs(self.ping_timeout); - // Wait for Pong message + let d = Duration::from_secs(self.ping_timeout); let pong_msg = match timeout(d, chan.recv()).await { Ok(m) => m?, Err(_) => { @@ -107,13 +104,14 @@ impl PingProtocol { continue; } }; - trace!("Received Pong message {:?}", pong_msg); if pong_msg != ping_nonce { retry += 1; continue; } + + retry = 0; } Err(NetError::Timeout.into()) @@ -125,8 +123,8 @@ impl Protocol for PingProtocol { async fn start(self: Arc) -> Result<()> { trace!("Start Ping protocol"); + let stop_signal = async_channel::bounded::>(1); let (pong_chan, pong_chan_recv) = async_channel::bounded(1); - let (stop_signal_s, stop_signal) = async_channel::bounded::>(1); self.task_group.spawn( { @@ -135,15 +133,12 @@ impl Protocol for PingProtocol { }, |res| async move { if let TaskResult::Completed(result) = res { - let _ = stop_signal_s.send(result).await; + let _ = stop_signal.0.send(result).await; } }, ); - let listener = self.peer.register_listener::().await; - - let result = select(self.recv_loop(&listener, pong_chan), stop_signal.recv()).await; - listener.cancel().await; + let result = select(self.recv_loop(pong_chan), stop_signal.1.recv()).await; self.task_group.cancel().await; match result { -- cgit v1.2.3