From e15d3e6fd20b3f87abaad7ddec1c88b0e66419f9 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Mon, 15 Jul 2024 13:16:01 +0200 Subject: p2p: Major refactoring of the handshake protocol Introduce a new protocol InitProtocol which can be used as the core protocol for initializing a connection with a peer. Move the handshake logic from the PeerPool module to the protocols directory and build a handshake protocol that implements InitProtocol trait. --- p2p/src/protocols/handshake.rs | 139 +++++++++++++++++++++++++++++++++++++++++ p2p/src/protocols/mod.rs | 4 +- p2p/src/protocols/ping.rs | 39 +++++------- 3 files changed, 159 insertions(+), 23 deletions(-) create mode 100644 p2p/src/protocols/handshake.rs (limited to 'p2p/src/protocols') diff --git a/p2p/src/protocols/handshake.rs b/p2p/src/protocols/handshake.rs new file mode 100644 index 0000000..b3fe989 --- /dev/null +++ b/p2p/src/protocols/handshake.rs @@ -0,0 +1,139 @@ +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use async_trait::async_trait; +use log::trace; + +use karyon_core::{async_util::timeout, util::decode}; + +use crate::{ + message::{NetMsg, NetMsgCmd, VerAckMsg, VerMsg}, + peer::Peer, + protocol::{InitProtocol, ProtocolID}, + version::{version_match, VersionInt}, + Error, PeerID, Result, Version, +}; + +pub struct HandshakeProtocol { + peer: Arc, + protocols: HashMap, +} + +#[async_trait] +impl InitProtocol for HandshakeProtocol { + type T = Result; + /// Initiate a handshake with a connection. + async fn init(self: Arc) -> Self::T { + trace!("Init Handshake: {}", self.peer.remote_endpoint()); + + if !self.peer.is_inbound() { + self.send_vermsg().await?; + } + + let t = Duration::from_secs(self.peer.config().handshake_timeout); + let msg: NetMsg = timeout(t, self.peer.conn.recv_inner()).await??; + match msg.header.command { + NetMsgCmd::Version => { + let result = self.validate_version_msg(&msg).await; + match result { + Ok(_) => { + self.send_verack(true).await?; + } + Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => { + self.send_verack(false).await?; + } + _ => {} + }; + result + } + NetMsgCmd::Verack => self.validate_verack_msg(&msg).await, + cmd => Err(Error::InvalidMsg(format!("unexpected msg found {:?}", cmd))), + } + } +} + +impl HandshakeProtocol { + pub fn new(peer: Arc, protocols: HashMap) -> Arc { + Arc::new(Self { peer, protocols }) + } + + /// Sends a Version message + async fn send_vermsg(&self) -> Result<()> { + let protocols = self + .protocols + .clone() + .into_iter() + .map(|p| (p.0, p.1.v)) + .collect(); + + let vermsg = VerMsg { + peer_id: self.peer.own_id().clone(), + protocols, + version: self.peer.config().version.v.clone(), + }; + + trace!("Send VerMsg"); + self.peer + .conn + .send_inner(NetMsg::new(NetMsgCmd::Version, &vermsg)?) + .await?; + Ok(()) + } + + /// Sends a Verack message + async fn send_verack(&self, ack: bool) -> Result<()> { + let verack = VerAckMsg { + peer_id: self.peer.own_id().clone(), + ack, + }; + + trace!("Send VerAckMsg {:?}", verack); + self.peer + .conn + .send_inner(NetMsg::new(NetMsgCmd::Verack, &verack)?) + .await?; + Ok(()) + } + + /// Validates the given version msg + async fn validate_version_msg(&self, msg: &NetMsg) -> Result { + let (vermsg, _) = decode::(&msg.payload)?; + + if !version_match(&self.peer.config().version.req, &vermsg.version) { + return Err(Error::IncompatibleVersion("system: {}".into())); + } + + self.protocols_match(&vermsg.protocols).await?; + + trace!("Received VerMsg from: {}", vermsg.peer_id); + Ok(vermsg.peer_id) + } + + /// Validates the given verack msg + async fn validate_verack_msg(&self, msg: &NetMsg) -> Result { + let (verack, _) = decode::(&msg.payload)?; + + if !verack.ack { + return Err(Error::IncompatiblePeer); + } + + trace!("Received VerAckMsg from: {}", verack.peer_id); + Ok(verack.peer_id) + } + + /// Check if the new connection has compatible protocols. + async fn protocols_match(&self, protocols: &HashMap) -> Result<()> { + for (n, pv) in protocols.iter() { + match self.protocols.get(n) { + Some(v) => { + if !version_match(&v.req, pv) { + return Err(Error::IncompatibleVersion(format!("{n} protocol: {pv}"))); + } + } + None => { + return Err(Error::UnsupportedProtocol(n.to_string())); + } + } + } + Ok(()) + } +} diff --git a/p2p/src/protocols/mod.rs b/p2p/src/protocols/mod.rs index 4a8f6b9..c58df03 100644 --- a/p2p/src/protocols/mod.rs +++ b/p2p/src/protocols/mod.rs @@ -1,3 +1,5 @@ +mod handshake; mod ping; -pub use ping::PingProtocol; +pub(crate) use handshake::HandshakeProtocol; +pub(crate) use ping::PingProtocol; diff --git a/p2p/src/protocols/ping.rs b/p2p/src/protocols/ping.rs index b800b23..f35b203 100644 --- a/p2p/src/protocols/ping.rs +++ b/p2p/src/protocols/ping.rs @@ -9,7 +9,6 @@ use rand::{rngs::OsRng, RngCore}; use karyon_core::{ async_runtime::Executor, async_util::{select, sleep, timeout, Either, TaskGroup, TaskResult}, - event::EventListener, util::decode, }; @@ -39,9 +38,12 @@ pub struct PingProtocol { impl PingProtocol { #[allow(clippy::new_ret_no_self)] - pub fn new(peer: Arc, executor: Executor) -> Arc { - let ping_interval = peer.config().ping_interval; - let ping_timeout = peer.config().ping_timeout; + pub fn new( + peer: Arc, + ping_interval: u64, + ping_timeout: u64, + executor: Executor, + ) -> Arc { Arc::new(Self { peer, ping_interval, @@ -50,13 +52,9 @@ impl PingProtocol { }) } - async fn recv_loop( - &self, - listener: &EventListener, - pong_chan: Sender<[u8; 32]>, - ) -> Result<()> { + async fn recv_loop(&self, pong_chan: Sender<[u8; 32]>) -> Result<()> { loop { - let event = listener.recv().await?; + let event = self.peer.recv::().await?; let msg_payload = match event.clone() { ProtocolEvent::Message(m) => m, ProtocolEvent::Shutdown => { @@ -70,7 +68,7 @@ impl PingProtocol { PingProtocolMsg::Ping(nonce) => { trace!("Received Ping message {:?}", nonce); self.peer - .send(&Self::id(), &PingProtocolMsg::Pong(nonce)) + .send(Self::id(), &PingProtocolMsg::Pong(nonce)) .await?; trace!("Send back Pong message {:?}", nonce); } @@ -82,7 +80,7 @@ impl PingProtocol { Ok(()) } - async fn ping_loop(self: Arc, chan: Receiver<[u8; 32]>) -> Result<()> { + async fn ping_loop(&self, chan: Receiver<[u8; 32]>) -> Result<()> { let rng = &mut OsRng; let mut retry = 0; @@ -94,12 +92,11 @@ impl PingProtocol { trace!("Send Ping message {:?}", ping_nonce); self.peer - .send(&Self::id(), &PingProtocolMsg::Ping(ping_nonce)) + .send(Self::id(), &PingProtocolMsg::Ping(ping_nonce)) .await?; - let d = Duration::from_secs(self.ping_timeout); - // Wait for Pong message + let d = Duration::from_secs(self.ping_timeout); let pong_msg = match timeout(d, chan.recv()).await { Ok(m) => m?, Err(_) => { @@ -107,13 +104,14 @@ impl PingProtocol { continue; } }; - trace!("Received Pong message {:?}", pong_msg); if pong_msg != ping_nonce { retry += 1; continue; } + + retry = 0; } Err(NetError::Timeout.into()) @@ -125,8 +123,8 @@ impl Protocol for PingProtocol { async fn start(self: Arc) -> Result<()> { trace!("Start Ping protocol"); + let stop_signal = async_channel::bounded::>(1); let (pong_chan, pong_chan_recv) = async_channel::bounded(1); - let (stop_signal_s, stop_signal) = async_channel::bounded::>(1); self.task_group.spawn( { @@ -135,15 +133,12 @@ impl Protocol for PingProtocol { }, |res| async move { if let TaskResult::Completed(result) = res { - let _ = stop_signal_s.send(result).await; + let _ = stop_signal.0.send(result).await; } }, ); - let listener = self.peer.register_listener::().await; - - let result = select(self.recv_loop(&listener, pong_chan), stop_signal.recv()).await; - listener.cancel().await; + let result = select(self.recv_loop(pong_chan), stop_signal.1.recv()).await; self.task_group.cancel().await; match result { -- cgit v1.2.3