From 0992071a7f1a36424bcfaf1fbc84541ea041df1a Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 11 Apr 2024 10:19:20 +0200 Subject: add support for tokio & improve net crate api --- p2p/Cargo.toml | 40 +++++------ p2p/README.md | 11 ++- p2p/examples/chat.rs | 4 +- p2p/examples/monitor.rs | 2 +- p2p/examples/peer.rs | 2 +- p2p/examples/shared/mod.rs | 8 +-- p2p/src/backend.rs | 4 +- p2p/src/codec.rs | 149 ++++++++++++++--------------------------- p2p/src/connection.rs | 12 ++-- p2p/src/connector.rs | 35 +++++++--- p2p/src/discovery/lookup.rs | 71 +++++++++++--------- p2p/src/discovery/mod.rs | 45 +++++++------ p2p/src/discovery/refresh.rs | 92 +++++++++++++------------ p2p/src/error.rs | 18 +++-- p2p/src/lib.rs | 4 +- p2p/src/listener.rs | 36 +++++----- p2p/src/message.rs | 58 +++++++++------- p2p/src/monitor.rs | 2 +- p2p/src/peer/mod.rs | 42 ++++++------ p2p/src/peer_pool.rs | 65 +++++++++--------- p2p/src/protocol.rs | 2 +- p2p/src/protocols/ping.rs | 23 +++---- p2p/src/routing_table/entry.rs | 8 +-- p2p/src/routing_table/mod.rs | 4 +- p2p/src/tls_config.rs | 12 +++- 25 files changed, 381 insertions(+), 368 deletions(-) (limited to 'p2p') diff --git a/p2p/Cargo.toml b/p2p/Cargo.toml index fc14de2..3327810 100644 --- a/p2p/Cargo.toml +++ b/p2p/Cargo.toml @@ -1,42 +1,43 @@ [package] name = "karyon_p2p" -version.workspace = true +version.workspace = true edition.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +default = ["smol"] +smol = ["karyon_core/smol", "karyon_net/smol", "dep:futures-rustls"] +tokio = ["karyon_core/tokio", "karyon_net/tokio", "dep:tokio-rustls"] + [dependencies] -karyon_core = { workspace = true, features=["crypto"] } -karyon_net.workspace = true +karyon_core = { workspace = true, features = [ + "crypto", +], default-features = false } +karyon_net = { workspace = true, default-features = false } -smol = "2.0.0" async-trait = "0.1.77" -futures-util = {version = "0.3.5", features=["std"], default-features = false } +async-channel = "2.3.0" +futures-util = { version = "0.3.5", features = [ + "std", +], default-features = false } log = "0.4.21" chrono = "0.4.35" -bincode = { version="2.0.0-rc.3", features = ["derive"]} +bincode = { version = "2.0.0-rc.3", features = ["derive"] } rand = "0.8.5" thiserror = "1.0.58" semver = "1.0.22" sha2 = "0.10.8" # tls -futures-rustls = { version = "0.25.1", features = ["aws-lc-rs"] } rcgen = "0.12.1" yasna = "0.5.2" x509-parser = "0.16.0" - -[[example]] -name = "peer" -path = "examples/peer.rs" - -[[example]] -name = "chat" -path = "examples/chat.rs" - -[[example]] -name = "monitor" -path = "examples/monitor.rs" +futures-rustls = { version = "0.25.1", features = [ + "aws-lc-rs", +], optional = true } +tokio-rustls = { version = "0.26.0", features = ["aws-lc-rs"], optional = true } +rustls-pki-types = "1.7.0" [dev-dependencies] async-std = "1.12.0" @@ -44,3 +45,4 @@ clap = { version = "4.5.2", features = ["derive"] } ctrlc = "3.4.4" easy-parallel = "3.3.1" env_logger = "0.11.3" +smol = "2.0.0" diff --git a/p2p/README.md b/p2p/README.md index e00d9e5..768fd19 100644 --- a/p2p/README.md +++ b/p2p/README.md @@ -1,6 +1,6 @@ # karyon p2p -karyon p2p serves as the foundational stack for the karyon project. It offers +karyon p2p serves as the foundational stack for the Karyon project. It offers a lightweight, extensible, and customizable peer-to-peer (p2p) network stack that seamlessly integrates with any p2p project. @@ -130,7 +130,14 @@ boolean `enable_tls` field in the configuration. However, implementing TLS for a P2P network is not trivial and is still unstable, requiring a comprehensive audit. -## Usage + +## Choosing the async runtime + +karyon p2p currently supports both smol(async-std) and tokio. The default is +smol, but if you want to use tokio, you need to disable the default features +and then select the `tokio` feature. + +## Examples You can check out the examples [here](./examples). diff --git a/p2p/examples/chat.rs b/p2p/examples/chat.rs index cc822d9..4eafb07 100644 --- a/p2p/examples/chat.rs +++ b/p2p/examples/chat.rs @@ -121,7 +121,7 @@ fn main() { let ex = Arc::new(Executor::new()); // Create a new Backend - let backend = Backend::new(&key_pair, config, ex.clone()); + let backend = Backend::new(&key_pair, config, ex.clone().into()); let (ctrlc_s, ctrlc_r) = channel::unbounded(); let handle = move || ctrlc_s.try_send(()).unwrap(); @@ -133,7 +133,7 @@ fn main() { let username = cli.username; // Attach the ChatProtocol - let c = move |peer| ChatProtocol::new(&username, peer, ex_cloned.clone()); + let c = move |peer| ChatProtocol::new(&username, peer, ex_cloned.clone().into()); backend.attach_protocol::(c).await.unwrap(); // Run the backend diff --git a/p2p/examples/monitor.rs b/p2p/examples/monitor.rs index b074352..019f751 100644 --- a/p2p/examples/monitor.rs +++ b/p2p/examples/monitor.rs @@ -51,7 +51,7 @@ fn main() { let ex = Arc::new(Executor::new()); // Create a new Backend - let backend = Backend::new(&key_pair, config, ex.clone()); + let backend = Backend::new(&key_pair, config, ex.clone().into()); let (ctrlc_s, ctrlc_r) = channel::unbounded(); let handle = move || ctrlc_s.try_send(()).unwrap(); diff --git a/p2p/examples/peer.rs b/p2p/examples/peer.rs index 06586b6..db747c9 100644 --- a/p2p/examples/peer.rs +++ b/p2p/examples/peer.rs @@ -51,7 +51,7 @@ fn main() { let ex = Arc::new(Executor::new()); // Create a new Backend - let backend = Backend::new(&key_pair, config, ex.clone()); + let backend = Backend::new(&key_pair, config, ex.clone().into()); let (ctrlc_s, ctrlc_r) = channel::unbounded(); let handle = move || ctrlc_s.try_send(()).unwrap(); diff --git a/p2p/examples/shared/mod.rs b/p2p/examples/shared/mod.rs index 8065e63..57d89ef 100644 --- a/p2p/examples/shared/mod.rs +++ b/p2p/examples/shared/mod.rs @@ -1,9 +1,7 @@ -use std::{num::NonZeroUsize, thread}; +use std::{num::NonZeroUsize, sync::Arc, thread}; use easy_parallel::Parallel; -use smol::{channel, future, future::Future}; - -use karyon_core::async_util::Executor; +use smol::{channel, future, future::Future, Executor}; /// Returns an estimate of the default amount of parallelism a program should use. /// see `std::thread::available_parallelism` @@ -14,7 +12,7 @@ fn available_parallelism() -> usize { } /// Run a multi-threaded executor -pub fn run_executor(main_future: impl Future, ex: Executor<'_>) { +pub fn run_executor(main_future: impl Future, ex: Arc>) { let (signal, shutdown) = channel::unbounded::<()>(); let num_threads = available_parallelism(); diff --git a/p2p/src/backend.rs b/p2p/src/backend.rs index d33f3dc..2f21b3e 100644 --- a/p2p/src/backend.rs +++ b/p2p/src/backend.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use log::info; -use karyon_core::{async_util::Executor, crypto::KeyPair, pubsub::Subscription}; +use karyon_core::{async_runtime::Executor, crypto::KeyPair, pubsub::Subscription}; use crate::{ config::Config, @@ -37,7 +37,7 @@ pub struct Backend { impl Backend { /// Creates a new Backend. - pub fn new(key_pair: &KeyPair, config: Config, ex: Executor<'static>) -> ArcBackend { + pub fn new(key_pair: &KeyPair, config: Config, ex: Executor) -> ArcBackend { let config = Arc::new(config); let monitor = Arc::new(Monitor::new()); let conn_queue = ConnQueue::new(); diff --git a/p2p/src/codec.rs b/p2p/src/codec.rs index 726a2f7..3d0f323 100644 --- a/p2p/src/codec.rs +++ b/p2p/src/codec.rs @@ -1,120 +1,69 @@ -use std::time::Duration; +use karyon_core::util::{decode, encode, encode_into_slice}; -use bincode::{Decode, Encode}; - -use karyon_core::{ - async_util::timeout, - util::{decode, encode, encode_into_slice}, -}; - -use karyon_net::{Connection, NetError}; - -use crate::{ - message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE}, - Error, Result, +use karyon_net::{ + codec::{Codec, Decoder, Encoder, LengthCodec}, + Result, }; -pub trait CodecMsg: Decode + Encode + std::fmt::Debug {} -impl CodecMsg for T {} +use crate::message::{NetMsg, RefreshMsg}; -/// A Codec working with generic network connections. -/// -/// It is responsible for both decoding data received from the network and -/// encoding data before sending it. -pub struct Codec { - conn: Box, +#[derive(Clone)] +pub struct NetMsgCodec { + inner_codec: LengthCodec, } -impl Codec { - /// Creates a new Codec. - pub fn new(conn: Box) -> Self { - Self { conn } - } - - /// Reads a message of type `NetMsg` from the connection. - /// - /// It reads the first 6 bytes as the header of the message, then reads - /// and decodes the remaining message data based on the determined header. - pub async fn read(&self) -> Result { - // Read 6 bytes to get the header of the incoming message - let mut buf = [0; MSG_HEADER_SIZE]; - self.read_exact(&mut buf).await?; - - // Decode the header from bytes to NetMsgHeader - let (header, _) = decode::(&buf)?; - - if header.payload_size > MAX_ALLOWED_MSG_SIZE { - return Err(Error::InvalidMsg( - "Message exceeds the maximum allowed size".to_string(), - )); +impl NetMsgCodec { + pub fn new() -> Self { + Self { + inner_codec: LengthCodec {}, } - - // Create a buffer to hold the message based on its length - let mut payload = vec![0; header.payload_size as usize]; - self.read_exact(&mut payload).await?; - - Ok(NetMsg { header, payload }) } +} - /// Writes a message of type `T` to the connection. - /// - /// Before appending the actual message payload, it calculates the length of - /// the encoded message in bytes and appends this length to the message header. - pub async fn write(&self, command: NetMsgCmd, msg: &T) -> Result<()> { - let payload = encode(msg)?; - - // Create a buffer to hold the message header (6 bytes) - let header_buf = &mut [0; MSG_HEADER_SIZE]; - let header = NetMsgHeader { - command, - payload_size: payload.len() as u32, - }; - encode_into_slice(&header, header_buf)?; - - let mut buffer = vec![]; - // Append the header bytes to the buffer - buffer.extend_from_slice(header_buf); - // Append the message payload to the buffer - buffer.extend_from_slice(&payload); - - self.write_all(&buffer).await?; - Ok(()) - } +impl Codec for NetMsgCodec { + type Item = NetMsg; +} - /// Reads a message of type `NetMsg` with the given timeout. - pub async fn read_timeout(&self, duration: Duration) -> Result { - timeout(duration, self.read()) - .await - .map_err(|_| NetError::Timeout)? +impl Encoder for NetMsgCodec { + type EnItem = NetMsg; + fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result { + let src = encode(src)?; + self.inner_codec.encode(&src, dst) } +} - /// Reads the exact number of bytes required to fill `buf`. - async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> { - while !buf.is_empty() { - let n = self.conn.read(buf).await?; - let (_, rest) = std::mem::take(&mut buf).split_at_mut(n); - buf = rest; - - if n == 0 { - return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); +impl Decoder for NetMsgCodec { + type DeItem = NetMsg; + fn decode(&self, src: &mut [u8]) -> Result> { + match self.inner_codec.decode(src)? { + Some((n, s)) => { + let (m, _) = decode::(&s)?; + Ok(Some((n, m))) } + None => Ok(None), } - - Ok(()) } +} - /// Writes an entire buffer into the connection. - async fn write_all(&self, mut buf: &[u8]) -> Result<()> { - while !buf.is_empty() { - let n = self.conn.write(buf).await?; - let (_, rest) = std::mem::take(&mut buf).split_at(n); - buf = rest; +#[derive(Clone)] +pub struct RefreshMsgCodec {} - if n == 0 { - return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); - } - } +impl Codec for RefreshMsgCodec { + type Item = RefreshMsg; +} + +impl Encoder for RefreshMsgCodec { + type EnItem = RefreshMsg; + fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result { + let n = encode_into_slice(src, dst)?; + Ok(n) + } +} - Ok(()) +impl Decoder for RefreshMsgCodec { + type DeItem = RefreshMsg; + fn decode(&self, src: &mut [u8]) -> Result> { + let (m, n) = decode::(src)?; + Ok(Some((n, m))) } } diff --git a/p2p/src/connection.rs b/p2p/src/connection.rs index 9fa57cb..9a153f3 100644 --- a/p2p/src/connection.rs +++ b/p2p/src/connection.rs @@ -1,11 +1,11 @@ use std::{collections::VecDeque, fmt, sync::Arc}; -use smol::{channel::Sender, lock::Mutex}; +use async_channel::Sender; -use karyon_core::async_util::CondVar; +use karyon_core::{async_runtime::lock::Mutex, async_util::CondVar}; use karyon_net::Conn; -use crate::Result; +use crate::{message::NetMsg, Result}; /// Defines the direction of a network connection. #[derive(Clone, Debug)] @@ -25,7 +25,7 @@ impl fmt::Display for ConnDirection { pub struct NewConn { pub direction: ConnDirection, - pub conn: Conn, + pub conn: Conn, pub disconnect_signal: Sender>, } @@ -44,8 +44,8 @@ impl ConnQueue { } /// Push a connection into the queue and wait for the disconnect signal - pub async fn handle(&self, conn: Conn, direction: ConnDirection) -> Result<()> { - let (disconnect_signal, chan) = smol::channel::bounded(1); + pub async fn handle(&self, conn: Conn, direction: ConnDirection) -> Result<()> { + let (disconnect_signal, chan) = async_channel::bounded(1); let new_conn = NewConn { direction, conn, diff --git a/p2p/src/connector.rs b/p2p/src/connector.rs index de9e746..aea21ab 100644 --- a/p2p/src/connector.rs +++ b/p2p/src/connector.rs @@ -3,12 +3,15 @@ use std::{future::Future, sync::Arc}; use log::{error, trace, warn}; use karyon_core::{ - async_util::{Backoff, Executor, TaskGroup, TaskResult}, + async_runtime::Executor, + async_util::{Backoff, TaskGroup, TaskResult}, crypto::KeyPair, }; -use karyon_net::{tcp, tls, Conn, Endpoint, NetError}; +use karyon_net::{tcp, tls, Conn, Endpoint, Error as NetError}; use crate::{ + codec::NetMsgCodec, + message::NetMsg, monitor::{ConnEvent, Monitor}, slots::ConnectionSlots, tls_config::tls_client_config, @@ -23,7 +26,7 @@ pub struct Connector { key_pair: KeyPair, /// Managing spawned tasks. - task_group: TaskGroup<'static>, + task_group: TaskGroup, /// Manages available outbound slots. connection_slots: Arc, @@ -47,7 +50,7 @@ impl Connector { connection_slots: Arc, enable_tls: bool, monitor: Arc, - ex: Executor<'static>, + ex: Executor, ) -> Arc { Arc::new(Self { key_pair: key_pair.clone(), @@ -70,7 +73,11 @@ impl Connector { /// `Conn` instance. /// /// This method will block until it finds an available slot. - pub async fn connect(&self, endpoint: &Endpoint, peer_id: &Option) -> Result { + pub async fn connect( + &self, + endpoint: &Endpoint, + peer_id: &Option, + ) -> Result> { self.connection_slots.wait_for_slot().await; self.connection_slots.add(); @@ -113,7 +120,7 @@ impl Connector { self: &Arc, endpoint: &Endpoint, peer_id: &Option, - callback: impl FnOnce(Conn) -> Fut + Send + 'static, + callback: impl FnOnce(Conn) -> Fut + Send + 'static, ) -> Result<()> where Fut: Future> + Send + 'static, @@ -138,14 +145,20 @@ impl Connector { Ok(()) } - async fn dial(&self, endpoint: &Endpoint, peer_id: &Option) -> Result { + async fn dial(&self, endpoint: &Endpoint, peer_id: &Option) -> Result> { if self.enable_tls { - let tls_config = tls_client_config(&self.key_pair, peer_id.clone())?; - tls::dial(endpoint, tls_config, DNS_NAME) + let tls_config = tls::ClientTlsConfig { + tcp_config: Default::default(), + client_config: tls_client_config(&self.key_pair, peer_id.clone())?, + dns_name: DNS_NAME.to_string(), + }; + tls::dial(endpoint, tls_config, NetMsgCodec::new()) .await - .map(|l| Box::new(l) as Conn) + .map(|l| Box::new(l) as karyon_net::Conn) } else { - tcp::dial(endpoint).await.map(|l| Box::new(l) as Conn) + tcp::dial(endpoint, tcp::TcpConfig::default(), NetMsgCodec::new()) + .await + .map(|l| Box::new(l) as karyon_net::Conn) } .map_err(Error::KaryonNet) } diff --git a/p2p/src/discovery/lookup.rs b/p2p/src/discovery/lookup.rs index c81fbc6..cff4610 100644 --- a/p2p/src/discovery/lookup.rs +++ b/p2p/src/discovery/lookup.rs @@ -3,10 +3,13 @@ use std::{sync::Arc, time::Duration}; use futures_util::{stream::FuturesUnordered, StreamExt}; use log::{error, trace}; use rand::{rngs::OsRng, seq::SliceRandom, RngCore}; -use smol::lock::{Mutex, RwLock}; use karyon_core::{ - async_util::{timeout, Executor}, + async_runtime::{ + lock::{Mutex, RwLock}, + Executor, + }, + async_util::timeout, crypto::KeyPair, util::decode, }; @@ -14,7 +17,6 @@ use karyon_core::{ use karyon_net::{Conn, Endpoint}; use crate::{ - codec::Codec, connector::Connector, listener::Listener, message::{ @@ -64,7 +66,7 @@ impl LookupService { table: Arc>, config: Arc, monitor: Arc, - ex: Executor<'static>, + ex: Executor, ) -> Self { let inbound_slots = Arc::new(ConnectionSlots::new(config.lookup_inbound_slots)); let outbound_slots = Arc::new(ConnectionSlots::new(config.lookup_outbound_slots)); @@ -228,8 +230,7 @@ impl LookupService { target_peer_id: &PeerID, ) -> Result> { let conn = self.connector.connect(&endpoint, &peer_id).await?; - let io_codec = Codec::new(conn); - let result = self.handle_outbound(io_codec, target_peer_id).await; + let result = self.handle_outbound(conn, target_peer_id).await; self.monitor .notify(&ConnEvent::Disconnected(endpoint).into()) @@ -242,14 +243,14 @@ impl LookupService { /// Handles outbound connection async fn handle_outbound( &self, - io_codec: Codec, + conn: Conn, target_peer_id: &PeerID, ) -> Result> { trace!("Send Ping msg"); - self.send_ping_msg(&io_codec).await?; + self.send_ping_msg(&conn).await?; trace!("Send FindPeer msg"); - let peers = self.send_findpeer_msg(&io_codec, target_peer_id).await?; + let peers = self.send_findpeer_msg(&conn, target_peer_id).await?; if peers.0.len() >= MAX_PEERS_IN_PEERSMSG { return Err(Error::Lookup("Received too many peers in PeersMsg")); @@ -257,12 +258,12 @@ impl LookupService { trace!("Send Peer msg"); if let Some(endpoint) = &self.listen_endpoint { - self.send_peer_msg(&io_codec, endpoint.read().await.clone()) + self.send_peer_msg(&conn, endpoint.read().await.clone()) .await?; } trace!("Send Shutdown msg"); - self.send_shutdown_msg(&io_codec).await?; + self.send_shutdown_msg(&conn).await?; Ok(peers.0) } @@ -277,7 +278,7 @@ impl LookupService { let endpoint = Endpoint::Tcp(addr, self.config.discovery_port); let selfc = self.clone(); - let callback = |conn: Conn| async move { + let callback = |conn: Conn| async move { let t = Duration::from_secs(selfc.config.lookup_connection_lifespan); timeout(t, selfc.handle_inbound(conn)).await??; Ok(()) @@ -288,10 +289,9 @@ impl LookupService { } /// Handles inbound connection - async fn handle_inbound(self: &Arc, conn: Conn) -> Result<()> { - let io_codec = Codec::new(conn); + async fn handle_inbound(self: &Arc, conn: Conn) -> Result<()> { loop { - let msg: NetMsg = io_codec.read().await?; + let msg: NetMsg = conn.recv().await?; trace!("Receive msg {:?}", msg.header.command); if let NetMsgCmd::Shutdown = msg.header.command { @@ -304,12 +304,12 @@ impl LookupService { if !version_match(&self.config.version.req, &ping_msg.version) { return Err(Error::IncompatibleVersion("system: {}".into())); } - self.send_pong_msg(ping_msg.nonce, &io_codec).await?; + self.send_pong_msg(ping_msg.nonce, &conn).await?; } NetMsgCmd::FindPeer => { let (findpeer_msg, _) = decode::(&msg.payload)?; let peer_id = findpeer_msg.0; - self.send_peers_msg(&peer_id, &io_codec).await?; + self.send_peers_msg(&peer_id, &conn).await?; } NetMsgCmd::Peer => { let (peer, _) = decode::(&msg.payload)?; @@ -322,7 +322,7 @@ impl LookupService { } /// Sends a Ping msg and wait to receive the Pong message. - async fn send_ping_msg(&self, io_codec: &Codec) -> Result<()> { + async fn send_ping_msg(&self, conn: &Conn) -> Result<()> { trace!("Send Pong msg"); let mut nonce: [u8; 32] = [0; 32]; @@ -332,10 +332,10 @@ impl LookupService { version: self.config.version.v.clone(), nonce, }; - io_codec.write(NetMsgCmd::Ping, &ping_msg).await?; + conn.send(NetMsg::new(NetMsgCmd::Ping, &ping_msg)?).await?; let t = Duration::from_secs(self.config.lookup_response_timeout); - let recv_msg: NetMsg = io_codec.read_timeout(t).await?; + let recv_msg: NetMsg = timeout(t, conn.recv()).await??; let payload = get_msg_payload!(Pong, recv_msg); let (pong_msg, _) = decode::(&payload)?; @@ -348,21 +348,24 @@ impl LookupService { } /// Sends a Pong msg - async fn send_pong_msg(&self, nonce: [u8; 32], io_codec: &Codec) -> Result<()> { + async fn send_pong_msg(&self, nonce: [u8; 32], conn: &Conn) -> Result<()> { trace!("Send Pong msg"); - io_codec.write(NetMsgCmd::Pong, &PongMsg(nonce)).await?; + conn.send(NetMsg::new(NetMsgCmd::Pong, &PongMsg(nonce))?) + .await?; Ok(()) } /// Sends a FindPeer msg and wait to receivet the Peers msg. - async fn send_findpeer_msg(&self, io_codec: &Codec, peer_id: &PeerID) -> Result { + async fn send_findpeer_msg(&self, conn: &Conn, peer_id: &PeerID) -> Result { trace!("Send FindPeer msg"); - io_codec - .write(NetMsgCmd::FindPeer, &FindPeerMsg(peer_id.clone())) - .await?; + conn.send(NetMsg::new( + NetMsgCmd::FindPeer, + &FindPeerMsg(peer_id.clone()), + )?) + .await?; let t = Duration::from_secs(self.config.lookup_response_timeout); - let recv_msg: NetMsg = io_codec.read_timeout(t).await?; + let recv_msg: NetMsg = timeout(t, conn.recv()).await??; let payload = get_msg_payload!(Peers, recv_msg); let (peers, _) = decode(&payload)?; @@ -371,19 +374,20 @@ impl LookupService { } /// Sends a Peers msg. - async fn send_peers_msg(&self, peer_id: &PeerID, io_codec: &Codec) -> Result<()> { + async fn send_peers_msg(&self, peer_id: &PeerID, conn: &Conn) -> Result<()> { trace!("Send Peers msg"); let table = self.table.lock().await; let entries = table.closest_entries(&peer_id.0, MAX_PEERS_IN_PEERSMSG); drop(table); let peers: Vec = entries.into_iter().map(|e| e.into()).collect(); - io_codec.write(NetMsgCmd::Peers, &PeersMsg(peers)).await?; + conn.send(NetMsg::new(NetMsgCmd::Peers, &PeersMsg(peers))?) + .await?; Ok(()) } /// Sends a Peer msg. - async fn send_peer_msg(&self, io_codec: &Codec, endpoint: Endpoint) -> Result<()> { + async fn send_peer_msg(&self, conn: &Conn, endpoint: Endpoint) -> Result<()> { trace!("Send Peer msg"); let peer_msg = PeerMsg { addr: endpoint.addr()?.clone(), @@ -391,14 +395,15 @@ impl LookupService { discovery_port: self.config.discovery_port, peer_id: self.id.clone(), }; - io_codec.write(NetMsgCmd::Peer, &peer_msg).await?; + conn.send(NetMsg::new(NetMsgCmd::Peer, &peer_msg)?).await?; Ok(()) } /// Sends a Shutdown msg. - async fn send_shutdown_msg(&self, io_codec: &Codec) -> Result<()> { + async fn send_shutdown_msg(&self, conn: &Conn) -> Result<()> { trace!("Send Shutdown msg"); - io_codec.write(NetMsgCmd::Shutdown, &ShutdownMsg(0)).await?; + conn.send(NetMsg::new(NetMsgCmd::Shutdown, &ShutdownMsg(0))?) + .await?; Ok(()) } } diff --git a/p2p/src/discovery/mod.rs b/p2p/src/discovery/mod.rs index 3e437aa..19ae77a 100644 --- a/p2p/src/discovery/mod.rs +++ b/p2p/src/discovery/mod.rs @@ -5,10 +5,10 @@ use std::sync::Arc; use log::{error, info}; use rand::{rngs::OsRng, seq::SliceRandom}; -use smol::lock::Mutex; use karyon_core::{ - async_util::{Backoff, Executor, TaskGroup, TaskResult}, + async_runtime::{lock::Mutex, Executor}, + async_util::{Backoff, TaskGroup, TaskResult}, crypto::KeyPair, }; @@ -19,6 +19,7 @@ use crate::{ connection::{ConnDirection, ConnQueue}, connector::Connector, listener::Listener, + message::NetMsg, monitor::Monitor, routing_table::{ Entry, EntryStatusFlag, RoutingTable, CONNECTED_ENTRY, DISCONNECTED_ENTRY, @@ -45,6 +46,7 @@ pub struct Discovery { /// Connector connector: Arc, + /// Listener listener: Arc, @@ -53,11 +55,12 @@ pub struct Discovery { /// Inbound slots. pub(crate) inbound_slots: Arc, + /// Outbound slots. pub(crate) outbound_slots: Arc, /// Managing spawned tasks. - task_group: TaskGroup<'static>, + task_group: TaskGroup, /// Holds the configuration for the P2P network. config: Arc, @@ -71,7 +74,7 @@ impl Discovery { conn_queue: Arc, config: Arc, monitor: Arc, - ex: Executor<'static>, + ex: Executor, ) -> ArcDiscovery { let inbound_slots = Arc::new(ConnectionSlots::new(config.inbound_slots)); let outbound_slots = Arc::new(ConnectionSlots::new(config.outbound_slots)); @@ -180,7 +183,7 @@ impl Discovery { /// Start a listener and on success, return the resolved endpoint. async fn start_listener(self: &Arc, endpoint: &Endpoint) -> Result { let selfc = self.clone(); - let callback = |c: Conn| async move { + let callback = |c: Conn| async move { selfc.conn_queue.handle(c, ConnDirection::Inbound).await?; Ok(()) }; @@ -198,8 +201,8 @@ impl Discovery { async fn connect_loop(self: Arc) -> Result<()> { let backoff = Backoff::new(500, self.config.seeding_interval * 1000); loop { - let random_entry = self.random_entry(PENDING_ENTRY).await; - match random_entry { + let random_table_entry = self.random_table_entry(PENDING_ENTRY).await; + match random_table_entry { Some(entry) => { backoff.reset(); let endpoint = Endpoint::Tcp(entry.addr, entry.port); @@ -218,7 +221,7 @@ impl Discovery { let selfc = self.clone(); let pid_c = pid.clone(); let endpoint_c = endpoint.clone(); - let cback = |conn: Conn| async move { + let cback = |conn: Conn| async move { let result = selfc.conn_queue.handle(conn, ConnDirection::Outbound).await; // If the entry is not in the routing table, ignore the result @@ -230,17 +233,17 @@ impl Discovery { match result { Err(Error::IncompatiblePeer) => { error!("Failed to do handshake: {endpoint_c} incompatible peer"); - selfc.update_entry(&pid, INCOMPATIBLE_ENTRY).await; + selfc.update_table_entry(&pid, INCOMPATIBLE_ENTRY).await; } Err(Error::PeerAlreadyConnected) => { - // TODO: Use the appropriate status. - selfc.update_entry(&pid, DISCONNECTED_ENTRY).await; + // TODO: Use an appropriate status. + selfc.update_table_entry(&pid, DISCONNECTED_ENTRY).await; } Err(_) => { - selfc.update_entry(&pid, UNSTABLE_ENTRY).await; + selfc.update_table_entry(&pid, UNSTABLE_ENTRY).await; } Ok(_) => { - selfc.update_entry(&pid, DISCONNECTED_ENTRY).await; + selfc.update_table_entry(&pid, DISCONNECTED_ENTRY).await; } } @@ -255,10 +258,10 @@ impl Discovery { if let Some(pid) = &pid { match result { Ok(_) => { - self.update_entry(pid, CONNECTED_ENTRY).await; + self.update_table_entry(pid, CONNECTED_ENTRY).await; } Err(_) => { - self.update_entry(pid, UNREACHABLE_ENTRY).await; + self.update_table_entry(pid, UNREACHABLE_ENTRY).await; } } } @@ -271,12 +274,16 @@ impl Discovery { /// table doesn't have an available entry, it will connect to one of the /// provided bootstrap endpoints in the `Config` and initiate the lookup. async fn start_seeding(&self) { - match self.random_entry(PENDING_ENTRY | CONNECTED_ENTRY).await { + match self + .random_table_entry(PENDING_ENTRY | CONNECTED_ENTRY) + .await + { Some(entry) => { let endpoint = Endpoint::Tcp(entry.addr, entry.discovery_port); let peer_id = Some(entry.key.into()); if let Err(err) = self.lookup_service.start_lookup(&endpoint, peer_id).await { - self.update_entry(&entry.key.into(), UNSTABLE_ENTRY).await; + self.update_table_entry(&entry.key.into(), UNSTABLE_ENTRY) + .await; error!("Failed to do lookup: {endpoint}: {err}"); } } @@ -292,12 +299,12 @@ impl Discovery { } /// Returns a random entry from routing table. - async fn random_entry(&self, entry_flag: EntryStatusFlag) -> Option { + async fn random_table_entry(&self, entry_flag: EntryStatusFlag) -> Option { self.table.lock().await.random_entry(entry_flag).cloned() } /// Update the entry status - async fn update_entry(&self, pid: &PeerID, entry_flag: EntryStatusFlag) { + async fn update_table_entry(&self, pid: &PeerID, entry_flag: EntryStatusFlag) { let table = &mut self.table.lock().await; table.update_entry(&pid.0, entry_flag); } diff --git a/p2p/src/discovery/refresh.rs b/p2p/src/discovery/refresh.rs index 035a581..0c49ac2 100644 --- a/p2p/src/discovery/refresh.rs +++ b/p2p/src/discovery/refresh.rs @@ -3,31 +3,28 @@ use std::{sync::Arc, time::Duration}; use bincode::{Decode, Encode}; use log::{error, info, trace}; use rand::{rngs::OsRng, RngCore}; -use smol::{ - lock::{Mutex, RwLock}, - stream::StreamExt, - Timer, -}; use karyon_core::{ - async_util::{timeout, Backoff, Executor, TaskGroup, TaskResult}, - util::{decode, encode}, + async_runtime::{ + lock::{Mutex, RwLock}, + Executor, + }, + async_util::{sleep, timeout, Backoff, TaskGroup, TaskResult}, }; -use karyon_net::{udp, Connection, Endpoint, NetError}; - -/// Maximum failures for an entry before removing it from the routing table. -pub const MAX_FAILURES: u32 = 3; - -/// Ping message size -const PINGMSG_SIZE: usize = 32; +use karyon_net::{udp, Connection, Endpoint, Error as NetError}; use crate::{ + codec::RefreshMsgCodec, + message::RefreshMsg, monitor::{ConnEvent, DiscoveryEvent, Monitor}, routing_table::{BucketEntry, Entry, RoutingTable, PENDING_ENTRY, UNREACHABLE_ENTRY}, Config, Error, Result, }; +/// Maximum failures for an entry before removing it from the routing table. +pub const MAX_FAILURES: u32 = 3; + #[derive(Decode, Encode, Debug, Clone)] pub struct PingMsg(pub [u8; 32]); @@ -42,10 +39,10 @@ pub struct RefreshService { listen_endpoint: Option>, /// Managing spawned tasks. - task_group: TaskGroup<'static>, + task_group: TaskGroup, /// A global executor - executor: Executor<'static>, + executor: Executor, /// Holds the configuration for the P2P network. config: Arc, @@ -60,7 +57,7 @@ impl RefreshService { config: Arc, table: Arc>, monitor: Arc, - executor: Executor<'static>, + executor: Executor, ) -> Self { let listen_endpoint = config .listen_endpoint @@ -118,9 +115,8 @@ impl RefreshService { /// selects the first 8 entries (oldest entries) from each bucket in the /// routing table and starts sending Ping messages to the collected entries. async fn refresh_loop(self: Arc) -> Result<()> { - let mut timer = Timer::interval(Duration::from_secs(self.config.refresh_interval)); loop { - timer.next().await; + sleep(Duration::from_secs(self.config.refresh_interval)).await; trace!("Start refreshing the routing table..."); self.monitor @@ -162,7 +158,7 @@ impl RefreshService { } for task in tasks { - task.await; + let _ = task.await; } } } @@ -193,10 +189,10 @@ impl RefreshService { async fn connect(&self, entry: &Entry) -> Result<()> { let mut retry = 0; let endpoint = Endpoint::Udp(entry.addr.clone(), entry.discovery_port); - let conn = udp::dial(&endpoint).await?; + let conn = udp::dial(&endpoint, Default::default(), RefreshMsgCodec {}).await?; let backoff = Backoff::new(100, 5000); while retry < self.config.refresh_connect_retries { - match self.send_ping_msg(&conn).await { + match self.send_ping_msg(&conn, &endpoint).await { Ok(()) => return Ok(()), Err(Error::KaryonNet(NetError::Timeout)) => { retry += 1; @@ -214,7 +210,7 @@ impl RefreshService { /// Set up a UDP listener and start listening for Ping messages from other /// peers. async fn listen_loop(self: Arc, endpoint: Endpoint) -> Result<()> { - let conn = match udp::listen(&endpoint).await { + let conn = match udp::listen(&endpoint, Default::default(), RefreshMsgCodec {}).await { Ok(c) => { self.monitor .notify(&ConnEvent::Listening(endpoint.clone()).into()) @@ -240,46 +236,48 @@ impl RefreshService { } /// Listen to receive a Ping message and respond with a Pong message. - async fn listen_to_ping_msg(&self, conn: &udp::UdpConn) -> Result<()> { - let mut buf = [0; PINGMSG_SIZE]; - let (_, endpoint) = conn.recv_from(&mut buf).await?; - + async fn listen_to_ping_msg(&self, conn: &udp::UdpConn) -> Result<()> { + let (msg, endpoint) = conn.recv().await?; self.monitor .notify(&ConnEvent::Accepted(endpoint.clone()).into()) .await; - let (ping_msg, _) = decode::(&buf)?; - - let pong_msg = PongMsg(ping_msg.0); - let buffer = encode(&pong_msg)?; - - conn.send_to(&buffer, &endpoint).await?; + match msg { + RefreshMsg::Ping(m) => { + let pong_msg = RefreshMsg::Pong(m); + conn.send((pong_msg, endpoint.clone())).await?; + } + RefreshMsg::Pong(_) => return Err(Error::InvalidMsg("Unexpected pong msg".into())), + } self.monitor - .notify(&ConnEvent::Disconnected(endpoint.clone()).into()) + .notify(&ConnEvent::Disconnected(endpoint).into()) .await; Ok(()) } /// Sends a Ping msg and wait to receive the Pong message. - async fn send_ping_msg(&self, conn: &udp::UdpConn) -> Result<()> { + async fn send_ping_msg( + &self, + conn: &udp::UdpConn, + endpoint: &Endpoint, + ) -> Result<()> { let mut nonce: [u8; 32] = [0; 32]; RngCore::fill_bytes(&mut OsRng, &mut nonce); + conn.send((RefreshMsg::Ping(nonce), endpoint.clone())) + .await?; - let ping_msg = PingMsg(nonce); - let buffer = encode(&ping_msg)?; - conn.write(&buffer).await?; - - let buf = &mut [0; PINGMSG_SIZE]; let t = Duration::from_secs(self.config.refresh_response_timeout); - timeout(t, conn.read(buf)).await??; - - let (pong_msg, _) = decode::(buf)?; + let (msg, _) = timeout(t, conn.recv()).await??; - if ping_msg.0 != pong_msg.0 { - return Err(Error::InvalidPongMsg); + match msg { + RefreshMsg::Pong(n) => { + if n != nonce { + return Err(Error::InvalidPongMsg); + } + Ok(()) + } + _ => Err(Error::InvalidMsg("Unexpected ping msg".into())), } - - Ok(()) } } diff --git a/p2p/src/error.rs b/p2p/src/error.rs index b4ddc2e..97b7b7f 100644 --- a/p2p/src/error.rs +++ b/p2p/src/error.rs @@ -62,27 +62,35 @@ pub enum Error { #[error("Rcgen Error: {0}")] Rcgen(#[from] rcgen::Error), + #[cfg(feature = "smol")] #[error("Tls Error: {0}")] Rustls(#[from] futures_rustls::rustls::Error), + #[cfg(feature = "tokio")] + #[error("Tls Error: {0}")] + Rustls(#[from] tokio_rustls::rustls::Error), + #[error("Invalid DNS Name: {0}")] - InvalidDnsNameError(#[from] futures_rustls::pki_types::InvalidDnsNameError), + InvalidDnsNameError(#[from] rustls_pki_types::InvalidDnsNameError), #[error("Channel Send Error: {0}")] ChannelSend(String), #[error(transparent)] - ChannelRecv(#[from] smol::channel::RecvError), + ChannelRecv(#[from] async_channel::RecvError), #[error(transparent)] KaryonCore(#[from] karyon_core::error::Error), #[error(transparent)] - KaryonNet(#[from] karyon_net::NetError), + KaryonNet(#[from] karyon_net::Error), + + #[error("Other Error: {0}")] + Other(String), } -impl From> for Error { - fn from(error: smol::channel::SendError) -> Self { +impl From> for Error { + fn from(error: async_channel::SendError) -> Self { Error::ChannelSend(error.to_string()) } } diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs index 3605359..8f3cf45 100644 --- a/p2p/src/lib.rs +++ b/p2p/src/lib.rs @@ -5,7 +5,7 @@ //! use std::sync::Arc; //! //! use easy_parallel::Parallel; -//! use smol::{channel as smol_channel, future, Executor}; +//! use smol::{future, Executor}; //! //! use karyon_core::crypto::{KeyPair, KeyPairType}; //! use karyon_p2p::{Backend, Config, PeerID}; @@ -19,7 +19,7 @@ //! let ex = Arc::new(Executor::new()); //! //! // Create a new Backend -//! let backend = Backend::new(&key_pair, config, ex.clone()); +//! let backend = Backend::new(&key_pair, config, ex.clone().into()); //! //! let task = async { //! // Run the backend diff --git a/p2p/src/listener.rs b/p2p/src/listener.rs index 4a41482..1abf79a 100644 --- a/p2p/src/listener.rs +++ b/p2p/src/listener.rs @@ -3,13 +3,16 @@ use std::{future::Future, sync::Arc}; use log::{debug, error, info}; use karyon_core::{ - async_util::{Executor, TaskGroup, TaskResult}, + async_runtime::Executor, + async_util::{TaskGroup, TaskResult}, crypto::KeyPair, }; -use karyon_net::{tcp, tls, Conn, ConnListener, Endpoint}; +use karyon_net::{tcp, tls, Conn, Endpoint}; use crate::{ + codec::NetMsgCodec, + message::NetMsg, monitor::{ConnEvent, Monitor}, slots::ConnectionSlots, tls_config::tls_server_config, @@ -22,7 +25,7 @@ pub struct Listener { key_pair: KeyPair, /// Managing spawned tasks. - task_group: TaskGroup<'static>, + task_group: TaskGroup, /// Manages available inbound slots. connection_slots: Arc, @@ -41,7 +44,7 @@ impl Listener { connection_slots: Arc, enable_tls: bool, monitor: Arc, - ex: Executor<'static>, + ex: Executor, ) -> Arc { Arc::new(Self { key_pair: key_pair.clone(), @@ -61,7 +64,7 @@ impl Listener { self: &Arc, endpoint: Endpoint, // https://github.com/rust-lang/rfcs/pull/2132 - callback: impl FnOnce(Conn) -> Fut + Clone + Send + 'static, + callback: impl FnOnce(Conn) -> Fut + Clone + Send + 'static, ) -> Result where Fut: Future> + Send + 'static, @@ -82,7 +85,7 @@ impl Listener { } }; - let resolved_endpoint = listener.local_endpoint()?; + let resolved_endpoint = listener.local_endpoint().map_err(Error::from)?; info!("Start listening on {resolved_endpoint}"); @@ -99,8 +102,8 @@ impl Listener { async fn listen_loop( self: Arc, - listener: Box, - callback: impl FnOnce(Conn) -> Fut + Clone + Send + 'static, + listener: karyon_net::Listener, + callback: impl FnOnce(Conn) -> Fut + Clone + Send + 'static, ) where Fut: Future> + Send + 'static, { @@ -112,7 +115,7 @@ impl Listener { let (conn, endpoint) = match result { Ok(c) => { let endpoint = match c.peer_endpoint() { - Ok(e) => e, + Ok(ep) => ep, Err(err) => { self.monitor.notify(&ConnEvent::AcceptFailed.into()).await; error!("Failed to accept a new connection: {err}"); @@ -151,16 +154,19 @@ impl Listener { } } - async fn listen(&self, endpoint: &Endpoint) -> Result { + async fn listen(&self, endpoint: &Endpoint) -> Result> { if self.enable_tls { - let tls_config = tls_server_config(&self.key_pair)?; - tls::listen(endpoint, tls_config) + let tls_config = tls::ServerTlsConfig { + tcp_config: Default::default(), + server_config: tls_server_config(&self.key_pair)?, + }; + tls::listen(endpoint, tls_config, NetMsgCodec::new()) .await - .map(|l| Box::new(l) as karyon_net::Listener) + .map(|l| Box::new(l) as karyon_net::Listener) } else { - tcp::listen(endpoint) + tcp::listen(endpoint, tcp::TcpConfig::default(), NetMsgCodec::new()) .await - .map(|l| Box::new(l) as karyon_net::Listener) + .map(|l| Box::new(l) as karyon_net::Listener) } .map_err(Error::KaryonNet) } diff --git a/p2p/src/message.rs b/p2p/src/message.rs index 1342110..6498ef7 100644 --- a/p2p/src/message.rs +++ b/p2p/src/message.rs @@ -2,15 +2,10 @@ use std::collections::HashMap; use bincode::{Decode, Encode}; +use karyon_core::util::encode; use karyon_net::{Addr, Port}; -use crate::{protocol::ProtocolID, routing_table::Entry, version::VersionInt, PeerID}; - -/// The size of the message header, in bytes. -pub const MSG_HEADER_SIZE: usize = 6; - -/// The maximum allowed size for a message in bytes. -pub const MAX_ALLOWED_MSG_SIZE: u32 = 1024 * 1024; // 1MB +use crate::{protocol::ProtocolID, routing_table::Entry, version::VersionInt, PeerID, Result}; /// Defines the main message in the karyon p2p network. /// @@ -23,11 +18,19 @@ pub struct NetMsg { pub payload: Vec, } +impl NetMsg { + pub fn new(command: NetMsgCmd, t: T) -> Result { + Ok(Self { + header: NetMsgHeader { command }, + payload: encode(&t)?, + }) + } +} + /// Represents the header of a message. #[derive(Decode, Encode, Debug, Clone)] pub struct NetMsgHeader { pub command: NetMsgCmd, - pub payload_size: u32, } /// Defines message commands. @@ -39,7 +42,7 @@ pub enum NetMsgCmd { Protocol, Shutdown, - // NOTE: The following commands are used during the lookup process. + // The following commands are used during the lookup process. Ping, Pong, FindPeer, @@ -47,6 +50,12 @@ pub enum NetMsgCmd { Peers, } +#[derive(Decode, Encode, Debug, Clone)] +pub enum RefreshMsg { + Ping([u8; 32]), + Pong([u8; 32]), +} + /// Defines a message related to a specific protocol. #[derive(Decode, Encode, Debug, Clone)] pub struct ProtocolMsg { @@ -103,21 +112,6 @@ pub struct PeerMsg { #[derive(Decode, Encode, Debug)] pub struct PeersMsg(pub Vec); -macro_rules! get_msg_payload { - ($a:ident, $b:ident) => { - if let NetMsgCmd::$a = $b.header.command { - $b.payload - } else { - return Err(Error::InvalidMsg(format!( - "Unexpected msg {:?}", - $b.header.command - ))); - } - }; -} - -pub(super) use get_msg_payload; - impl From for PeerMsg { fn from(entry: Entry) -> PeerMsg { PeerMsg { @@ -139,3 +133,19 @@ impl From for Entry { } } } + +macro_rules! get_msg_payload { + ($a:ident, $b:ident) => { + if let NetMsgCmd::$a = $b.header.command { + $b.payload + } else { + return Err(Error::InvalidMsg(format!( + "Expected {:?} msg found {:?} msg", + stringify!($a), + $b.header.command + ))); + } + }; +} + +pub(super) use get_msg_payload; diff --git a/p2p/src/monitor.rs b/p2p/src/monitor.rs index b0ce028..48719c0 100644 --- a/p2p/src/monitor.rs +++ b/p2p/src/monitor.rs @@ -26,7 +26,7 @@ use karyon_net::Endpoint; /// let ex = Arc::new(Executor::new()); /// /// let key_pair = KeyPair::generate(&KeyPairType::Ed25519); -/// let backend = Backend::new(&key_pair, Config::default(), ex); +/// let backend = Backend::new(&key_pair, Config::default(), ex.into()); /// /// // Create a new Subscription /// let sub = backend.monitor().await; diff --git a/p2p/src/peer/mod.rs b/p2p/src/peer/mod.rs index ca68530..f0f6f17 100644 --- a/p2p/src/peer/mod.rs +++ b/p2p/src/peer/mod.rs @@ -4,24 +4,22 @@ pub use peer_id::PeerID; use std::sync::Arc; +use async_channel::{Receiver, Sender}; +use bincode::{Decode, Encode}; use log::{error, trace}; -use smol::{ - channel::{self, Receiver, Sender}, - lock::RwLock, -}; use karyon_core::{ - async_util::{select, Either, Executor, TaskGroup, TaskResult}, + async_runtime::{lock::RwLock, Executor}, + async_util::{select, Either, TaskGroup, TaskResult}, event::{ArcEventSys, EventListener, EventSys}, util::{decode, encode}, }; -use karyon_net::Endpoint; +use karyon_net::{Conn, Endpoint}; use crate::{ - codec::{Codec, CodecMsg}, connection::ConnDirection, - message::{NetMsgCmd, ProtocolMsg, ShutdownMsg}, + message::{NetMsg, NetMsgCmd, ProtocolMsg, ShutdownMsg}, peer_pool::{ArcPeerPool, WeakPeerPool}, protocol::{Protocol, ProtocolEvent, ProtocolID}, Config, Error, Result, @@ -36,8 +34,8 @@ pub struct Peer { /// A weak pointer to `PeerPool` peer_pool: WeakPeerPool, - /// Holds the Codec for the peer connection - codec: Codec, + /// Holds the peer connection + conn: Conn, /// Remote endpoint for the peer remote_endpoint: Endpoint, @@ -55,7 +53,7 @@ pub struct Peer { stop_chan: (Sender>, Receiver>), /// Managing spawned tasks. - task_group: TaskGroup<'static>, + task_group: TaskGroup, } impl Peer { @@ -63,21 +61,21 @@ impl Peer { pub fn new( peer_pool: WeakPeerPool, id: &PeerID, - codec: Codec, + conn: Conn, remote_endpoint: Endpoint, conn_direction: ConnDirection, - ex: Executor<'static>, + ex: Executor, ) -> ArcPeer { Arc::new(Peer { id: id.clone(), peer_pool, - codec, + conn, protocol_ids: RwLock::new(Vec::new()), remote_endpoint, conn_direction, protocol_events: EventSys::new(), task_group: TaskGroup::with_executor(ex), - stop_chan: channel::bounded(1), + stop_chan: async_channel::bounded(1), }) } @@ -88,7 +86,7 @@ impl Peer { } /// Send a message to the peer connection using the specified protocol. - pub async fn send(&self, protocol_id: &ProtocolID, msg: &T) -> Result<()> { + pub async fn send(&self, protocol_id: &ProtocolID, msg: &T) -> Result<()> { let payload = encode(msg)?; let proto_msg = ProtocolMsg { @@ -96,12 +94,14 @@ impl Peer { payload: payload.to_vec(), }; - self.codec.write(NetMsgCmd::Protocol, &proto_msg).await?; + self.conn + .send(NetMsg::new(NetMsgCmd::Protocol, &proto_msg)?) + .await?; Ok(()) } /// Broadcast a message to all connected peers using the specified protocol. - pub async fn broadcast(&self, protocol_id: &ProtocolID, msg: &T) { + pub async fn broadcast(&self, protocol_id: &ProtocolID, msg: &T) { self.peer_pool().broadcast(protocol_id, msg).await; } @@ -123,7 +123,9 @@ impl Peer { let _ = self.stop_chan.0.try_send(Ok(())); // No need to handle the error here - let _ = self.codec.write(NetMsgCmd::Shutdown, &ShutdownMsg(0)).await; + let shutdown_msg = + NetMsg::new(NetMsgCmd::Shutdown, &ShutdownMsg(0)).expect("pack shutdown message"); + let _ = self.conn.send(shutdown_msg).await; // Force shutting down self.task_group.cancel().await; @@ -170,7 +172,7 @@ impl Peer { /// Start a read loop to handle incoming messages from the peer connection. async fn read_loop(&self) -> Result<()> { loop { - let fut = select(self.stop_chan.1.recv(), self.codec.read()).await; + let fut = select(self.stop_chan.1.recv(), self.conn.recv()).await; let result = match fut { Either::Left(stop_signal) => { trace!("Peer {} received a stop signal", self.id); diff --git a/p2p/src/peer_pool.rs b/p2p/src/peer_pool.rs index 4e20c99..8b16ef5 100644 --- a/p2p/src/peer_pool.rs +++ b/p2p/src/peer_pool.rs @@ -4,21 +4,22 @@ use std::{ time::Duration, }; +use async_channel::Sender; +use bincode::{Decode, Encode}; use log::{error, info, trace, warn}; -use smol::{ - channel::Sender, - lock::{Mutex, RwLock}, -}; use karyon_core::{ - async_util::{Executor, TaskGroup, TaskResult}, + async_runtime::{ + lock::{Mutex, RwLock}, + Executor, + }, + async_util::{timeout, TaskGroup, TaskResult}, util::decode, }; use karyon_net::Conn; use crate::{ - codec::{Codec, CodecMsg}, config::Config, connection::{ConnDirection, ConnQueue}, message::{get_msg_payload, NetMsg, NetMsgCmd, VerAckMsg, VerMsg}, @@ -50,10 +51,10 @@ pub struct PeerPool { protocol_versions: Arc>>, /// Managing spawned tasks. - task_group: TaskGroup<'static>, + task_group: TaskGroup, /// A global Executor - executor: Executor<'static>, + executor: Executor, /// The Configuration for the P2P network. pub(crate) config: Arc, @@ -69,7 +70,7 @@ impl PeerPool { conn_queue: Arc, config: Arc, monitor: Arc, - executor: Executor<'static>, + executor: Executor, ) -> Arc { let protocols = RwLock::new(HashMap::new()); let protocol_versions = Arc::new(RwLock::new(HashMap::new())); @@ -137,7 +138,7 @@ impl PeerPool { } /// Broadcast a message to all connected peers using the specified protocol. - pub async fn broadcast(&self, proto_id: &ProtocolID, msg: &T) { + pub async fn broadcast(&self, proto_id: &ProtocolID, msg: &T) { for (pid, peer) in self.peers.lock().await.iter() { if let Err(err) = peer.send(proto_id, msg).await { error!("failed to send msg to {pid}: {err}"); @@ -149,15 +150,14 @@ impl PeerPool { /// Add a new peer to the peer list. pub async fn new_peer( self: &Arc, - conn: Conn, + conn: Conn, conn_direction: &ConnDirection, disconnect_signal: Sender>, ) -> Result<()> { let endpoint = conn.peer_endpoint()?; - let codec = Codec::new(conn); // Do a handshake with the connection before creating a new peer. - let pid = self.do_handshake(&codec, conn_direction).await?; + let pid = self.do_handshake(&conn, conn_direction).await?; // TODO: Consider restricting the subnet for inbound connections if self.contains_peer(&pid).await { @@ -168,7 +168,7 @@ impl PeerPool { let peer = Peer::new( Arc::downgrade(self), &pid, - codec, + conn, endpoint.clone(), conn_direction.clone(), self.executor.clone(), @@ -234,16 +234,21 @@ impl PeerPool { } /// Initiate a handshake with a connection. - async fn do_handshake(&self, codec: &Codec, conn_direction: &ConnDirection) -> Result { + async fn do_handshake( + &self, + conn: &Conn, + conn_direction: &ConnDirection, + ) -> Result { + trace!("Handshake started: {}", conn.peer_endpoint()?); match conn_direction { ConnDirection::Inbound => { - let result = self.wait_vermsg(codec).await; + let result = self.wait_vermsg(conn).await; match result { Ok(_) => { - self.send_verack(codec, true).await?; + self.send_verack(conn, true).await?; } Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => { - self.send_verack(codec, false).await?; + self.send_verack(conn, false).await?; } _ => {} } @@ -251,14 +256,14 @@ impl PeerPool { } ConnDirection::Outbound => { - self.send_vermsg(codec).await?; - self.wait_verack(codec).await + self.send_vermsg(conn).await?; + self.wait_verack(conn).await } } } /// Send a Version message - async fn send_vermsg(&self, codec: &Codec) -> Result<()> { + async fn send_vermsg(&self, conn: &Conn) -> Result<()> { let pids = self.protocol_versions.read().await; let protocols = pids.iter().map(|p| (p.0.clone(), p.1.v.clone())).collect(); drop(pids); @@ -270,16 +275,16 @@ impl PeerPool { }; trace!("Send VerMsg"); - codec.write(NetMsgCmd::Version, &vermsg).await?; + conn.send(NetMsg::new(NetMsgCmd::Version, &vermsg)?).await?; Ok(()) } /// Wait for a Version message /// /// Returns the peer's ID upon successfully receiving the Version message. - async fn wait_vermsg(&self, codec: &Codec) -> Result { - let timeout = Duration::from_secs(self.config.handshake_timeout); - let msg: NetMsg = codec.read_timeout(timeout).await?; + async fn wait_vermsg(&self, conn: &Conn) -> Result { + let t = Duration::from_secs(self.config.handshake_timeout); + let msg: NetMsg = timeout(t, conn.recv()).await??; let payload = get_msg_payload!(Version, msg); let (vermsg, _) = decode::(&payload)?; @@ -295,23 +300,23 @@ impl PeerPool { } /// Send a Verack message - async fn send_verack(&self, codec: &Codec, ack: bool) -> Result<()> { + async fn send_verack(&self, conn: &Conn, ack: bool) -> Result<()> { let verack = VerAckMsg { peer_id: self.id.clone(), ack, }; trace!("Send VerAckMsg {:?}", verack); - codec.write(NetMsgCmd::Verack, &verack).await?; + conn.send(NetMsg::new(NetMsgCmd::Verack, &verack)?).await?; Ok(()) } /// Wait for a Verack message /// /// Returns the peer's ID upon successfully receiving the Verack message. - async fn wait_verack(&self, codec: &Codec) -> Result { - let timeout = Duration::from_secs(self.config.handshake_timeout); - let msg: NetMsg = codec.read_timeout(timeout).await?; + async fn wait_verack(&self, conn: &Conn) -> Result { + let t = Duration::from_secs(self.config.handshake_timeout); + let msg: NetMsg = timeout(t, conn.recv()).await??; let payload = get_msg_payload!(Verack, msg); let (verack, _) = decode::(&payload)?; diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs index f28659c..6153ea1 100644 --- a/p2p/src/protocol.rs +++ b/p2p/src/protocol.rs @@ -92,7 +92,7 @@ impl EventValue for ProtocolEvent { /// let ex = Arc::new(Executor::new()); /// /// // Create a new Backend -/// let backend = Backend::new(&key_pair, config, ex); +/// let backend = Backend::new(&key_pair, config, ex.into()); /// /// // Attach the NewProtocol /// let c = move |peer| NewProtocol::new(peer); diff --git a/p2p/src/protocols/ping.rs b/p2p/src/protocols/ping.rs index f04e059..654644a 100644 --- a/p2p/src/protocols/ping.rs +++ b/p2p/src/protocols/ping.rs @@ -1,23 +1,19 @@ use std::{sync::Arc, time::Duration}; +use async_channel::{Receiver, Sender}; use async_trait::async_trait; use bincode::{Decode, Encode}; use log::trace; use rand::{rngs::OsRng, RngCore}; -use smol::{ - channel, - channel::{Receiver, Sender}, - stream::StreamExt, - Timer, -}; use karyon_core::{ - async_util::{select, timeout, Either, Executor, TaskGroup, TaskResult}, + async_runtime::Executor, + async_util::{select, sleep, timeout, Either, TaskGroup, TaskResult}, event::EventListener, util::decode, }; -use karyon_net::NetError; +use karyon_net::Error as NetError; use crate::{ peer::ArcPeer, @@ -38,12 +34,12 @@ pub struct PingProtocol { peer: ArcPeer, ping_interval: u64, ping_timeout: u64, - task_group: TaskGroup<'static>, + task_group: TaskGroup, } impl PingProtocol { #[allow(clippy::new_ret_no_self)] - pub fn new(peer: ArcPeer, executor: Executor<'static>) -> ArcProtocol { + pub fn new(peer: ArcPeer, executor: Executor) -> ArcProtocol { let ping_interval = peer.config().ping_interval; let ping_timeout = peer.config().ping_timeout; Arc::new(Self { @@ -87,12 +83,11 @@ impl PingProtocol { } async fn ping_loop(self: Arc, chan: Receiver<[u8; 32]>) -> Result<()> { - let mut timer = Timer::interval(Duration::from_secs(self.ping_interval)); let rng = &mut OsRng; let mut retry = 0; while retry < MAX_FAILUERS { - timer.next().await; + sleep(Duration::from_secs(self.ping_interval)).await; let mut ping_nonce: [u8; 32] = [0; 32]; rng.fill_bytes(&mut ping_nonce); @@ -130,8 +125,8 @@ impl Protocol for PingProtocol { async fn start(self: Arc) -> Result<()> { trace!("Start Ping protocol"); - let (pong_chan, pong_chan_recv) = channel::bounded(1); - let (stop_signal_s, stop_signal) = channel::bounded::>(1); + let (pong_chan, pong_chan_recv) = async_channel::bounded(1); + let (stop_signal_s, stop_signal) = async_channel::bounded::>(1); let selfc = self.clone(); self.task_group.spawn( diff --git a/p2p/src/routing_table/entry.rs b/p2p/src/routing_table/entry.rs index 3fc8a6b..1427c2b 100644 --- a/p2p/src/routing_table/entry.rs +++ b/p2p/src/routing_table/entry.rs @@ -5,6 +5,9 @@ use karyon_net::{Addr, Port}; /// Specifies the size of the key, in bytes. pub const KEY_SIZE: usize = 32; +/// The unique key identifying the peer. +pub type Key = [u8; KEY_SIZE]; + /// An Entry represents a peer in the routing table. #[derive(Encode, Decode, Clone, Debug)] pub struct Entry { @@ -20,14 +23,11 @@ pub struct Entry { impl PartialEq for Entry { fn eq(&self, other: &Self) -> bool { - // TODO: this should also compare both addresses (the self.addr == other.addr) + // XXX: should we compare both self.addr and other.addr??? self.key == other.key } } -/// The unique key identifying the peer. -pub type Key = [u8; KEY_SIZE]; - /// Calculates the XOR distance between two provided keys. /// /// The XOR distance is a metric used in Kademlia to measure the closeness diff --git a/p2p/src/routing_table/mod.rs b/p2p/src/routing_table/mod.rs index 6854546..bbf4801 100644 --- a/p2p/src/routing_table/mod.rs +++ b/p2p/src/routing_table/mod.rs @@ -266,7 +266,7 @@ impl RoutingTable { } /// Check if two addresses belong to the same subnet. -pub fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool { +fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool { match (addr, other_addr) { (Addr::Ip(IpAddr::V4(ip)), Addr::Ip(IpAddr::V4(other_ip))) => { // TODO: Consider moving this to a different place @@ -275,6 +275,8 @@ pub fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool { } ip.octets()[0..3] == other_ip.octets()[0..3] } + + // TODO: check ipv6 _ => false, } } diff --git a/p2p/src/tls_config.rs b/p2p/src/tls_config.rs index 893c321..65d2adc 100644 --- a/p2p/src/tls_config.rs +++ b/p2p/src/tls_config.rs @@ -1,19 +1,25 @@ use std::sync::Arc; -use futures_rustls::rustls::{ - self, +#[cfg(feature = "smol")] +use futures_rustls::rustls; + +#[cfg(feature = "tokio")] +use tokio_rustls::rustls; + +use rustls::{ client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, crypto::{ aws_lc_rs::{self, cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, kx_group}, CryptoProvider, SupportedKxGroup, }, - pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}, server::danger::{ClientCertVerified, ClientCertVerifier}, CertificateError, DigitallySignedStruct, DistinguishedName, Error::InvalidCertificate, SignatureScheme, SupportedCipherSuite, SupportedProtocolVersion, }; +use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; + use log::error; use x509_parser::{certificate::X509Certificate, parse_x509_certificate}; -- cgit v1.2.3