aboutsummaryrefslogtreecommitdiff
path: root/p2p/src
diff options
context:
space:
mode:
authorhozan23 <hozan23@karyontech.net>2024-04-11 10:19:20 +0200
committerhozan23 <hozan23@karyontech.net>2024-05-19 13:51:30 +0200
commit0992071a7f1a36424bcfaf1fbc84541ea041df1a (patch)
tree961d73218af672797d49f899289bef295bc56493 /p2p/src
parenta69917ecd8272a4946cfd12c75bf8f8c075b0e50 (diff)
add support for tokio & improve net crate api
Diffstat (limited to 'p2p/src')
-rw-r--r--p2p/src/backend.rs4
-rw-r--r--p2p/src/codec.rs149
-rw-r--r--p2p/src/connection.rs12
-rw-r--r--p2p/src/connector.rs35
-rw-r--r--p2p/src/discovery/lookup.rs71
-rw-r--r--p2p/src/discovery/mod.rs45
-rw-r--r--p2p/src/discovery/refresh.rs92
-rw-r--r--p2p/src/error.rs18
-rw-r--r--p2p/src/lib.rs4
-rw-r--r--p2p/src/listener.rs36
-rw-r--r--p2p/src/message.rs58
-rw-r--r--p2p/src/monitor.rs2
-rw-r--r--p2p/src/peer/mod.rs42
-rw-r--r--p2p/src/peer_pool.rs65
-rw-r--r--p2p/src/protocol.rs2
-rw-r--r--p2p/src/protocols/ping.rs23
-rw-r--r--p2p/src/routing_table/entry.rs8
-rw-r--r--p2p/src/routing_table/mod.rs4
-rw-r--r--p2p/src/tls_config.rs12
19 files changed, 344 insertions, 338 deletions
diff --git a/p2p/src/backend.rs b/p2p/src/backend.rs
index d33f3dc..2f21b3e 100644
--- a/p2p/src/backend.rs
+++ b/p2p/src/backend.rs
@@ -2,7 +2,7 @@ use std::sync::Arc;
use log::info;
-use karyon_core::{async_util::Executor, crypto::KeyPair, pubsub::Subscription};
+use karyon_core::{async_runtime::Executor, crypto::KeyPair, pubsub::Subscription};
use crate::{
config::Config,
@@ -37,7 +37,7 @@ pub struct Backend {
impl Backend {
/// Creates a new Backend.
- pub fn new(key_pair: &KeyPair, config: Config, ex: Executor<'static>) -> ArcBackend {
+ pub fn new(key_pair: &KeyPair, config: Config, ex: Executor) -> ArcBackend {
let config = Arc::new(config);
let monitor = Arc::new(Monitor::new());
let conn_queue = ConnQueue::new();
diff --git a/p2p/src/codec.rs b/p2p/src/codec.rs
index 726a2f7..3d0f323 100644
--- a/p2p/src/codec.rs
+++ b/p2p/src/codec.rs
@@ -1,120 +1,69 @@
-use std::time::Duration;
+use karyon_core::util::{decode, encode, encode_into_slice};
-use bincode::{Decode, Encode};
-
-use karyon_core::{
- async_util::timeout,
- util::{decode, encode, encode_into_slice},
-};
-
-use karyon_net::{Connection, NetError};
-
-use crate::{
- message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE},
- Error, Result,
+use karyon_net::{
+ codec::{Codec, Decoder, Encoder, LengthCodec},
+ Result,
};
-pub trait CodecMsg: Decode + Encode + std::fmt::Debug {}
-impl<T: Encode + Decode + std::fmt::Debug> CodecMsg for T {}
+use crate::message::{NetMsg, RefreshMsg};
-/// A Codec working with generic network connections.
-///
-/// It is responsible for both decoding data received from the network and
-/// encoding data before sending it.
-pub struct Codec {
- conn: Box<dyn Connection>,
+#[derive(Clone)]
+pub struct NetMsgCodec {
+ inner_codec: LengthCodec,
}
-impl Codec {
- /// Creates a new Codec.
- pub fn new(conn: Box<dyn Connection>) -> Self {
- Self { conn }
- }
-
- /// Reads a message of type `NetMsg` from the connection.
- ///
- /// It reads the first 6 bytes as the header of the message, then reads
- /// and decodes the remaining message data based on the determined header.
- pub async fn read(&self) -> Result<NetMsg> {
- // Read 6 bytes to get the header of the incoming message
- let mut buf = [0; MSG_HEADER_SIZE];
- self.read_exact(&mut buf).await?;
-
- // Decode the header from bytes to NetMsgHeader
- let (header, _) = decode::<NetMsgHeader>(&buf)?;
-
- if header.payload_size > MAX_ALLOWED_MSG_SIZE {
- return Err(Error::InvalidMsg(
- "Message exceeds the maximum allowed size".to_string(),
- ));
+impl NetMsgCodec {
+ pub fn new() -> Self {
+ Self {
+ inner_codec: LengthCodec {},
}
-
- // Create a buffer to hold the message based on its length
- let mut payload = vec![0; header.payload_size as usize];
- self.read_exact(&mut payload).await?;
-
- Ok(NetMsg { header, payload })
}
+}
- /// Writes a message of type `T` to the connection.
- ///
- /// Before appending the actual message payload, it calculates the length of
- /// the encoded message in bytes and appends this length to the message header.
- pub async fn write<T: CodecMsg>(&self, command: NetMsgCmd, msg: &T) -> Result<()> {
- let payload = encode(msg)?;
-
- // Create a buffer to hold the message header (6 bytes)
- let header_buf = &mut [0; MSG_HEADER_SIZE];
- let header = NetMsgHeader {
- command,
- payload_size: payload.len() as u32,
- };
- encode_into_slice(&header, header_buf)?;
-
- let mut buffer = vec![];
- // Append the header bytes to the buffer
- buffer.extend_from_slice(header_buf);
- // Append the message payload to the buffer
- buffer.extend_from_slice(&payload);
-
- self.write_all(&buffer).await?;
- Ok(())
- }
+impl Codec for NetMsgCodec {
+ type Item = NetMsg;
+}
- /// Reads a message of type `NetMsg` with the given timeout.
- pub async fn read_timeout(&self, duration: Duration) -> Result<NetMsg> {
- timeout(duration, self.read())
- .await
- .map_err(|_| NetError::Timeout)?
+impl Encoder for NetMsgCodec {
+ type EnItem = NetMsg;
+ fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result<usize> {
+ let src = encode(src)?;
+ self.inner_codec.encode(&src, dst)
}
+}
- /// Reads the exact number of bytes required to fill `buf`.
- async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> {
- while !buf.is_empty() {
- let n = self.conn.read(buf).await?;
- let (_, rest) = std::mem::take(&mut buf).split_at_mut(n);
- buf = rest;
-
- if n == 0 {
- return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
+impl Decoder for NetMsgCodec {
+ type DeItem = NetMsg;
+ fn decode(&self, src: &mut [u8]) -> Result<Option<(usize, Self::DeItem)>> {
+ match self.inner_codec.decode(src)? {
+ Some((n, s)) => {
+ let (m, _) = decode::<Self::DeItem>(&s)?;
+ Ok(Some((n, m)))
}
+ None => Ok(None),
}
-
- Ok(())
}
+}
- /// Writes an entire buffer into the connection.
- async fn write_all(&self, mut buf: &[u8]) -> Result<()> {
- while !buf.is_empty() {
- let n = self.conn.write(buf).await?;
- let (_, rest) = std::mem::take(&mut buf).split_at(n);
- buf = rest;
+#[derive(Clone)]
+pub struct RefreshMsgCodec {}
- if n == 0 {
- return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
- }
- }
+impl Codec for RefreshMsgCodec {
+ type Item = RefreshMsg;
+}
+
+impl Encoder for RefreshMsgCodec {
+ type EnItem = RefreshMsg;
+ fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result<usize> {
+ let n = encode_into_slice(src, dst)?;
+ Ok(n)
+ }
+}
- Ok(())
+impl Decoder for RefreshMsgCodec {
+ type DeItem = RefreshMsg;
+ fn decode(&self, src: &mut [u8]) -> Result<Option<(usize, Self::DeItem)>> {
+ let (m, n) = decode::<Self::DeItem>(src)?;
+ Ok(Some((n, m)))
}
}
diff --git a/p2p/src/connection.rs b/p2p/src/connection.rs
index 9fa57cb..9a153f3 100644
--- a/p2p/src/connection.rs
+++ b/p2p/src/connection.rs
@@ -1,11 +1,11 @@
use std::{collections::VecDeque, fmt, sync::Arc};
-use smol::{channel::Sender, lock::Mutex};
+use async_channel::Sender;
-use karyon_core::async_util::CondVar;
+use karyon_core::{async_runtime::lock::Mutex, async_util::CondVar};
use karyon_net::Conn;
-use crate::Result;
+use crate::{message::NetMsg, Result};
/// Defines the direction of a network connection.
#[derive(Clone, Debug)]
@@ -25,7 +25,7 @@ impl fmt::Display for ConnDirection {
pub struct NewConn {
pub direction: ConnDirection,
- pub conn: Conn,
+ pub conn: Conn<NetMsg>,
pub disconnect_signal: Sender<Result<()>>,
}
@@ -44,8 +44,8 @@ impl ConnQueue {
}
/// Push a connection into the queue and wait for the disconnect signal
- pub async fn handle(&self, conn: Conn, direction: ConnDirection) -> Result<()> {
- let (disconnect_signal, chan) = smol::channel::bounded(1);
+ pub async fn handle(&self, conn: Conn<NetMsg>, direction: ConnDirection) -> Result<()> {
+ let (disconnect_signal, chan) = async_channel::bounded(1);
let new_conn = NewConn {
direction,
conn,
diff --git a/p2p/src/connector.rs b/p2p/src/connector.rs
index de9e746..aea21ab 100644
--- a/p2p/src/connector.rs
+++ b/p2p/src/connector.rs
@@ -3,12 +3,15 @@ use std::{future::Future, sync::Arc};
use log::{error, trace, warn};
use karyon_core::{
- async_util::{Backoff, Executor, TaskGroup, TaskResult},
+ async_runtime::Executor,
+ async_util::{Backoff, TaskGroup, TaskResult},
crypto::KeyPair,
};
-use karyon_net::{tcp, tls, Conn, Endpoint, NetError};
+use karyon_net::{tcp, tls, Conn, Endpoint, Error as NetError};
use crate::{
+ codec::NetMsgCodec,
+ message::NetMsg,
monitor::{ConnEvent, Monitor},
slots::ConnectionSlots,
tls_config::tls_client_config,
@@ -23,7 +26,7 @@ pub struct Connector {
key_pair: KeyPair,
/// Managing spawned tasks.
- task_group: TaskGroup<'static>,
+ task_group: TaskGroup,
/// Manages available outbound slots.
connection_slots: Arc<ConnectionSlots>,
@@ -47,7 +50,7 @@ impl Connector {
connection_slots: Arc<ConnectionSlots>,
enable_tls: bool,
monitor: Arc<Monitor>,
- ex: Executor<'static>,
+ ex: Executor,
) -> Arc<Self> {
Arc::new(Self {
key_pair: key_pair.clone(),
@@ -70,7 +73,11 @@ impl Connector {
/// `Conn` instance.
///
/// This method will block until it finds an available slot.
- pub async fn connect(&self, endpoint: &Endpoint, peer_id: &Option<PeerID>) -> Result<Conn> {
+ pub async fn connect(
+ &self,
+ endpoint: &Endpoint,
+ peer_id: &Option<PeerID>,
+ ) -> Result<Conn<NetMsg>> {
self.connection_slots.wait_for_slot().await;
self.connection_slots.add();
@@ -113,7 +120,7 @@ impl Connector {
self: &Arc<Self>,
endpoint: &Endpoint,
peer_id: &Option<PeerID>,
- callback: impl FnOnce(Conn) -> Fut + Send + 'static,
+ callback: impl FnOnce(Conn<NetMsg>) -> Fut + Send + 'static,
) -> Result<()>
where
Fut: Future<Output = Result<()>> + Send + 'static,
@@ -138,14 +145,20 @@ impl Connector {
Ok(())
}
- async fn dial(&self, endpoint: &Endpoint, peer_id: &Option<PeerID>) -> Result<Conn> {
+ async fn dial(&self, endpoint: &Endpoint, peer_id: &Option<PeerID>) -> Result<Conn<NetMsg>> {
if self.enable_tls {
- let tls_config = tls_client_config(&self.key_pair, peer_id.clone())?;
- tls::dial(endpoint, tls_config, DNS_NAME)
+ let tls_config = tls::ClientTlsConfig {
+ tcp_config: Default::default(),
+ client_config: tls_client_config(&self.key_pair, peer_id.clone())?,
+ dns_name: DNS_NAME.to_string(),
+ };
+ tls::dial(endpoint, tls_config, NetMsgCodec::new())
.await
- .map(|l| Box::new(l) as Conn)
+ .map(|l| Box::new(l) as karyon_net::Conn<NetMsg>)
} else {
- tcp::dial(endpoint).await.map(|l| Box::new(l) as Conn)
+ tcp::dial(endpoint, tcp::TcpConfig::default(), NetMsgCodec::new())
+ .await
+ .map(|l| Box::new(l) as karyon_net::Conn<NetMsg>)
}
.map_err(Error::KaryonNet)
}
diff --git a/p2p/src/discovery/lookup.rs b/p2p/src/discovery/lookup.rs
index c81fbc6..cff4610 100644
--- a/p2p/src/discovery/lookup.rs
+++ b/p2p/src/discovery/lookup.rs
@@ -3,10 +3,13 @@ use std::{sync::Arc, time::Duration};
use futures_util::{stream::FuturesUnordered, StreamExt};
use log::{error, trace};
use rand::{rngs::OsRng, seq::SliceRandom, RngCore};
-use smol::lock::{Mutex, RwLock};
use karyon_core::{
- async_util::{timeout, Executor},
+ async_runtime::{
+ lock::{Mutex, RwLock},
+ Executor,
+ },
+ async_util::timeout,
crypto::KeyPair,
util::decode,
};
@@ -14,7 +17,6 @@ use karyon_core::{
use karyon_net::{Conn, Endpoint};
use crate::{
- codec::Codec,
connector::Connector,
listener::Listener,
message::{
@@ -64,7 +66,7 @@ impl LookupService {
table: Arc<Mutex<RoutingTable>>,
config: Arc<Config>,
monitor: Arc<Monitor>,
- ex: Executor<'static>,
+ ex: Executor,
) -> Self {
let inbound_slots = Arc::new(ConnectionSlots::new(config.lookup_inbound_slots));
let outbound_slots = Arc::new(ConnectionSlots::new(config.lookup_outbound_slots));
@@ -228,8 +230,7 @@ impl LookupService {
target_peer_id: &PeerID,
) -> Result<Vec<PeerMsg>> {
let conn = self.connector.connect(&endpoint, &peer_id).await?;
- let io_codec = Codec::new(conn);
- let result = self.handle_outbound(io_codec, target_peer_id).await;
+ let result = self.handle_outbound(conn, target_peer_id).await;
self.monitor
.notify(&ConnEvent::Disconnected(endpoint).into())
@@ -242,14 +243,14 @@ impl LookupService {
/// Handles outbound connection
async fn handle_outbound(
&self,
- io_codec: Codec,
+ conn: Conn<NetMsg>,
target_peer_id: &PeerID,
) -> Result<Vec<PeerMsg>> {
trace!("Send Ping msg");
- self.send_ping_msg(&io_codec).await?;
+ self.send_ping_msg(&conn).await?;
trace!("Send FindPeer msg");
- let peers = self.send_findpeer_msg(&io_codec, target_peer_id).await?;
+ let peers = self.send_findpeer_msg(&conn, target_peer_id).await?;
if peers.0.len() >= MAX_PEERS_IN_PEERSMSG {
return Err(Error::Lookup("Received too many peers in PeersMsg"));
@@ -257,12 +258,12 @@ impl LookupService {
trace!("Send Peer msg");
if let Some(endpoint) = &self.listen_endpoint {
- self.send_peer_msg(&io_codec, endpoint.read().await.clone())
+ self.send_peer_msg(&conn, endpoint.read().await.clone())
.await?;
}
trace!("Send Shutdown msg");
- self.send_shutdown_msg(&io_codec).await?;
+ self.send_shutdown_msg(&conn).await?;
Ok(peers.0)
}
@@ -277,7 +278,7 @@ impl LookupService {
let endpoint = Endpoint::Tcp(addr, self.config.discovery_port);
let selfc = self.clone();
- let callback = |conn: Conn| async move {
+ let callback = |conn: Conn<NetMsg>| async move {
let t = Duration::from_secs(selfc.config.lookup_connection_lifespan);
timeout(t, selfc.handle_inbound(conn)).await??;
Ok(())
@@ -288,10 +289,9 @@ impl LookupService {
}
/// Handles inbound connection
- async fn handle_inbound(self: &Arc<Self>, conn: Conn) -> Result<()> {
- let io_codec = Codec::new(conn);
+ async fn handle_inbound(self: &Arc<Self>, conn: Conn<NetMsg>) -> Result<()> {
loop {
- let msg: NetMsg = io_codec.read().await?;
+ let msg: NetMsg = conn.recv().await?;
trace!("Receive msg {:?}", msg.header.command);
if let NetMsgCmd::Shutdown = msg.header.command {
@@ -304,12 +304,12 @@ impl LookupService {
if !version_match(&self.config.version.req, &ping_msg.version) {
return Err(Error::IncompatibleVersion("system: {}".into()));
}
- self.send_pong_msg(ping_msg.nonce, &io_codec).await?;
+ self.send_pong_msg(ping_msg.nonce, &conn).await?;
}
NetMsgCmd::FindPeer => {
let (findpeer_msg, _) = decode::<FindPeerMsg>(&msg.payload)?;
let peer_id = findpeer_msg.0;
- self.send_peers_msg(&peer_id, &io_codec).await?;
+ self.send_peers_msg(&peer_id, &conn).await?;
}
NetMsgCmd::Peer => {
let (peer, _) = decode::<PeerMsg>(&msg.payload)?;
@@ -322,7 +322,7 @@ impl LookupService {
}
/// Sends a Ping msg and wait to receive the Pong message.
- async fn send_ping_msg(&self, io_codec: &Codec) -> Result<()> {
+ async fn send_ping_msg(&self, conn: &Conn<NetMsg>) -> Result<()> {
trace!("Send Pong msg");
let mut nonce: [u8; 32] = [0; 32];
@@ -332,10 +332,10 @@ impl LookupService {
version: self.config.version.v.clone(),
nonce,
};
- io_codec.write(NetMsgCmd::Ping, &ping_msg).await?;
+ conn.send(NetMsg::new(NetMsgCmd::Ping, &ping_msg)?).await?;
let t = Duration::from_secs(self.config.lookup_response_timeout);
- let recv_msg: NetMsg = io_codec.read_timeout(t).await?;
+ let recv_msg: NetMsg = timeout(t, conn.recv()).await??;
let payload = get_msg_payload!(Pong, recv_msg);
let (pong_msg, _) = decode::<PongMsg>(&payload)?;
@@ -348,21 +348,24 @@ impl LookupService {
}
/// Sends a Pong msg
- async fn send_pong_msg(&self, nonce: [u8; 32], io_codec: &Codec) -> Result<()> {
+ async fn send_pong_msg(&self, nonce: [u8; 32], conn: &Conn<NetMsg>) -> Result<()> {
trace!("Send Pong msg");
- io_codec.write(NetMsgCmd::Pong, &PongMsg(nonce)).await?;
+ conn.send(NetMsg::new(NetMsgCmd::Pong, &PongMsg(nonce))?)
+ .await?;
Ok(())
}
/// Sends a FindPeer msg and wait to receivet the Peers msg.
- async fn send_findpeer_msg(&self, io_codec: &Codec, peer_id: &PeerID) -> Result<PeersMsg> {
+ async fn send_findpeer_msg(&self, conn: &Conn<NetMsg>, peer_id: &PeerID) -> Result<PeersMsg> {
trace!("Send FindPeer msg");
- io_codec
- .write(NetMsgCmd::FindPeer, &FindPeerMsg(peer_id.clone()))
- .await?;
+ conn.send(NetMsg::new(
+ NetMsgCmd::FindPeer,
+ &FindPeerMsg(peer_id.clone()),
+ )?)
+ .await?;
let t = Duration::from_secs(self.config.lookup_response_timeout);
- let recv_msg: NetMsg = io_codec.read_timeout(t).await?;
+ let recv_msg: NetMsg = timeout(t, conn.recv()).await??;
let payload = get_msg_payload!(Peers, recv_msg);
let (peers, _) = decode(&payload)?;
@@ -371,19 +374,20 @@ impl LookupService {
}
/// Sends a Peers msg.
- async fn send_peers_msg(&self, peer_id: &PeerID, io_codec: &Codec) -> Result<()> {
+ async fn send_peers_msg(&self, peer_id: &PeerID, conn: &Conn<NetMsg>) -> Result<()> {
trace!("Send Peers msg");
let table = self.table.lock().await;
let entries = table.closest_entries(&peer_id.0, MAX_PEERS_IN_PEERSMSG);
drop(table);
let peers: Vec<PeerMsg> = entries.into_iter().map(|e| e.into()).collect();
- io_codec.write(NetMsgCmd::Peers, &PeersMsg(peers)).await?;
+ conn.send(NetMsg::new(NetMsgCmd::Peers, &PeersMsg(peers))?)
+ .await?;
Ok(())
}
/// Sends a Peer msg.
- async fn send_peer_msg(&self, io_codec: &Codec, endpoint: Endpoint) -> Result<()> {
+ async fn send_peer_msg(&self, conn: &Conn<NetMsg>, endpoint: Endpoint) -> Result<()> {
trace!("Send Peer msg");
let peer_msg = PeerMsg {
addr: endpoint.addr()?.clone(),
@@ -391,14 +395,15 @@ impl LookupService {
discovery_port: self.config.discovery_port,
peer_id: self.id.clone(),
};
- io_codec.write(NetMsgCmd::Peer, &peer_msg).await?;
+ conn.send(NetMsg::new(NetMsgCmd::Peer, &peer_msg)?).await?;
Ok(())
}
/// Sends a Shutdown msg.
- async fn send_shutdown_msg(&self, io_codec: &Codec) -> Result<()> {
+ async fn send_shutdown_msg(&self, conn: &Conn<NetMsg>) -> Result<()> {
trace!("Send Shutdown msg");
- io_codec.write(NetMsgCmd::Shutdown, &ShutdownMsg(0)).await?;
+ conn.send(NetMsg::new(NetMsgCmd::Shutdown, &ShutdownMsg(0))?)
+ .await?;
Ok(())
}
}
diff --git a/p2p/src/discovery/mod.rs b/p2p/src/discovery/mod.rs
index 3e437aa..19ae77a 100644
--- a/p2p/src/discovery/mod.rs
+++ b/p2p/src/discovery/mod.rs
@@ -5,10 +5,10 @@ use std::sync::Arc;
use log::{error, info};
use rand::{rngs::OsRng, seq::SliceRandom};
-use smol::lock::Mutex;
use karyon_core::{
- async_util::{Backoff, Executor, TaskGroup, TaskResult},
+ async_runtime::{lock::Mutex, Executor},
+ async_util::{Backoff, TaskGroup, TaskResult},
crypto::KeyPair,
};
@@ -19,6 +19,7 @@ use crate::{
connection::{ConnDirection, ConnQueue},
connector::Connector,
listener::Listener,
+ message::NetMsg,
monitor::Monitor,
routing_table::{
Entry, EntryStatusFlag, RoutingTable, CONNECTED_ENTRY, DISCONNECTED_ENTRY,
@@ -45,6 +46,7 @@ pub struct Discovery {
/// Connector
connector: Arc<Connector>,
+
/// Listener
listener: Arc<Listener>,
@@ -53,11 +55,12 @@ pub struct Discovery {
/// Inbound slots.
pub(crate) inbound_slots: Arc<ConnectionSlots>,
+
/// Outbound slots.
pub(crate) outbound_slots: Arc<ConnectionSlots>,
/// Managing spawned tasks.
- task_group: TaskGroup<'static>,
+ task_group: TaskGroup,
/// Holds the configuration for the P2P network.
config: Arc<Config>,
@@ -71,7 +74,7 @@ impl Discovery {
conn_queue: Arc<ConnQueue>,
config: Arc<Config>,
monitor: Arc<Monitor>,
- ex: Executor<'static>,
+ ex: Executor,
) -> ArcDiscovery {
let inbound_slots = Arc::new(ConnectionSlots::new(config.inbound_slots));
let outbound_slots = Arc::new(ConnectionSlots::new(config.outbound_slots));
@@ -180,7 +183,7 @@ impl Discovery {
/// Start a listener and on success, return the resolved endpoint.
async fn start_listener(self: &Arc<Self>, endpoint: &Endpoint) -> Result<Endpoint> {
let selfc = self.clone();
- let callback = |c: Conn| async move {
+ let callback = |c: Conn<NetMsg>| async move {
selfc.conn_queue.handle(c, ConnDirection::Inbound).await?;
Ok(())
};
@@ -198,8 +201,8 @@ impl Discovery {
async fn connect_loop(self: Arc<Self>) -> Result<()> {
let backoff = Backoff::new(500, self.config.seeding_interval * 1000);
loop {
- let random_entry = self.random_entry(PENDING_ENTRY).await;
- match random_entry {
+ let random_table_entry = self.random_table_entry(PENDING_ENTRY).await;
+ match random_table_entry {
Some(entry) => {
backoff.reset();
let endpoint = Endpoint::Tcp(entry.addr, entry.port);
@@ -218,7 +221,7 @@ impl Discovery {
let selfc = self.clone();
let pid_c = pid.clone();
let endpoint_c = endpoint.clone();
- let cback = |conn: Conn| async move {
+ let cback = |conn: Conn<NetMsg>| async move {
let result = selfc.conn_queue.handle(conn, ConnDirection::Outbound).await;
// If the entry is not in the routing table, ignore the result
@@ -230,17 +233,17 @@ impl Discovery {
match result {
Err(Error::IncompatiblePeer) => {
error!("Failed to do handshake: {endpoint_c} incompatible peer");
- selfc.update_entry(&pid, INCOMPATIBLE_ENTRY).await;
+ selfc.update_table_entry(&pid, INCOMPATIBLE_ENTRY).await;
}
Err(Error::PeerAlreadyConnected) => {
- // TODO: Use the appropriate status.
- selfc.update_entry(&pid, DISCONNECTED_ENTRY).await;
+ // TODO: Use an appropriate status.
+ selfc.update_table_entry(&pid, DISCONNECTED_ENTRY).await;
}
Err(_) => {
- selfc.update_entry(&pid, UNSTABLE_ENTRY).await;
+ selfc.update_table_entry(&pid, UNSTABLE_ENTRY).await;
}
Ok(_) => {
- selfc.update_entry(&pid, DISCONNECTED_ENTRY).await;
+ selfc.update_table_entry(&pid, DISCONNECTED_ENTRY).await;
}
}
@@ -255,10 +258,10 @@ impl Discovery {
if let Some(pid) = &pid {
match result {
Ok(_) => {
- self.update_entry(pid, CONNECTED_ENTRY).await;
+ self.update_table_entry(pid, CONNECTED_ENTRY).await;
}
Err(_) => {
- self.update_entry(pid, UNREACHABLE_ENTRY).await;
+ self.update_table_entry(pid, UNREACHABLE_ENTRY).await;
}
}
}
@@ -271,12 +274,16 @@ impl Discovery {
/// table doesn't have an available entry, it will connect to one of the
/// provided bootstrap endpoints in the `Config` and initiate the lookup.
async fn start_seeding(&self) {
- match self.random_entry(PENDING_ENTRY | CONNECTED_ENTRY).await {
+ match self
+ .random_table_entry(PENDING_ENTRY | CONNECTED_ENTRY)
+ .await
+ {
Some(entry) => {
let endpoint = Endpoint::Tcp(entry.addr, entry.discovery_port);
let peer_id = Some(entry.key.into());
if let Err(err) = self.lookup_service.start_lookup(&endpoint, peer_id).await {
- self.update_entry(&entry.key.into(), UNSTABLE_ENTRY).await;
+ self.update_table_entry(&entry.key.into(), UNSTABLE_ENTRY)
+ .await;
error!("Failed to do lookup: {endpoint}: {err}");
}
}
@@ -292,12 +299,12 @@ impl Discovery {
}
/// Returns a random entry from routing table.
- async fn random_entry(&self, entry_flag: EntryStatusFlag) -> Option<Entry> {
+ async fn random_table_entry(&self, entry_flag: EntryStatusFlag) -> Option<Entry> {
self.table.lock().await.random_entry(entry_flag).cloned()
}
/// Update the entry status
- async fn update_entry(&self, pid: &PeerID, entry_flag: EntryStatusFlag) {
+ async fn update_table_entry(&self, pid: &PeerID, entry_flag: EntryStatusFlag) {
let table = &mut self.table.lock().await;
table.update_entry(&pid.0, entry_flag);
}
diff --git a/p2p/src/discovery/refresh.rs b/p2p/src/discovery/refresh.rs
index 035a581..0c49ac2 100644
--- a/p2p/src/discovery/refresh.rs
+++ b/p2p/src/discovery/refresh.rs
@@ -3,31 +3,28 @@ use std::{sync::Arc, time::Duration};
use bincode::{Decode, Encode};
use log::{error, info, trace};
use rand::{rngs::OsRng, RngCore};
-use smol::{
- lock::{Mutex, RwLock},
- stream::StreamExt,
- Timer,
-};
use karyon_core::{
- async_util::{timeout, Backoff, Executor, TaskGroup, TaskResult},
- util::{decode, encode},
+ async_runtime::{
+ lock::{Mutex, RwLock},
+ Executor,
+ },
+ async_util::{sleep, timeout, Backoff, TaskGroup, TaskResult},
};
-use karyon_net::{udp, Connection, Endpoint, NetError};
-
-/// Maximum failures for an entry before removing it from the routing table.
-pub const MAX_FAILURES: u32 = 3;
-
-/// Ping message size
-const PINGMSG_SIZE: usize = 32;
+use karyon_net::{udp, Connection, Endpoint, Error as NetError};
use crate::{
+ codec::RefreshMsgCodec,
+ message::RefreshMsg,
monitor::{ConnEvent, DiscoveryEvent, Monitor},
routing_table::{BucketEntry, Entry, RoutingTable, PENDING_ENTRY, UNREACHABLE_ENTRY},
Config, Error, Result,
};
+/// Maximum failures for an entry before removing it from the routing table.
+pub const MAX_FAILURES: u32 = 3;
+
#[derive(Decode, Encode, Debug, Clone)]
pub struct PingMsg(pub [u8; 32]);
@@ -42,10 +39,10 @@ pub struct RefreshService {
listen_endpoint: Option<RwLock<Endpoint>>,
/// Managing spawned tasks.
- task_group: TaskGroup<'static>,
+ task_group: TaskGroup,
/// A global executor
- executor: Executor<'static>,
+ executor: Executor,
/// Holds the configuration for the P2P network.
config: Arc<Config>,
@@ -60,7 +57,7 @@ impl RefreshService {
config: Arc<Config>,
table: Arc<Mutex<RoutingTable>>,
monitor: Arc<Monitor>,
- executor: Executor<'static>,
+ executor: Executor,
) -> Self {
let listen_endpoint = config
.listen_endpoint
@@ -118,9 +115,8 @@ impl RefreshService {
/// selects the first 8 entries (oldest entries) from each bucket in the
/// routing table and starts sending Ping messages to the collected entries.
async fn refresh_loop(self: Arc<Self>) -> Result<()> {
- let mut timer = Timer::interval(Duration::from_secs(self.config.refresh_interval));
loop {
- timer.next().await;
+ sleep(Duration::from_secs(self.config.refresh_interval)).await;
trace!("Start refreshing the routing table...");
self.monitor
@@ -162,7 +158,7 @@ impl RefreshService {
}
for task in tasks {
- task.await;
+ let _ = task.await;
}
}
}
@@ -193,10 +189,10 @@ impl RefreshService {
async fn connect(&self, entry: &Entry) -> Result<()> {
let mut retry = 0;
let endpoint = Endpoint::Udp(entry.addr.clone(), entry.discovery_port);
- let conn = udp::dial(&endpoint).await?;
+ let conn = udp::dial(&endpoint, Default::default(), RefreshMsgCodec {}).await?;
let backoff = Backoff::new(100, 5000);
while retry < self.config.refresh_connect_retries {
- match self.send_ping_msg(&conn).await {
+ match self.send_ping_msg(&conn, &endpoint).await {
Ok(()) => return Ok(()),
Err(Error::KaryonNet(NetError::Timeout)) => {
retry += 1;
@@ -214,7 +210,7 @@ impl RefreshService {
/// Set up a UDP listener and start listening for Ping messages from other
/// peers.
async fn listen_loop(self: Arc<Self>, endpoint: Endpoint) -> Result<()> {
- let conn = match udp::listen(&endpoint).await {
+ let conn = match udp::listen(&endpoint, Default::default(), RefreshMsgCodec {}).await {
Ok(c) => {
self.monitor
.notify(&ConnEvent::Listening(endpoint.clone()).into())
@@ -240,46 +236,48 @@ impl RefreshService {
}
/// Listen to receive a Ping message and respond with a Pong message.
- async fn listen_to_ping_msg(&self, conn: &udp::UdpConn) -> Result<()> {
- let mut buf = [0; PINGMSG_SIZE];
- let (_, endpoint) = conn.recv_from(&mut buf).await?;
-
+ async fn listen_to_ping_msg(&self, conn: &udp::UdpConn<RefreshMsgCodec>) -> Result<()> {
+ let (msg, endpoint) = conn.recv().await?;
self.monitor
.notify(&ConnEvent::Accepted(endpoint.clone()).into())
.await;
- let (ping_msg, _) = decode::<PingMsg>(&buf)?;
-
- let pong_msg = PongMsg(ping_msg.0);
- let buffer = encode(&pong_msg)?;
-
- conn.send_to(&buffer, &endpoint).await?;
+ match msg {
+ RefreshMsg::Ping(m) => {
+ let pong_msg = RefreshMsg::Pong(m);
+ conn.send((pong_msg, endpoint.clone())).await?;
+ }
+ RefreshMsg::Pong(_) => return Err(Error::InvalidMsg("Unexpected pong msg".into())),
+ }
self.monitor
- .notify(&ConnEvent::Disconnected(endpoint.clone()).into())
+ .notify(&ConnEvent::Disconnected(endpoint).into())
.await;
Ok(())
}
/// Sends a Ping msg and wait to receive the Pong message.
- async fn send_ping_msg(&self, conn: &udp::UdpConn) -> Result<()> {
+ async fn send_ping_msg(
+ &self,
+ conn: &udp::UdpConn<RefreshMsgCodec>,
+ endpoint: &Endpoint,
+ ) -> Result<()> {
let mut nonce: [u8; 32] = [0; 32];
RngCore::fill_bytes(&mut OsRng, &mut nonce);
+ conn.send((RefreshMsg::Ping(nonce), endpoint.clone()))
+ .await?;
- let ping_msg = PingMsg(nonce);
- let buffer = encode(&ping_msg)?;
- conn.write(&buffer).await?;
-
- let buf = &mut [0; PINGMSG_SIZE];
let t = Duration::from_secs(self.config.refresh_response_timeout);
- timeout(t, conn.read(buf)).await??;
-
- let (pong_msg, _) = decode::<PongMsg>(buf)?;
+ let (msg, _) = timeout(t, conn.recv()).await??;
- if ping_msg.0 != pong_msg.0 {
- return Err(Error::InvalidPongMsg);
+ match msg {
+ RefreshMsg::Pong(n) => {
+ if n != nonce {
+ return Err(Error::InvalidPongMsg);
+ }
+ Ok(())
+ }
+ _ => Err(Error::InvalidMsg("Unexpected ping msg".into())),
}
-
- Ok(())
}
}
diff --git a/p2p/src/error.rs b/p2p/src/error.rs
index b4ddc2e..97b7b7f 100644
--- a/p2p/src/error.rs
+++ b/p2p/src/error.rs
@@ -62,27 +62,35 @@ pub enum Error {
#[error("Rcgen Error: {0}")]
Rcgen(#[from] rcgen::Error),
+ #[cfg(feature = "smol")]
#[error("Tls Error: {0}")]
Rustls(#[from] futures_rustls::rustls::Error),
+ #[cfg(feature = "tokio")]
+ #[error("Tls Error: {0}")]
+ Rustls(#[from] tokio_rustls::rustls::Error),
+
#[error("Invalid DNS Name: {0}")]
- InvalidDnsNameError(#[from] futures_rustls::pki_types::InvalidDnsNameError),
+ InvalidDnsNameError(#[from] rustls_pki_types::InvalidDnsNameError),
#[error("Channel Send Error: {0}")]
ChannelSend(String),
#[error(transparent)]
- ChannelRecv(#[from] smol::channel::RecvError),
+ ChannelRecv(#[from] async_channel::RecvError),
#[error(transparent)]
KaryonCore(#[from] karyon_core::error::Error),
#[error(transparent)]
- KaryonNet(#[from] karyon_net::NetError),
+ KaryonNet(#[from] karyon_net::Error),
+
+ #[error("Other Error: {0}")]
+ Other(String),
}
-impl<T> From<smol::channel::SendError<T>> for Error {
- fn from(error: smol::channel::SendError<T>) -> Self {
+impl<T> From<async_channel::SendError<T>> for Error {
+ fn from(error: async_channel::SendError<T>) -> Self {
Error::ChannelSend(error.to_string())
}
}
diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs
index 3605359..8f3cf45 100644
--- a/p2p/src/lib.rs
+++ b/p2p/src/lib.rs
@@ -5,7 +5,7 @@
//! use std::sync::Arc;
//!
//! use easy_parallel::Parallel;
-//! use smol::{channel as smol_channel, future, Executor};
+//! use smol::{future, Executor};
//!
//! use karyon_core::crypto::{KeyPair, KeyPairType};
//! use karyon_p2p::{Backend, Config, PeerID};
@@ -19,7 +19,7 @@
//! let ex = Arc::new(Executor::new());
//!
//! // Create a new Backend
-//! let backend = Backend::new(&key_pair, config, ex.clone());
+//! let backend = Backend::new(&key_pair, config, ex.clone().into());
//!
//! let task = async {
//! // Run the backend
diff --git a/p2p/src/listener.rs b/p2p/src/listener.rs
index 4a41482..1abf79a 100644
--- a/p2p/src/listener.rs
+++ b/p2p/src/listener.rs
@@ -3,13 +3,16 @@ use std::{future::Future, sync::Arc};
use log::{debug, error, info};
use karyon_core::{
- async_util::{Executor, TaskGroup, TaskResult},
+ async_runtime::Executor,
+ async_util::{TaskGroup, TaskResult},
crypto::KeyPair,
};
-use karyon_net::{tcp, tls, Conn, ConnListener, Endpoint};
+use karyon_net::{tcp, tls, Conn, Endpoint};
use crate::{
+ codec::NetMsgCodec,
+ message::NetMsg,
monitor::{ConnEvent, Monitor},
slots::ConnectionSlots,
tls_config::tls_server_config,
@@ -22,7 +25,7 @@ pub struct Listener {
key_pair: KeyPair,
/// Managing spawned tasks.
- task_group: TaskGroup<'static>,
+ task_group: TaskGroup,
/// Manages available inbound slots.
connection_slots: Arc<ConnectionSlots>,
@@ -41,7 +44,7 @@ impl Listener {
connection_slots: Arc<ConnectionSlots>,
enable_tls: bool,
monitor: Arc<Monitor>,
- ex: Executor<'static>,
+ ex: Executor,
) -> Arc<Self> {
Arc::new(Self {
key_pair: key_pair.clone(),
@@ -61,7 +64,7 @@ impl Listener {
self: &Arc<Self>,
endpoint: Endpoint,
// https://github.com/rust-lang/rfcs/pull/2132
- callback: impl FnOnce(Conn) -> Fut + Clone + Send + 'static,
+ callback: impl FnOnce(Conn<NetMsg>) -> Fut + Clone + Send + 'static,
) -> Result<Endpoint>
where
Fut: Future<Output = Result<()>> + Send + 'static,
@@ -82,7 +85,7 @@ impl Listener {
}
};
- let resolved_endpoint = listener.local_endpoint()?;
+ let resolved_endpoint = listener.local_endpoint().map_err(Error::from)?;
info!("Start listening on {resolved_endpoint}");
@@ -99,8 +102,8 @@ impl Listener {
async fn listen_loop<Fut>(
self: Arc<Self>,
- listener: Box<dyn ConnListener>,
- callback: impl FnOnce(Conn) -> Fut + Clone + Send + 'static,
+ listener: karyon_net::Listener<NetMsg>,
+ callback: impl FnOnce(Conn<NetMsg>) -> Fut + Clone + Send + 'static,
) where
Fut: Future<Output = Result<()>> + Send + 'static,
{
@@ -112,7 +115,7 @@ impl Listener {
let (conn, endpoint) = match result {
Ok(c) => {
let endpoint = match c.peer_endpoint() {
- Ok(e) => e,
+ Ok(ep) => ep,
Err(err) => {
self.monitor.notify(&ConnEvent::AcceptFailed.into()).await;
error!("Failed to accept a new connection: {err}");
@@ -151,16 +154,19 @@ impl Listener {
}
}
- async fn listen(&self, endpoint: &Endpoint) -> Result<karyon_net::Listener> {
+ async fn listen(&self, endpoint: &Endpoint) -> Result<karyon_net::Listener<NetMsg>> {
if self.enable_tls {
- let tls_config = tls_server_config(&self.key_pair)?;
- tls::listen(endpoint, tls_config)
+ let tls_config = tls::ServerTlsConfig {
+ tcp_config: Default::default(),
+ server_config: tls_server_config(&self.key_pair)?,
+ };
+ tls::listen(endpoint, tls_config, NetMsgCodec::new())
.await
- .map(|l| Box::new(l) as karyon_net::Listener)
+ .map(|l| Box::new(l) as karyon_net::Listener<NetMsg>)
} else {
- tcp::listen(endpoint)
+ tcp::listen(endpoint, tcp::TcpConfig::default(), NetMsgCodec::new())
.await
- .map(|l| Box::new(l) as karyon_net::Listener)
+ .map(|l| Box::new(l) as karyon_net::Listener<NetMsg>)
}
.map_err(Error::KaryonNet)
}
diff --git a/p2p/src/message.rs b/p2p/src/message.rs
index 1342110..6498ef7 100644
--- a/p2p/src/message.rs
+++ b/p2p/src/message.rs
@@ -2,15 +2,10 @@ use std::collections::HashMap;
use bincode::{Decode, Encode};
+use karyon_core::util::encode;
use karyon_net::{Addr, Port};
-use crate::{protocol::ProtocolID, routing_table::Entry, version::VersionInt, PeerID};
-
-/// The size of the message header, in bytes.
-pub const MSG_HEADER_SIZE: usize = 6;
-
-/// The maximum allowed size for a message in bytes.
-pub const MAX_ALLOWED_MSG_SIZE: u32 = 1024 * 1024; // 1MB
+use crate::{protocol::ProtocolID, routing_table::Entry, version::VersionInt, PeerID, Result};
/// Defines the main message in the karyon p2p network.
///
@@ -23,11 +18,19 @@ pub struct NetMsg {
pub payload: Vec<u8>,
}
+impl NetMsg {
+ pub fn new<T: Encode>(command: NetMsgCmd, t: T) -> Result<Self> {
+ Ok(Self {
+ header: NetMsgHeader { command },
+ payload: encode(&t)?,
+ })
+ }
+}
+
/// Represents the header of a message.
#[derive(Decode, Encode, Debug, Clone)]
pub struct NetMsgHeader {
pub command: NetMsgCmd,
- pub payload_size: u32,
}
/// Defines message commands.
@@ -39,7 +42,7 @@ pub enum NetMsgCmd {
Protocol,
Shutdown,
- // NOTE: The following commands are used during the lookup process.
+ // The following commands are used during the lookup process.
Ping,
Pong,
FindPeer,
@@ -47,6 +50,12 @@ pub enum NetMsgCmd {
Peers,
}
+#[derive(Decode, Encode, Debug, Clone)]
+pub enum RefreshMsg {
+ Ping([u8; 32]),
+ Pong([u8; 32]),
+}
+
/// Defines a message related to a specific protocol.
#[derive(Decode, Encode, Debug, Clone)]
pub struct ProtocolMsg {
@@ -103,21 +112,6 @@ pub struct PeerMsg {
#[derive(Decode, Encode, Debug)]
pub struct PeersMsg(pub Vec<PeerMsg>);
-macro_rules! get_msg_payload {
- ($a:ident, $b:ident) => {
- if let NetMsgCmd::$a = $b.header.command {
- $b.payload
- } else {
- return Err(Error::InvalidMsg(format!(
- "Unexpected msg {:?}",
- $b.header.command
- )));
- }
- };
-}
-
-pub(super) use get_msg_payload;
-
impl From<Entry> for PeerMsg {
fn from(entry: Entry) -> PeerMsg {
PeerMsg {
@@ -139,3 +133,19 @@ impl From<PeerMsg> for Entry {
}
}
}
+
+macro_rules! get_msg_payload {
+ ($a:ident, $b:ident) => {
+ if let NetMsgCmd::$a = $b.header.command {
+ $b.payload
+ } else {
+ return Err(Error::InvalidMsg(format!(
+ "Expected {:?} msg found {:?} msg",
+ stringify!($a),
+ $b.header.command
+ )));
+ }
+ };
+}
+
+pub(super) use get_msg_payload;
diff --git a/p2p/src/monitor.rs b/p2p/src/monitor.rs
index b0ce028..48719c0 100644
--- a/p2p/src/monitor.rs
+++ b/p2p/src/monitor.rs
@@ -26,7 +26,7 @@ use karyon_net::Endpoint;
/// let ex = Arc::new(Executor::new());
///
/// let key_pair = KeyPair::generate(&KeyPairType::Ed25519);
-/// let backend = Backend::new(&key_pair, Config::default(), ex);
+/// let backend = Backend::new(&key_pair, Config::default(), ex.into());
///
/// // Create a new Subscription
/// let sub = backend.monitor().await;
diff --git a/p2p/src/peer/mod.rs b/p2p/src/peer/mod.rs
index ca68530..f0f6f17 100644
--- a/p2p/src/peer/mod.rs
+++ b/p2p/src/peer/mod.rs
@@ -4,24 +4,22 @@ pub use peer_id::PeerID;
use std::sync::Arc;
+use async_channel::{Receiver, Sender};
+use bincode::{Decode, Encode};
use log::{error, trace};
-use smol::{
- channel::{self, Receiver, Sender},
- lock::RwLock,
-};
use karyon_core::{
- async_util::{select, Either, Executor, TaskGroup, TaskResult},
+ async_runtime::{lock::RwLock, Executor},
+ async_util::{select, Either, TaskGroup, TaskResult},
event::{ArcEventSys, EventListener, EventSys},
util::{decode, encode},
};
-use karyon_net::Endpoint;
+use karyon_net::{Conn, Endpoint};
use crate::{
- codec::{Codec, CodecMsg},
connection::ConnDirection,
- message::{NetMsgCmd, ProtocolMsg, ShutdownMsg},
+ message::{NetMsg, NetMsgCmd, ProtocolMsg, ShutdownMsg},
peer_pool::{ArcPeerPool, WeakPeerPool},
protocol::{Protocol, ProtocolEvent, ProtocolID},
Config, Error, Result,
@@ -36,8 +34,8 @@ pub struct Peer {
/// A weak pointer to `PeerPool`
peer_pool: WeakPeerPool,
- /// Holds the Codec for the peer connection
- codec: Codec,
+ /// Holds the peer connection
+ conn: Conn<NetMsg>,
/// Remote endpoint for the peer
remote_endpoint: Endpoint,
@@ -55,7 +53,7 @@ pub struct Peer {
stop_chan: (Sender<Result<()>>, Receiver<Result<()>>),
/// Managing spawned tasks.
- task_group: TaskGroup<'static>,
+ task_group: TaskGroup,
}
impl Peer {
@@ -63,21 +61,21 @@ impl Peer {
pub fn new(
peer_pool: WeakPeerPool,
id: &PeerID,
- codec: Codec,
+ conn: Conn<NetMsg>,
remote_endpoint: Endpoint,
conn_direction: ConnDirection,
- ex: Executor<'static>,
+ ex: Executor,
) -> ArcPeer {
Arc::new(Peer {
id: id.clone(),
peer_pool,
- codec,
+ conn,
protocol_ids: RwLock::new(Vec::new()),
remote_endpoint,
conn_direction,
protocol_events: EventSys::new(),
task_group: TaskGroup::with_executor(ex),
- stop_chan: channel::bounded(1),
+ stop_chan: async_channel::bounded(1),
})
}
@@ -88,7 +86,7 @@ impl Peer {
}
/// Send a message to the peer connection using the specified protocol.
- pub async fn send<T: CodecMsg>(&self, protocol_id: &ProtocolID, msg: &T) -> Result<()> {
+ pub async fn send<T: Encode + Decode>(&self, protocol_id: &ProtocolID, msg: &T) -> Result<()> {
let payload = encode(msg)?;
let proto_msg = ProtocolMsg {
@@ -96,12 +94,14 @@ impl Peer {
payload: payload.to_vec(),
};
- self.codec.write(NetMsgCmd::Protocol, &proto_msg).await?;
+ self.conn
+ .send(NetMsg::new(NetMsgCmd::Protocol, &proto_msg)?)
+ .await?;
Ok(())
}
/// Broadcast a message to all connected peers using the specified protocol.
- pub async fn broadcast<T: CodecMsg>(&self, protocol_id: &ProtocolID, msg: &T) {
+ pub async fn broadcast<T: Encode + Decode>(&self, protocol_id: &ProtocolID, msg: &T) {
self.peer_pool().broadcast(protocol_id, msg).await;
}
@@ -123,7 +123,9 @@ impl Peer {
let _ = self.stop_chan.0.try_send(Ok(()));
// No need to handle the error here
- let _ = self.codec.write(NetMsgCmd::Shutdown, &ShutdownMsg(0)).await;
+ let shutdown_msg =
+ NetMsg::new(NetMsgCmd::Shutdown, &ShutdownMsg(0)).expect("pack shutdown message");
+ let _ = self.conn.send(shutdown_msg).await;
// Force shutting down
self.task_group.cancel().await;
@@ -170,7 +172,7 @@ impl Peer {
/// Start a read loop to handle incoming messages from the peer connection.
async fn read_loop(&self) -> Result<()> {
loop {
- let fut = select(self.stop_chan.1.recv(), self.codec.read()).await;
+ let fut = select(self.stop_chan.1.recv(), self.conn.recv()).await;
let result = match fut {
Either::Left(stop_signal) => {
trace!("Peer {} received a stop signal", self.id);
diff --git a/p2p/src/peer_pool.rs b/p2p/src/peer_pool.rs
index 4e20c99..8b16ef5 100644
--- a/p2p/src/peer_pool.rs
+++ b/p2p/src/peer_pool.rs
@@ -4,21 +4,22 @@ use std::{
time::Duration,
};
+use async_channel::Sender;
+use bincode::{Decode, Encode};
use log::{error, info, trace, warn};
-use smol::{
- channel::Sender,
- lock::{Mutex, RwLock},
-};
use karyon_core::{
- async_util::{Executor, TaskGroup, TaskResult},
+ async_runtime::{
+ lock::{Mutex, RwLock},
+ Executor,
+ },
+ async_util::{timeout, TaskGroup, TaskResult},
util::decode,
};
use karyon_net::Conn;
use crate::{
- codec::{Codec, CodecMsg},
config::Config,
connection::{ConnDirection, ConnQueue},
message::{get_msg_payload, NetMsg, NetMsgCmd, VerAckMsg, VerMsg},
@@ -50,10 +51,10 @@ pub struct PeerPool {
protocol_versions: Arc<RwLock<HashMap<ProtocolID, Version>>>,
/// Managing spawned tasks.
- task_group: TaskGroup<'static>,
+ task_group: TaskGroup,
/// A global Executor
- executor: Executor<'static>,
+ executor: Executor,
/// The Configuration for the P2P network.
pub(crate) config: Arc<Config>,
@@ -69,7 +70,7 @@ impl PeerPool {
conn_queue: Arc<ConnQueue>,
config: Arc<Config>,
monitor: Arc<Monitor>,
- executor: Executor<'static>,
+ executor: Executor,
) -> Arc<Self> {
let protocols = RwLock::new(HashMap::new());
let protocol_versions = Arc::new(RwLock::new(HashMap::new()));
@@ -137,7 +138,7 @@ impl PeerPool {
}
/// Broadcast a message to all connected peers using the specified protocol.
- pub async fn broadcast<T: CodecMsg>(&self, proto_id: &ProtocolID, msg: &T) {
+ pub async fn broadcast<T: Decode + Encode>(&self, proto_id: &ProtocolID, msg: &T) {
for (pid, peer) in self.peers.lock().await.iter() {
if let Err(err) = peer.send(proto_id, msg).await {
error!("failed to send msg to {pid}: {err}");
@@ -149,15 +150,14 @@ impl PeerPool {
/// Add a new peer to the peer list.
pub async fn new_peer(
self: &Arc<Self>,
- conn: Conn,
+ conn: Conn<NetMsg>,
conn_direction: &ConnDirection,
disconnect_signal: Sender<Result<()>>,
) -> Result<()> {
let endpoint = conn.peer_endpoint()?;
- let codec = Codec::new(conn);
// Do a handshake with the connection before creating a new peer.
- let pid = self.do_handshake(&codec, conn_direction).await?;
+ let pid = self.do_handshake(&conn, conn_direction).await?;
// TODO: Consider restricting the subnet for inbound connections
if self.contains_peer(&pid).await {
@@ -168,7 +168,7 @@ impl PeerPool {
let peer = Peer::new(
Arc::downgrade(self),
&pid,
- codec,
+ conn,
endpoint.clone(),
conn_direction.clone(),
self.executor.clone(),
@@ -234,16 +234,21 @@ impl PeerPool {
}
/// Initiate a handshake with a connection.
- async fn do_handshake(&self, codec: &Codec, conn_direction: &ConnDirection) -> Result<PeerID> {
+ async fn do_handshake(
+ &self,
+ conn: &Conn<NetMsg>,
+ conn_direction: &ConnDirection,
+ ) -> Result<PeerID> {
+ trace!("Handshake started: {}", conn.peer_endpoint()?);
match conn_direction {
ConnDirection::Inbound => {
- let result = self.wait_vermsg(codec).await;
+ let result = self.wait_vermsg(conn).await;
match result {
Ok(_) => {
- self.send_verack(codec, true).await?;
+ self.send_verack(conn, true).await?;
}
Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => {
- self.send_verack(codec, false).await?;
+ self.send_verack(conn, false).await?;
}
_ => {}
}
@@ -251,14 +256,14 @@ impl PeerPool {
}
ConnDirection::Outbound => {
- self.send_vermsg(codec).await?;
- self.wait_verack(codec).await
+ self.send_vermsg(conn).await?;
+ self.wait_verack(conn).await
}
}
}
/// Send a Version message
- async fn send_vermsg(&self, codec: &Codec) -> Result<()> {
+ async fn send_vermsg(&self, conn: &Conn<NetMsg>) -> Result<()> {
let pids = self.protocol_versions.read().await;
let protocols = pids.iter().map(|p| (p.0.clone(), p.1.v.clone())).collect();
drop(pids);
@@ -270,16 +275,16 @@ impl PeerPool {
};
trace!("Send VerMsg");
- codec.write(NetMsgCmd::Version, &vermsg).await?;
+ conn.send(NetMsg::new(NetMsgCmd::Version, &vermsg)?).await?;
Ok(())
}
/// Wait for a Version message
///
/// Returns the peer's ID upon successfully receiving the Version message.
- async fn wait_vermsg(&self, codec: &Codec) -> Result<PeerID> {
- let timeout = Duration::from_secs(self.config.handshake_timeout);
- let msg: NetMsg = codec.read_timeout(timeout).await?;
+ async fn wait_vermsg(&self, conn: &Conn<NetMsg>) -> Result<PeerID> {
+ let t = Duration::from_secs(self.config.handshake_timeout);
+ let msg: NetMsg = timeout(t, conn.recv()).await??;
let payload = get_msg_payload!(Version, msg);
let (vermsg, _) = decode::<VerMsg>(&payload)?;
@@ -295,23 +300,23 @@ impl PeerPool {
}
/// Send a Verack message
- async fn send_verack(&self, codec: &Codec, ack: bool) -> Result<()> {
+ async fn send_verack(&self, conn: &Conn<NetMsg>, ack: bool) -> Result<()> {
let verack = VerAckMsg {
peer_id: self.id.clone(),
ack,
};
trace!("Send VerAckMsg {:?}", verack);
- codec.write(NetMsgCmd::Verack, &verack).await?;
+ conn.send(NetMsg::new(NetMsgCmd::Verack, &verack)?).await?;
Ok(())
}
/// Wait for a Verack message
///
/// Returns the peer's ID upon successfully receiving the Verack message.
- async fn wait_verack(&self, codec: &Codec) -> Result<PeerID> {
- let timeout = Duration::from_secs(self.config.handshake_timeout);
- let msg: NetMsg = codec.read_timeout(timeout).await?;
+ async fn wait_verack(&self, conn: &Conn<NetMsg>) -> Result<PeerID> {
+ let t = Duration::from_secs(self.config.handshake_timeout);
+ let msg: NetMsg = timeout(t, conn.recv()).await??;
let payload = get_msg_payload!(Verack, msg);
let (verack, _) = decode::<VerAckMsg>(&payload)?;
diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs
index f28659c..6153ea1 100644
--- a/p2p/src/protocol.rs
+++ b/p2p/src/protocol.rs
@@ -92,7 +92,7 @@ impl EventValue for ProtocolEvent {
/// let ex = Arc::new(Executor::new());
///
/// // Create a new Backend
-/// let backend = Backend::new(&key_pair, config, ex);
+/// let backend = Backend::new(&key_pair, config, ex.into());
///
/// // Attach the NewProtocol
/// let c = move |peer| NewProtocol::new(peer);
diff --git a/p2p/src/protocols/ping.rs b/p2p/src/protocols/ping.rs
index f04e059..654644a 100644
--- a/p2p/src/protocols/ping.rs
+++ b/p2p/src/protocols/ping.rs
@@ -1,23 +1,19 @@
use std::{sync::Arc, time::Duration};
+use async_channel::{Receiver, Sender};
use async_trait::async_trait;
use bincode::{Decode, Encode};
use log::trace;
use rand::{rngs::OsRng, RngCore};
-use smol::{
- channel,
- channel::{Receiver, Sender},
- stream::StreamExt,
- Timer,
-};
use karyon_core::{
- async_util::{select, timeout, Either, Executor, TaskGroup, TaskResult},
+ async_runtime::Executor,
+ async_util::{select, sleep, timeout, Either, TaskGroup, TaskResult},
event::EventListener,
util::decode,
};
-use karyon_net::NetError;
+use karyon_net::Error as NetError;
use crate::{
peer::ArcPeer,
@@ -38,12 +34,12 @@ pub struct PingProtocol {
peer: ArcPeer,
ping_interval: u64,
ping_timeout: u64,
- task_group: TaskGroup<'static>,
+ task_group: TaskGroup,
}
impl PingProtocol {
#[allow(clippy::new_ret_no_self)]
- pub fn new(peer: ArcPeer, executor: Executor<'static>) -> ArcProtocol {
+ pub fn new(peer: ArcPeer, executor: Executor) -> ArcProtocol {
let ping_interval = peer.config().ping_interval;
let ping_timeout = peer.config().ping_timeout;
Arc::new(Self {
@@ -87,12 +83,11 @@ impl PingProtocol {
}
async fn ping_loop(self: Arc<Self>, chan: Receiver<[u8; 32]>) -> Result<()> {
- let mut timer = Timer::interval(Duration::from_secs(self.ping_interval));
let rng = &mut OsRng;
let mut retry = 0;
while retry < MAX_FAILUERS {
- timer.next().await;
+ sleep(Duration::from_secs(self.ping_interval)).await;
let mut ping_nonce: [u8; 32] = [0; 32];
rng.fill_bytes(&mut ping_nonce);
@@ -130,8 +125,8 @@ impl Protocol for PingProtocol {
async fn start(self: Arc<Self>) -> Result<()> {
trace!("Start Ping protocol");
- let (pong_chan, pong_chan_recv) = channel::bounded(1);
- let (stop_signal_s, stop_signal) = channel::bounded::<Result<()>>(1);
+ let (pong_chan, pong_chan_recv) = async_channel::bounded(1);
+ let (stop_signal_s, stop_signal) = async_channel::bounded::<Result<()>>(1);
let selfc = self.clone();
self.task_group.spawn(
diff --git a/p2p/src/routing_table/entry.rs b/p2p/src/routing_table/entry.rs
index 3fc8a6b..1427c2b 100644
--- a/p2p/src/routing_table/entry.rs
+++ b/p2p/src/routing_table/entry.rs
@@ -5,6 +5,9 @@ use karyon_net::{Addr, Port};
/// Specifies the size of the key, in bytes.
pub const KEY_SIZE: usize = 32;
+/// The unique key identifying the peer.
+pub type Key = [u8; KEY_SIZE];
+
/// An Entry represents a peer in the routing table.
#[derive(Encode, Decode, Clone, Debug)]
pub struct Entry {
@@ -20,14 +23,11 @@ pub struct Entry {
impl PartialEq for Entry {
fn eq(&self, other: &Self) -> bool {
- // TODO: this should also compare both addresses (the self.addr == other.addr)
+ // XXX: should we compare both self.addr and other.addr???
self.key == other.key
}
}
-/// The unique key identifying the peer.
-pub type Key = [u8; KEY_SIZE];
-
/// Calculates the XOR distance between two provided keys.
///
/// The XOR distance is a metric used in Kademlia to measure the closeness
diff --git a/p2p/src/routing_table/mod.rs b/p2p/src/routing_table/mod.rs
index 6854546..bbf4801 100644
--- a/p2p/src/routing_table/mod.rs
+++ b/p2p/src/routing_table/mod.rs
@@ -266,7 +266,7 @@ impl RoutingTable {
}
/// Check if two addresses belong to the same subnet.
-pub fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool {
+fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool {
match (addr, other_addr) {
(Addr::Ip(IpAddr::V4(ip)), Addr::Ip(IpAddr::V4(other_ip))) => {
// TODO: Consider moving this to a different place
@@ -275,6 +275,8 @@ pub fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool {
}
ip.octets()[0..3] == other_ip.octets()[0..3]
}
+
+ // TODO: check ipv6
_ => false,
}
}
diff --git a/p2p/src/tls_config.rs b/p2p/src/tls_config.rs
index 893c321..65d2adc 100644
--- a/p2p/src/tls_config.rs
+++ b/p2p/src/tls_config.rs
@@ -1,19 +1,25 @@
use std::sync::Arc;
-use futures_rustls::rustls::{
- self,
+#[cfg(feature = "smol")]
+use futures_rustls::rustls;
+
+#[cfg(feature = "tokio")]
+use tokio_rustls::rustls;
+
+use rustls::{
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
crypto::{
aws_lc_rs::{self, cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, kx_group},
CryptoProvider, SupportedKxGroup,
},
- pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime},
server::danger::{ClientCertVerified, ClientCertVerifier},
CertificateError, DigitallySignedStruct, DistinguishedName,
Error::InvalidCertificate,
SignatureScheme, SupportedCipherSuite, SupportedProtocolVersion,
};
+use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
+
use log::error;
use x509_parser::{certificate::X509Certificate, parse_x509_certificate};