From e15d3e6fd20b3f87abaad7ddec1c88b0e66419f9 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Mon, 15 Jul 2024 13:16:01 +0200 Subject: p2p: Major refactoring of the handshake protocol Introduce a new protocol InitProtocol which can be used as the core protocol for initializing a connection with a peer. Move the handshake logic from the PeerPool module to the protocols directory and build a handshake protocol that implements InitProtocol trait. --- p2p/README.md | 45 +++---- p2p/examples/chat.rs | 6 +- p2p/src/conn_queue.rs | 53 +++----- p2p/src/connection.rs | 110 +++++++++++++++ p2p/src/discovery/lookup.rs | 114 ++++++++-------- p2p/src/discovery/mod.rs | 17 +-- p2p/src/discovery/refresh.rs | 28 ++-- p2p/src/lib.rs | 1 + p2p/src/message.rs | 20 +-- p2p/src/monitor/mod.rs | 6 +- p2p/src/peer/mod.rs | 247 ++++++++++++++++------------------ p2p/src/peer_pool.rs | 296 +++++++++++------------------------------ p2p/src/protocol.rs | 14 +- p2p/src/protocols/handshake.rs | 139 +++++++++++++++++++ p2p/src/protocols/mod.rs | 4 +- p2p/src/protocols/ping.rs | 39 +++--- 16 files changed, 590 insertions(+), 549 deletions(-) create mode 100644 p2p/src/connection.rs create mode 100644 p2p/src/protocols/handshake.rs (limited to 'p2p') diff --git a/p2p/README.md b/p2p/README.md index efd6d60..bd2a69e 100644 --- a/p2p/README.md +++ b/p2p/README.md @@ -1,6 +1,6 @@ -# karyon p2p +# Karyon p2p -karyon p2p serves as the foundational stack for the Karyon library. It offers +Karyon p2p serves as the foundational stack for the Karyon library. It offers a lightweight, extensible, and customizable peer-to-peer (p2p) network stack that seamlessly integrates with any p2p project. @@ -8,11 +8,11 @@ that seamlessly integrates with any p2p project. ### Discovery -karyon p2p uses a customized version of the Kademlia for discovering new peers +Karyon p2p uses a customized version of the Kademlia for discovering new peers in the network. This approach is based on Kademlia but with several significant differences and optimizations. Some of the main changes: -1. karyon p2p uses TCP for the lookup process, while UDP is used for +1. Karyon p2p uses TCP for the lookup process, while UDP is used for validating and refreshing the routing table. The reason for this choice is that the lookup process is infrequent, and the work required to manage messages with UDP is largely equivalent to using TCP for this purpose. @@ -21,11 +21,11 @@ differences and optimizations. Some of the main changes: use UDP. 2. In contrast to traditional Kademlia, which often employs 160 buckets, - karyon p2p reduces the number of buckets to 32. This optimization is a + Karyon p2p reduces the number of buckets to 32. This optimization is a result of the observation that most nodes tend to map into the last few buckets, with the majority of other buckets remaining empty. -3. While Kademlia typically uses a 160-bit key to identify a peer, karyon p2p +3. While Kademlia typically uses a 160-bit key to identify a peer, Karyon p2p uses a 256-bit key. > Despite criticisms of Kademlia's vulnerabilities, particularly concerning @@ -38,7 +38,7 @@ differences and optimizations. Some of the main changes: ### Peer ID -In the karyon p2p network, each peer is identified by a 256-bit (32-byte) Peer ID. +In the Karyon p2p network, each peer is identified by a 256-bit (32-byte) Peer ID. ### Seeding @@ -67,21 +67,20 @@ is added to the `PeerPool`. ### Protocols -In the karyon p2p network, we have two types of protocols: core protocols and -custom protocols. Core protocols are prebuilt into karyon p2p, such as the -Ping protocol used to maintain connections. Custom protocols, on the other -hand, are protocols that you define for your application to provide its core -functionality. +In the Karyon p2p network, there are two types of protocols: core protocols and +custom protocols. Core protocols, such as the Ping and Handshake protocols, +come prebuilt into Karyon p2p. Custom protocols, however, are ones that you +create to provide the specific functionality your application needs. Here's an example of a custom protocol: ```rust pub struct NewProtocol { - peer: ArcPeer, + peer: Arc, } impl NewProtocol { - fn new(peer: Arc) -> Arc { + fn new(peer: Arc) -> Arc { Arc::new(Self { peer, }) @@ -90,12 +89,9 @@ impl NewProtocol { #[async_trait] impl Protocol for NewProtocol { - async fn start(self: Arc) -> Result<(), P2pError> { - let listener = self.peer.register_listener::().await; + async fn start(self: Arc) -> Result<(), Error> { loop { - let event = listener.recv().await.unwrap(); - - match event { + match self.peer.recv::().await.expect("Receive msg") { ProtocolEvent::Message(msg) => { println!("{:?}", msg); } @@ -104,12 +100,10 @@ impl Protocol for NewProtocol { } } } - - listener.cancel().await; Ok(()) } - fn version() -> Result { + fn version() -> Result { "0.2.0, >0.1.0".parse() } @@ -120,20 +114,17 @@ impl Protocol for NewProtocol { ``` -Whenever a new peer is added to the `PeerPool`, all the protocols, including -the custom protocols, will automatically start running with the newly connected peer. - ## Network Security Using TLS is possible for all inbound and outbound connections by enabling the boolean `enable_tls` field in the configuration. However, implementing TLS for -a P2P network is not trivial and is still unstable, requiring a comprehensive +a p2p network is not trivial and is still unstable, requiring a comprehensive audit. ## Choosing the async runtime -karyon p2p currently supports both **smol(async-std)** and **tokio** async runtimes. +Karyon p2p currently supports both **smol(async-std)** and **tokio** async runtimes. The default is **smol**, but if you want to use **tokio**, you need to disable the default features and then select the `tokio` feature. diff --git a/p2p/examples/chat.rs b/p2p/examples/chat.rs index 5867c8b..2ea6b2c 100644 --- a/p2p/examples/chat.rs +++ b/p2p/examples/chat.rs @@ -69,11 +69,8 @@ impl Protocol for ChatProtocol { } }); - let listener = self.peer.register_listener::().await; loop { - let event = listener.recv().await.expect("Receive new protocol event"); - - match event { + match self.peer.recv::().await? { ProtocolEvent::Message(msg) => { let msg = String::from_utf8(msg).expect("Convert received bytes to string"); println!("{msg}"); @@ -85,7 +82,6 @@ impl Protocol for ChatProtocol { } task.cancel().await; - listener.cancel().await; Ok(()) } diff --git a/p2p/src/conn_queue.rs b/p2p/src/conn_queue.rs index 9a153f3..1b6ef98 100644 --- a/p2p/src/conn_queue.rs +++ b/p2p/src/conn_queue.rs @@ -1,37 +1,13 @@ -use std::{collections::VecDeque, fmt, sync::Arc}; - -use async_channel::Sender; +use std::{collections::VecDeque, sync::Arc}; use karyon_core::{async_runtime::lock::Mutex, async_util::CondVar}; use karyon_net::Conn; -use crate::{message::NetMsg, Result}; - -/// Defines the direction of a network connection. -#[derive(Clone, Debug)] -pub enum ConnDirection { - Inbound, - Outbound, -} - -impl fmt::Display for ConnDirection { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - ConnDirection::Inbound => write!(f, "Inbound"), - ConnDirection::Outbound => write!(f, "Outbound"), - } - } -} - -pub struct NewConn { - pub direction: ConnDirection, - pub conn: Conn, - pub disconnect_signal: Sender>, -} +use crate::{connection::ConnDirection, connection::Connection, message::NetMsg, Result}; /// Connection queue pub struct ConnQueue { - queue: Mutex>, + queue: Mutex>, conn_available: CondVar, } @@ -43,24 +19,27 @@ impl ConnQueue { }) } - /// Push a connection into the queue and wait for the disconnect signal + /// Handle a connection by pushing it into the queue and wait for the disconnect signal pub async fn handle(&self, conn: Conn, direction: ConnDirection) -> Result<()> { - let (disconnect_signal, chan) = async_channel::bounded(1); - let new_conn = NewConn { - direction, - conn, - disconnect_signal, - }; + let endpoint = conn.peer_endpoint()?; + + let (disconnect_tx, disconnect_rx) = async_channel::bounded(1); + let new_conn = Connection::new(conn, disconnect_tx, direction, endpoint); + + // Push a new conn to the queue self.queue.lock().await.push_back(new_conn); self.conn_available.signal(); - if let Ok(result) = chan.recv().await { + + // Wait for the disconnect signal from the connection handler + if let Ok(result) = disconnect_rx.recv().await { return result; } + Ok(()) } - /// Receive the next connection in the queue - pub async fn next(&self) -> NewConn { + /// Waits for the next connection in the queue + pub async fn next(&self) -> Connection { let mut queue = self.queue.lock().await; while queue.is_empty() { queue = self.conn_available.wait(queue).await; diff --git a/p2p/src/connection.rs b/p2p/src/connection.rs new file mode 100644 index 0000000..52190a8 --- /dev/null +++ b/p2p/src/connection.rs @@ -0,0 +1,110 @@ +use std::{collections::HashMap, fmt, sync::Arc}; + +use async_channel::Sender; +use bincode::Encode; + +use karyon_core::{ + event::{EventListener, EventSys}, + util::encode, +}; + +use karyon_net::{Conn, Endpoint}; + +use crate::{ + message::{NetMsg, NetMsgCmd, ProtocolMsg, ShutdownMsg}, + protocol::{Protocol, ProtocolEvent, ProtocolID}, + Error, Result, +}; + +/// Defines the direction of a network connection. +#[derive(Clone, Debug)] +pub enum ConnDirection { + Inbound, + Outbound, +} + +impl fmt::Display for ConnDirection { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ConnDirection::Inbound => write!(f, "Inbound"), + ConnDirection::Outbound => write!(f, "Outbound"), + } + } +} + +pub struct Connection { + pub(crate) direction: ConnDirection, + conn: Conn, + disconnect_signal: Sender>, + /// `EventSys` responsible for sending events to the registered protocols. + protocol_events: Arc>, + pub(crate) remote_endpoint: Endpoint, + listeners: HashMap>, +} + +impl Connection { + pub fn new( + conn: Conn, + signal: Sender>, + direction: ConnDirection, + remote_endpoint: Endpoint, + ) -> Self { + Self { + conn, + direction, + protocol_events: EventSys::new(), + disconnect_signal: signal, + remote_endpoint, + listeners: HashMap::new(), + } + } + + pub async fn send(&self, protocol_id: ProtocolID, msg: T) -> Result<()> { + let payload = encode(&msg)?; + + let proto_msg = ProtocolMsg { + protocol_id, + payload: payload.to_vec(), + }; + + let msg = NetMsg::new(NetMsgCmd::Protocol, &proto_msg)?; + self.conn.send(msg).await.map_err(Error::from) + } + + pub async fn recv(&self) -> Result { + match self.listeners.get(&P::id()) { + Some(l) => l.recv().await.map_err(Error::from), + // TODO + None => todo!(), + } + } + + /// Registers a listener for the given Protocol `P`. + pub async fn register_protocol(&mut self, protocol_id: String) { + let listener = self.protocol_events.register(&protocol_id).await; + self.listeners.insert(protocol_id, listener); + } + + pub async fn emit_msg(&self, id: &ProtocolID, event: &ProtocolEvent) -> Result<()> { + self.protocol_events.emit_by_topic(id, event).await?; + Ok(()) + } + + pub async fn recv_inner(&self) -> Result { + self.conn.recv().await.map_err(Error::from) + } + + pub async fn send_inner(&self, msg: NetMsg) -> Result<()> { + self.conn.send(msg).await.map_err(Error::from) + } + + pub async fn disconnect(&self, res: Result<()>) -> Result<()> { + self.protocol_events.clear().await; + self.disconnect_signal.send(res).await?; + + let m = NetMsg::new(NetMsgCmd::Shutdown, ShutdownMsg(0)).expect("Create shutdown message"); + self.conn.send(m).await.map_err(Error::from)?; + + Ok(()) + } +} diff --git a/p2p/src/discovery/lookup.rs b/p2p/src/discovery/lookup.rs index 9ddf614..47a1d09 100644 --- a/p2p/src/discovery/lookup.rs +++ b/p2p/src/discovery/lookup.rs @@ -2,24 +2,17 @@ use std::{sync::Arc, time::Duration}; use futures_util::stream::{FuturesUnordered, StreamExt}; use log::{error, trace}; +use parking_lot::RwLock; use rand::{rngs::OsRng, seq::SliceRandom, RngCore}; -use karyon_core::{ - async_runtime::{lock::RwLock, Executor}, - async_util::timeout, - crypto::KeyPair, - util::decode, -}; +use karyon_core::{async_runtime::Executor, async_util::timeout, crypto::KeyPair, util::decode}; use karyon_net::{Conn, Endpoint}; use crate::{ connector::Connector, listener::Listener, - message::{ - get_msg_payload, FindPeerMsg, NetMsg, NetMsgCmd, PeerMsg, PeersMsg, PingMsg, PongMsg, - ShutdownMsg, - }, + message::{FindPeerMsg, NetMsg, NetMsgCmd, PeerMsg, PeersMsg, PingMsg, PongMsg, ShutdownMsg}, monitor::{ConnEvent, DiscvEvent, Monitor}, routing_table::RoutingTable, slots::ConnectionSlots, @@ -46,7 +39,7 @@ pub struct LookupService { outbound_slots: Arc, /// Resolved listen endpoint - listen_endpoint: Option>, + listen_endpoint: RwLock>, /// Holds the configuration for the P2P network. config: Arc, @@ -85,18 +78,13 @@ impl LookupService { ex, ); - let listen_endpoint = config - .listen_endpoint - .as_ref() - .map(|endpoint| RwLock::new(endpoint.clone())); - Self { id: id.clone(), table, listener, connector, outbound_slots, - listen_endpoint, + listen_endpoint: RwLock::new(None), config, monitor, } @@ -109,10 +97,18 @@ impl LookupService { } /// Set the resolved listen endpoint. - pub async fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) { - if let Some(endpoint) = &self.listen_endpoint { - *endpoint.write().await = resolved_endpoint.clone(); - } + pub fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) -> Result<()> { + let resolved_endpoint = Endpoint::Tcp( + resolved_endpoint.addr()?.clone(), + self.config.discovery_port, + ); + *self.listen_endpoint.write() = Some(resolved_endpoint); + Ok(()) + } + + /// Get the listening endpoint. + pub fn listen_endpoint(&self) -> Option { + self.listen_endpoint.read().clone() } /// Shuts down the lookup service. @@ -253,36 +249,51 @@ impl LookupService { target_peer_id: &PeerID, ) -> Result> { trace!("Send Ping msg"); - self.send_ping_msg(&conn).await?; + let peers; - trace!("Send FindPeer msg"); - let peers = self.send_findpeer_msg(&conn, target_peer_id).await?; + let ping_msg = self.send_ping_msg(&conn).await?; - if peers.0.len() >= MAX_PEERS_IN_PEERSMSG { - return Err(Error::Lookup("Received too many peers in PeersMsg")); + loop { + let t = Duration::from_secs(self.config.lookup_response_timeout); + let msg: NetMsg = timeout(t, conn.recv()).await??; + match msg.header.command { + NetMsgCmd::Pong => { + let (pong_msg, _) = decode::(&msg.payload)?; + if ping_msg.nonce != pong_msg.0 { + return Err(Error::InvalidPongMsg); + } + trace!("Send FindPeer msg"); + self.send_findpeer_msg(&conn, target_peer_id).await?; + } + NetMsgCmd::Peers => { + peers = decode::(&msg.payload)?.0.peers; + if peers.len() >= MAX_PEERS_IN_PEERSMSG { + return Err(Error::Lookup("Received too many peers in PeersMsg")); + } + break; + } + c => return Err(Error::InvalidMsg(format!("Unexpected msg: {:?}", c))), + }; } trace!("Send Peer msg"); - if let Some(endpoint) = &self.listen_endpoint { - self.send_peer_msg(&conn, endpoint.read().await.clone()) - .await?; + if let Some(endpoint) = self.listen_endpoint() { + self.send_peer_msg(&conn, endpoint.clone()).await?; } trace!("Send Shutdown msg"); self.send_shutdown_msg(&conn).await?; - Ok(peers.0) + Ok(peers) } /// Start a listener. async fn start_listener(self: &Arc) -> Result<()> { - let addr = match &self.listen_endpoint { - Some(a) => a.read().await.addr()?.clone(), + let endpoint: Endpoint = match self.listen_endpoint() { + Some(e) => e.clone(), None => return Ok(()), }; - let endpoint = Endpoint::Tcp(addr, self.config.discovery_port); - let callback = { let this = self.clone(); |conn: Conn| async move { @@ -292,7 +303,7 @@ impl LookupService { } }; - self.listener.start(endpoint.clone(), callback).await?; + self.listener.start(endpoint, callback).await?; Ok(()) } @@ -329,10 +340,9 @@ impl LookupService { } } - /// Sends a Ping msg and wait to receive the Pong message. - async fn send_ping_msg(&self, conn: &Conn) -> Result<()> { + /// Sends a Ping msg. + async fn send_ping_msg(&self, conn: &Conn) -> Result { trace!("Send Pong msg"); - let mut nonce: [u8; 32] = [0; 32]; RngCore::fill_bytes(&mut OsRng, &mut nonce); @@ -341,18 +351,7 @@ impl LookupService { nonce, }; conn.send(NetMsg::new(NetMsgCmd::Ping, &ping_msg)?).await?; - - let t = Duration::from_secs(self.config.lookup_response_timeout); - let recv_msg: NetMsg = timeout(t, conn.recv()).await??; - - let payload = get_msg_payload!(Pong, recv_msg); - let (pong_msg, _) = decode::(&payload)?; - - if ping_msg.nonce != pong_msg.0 { - return Err(Error::InvalidPongMsg); - } - - Ok(()) + Ok(ping_msg) } /// Sends a Pong msg @@ -363,22 +362,15 @@ impl LookupService { Ok(()) } - /// Sends a FindPeer msg and wait to receivet the Peers msg. - async fn send_findpeer_msg(&self, conn: &Conn, peer_id: &PeerID) -> Result { + /// Sends a FindPeer msg + async fn send_findpeer_msg(&self, conn: &Conn, peer_id: &PeerID) -> Result<()> { trace!("Send FindPeer msg"); conn.send(NetMsg::new( NetMsgCmd::FindPeer, FindPeerMsg(peer_id.clone()), )?) .await?; - - let t = Duration::from_secs(self.config.lookup_response_timeout); - let recv_msg: NetMsg = timeout(t, conn.recv()).await??; - - let payload = get_msg_payload!(Peers, recv_msg); - let (peers, _) = decode(&payload)?; - - Ok(peers) + Ok(()) } /// Sends a Peers msg. @@ -389,7 +381,7 @@ impl LookupService { .closest_entries(&peer_id.0, MAX_PEERS_IN_PEERSMSG); let peers: Vec = entries.into_iter().map(|e| e.into()).collect(); - conn.send(NetMsg::new(NetMsgCmd::Peers, PeersMsg(peers))?) + conn.send(NetMsg::new(NetMsgCmd::Peers, PeersMsg { peers })?) .await?; Ok(()) } diff --git a/p2p/src/discovery/mod.rs b/p2p/src/discovery/mod.rs index a9d99d6..a81a817 100644 --- a/p2p/src/discovery/mod.rs +++ b/p2p/src/discovery/mod.rs @@ -16,7 +16,8 @@ use karyon_net::{Conn, Endpoint}; use crate::{ config::Config, - conn_queue::{ConnDirection, ConnQueue}, + conn_queue::ConnQueue, + connection::ConnDirection, connector::Connector, listener::Listener, message::NetMsg, @@ -132,15 +133,11 @@ impl Discovery { let resolved_endpoint = self.start_listener(endpoint).await?; - if endpoint.addr()? != resolved_endpoint.addr()? { - info!("Resolved listen endpoint: {resolved_endpoint}"); - self.lookup_service - .set_listen_endpoint(&resolved_endpoint) - .await; - self.refresh_service - .set_listen_endpoint(&resolved_endpoint) - .await; - } + info!("Resolved listen endpoint: {resolved_endpoint}"); + self.lookup_service + .set_listen_endpoint(&resolved_endpoint)?; + self.refresh_service + .set_listen_endpoint(&resolved_endpoint)?; } // Start the lookup service diff --git a/p2p/src/discovery/refresh.rs b/p2p/src/discovery/refresh.rs index b4f5396..1452a1b 100644 --- a/p2p/src/discovery/refresh.rs +++ b/p2p/src/discovery/refresh.rs @@ -2,10 +2,11 @@ use std::{sync::Arc, time::Duration}; use bincode::{Decode, Encode}; use log::{error, info, trace}; +use parking_lot::RwLock; use rand::{rngs::OsRng, RngCore}; use karyon_core::{ - async_runtime::{lock::RwLock, Executor}, + async_runtime::Executor, async_util::{sleep, timeout, Backoff, TaskGroup, TaskResult}, }; @@ -33,7 +34,7 @@ pub struct RefreshService { table: Arc, /// Resolved listen endpoint - listen_endpoint: Option>, + listen_endpoint: RwLock>, /// Managing spawned tasks. task_group: TaskGroup, @@ -53,14 +54,9 @@ impl RefreshService { monitor: Arc, executor: Executor, ) -> Self { - let listen_endpoint = config - .listen_endpoint - .as_ref() - .map(|endpoint| RwLock::new(endpoint.clone())); - Self { table, - listen_endpoint, + listen_endpoint: RwLock::new(None), task_group: TaskGroup::with_executor(executor.clone()), config, monitor, @@ -69,9 +65,8 @@ impl RefreshService { /// Start the refresh service pub async fn start(self: &Arc) -> Result<()> { - if let Some(endpoint) = &self.listen_endpoint { - let endpoint = endpoint.read().await.clone(); - + if let Some(endpoint) = self.listen_endpoint.read().as_ref() { + let endpoint = endpoint.clone(); self.task_group.spawn( { let this = self.clone(); @@ -101,10 +96,13 @@ impl RefreshService { } /// Set the resolved listen endpoint. - pub async fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) { - if let Some(endpoint) = &self.listen_endpoint { - *endpoint.write().await = resolved_endpoint.clone(); - } + pub fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) -> Result<()> { + let resolved_endpoint = Endpoint::Udp( + resolved_endpoint.addr()?.clone(), + self.config.discovery_port, + ); + *self.listen_endpoint.write() = Some(resolved_endpoint); + Ok(()) } /// Shuts down the refresh service diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs index b21a353..f0dc725 100644 --- a/p2p/src/lib.rs +++ b/p2p/src/lib.rs @@ -41,6 +41,7 @@ mod backend; mod codec; mod config; mod conn_queue; +mod connection; mod connector; mod discovery; mod error; diff --git a/p2p/src/message.rs b/p2p/src/message.rs index 6498ef7..5bf0853 100644 --- a/p2p/src/message.rs +++ b/p2p/src/message.rs @@ -110,7 +110,9 @@ pub struct PeerMsg { /// PeersMsg a list of `PeerMsg`. #[derive(Decode, Encode, Debug)] -pub struct PeersMsg(pub Vec); +pub struct PeersMsg { + pub peers: Vec, +} impl From for PeerMsg { fn from(entry: Entry) -> PeerMsg { @@ -133,19 +135,3 @@ impl From for Entry { } } } - -macro_rules! get_msg_payload { - ($a:ident, $b:ident) => { - if let NetMsgCmd::$a = $b.header.command { - $b.payload - } else { - return Err(Error::InvalidMsg(format!( - "Expected {:?} msg found {:?} msg", - stringify!($a), - $b.header.command - ))); - } - }; -} - -pub(super) use get_msg_payload; diff --git a/p2p/src/monitor/mod.rs b/p2p/src/monitor/mod.rs index 4ecb431..86db23e 100644 --- a/p2p/src/monitor/mod.rs +++ b/p2p/src/monitor/mod.rs @@ -2,6 +2,8 @@ mod event; use std::sync::Arc; +use log::error; + use karyon_core::event::{EventListener, EventSys, EventValue, EventValueTopic}; use karyon_net::Endpoint; @@ -62,7 +64,9 @@ impl Monitor { pub(crate) async fn notify(&self, event: E) { if self.config.enable_monitor { let event = event.to_struct(); - self.event_sys.emit(&event).await + if let Err(err) = self.event_sys.emit(&event).await { + error!("Failed to notify monitor event {:?}: {err}", event); + } } } diff --git a/p2p/src/peer/mod.rs b/p2p/src/peer/mod.rs index 6903294..a5ac7ad 100644 --- a/p2p/src/peer/mod.rs +++ b/p2p/src/peer/mod.rs @@ -1,138 +1,111 @@ mod peer_id; -pub use peer_id::PeerID; - use std::sync::{Arc, Weak}; use async_channel::{Receiver, Sender}; -use bincode::{Decode, Encode}; +use bincode::Encode; use log::{error, trace}; +use parking_lot::RwLock; use karyon_core::{ - async_runtime::{lock::RwLock, Executor}, + async_runtime::Executor, async_util::{select, Either, TaskGroup, TaskResult}, - event::{EventListener, EventSys}, - util::{decode, encode}, + util::decode, }; -use karyon_net::{Conn, Endpoint}; - use crate::{ - conn_queue::ConnDirection, - message::{NetMsg, NetMsgCmd, ProtocolMsg, ShutdownMsg}, + connection::{ConnDirection, Connection}, + endpoint::Endpoint, + message::{NetMsgCmd, ProtocolMsg}, peer_pool::PeerPool, - protocol::{Protocol, ProtocolEvent, ProtocolID}, + protocol::{InitProtocol, Protocol, ProtocolEvent, ProtocolID}, + protocols::HandshakeProtocol, Config, Error, Result, }; +pub use peer_id::PeerID; + pub struct Peer { + /// Own ID + own_id: PeerID, + /// Peer's ID - id: PeerID, + id: RwLock>, - /// A weak pointer to `PeerPool` + /// A weak pointer to [`PeerPool`] peer_pool: Weak, /// Holds the peer connection - conn: Conn, - - /// Remote endpoint for the peer - remote_endpoint: Endpoint, - - /// The direction of the connection, either `Inbound` or `Outbound` - conn_direction: ConnDirection, - - /// A list of protocol IDs - protocol_ids: RwLock>, - - /// `EventSys` responsible for sending events to the protocols. - protocol_events: Arc>, + pub(crate) conn: Arc, /// This channel is used to send a stop signal to the read loop. stop_chan: (Sender>, Receiver>), + /// The Configuration for the P2P network. + config: Arc, + /// Managing spawned tasks. task_group: TaskGroup, } impl Peer { /// Creates a new peer - pub fn new( + pub(crate) fn new( + own_id: PeerID, peer_pool: Weak, - id: &PeerID, - conn: Conn, - remote_endpoint: Endpoint, - conn_direction: ConnDirection, + conn: Arc, + config: Arc, ex: Executor, ) -> Arc { Arc::new(Peer { - id: id.clone(), + own_id, + id: RwLock::new(None), peer_pool, conn, - protocol_ids: RwLock::new(Vec::new()), - remote_endpoint, - conn_direction, - protocol_events: EventSys::new(), + config, task_group: TaskGroup::with_executor(ex), stop_chan: async_channel::bounded(1), }) } - /// Run the peer - pub async fn run(self: Arc) -> Result<()> { - self.start_protocols().await; - self.read_loop().await + /// Send a msg to this peer connection using the specified protocol. + pub async fn send(&self, proto_id: ProtocolID, msg: T) -> Result<()> { + self.conn.send(proto_id, msg).await } - /// Send a message to the peer connection using the specified protocol. - pub async fn send(&self, protocol_id: &ProtocolID, msg: &T) -> Result<()> { - let payload = encode(msg)?; - - let proto_msg = ProtocolMsg { - protocol_id: protocol_id.to_string(), - payload: payload.to_vec(), - }; - - self.conn - .send(NetMsg::new(NetMsgCmd::Protocol, &proto_msg)?) - .await?; - Ok(()) + /// Receives a new msg from this peer connection. + pub async fn recv(&self) -> Result { + self.conn.recv::

