aboutsummaryrefslogtreecommitdiff
path: root/p2p/src
diff options
context:
space:
mode:
authorhozan23 <hozan23@karyontech.net>2024-07-15 13:16:01 +0200
committerhozan23 <hozan23@karyontech.net>2024-07-15 13:16:01 +0200
commite15d3e6fd20b3f87abaad7ddec1c88b0e66419f9 (patch)
tree7976f6993e4f6b3646f5bd6954189346d5ffd330 /p2p/src
parent6c65232d741229635151671708556b9af7ef75ac (diff)
p2p: Major refactoring of the handshake protocol
Introduce a new protocol InitProtocol which can be used as the core protocol for initializing a connection with a peer. Move the handshake logic from the PeerPool module to the protocols directory and build a handshake protocol that implements InitProtocol trait.
Diffstat (limited to 'p2p/src')
-rw-r--r--p2p/src/conn_queue.rs53
-rw-r--r--p2p/src/connection.rs110
-rw-r--r--p2p/src/discovery/lookup.rs114
-rw-r--r--p2p/src/discovery/mod.rs17
-rw-r--r--p2p/src/discovery/refresh.rs28
-rw-r--r--p2p/src/lib.rs1
-rw-r--r--p2p/src/message.rs20
-rw-r--r--p2p/src/monitor/mod.rs6
-rw-r--r--p2p/src/peer/mod.rs247
-rw-r--r--p2p/src/peer_pool.rs296
-rw-r--r--p2p/src/protocol.rs14
-rw-r--r--p2p/src/protocols/handshake.rs139
-rw-r--r--p2p/src/protocols/mod.rs4
-rw-r--r--p2p/src/protocols/ping.rs39
14 files changed, 571 insertions, 517 deletions
diff --git a/p2p/src/conn_queue.rs b/p2p/src/conn_queue.rs
index 9a153f3..1b6ef98 100644
--- a/p2p/src/conn_queue.rs
+++ b/p2p/src/conn_queue.rs
@@ -1,37 +1,13 @@
-use std::{collections::VecDeque, fmt, sync::Arc};
-
-use async_channel::Sender;
+use std::{collections::VecDeque, sync::Arc};
use karyon_core::{async_runtime::lock::Mutex, async_util::CondVar};
use karyon_net::Conn;
-use crate::{message::NetMsg, Result};
-
-/// Defines the direction of a network connection.
-#[derive(Clone, Debug)]
-pub enum ConnDirection {
- Inbound,
- Outbound,
-}
-
-impl fmt::Display for ConnDirection {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- match self {
- ConnDirection::Inbound => write!(f, "Inbound"),
- ConnDirection::Outbound => write!(f, "Outbound"),
- }
- }
-}
-
-pub struct NewConn {
- pub direction: ConnDirection,
- pub conn: Conn<NetMsg>,
- pub disconnect_signal: Sender<Result<()>>,
-}
+use crate::{connection::ConnDirection, connection::Connection, message::NetMsg, Result};
/// Connection queue
pub struct ConnQueue {
- queue: Mutex<VecDeque<NewConn>>,
+ queue: Mutex<VecDeque<Connection>>,
conn_available: CondVar,
}
@@ -43,24 +19,27 @@ impl ConnQueue {
})
}
- /// Push a connection into the queue and wait for the disconnect signal
+ /// Handle a connection by pushing it into the queue and wait for the disconnect signal
pub async fn handle(&self, conn: Conn<NetMsg>, direction: ConnDirection) -> Result<()> {
- let (disconnect_signal, chan) = async_channel::bounded(1);
- let new_conn = NewConn {
- direction,
- conn,
- disconnect_signal,
- };
+ let endpoint = conn.peer_endpoint()?;
+
+ let (disconnect_tx, disconnect_rx) = async_channel::bounded(1);
+ let new_conn = Connection::new(conn, disconnect_tx, direction, endpoint);
+
+ // Push a new conn to the queue
self.queue.lock().await.push_back(new_conn);
self.conn_available.signal();
- if let Ok(result) = chan.recv().await {
+
+ // Wait for the disconnect signal from the connection handler
+ if let Ok(result) = disconnect_rx.recv().await {
return result;
}
+
Ok(())
}
- /// Receive the next connection in the queue
- pub async fn next(&self) -> NewConn {
+ /// Waits for the next connection in the queue
+ pub async fn next(&self) -> Connection {
let mut queue = self.queue.lock().await;
while queue.is_empty() {
queue = self.conn_available.wait(queue).await;
diff --git a/p2p/src/connection.rs b/p2p/src/connection.rs
new file mode 100644
index 0000000..52190a8
--- /dev/null
+++ b/p2p/src/connection.rs
@@ -0,0 +1,110 @@
+use std::{collections::HashMap, fmt, sync::Arc};
+
+use async_channel::Sender;
+use bincode::Encode;
+
+use karyon_core::{
+ event::{EventListener, EventSys},
+ util::encode,
+};
+
+use karyon_net::{Conn, Endpoint};
+
+use crate::{
+ message::{NetMsg, NetMsgCmd, ProtocolMsg, ShutdownMsg},
+ protocol::{Protocol, ProtocolEvent, ProtocolID},
+ Error, Result,
+};
+
+/// Defines the direction of a network connection.
+#[derive(Clone, Debug)]
+pub enum ConnDirection {
+ Inbound,
+ Outbound,
+}
+
+impl fmt::Display for ConnDirection {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ ConnDirection::Inbound => write!(f, "Inbound"),
+ ConnDirection::Outbound => write!(f, "Outbound"),
+ }
+ }
+}
+
+pub struct Connection {
+ pub(crate) direction: ConnDirection,
+ conn: Conn<NetMsg>,
+ disconnect_signal: Sender<Result<()>>,
+ /// `EventSys` responsible for sending events to the registered protocols.
+ protocol_events: Arc<EventSys<ProtocolID>>,
+ pub(crate) remote_endpoint: Endpoint,
+ listeners: HashMap<ProtocolID, EventListener<ProtocolID, ProtocolEvent>>,
+}
+
+impl Connection {
+ pub fn new(
+ conn: Conn<NetMsg>,
+ signal: Sender<Result<()>>,
+ direction: ConnDirection,
+ remote_endpoint: Endpoint,
+ ) -> Self {
+ Self {
+ conn,
+ direction,
+ protocol_events: EventSys::new(),
+ disconnect_signal: signal,
+ remote_endpoint,
+ listeners: HashMap::new(),
+ }
+ }
+
+ pub async fn send<T: Encode>(&self, protocol_id: ProtocolID, msg: T) -> Result<()> {
+ let payload = encode(&msg)?;
+
+ let proto_msg = ProtocolMsg {
+ protocol_id,
+ payload: payload.to_vec(),
+ };
+
+ let msg = NetMsg::new(NetMsgCmd::Protocol, &proto_msg)?;
+ self.conn.send(msg).await.map_err(Error::from)
+ }
+
+ pub async fn recv<P: Protocol>(&self) -> Result<ProtocolEvent> {
+ match self.listeners.get(&P::id()) {
+ Some(l) => l.recv().await.map_err(Error::from),
+ // TODO
+ None => todo!(),
+ }
+ }
+
+ /// Registers a listener for the given Protocol `P`.
+ pub async fn register_protocol(&mut self, protocol_id: String) {
+ let listener = self.protocol_events.register(&protocol_id).await;
+ self.listeners.insert(protocol_id, listener);
+ }
+
+ pub async fn emit_msg(&self, id: &ProtocolID, event: &ProtocolEvent) -> Result<()> {
+ self.protocol_events.emit_by_topic(id, event).await?;
+ Ok(())
+ }
+
+ pub async fn recv_inner(&self) -> Result<NetMsg> {
+ self.conn.recv().await.map_err(Error::from)
+ }
+
+ pub async fn send_inner(&self, msg: NetMsg) -> Result<()> {
+ self.conn.send(msg).await.map_err(Error::from)
+ }
+
+ pub async fn disconnect(&self, res: Result<()>) -> Result<()> {
+ self.protocol_events.clear().await;
+ self.disconnect_signal.send(res).await?;
+
+ let m = NetMsg::new(NetMsgCmd::Shutdown, ShutdownMsg(0)).expect("Create shutdown message");
+ self.conn.send(m).await.map_err(Error::from)?;
+
+ Ok(())
+ }
+}
diff --git a/p2p/src/discovery/lookup.rs b/p2p/src/discovery/lookup.rs
index 9ddf614..47a1d09 100644
--- a/p2p/src/discovery/lookup.rs
+++ b/p2p/src/discovery/lookup.rs
@@ -2,24 +2,17 @@ use std::{sync::Arc, time::Duration};
use futures_util::stream::{FuturesUnordered, StreamExt};
use log::{error, trace};
+use parking_lot::RwLock;
use rand::{rngs::OsRng, seq::SliceRandom, RngCore};
-use karyon_core::{
- async_runtime::{lock::RwLock, Executor},
- async_util::timeout,
- crypto::KeyPair,
- util::decode,
-};
+use karyon_core::{async_runtime::Executor, async_util::timeout, crypto::KeyPair, util::decode};
use karyon_net::{Conn, Endpoint};
use crate::{
connector::Connector,
listener::Listener,
- message::{
- get_msg_payload, FindPeerMsg, NetMsg, NetMsgCmd, PeerMsg, PeersMsg, PingMsg, PongMsg,
- ShutdownMsg,
- },
+ message::{FindPeerMsg, NetMsg, NetMsgCmd, PeerMsg, PeersMsg, PingMsg, PongMsg, ShutdownMsg},
monitor::{ConnEvent, DiscvEvent, Monitor},
routing_table::RoutingTable,
slots::ConnectionSlots,
@@ -46,7 +39,7 @@ pub struct LookupService {
outbound_slots: Arc<ConnectionSlots>,
/// Resolved listen endpoint
- listen_endpoint: Option<RwLock<Endpoint>>,
+ listen_endpoint: RwLock<Option<Endpoint>>,
/// Holds the configuration for the P2P network.
config: Arc<Config>,
@@ -85,18 +78,13 @@ impl LookupService {
ex,
);
- let listen_endpoint = config
- .listen_endpoint
- .as_ref()
- .map(|endpoint| RwLock::new(endpoint.clone()));
-
Self {
id: id.clone(),
table,
listener,
connector,
outbound_slots,
- listen_endpoint,
+ listen_endpoint: RwLock::new(None),
config,
monitor,
}
@@ -109,10 +97,18 @@ impl LookupService {
}
/// Set the resolved listen endpoint.
- pub async fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) {
- if let Some(endpoint) = &self.listen_endpoint {
- *endpoint.write().await = resolved_endpoint.clone();
- }
+ pub fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) -> Result<()> {
+ let resolved_endpoint = Endpoint::Tcp(
+ resolved_endpoint.addr()?.clone(),
+ self.config.discovery_port,
+ );
+ *self.listen_endpoint.write() = Some(resolved_endpoint);
+ Ok(())
+ }
+
+ /// Get the listening endpoint.
+ pub fn listen_endpoint(&self) -> Option<Endpoint> {
+ self.listen_endpoint.read().clone()
}
/// Shuts down the lookup service.
@@ -253,36 +249,51 @@ impl LookupService {
target_peer_id: &PeerID,
) -> Result<Vec<PeerMsg>> {
trace!("Send Ping msg");
- self.send_ping_msg(&conn).await?;
+ let peers;
- trace!("Send FindPeer msg");
- let peers = self.send_findpeer_msg(&conn, target_peer_id).await?;
+ let ping_msg = self.send_ping_msg(&conn).await?;
- if peers.0.len() >= MAX_PEERS_IN_PEERSMSG {
- return Err(Error::Lookup("Received too many peers in PeersMsg"));
+ loop {
+ let t = Duration::from_secs(self.config.lookup_response_timeout);
+ let msg: NetMsg = timeout(t, conn.recv()).await??;
+ match msg.header.command {
+ NetMsgCmd::Pong => {
+ let (pong_msg, _) = decode::<PongMsg>(&msg.payload)?;
+ if ping_msg.nonce != pong_msg.0 {
+ return Err(Error::InvalidPongMsg);
+ }
+ trace!("Send FindPeer msg");
+ self.send_findpeer_msg(&conn, target_peer_id).await?;
+ }
+ NetMsgCmd::Peers => {
+ peers = decode::<PeersMsg>(&msg.payload)?.0.peers;
+ if peers.len() >= MAX_PEERS_IN_PEERSMSG {
+ return Err(Error::Lookup("Received too many peers in PeersMsg"));
+ }
+ break;
+ }
+ c => return Err(Error::InvalidMsg(format!("Unexpected msg: {:?}", c))),
+ };
}
trace!("Send Peer msg");
- if let Some(endpoint) = &self.listen_endpoint {
- self.send_peer_msg(&conn, endpoint.read().await.clone())
- .await?;
+ if let Some(endpoint) = self.listen_endpoint() {
+ self.send_peer_msg(&conn, endpoint.clone()).await?;
}
trace!("Send Shutdown msg");
self.send_shutdown_msg(&conn).await?;
- Ok(peers.0)
+ Ok(peers)
}
/// Start a listener.
async fn start_listener(self: &Arc<Self>) -> Result<()> {
- let addr = match &self.listen_endpoint {
- Some(a) => a.read().await.addr()?.clone(),
+ let endpoint: Endpoint = match self.listen_endpoint() {
+ Some(e) => e.clone(),
None => return Ok(()),
};
- let endpoint = Endpoint::Tcp(addr, self.config.discovery_port);
-
let callback = {
let this = self.clone();
|conn: Conn<NetMsg>| async move {
@@ -292,7 +303,7 @@ impl LookupService {
}
};
- self.listener.start(endpoint.clone(), callback).await?;
+ self.listener.start(endpoint, callback).await?;
Ok(())
}
@@ -329,10 +340,9 @@ impl LookupService {
}
}
- /// Sends a Ping msg and wait to receive the Pong message.
- async fn send_ping_msg(&self, conn: &Conn<NetMsg>) -> Result<()> {
+ /// Sends a Ping msg.
+ async fn send_ping_msg(&self, conn: &Conn<NetMsg>) -> Result<PingMsg> {
trace!("Send Pong msg");
-
let mut nonce: [u8; 32] = [0; 32];
RngCore::fill_bytes(&mut OsRng, &mut nonce);
@@ -341,18 +351,7 @@ impl LookupService {
nonce,
};
conn.send(NetMsg::new(NetMsgCmd::Ping, &ping_msg)?).await?;
-
- let t = Duration::from_secs(self.config.lookup_response_timeout);
- let recv_msg: NetMsg = timeout(t, conn.recv()).await??;
-
- let payload = get_msg_payload!(Pong, recv_msg);
- let (pong_msg, _) = decode::<PongMsg>(&payload)?;
-
- if ping_msg.nonce != pong_msg.0 {
- return Err(Error::InvalidPongMsg);
- }
-
- Ok(())
+ Ok(ping_msg)
}
/// Sends a Pong msg
@@ -363,22 +362,15 @@ impl LookupService {
Ok(())
}
- /// Sends a FindPeer msg and wait to receivet the Peers msg.
- async fn send_findpeer_msg(&self, conn: &Conn<NetMsg>, peer_id: &PeerID) -> Result<PeersMsg> {
+ /// Sends a FindPeer msg
+ async fn send_findpeer_msg(&self, conn: &Conn<NetMsg>, peer_id: &PeerID) -> Result<()> {
trace!("Send FindPeer msg");
conn.send(NetMsg::new(
NetMsgCmd::FindPeer,
FindPeerMsg(peer_id.clone()),
)?)
.await?;
-
- let t = Duration::from_secs(self.config.lookup_response_timeout);
- let recv_msg: NetMsg = timeout(t, conn.recv()).await??;
-
- let payload = get_msg_payload!(Peers, recv_msg);
- let (peers, _) = decode(&payload)?;
-
- Ok(peers)
+ Ok(())
}
/// Sends a Peers msg.
@@ -389,7 +381,7 @@ impl LookupService {
.closest_entries(&peer_id.0, MAX_PEERS_IN_PEERSMSG);
let peers: Vec<PeerMsg> = entries.into_iter().map(|e| e.into()).collect();
- conn.send(NetMsg::new(NetMsgCmd::Peers, PeersMsg(peers))?)
+ conn.send(NetMsg::new(NetMsgCmd::Peers, PeersMsg { peers })?)
.await?;
Ok(())
}
diff --git a/p2p/src/discovery/mod.rs b/p2p/src/discovery/mod.rs
index a9d99d6..a81a817 100644
--- a/p2p/src/discovery/mod.rs
+++ b/p2p/src/discovery/mod.rs
@@ -16,7 +16,8 @@ use karyon_net::{Conn, Endpoint};
use crate::{
config::Config,
- conn_queue::{ConnDirection, ConnQueue},
+ conn_queue::ConnQueue,
+ connection::ConnDirection,
connector::Connector,
listener::Listener,
message::NetMsg,
@@ -132,15 +133,11 @@ impl Discovery {
let resolved_endpoint = self.start_listener(endpoint).await?;
- if endpoint.addr()? != resolved_endpoint.addr()? {
- info!("Resolved listen endpoint: {resolved_endpoint}");
- self.lookup_service
- .set_listen_endpoint(&resolved_endpoint)
- .await;
- self.refresh_service
- .set_listen_endpoint(&resolved_endpoint)
- .await;
- }
+ info!("Resolved listen endpoint: {resolved_endpoint}");
+ self.lookup_service
+ .set_listen_endpoint(&resolved_endpoint)?;
+ self.refresh_service
+ .set_listen_endpoint(&resolved_endpoint)?;
}
// Start the lookup service
diff --git a/p2p/src/discovery/refresh.rs b/p2p/src/discovery/refresh.rs
index b4f5396..1452a1b 100644
--- a/p2p/src/discovery/refresh.rs
+++ b/p2p/src/discovery/refresh.rs
@@ -2,10 +2,11 @@ use std::{sync::Arc, time::Duration};
use bincode::{Decode, Encode};
use log::{error, info, trace};
+use parking_lot::RwLock;
use rand::{rngs::OsRng, RngCore};
use karyon_core::{
- async_runtime::{lock::RwLock, Executor},
+ async_runtime::Executor,
async_util::{sleep, timeout, Backoff, TaskGroup, TaskResult},
};
@@ -33,7 +34,7 @@ pub struct RefreshService {
table: Arc<RoutingTable>,
/// Resolved listen endpoint
- listen_endpoint: Option<RwLock<Endpoint>>,
+ listen_endpoint: RwLock<Option<Endpoint>>,
/// Managing spawned tasks.
task_group: TaskGroup,
@@ -53,14 +54,9 @@ impl RefreshService {
monitor: Arc<Monitor>,
executor: Executor,
) -> Self {
- let listen_endpoint = config
- .listen_endpoint
- .as_ref()
- .map(|endpoint| RwLock::new(endpoint.clone()));
-
Self {
table,
- listen_endpoint,
+ listen_endpoint: RwLock::new(None),
task_group: TaskGroup::with_executor(executor.clone()),
config,
monitor,
@@ -69,9 +65,8 @@ impl RefreshService {
/// Start the refresh service
pub async fn start(self: &Arc<Self>) -> Result<()> {
- if let Some(endpoint) = &self.listen_endpoint {
- let endpoint = endpoint.read().await.clone();
-
+ if let Some(endpoint) = self.listen_endpoint.read().as_ref() {
+ let endpoint = endpoint.clone();
self.task_group.spawn(
{
let this = self.clone();
@@ -101,10 +96,13 @@ impl RefreshService {
}
/// Set the resolved listen endpoint.
- pub async fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) {
- if let Some(endpoint) = &self.listen_endpoint {
- *endpoint.write().await = resolved_endpoint.clone();
- }
+ pub fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) -> Result<()> {
+ let resolved_endpoint = Endpoint::Udp(
+ resolved_endpoint.addr()?.clone(),
+ self.config.discovery_port,
+ );
+ *self.listen_endpoint.write() = Some(resolved_endpoint);
+ Ok(())
}
/// Shuts down the refresh service
diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs
index b21a353..f0dc725 100644
--- a/p2p/src/lib.rs
+++ b/p2p/src/lib.rs
@@ -41,6 +41,7 @@ mod backend;
mod codec;
mod config;
mod conn_queue;
+mod connection;
mod connector;
mod discovery;
mod error;
diff --git a/p2p/src/message.rs b/p2p/src/message.rs
index 6498ef7..5bf0853 100644
--- a/p2p/src/message.rs
+++ b/p2p/src/message.rs
@@ -110,7 +110,9 @@ pub struct PeerMsg {
/// PeersMsg a list of `PeerMsg`.
#[derive(Decode, Encode, Debug)]
-pub struct PeersMsg(pub Vec<PeerMsg>);
+pub struct PeersMsg {
+ pub peers: Vec<PeerMsg>,
+}
impl From<Entry> for PeerMsg {
fn from(entry: Entry) -> PeerMsg {
@@ -133,19 +135,3 @@ impl From<PeerMsg> for Entry {
}
}
}
-
-macro_rules! get_msg_payload {
- ($a:ident, $b:ident) => {
- if let NetMsgCmd::$a = $b.header.command {
- $b.payload
- } else {
- return Err(Error::InvalidMsg(format!(
- "Expected {:?} msg found {:?} msg",
- stringify!($a),
- $b.header.command
- )));
- }
- };
-}
-
-pub(super) use get_msg_payload;
diff --git a/p2p/src/monitor/mod.rs b/p2p/src/monitor/mod.rs
index 4ecb431..86db23e 100644
--- a/p2p/src/monitor/mod.rs
+++ b/p2p/src/monitor/mod.rs
@@ -2,6 +2,8 @@ mod event;
use std::sync::Arc;
+use log::error;
+
use karyon_core::event::{EventListener, EventSys, EventValue, EventValueTopic};
use karyon_net::Endpoint;
@@ -62,7 +64,9 @@ impl Monitor {
pub(crate) async fn notify<E: ToEventStruct>(&self, event: E) {
if self.config.enable_monitor {
let event = event.to_struct();
- self.event_sys.emit(&event).await
+ if let Err(err) = self.event_sys.emit(&event).await {
+ error!("Failed to notify monitor event {:?}: {err}", event);
+ }
}
}
diff --git a/p2p/src/peer/mod.rs b/p2p/src/peer/mod.rs
index 6903294..a5ac7ad 100644
--- a/p2p/src/peer/mod.rs
+++ b/p2p/src/peer/mod.rs
@@ -1,138 +1,111 @@
mod peer_id;
-pub use peer_id::PeerID;
-
use std::sync::{Arc, Weak};
use async_channel::{Receiver, Sender};
-use bincode::{Decode, Encode};
+use bincode::Encode;
use log::{error, trace};
+use parking_lot::RwLock;
use karyon_core::{
- async_runtime::{lock::RwLock, Executor},
+ async_runtime::Executor,
async_util::{select, Either, TaskGroup, TaskResult},
- event::{EventListener, EventSys},
- util::{decode, encode},
+ util::decode,
};
-use karyon_net::{Conn, Endpoint};
-
use crate::{
- conn_queue::ConnDirection,
- message::{NetMsg, NetMsgCmd, ProtocolMsg, ShutdownMsg},
+ connection::{ConnDirection, Connection},
+ endpoint::Endpoint,
+ message::{NetMsgCmd, ProtocolMsg},
peer_pool::PeerPool,
- protocol::{Protocol, ProtocolEvent, ProtocolID},
+ protocol::{InitProtocol, Protocol, ProtocolEvent, ProtocolID},
+ protocols::HandshakeProtocol,
Config, Error, Result,
};
+pub use peer_id::PeerID;
+
pub struct Peer {
+ /// Own ID
+ own_id: PeerID,
+
/// Peer's ID
- id: PeerID,
+ id: RwLock<Option<PeerID>>,
- /// A weak pointer to `PeerPool`
+ /// A weak pointer to [`PeerPool`]
peer_pool: Weak<PeerPool>,
/// Holds the peer connection
- conn: Conn<NetMsg>,
-
- /// Remote endpoint for the peer
- remote_endpoint: Endpoint,
-
- /// The direction of the connection, either `Inbound` or `Outbound`
- conn_direction: ConnDirection,
-
- /// A list of protocol IDs
- protocol_ids: RwLock<Vec<ProtocolID>>,
-
- /// `EventSys` responsible for sending events to the protocols.
- protocol_events: Arc<EventSys<ProtocolID>>,
+ pub(crate) conn: Arc<Connection>,
/// This channel is used to send a stop signal to the read loop.
stop_chan: (Sender<Result<()>>, Receiver<Result<()>>),
+ /// The Configuration for the P2P network.
+ config: Arc<Config>,
+
/// Managing spawned tasks.
task_group: TaskGroup,
}
impl Peer {
/// Creates a new peer
- pub fn new(
+ pub(crate) fn new(
+ own_id: PeerID,
peer_pool: Weak<PeerPool>,
- id: &PeerID,
- conn: Conn<NetMsg>,
- remote_endpoint: Endpoint,
- conn_direction: ConnDirection,
+ conn: Arc<Connection>,
+ config: Arc<Config>,
ex: Executor,
) -> Arc<Peer> {
Arc::new(Peer {
- id: id.clone(),
+ own_id,
+ id: RwLock::new(None),
peer_pool,
conn,
- protocol_ids: RwLock::new(Vec::new()),
- remote_endpoint,
- conn_direction,
- protocol_events: EventSys::new(),
+ config,
task_group: TaskGroup::with_executor(ex),
stop_chan: async_channel::bounded(1),
})
}
- /// Run the peer
- pub async fn run(self: Arc<Self>) -> Result<()> {
- self.start_protocols().await;
- self.read_loop().await
+ /// Send a msg to this peer connection using the specified protocol.
+ pub async fn send<T: Encode>(&self, proto_id: ProtocolID, msg: T) -> Result<()> {
+ self.conn.send(proto_id, msg).await
}
- /// Send a message to the peer connection using the specified protocol.
- pub async fn send<T: Encode + Decode>(&self, protocol_id: &ProtocolID, msg: &T) -> Result<()> {
- let payload = encode(msg)?;
-
- let proto_msg = ProtocolMsg {
- protocol_id: protocol_id.to_string(),
- payload: payload.to_vec(),
- };
-
- self.conn
- .send(NetMsg::new(NetMsgCmd::Protocol, &proto_msg)?)
- .await?;
- Ok(())
+ /// Receives a new msg from this peer connection.
+ pub async fn recv<P: Protocol>(&self) -> Result<ProtocolEvent> {
+ self.conn.recv::<P>().await
}
/// Broadcast a message to all connected peers using the specified protocol.
- pub async fn broadcast<T: Encode + Decode>(&self, protocol_id: &ProtocolID, msg: &T) {
- self.peer_pool().broadcast(protocol_id, msg).await;
+ pub async fn broadcast<T: Encode>(&self, proto_id: &ProtocolID, msg: &T) {
+ self.peer_pool().broadcast(proto_id, msg).await;
}
- /// Shuts down the peer
- pub async fn shutdown(&self) {
- trace!("peer {} start shutting down", self.id);
-
- // Send shutdown event to all protocols
- for protocol_id in self.protocol_ids.read().await.iter() {
- self.protocol_events
- .emit_by_topic(protocol_id, &ProtocolEvent::Shutdown)
- .await;
- }
+ /// Returns the peer's ID
+ pub fn id(&self) -> Option<PeerID> {
+ self.id.read().clone()
+ }
- // Send a stop signal to the read loop
- //
- // No need to handle the error here; a dropped channel and
- // sending a stop signal have the same effect.
- let _ = self.stop_chan.0.try_send(Ok(()));
+ /// Returns own ID
+ pub fn own_id(&self) -> &PeerID {
+ &self.own_id
+ }
- // No need to handle the error here
- let shutdown_msg =
- NetMsg::new(NetMsgCmd::Shutdown, ShutdownMsg(0)).expect("pack shutdown message");
- let _ = self.conn.send(shutdown_msg).await;
+ /// Returns the [`Config`]
+ pub fn config(&self) -> Arc<Config> {
+ self.config.clone()
+ }
- // Force shutting down
- self.task_group.cancel().await;
+ /// Returns the remote endpoint for the peer
+ pub fn remote_endpoint(&self) -> &Endpoint {
+ &self.conn.remote_endpoint
}
/// Check if the connection is Inbound
- #[inline]
pub fn is_inbound(&self) -> bool {
- match self.conn_direction {
+ match self.conn.direction {
ConnDirection::Inbound => true,
ConnDirection::Outbound => false,
}
@@ -140,40 +113,82 @@ impl Peer {
/// Returns the direction of the connection, which can be either `Inbound`
/// or `Outbound`.
- #[inline]
pub fn direction(&self) -> &ConnDirection {
- &self.conn_direction
+ &self.conn.direction
}
- /// Returns the remote endpoint for the peer
- #[inline]
- pub fn remote_endpoint(&self) -> &Endpoint {
- &self.remote_endpoint
+ pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
+ let handshake_protocol = HandshakeProtocol::new(
+ self.clone(),
+ self.peer_pool().protocol_versions.read().await.clone(),
+ );
+
+ let pid = handshake_protocol.init().await?;
+ *self.id.write() = Some(pid);
+
+ Ok(())
}
- /// Return the peer's ID
- #[inline]
- pub fn id(&self) -> &PeerID {
- &self.id
+ /// Run the peer
+ pub(crate) async fn run(self: Arc<Self>) -> Result<()> {
+ self.run_connect_protocols().await;
+ self.read_loop().await
}
- /// Returns the `Config` instance.
- pub fn config(&self) -> Arc<Config> {
- self.peer_pool().config.clone()
+ /// Shuts down the peer
+ pub(crate) async fn shutdown(self: &Arc<Self>) -> Result<()> {
+ trace!("peer {:?} shutting down", self.id());
+
+ // Send shutdown event to the attached protocols
+ for proto_id in self.peer_pool().protocols.read().await.keys() {
+ let _ = self.conn.emit_msg(proto_id, &ProtocolEvent::Shutdown).await;
+ }
+
+ // Send a stop signal to the read loop
+ //
+ // No need to handle the error here; a dropped channel and
+ // sendig a stop signal have the same effect.
+ let _ = self.stop_chan.0.try_send(Ok(()));
+
+ self.conn.disconnect(Ok(())).await?;
+
+ // Force shutting down
+ self.task_group.cancel().await;
+ Ok(())
}
- /// Registers a listener for the given Protocol `P`.
- pub async fn register_listener<P: Protocol>(&self) -> EventListener<ProtocolID, ProtocolEvent> {
- self.protocol_events.register(&P::id()).await
+ /// Run running the Connect Protocols for this peer connection.
+ async fn run_connect_protocols(self: &Arc<Self>) {
+ for (proto_id, constructor) in self.peer_pool().protocols.read().await.iter() {
+ trace!("peer {:?} run protocol {proto_id}", self.id());
+
+ let protocol = constructor(self.clone());
+
+ let on_failure = {
+ let this = self.clone();
+ let proto_id = proto_id.clone();
+ |result: TaskResult<Result<()>>| async move {
+ if let TaskResult::Completed(res) = result {
+ if res.is_err() {
+ error!("protocol {} stopped", proto_id);
+ }
+ // Send a stop signal to read loop
+ let _ = this.stop_chan.0.try_send(res);
+ }
+ }
+ };
+
+ self.task_group.spawn(protocol.start(), on_failure);
+ }
}
- /// Start a read loop to handle incoming messages from the peer connection.
+ /// Run a read loop to handle incoming messages from the peer connection.
async fn read_loop(&self) -> Result<()> {
loop {
- let fut = select(self.stop_chan.1.recv(), self.conn.recv()).await;
+ let fut = select(self.stop_chan.1.recv(), self.conn.recv_inner()).await;
let result = match fut {
Either::Left(stop_signal) => {
- trace!("Peer {} received a stop signal", self.id);
+ trace!("Peer {:?} received a stop signal", self.id());
return stop_signal?;
}
Either::Right(result) => result,
@@ -184,14 +199,9 @@ impl Peer {
match msg.header.command {
NetMsgCmd::Protocol => {
let msg: ProtocolMsg = decode(&msg.payload)?.0;
-
- if !self.protocol_ids.read().await.contains(&msg.protocol_id) {
- return Err(Error::UnsupportedProtocol(msg.protocol_id));
- }
-
- let proto_id = &msg.protocol_id;
- let msg = ProtocolEvent::Message(msg.payload);
- self.protocol_events.emit_by_topic(proto_id, &msg).await;
+ self.conn
+ .emit_msg(&msg.protocol_id, &ProtocolEvent::Message(msg.payload))
+ .await?;
}
NetMsgCmd::Shutdown => {
return Err(Error::PeerShutdown);
@@ -201,32 +211,7 @@ impl Peer {
}
}
- /// Start running the protocols for this peer connection.
- async fn start_protocols(self: &Arc<Self>) {
- for (protocol_id, constructor) in self.peer_pool().protocols.read().await.iter() {
- trace!("peer {} start protocol {protocol_id}", self.id);
- let protocol = constructor(self.clone());
-
- self.protocol_ids.write().await.push(protocol_id.clone());
-
- let on_failure = {
- let this = self.clone();
- let protocol_id = protocol_id.clone();
- |result: TaskResult<Result<()>>| async move {
- if let TaskResult::Completed(res) = result {
- if res.is_err() {
- error!("protocol {} stopped", protocol_id);
- }
- // Send a stop signal to read loop
- let _ = this.stop_chan.0.try_send(res);
- }
- }
- };
-
- self.task_group.spawn(protocol.start(), on_failure);
- }
- }
-
+ /// Returns `PeerPool` pointer
fn peer_pool(&self) -> Arc<PeerPool> {
self.peer_pool.upgrade().unwrap()
}
diff --git a/p2p/src/peer_pool.rs b/p2p/src/peer_pool.rs
index 1f3ca55..549dc76 100644
--- a/p2p/src/peer_pool.rs
+++ b/p2p/src/peer_pool.rs
@@ -1,26 +1,24 @@
-use std::{collections::HashMap, sync::Arc, time::Duration};
+use std::{collections::HashMap, sync::Arc};
-use async_channel::Sender;
-use bincode::{Decode, Encode};
-use log::{error, info, trace, warn};
+use bincode::Encode;
+use log::{error, info, warn};
use karyon_core::{
async_runtime::{lock::RwLock, Executor},
- async_util::{timeout, TaskGroup, TaskResult},
- util::decode,
+ async_util::{TaskGroup, TaskResult},
};
-use karyon_net::{Conn, Endpoint};
+use karyon_net::Endpoint;
use crate::{
config::Config,
- conn_queue::{ConnDirection, ConnQueue},
- message::{get_msg_payload, NetMsg, NetMsgCmd, VerAckMsg, VerMsg},
+ conn_queue::ConnQueue,
+ connection::Connection,
monitor::{Monitor, PPEvent},
peer::Peer,
protocol::{Protocol, ProtocolConstructor, ProtocolID},
protocols::PingProtocol,
- version::{version_match, Version, VersionInt},
+ version::Version,
Error, PeerID, Result,
};
@@ -37,8 +35,8 @@ pub struct PeerPool {
/// Hashmap contains protocol constructors.
pub(crate) protocols: RwLock<HashMap<ProtocolID, Box<ProtocolConstructor>>>,
- /// Hashmap contains protocol IDs and their versions.
- protocol_versions: Arc<RwLock<HashMap<ProtocolID, Version>>>,
+ /// Hashmap contains protocols with their versions
+ pub(crate) protocol_versions: RwLock<HashMap<ProtocolID, Version>>,
/// Managing spawned tasks.
task_group: TaskGroup,
@@ -47,7 +45,7 @@ pub struct PeerPool {
executor: Executor,
/// The Configuration for the P2P network.
- pub(crate) config: Arc<Config>,
+ config: Arc<Config>,
/// Responsible for network and system monitoring.
monitor: Arc<Monitor>,
@@ -62,15 +60,12 @@ impl PeerPool {
monitor: Arc<Monitor>,
executor: Executor,
) -> Arc<Self> {
- let protocols = RwLock::new(HashMap::new());
- let protocol_versions = Arc::new(RwLock::new(HashMap::new()));
-
Arc::new(Self {
id: id.clone(),
conn_queue,
peers: RwLock::new(HashMap::new()),
- protocols,
- protocol_versions,
+ protocols: RwLock::new(HashMap::new()),
+ protocol_versions: RwLock::new(HashMap::new()),
task_group: TaskGroup::with_executor(executor.clone()),
executor,
monitor,
@@ -80,21 +75,15 @@ impl PeerPool {
/// Starts the [`PeerPool`]
pub async fn start(self: &Arc<Self>) -> Result<()> {
- self.setup_protocols().await?;
- self.task_group.spawn(
- {
- let this = self.clone();
- async move { this.listen_loop().await }
- },
- |_| async {},
- );
+ self.setup_core_protocols().await?;
+ self.task_group.spawn(self.clone().run(), |_| async {});
Ok(())
}
/// Shuts down
pub async fn shutdown(&self) {
for (_, peer) in self.peers.read().await.iter() {
- peer.shutdown().await;
+ let _ = peer.shutdown().await;
}
self.task_group.cancel().await;
@@ -102,76 +91,24 @@ impl PeerPool {
/// Attach a custom protocol to the network
pub async fn attach_protocol<P: Protocol>(&self, c: Box<ProtocolConstructor>) -> Result<()> {
- let protocol_versions = &mut self.protocol_versions.write().await;
- let protocols = &mut self.protocols.write().await;
-
- protocol_versions.insert(P::id(), P::version()?);
- protocols.insert(P::id(), c);
+ self.protocols.write().await.insert(P::id(), c);
+ self.protocol_versions
+ .write()
+ .await
+ .insert(P::id(), P::version()?);
Ok(())
}
/// Broadcast a message to all connected peers using the specified protocol.
- pub async fn broadcast<T: Decode + Encode>(&self, proto_id: &ProtocolID, msg: &T) {
+ pub async fn broadcast<T: Encode>(&self, proto_id: &ProtocolID, msg: &T) {
for (pid, peer) in self.peers.read().await.iter() {
- if let Err(err) = peer.send(proto_id, msg).await {
+ if let Err(err) = peer.conn.send(proto_id.to_string(), msg).await {
error!("failed to send msg to {pid}: {err}");
continue;
}
}
}
- /// Add a new peer to the peer list.
- pub async fn new_peer(
- self: &Arc<Self>,
- conn: Conn<NetMsg>,
- conn_direction: &ConnDirection,
- disconnect_signal: Sender<Result<()>>,
- ) -> Result<()> {
- let endpoint = conn.peer_endpoint()?;
-
- // Do a handshake with the connection before creating a new peer.
- let pid = self.do_handshake(&conn, conn_direction).await?;
-
- // TODO: Consider restricting the subnet for inbound connections
- if self.contains_peer(&pid).await {
- return Err(Error::PeerAlreadyConnected);
- }
-
- // Create a new peer
- let peer = Peer::new(
- Arc::downgrade(self),
- &pid,
- conn,
- endpoint.clone(),
- conn_direction.clone(),
- self.executor.clone(),
- );
-
- // Insert the new peer
- self.peers.write().await.insert(pid.clone(), peer.clone());
-
- let on_disconnect = {
- let this = self.clone();
- let pid = pid.clone();
- |result| async move {
- if let TaskResult::Completed(result) = result {
- if let Err(err) = this.remove_peer(&pid).await {
- error!("Failed to remove peer {pid}: {err}");
- }
- let _ = disconnect_signal.send(result).await;
- }
- }
- };
-
- self.task_group.spawn(peer.run(), on_disconnect);
-
- info!("Add new peer {pid}, direction: {conn_direction}, endpoint: {endpoint}");
-
- self.monitor.notify(PPEvent::NewPeer(pid.clone())).await;
-
- Ok(())
- }
-
/// Checks if the peer list contains a peer with the given peer id
pub async fn contains_peer(&self, pid: &PeerID) -> bool {
self.peers.read().await.contains_key(pid)
@@ -204,162 +141,89 @@ impl PeerPool {
peers
}
- /// Listens to a new connection from the connection queue
- async fn listen_loop(self: Arc<Self>) {
+ async fn run(self: Arc<Self>) {
loop {
- let conn = self.conn_queue.next().await;
- let signal = conn.disconnect_signal;
+ let mut conn = self.conn_queue.next().await;
+
+ for protocol_id in self.protocols.read().await.keys() {
+ conn.register_protocol(protocol_id.to_string()).await;
+ }
- let result = self
- .new_peer(conn.conn, &conn.direction, signal.clone())
- .await;
+ let conn = Arc::new(conn);
- // Only send a disconnect signal if there is an error when adding a peer.
+ let result = self.new_peer(conn.clone()).await;
+
+ // Disconnect if there is an error when adding a peer.
if result.is_err() {
- let _ = signal.send(result).await;
+ let _ = conn.disconnect(result).await;
}
}
}
- /// Shuts down the peer and remove it from the peer list.
- async fn remove_peer(&self, pid: &PeerID) -> Result<()> {
- let result = self.peers.write().await.remove(pid);
-
- let peer = match result {
- Some(p) => p,
- None => return Ok(()),
- };
-
- peer.shutdown().await;
-
- self.monitor.notify(PPEvent::RemovePeer(pid.clone())).await;
-
- let endpoint = peer.remote_endpoint();
- let direction = peer.direction();
+ /// Add a new peer to the peer list.
+ async fn new_peer(self: &Arc<Self>, conn: Arc<Connection>) -> Result<()> {
+ // Create a new peer
+ let peer = Peer::new(
+ self.id.clone(),
+ Arc::downgrade(self),
+ conn.clone(),
+ self.config.clone(),
+ self.executor.clone(),
+ );
+ peer.init().await?;
+ let pid = peer.id().expect("Get peer id after peer initialization");
- warn!("Peer {pid} removed, direction: {direction}, endpoint: {endpoint}",);
- Ok(())
- }
+ // TODO: Consider restricting the subnet for inbound connections
+ if self.contains_peer(&pid).await {
+ return Err(Error::PeerAlreadyConnected);
+ }
- /// Attach the core protocols.
- async fn setup_protocols(&self) -> Result<()> {
- let executor = self.executor.clone();
- let c = move |peer| PingProtocol::new(peer, executor.clone());
- self.attach_protocol::<PingProtocol>(Box::new(c)).await
- }
+ // Insert the new peer
+ self.peers.write().await.insert(pid.clone(), peer.clone());
- /// Initiate a handshake with a connection.
- async fn do_handshake(
- &self,
- conn: &Conn<NetMsg>,
- conn_direction: &ConnDirection,
- ) -> Result<PeerID> {
- trace!("Handshake started: {}", conn.peer_endpoint()?);
- match conn_direction {
- ConnDirection::Inbound => {
- let result = self.wait_vermsg(conn).await;
- match result {
- Ok(_) => {
- self.send_verack(conn, true).await?;
- }
- Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => {
- self.send_verack(conn, false).await?;
+ let on_disconnect = {
+ let this = self.clone();
+ let pid = pid.clone();
+ |result| async move {
+ if let TaskResult::Completed(_) = result {
+ if let Err(err) = this.remove_peer(&pid).await {
+ error!("Failed to remove peer {pid}: {err}");
}
- _ => {}
}
- result
- }
-
- ConnDirection::Outbound => {
- self.send_vermsg(conn).await?;
- self.wait_verack(conn).await
}
- }
- }
+ };
- /// Send a Version message
- async fn send_vermsg(&self, conn: &Conn<NetMsg>) -> Result<()> {
- let pids = self.protocol_versions.read().await;
- let protocols = pids.iter().map(|p| (p.0.clone(), p.1.v.clone())).collect();
- drop(pids);
+ self.task_group.spawn(peer.run(), on_disconnect);
- let vermsg = VerMsg {
- peer_id: self.id.clone(),
- protocols,
- version: self.config.version.v.clone(),
- };
+ info!("Add new peer {pid}");
+ self.monitor.notify(PPEvent::NewPeer(pid)).await;
- trace!("Send VerMsg");
- conn.send(NetMsg::new(NetMsgCmd::Version, &vermsg)?).await?;
Ok(())
}
- /// Wait for a Version message
- ///
- /// Returns the peer's ID upon successfully receiving the Version message.
- async fn wait_vermsg(&self, conn: &Conn<NetMsg>) -> Result<PeerID> {
- let t = Duration::from_secs(self.config.handshake_timeout);
- let msg: NetMsg = timeout(t, conn.recv()).await??;
-
- let payload = get_msg_payload!(Version, msg);
- let (vermsg, _) = decode::<VerMsg>(&payload)?;
-
- if !version_match(&self.config.version.req, &vermsg.version) {
- return Err(Error::IncompatibleVersion("system: {}".into()));
- }
-
- self.protocols_match(&vermsg.protocols).await?;
-
- trace!("Received VerMsg from: {}", vermsg.peer_id);
- Ok(vermsg.peer_id)
- }
+ /// Shuts down the peer and remove it from the peer list.
+ async fn remove_peer(&self, pid: &PeerID) -> Result<()> {
+ let result = self.peers.write().await.remove(pid);
- /// Send a Verack message
- async fn send_verack(&self, conn: &Conn<NetMsg>, ack: bool) -> Result<()> {
- let verack = VerAckMsg {
- peer_id: self.id.clone(),
- ack,
+ let peer = match result {
+ Some(p) => p,
+ None => return Ok(()),
};
- trace!("Send VerAckMsg {:?}", verack);
- conn.send(NetMsg::new(NetMsgCmd::Verack, &verack)?).await?;
- Ok(())
- }
-
- /// Wait for a Verack message
- ///
- /// Returns the peer's ID upon successfully receiving the Verack message.
- async fn wait_verack(&self, conn: &Conn<NetMsg>) -> Result<PeerID> {
- let t = Duration::from_secs(self.config.handshake_timeout);
- let msg: NetMsg = timeout(t, conn.recv()).await??;
+ let _ = peer.shutdown().await;
- let payload = get_msg_payload!(Verack, msg);
- let (verack, _) = decode::<VerAckMsg>(&payload)?;
-
- if !verack.ack {
- return Err(Error::IncompatiblePeer);
- }
+ self.monitor.notify(PPEvent::RemovePeer(pid.clone())).await;
- trace!("Received VerAckMsg from: {}", verack.peer_id);
- Ok(verack.peer_id)
+ warn!("Peer {pid} removed",);
+ Ok(())
}
- /// Check if the new connection has compatible protocols.
- async fn protocols_match(&self, protocols: &HashMap<ProtocolID, VersionInt>) -> Result<()> {
- for (n, pv) in protocols.iter() {
- let pids = self.protocol_versions.read().await;
-
- match pids.get(n) {
- Some(v) => {
- if !version_match(&v.req, pv) {
- return Err(Error::IncompatibleVersion(format!("{n} protocol: {pv}")));
- }
- }
- None => {
- return Err(Error::UnsupportedProtocol(n.to_string()));
- }
- }
- }
- Ok(())
+ /// Attach the core protocols.
+ async fn setup_core_protocols(&self) -> Result<()> {
+ let executor = self.executor.clone();
+ let ping_interval = self.config.ping_interval;
+ let ping_timeout = self.config.ping_timeout;
+ let c = move |peer| PingProtocol::new(peer, ping_interval, ping_timeout, executor.clone());
+ self.attach_protocol::<PingProtocol>(Box::new(c)).await
}
}
diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs
index 021844f..249692b 100644
--- a/p2p/src/protocol.rs
+++ b/p2p/src/protocol.rs
@@ -56,11 +56,8 @@ impl EventValue for ProtocolEvent {
/// #[async_trait]
/// impl Protocol for NewProtocol {
/// async fn start(self: Arc<Self>) -> Result<(), Error> {
-/// let listener = self.peer.register_listener::<Self>().await;
/// loop {
-/// let event = listener.recv().await.unwrap();
-///
-/// match event {
+/// match self.peer.recv::<Self>().await.expect("Receive msg") {
/// ProtocolEvent::Message(msg) => {
/// println!("{:?}", msg);
/// }
@@ -69,8 +66,6 @@ impl EventValue for ProtocolEvent {
/// }
/// }
/// }
-///
-/// listener.cancel().await;
/// Ok(())
/// }
///
@@ -114,3 +109,10 @@ pub trait Protocol: Send + Sync {
where
Self: Sized;
}
+
+#[async_trait]
+pub(crate) trait InitProtocol: Send + Sync {
+ type T;
+ /// Initialize the protocol
+ async fn init(self: Arc<Self>) -> Self::T;
+}
diff --git a/p2p/src/protocols/handshake.rs b/p2p/src/protocols/handshake.rs
new file mode 100644
index 0000000..b3fe989
--- /dev/null
+++ b/p2p/src/protocols/handshake.rs
@@ -0,0 +1,139 @@
+use std::{collections::HashMap, sync::Arc, time::Duration};
+
+use async_trait::async_trait;
+use log::trace;
+
+use karyon_core::{async_util::timeout, util::decode};
+
+use crate::{
+ message::{NetMsg, NetMsgCmd, VerAckMsg, VerMsg},
+ peer::Peer,
+ protocol::{InitProtocol, ProtocolID},
+ version::{version_match, VersionInt},
+ Error, PeerID, Result, Version,
+};
+
+pub struct HandshakeProtocol {
+ peer: Arc<Peer>,
+ protocols: HashMap<ProtocolID, Version>,
+}
+
+#[async_trait]
+impl InitProtocol for HandshakeProtocol {
+ type T = Result<PeerID>;
+ /// Initiate a handshake with a connection.
+ async fn init(self: Arc<Self>) -> Self::T {
+ trace!("Init Handshake: {}", self.peer.remote_endpoint());
+
+ if !self.peer.is_inbound() {
+ self.send_vermsg().await?;
+ }
+
+ let t = Duration::from_secs(self.peer.config().handshake_timeout);
+ let msg: NetMsg = timeout(t, self.peer.conn.recv_inner()).await??;
+ match msg.header.command {
+ NetMsgCmd::Version => {
+ let result = self.validate_version_msg(&msg).await;
+ match result {
+ Ok(_) => {
+ self.send_verack(true).await?;
+ }
+ Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => {
+ self.send_verack(false).await?;
+ }
+ _ => {}
+ };
+ result
+ }
+ NetMsgCmd::Verack => self.validate_verack_msg(&msg).await,
+ cmd => Err(Error::InvalidMsg(format!("unexpected msg found {:?}", cmd))),
+ }
+ }
+}
+
+impl HandshakeProtocol {
+ pub fn new(peer: Arc<Peer>, protocols: HashMap<ProtocolID, Version>) -> Arc<Self> {
+ Arc::new(Self { peer, protocols })
+ }
+
+ /// Sends a Version message
+ async fn send_vermsg(&self) -> Result<()> {
+ let protocols = self
+ .protocols
+ .clone()
+ .into_iter()
+ .map(|p| (p.0, p.1.v))
+ .collect();
+
+ let vermsg = VerMsg {
+ peer_id: self.peer.own_id().clone(),
+ protocols,
+ version: self.peer.config().version.v.clone(),
+ };
+
+ trace!("Send VerMsg");
+ self.peer
+ .conn
+ .send_inner(NetMsg::new(NetMsgCmd::Version, &vermsg)?)
+ .await?;
+ Ok(())
+ }
+
+ /// Sends a Verack message
+ async fn send_verack(&self, ack: bool) -> Result<()> {
+ let verack = VerAckMsg {
+ peer_id: self.peer.own_id().clone(),
+ ack,
+ };
+
+ trace!("Send VerAckMsg {:?}", verack);
+ self.peer
+ .conn
+ .send_inner(NetMsg::new(NetMsgCmd::Verack, &verack)?)
+ .await?;
+ Ok(())
+ }
+
+ /// Validates the given version msg
+ async fn validate_version_msg(&self, msg: &NetMsg) -> Result<PeerID> {
+ let (vermsg, _) = decode::<VerMsg>(&msg.payload)?;
+
+ if !version_match(&self.peer.config().version.req, &vermsg.version) {
+ return Err(Error::IncompatibleVersion("system: {}".into()));
+ }
+
+ self.protocols_match(&vermsg.protocols).await?;
+
+ trace!("Received VerMsg from: {}", vermsg.peer_id);
+ Ok(vermsg.peer_id)
+ }
+
+ /// Validates the given verack msg
+ async fn validate_verack_msg(&self, msg: &NetMsg) -> Result<PeerID> {
+ let (verack, _) = decode::<VerAckMsg>(&msg.payload)?;
+
+ if !verack.ack {
+ return Err(Error::IncompatiblePeer);
+ }
+
+ trace!("Received VerAckMsg from: {}", verack.peer_id);
+ Ok(verack.peer_id)
+ }
+
+ /// Check if the new connection has compatible protocols.
+ async fn protocols_match(&self, protocols: &HashMap<ProtocolID, VersionInt>) -> Result<()> {
+ for (n, pv) in protocols.iter() {
+ match self.protocols.get(n) {
+ Some(v) => {
+ if !version_match(&v.req, pv) {
+ return Err(Error::IncompatibleVersion(format!("{n} protocol: {pv}")));
+ }
+ }
+ None => {
+ return Err(Error::UnsupportedProtocol(n.to_string()));
+ }
+ }
+ }
+ Ok(())
+ }
+}
diff --git a/p2p/src/protocols/mod.rs b/p2p/src/protocols/mod.rs
index 4a8f6b9..c58df03 100644
--- a/p2p/src/protocols/mod.rs
+++ b/p2p/src/protocols/mod.rs
@@ -1,3 +1,5 @@
+mod handshake;
mod ping;
-pub use ping::PingProtocol;
+pub(crate) use handshake::HandshakeProtocol;
+pub(crate) use ping::PingProtocol;
diff --git a/p2p/src/protocols/ping.rs b/p2p/src/protocols/ping.rs
index b800b23..f35b203 100644
--- a/p2p/src/protocols/ping.rs
+++ b/p2p/src/protocols/ping.rs
@@ -9,7 +9,6 @@ use rand::{rngs::OsRng, RngCore};
use karyon_core::{
async_runtime::Executor,
async_util::{select, sleep, timeout, Either, TaskGroup, TaskResult},
- event::EventListener,
util::decode,
};
@@ -39,9 +38,12 @@ pub struct PingProtocol {
impl PingProtocol {
#[allow(clippy::new_ret_no_self)]
- pub fn new(peer: Arc<Peer>, executor: Executor) -> Arc<dyn Protocol> {
- let ping_interval = peer.config().ping_interval;
- let ping_timeout = peer.config().ping_timeout;
+ pub fn new(
+ peer: Arc<Peer>,
+ ping_interval: u64,
+ ping_timeout: u64,
+ executor: Executor,
+ ) -> Arc<dyn Protocol> {
Arc::new(Self {
peer,
ping_interval,
@@ -50,13 +52,9 @@ impl PingProtocol {
})
}
- async fn recv_loop(
- &self,
- listener: &EventListener<ProtocolID, ProtocolEvent>,
- pong_chan: Sender<[u8; 32]>,
- ) -> Result<()> {
+ async fn recv_loop(&self, pong_chan: Sender<[u8; 32]>) -> Result<()> {
loop {
- let event = listener.recv().await?;
+ let event = self.peer.recv::<Self>().await?;
let msg_payload = match event.clone() {
ProtocolEvent::Message(m) => m,
ProtocolEvent::Shutdown => {
@@ -70,7 +68,7 @@ impl PingProtocol {
PingProtocolMsg::Ping(nonce) => {
trace!("Received Ping message {:?}", nonce);
self.peer
- .send(&Self::id(), &PingProtocolMsg::Pong(nonce))
+ .send(Self::id(), &PingProtocolMsg::Pong(nonce))
.await?;
trace!("Send back Pong message {:?}", nonce);
}
@@ -82,7 +80,7 @@ impl PingProtocol {
Ok(())
}
- async fn ping_loop(self: Arc<Self>, chan: Receiver<[u8; 32]>) -> Result<()> {
+ async fn ping_loop(&self, chan: Receiver<[u8; 32]>) -> Result<()> {
let rng = &mut OsRng;
let mut retry = 0;
@@ -94,12 +92,11 @@ impl PingProtocol {
trace!("Send Ping message {:?}", ping_nonce);
self.peer
- .send(&Self::id(), &PingProtocolMsg::Ping(ping_nonce))
+ .send(Self::id(), &PingProtocolMsg::Ping(ping_nonce))
.await?;
- let d = Duration::from_secs(self.ping_timeout);
-
// Wait for Pong message
+ let d = Duration::from_secs(self.ping_timeout);
let pong_msg = match timeout(d, chan.recv()).await {
Ok(m) => m?,
Err(_) => {
@@ -107,13 +104,14 @@ impl PingProtocol {
continue;
}
};
-
trace!("Received Pong message {:?}", pong_msg);
if pong_msg != ping_nonce {
retry += 1;
continue;
}
+
+ retry = 0;
}
Err(NetError::Timeout.into())
@@ -125,8 +123,8 @@ impl Protocol for PingProtocol {
async fn start(self: Arc<Self>) -> Result<()> {
trace!("Start Ping protocol");
+ let stop_signal = async_channel::bounded::<Result<()>>(1);
let (pong_chan, pong_chan_recv) = async_channel::bounded(1);
- let (stop_signal_s, stop_signal) = async_channel::bounded::<Result<()>>(1);
self.task_group.spawn(
{
@@ -135,15 +133,12 @@ impl Protocol for PingProtocol {
},
|res| async move {
if let TaskResult::Completed(result) = res {
- let _ = stop_signal_s.send(result).await;
+ let _ = stop_signal.0.send(result).await;
}
},
);
- let listener = self.peer.register_listener::<Self>().await;
-
- let result = select(self.recv_loop(&listener, pong_chan), stop_signal.recv()).await;
- listener.cancel().await;
+ let result = select(self.recv_loop(pong_chan), stop_signal.1.recv()).await;
self.task_group.cancel().await;
match result {