aboutsummaryrefslogtreecommitdiff
path: root/p2p/src/protocols
diff options
context:
space:
mode:
authorhozan23 <hozan23@karyontech.net>2024-07-15 13:16:01 +0200
committerhozan23 <hozan23@karyontech.net>2024-07-15 13:16:01 +0200
commite15d3e6fd20b3f87abaad7ddec1c88b0e66419f9 (patch)
tree7976f6993e4f6b3646f5bd6954189346d5ffd330 /p2p/src/protocols
parent6c65232d741229635151671708556b9af7ef75ac (diff)
p2p: Major refactoring of the handshake protocol
Introduce a new protocol InitProtocol which can be used as the core protocol for initializing a connection with a peer. Move the handshake logic from the PeerPool module to the protocols directory and build a handshake protocol that implements InitProtocol trait.
Diffstat (limited to 'p2p/src/protocols')
-rw-r--r--p2p/src/protocols/handshake.rs139
-rw-r--r--p2p/src/protocols/mod.rs4
-rw-r--r--p2p/src/protocols/ping.rs39
3 files changed, 159 insertions, 23 deletions
diff --git a/p2p/src/protocols/handshake.rs b/p2p/src/protocols/handshake.rs
new file mode 100644
index 0000000..b3fe989
--- /dev/null
+++ b/p2p/src/protocols/handshake.rs
@@ -0,0 +1,139 @@
+use std::{collections::HashMap, sync::Arc, time::Duration};
+
+use async_trait::async_trait;
+use log::trace;
+
+use karyon_core::{async_util::timeout, util::decode};
+
+use crate::{
+ message::{NetMsg, NetMsgCmd, VerAckMsg, VerMsg},
+ peer::Peer,
+ protocol::{InitProtocol, ProtocolID},
+ version::{version_match, VersionInt},
+ Error, PeerID, Result, Version,
+};
+
+pub struct HandshakeProtocol {
+ peer: Arc<Peer>,
+ protocols: HashMap<ProtocolID, Version>,
+}
+
+#[async_trait]
+impl InitProtocol for HandshakeProtocol {
+ type T = Result<PeerID>;
+ /// Initiate a handshake with a connection.
+ async fn init(self: Arc<Self>) -> Self::T {
+ trace!("Init Handshake: {}", self.peer.remote_endpoint());
+
+ if !self.peer.is_inbound() {
+ self.send_vermsg().await?;
+ }
+
+ let t = Duration::from_secs(self.peer.config().handshake_timeout);
+ let msg: NetMsg = timeout(t, self.peer.conn.recv_inner()).await??;
+ match msg.header.command {
+ NetMsgCmd::Version => {
+ let result = self.validate_version_msg(&msg).await;
+ match result {
+ Ok(_) => {
+ self.send_verack(true).await?;
+ }
+ Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => {
+ self.send_verack(false).await?;
+ }
+ _ => {}
+ };
+ result
+ }
+ NetMsgCmd::Verack => self.validate_verack_msg(&msg).await,
+ cmd => Err(Error::InvalidMsg(format!("unexpected msg found {:?}", cmd))),
+ }
+ }
+}
+
+impl HandshakeProtocol {
+ pub fn new(peer: Arc<Peer>, protocols: HashMap<ProtocolID, Version>) -> Arc<Self> {
+ Arc::new(Self { peer, protocols })
+ }
+
+ /// Sends a Version message
+ async fn send_vermsg(&self) -> Result<()> {
+ let protocols = self
+ .protocols
+ .clone()
+ .into_iter()
+ .map(|p| (p.0, p.1.v))
+ .collect();
+
+ let vermsg = VerMsg {
+ peer_id: self.peer.own_id().clone(),
+ protocols,
+ version: self.peer.config().version.v.clone(),
+ };
+
+ trace!("Send VerMsg");
+ self.peer
+ .conn
+ .send_inner(NetMsg::new(NetMsgCmd::Version, &vermsg)?)
+ .await?;
+ Ok(())
+ }
+
+ /// Sends a Verack message
+ async fn send_verack(&self, ack: bool) -> Result<()> {
+ let verack = VerAckMsg {
+ peer_id: self.peer.own_id().clone(),
+ ack,
+ };
+
+ trace!("Send VerAckMsg {:?}", verack);
+ self.peer
+ .conn
+ .send_inner(NetMsg::new(NetMsgCmd::Verack, &verack)?)
+ .await?;
+ Ok(())
+ }
+
+ /// Validates the given version msg
+ async fn validate_version_msg(&self, msg: &NetMsg) -> Result<PeerID> {
+ let (vermsg, _) = decode::<VerMsg>(&msg.payload)?;
+
+ if !version_match(&self.peer.config().version.req, &vermsg.version) {
+ return Err(Error::IncompatibleVersion("system: {}".into()));
+ }
+
+ self.protocols_match(&vermsg.protocols).await?;
+
+ trace!("Received VerMsg from: {}", vermsg.peer_id);
+ Ok(vermsg.peer_id)
+ }
+
+ /// Validates the given verack msg
+ async fn validate_verack_msg(&self, msg: &NetMsg) -> Result<PeerID> {
+ let (verack, _) = decode::<VerAckMsg>(&msg.payload)?;
+
+ if !verack.ack {
+ return Err(Error::IncompatiblePeer);
+ }
+
+ trace!("Received VerAckMsg from: {}", verack.peer_id);
+ Ok(verack.peer_id)
+ }
+
+ /// Check if the new connection has compatible protocols.
+ async fn protocols_match(&self, protocols: &HashMap<ProtocolID, VersionInt>) -> Result<()> {
+ for (n, pv) in protocols.iter() {
+ match self.protocols.get(n) {
+ Some(v) => {
+ if !version_match(&v.req, pv) {
+ return Err(Error::IncompatibleVersion(format!("{n} protocol: {pv}")));
+ }
+ }
+ None => {
+ return Err(Error::UnsupportedProtocol(n.to_string()));
+ }
+ }
+ }
+ Ok(())
+ }
+}
diff --git a/p2p/src/protocols/mod.rs b/p2p/src/protocols/mod.rs
index 4a8f6b9..c58df03 100644
--- a/p2p/src/protocols/mod.rs
+++ b/p2p/src/protocols/mod.rs
@@ -1,3 +1,5 @@
+mod handshake;
mod ping;
-pub use ping::PingProtocol;
+pub(crate) use handshake::HandshakeProtocol;
+pub(crate) use ping::PingProtocol;
diff --git a/p2p/src/protocols/ping.rs b/p2p/src/protocols/ping.rs
index b800b23..f35b203 100644
--- a/p2p/src/protocols/ping.rs
+++ b/p2p/src/protocols/ping.rs
@@ -9,7 +9,6 @@ use rand::{rngs::OsRng, RngCore};
use karyon_core::{
async_runtime::Executor,
async_util::{select, sleep, timeout, Either, TaskGroup, TaskResult},
- event::EventListener,
util::decode,
};
@@ -39,9 +38,12 @@ pub struct PingProtocol {
impl PingProtocol {
#[allow(clippy::new_ret_no_self)]
- pub fn new(peer: Arc<Peer>, executor: Executor) -> Arc<dyn Protocol> {
- let ping_interval = peer.config().ping_interval;
- let ping_timeout = peer.config().ping_timeout;
+ pub fn new(
+ peer: Arc<Peer>,
+ ping_interval: u64,
+ ping_timeout: u64,
+ executor: Executor,
+ ) -> Arc<dyn Protocol> {
Arc::new(Self {
peer,
ping_interval,
@@ -50,13 +52,9 @@ impl PingProtocol {
})
}
- async fn recv_loop(
- &self,
- listener: &EventListener<ProtocolID, ProtocolEvent>,
- pong_chan: Sender<[u8; 32]>,
- ) -> Result<()> {
+ async fn recv_loop(&self, pong_chan: Sender<[u8; 32]>) -> Result<()> {
loop {
- let event = listener.recv().await?;
+ let event = self.peer.recv::<Self>().await?;
let msg_payload = match event.clone() {
ProtocolEvent::Message(m) => m,
ProtocolEvent::Shutdown => {
@@ -70,7 +68,7 @@ impl PingProtocol {
PingProtocolMsg::Ping(nonce) => {
trace!("Received Ping message {:?}", nonce);
self.peer
- .send(&Self::id(), &PingProtocolMsg::Pong(nonce))
+ .send(Self::id(), &PingProtocolMsg::Pong(nonce))
.await?;
trace!("Send back Pong message {:?}", nonce);
}
@@ -82,7 +80,7 @@ impl PingProtocol {
Ok(())
}
- async fn ping_loop(self: Arc<Self>, chan: Receiver<[u8; 32]>) -> Result<()> {
+ async fn ping_loop(&self, chan: Receiver<[u8; 32]>) -> Result<()> {
let rng = &mut OsRng;
let mut retry = 0;
@@ -94,12 +92,11 @@ impl PingProtocol {
trace!("Send Ping message {:?}", ping_nonce);
self.peer
- .send(&Self::id(), &PingProtocolMsg::Ping(ping_nonce))
+ .send(Self::id(), &PingProtocolMsg::Ping(ping_nonce))
.await?;
- let d = Duration::from_secs(self.ping_timeout);
-
// Wait for Pong message
+ let d = Duration::from_secs(self.ping_timeout);
let pong_msg = match timeout(d, chan.recv()).await {
Ok(m) => m?,
Err(_) => {
@@ -107,13 +104,14 @@ impl PingProtocol {
continue;
}
};
-
trace!("Received Pong message {:?}", pong_msg);
if pong_msg != ping_nonce {
retry += 1;
continue;
}
+
+ retry = 0;
}
Err(NetError::Timeout.into())
@@ -125,8 +123,8 @@ impl Protocol for PingProtocol {
async fn start(self: Arc<Self>) -> Result<()> {
trace!("Start Ping protocol");
+ let stop_signal = async_channel::bounded::<Result<()>>(1);
let (pong_chan, pong_chan_recv) = async_channel::bounded(1);
- let (stop_signal_s, stop_signal) = async_channel::bounded::<Result<()>>(1);
self.task_group.spawn(
{
@@ -135,15 +133,12 @@ impl Protocol for PingProtocol {
},
|res| async move {
if let TaskResult::Completed(result) = res {
- let _ = stop_signal_s.send(result).await;
+ let _ = stop_signal.0.send(result).await;
}
},
);
- let listener = self.peer.register_listener::<Self>().await;
-
- let result = select(self.recv_loop(&listener, pong_chan), stop_signal.recv()).await;
- listener.cancel().await;
+ let result = select(self.recv_loop(pong_chan), stop_signal.1.recv()).await;
self.task_group.cancel().await;
match result {