().await } /// Broadcast a message to all connected peers using the specified protocol. - pub async fn broadcast(&self, protocol_id: &ProtocolID, msg: &T) { - self.peer_pool().broadcast(protocol_id, msg).await; + pub async fn broadcast(&self, proto_id: &ProtocolID, msg: &T) { + self.peer_pool().broadcast(proto_id, msg).await; } - /// Shuts down the peer - pub async fn shutdown(&self) { - trace!("peer {} start shutting down", self.id); - - // Send shutdown event to all protocols - for protocol_id in self.protocol_ids.read().await.iter() { - self.protocol_events - .emit_by_topic(protocol_id, &ProtocolEvent::Shutdown) - .await; - } + /// Returns the peer's ID + pub fn id(&self) -> Option { + self.id.read().clone() + } - // Send a stop signal to the read loop - // - // No need to handle the error here; a dropped channel and - // sending a stop signal have the same effect. - let _ = self.stop_chan.0.try_send(Ok(())); + /// Returns own ID + pub fn own_id(&self) -> &PeerID { + &self.own_id + } - // No need to handle the error here - let shutdown_msg = - NetMsg::new(NetMsgCmd::Shutdown, ShutdownMsg(0)).expect("pack shutdown message"); - let _ = self.conn.send(shutdown_msg).await; + /// Returns the [`Config`] + pub fn config(&self) -> Arc { + self.config.clone() + } - // Force shutting down - self.task_group.cancel().await; + /// Returns the remote endpoint for the peer + pub fn remote_endpoint(&self) -> &Endpoint { + &self.conn.remote_endpoint } /// Check if the connection is Inbound - #[inline] pub fn is_inbound(&self) -> bool { - match self.conn_direction { + match self.conn.direction { ConnDirection::Inbound => true, ConnDirection::Outbound => false, } @@ -140,40 +113,82 @@ impl Peer { /// Returns the direction of the connection, which can be either `Inbound` /// or `Outbound`. - #[inline] pub fn direction(&self) -> &ConnDirection { - &self.conn_direction + &self.conn.direction } - /// Returns the remote endpoint for the peer - #[inline] - pub fn remote_endpoint(&self) -> &Endpoint { - &self.remote_endpoint + pub(crate) async fn init(self: &Arc) -> Result<()> { + let handshake_protocol = HandshakeProtocol::new( + self.clone(), + self.peer_pool().protocol_versions.read().await.clone(), + ); + + let pid = handshake_protocol.init().await?; + *self.id.write() = Some(pid); + + Ok(()) } - /// Return the peer's ID - #[inline] - pub fn id(&self) -> &PeerID { - &self.id + /// Run the peer + pub(crate) async fn run(self: Arc) -> Result<()> { + self.run_connect_protocols().await; + self.read_loop().await } - /// Returns the `Config` instance. - pub fn config(&self) -> Arc { - self.peer_pool().config.clone() + /// Shuts down the peer + pub(crate) async fn shutdown(self: &Arc) -> Result<()> { + trace!("peer {:?} shutting down", self.id()); + + // Send shutdown event to the attached protocols + for proto_id in self.peer_pool().protocols.read().await.keys() { + let _ = self.conn.emit_msg(proto_id, &ProtocolEvent::Shutdown).await; + } + + // Send a stop signal to the read loop + // + // No need to handle the error here; a dropped channel and + // sendig a stop signal have the same effect. + let _ = self.stop_chan.0.try_send(Ok(())); + + self.conn.disconnect(Ok(())).await?; + + // Force shutting down + self.task_group.cancel().await; + Ok(()) } - /// Registers a listener for the given Protocol `P`. - pub async fn register_listener(&self) -> EventListener { - self.protocol_events.register(&P::id()).await + /// Run running the Connect Protocols for this peer connection. + async fn run_connect_protocols(self: &Arc) { + for (proto_id, constructor) in self.peer_pool().protocols.read().await.iter() { + trace!("peer {:?} run protocol {proto_id}", self.id()); + + let protocol = constructor(self.clone()); + + let on_failure = { + let this = self.clone(); + let proto_id = proto_id.clone(); + |result: TaskResult>| async move { + if let TaskResult::Completed(res) = result { + if res.is_err() { + error!("protocol {} stopped", proto_id); + } + // Send a stop signal to read loop + let _ = this.stop_chan.0.try_send(res); + } + } + }; + + self.task_group.spawn(protocol.start(), on_failure); + } } - /// Start a read loop to handle incoming messages from the peer connection. + /// Run a read loop to handle incoming messages from the peer connection. async fn read_loop(&self) -> Result<()> { loop { - let fut = select(self.stop_chan.1.recv(), self.conn.recv()).await; + let fut = select(self.stop_chan.1.recv(), self.conn.recv_inner()).await; let result = match fut { Either::Left(stop_signal) => { - trace!("Peer {} received a stop signal", self.id); + trace!("Peer {:?} received a stop signal", self.id()); return stop_signal?; } Either::Right(result) => result, @@ -184,14 +199,9 @@ impl Peer { match msg.header.command { NetMsgCmd::Protocol => { let msg: ProtocolMsg = decode(&msg.payload)?.0; - - if !self.protocol_ids.read().await.contains(&msg.protocol_id) { - return Err(Error::UnsupportedProtocol(msg.protocol_id)); - } - - let proto_id = &msg.protocol_id; - let msg = ProtocolEvent::Message(msg.payload); - self.protocol_events.emit_by_topic(proto_id, &msg).await; + self.conn + .emit_msg(&msg.protocol_id, &ProtocolEvent::Message(msg.payload)) + .await?; } NetMsgCmd::Shutdown => { return Err(Error::PeerShutdown); @@ -201,32 +211,7 @@ impl Peer { } } - /// Start running the protocols for this peer connection. - async fn start_protocols(self: &Arc) { - for (protocol_id, constructor) in self.peer_pool().protocols.read().await.iter() { - trace!("peer {} start protocol {protocol_id}", self.id); - let protocol = constructor(self.clone()); - - self.protocol_ids.write().await.push(protocol_id.clone()); - - let on_failure = { - let this = self.clone(); - let protocol_id = protocol_id.clone(); - |result: TaskResult>| async move { - if let TaskResult::Completed(res) = result { - if res.is_err() { - error!("protocol {} stopped", protocol_id); - } - // Send a stop signal to read loop - let _ = this.stop_chan.0.try_send(res); - } - } - }; - - self.task_group.spawn(protocol.start(), on_failure); - } - } - + /// Returns `PeerPool` pointer fn peer_pool(&self) -> Arc { self.peer_pool.upgrade().unwrap() } diff --git a/p2p/src/peer_pool.rs b/p2p/src/peer_pool.rs index 1f3ca55..549dc76 100644 --- a/p2p/src/peer_pool.rs +++ b/p2p/src/peer_pool.rs @@ -1,26 +1,24 @@ -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{collections::HashMap, sync::Arc}; -use async_channel::Sender; -use bincode::{Decode, Encode}; -use log::{error, info, trace, warn}; +use bincode::Encode; +use log::{error, info, warn}; use karyon_core::{ async_runtime::{lock::RwLock, Executor}, - async_util::{timeout, TaskGroup, TaskResult}, - util::decode, + async_util::{TaskGroup, TaskResult}, }; -use karyon_net::{Conn, Endpoint}; +use karyon_net::Endpoint; use crate::{ config::Config, - conn_queue::{ConnDirection, ConnQueue}, - message::{get_msg_payload, NetMsg, NetMsgCmd, VerAckMsg, VerMsg}, + conn_queue::ConnQueue, + connection::Connection, monitor::{Monitor, PPEvent}, peer::Peer, protocol::{Protocol, ProtocolConstructor, ProtocolID}, protocols::PingProtocol, - version::{version_match, Version, VersionInt}, + version::Version, Error, PeerID, Result, }; @@ -37,8 +35,8 @@ pub struct PeerPool { /// Hashmap contains protocol constructors. pub(crate) protocols: RwLock>>, - /// Hashmap contains protocol IDs and their versions. - protocol_versions: Arc>>, + /// Hashmap contains protocols with their versions + pub(crate) protocol_versions: RwLock>, /// Managing spawned tasks. task_group: TaskGroup, @@ -47,7 +45,7 @@ pub struct PeerPool { executor: Executor, /// The Configuration for the P2P network. - pub(crate) config: Arc, + config: Arc, /// Responsible for network and system monitoring. monitor: Arc, @@ -62,15 +60,12 @@ impl PeerPool { monitor: Arc, executor: Executor, ) -> Arc { - let protocols = RwLock::new(HashMap::new()); - let protocol_versions = Arc::new(RwLock::new(HashMap::new())); - Arc::new(Self { id: id.clone(), conn_queue, peers: RwLock::new(HashMap::new()), - protocols, - protocol_versions, + protocols: RwLock::new(HashMap::new()), + protocol_versions: RwLock::new(HashMap::new()), task_group: TaskGroup::with_executor(executor.clone()), executor, monitor, @@ -80,21 +75,15 @@ impl PeerPool { /// Starts the [`PeerPool`] pub async fn start(self: &Arc) -> Result<()> { - self.setup_protocols().await?; - self.task_group.spawn( - { - let this = self.clone(); - async move { this.listen_loop().await } - }, - |_| async {}, - ); + self.setup_core_protocols().await?; + self.task_group.spawn(self.clone().run(), |_| async {}); Ok(()) } /// Shuts down pub async fn shutdown(&self) { for (_, peer) in self.peers.read().await.iter() { - peer.shutdown().await; + let _ = peer.shutdown().await; } self.task_group.cancel().await; @@ -102,76 +91,24 @@ impl PeerPool { /// Attach a custom protocol to the network pub async fn attach_protocol(&self, c: Box) -> Result<()> { - let protocol_versions = &mut self.protocol_versions.write().await; - let protocols = &mut self.protocols.write().await; - - protocol_versions.insert(P::id(), P::version()?); - protocols.insert(P::id(), c); + self.protocols.write().await.insert(P::id(), c); + self.protocol_versions + .write() + .await + .insert(P::id(), P::version()?); Ok(()) } /// Broadcast a message to all connected peers using the specified protocol. - pub async fn broadcast(&self, proto_id: &ProtocolID, msg: &T) { + pub async fn broadcast(&self, proto_id: &ProtocolID, msg: &T) { for (pid, peer) in self.peers.read().await.iter() { - if let Err(err) = peer.send(proto_id, msg).await { + if let Err(err) = peer.conn.send(proto_id.to_string(), msg).await { error!("failed to send msg to {pid}: {err}"); continue; } } } - /// Add a new peer to the peer list. - pub async fn new_peer( - self: &Arc, - conn: Conn, - conn_direction: &ConnDirection, - disconnect_signal: Sender>, - ) -> Result<()> { - let endpoint = conn.peer_endpoint()?; - - // Do a handshake with the connection before creating a new peer. - let pid = self.do_handshake(&conn, conn_direction).await?; - - // TODO: Consider restricting the subnet for inbound connections - if self.contains_peer(&pid).await { - return Err(Error::PeerAlreadyConnected); - } - - // Create a new peer - let peer = Peer::new( - Arc::downgrade(self), - &pid, - conn, - endpoint.clone(), - conn_direction.clone(), - self.executor.clone(), - ); - - // Insert the new peer - self.peers.write().await.insert(pid.clone(), peer.clone()); - - let on_disconnect = { - let this = self.clone(); - let pid = pid.clone(); - |result| async move { - if let TaskResult::Completed(result) = result { - if let Err(err) = this.remove_peer(&pid).await { - error!("Failed to remove peer {pid}: {err}"); - } - let _ = disconnect_signal.send(result).await; - } - } - }; - - self.task_group.spawn(peer.run(), on_disconnect); - - info!("Add new peer {pid}, direction: {conn_direction}, endpoint: {endpoint}"); - - self.monitor.notify(PPEvent::NewPeer(pid.clone())).await; - - Ok(()) - } - /// Checks if the peer list contains a peer with the given peer id pub async fn contains_peer(&self, pid: &PeerID) -> bool { self.peers.read().await.contains_key(pid) @@ -204,162 +141,89 @@ impl PeerPool { peers } - /// Listens to a new connection from the connection queue - async fn listen_loop(self: Arc) { + async fn run(self: Arc) { loop { - let conn = self.conn_queue.next().await; - let signal = conn.disconnect_signal; + let mut conn = self.conn_queue.next().await; + + for protocol_id in self.protocols.read().await.keys() { + conn.register_protocol(protocol_id.to_string()).await; + } - let result = self - .new_peer(conn.conn, &conn.direction, signal.clone()) - .await; + let conn = Arc::new(conn); - // Only send a disconnect signal if there is an error when adding a peer. + let result = self.new_peer(conn.clone()).await; + + // Disconnect if there is an error when adding a peer. if result.is_err() { - let _ = signal.send(result).await; + let _ = conn.disconnect(result).await; } } } - /// Shuts down the peer and remove it from the peer list. - async fn remove_peer(&self, pid: &PeerID) -> Result<()> { - let result = self.peers.write().await.remove(pid); - - let peer = match result { - Some(p) => p, - None => return Ok(()), - }; - - peer.shutdown().await; - - self.monitor.notify(PPEvent::RemovePeer(pid.clone())).await; - - let endpoint = peer.remote_endpoint(); - let direction = peer.direction(); + /// Add a new peer to the peer list. + async fn new_peer(self: &Arc, conn: Arc) -> Result<()> { + // Create a new peer + let peer = Peer::new( + self.id.clone(), + Arc::downgrade(self), + conn.clone(), + self.config.clone(), + self.executor.clone(), + ); + peer.init().await?; + let pid = peer.id().expect("Get peer id after peer initialization"); - warn!("Peer {pid} removed, direction: {direction}, endpoint: {endpoint}",); - Ok(()) - } + // TODO: Consider restricting the subnet for inbound connections + if self.contains_peer(&pid).await { + return Err(Error::PeerAlreadyConnected); + } - /// Attach the core protocols. - async fn setup_protocols(&self) -> Result<()> { - let executor = self.executor.clone(); - let c = move |peer| PingProtocol::new(peer, executor.clone()); - self.attach_protocol::(Box::new(c)).await - } + // Insert the new peer + self.peers.write().await.insert(pid.clone(), peer.clone()); - /// Initiate a handshake with a connection. - async fn do_handshake( - &self, - conn: &Conn, - conn_direction: &ConnDirection, - ) -> Result { - trace!("Handshake started: {}", conn.peer_endpoint()?); - match conn_direction { - ConnDirection::Inbound => { - let result = self.wait_vermsg(conn).await; - match result { - Ok(_) => { - self.send_verack(conn, true).await?; - } - Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => { - self.send_verack(conn, false).await?; + let on_disconnect = { + let this = self.clone(); + let pid = pid.clone(); + |result| async move { + if let TaskResult::Completed(_) = result { + if let Err(err) = this.remove_peer(&pid).await { + error!("Failed to remove peer {pid}: {err}"); } - _ => {} } - result - } - - ConnDirection::Outbound => { - self.send_vermsg(conn).await?; - self.wait_verack(conn).await } - } - } + }; - /// Send a Version message - async fn send_vermsg(&self, conn: &Conn) -> Result<()> { - let pids = self.protocol_versions.read().await; - let protocols = pids.iter().map(|p| (p.0.clone(), p.1.v.clone())).collect(); - drop(pids); + self.task_group.spawn(peer.run(), on_disconnect); - let vermsg = VerMsg { - peer_id: self.id.clone(), - protocols, - version: self.config.version.v.clone(), - }; + info!("Add new peer {pid}"); + self.monitor.notify(PPEvent::NewPeer(pid)).await; - trace!("Send VerMsg"); - conn.send(NetMsg::new(NetMsgCmd::Version, &vermsg)?).await?; Ok(()) } - /// Wait for a Version message - /// - /// Returns the peer's ID upon successfully receiving the Version message. - async fn wait_vermsg(&self, conn: &Conn) -> Result { - let t = Duration::from_secs(self.config.handshake_timeout); - let msg: NetMsg = timeout(t, conn.recv()).await??; - - let payload = get_msg_payload!(Version, msg); - let (vermsg, _) = decode::(&payload)?; - - if !version_match(&self.config.version.req, &vermsg.version) { - return Err(Error::IncompatibleVersion("system: {}".into())); - } - - self.protocols_match(&vermsg.protocols).await?; - - trace!("Received VerMsg from: {}", vermsg.peer_id); - Ok(vermsg.peer_id) - } + /// Shuts down the peer and remove it from the peer list. + async fn remove_peer(&self, pid: &PeerID) -> Result<()> { + let result = self.peers.write().await.remove(pid); - /// Send a Verack message - async fn send_verack(&self, conn: &Conn, ack: bool) -> Result<()> { - let verack = VerAckMsg { - peer_id: self.id.clone(), - ack, + let peer = match result { + Some(p) => p, + None => return Ok(()), }; - trace!("Send VerAckMsg {:?}", verack); - conn.send(NetMsg::new(NetMsgCmd::Verack, &verack)?).await?; - Ok(()) - } - - /// Wait for a Verack message - /// - /// Returns the peer's ID upon successfully receiving the Verack message. - async fn wait_verack(&self, conn: &Conn) -> Result { - let t = Duration::from_secs(self.config.handshake_timeout); - let msg: NetMsg = timeout(t, conn.recv()).await??; + let _ = peer.shutdown().await; - let payload = get_msg_payload!(Verack, msg); - let (verack, _) = decode::(&payload)?; - - if !verack.ack { - return Err(Error::IncompatiblePeer); - } + self.monitor.notify(PPEvent::RemovePeer(pid.clone())).await; - trace!("Received VerAckMsg from: {}", verack.peer_id); - Ok(verack.peer_id) + warn!("Peer {pid} removed",); + Ok(()) } - /// Check if the new connection has compatible protocols. - async fn protocols_match(&self, protocols: &HashMap) -> Result<()> { - for (n, pv) in protocols.iter() { - let pids = self.protocol_versions.read().await; - - match pids.get(n) { - Some(v) => { - if !version_match(&v.req, pv) { - return Err(Error::IncompatibleVersion(format!("{n} protocol: {pv}"))); - } - } - None => { - return Err(Error::UnsupportedProtocol(n.to_string())); - } - } - } - Ok(()) + /// Attach the core protocols. + async fn setup_core_protocols(&self) -> Result<()> { + let executor = self.executor.clone(); + let ping_interval = self.config.ping_interval; + let ping_timeout = self.config.ping_timeout; + let c = move |peer| PingProtocol::new(peer, ping_interval, ping_timeout, executor.clone()); + self.attach_protocol::(Box::new(c)).await } } diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs index 021844f..249692b 100644 --- a/p2p/src/protocol.rs +++ b/p2p/src/protocol.rs @@ -56,11 +56,8 @@ impl EventValue for ProtocolEvent { /// #[async_trait] /// impl Protocol for NewProtocol { /// async fn start(self: Arc) -> Result<(), Error> { -/// let listener = self.peer.register_listener::().await; /// loop { -/// let event = listener.recv().await.unwrap(); -/// -/// match event { +/// match self.peer.recv::().await.expect("Receive msg") { /// ProtocolEvent::Message(msg) => { /// println!("{:?}", msg); /// } @@ -69,8 +66,6 @@ impl EventValue for ProtocolEvent { /// } /// } /// } -/// -/// listener.cancel().await; /// Ok(()) /// } /// @@ -114,3 +109,10 @@ pub trait Protocol: Send + Sync { where Self: Sized; } + +#[async_trait] +pub(crate) trait InitProtocol: Send + Sync { + type T; + /// Initialize the protocol + async fn init(self: Arc) -> Self::T; +} diff --git a/p2p/src/protocols/handshake.rs b/p2p/src/protocols/handshake.rs new file mode 100644 index 0000000..b3fe989 --- /dev/null +++ b/p2p/src/protocols/handshake.rs @@ -0,0 +1,139 @@ +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use async_trait::async_trait; +use log::trace; + +use karyon_core::{async_util::timeout, util::decode}; + +use crate::{ + message::{NetMsg, NetMsgCmd, VerAckMsg, VerMsg}, + peer::Peer, + protocol::{InitProtocol, ProtocolID}, + version::{version_match, VersionInt}, + Error, PeerID, Result, Version, +}; + +pub struct HandshakeProtocol { + peer: Arc, + protocols: HashMap, +} + +#[async_trait] +impl InitProtocol for HandshakeProtocol { + type T = Result; + /// Initiate a handshake with a connection. + async fn init(self: Arc) -> Self::T { + trace!("Init Handshake: {}", self.peer.remote_endpoint()); + + if !self.peer.is_inbound() { + self.send_vermsg().await?; + } + + let t = Duration::from_secs(self.peer.config().handshake_timeout); + let msg: NetMsg = timeout(t, self.peer.conn.recv_inner()).await??; + match msg.header.command { + NetMsgCmd::Version => { + let result = self.validate_version_msg(&msg).await; + match result { + Ok(_) => { + self.send_verack(true).await?; + } + Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => { + self.send_verack(false).await?; + } + _ => {} + }; + result + } + NetMsgCmd::Verack => self.validate_verack_msg(&msg).await, + cmd => Err(Error::InvalidMsg(format!("unexpected msg found {:?}", cmd))), + } + } +} + +impl HandshakeProtocol { + pub fn new(peer: Arc, protocols: HashMap) -> Arc { + Arc::new(Self { peer, protocols }) + } + + /// Sends a Version message + async fn send_vermsg(&self) -> Result<()> { + let protocols = self + .protocols + .clone() + .into_iter() + .map(|p| (p.0, p.1.v)) + .collect(); + + let vermsg = VerMsg { + peer_id: self.peer.own_id().clone(), + protocols, + version: self.peer.config().version.v.clone(), + }; + + trace!("Send VerMsg"); + self.peer + .conn + .send_inner(NetMsg::new(NetMsgCmd::Version, &vermsg)?) + .await?; + Ok(()) + } + + /// Sends a Verack message + async fn send_verack(&self, ack: bool) -> Result<()> { + let verack = VerAckMsg { + peer_id: self.peer.own_id().clone(), + ack, + }; + + trace!("Send VerAckMsg {:?}", verack); + self.peer + .conn + .send_inner(NetMsg::new(NetMsgCmd::Verack, &verack)?) + .await?; + Ok(()) + } + + /// Validates the given version msg + async fn validate_version_msg(&self, msg: &NetMsg) -> Result { + let (vermsg, _) = decode::(&msg.payload)?; + + if !version_match(&self.peer.config().version.req, &vermsg.version) { + return Err(Error::IncompatibleVersion("system: {}".into())); + } + + self.protocols_match(&vermsg.protocols).await?; + + trace!("Received VerMsg from: {}", vermsg.peer_id); + Ok(vermsg.peer_id) + } + + /// Validates the given verack msg + async fn validate_verack_msg(&self, msg: &NetMsg) -> Result { + let (verack, _) = decode::(&msg.payload)?; + + if !verack.ack { + return Err(Error::IncompatiblePeer); + } + + trace!("Received VerAckMsg from: {}", verack.peer_id); + Ok(verack.peer_id) + } + + /// Check if the new connection has compatible protocols. + async fn protocols_match(&self, protocols: &HashMap) -> Result<()> { + for (n, pv) in protocols.iter() { + match self.protocols.get(n) { + Some(v) => { + if !version_match(&v.req, pv) { + return Err(Error::IncompatibleVersion(format!("{n} protocol: {pv}"))); + } + } + None => { + return Err(Error::UnsupportedProtocol(n.to_string())); + } + } + } + Ok(()) + } +} diff --git a/p2p/src/protocols/mod.rs b/p2p/src/protocols/mod.rs index 4a8f6b9..c58df03 100644 --- a/p2p/src/protocols/mod.rs +++ b/p2p/src/protocols/mod.rs @@ -1,3 +1,5 @@ +mod handshake; mod ping; -pub use ping::PingProtocol; +pub(crate) use handshake::HandshakeProtocol; +pub(crate) use ping::PingProtocol; diff --git a/p2p/src/protocols/ping.rs b/p2p/src/protocols/ping.rs index b800b23..f35b203 100644 --- a/p2p/src/protocols/ping.rs +++ b/p2p/src/protocols/ping.rs @@ -9,7 +9,6 @@ use rand::{rngs::OsRng, RngCore}; use karyon_core::{ async_runtime::Executor, async_util::{select, sleep, timeout, Either, TaskGroup, TaskResult}, - event::EventListener, util::decode, }; @@ -39,9 +38,12 @@ pub struct PingProtocol { impl PingProtocol { #[allow(clippy::new_ret_no_self)] - pub fn new(peer: Arc, executor: Executor) -> Arc { - let ping_interval = peer.config().ping_interval; - let ping_timeout = peer.config().ping_timeout; + pub fn new( + peer: Arc, + ping_interval: u64, + ping_timeout: u64, + executor: Executor, + ) -> Arc { Arc::new(Self { peer, ping_interval, @@ -50,13 +52,9 @@ impl PingProtocol { }) } - async fn recv_loop( - &self, - listener: &EventListener, - pong_chan: Sender<[u8; 32]>, - ) -> Result<()> { + async fn recv_loop(&self, pong_chan: Sender<[u8; 32]>) -> Result<()> { loop { - let event = listener.recv().await?; + let event = self.peer.recv::().await?; let msg_payload = match event.clone() { ProtocolEvent::Message(m) => m, ProtocolEvent::Shutdown => { @@ -70,7 +68,7 @@ impl PingProtocol { PingProtocolMsg::Ping(nonce) => { trace!("Received Ping message {:?}", nonce); self.peer - .send(&Self::id(), &PingProtocolMsg::Pong(nonce)) + .send(Self::id(), &PingProtocolMsg::Pong(nonce)) .await?; trace!("Send back Pong message {:?}", nonce); } @@ -82,7 +80,7 @@ impl PingProtocol { Ok(()) } - async fn ping_loop(self: Arc, chan: Receiver<[u8; 32]>) -> Result<()> { + async fn ping_loop(&self, chan: Receiver<[u8; 32]>) -> Result<()> { let rng = &mut OsRng; let mut retry = 0; @@ -94,12 +92,11 @@ impl PingProtocol { trace!("Send Ping message {:?}", ping_nonce); self.peer - .send(&Self::id(), &PingProtocolMsg::Ping(ping_nonce)) + .send(Self::id(), &PingProtocolMsg::Ping(ping_nonce)) .await?; - let d = Duration::from_secs(self.ping_timeout); - // Wait for Pong message + let d = Duration::from_secs(self.ping_timeout); let pong_msg = match timeout(d, chan.recv()).await { Ok(m) => m?, Err(_) => { @@ -107,13 +104,14 @@ impl PingProtocol { continue; } }; - trace!("Received Pong message {:?}", pong_msg); if pong_msg != ping_nonce { retry += 1; continue; } + + retry = 0; } Err(NetError::Timeout.into()) @@ -125,8 +123,8 @@ impl Protocol for PingProtocol { async fn start(self: Arc) -> Result<()> { trace!("Start Ping protocol"); + let stop_signal = async_channel::bounded::>(1); let (pong_chan, pong_chan_recv) = async_channel::bounded(1); - let (stop_signal_s, stop_signal) = async_channel::bounded::>(1); self.task_group.spawn( { @@ -135,15 +133,12 @@ impl Protocol for PingProtocol { }, |res| async move { if let TaskResult::Completed(result) = res { - let _ = stop_signal_s.send(result).await; + let _ = stop_signal.0.send(result).await; } }, ); - let listener = self.peer.register_listener::().await; - - let result = select(self.recv_loop(&listener, pong_chan), stop_signal.recv()).await; - listener.cancel().await; + let result = select(self.recv_loop(pong_chan), stop_signal.1.recv()).await; self.task_group.cancel().await; match result { -- cgit v1.2.3