From 0992071a7f1a36424bcfaf1fbc84541ea041df1a Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 11 Apr 2024 10:19:20 +0200 Subject: add support for tokio & improve net crate api --- p2p/src/peer_pool.rs | 65 ++++++++++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 30 deletions(-) (limited to 'p2p/src/peer_pool.rs') diff --git a/p2p/src/peer_pool.rs b/p2p/src/peer_pool.rs index 4e20c99..8b16ef5 100644 --- a/p2p/src/peer_pool.rs +++ b/p2p/src/peer_pool.rs @@ -4,21 +4,22 @@ use std::{ time::Duration, }; +use async_channel::Sender; +use bincode::{Decode, Encode}; use log::{error, info, trace, warn}; -use smol::{ - channel::Sender, - lock::{Mutex, RwLock}, -}; use karyon_core::{ - async_util::{Executor, TaskGroup, TaskResult}, + async_runtime::{ + lock::{Mutex, RwLock}, + Executor, + }, + async_util::{timeout, TaskGroup, TaskResult}, util::decode, }; use karyon_net::Conn; use crate::{ - codec::{Codec, CodecMsg}, config::Config, connection::{ConnDirection, ConnQueue}, message::{get_msg_payload, NetMsg, NetMsgCmd, VerAckMsg, VerMsg}, @@ -50,10 +51,10 @@ pub struct PeerPool { protocol_versions: Arc>>, /// Managing spawned tasks. - task_group: TaskGroup<'static>, + task_group: TaskGroup, /// A global Executor - executor: Executor<'static>, + executor: Executor, /// The Configuration for the P2P network. pub(crate) config: Arc, @@ -69,7 +70,7 @@ impl PeerPool { conn_queue: Arc, config: Arc, monitor: Arc, - executor: Executor<'static>, + executor: Executor, ) -> Arc { let protocols = RwLock::new(HashMap::new()); let protocol_versions = Arc::new(RwLock::new(HashMap::new())); @@ -137,7 +138,7 @@ impl PeerPool { } /// Broadcast a message to all connected peers using the specified protocol. - pub async fn broadcast(&self, proto_id: &ProtocolID, msg: &T) { + pub async fn broadcast(&self, proto_id: &ProtocolID, msg: &T) { for (pid, peer) in self.peers.lock().await.iter() { if let Err(err) = peer.send(proto_id, msg).await { error!("failed to send msg to {pid}: {err}"); @@ -149,15 +150,14 @@ impl PeerPool { /// Add a new peer to the peer list. pub async fn new_peer( self: &Arc, - conn: Conn, + conn: Conn, conn_direction: &ConnDirection, disconnect_signal: Sender>, ) -> Result<()> { let endpoint = conn.peer_endpoint()?; - let codec = Codec::new(conn); // Do a handshake with the connection before creating a new peer. - let pid = self.do_handshake(&codec, conn_direction).await?; + let pid = self.do_handshake(&conn, conn_direction).await?; // TODO: Consider restricting the subnet for inbound connections if self.contains_peer(&pid).await { @@ -168,7 +168,7 @@ impl PeerPool { let peer = Peer::new( Arc::downgrade(self), &pid, - codec, + conn, endpoint.clone(), conn_direction.clone(), self.executor.clone(), @@ -234,16 +234,21 @@ impl PeerPool { } /// Initiate a handshake with a connection. - async fn do_handshake(&self, codec: &Codec, conn_direction: &ConnDirection) -> Result { + async fn do_handshake( + &self, + conn: &Conn, + conn_direction: &ConnDirection, + ) -> Result { + trace!("Handshake started: {}", conn.peer_endpoint()?); match conn_direction { ConnDirection::Inbound => { - let result = self.wait_vermsg(codec).await; + let result = self.wait_vermsg(conn).await; match result { Ok(_) => { - self.send_verack(codec, true).await?; + self.send_verack(conn, true).await?; } Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => { - self.send_verack(codec, false).await?; + self.send_verack(conn, false).await?; } _ => {} } @@ -251,14 +256,14 @@ impl PeerPool { } ConnDirection::Outbound => { - self.send_vermsg(codec).await?; - self.wait_verack(codec).await + self.send_vermsg(conn).await?; + self.wait_verack(conn).await } } } /// Send a Version message - async fn send_vermsg(&self, codec: &Codec) -> Result<()> { + async fn send_vermsg(&self, conn: &Conn) -> Result<()> { let pids = self.protocol_versions.read().await; let protocols = pids.iter().map(|p| (p.0.clone(), p.1.v.clone())).collect(); drop(pids); @@ -270,16 +275,16 @@ impl PeerPool { }; trace!("Send VerMsg"); - codec.write(NetMsgCmd::Version, &vermsg).await?; + conn.send(NetMsg::new(NetMsgCmd::Version, &vermsg)?).await?; Ok(()) } /// Wait for a Version message /// /// Returns the peer's ID upon successfully receiving the Version message. - async fn wait_vermsg(&self, codec: &Codec) -> Result { - let timeout = Duration::from_secs(self.config.handshake_timeout); - let msg: NetMsg = codec.read_timeout(timeout).await?; + async fn wait_vermsg(&self, conn: &Conn) -> Result { + let t = Duration::from_secs(self.config.handshake_timeout); + let msg: NetMsg = timeout(t, conn.recv()).await??; let payload = get_msg_payload!(Version, msg); let (vermsg, _) = decode::(&payload)?; @@ -295,23 +300,23 @@ impl PeerPool { } /// Send a Verack message - async fn send_verack(&self, codec: &Codec, ack: bool) -> Result<()> { + async fn send_verack(&self, conn: &Conn, ack: bool) -> Result<()> { let verack = VerAckMsg { peer_id: self.id.clone(), ack, }; trace!("Send VerAckMsg {:?}", verack); - codec.write(NetMsgCmd::Verack, &verack).await?; + conn.send(NetMsg::new(NetMsgCmd::Verack, &verack)?).await?; Ok(()) } /// Wait for a Verack message /// /// Returns the peer's ID upon successfully receiving the Verack message. - async fn wait_verack(&self, codec: &Codec) -> Result { - let timeout = Duration::from_secs(self.config.handshake_timeout); - let msg: NetMsg = codec.read_timeout(timeout).await?; + async fn wait_verack(&self, conn: &Conn) -> Result { + let t = Duration::from_secs(self.config.handshake_timeout); + let msg: NetMsg = timeout(t, conn.recv()).await??; let payload = get_msg_payload!(Verack, msg); let (verack, _) = decode::(&payload)?; -- cgit v1.2.3