aboutsummaryrefslogtreecommitdiff
path: root/p2p/src/peer_pool.rs
diff options
context:
space:
mode:
authorhozan23 <hozan23@karyontech.net>2024-04-11 10:19:20 +0200
committerhozan23 <hozan23@karyontech.net>2024-05-19 13:51:30 +0200
commit0992071a7f1a36424bcfaf1fbc84541ea041df1a (patch)
tree961d73218af672797d49f899289bef295bc56493 /p2p/src/peer_pool.rs
parenta69917ecd8272a4946cfd12c75bf8f8c075b0e50 (diff)
add support for tokio & improve net crate api
Diffstat (limited to 'p2p/src/peer_pool.rs')
-rw-r--r--p2p/src/peer_pool.rs65
1 files changed, 35 insertions, 30 deletions
diff --git a/p2p/src/peer_pool.rs b/p2p/src/peer_pool.rs
index 4e20c99..8b16ef5 100644
--- a/p2p/src/peer_pool.rs
+++ b/p2p/src/peer_pool.rs
@@ -4,21 +4,22 @@ use std::{
time::Duration,
};
+use async_channel::Sender;
+use bincode::{Decode, Encode};
use log::{error, info, trace, warn};
-use smol::{
- channel::Sender,
- lock::{Mutex, RwLock},
-};
use karyon_core::{
- async_util::{Executor, TaskGroup, TaskResult},
+ async_runtime::{
+ lock::{Mutex, RwLock},
+ Executor,
+ },
+ async_util::{timeout, TaskGroup, TaskResult},
util::decode,
};
use karyon_net::Conn;
use crate::{
- codec::{Codec, CodecMsg},
config::Config,
connection::{ConnDirection, ConnQueue},
message::{get_msg_payload, NetMsg, NetMsgCmd, VerAckMsg, VerMsg},
@@ -50,10 +51,10 @@ pub struct PeerPool {
protocol_versions: Arc<RwLock<HashMap<ProtocolID, Version>>>,
/// Managing spawned tasks.
- task_group: TaskGroup<'static>,
+ task_group: TaskGroup,
/// A global Executor
- executor: Executor<'static>,
+ executor: Executor,
/// The Configuration for the P2P network.
pub(crate) config: Arc<Config>,
@@ -69,7 +70,7 @@ impl PeerPool {
conn_queue: Arc<ConnQueue>,
config: Arc<Config>,
monitor: Arc<Monitor>,
- executor: Executor<'static>,
+ executor: Executor,
) -> Arc<Self> {
let protocols = RwLock::new(HashMap::new());
let protocol_versions = Arc::new(RwLock::new(HashMap::new()));
@@ -137,7 +138,7 @@ impl PeerPool {
}
/// Broadcast a message to all connected peers using the specified protocol.
- pub async fn broadcast<T: CodecMsg>(&self, proto_id: &ProtocolID, msg: &T) {
+ pub async fn broadcast<T: Decode + Encode>(&self, proto_id: &ProtocolID, msg: &T) {
for (pid, peer) in self.peers.lock().await.iter() {
if let Err(err) = peer.send(proto_id, msg).await {
error!("failed to send msg to {pid}: {err}");
@@ -149,15 +150,14 @@ impl PeerPool {
/// Add a new peer to the peer list.
pub async fn new_peer(
self: &Arc<Self>,
- conn: Conn,
+ conn: Conn<NetMsg>,
conn_direction: &ConnDirection,
disconnect_signal: Sender<Result<()>>,
) -> Result<()> {
let endpoint = conn.peer_endpoint()?;
- let codec = Codec::new(conn);
// Do a handshake with the connection before creating a new peer.
- let pid = self.do_handshake(&codec, conn_direction).await?;
+ let pid = self.do_handshake(&conn, conn_direction).await?;
// TODO: Consider restricting the subnet for inbound connections
if self.contains_peer(&pid).await {
@@ -168,7 +168,7 @@ impl PeerPool {
let peer = Peer::new(
Arc::downgrade(self),
&pid,
- codec,
+ conn,
endpoint.clone(),
conn_direction.clone(),
self.executor.clone(),
@@ -234,16 +234,21 @@ impl PeerPool {
}
/// Initiate a handshake with a connection.
- async fn do_handshake(&self, codec: &Codec, conn_direction: &ConnDirection) -> Result<PeerID> {
+ async fn do_handshake(
+ &self,
+ conn: &Conn<NetMsg>,
+ conn_direction: &ConnDirection,
+ ) -> Result<PeerID> {
+ trace!("Handshake started: {}", conn.peer_endpoint()?);
match conn_direction {
ConnDirection::Inbound => {
- let result = self.wait_vermsg(codec).await;
+ let result = self.wait_vermsg(conn).await;
match result {
Ok(_) => {
- self.send_verack(codec, true).await?;
+ self.send_verack(conn, true).await?;
}
Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => {
- self.send_verack(codec, false).await?;
+ self.send_verack(conn, false).await?;
}
_ => {}
}
@@ -251,14 +256,14 @@ impl PeerPool {
}
ConnDirection::Outbound => {
- self.send_vermsg(codec).await?;
- self.wait_verack(codec).await
+ self.send_vermsg(conn).await?;
+ self.wait_verack(conn).await
}
}
}
/// Send a Version message
- async fn send_vermsg(&self, codec: &Codec) -> Result<()> {
+ async fn send_vermsg(&self, conn: &Conn<NetMsg>) -> Result<()> {
let pids = self.protocol_versions.read().await;
let protocols = pids.iter().map(|p| (p.0.clone(), p.1.v.clone())).collect();
drop(pids);
@@ -270,16 +275,16 @@ impl PeerPool {
};
trace!("Send VerMsg");
- codec.write(NetMsgCmd::Version, &vermsg).await?;
+ conn.send(NetMsg::new(NetMsgCmd::Version, &vermsg)?).await?;
Ok(())
}
/// Wait for a Version message
///
/// Returns the peer's ID upon successfully receiving the Version message.
- async fn wait_vermsg(&self, codec: &Codec) -> Result<PeerID> {
- let timeout = Duration::from_secs(self.config.handshake_timeout);
- let msg: NetMsg = codec.read_timeout(timeout).await?;
+ async fn wait_vermsg(&self, conn: &Conn<NetMsg>) -> Result<PeerID> {
+ let t = Duration::from_secs(self.config.handshake_timeout);
+ let msg: NetMsg = timeout(t, conn.recv()).await??;
let payload = get_msg_payload!(Version, msg);
let (vermsg, _) = decode::<VerMsg>(&payload)?;
@@ -295,23 +300,23 @@ impl PeerPool {
}
/// Send a Verack message
- async fn send_verack(&self, codec: &Codec, ack: bool) -> Result<()> {
+ async fn send_verack(&self, conn: &Conn<NetMsg>, ack: bool) -> Result<()> {
let verack = VerAckMsg {
peer_id: self.id.clone(),
ack,
};
trace!("Send VerAckMsg {:?}", verack);
- codec.write(NetMsgCmd::Verack, &verack).await?;
+ conn.send(NetMsg::new(NetMsgCmd::Verack, &verack)?).await?;
Ok(())
}
/// Wait for a Verack message
///
/// Returns the peer's ID upon successfully receiving the Verack message.
- async fn wait_verack(&self, codec: &Codec) -> Result<PeerID> {
- let timeout = Duration::from_secs(self.config.handshake_timeout);
- let msg: NetMsg = codec.read_timeout(timeout).await?;
+ async fn wait_verack(&self, conn: &Conn<NetMsg>) -> Result<PeerID> {
+ let t = Duration::from_secs(self.config.handshake_timeout);
+ let msg: NetMsg = timeout(t, conn.recv()).await??;
let payload = get_msg_payload!(Verack, msg);
let (verack, _) = decode::<VerAckMsg>(&payload)?;