diff options
author | hozan23 <hozan23@karyontech.net> | 2024-04-11 10:19:20 +0200 |
---|---|---|
committer | hozan23 <hozan23@karyontech.net> | 2024-05-19 13:51:30 +0200 |
commit | 0992071a7f1a36424bcfaf1fbc84541ea041df1a (patch) | |
tree | 961d73218af672797d49f899289bef295bc56493 /p2p/src | |
parent | a69917ecd8272a4946cfd12c75bf8f8c075b0e50 (diff) |
add support for tokio & improve net crate api
Diffstat (limited to 'p2p/src')
-rw-r--r-- | p2p/src/backend.rs | 4 | ||||
-rw-r--r-- | p2p/src/codec.rs | 149 | ||||
-rw-r--r-- | p2p/src/connection.rs | 12 | ||||
-rw-r--r-- | p2p/src/connector.rs | 35 | ||||
-rw-r--r-- | p2p/src/discovery/lookup.rs | 71 | ||||
-rw-r--r-- | p2p/src/discovery/mod.rs | 45 | ||||
-rw-r--r-- | p2p/src/discovery/refresh.rs | 92 | ||||
-rw-r--r-- | p2p/src/error.rs | 18 | ||||
-rw-r--r-- | p2p/src/lib.rs | 4 | ||||
-rw-r--r-- | p2p/src/listener.rs | 36 | ||||
-rw-r--r-- | p2p/src/message.rs | 58 | ||||
-rw-r--r-- | p2p/src/monitor.rs | 2 | ||||
-rw-r--r-- | p2p/src/peer/mod.rs | 42 | ||||
-rw-r--r-- | p2p/src/peer_pool.rs | 65 | ||||
-rw-r--r-- | p2p/src/protocol.rs | 2 | ||||
-rw-r--r-- | p2p/src/protocols/ping.rs | 23 | ||||
-rw-r--r-- | p2p/src/routing_table/entry.rs | 8 | ||||
-rw-r--r-- | p2p/src/routing_table/mod.rs | 4 | ||||
-rw-r--r-- | p2p/src/tls_config.rs | 12 |
19 files changed, 344 insertions, 338 deletions
diff --git a/p2p/src/backend.rs b/p2p/src/backend.rs index d33f3dc..2f21b3e 100644 --- a/p2p/src/backend.rs +++ b/p2p/src/backend.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use log::info; -use karyon_core::{async_util::Executor, crypto::KeyPair, pubsub::Subscription}; +use karyon_core::{async_runtime::Executor, crypto::KeyPair, pubsub::Subscription}; use crate::{ config::Config, @@ -37,7 +37,7 @@ pub struct Backend { impl Backend { /// Creates a new Backend. - pub fn new(key_pair: &KeyPair, config: Config, ex: Executor<'static>) -> ArcBackend { + pub fn new(key_pair: &KeyPair, config: Config, ex: Executor) -> ArcBackend { let config = Arc::new(config); let monitor = Arc::new(Monitor::new()); let conn_queue = ConnQueue::new(); diff --git a/p2p/src/codec.rs b/p2p/src/codec.rs index 726a2f7..3d0f323 100644 --- a/p2p/src/codec.rs +++ b/p2p/src/codec.rs @@ -1,120 +1,69 @@ -use std::time::Duration; +use karyon_core::util::{decode, encode, encode_into_slice}; -use bincode::{Decode, Encode}; - -use karyon_core::{ - async_util::timeout, - util::{decode, encode, encode_into_slice}, -}; - -use karyon_net::{Connection, NetError}; - -use crate::{ - message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE}, - Error, Result, +use karyon_net::{ + codec::{Codec, Decoder, Encoder, LengthCodec}, + Result, }; -pub trait CodecMsg: Decode + Encode + std::fmt::Debug {} -impl<T: Encode + Decode + std::fmt::Debug> CodecMsg for T {} +use crate::message::{NetMsg, RefreshMsg}; -/// A Codec working with generic network connections. -/// -/// It is responsible for both decoding data received from the network and -/// encoding data before sending it. -pub struct Codec { - conn: Box<dyn Connection>, +#[derive(Clone)] +pub struct NetMsgCodec { + inner_codec: LengthCodec, } -impl Codec { - /// Creates a new Codec. - pub fn new(conn: Box<dyn Connection>) -> Self { - Self { conn } - } - - /// Reads a message of type `NetMsg` from the connection. - /// - /// It reads the first 6 bytes as the header of the message, then reads - /// and decodes the remaining message data based on the determined header. - pub async fn read(&self) -> Result<NetMsg> { - // Read 6 bytes to get the header of the incoming message - let mut buf = [0; MSG_HEADER_SIZE]; - self.read_exact(&mut buf).await?; - - // Decode the header from bytes to NetMsgHeader - let (header, _) = decode::<NetMsgHeader>(&buf)?; - - if header.payload_size > MAX_ALLOWED_MSG_SIZE { - return Err(Error::InvalidMsg( - "Message exceeds the maximum allowed size".to_string(), - )); +impl NetMsgCodec { + pub fn new() -> Self { + Self { + inner_codec: LengthCodec {}, } - - // Create a buffer to hold the message based on its length - let mut payload = vec![0; header.payload_size as usize]; - self.read_exact(&mut payload).await?; - - Ok(NetMsg { header, payload }) } +} - /// Writes a message of type `T` to the connection. - /// - /// Before appending the actual message payload, it calculates the length of - /// the encoded message in bytes and appends this length to the message header. - pub async fn write<T: CodecMsg>(&self, command: NetMsgCmd, msg: &T) -> Result<()> { - let payload = encode(msg)?; - - // Create a buffer to hold the message header (6 bytes) - let header_buf = &mut [0; MSG_HEADER_SIZE]; - let header = NetMsgHeader { - command, - payload_size: payload.len() as u32, - }; - encode_into_slice(&header, header_buf)?; - - let mut buffer = vec![]; - // Append the header bytes to the buffer - buffer.extend_from_slice(header_buf); - // Append the message payload to the buffer - buffer.extend_from_slice(&payload); - - self.write_all(&buffer).await?; - Ok(()) - } +impl Codec for NetMsgCodec { + type Item = NetMsg; +} - /// Reads a message of type `NetMsg` with the given timeout. - pub async fn read_timeout(&self, duration: Duration) -> Result<NetMsg> { - timeout(duration, self.read()) - .await - .map_err(|_| NetError::Timeout)? +impl Encoder for NetMsgCodec { + type EnItem = NetMsg; + fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result<usize> { + let src = encode(src)?; + self.inner_codec.encode(&src, dst) } +} - /// Reads the exact number of bytes required to fill `buf`. - async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> { - while !buf.is_empty() { - let n = self.conn.read(buf).await?; - let (_, rest) = std::mem::take(&mut buf).split_at_mut(n); - buf = rest; - - if n == 0 { - return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); +impl Decoder for NetMsgCodec { + type DeItem = NetMsg; + fn decode(&self, src: &mut [u8]) -> Result<Option<(usize, Self::DeItem)>> { + match self.inner_codec.decode(src)? { + Some((n, s)) => { + let (m, _) = decode::<Self::DeItem>(&s)?; + Ok(Some((n, m))) } + None => Ok(None), } - - Ok(()) } +} - /// Writes an entire buffer into the connection. - async fn write_all(&self, mut buf: &[u8]) -> Result<()> { - while !buf.is_empty() { - let n = self.conn.write(buf).await?; - let (_, rest) = std::mem::take(&mut buf).split_at(n); - buf = rest; +#[derive(Clone)] +pub struct RefreshMsgCodec {} - if n == 0 { - return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); - } - } +impl Codec for RefreshMsgCodec { + type Item = RefreshMsg; +} + +impl Encoder for RefreshMsgCodec { + type EnItem = RefreshMsg; + fn encode(&self, src: &Self::EnItem, dst: &mut [u8]) -> Result<usize> { + let n = encode_into_slice(src, dst)?; + Ok(n) + } +} - Ok(()) +impl Decoder for RefreshMsgCodec { + type DeItem = RefreshMsg; + fn decode(&self, src: &mut [u8]) -> Result<Option<(usize, Self::DeItem)>> { + let (m, n) = decode::<Self::DeItem>(src)?; + Ok(Some((n, m))) } } diff --git a/p2p/src/connection.rs b/p2p/src/connection.rs index 9fa57cb..9a153f3 100644 --- a/p2p/src/connection.rs +++ b/p2p/src/connection.rs @@ -1,11 +1,11 @@ use std::{collections::VecDeque, fmt, sync::Arc}; -use smol::{channel::Sender, lock::Mutex}; +use async_channel::Sender; -use karyon_core::async_util::CondVar; +use karyon_core::{async_runtime::lock::Mutex, async_util::CondVar}; use karyon_net::Conn; -use crate::Result; +use crate::{message::NetMsg, Result}; /// Defines the direction of a network connection. #[derive(Clone, Debug)] @@ -25,7 +25,7 @@ impl fmt::Display for ConnDirection { pub struct NewConn { pub direction: ConnDirection, - pub conn: Conn, + pub conn: Conn<NetMsg>, pub disconnect_signal: Sender<Result<()>>, } @@ -44,8 +44,8 @@ impl ConnQueue { } /// Push a connection into the queue and wait for the disconnect signal - pub async fn handle(&self, conn: Conn, direction: ConnDirection) -> Result<()> { - let (disconnect_signal, chan) = smol::channel::bounded(1); + pub async fn handle(&self, conn: Conn<NetMsg>, direction: ConnDirection) -> Result<()> { + let (disconnect_signal, chan) = async_channel::bounded(1); let new_conn = NewConn { direction, conn, diff --git a/p2p/src/connector.rs b/p2p/src/connector.rs index de9e746..aea21ab 100644 --- a/p2p/src/connector.rs +++ b/p2p/src/connector.rs @@ -3,12 +3,15 @@ use std::{future::Future, sync::Arc}; use log::{error, trace, warn}; use karyon_core::{ - async_util::{Backoff, Executor, TaskGroup, TaskResult}, + async_runtime::Executor, + async_util::{Backoff, TaskGroup, TaskResult}, crypto::KeyPair, }; -use karyon_net::{tcp, tls, Conn, Endpoint, NetError}; +use karyon_net::{tcp, tls, Conn, Endpoint, Error as NetError}; use crate::{ + codec::NetMsgCodec, + message::NetMsg, monitor::{ConnEvent, Monitor}, slots::ConnectionSlots, tls_config::tls_client_config, @@ -23,7 +26,7 @@ pub struct Connector { key_pair: KeyPair, /// Managing spawned tasks. - task_group: TaskGroup<'static>, + task_group: TaskGroup, /// Manages available outbound slots. connection_slots: Arc<ConnectionSlots>, @@ -47,7 +50,7 @@ impl Connector { connection_slots: Arc<ConnectionSlots>, enable_tls: bool, monitor: Arc<Monitor>, - ex: Executor<'static>, + ex: Executor, ) -> Arc<Self> { Arc::new(Self { key_pair: key_pair.clone(), @@ -70,7 +73,11 @@ impl Connector { /// `Conn` instance. /// /// This method will block until it finds an available slot. - pub async fn connect(&self, endpoint: &Endpoint, peer_id: &Option<PeerID>) -> Result<Conn> { + pub async fn connect( + &self, + endpoint: &Endpoint, + peer_id: &Option<PeerID>, + ) -> Result<Conn<NetMsg>> { self.connection_slots.wait_for_slot().await; self.connection_slots.add(); @@ -113,7 +120,7 @@ impl Connector { self: &Arc<Self>, endpoint: &Endpoint, peer_id: &Option<PeerID>, - callback: impl FnOnce(Conn) -> Fut + Send + 'static, + callback: impl FnOnce(Conn<NetMsg>) -> Fut + Send + 'static, ) -> Result<()> where Fut: Future<Output = Result<()>> + Send + 'static, @@ -138,14 +145,20 @@ impl Connector { Ok(()) } - async fn dial(&self, endpoint: &Endpoint, peer_id: &Option<PeerID>) -> Result<Conn> { + async fn dial(&self, endpoint: &Endpoint, peer_id: &Option<PeerID>) -> Result<Conn<NetMsg>> { if self.enable_tls { - let tls_config = tls_client_config(&self.key_pair, peer_id.clone())?; - tls::dial(endpoint, tls_config, DNS_NAME) + let tls_config = tls::ClientTlsConfig { + tcp_config: Default::default(), + client_config: tls_client_config(&self.key_pair, peer_id.clone())?, + dns_name: DNS_NAME.to_string(), + }; + tls::dial(endpoint, tls_config, NetMsgCodec::new()) .await - .map(|l| Box::new(l) as Conn) + .map(|l| Box::new(l) as karyon_net::Conn<NetMsg>) } else { - tcp::dial(endpoint).await.map(|l| Box::new(l) as Conn) + tcp::dial(endpoint, tcp::TcpConfig::default(), NetMsgCodec::new()) + .await + .map(|l| Box::new(l) as karyon_net::Conn<NetMsg>) } .map_err(Error::KaryonNet) } diff --git a/p2p/src/discovery/lookup.rs b/p2p/src/discovery/lookup.rs index c81fbc6..cff4610 100644 --- a/p2p/src/discovery/lookup.rs +++ b/p2p/src/discovery/lookup.rs @@ -3,10 +3,13 @@ use std::{sync::Arc, time::Duration}; use futures_util::{stream::FuturesUnordered, StreamExt}; use log::{error, trace}; use rand::{rngs::OsRng, seq::SliceRandom, RngCore}; -use smol::lock::{Mutex, RwLock}; use karyon_core::{ - async_util::{timeout, Executor}, + async_runtime::{ + lock::{Mutex, RwLock}, + Executor, + }, + async_util::timeout, crypto::KeyPair, util::decode, }; @@ -14,7 +17,6 @@ use karyon_core::{ use karyon_net::{Conn, Endpoint}; use crate::{ - codec::Codec, connector::Connector, listener::Listener, message::{ @@ -64,7 +66,7 @@ impl LookupService { table: Arc<Mutex<RoutingTable>>, config: Arc<Config>, monitor: Arc<Monitor>, - ex: Executor<'static>, + ex: Executor, ) -> Self { let inbound_slots = Arc::new(ConnectionSlots::new(config.lookup_inbound_slots)); let outbound_slots = Arc::new(ConnectionSlots::new(config.lookup_outbound_slots)); @@ -228,8 +230,7 @@ impl LookupService { target_peer_id: &PeerID, ) -> Result<Vec<PeerMsg>> { let conn = self.connector.connect(&endpoint, &peer_id).await?; - let io_codec = Codec::new(conn); - let result = self.handle_outbound(io_codec, target_peer_id).await; + let result = self.handle_outbound(conn, target_peer_id).await; self.monitor .notify(&ConnEvent::Disconnected(endpoint).into()) @@ -242,14 +243,14 @@ impl LookupService { /// Handles outbound connection async fn handle_outbound( &self, - io_codec: Codec, + conn: Conn<NetMsg>, target_peer_id: &PeerID, ) -> Result<Vec<PeerMsg>> { trace!("Send Ping msg"); - self.send_ping_msg(&io_codec).await?; + self.send_ping_msg(&conn).await?; trace!("Send FindPeer msg"); - let peers = self.send_findpeer_msg(&io_codec, target_peer_id).await?; + let peers = self.send_findpeer_msg(&conn, target_peer_id).await?; if peers.0.len() >= MAX_PEERS_IN_PEERSMSG { return Err(Error::Lookup("Received too many peers in PeersMsg")); @@ -257,12 +258,12 @@ impl LookupService { trace!("Send Peer msg"); if let Some(endpoint) = &self.listen_endpoint { - self.send_peer_msg(&io_codec, endpoint.read().await.clone()) + self.send_peer_msg(&conn, endpoint.read().await.clone()) .await?; } trace!("Send Shutdown msg"); - self.send_shutdown_msg(&io_codec).await?; + self.send_shutdown_msg(&conn).await?; Ok(peers.0) } @@ -277,7 +278,7 @@ impl LookupService { let endpoint = Endpoint::Tcp(addr, self.config.discovery_port); let selfc = self.clone(); - let callback = |conn: Conn| async move { + let callback = |conn: Conn<NetMsg>| async move { let t = Duration::from_secs(selfc.config.lookup_connection_lifespan); timeout(t, selfc.handle_inbound(conn)).await??; Ok(()) @@ -288,10 +289,9 @@ impl LookupService { } /// Handles inbound connection - async fn handle_inbound(self: &Arc<Self>, conn: Conn) -> Result<()> { - let io_codec = Codec::new(conn); + async fn handle_inbound(self: &Arc<Self>, conn: Conn<NetMsg>) -> Result<()> { loop { - let msg: NetMsg = io_codec.read().await?; + let msg: NetMsg = conn.recv().await?; trace!("Receive msg {:?}", msg.header.command); if let NetMsgCmd::Shutdown = msg.header.command { @@ -304,12 +304,12 @@ impl LookupService { if !version_match(&self.config.version.req, &ping_msg.version) { return Err(Error::IncompatibleVersion("system: {}".into())); } - self.send_pong_msg(ping_msg.nonce, &io_codec).await?; + self.send_pong_msg(ping_msg.nonce, &conn).await?; } NetMsgCmd::FindPeer => { let (findpeer_msg, _) = decode::<FindPeerMsg>(&msg.payload)?; let peer_id = findpeer_msg.0; - self.send_peers_msg(&peer_id, &io_codec).await?; + self.send_peers_msg(&peer_id, &conn).await?; } NetMsgCmd::Peer => { let (peer, _) = decode::<PeerMsg>(&msg.payload)?; @@ -322,7 +322,7 @@ impl LookupService { } /// Sends a Ping msg and wait to receive the Pong message. - async fn send_ping_msg(&self, io_codec: &Codec) -> Result<()> { + async fn send_ping_msg(&self, conn: &Conn<NetMsg>) -> Result<()> { trace!("Send Pong msg"); let mut nonce: [u8; 32] = [0; 32]; @@ -332,10 +332,10 @@ impl LookupService { version: self.config.version.v.clone(), nonce, }; - io_codec.write(NetMsgCmd::Ping, &ping_msg).await?; + conn.send(NetMsg::new(NetMsgCmd::Ping, &ping_msg)?).await?; let t = Duration::from_secs(self.config.lookup_response_timeout); - let recv_msg: NetMsg = io_codec.read_timeout(t).await?; + let recv_msg: NetMsg = timeout(t, conn.recv()).await??; let payload = get_msg_payload!(Pong, recv_msg); let (pong_msg, _) = decode::<PongMsg>(&payload)?; @@ -348,21 +348,24 @@ impl LookupService { } /// Sends a Pong msg - async fn send_pong_msg(&self, nonce: [u8; 32], io_codec: &Codec) -> Result<()> { + async fn send_pong_msg(&self, nonce: [u8; 32], conn: &Conn<NetMsg>) -> Result<()> { trace!("Send Pong msg"); - io_codec.write(NetMsgCmd::Pong, &PongMsg(nonce)).await?; + conn.send(NetMsg::new(NetMsgCmd::Pong, &PongMsg(nonce))?) + .await?; Ok(()) } /// Sends a FindPeer msg and wait to receivet the Peers msg. - async fn send_findpeer_msg(&self, io_codec: &Codec, peer_id: &PeerID) -> Result<PeersMsg> { + async fn send_findpeer_msg(&self, conn: &Conn<NetMsg>, peer_id: &PeerID) -> Result<PeersMsg> { trace!("Send FindPeer msg"); - io_codec - .write(NetMsgCmd::FindPeer, &FindPeerMsg(peer_id.clone())) - .await?; + conn.send(NetMsg::new( + NetMsgCmd::FindPeer, + &FindPeerMsg(peer_id.clone()), + )?) + .await?; let t = Duration::from_secs(self.config.lookup_response_timeout); - let recv_msg: NetMsg = io_codec.read_timeout(t).await?; + let recv_msg: NetMsg = timeout(t, conn.recv()).await??; let payload = get_msg_payload!(Peers, recv_msg); let (peers, _) = decode(&payload)?; @@ -371,19 +374,20 @@ impl LookupService { } /// Sends a Peers msg. - async fn send_peers_msg(&self, peer_id: &PeerID, io_codec: &Codec) -> Result<()> { + async fn send_peers_msg(&self, peer_id: &PeerID, conn: &Conn<NetMsg>) -> Result<()> { trace!("Send Peers msg"); let table = self.table.lock().await; let entries = table.closest_entries(&peer_id.0, MAX_PEERS_IN_PEERSMSG); drop(table); let peers: Vec<PeerMsg> = entries.into_iter().map(|e| e.into()).collect(); - io_codec.write(NetMsgCmd::Peers, &PeersMsg(peers)).await?; + conn.send(NetMsg::new(NetMsgCmd::Peers, &PeersMsg(peers))?) + .await?; Ok(()) } /// Sends a Peer msg. - async fn send_peer_msg(&self, io_codec: &Codec, endpoint: Endpoint) -> Result<()> { + async fn send_peer_msg(&self, conn: &Conn<NetMsg>, endpoint: Endpoint) -> Result<()> { trace!("Send Peer msg"); let peer_msg = PeerMsg { addr: endpoint.addr()?.clone(), @@ -391,14 +395,15 @@ impl LookupService { discovery_port: self.config.discovery_port, peer_id: self.id.clone(), }; - io_codec.write(NetMsgCmd::Peer, &peer_msg).await?; + conn.send(NetMsg::new(NetMsgCmd::Peer, &peer_msg)?).await?; Ok(()) } /// Sends a Shutdown msg. - async fn send_shutdown_msg(&self, io_codec: &Codec) -> Result<()> { + async fn send_shutdown_msg(&self, conn: &Conn<NetMsg>) -> Result<()> { trace!("Send Shutdown msg"); - io_codec.write(NetMsgCmd::Shutdown, &ShutdownMsg(0)).await?; + conn.send(NetMsg::new(NetMsgCmd::Shutdown, &ShutdownMsg(0))?) + .await?; Ok(()) } } diff --git a/p2p/src/discovery/mod.rs b/p2p/src/discovery/mod.rs index 3e437aa..19ae77a 100644 --- a/p2p/src/discovery/mod.rs +++ b/p2p/src/discovery/mod.rs @@ -5,10 +5,10 @@ use std::sync::Arc; use log::{error, info}; use rand::{rngs::OsRng, seq::SliceRandom}; -use smol::lock::Mutex; use karyon_core::{ - async_util::{Backoff, Executor, TaskGroup, TaskResult}, + async_runtime::{lock::Mutex, Executor}, + async_util::{Backoff, TaskGroup, TaskResult}, crypto::KeyPair, }; @@ -19,6 +19,7 @@ use crate::{ connection::{ConnDirection, ConnQueue}, connector::Connector, listener::Listener, + message::NetMsg, monitor::Monitor, routing_table::{ Entry, EntryStatusFlag, RoutingTable, CONNECTED_ENTRY, DISCONNECTED_ENTRY, @@ -45,6 +46,7 @@ pub struct Discovery { /// Connector connector: Arc<Connector>, + /// Listener listener: Arc<Listener>, @@ -53,11 +55,12 @@ pub struct Discovery { /// Inbound slots. pub(crate) inbound_slots: Arc<ConnectionSlots>, + /// Outbound slots. pub(crate) outbound_slots: Arc<ConnectionSlots>, /// Managing spawned tasks. - task_group: TaskGroup<'static>, + task_group: TaskGroup, /// Holds the configuration for the P2P network. config: Arc<Config>, @@ -71,7 +74,7 @@ impl Discovery { conn_queue: Arc<ConnQueue>, config: Arc<Config>, monitor: Arc<Monitor>, - ex: Executor<'static>, + ex: Executor, ) -> ArcDiscovery { let inbound_slots = Arc::new(ConnectionSlots::new(config.inbound_slots)); let outbound_slots = Arc::new(ConnectionSlots::new(config.outbound_slots)); @@ -180,7 +183,7 @@ impl Discovery { /// Start a listener and on success, return the resolved endpoint. async fn start_listener(self: &Arc<Self>, endpoint: &Endpoint) -> Result<Endpoint> { let selfc = self.clone(); - let callback = |c: Conn| async move { + let callback = |c: Conn<NetMsg>| async move { selfc.conn_queue.handle(c, ConnDirection::Inbound).await?; Ok(()) }; @@ -198,8 +201,8 @@ impl Discovery { async fn connect_loop(self: Arc<Self>) -> Result<()> { let backoff = Backoff::new(500, self.config.seeding_interval * 1000); loop { - let random_entry = self.random_entry(PENDING_ENTRY).await; - match random_entry { + let random_table_entry = self.random_table_entry(PENDING_ENTRY).await; + match random_table_entry { Some(entry) => { backoff.reset(); let endpoint = Endpoint::Tcp(entry.addr, entry.port); @@ -218,7 +221,7 @@ impl Discovery { let selfc = self.clone(); let pid_c = pid.clone(); let endpoint_c = endpoint.clone(); - let cback = |conn: Conn| async move { + let cback = |conn: Conn<NetMsg>| async move { let result = selfc.conn_queue.handle(conn, ConnDirection::Outbound).await; // If the entry is not in the routing table, ignore the result @@ -230,17 +233,17 @@ impl Discovery { match result { Err(Error::IncompatiblePeer) => { error!("Failed to do handshake: {endpoint_c} incompatible peer"); - selfc.update_entry(&pid, INCOMPATIBLE_ENTRY).await; + selfc.update_table_entry(&pid, INCOMPATIBLE_ENTRY).await; } Err(Error::PeerAlreadyConnected) => { - // TODO: Use the appropriate status. - selfc.update_entry(&pid, DISCONNECTED_ENTRY).await; + // TODO: Use an appropriate status. + selfc.update_table_entry(&pid, DISCONNECTED_ENTRY).await; } Err(_) => { - selfc.update_entry(&pid, UNSTABLE_ENTRY).await; + selfc.update_table_entry(&pid, UNSTABLE_ENTRY).await; } Ok(_) => { - selfc.update_entry(&pid, DISCONNECTED_ENTRY).await; + selfc.update_table_entry(&pid, DISCONNECTED_ENTRY).await; } } @@ -255,10 +258,10 @@ impl Discovery { if let Some(pid) = &pid { match result { Ok(_) => { - self.update_entry(pid, CONNECTED_ENTRY).await; + self.update_table_entry(pid, CONNECTED_ENTRY).await; } Err(_) => { - self.update_entry(pid, UNREACHABLE_ENTRY).await; + self.update_table_entry(pid, UNREACHABLE_ENTRY).await; } } } @@ -271,12 +274,16 @@ impl Discovery { /// table doesn't have an available entry, it will connect to one of the /// provided bootstrap endpoints in the `Config` and initiate the lookup. async fn start_seeding(&self) { - match self.random_entry(PENDING_ENTRY | CONNECTED_ENTRY).await { + match self + .random_table_entry(PENDING_ENTRY | CONNECTED_ENTRY) + .await + { Some(entry) => { let endpoint = Endpoint::Tcp(entry.addr, entry.discovery_port); let peer_id = Some(entry.key.into()); if let Err(err) = self.lookup_service.start_lookup(&endpoint, peer_id).await { - self.update_entry(&entry.key.into(), UNSTABLE_ENTRY).await; + self.update_table_entry(&entry.key.into(), UNSTABLE_ENTRY) + .await; error!("Failed to do lookup: {endpoint}: {err}"); } } @@ -292,12 +299,12 @@ impl Discovery { } /// Returns a random entry from routing table. - async fn random_entry(&self, entry_flag: EntryStatusFlag) -> Option<Entry> { + async fn random_table_entry(&self, entry_flag: EntryStatusFlag) -> Option<Entry> { self.table.lock().await.random_entry(entry_flag).cloned() } /// Update the entry status - async fn update_entry(&self, pid: &PeerID, entry_flag: EntryStatusFlag) { + async fn update_table_entry(&self, pid: &PeerID, entry_flag: EntryStatusFlag) { let table = &mut self.table.lock().await; table.update_entry(&pid.0, entry_flag); } diff --git a/p2p/src/discovery/refresh.rs b/p2p/src/discovery/refresh.rs index 035a581..0c49ac2 100644 --- a/p2p/src/discovery/refresh.rs +++ b/p2p/src/discovery/refresh.rs @@ -3,31 +3,28 @@ use std::{sync::Arc, time::Duration}; use bincode::{Decode, Encode}; use log::{error, info, trace}; use rand::{rngs::OsRng, RngCore}; -use smol::{ - lock::{Mutex, RwLock}, - stream::StreamExt, - Timer, -}; use karyon_core::{ - async_util::{timeout, Backoff, Executor, TaskGroup, TaskResult}, - util::{decode, encode}, + async_runtime::{ + lock::{Mutex, RwLock}, + Executor, + }, + async_util::{sleep, timeout, Backoff, TaskGroup, TaskResult}, }; -use karyon_net::{udp, Connection, Endpoint, NetError}; - -/// Maximum failures for an entry before removing it from the routing table. -pub const MAX_FAILURES: u32 = 3; - -/// Ping message size -const PINGMSG_SIZE: usize = 32; +use karyon_net::{udp, Connection, Endpoint, Error as NetError}; use crate::{ + codec::RefreshMsgCodec, + message::RefreshMsg, monitor::{ConnEvent, DiscoveryEvent, Monitor}, routing_table::{BucketEntry, Entry, RoutingTable, PENDING_ENTRY, UNREACHABLE_ENTRY}, Config, Error, Result, }; +/// Maximum failures for an entry before removing it from the routing table. +pub const MAX_FAILURES: u32 = 3; + #[derive(Decode, Encode, Debug, Clone)] pub struct PingMsg(pub [u8; 32]); @@ -42,10 +39,10 @@ pub struct RefreshService { listen_endpoint: Option<RwLock<Endpoint>>, /// Managing spawned tasks. - task_group: TaskGroup<'static>, + task_group: TaskGroup, /// A global executor - executor: Executor<'static>, + executor: Executor, /// Holds the configuration for the P2P network. config: Arc<Config>, @@ -60,7 +57,7 @@ impl RefreshService { config: Arc<Config>, table: Arc<Mutex<RoutingTable>>, monitor: Arc<Monitor>, - executor: Executor<'static>, + executor: Executor, ) -> Self { let listen_endpoint = config .listen_endpoint @@ -118,9 +115,8 @@ impl RefreshService { /// selects the first 8 entries (oldest entries) from each bucket in the /// routing table and starts sending Ping messages to the collected entries. async fn refresh_loop(self: Arc<Self>) -> Result<()> { - let mut timer = Timer::interval(Duration::from_secs(self.config.refresh_interval)); loop { - timer.next().await; + sleep(Duration::from_secs(self.config.refresh_interval)).await; trace!("Start refreshing the routing table..."); self.monitor @@ -162,7 +158,7 @@ impl RefreshService { } for task in tasks { - task.await; + let _ = task.await; } } } @@ -193,10 +189,10 @@ impl RefreshService { async fn connect(&self, entry: &Entry) -> Result<()> { let mut retry = 0; let endpoint = Endpoint::Udp(entry.addr.clone(), entry.discovery_port); - let conn = udp::dial(&endpoint).await?; + let conn = udp::dial(&endpoint, Default::default(), RefreshMsgCodec {}).await?; let backoff = Backoff::new(100, 5000); while retry < self.config.refresh_connect_retries { - match self.send_ping_msg(&conn).await { + match self.send_ping_msg(&conn, &endpoint).await { Ok(()) => return Ok(()), Err(Error::KaryonNet(NetError::Timeout)) => { retry += 1; @@ -214,7 +210,7 @@ impl RefreshService { /// Set up a UDP listener and start listening for Ping messages from other /// peers. async fn listen_loop(self: Arc<Self>, endpoint: Endpoint) -> Result<()> { - let conn = match udp::listen(&endpoint).await { + let conn = match udp::listen(&endpoint, Default::default(), RefreshMsgCodec {}).await { Ok(c) => { self.monitor .notify(&ConnEvent::Listening(endpoint.clone()).into()) @@ -240,46 +236,48 @@ impl RefreshService { } /// Listen to receive a Ping message and respond with a Pong message. - async fn listen_to_ping_msg(&self, conn: &udp::UdpConn) -> Result<()> { - let mut buf = [0; PINGMSG_SIZE]; - let (_, endpoint) = conn.recv_from(&mut buf).await?; - + async fn listen_to_ping_msg(&self, conn: &udp::UdpConn<RefreshMsgCodec>) -> Result<()> { + let (msg, endpoint) = conn.recv().await?; self.monitor .notify(&ConnEvent::Accepted(endpoint.clone()).into()) .await; - let (ping_msg, _) = decode::<PingMsg>(&buf)?; - - let pong_msg = PongMsg(ping_msg.0); - let buffer = encode(&pong_msg)?; - - conn.send_to(&buffer, &endpoint).await?; + match msg { + RefreshMsg::Ping(m) => { + let pong_msg = RefreshMsg::Pong(m); + conn.send((pong_msg, endpoint.clone())).await?; + } + RefreshMsg::Pong(_) => return Err(Error::InvalidMsg("Unexpected pong msg".into())), + } self.monitor - .notify(&ConnEvent::Disconnected(endpoint.clone()).into()) + .notify(&ConnEvent::Disconnected(endpoint).into()) .await; Ok(()) } /// Sends a Ping msg and wait to receive the Pong message. - async fn send_ping_msg(&self, conn: &udp::UdpConn) -> Result<()> { + async fn send_ping_msg( + &self, + conn: &udp::UdpConn<RefreshMsgCodec>, + endpoint: &Endpoint, + ) -> Result<()> { let mut nonce: [u8; 32] = [0; 32]; RngCore::fill_bytes(&mut OsRng, &mut nonce); + conn.send((RefreshMsg::Ping(nonce), endpoint.clone())) + .await?; - let ping_msg = PingMsg(nonce); - let buffer = encode(&ping_msg)?; - conn.write(&buffer).await?; - - let buf = &mut [0; PINGMSG_SIZE]; let t = Duration::from_secs(self.config.refresh_response_timeout); - timeout(t, conn.read(buf)).await??; - - let (pong_msg, _) = decode::<PongMsg>(buf)?; + let (msg, _) = timeout(t, conn.recv()).await??; - if ping_msg.0 != pong_msg.0 { - return Err(Error::InvalidPongMsg); + match msg { + RefreshMsg::Pong(n) => { + if n != nonce { + return Err(Error::InvalidPongMsg); + } + Ok(()) + } + _ => Err(Error::InvalidMsg("Unexpected ping msg".into())), } - - Ok(()) } } diff --git a/p2p/src/error.rs b/p2p/src/error.rs index b4ddc2e..97b7b7f 100644 --- a/p2p/src/error.rs +++ b/p2p/src/error.rs @@ -62,27 +62,35 @@ pub enum Error { #[error("Rcgen Error: {0}")] Rcgen(#[from] rcgen::Error), + #[cfg(feature = "smol")] #[error("Tls Error: {0}")] Rustls(#[from] futures_rustls::rustls::Error), + #[cfg(feature = "tokio")] + #[error("Tls Error: {0}")] + Rustls(#[from] tokio_rustls::rustls::Error), + #[error("Invalid DNS Name: {0}")] - InvalidDnsNameError(#[from] futures_rustls::pki_types::InvalidDnsNameError), + InvalidDnsNameError(#[from] rustls_pki_types::InvalidDnsNameError), #[error("Channel Send Error: {0}")] ChannelSend(String), #[error(transparent)] - ChannelRecv(#[from] smol::channel::RecvError), + ChannelRecv(#[from] async_channel::RecvError), #[error(transparent)] KaryonCore(#[from] karyon_core::error::Error), #[error(transparent)] - KaryonNet(#[from] karyon_net::NetError), + KaryonNet(#[from] karyon_net::Error), + + #[error("Other Error: {0}")] + Other(String), } -impl<T> From<smol::channel::SendError<T>> for Error { - fn from(error: smol::channel::SendError<T>) -> Self { +impl<T> From<async_channel::SendError<T>> for Error { + fn from(error: async_channel::SendError<T>) -> Self { Error::ChannelSend(error.to_string()) } } diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs index 3605359..8f3cf45 100644 --- a/p2p/src/lib.rs +++ b/p2p/src/lib.rs @@ -5,7 +5,7 @@ //! use std::sync::Arc; //! //! use easy_parallel::Parallel; -//! use smol::{channel as smol_channel, future, Executor}; +//! use smol::{future, Executor}; //! //! use karyon_core::crypto::{KeyPair, KeyPairType}; //! use karyon_p2p::{Backend, Config, PeerID}; @@ -19,7 +19,7 @@ //! let ex = Arc::new(Executor::new()); //! //! // Create a new Backend -//! let backend = Backend::new(&key_pair, config, ex.clone()); +//! let backend = Backend::new(&key_pair, config, ex.clone().into()); //! //! let task = async { //! // Run the backend diff --git a/p2p/src/listener.rs b/p2p/src/listener.rs index 4a41482..1abf79a 100644 --- a/p2p/src/listener.rs +++ b/p2p/src/listener.rs @@ -3,13 +3,16 @@ use std::{future::Future, sync::Arc}; use log::{debug, error, info}; use karyon_core::{ - async_util::{Executor, TaskGroup, TaskResult}, + async_runtime::Executor, + async_util::{TaskGroup, TaskResult}, crypto::KeyPair, }; -use karyon_net::{tcp, tls, Conn, ConnListener, Endpoint}; +use karyon_net::{tcp, tls, Conn, Endpoint}; use crate::{ + codec::NetMsgCodec, + message::NetMsg, monitor::{ConnEvent, Monitor}, slots::ConnectionSlots, tls_config::tls_server_config, @@ -22,7 +25,7 @@ pub struct Listener { key_pair: KeyPair, /// Managing spawned tasks. - task_group: TaskGroup<'static>, + task_group: TaskGroup, /// Manages available inbound slots. connection_slots: Arc<ConnectionSlots>, @@ -41,7 +44,7 @@ impl Listener { connection_slots: Arc<ConnectionSlots>, enable_tls: bool, monitor: Arc<Monitor>, - ex: Executor<'static>, + ex: Executor, ) -> Arc<Self> { Arc::new(Self { key_pair: key_pair.clone(), @@ -61,7 +64,7 @@ impl Listener { self: &Arc<Self>, endpoint: Endpoint, // https://github.com/rust-lang/rfcs/pull/2132 - callback: impl FnOnce(Conn) -> Fut + Clone + Send + 'static, + callback: impl FnOnce(Conn<NetMsg>) -> Fut + Clone + Send + 'static, ) -> Result<Endpoint> where Fut: Future<Output = Result<()>> + Send + 'static, @@ -82,7 +85,7 @@ impl Listener { } }; - let resolved_endpoint = listener.local_endpoint()?; + let resolved_endpoint = listener.local_endpoint().map_err(Error::from)?; info!("Start listening on {resolved_endpoint}"); @@ -99,8 +102,8 @@ impl Listener { async fn listen_loop<Fut>( self: Arc<Self>, - listener: Box<dyn ConnListener>, - callback: impl FnOnce(Conn) -> Fut + Clone + Send + 'static, + listener: karyon_net::Listener<NetMsg>, + callback: impl FnOnce(Conn<NetMsg>) -> Fut + Clone + Send + 'static, ) where Fut: Future<Output = Result<()>> + Send + 'static, { @@ -112,7 +115,7 @@ impl Listener { let (conn, endpoint) = match result { Ok(c) => { let endpoint = match c.peer_endpoint() { - Ok(e) => e, + Ok(ep) => ep, Err(err) => { self.monitor.notify(&ConnEvent::AcceptFailed.into()).await; error!("Failed to accept a new connection: {err}"); @@ -151,16 +154,19 @@ impl Listener { } } - async fn listen(&self, endpoint: &Endpoint) -> Result<karyon_net::Listener> { + async fn listen(&self, endpoint: &Endpoint) -> Result<karyon_net::Listener<NetMsg>> { if self.enable_tls { - let tls_config = tls_server_config(&self.key_pair)?; - tls::listen(endpoint, tls_config) + let tls_config = tls::ServerTlsConfig { + tcp_config: Default::default(), + server_config: tls_server_config(&self.key_pair)?, + }; + tls::listen(endpoint, tls_config, NetMsgCodec::new()) .await - .map(|l| Box::new(l) as karyon_net::Listener) + .map(|l| Box::new(l) as karyon_net::Listener<NetMsg>) } else { - tcp::listen(endpoint) + tcp::listen(endpoint, tcp::TcpConfig::default(), NetMsgCodec::new()) .await - .map(|l| Box::new(l) as karyon_net::Listener) + .map(|l| Box::new(l) as karyon_net::Listener<NetMsg>) } .map_err(Error::KaryonNet) } diff --git a/p2p/src/message.rs b/p2p/src/message.rs index 1342110..6498ef7 100644 --- a/p2p/src/message.rs +++ b/p2p/src/message.rs @@ -2,15 +2,10 @@ use std::collections::HashMap; use bincode::{Decode, Encode}; +use karyon_core::util::encode; use karyon_net::{Addr, Port}; -use crate::{protocol::ProtocolID, routing_table::Entry, version::VersionInt, PeerID}; - -/// The size of the message header, in bytes. -pub const MSG_HEADER_SIZE: usize = 6; - -/// The maximum allowed size for a message in bytes. -pub const MAX_ALLOWED_MSG_SIZE: u32 = 1024 * 1024; // 1MB +use crate::{protocol::ProtocolID, routing_table::Entry, version::VersionInt, PeerID, Result}; /// Defines the main message in the karyon p2p network. /// @@ -23,11 +18,19 @@ pub struct NetMsg { pub payload: Vec<u8>, } +impl NetMsg { + pub fn new<T: Encode>(command: NetMsgCmd, t: T) -> Result<Self> { + Ok(Self { + header: NetMsgHeader { command }, + payload: encode(&t)?, + }) + } +} + /// Represents the header of a message. #[derive(Decode, Encode, Debug, Clone)] pub struct NetMsgHeader { pub command: NetMsgCmd, - pub payload_size: u32, } /// Defines message commands. @@ -39,7 +42,7 @@ pub enum NetMsgCmd { Protocol, Shutdown, - // NOTE: The following commands are used during the lookup process. + // The following commands are used during the lookup process. Ping, Pong, FindPeer, @@ -47,6 +50,12 @@ pub enum NetMsgCmd { Peers, } +#[derive(Decode, Encode, Debug, Clone)] +pub enum RefreshMsg { + Ping([u8; 32]), + Pong([u8; 32]), +} + /// Defines a message related to a specific protocol. #[derive(Decode, Encode, Debug, Clone)] pub struct ProtocolMsg { @@ -103,21 +112,6 @@ pub struct PeerMsg { #[derive(Decode, Encode, Debug)] pub struct PeersMsg(pub Vec<PeerMsg>); -macro_rules! get_msg_payload { - ($a:ident, $b:ident) => { - if let NetMsgCmd::$a = $b.header.command { - $b.payload - } else { - return Err(Error::InvalidMsg(format!( - "Unexpected msg {:?}", - $b.header.command - ))); - } - }; -} - -pub(super) use get_msg_payload; - impl From<Entry> for PeerMsg { fn from(entry: Entry) -> PeerMsg { PeerMsg { @@ -139,3 +133,19 @@ impl From<PeerMsg> for Entry { } } } + +macro_rules! get_msg_payload { + ($a:ident, $b:ident) => { + if let NetMsgCmd::$a = $b.header.command { + $b.payload + } else { + return Err(Error::InvalidMsg(format!( + "Expected {:?} msg found {:?} msg", + stringify!($a), + $b.header.command + ))); + } + }; +} + +pub(super) use get_msg_payload; diff --git a/p2p/src/monitor.rs b/p2p/src/monitor.rs index b0ce028..48719c0 100644 --- a/p2p/src/monitor.rs +++ b/p2p/src/monitor.rs @@ -26,7 +26,7 @@ use karyon_net::Endpoint; /// let ex = Arc::new(Executor::new()); /// /// let key_pair = KeyPair::generate(&KeyPairType::Ed25519); -/// let backend = Backend::new(&key_pair, Config::default(), ex); +/// let backend = Backend::new(&key_pair, Config::default(), ex.into()); /// /// // Create a new Subscription /// let sub = backend.monitor().await; diff --git a/p2p/src/peer/mod.rs b/p2p/src/peer/mod.rs index ca68530..f0f6f17 100644 --- a/p2p/src/peer/mod.rs +++ b/p2p/src/peer/mod.rs @@ -4,24 +4,22 @@ pub use peer_id::PeerID; use std::sync::Arc; +use async_channel::{Receiver, Sender}; +use bincode::{Decode, Encode}; use log::{error, trace}; -use smol::{ - channel::{self, Receiver, Sender}, - lock::RwLock, -}; use karyon_core::{ - async_util::{select, Either, Executor, TaskGroup, TaskResult}, + async_runtime::{lock::RwLock, Executor}, + async_util::{select, Either, TaskGroup, TaskResult}, event::{ArcEventSys, EventListener, EventSys}, util::{decode, encode}, }; -use karyon_net::Endpoint; +use karyon_net::{Conn, Endpoint}; use crate::{ - codec::{Codec, CodecMsg}, connection::ConnDirection, - message::{NetMsgCmd, ProtocolMsg, ShutdownMsg}, + message::{NetMsg, NetMsgCmd, ProtocolMsg, ShutdownMsg}, peer_pool::{ArcPeerPool, WeakPeerPool}, protocol::{Protocol, ProtocolEvent, ProtocolID}, Config, Error, Result, @@ -36,8 +34,8 @@ pub struct Peer { /// A weak pointer to `PeerPool` peer_pool: WeakPeerPool, - /// Holds the Codec for the peer connection - codec: Codec, + /// Holds the peer connection + conn: Conn<NetMsg>, /// Remote endpoint for the peer remote_endpoint: Endpoint, @@ -55,7 +53,7 @@ pub struct Peer { stop_chan: (Sender<Result<()>>, Receiver<Result<()>>), /// Managing spawned tasks. - task_group: TaskGroup<'static>, + task_group: TaskGroup, } impl Peer { @@ -63,21 +61,21 @@ impl Peer { pub fn new( peer_pool: WeakPeerPool, id: &PeerID, - codec: Codec, + conn: Conn<NetMsg>, remote_endpoint: Endpoint, conn_direction: ConnDirection, - ex: Executor<'static>, + ex: Executor, ) -> ArcPeer { Arc::new(Peer { id: id.clone(), peer_pool, - codec, + conn, protocol_ids: RwLock::new(Vec::new()), remote_endpoint, conn_direction, protocol_events: EventSys::new(), task_group: TaskGroup::with_executor(ex), - stop_chan: channel::bounded(1), + stop_chan: async_channel::bounded(1), }) } @@ -88,7 +86,7 @@ impl Peer { } /// Send a message to the peer connection using the specified protocol. - pub async fn send<T: CodecMsg>(&self, protocol_id: &ProtocolID, msg: &T) -> Result<()> { + pub async fn send<T: Encode + Decode>(&self, protocol_id: &ProtocolID, msg: &T) -> Result<()> { let payload = encode(msg)?; let proto_msg = ProtocolMsg { @@ -96,12 +94,14 @@ impl Peer { payload: payload.to_vec(), }; - self.codec.write(NetMsgCmd::Protocol, &proto_msg).await?; + self.conn + .send(NetMsg::new(NetMsgCmd::Protocol, &proto_msg)?) + .await?; Ok(()) } /// Broadcast a message to all connected peers using the specified protocol. - pub async fn broadcast<T: CodecMsg>(&self, protocol_id: &ProtocolID, msg: &T) { + pub async fn broadcast<T: Encode + Decode>(&self, protocol_id: &ProtocolID, msg: &T) { self.peer_pool().broadcast(protocol_id, msg).await; } @@ -123,7 +123,9 @@ impl Peer { let _ = self.stop_chan.0.try_send(Ok(())); // No need to handle the error here - let _ = self.codec.write(NetMsgCmd::Shutdown, &ShutdownMsg(0)).await; + let shutdown_msg = + NetMsg::new(NetMsgCmd::Shutdown, &ShutdownMsg(0)).expect("pack shutdown message"); + let _ = self.conn.send(shutdown_msg).await; // Force shutting down self.task_group.cancel().await; @@ -170,7 +172,7 @@ impl Peer { /// Start a read loop to handle incoming messages from the peer connection. async fn read_loop(&self) -> Result<()> { loop { - let fut = select(self.stop_chan.1.recv(), self.codec.read()).await; + let fut = select(self.stop_chan.1.recv(), self.conn.recv()).await; let result = match fut { Either::Left(stop_signal) => { trace!("Peer {} received a stop signal", self.id); diff --git a/p2p/src/peer_pool.rs b/p2p/src/peer_pool.rs index 4e20c99..8b16ef5 100644 --- a/p2p/src/peer_pool.rs +++ b/p2p/src/peer_pool.rs @@ -4,21 +4,22 @@ use std::{ time::Duration, }; +use async_channel::Sender; +use bincode::{Decode, Encode}; use log::{error, info, trace, warn}; -use smol::{ - channel::Sender, - lock::{Mutex, RwLock}, -}; use karyon_core::{ - async_util::{Executor, TaskGroup, TaskResult}, + async_runtime::{ + lock::{Mutex, RwLock}, + Executor, + }, + async_util::{timeout, TaskGroup, TaskResult}, util::decode, }; use karyon_net::Conn; use crate::{ - codec::{Codec, CodecMsg}, config::Config, connection::{ConnDirection, ConnQueue}, message::{get_msg_payload, NetMsg, NetMsgCmd, VerAckMsg, VerMsg}, @@ -50,10 +51,10 @@ pub struct PeerPool { protocol_versions: Arc<RwLock<HashMap<ProtocolID, Version>>>, /// Managing spawned tasks. - task_group: TaskGroup<'static>, + task_group: TaskGroup, /// A global Executor - executor: Executor<'static>, + executor: Executor, /// The Configuration for the P2P network. pub(crate) config: Arc<Config>, @@ -69,7 +70,7 @@ impl PeerPool { conn_queue: Arc<ConnQueue>, config: Arc<Config>, monitor: Arc<Monitor>, - executor: Executor<'static>, + executor: Executor, ) -> Arc<Self> { let protocols = RwLock::new(HashMap::new()); let protocol_versions = Arc::new(RwLock::new(HashMap::new())); @@ -137,7 +138,7 @@ impl PeerPool { } /// Broadcast a message to all connected peers using the specified protocol. - pub async fn broadcast<T: CodecMsg>(&self, proto_id: &ProtocolID, msg: &T) { + pub async fn broadcast<T: Decode + Encode>(&self, proto_id: &ProtocolID, msg: &T) { for (pid, peer) in self.peers.lock().await.iter() { if let Err(err) = peer.send(proto_id, msg).await { error!("failed to send msg to {pid}: {err}"); @@ -149,15 +150,14 @@ impl PeerPool { /// Add a new peer to the peer list. pub async fn new_peer( self: &Arc<Self>, - conn: Conn, + conn: Conn<NetMsg>, conn_direction: &ConnDirection, disconnect_signal: Sender<Result<()>>, ) -> Result<()> { let endpoint = conn.peer_endpoint()?; - let codec = Codec::new(conn); // Do a handshake with the connection before creating a new peer. - let pid = self.do_handshake(&codec, conn_direction).await?; + let pid = self.do_handshake(&conn, conn_direction).await?; // TODO: Consider restricting the subnet for inbound connections if self.contains_peer(&pid).await { @@ -168,7 +168,7 @@ impl PeerPool { let peer = Peer::new( Arc::downgrade(self), &pid, - codec, + conn, endpoint.clone(), conn_direction.clone(), self.executor.clone(), @@ -234,16 +234,21 @@ impl PeerPool { } /// Initiate a handshake with a connection. - async fn do_handshake(&self, codec: &Codec, conn_direction: &ConnDirection) -> Result<PeerID> { + async fn do_handshake( + &self, + conn: &Conn<NetMsg>, + conn_direction: &ConnDirection, + ) -> Result<PeerID> { + trace!("Handshake started: {}", conn.peer_endpoint()?); match conn_direction { ConnDirection::Inbound => { - let result = self.wait_vermsg(codec).await; + let result = self.wait_vermsg(conn).await; match result { Ok(_) => { - self.send_verack(codec, true).await?; + self.send_verack(conn, true).await?; } Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => { - self.send_verack(codec, false).await?; + self.send_verack(conn, false).await?; } _ => {} } @@ -251,14 +256,14 @@ impl PeerPool { } ConnDirection::Outbound => { - self.send_vermsg(codec).await?; - self.wait_verack(codec).await + self.send_vermsg(conn).await?; + self.wait_verack(conn).await } } } /// Send a Version message - async fn send_vermsg(&self, codec: &Codec) -> Result<()> { + async fn send_vermsg(&self, conn: &Conn<NetMsg>) -> Result<()> { let pids = self.protocol_versions.read().await; let protocols = pids.iter().map(|p| (p.0.clone(), p.1.v.clone())).collect(); drop(pids); @@ -270,16 +275,16 @@ impl PeerPool { }; trace!("Send VerMsg"); - codec.write(NetMsgCmd::Version, &vermsg).await?; + conn.send(NetMsg::new(NetMsgCmd::Version, &vermsg)?).await?; Ok(()) } /// Wait for a Version message /// /// Returns the peer's ID upon successfully receiving the Version message. - async fn wait_vermsg(&self, codec: &Codec) -> Result<PeerID> { - let timeout = Duration::from_secs(self.config.handshake_timeout); - let msg: NetMsg = codec.read_timeout(timeout).await?; + async fn wait_vermsg(&self, conn: &Conn<NetMsg>) -> Result<PeerID> { + let t = Duration::from_secs(self.config.handshake_timeout); + let msg: NetMsg = timeout(t, conn.recv()).await??; let payload = get_msg_payload!(Version, msg); let (vermsg, _) = decode::<VerMsg>(&payload)?; @@ -295,23 +300,23 @@ impl PeerPool { } /// Send a Verack message - async fn send_verack(&self, codec: &Codec, ack: bool) -> Result<()> { + async fn send_verack(&self, conn: &Conn<NetMsg>, ack: bool) -> Result<()> { let verack = VerAckMsg { peer_id: self.id.clone(), ack, }; trace!("Send VerAckMsg {:?}", verack); - codec.write(NetMsgCmd::Verack, &verack).await?; + conn.send(NetMsg::new(NetMsgCmd::Verack, &verack)?).await?; Ok(()) } /// Wait for a Verack message /// /// Returns the peer's ID upon successfully receiving the Verack message. - async fn wait_verack(&self, codec: &Codec) -> Result<PeerID> { - let timeout = Duration::from_secs(self.config.handshake_timeout); - let msg: NetMsg = codec.read_timeout(timeout).await?; + async fn wait_verack(&self, conn: &Conn<NetMsg>) -> Result<PeerID> { + let t = Duration::from_secs(self.config.handshake_timeout); + let msg: NetMsg = timeout(t, conn.recv()).await??; let payload = get_msg_payload!(Verack, msg); let (verack, _) = decode::<VerAckMsg>(&payload)?; diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs index f28659c..6153ea1 100644 --- a/p2p/src/protocol.rs +++ b/p2p/src/protocol.rs @@ -92,7 +92,7 @@ impl EventValue for ProtocolEvent { /// let ex = Arc::new(Executor::new()); /// /// // Create a new Backend -/// let backend = Backend::new(&key_pair, config, ex); +/// let backend = Backend::new(&key_pair, config, ex.into()); /// /// // Attach the NewProtocol /// let c = move |peer| NewProtocol::new(peer); diff --git a/p2p/src/protocols/ping.rs b/p2p/src/protocols/ping.rs index f04e059..654644a 100644 --- a/p2p/src/protocols/ping.rs +++ b/p2p/src/protocols/ping.rs @@ -1,23 +1,19 @@ use std::{sync::Arc, time::Duration}; +use async_channel::{Receiver, Sender}; use async_trait::async_trait; use bincode::{Decode, Encode}; use log::trace; use rand::{rngs::OsRng, RngCore}; -use smol::{ - channel, - channel::{Receiver, Sender}, - stream::StreamExt, - Timer, -}; use karyon_core::{ - async_util::{select, timeout, Either, Executor, TaskGroup, TaskResult}, + async_runtime::Executor, + async_util::{select, sleep, timeout, Either, TaskGroup, TaskResult}, event::EventListener, util::decode, }; -use karyon_net::NetError; +use karyon_net::Error as NetError; use crate::{ peer::ArcPeer, @@ -38,12 +34,12 @@ pub struct PingProtocol { peer: ArcPeer, ping_interval: u64, ping_timeout: u64, - task_group: TaskGroup<'static>, + task_group: TaskGroup, } impl PingProtocol { #[allow(clippy::new_ret_no_self)] - pub fn new(peer: ArcPeer, executor: Executor<'static>) -> ArcProtocol { + pub fn new(peer: ArcPeer, executor: Executor) -> ArcProtocol { let ping_interval = peer.config().ping_interval; let ping_timeout = peer.config().ping_timeout; Arc::new(Self { @@ -87,12 +83,11 @@ impl PingProtocol { } async fn ping_loop(self: Arc<Self>, chan: Receiver<[u8; 32]>) -> Result<()> { - let mut timer = Timer::interval(Duration::from_secs(self.ping_interval)); let rng = &mut OsRng; let mut retry = 0; while retry < MAX_FAILUERS { - timer.next().await; + sleep(Duration::from_secs(self.ping_interval)).await; let mut ping_nonce: [u8; 32] = [0; 32]; rng.fill_bytes(&mut ping_nonce); @@ -130,8 +125,8 @@ impl Protocol for PingProtocol { async fn start(self: Arc<Self>) -> Result<()> { trace!("Start Ping protocol"); - let (pong_chan, pong_chan_recv) = channel::bounded(1); - let (stop_signal_s, stop_signal) = channel::bounded::<Result<()>>(1); + let (pong_chan, pong_chan_recv) = async_channel::bounded(1); + let (stop_signal_s, stop_signal) = async_channel::bounded::<Result<()>>(1); let selfc = self.clone(); self.task_group.spawn( diff --git a/p2p/src/routing_table/entry.rs b/p2p/src/routing_table/entry.rs index 3fc8a6b..1427c2b 100644 --- a/p2p/src/routing_table/entry.rs +++ b/p2p/src/routing_table/entry.rs @@ -5,6 +5,9 @@ use karyon_net::{Addr, Port}; /// Specifies the size of the key, in bytes. pub const KEY_SIZE: usize = 32; +/// The unique key identifying the peer. +pub type Key = [u8; KEY_SIZE]; + /// An Entry represents a peer in the routing table. #[derive(Encode, Decode, Clone, Debug)] pub struct Entry { @@ -20,14 +23,11 @@ pub struct Entry { impl PartialEq for Entry { fn eq(&self, other: &Self) -> bool { - // TODO: this should also compare both addresses (the self.addr == other.addr) + // XXX: should we compare both self.addr and other.addr??? self.key == other.key } } -/// The unique key identifying the peer. -pub type Key = [u8; KEY_SIZE]; - /// Calculates the XOR distance between two provided keys. /// /// The XOR distance is a metric used in Kademlia to measure the closeness diff --git a/p2p/src/routing_table/mod.rs b/p2p/src/routing_table/mod.rs index 6854546..bbf4801 100644 --- a/p2p/src/routing_table/mod.rs +++ b/p2p/src/routing_table/mod.rs @@ -266,7 +266,7 @@ impl RoutingTable { } /// Check if two addresses belong to the same subnet. -pub fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool { +fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool { match (addr, other_addr) { (Addr::Ip(IpAddr::V4(ip)), Addr::Ip(IpAddr::V4(other_ip))) => { // TODO: Consider moving this to a different place @@ -275,6 +275,8 @@ pub fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool { } ip.octets()[0..3] == other_ip.octets()[0..3] } + + // TODO: check ipv6 _ => false, } } diff --git a/p2p/src/tls_config.rs b/p2p/src/tls_config.rs index 893c321..65d2adc 100644 --- a/p2p/src/tls_config.rs +++ b/p2p/src/tls_config.rs @@ -1,19 +1,25 @@ use std::sync::Arc; -use futures_rustls::rustls::{ - self, +#[cfg(feature = "smol")] +use futures_rustls::rustls; + +#[cfg(feature = "tokio")] +use tokio_rustls::rustls; + +use rustls::{ client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, crypto::{ aws_lc_rs::{self, cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, kx_group}, CryptoProvider, SupportedKxGroup, }, - pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}, server::danger::{ClientCertVerified, ClientCertVerifier}, CertificateError, DigitallySignedStruct, DistinguishedName, Error::InvalidCertificate, SignatureScheme, SupportedCipherSuite, SupportedProtocolVersion, }; +use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; + use log::error; use x509_parser::{certificate::X509Certificate, parse_x509_certificate}; |