diff options
Diffstat (limited to 'p2p/src')
-rw-r--r-- | p2p/src/backend.rs | 139 | ||||
-rw-r--r-- | p2p/src/config.rs | 105 | ||||
-rw-r--r-- | p2p/src/discovery/lookup.rs | 366 | ||||
-rw-r--r-- | p2p/src/discovery/mod.rs | 262 | ||||
-rw-r--r-- | p2p/src/discovery/refresh.rs | 289 | ||||
-rw-r--r-- | p2p/src/error.rs | 82 | ||||
-rw-r--r-- | p2p/src/io_codec.rs | 102 | ||||
-rw-r--r-- | p2p/src/lib.rs | 27 | ||||
-rw-r--r-- | p2p/src/message.rs | 133 | ||||
-rw-r--r-- | p2p/src/monitor.rs | 154 | ||||
-rw-r--r-- | p2p/src/net/connection_queue.rs | 52 | ||||
-rw-r--r-- | p2p/src/net/connector.rs | 125 | ||||
-rw-r--r-- | p2p/src/net/listener.rs | 142 | ||||
-rw-r--r-- | p2p/src/net/mod.rs | 27 | ||||
-rw-r--r-- | p2p/src/net/slots.rs | 54 | ||||
-rw-r--r-- | p2p/src/peer/mod.rs | 237 | ||||
-rw-r--r-- | p2p/src/peer/peer_id.rs | 41 | ||||
-rw-r--r-- | p2p/src/peer_pool.rs | 337 | ||||
-rw-r--r-- | p2p/src/protocol.rs | 113 | ||||
-rw-r--r-- | p2p/src/protocols/mod.rs | 3 | ||||
-rw-r--r-- | p2p/src/protocols/ping.rs | 173 | ||||
-rw-r--r-- | p2p/src/routing_table/bucket.rs | 123 | ||||
-rw-r--r-- | p2p/src/routing_table/entry.rs | 41 | ||||
-rw-r--r-- | p2p/src/routing_table/mod.rs | 461 | ||||
-rw-r--r-- | p2p/src/utils/mod.rs | 21 | ||||
-rw-r--r-- | p2p/src/utils/version.rs | 93 |
26 files changed, 3702 insertions, 0 deletions
diff --git a/p2p/src/backend.rs b/p2p/src/backend.rs new file mode 100644 index 0000000..290e3e7 --- /dev/null +++ b/p2p/src/backend.rs @@ -0,0 +1,139 @@ +use std::sync::Arc; + +use log::info; + +use karyons_core::{pubsub::Subscription, Executor}; + +use crate::{ + config::Config, + discovery::{ArcDiscovery, Discovery}, + monitor::{Monitor, MonitorEvent}, + net::ConnQueue, + peer_pool::PeerPool, + protocol::{ArcProtocol, Protocol}, + ArcPeer, PeerID, Result, +}; + +pub type ArcBackend = Arc<Backend>; + +/// Backend serves as the central entry point for initiating and managing +/// the P2P network. +/// +/// +/// # Example +/// ``` +/// use std::sync::Arc; +/// +/// use easy_parallel::Parallel; +/// use smol::{channel as smol_channel, future, Executor}; +/// +/// use karyons_p2p::{Backend, Config, PeerID}; +/// +/// let peer_id = PeerID::random(); +/// +/// // Create the configuration for the backend. +/// let mut config = Config::default(); +/// +/// // Create a new Backend +/// let backend = Backend::new(peer_id, config); +/// +/// // Create a new Executor +/// let ex = Arc::new(Executor::new()); +/// +/// let task = async { +/// // Run the backend +/// backend.run(ex.clone()).await.unwrap(); +/// +/// // .... +/// +/// // Shutdown the backend +/// backend.shutdown().await; +/// }; +/// +/// future::block_on(ex.run(task)); +/// +/// ``` +pub struct Backend { + /// The Configuration for the P2P network. + config: Arc<Config>, + + /// Peer ID. + id: PeerID, + + /// Responsible for network and system monitoring. + monitor: Arc<Monitor>, + + /// Discovery instance. + discovery: ArcDiscovery, + + /// PeerPool instance. + peer_pool: Arc<PeerPool>, +} + +impl Backend { + /// Creates a new Backend. + pub fn new(id: PeerID, config: Config) -> ArcBackend { + let config = Arc::new(config); + let monitor = Arc::new(Monitor::new()); + + let conn_queue = ConnQueue::new(); + + let peer_pool = PeerPool::new(&id, conn_queue.clone(), config.clone(), monitor.clone()); + let discovery = Discovery::new(&id, conn_queue, config.clone(), monitor.clone()); + + Arc::new(Self { + id: id.clone(), + monitor, + discovery, + config, + peer_pool, + }) + } + + /// Run the Backend, starting the PeerPool and Discovery instances. + pub async fn run(self: &Arc<Self>, ex: Executor<'_>) -> Result<()> { + info!("Run the backend {}", self.id); + self.peer_pool.start(ex.clone()).await?; + self.discovery.start(ex.clone()).await?; + Ok(()) + } + + /// Attach a custom protocol to the network + pub async fn attach_protocol<P: Protocol>( + &self, + c: impl Fn(ArcPeer) -> ArcProtocol + Send + Sync + 'static, + ) -> Result<()> { + self.peer_pool.attach_protocol::<P>(Box::new(c)).await + } + + /// Returns the number of currently connected peers. + pub async fn peers(&self) -> usize { + self.peer_pool.peers_len().await + } + + /// Returns the `Config`. + pub fn config(&self) -> Arc<Config> { + self.config.clone() + } + + /// Returns the number of occupied inbound slots. + pub fn inbound_slots(&self) -> usize { + self.discovery.inbound_slots.load() + } + + /// Returns the number of occupied outbound slots. + pub fn outbound_slots(&self) -> usize { + self.discovery.outbound_slots.load() + } + + /// Subscribes to the monitor to receive network events. + pub async fn monitor(&self) -> Subscription<MonitorEvent> { + self.monitor.subscribe().await + } + + /// Shuts down the Backend. + pub async fn shutdown(&self) { + self.discovery.shutdown().await; + self.peer_pool.shutdown().await; + } +} diff --git a/p2p/src/config.rs b/p2p/src/config.rs new file mode 100644 index 0000000..ebecbf0 --- /dev/null +++ b/p2p/src/config.rs @@ -0,0 +1,105 @@ +use karyons_net::{Endpoint, Port}; + +use crate::utils::Version; + +/// the Configuration for the P2P network. +pub struct Config { + /// Represents the network version. + pub version: Version, + + ///////////////// + // PeerPool + //////////////// + /// Timeout duration for the handshake with new peers, in seconds. + pub handshake_timeout: u64, + /// Interval at which the ping protocol sends ping messages to a peer to + /// maintain connections, in seconds. + pub ping_interval: u64, + /// Timeout duration for receiving the pong message corresponding to the + /// sent ping message, in seconds. + pub ping_timeout: u64, + /// The maximum number of retries for outbound connection establishment. + pub max_connect_retries: usize, + + ///////////////// + // DISCOVERY + //////////////// + /// A list of bootstrap peers for the seeding process. + pub bootstrap_peers: Vec<Endpoint>, + /// An optional listening endpoint to accept incoming connections. + pub listen_endpoint: Option<Endpoint>, + /// A list of endpoints representing peers that the `Discovery` will + /// manually connect to. + pub peer_endpoints: Vec<Endpoint>, + /// The number of available inbound slots for incoming connections. + pub inbound_slots: usize, + /// The number of available outbound slots for outgoing connections. + pub outbound_slots: usize, + /// TCP/UDP port for lookup and refresh processes. + pub discovery_port: Port, + /// Time interval, in seconds, at which the Discovery restarts the + /// seeding process. + pub seeding_interval: u64, + + ///////////////// + // LOOKUP + //////////////// + /// The number of available inbound slots for incoming connections during + /// the lookup process. + pub lookup_inbound_slots: usize, + /// The number of available outbound slots for outgoing connections during + /// the lookup process. + pub lookup_outbound_slots: usize, + /// Timeout duration for a peer response during the lookup process, in + /// seconds. + pub lookup_response_timeout: u64, + /// Maximum allowable time for a live connection with a peer during the + /// lookup process, in seconds. + pub lookup_connection_lifespan: u64, + /// The maximum number of retries for outbound connection establishment + /// during the lookup process. + pub lookup_connect_retries: usize, + + ///////////////// + // REFRESH + //////////////// + /// Interval at which the table refreshes its entries, in seconds. + pub refresh_interval: u64, + /// Timeout duration for a peer response during the table refresh process, + /// in seconds. + pub refresh_response_timeout: u64, + /// The maximum number of retries for outbound connection establishment + /// during the refresh process. + pub refresh_connect_retries: usize, +} + +impl Default for Config { + fn default() -> Self { + Config { + version: "0.1.0".parse().unwrap(), + + handshake_timeout: 2, + ping_interval: 20, + ping_timeout: 2, + + bootstrap_peers: vec![], + listen_endpoint: None, + peer_endpoints: vec![], + inbound_slots: 12, + outbound_slots: 12, + max_connect_retries: 3, + discovery_port: 0, + seeding_interval: 60, + + lookup_inbound_slots: 20, + lookup_outbound_slots: 20, + lookup_response_timeout: 1, + lookup_connection_lifespan: 3, + lookup_connect_retries: 3, + + refresh_interval: 1800, + refresh_response_timeout: 1, + refresh_connect_retries: 3, + } + } +} diff --git a/p2p/src/discovery/lookup.rs b/p2p/src/discovery/lookup.rs new file mode 100644 index 0000000..f404133 --- /dev/null +++ b/p2p/src/discovery/lookup.rs @@ -0,0 +1,366 @@ +use std::{sync::Arc, time::Duration}; + +use futures_util::{stream::FuturesUnordered, StreamExt}; +use log::{error, trace}; +use rand::{rngs::OsRng, seq::SliceRandom, RngCore}; +use smol::lock::{Mutex, RwLock}; + +use karyons_core::{async_utils::timeout, utils::decode, Executor}; + +use karyons_net::{Conn, Endpoint}; + +use crate::{ + io_codec::IOCodec, + message::{ + get_msg_payload, FindPeerMsg, NetMsg, NetMsgCmd, PeerMsg, PeersMsg, PingMsg, PongMsg, + ShutdownMsg, + }, + monitor::{ConnEvent, DiscoveryEvent, Monitor}, + net::{ConnectionSlots, Connector, Listener}, + routing_table::RoutingTable, + utils::version_match, + Config, Error, PeerID, Result, +}; + +/// Maximum number of peers that can be returned in a PeersMsg. +pub const MAX_PEERS_IN_PEERSMSG: usize = 10; + +pub struct LookupService { + /// Peer's ID + id: PeerID, + + /// Routing Table + table: Arc<Mutex<RoutingTable>>, + + /// Listener + listener: Arc<Listener>, + /// Connector + connector: Arc<Connector>, + + /// Outbound slots. + outbound_slots: Arc<ConnectionSlots>, + + /// Resolved listen endpoint + listen_endpoint: Option<RwLock<Endpoint>>, + + /// Holds the configuration for the P2P network. + config: Arc<Config>, + + /// Responsible for network and system monitoring. + monitor: Arc<Monitor>, +} + +impl LookupService { + /// Creates a new lookup service + pub fn new( + id: &PeerID, + table: Arc<Mutex<RoutingTable>>, + config: Arc<Config>, + monitor: Arc<Monitor>, + ) -> Self { + let inbound_slots = Arc::new(ConnectionSlots::new(config.lookup_inbound_slots)); + let outbound_slots = Arc::new(ConnectionSlots::new(config.lookup_outbound_slots)); + + let listener = Listener::new(inbound_slots.clone(), monitor.clone()); + let connector = Connector::new( + config.lookup_connect_retries, + outbound_slots.clone(), + monitor.clone(), + ); + + let listen_endpoint = config + .listen_endpoint + .as_ref() + .map(|endpoint| RwLock::new(endpoint.clone())); + + Self { + id: id.clone(), + table, + listener, + connector, + outbound_slots, + listen_endpoint, + config, + monitor, + } + } + + /// Start the lookup service. + pub async fn start(self: &Arc<Self>, ex: Executor<'_>) -> Result<()> { + self.start_listener(ex).await?; + Ok(()) + } + + /// Set the resolved listen endpoint. + pub async fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) { + if let Some(endpoint) = &self.listen_endpoint { + *endpoint.write().await = resolved_endpoint.clone(); + } + } + + /// Shuts down the lookup service. + pub async fn shutdown(&self) { + self.connector.shutdown().await; + self.listener.shutdown().await; + } + + /// Starts iterative lookup and populate the routing table. + /// + /// This method begins by generating a random peer ID and connecting to the + /// provided endpoint. It then sends a FindPeer message containing the + /// randomly generated peer ID. Upon receiving peers from the initial lookup, + /// it starts connecting to these received peers and sends them a FindPeer + /// message that contains our own peer ID. + pub async fn start_lookup(&self, endpoint: &Endpoint) -> Result<()> { + trace!("Lookup started {endpoint}"); + self.monitor + .notify(&DiscoveryEvent::LookupStarted(endpoint.clone()).into()) + .await; + + let mut random_peers = vec![]; + if let Err(err) = self.random_lookup(endpoint, &mut random_peers).await { + self.monitor + .notify(&DiscoveryEvent::LookupFailed(endpoint.clone()).into()) + .await; + return Err(err); + }; + + let mut peer_buffer = vec![]; + self.self_lookup(&random_peers, &mut peer_buffer).await; + + while peer_buffer.len() < MAX_PEERS_IN_PEERSMSG { + match random_peers.pop() { + Some(p) => peer_buffer.push(p), + None => break, + } + } + + for peer in peer_buffer.iter() { + let mut table = self.table.lock().await; + let result = table.add_entry(peer.clone().into()); + trace!("Add entry {:?}", result); + } + + self.monitor + .notify(&DiscoveryEvent::LookupSucceeded(endpoint.clone(), peer_buffer.len()).into()) + .await; + + Ok(()) + } + + /// Starts a random lookup + /// + /// This will perfom lookup on a random generated PeerID + async fn random_lookup( + &self, + endpoint: &Endpoint, + random_peers: &mut Vec<PeerMsg>, + ) -> Result<()> { + for _ in 0..2 { + let peer_id = PeerID::random(); + let peers = self.connect(&peer_id, endpoint.clone()).await?; + for peer in peers { + if random_peers.contains(&peer) + || peer.peer_id == self.id + || self.table.lock().await.contains_key(&peer.peer_id.0) + { + continue; + } + + random_peers.push(peer); + } + } + + Ok(()) + } + + /// Starts a self lookup + async fn self_lookup(&self, random_peers: &Vec<PeerMsg>, peer_buffer: &mut Vec<PeerMsg>) { + let mut tasks = FuturesUnordered::new(); + for peer in random_peers.choose_multiple(&mut OsRng, random_peers.len()) { + let endpoint = Endpoint::Tcp(peer.addr.clone(), peer.discovery_port); + tasks.push(self.connect(&self.id, endpoint)) + } + + while let Some(result) = tasks.next().await { + match result { + Ok(peers) => peer_buffer.extend(peers), + Err(err) => { + error!("Failed to do self lookup: {err}"); + } + } + } + } + + /// Connects to the given endpoint + async fn connect(&self, peer_id: &PeerID, endpoint: Endpoint) -> Result<Vec<PeerMsg>> { + let conn = self.connector.connect(&endpoint).await?; + let io_codec = IOCodec::new(conn); + let result = self.handle_outbound(io_codec, peer_id).await; + + self.monitor + .notify(&ConnEvent::Disconnected(endpoint).into()) + .await; + self.outbound_slots.remove().await; + + result + } + + /// Handles outbound connection + async fn handle_outbound(&self, io_codec: IOCodec, peer_id: &PeerID) -> Result<Vec<PeerMsg>> { + trace!("Send Ping msg"); + self.send_ping_msg(&io_codec).await?; + + trace!("Send FindPeer msg"); + let peers = self.send_findpeer_msg(&io_codec, peer_id).await?; + + if peers.0.len() >= MAX_PEERS_IN_PEERSMSG { + return Err(Error::Lookup("Received too many peers in PeersMsg")); + } + + trace!("Send Peer msg"); + if let Some(endpoint) = &self.listen_endpoint { + self.send_peer_msg(&io_codec, endpoint.read().await.clone()) + .await?; + } + + trace!("Send Shutdown msg"); + self.send_shutdown_msg(&io_codec).await?; + + Ok(peers.0) + } + + /// Start a listener. + async fn start_listener(self: &Arc<Self>, ex: Executor<'_>) -> Result<()> { + let addr = match &self.listen_endpoint { + Some(a) => a.read().await.addr()?.clone(), + None => return Ok(()), + }; + + let endpoint = Endpoint::Tcp(addr, self.config.discovery_port); + + let selfc = self.clone(); + let callback = |conn: Conn| async move { + let t = Duration::from_secs(selfc.config.lookup_connection_lifespan); + timeout(t, selfc.handle_inbound(conn)).await??; + Ok(()) + }; + + self.listener.start(ex, endpoint.clone(), callback).await?; + Ok(()) + } + + /// Handles inbound connection + async fn handle_inbound(self: &Arc<Self>, conn: Conn) -> Result<()> { + let io_codec = IOCodec::new(conn); + loop { + let msg: NetMsg = io_codec.read().await?; + trace!("Receive msg {:?}", msg.header.command); + + if let NetMsgCmd::Shutdown = msg.header.command { + return Ok(()); + } + + match &msg.header.command { + NetMsgCmd::Ping => { + let (ping_msg, _) = decode::<PingMsg>(&msg.payload)?; + if !version_match(&self.config.version.req, &ping_msg.version) { + return Err(Error::IncompatibleVersion("system: {}".into())); + } + self.send_pong_msg(ping_msg.nonce, &io_codec).await?; + } + NetMsgCmd::FindPeer => { + let (findpeer_msg, _) = decode::<FindPeerMsg>(&msg.payload)?; + let peer_id = findpeer_msg.0; + self.send_peers_msg(&peer_id, &io_codec).await?; + } + NetMsgCmd::Peer => { + let (peer, _) = decode::<PeerMsg>(&msg.payload)?; + let result = self.table.lock().await.add_entry(peer.clone().into()); + trace!("Add entry result: {:?}", result); + } + c => return Err(Error::InvalidMsg(format!("Unexpected msg: {:?}", c))), + } + } + } + + /// Sends a Ping msg and wait to receive the Pong message. + async fn send_ping_msg(&self, io_codec: &IOCodec) -> Result<()> { + trace!("Send Pong msg"); + + let mut nonce: [u8; 32] = [0; 32]; + RngCore::fill_bytes(&mut OsRng, &mut nonce); + + let ping_msg = PingMsg { + version: self.config.version.v.clone(), + nonce, + }; + io_codec.write(NetMsgCmd::Ping, &ping_msg).await?; + + let t = Duration::from_secs(self.config.lookup_response_timeout); + let recv_msg: NetMsg = io_codec.read_timeout(t).await?; + + let payload = get_msg_payload!(Pong, recv_msg); + let (pong_msg, _) = decode::<PongMsg>(&payload)?; + + if ping_msg.nonce != pong_msg.0 { + return Err(Error::InvalidPongMsg); + } + + Ok(()) + } + + /// Sends a Pong msg + async fn send_pong_msg(&self, nonce: [u8; 32], io_codec: &IOCodec) -> Result<()> { + trace!("Send Pong msg"); + io_codec.write(NetMsgCmd::Pong, &PongMsg(nonce)).await?; + Ok(()) + } + + /// Sends a FindPeer msg and wait to receivet the Peers msg. + async fn send_findpeer_msg(&self, io_codec: &IOCodec, peer_id: &PeerID) -> Result<PeersMsg> { + trace!("Send FindPeer msg"); + io_codec + .write(NetMsgCmd::FindPeer, &FindPeerMsg(peer_id.clone())) + .await?; + + let t = Duration::from_secs(self.config.lookup_response_timeout); + let recv_msg: NetMsg = io_codec.read_timeout(t).await?; + + let payload = get_msg_payload!(Peers, recv_msg); + let (peers, _) = decode(&payload)?; + + Ok(peers) + } + + /// Sends a Peers msg. + async fn send_peers_msg(&self, peer_id: &PeerID, io_codec: &IOCodec) -> Result<()> { + trace!("Send Peers msg"); + let table = self.table.lock().await; + let entries = table.closest_entries(&peer_id.0, MAX_PEERS_IN_PEERSMSG); + let peers: Vec<PeerMsg> = entries.into_iter().map(|e| e.into()).collect(); + drop(table); + io_codec.write(NetMsgCmd::Peers, &PeersMsg(peers)).await?; + Ok(()) + } + + /// Sends a Peer msg. + async fn send_peer_msg(&self, io_codec: &IOCodec, endpoint: Endpoint) -> Result<()> { + trace!("Send Peer msg"); + let peer_msg = PeerMsg { + addr: endpoint.addr()?.clone(), + port: *endpoint.port()?, + discovery_port: self.config.discovery_port, + peer_id: self.id.clone(), + }; + io_codec.write(NetMsgCmd::Peer, &peer_msg).await?; + Ok(()) + } + + /// Sends a Shutdown msg. + async fn send_shutdown_msg(&self, io_codec: &IOCodec) -> Result<()> { + trace!("Send Shutdown msg"); + io_codec.write(NetMsgCmd::Shutdown, &ShutdownMsg(0)).await?; + Ok(()) + } +} diff --git a/p2p/src/discovery/mod.rs b/p2p/src/discovery/mod.rs new file mode 100644 index 0000000..94b350b --- /dev/null +++ b/p2p/src/discovery/mod.rs @@ -0,0 +1,262 @@ +mod lookup; +mod refresh; + +use std::sync::Arc; + +use log::{error, info}; +use rand::{rngs::OsRng, seq::SliceRandom}; +use smol::lock::Mutex; + +use karyons_core::{ + async_utils::{Backoff, TaskGroup, TaskResult}, + Executor, +}; + +use karyons_net::{Conn, Endpoint}; + +use crate::{ + config::Config, + monitor::Monitor, + net::ConnQueue, + net::{ConnDirection, ConnectionSlots, Connector, Listener}, + routing_table::{ + Entry, EntryStatusFlag, RoutingTable, CONNECTED_ENTRY, DISCONNECTED_ENTRY, PENDING_ENTRY, + UNREACHABLE_ENTRY, UNSTABLE_ENTRY, + }, + Error, PeerID, Result, +}; + +use lookup::LookupService; +use refresh::RefreshService; + +pub type ArcDiscovery = Arc<Discovery>; + +pub struct Discovery { + /// Routing table + table: Arc<Mutex<RoutingTable>>, + + /// Lookup Service + lookup_service: Arc<LookupService>, + + /// Refresh Service + refresh_service: Arc<RefreshService>, + + /// Connector + connector: Arc<Connector>, + /// Listener + listener: Arc<Listener>, + + /// Connection queue + conn_queue: Arc<ConnQueue>, + + /// Inbound slots. + pub(crate) inbound_slots: Arc<ConnectionSlots>, + /// Outbound slots. + pub(crate) outbound_slots: Arc<ConnectionSlots>, + + /// Managing spawned tasks. + task_group: TaskGroup, + + /// Holds the configuration for the P2P network. + config: Arc<Config>, +} + +impl Discovery { + /// Creates a new Discovery + pub fn new( + peer_id: &PeerID, + conn_queue: Arc<ConnQueue>, + config: Arc<Config>, + monitor: Arc<Monitor>, + ) -> ArcDiscovery { + let inbound_slots = Arc::new(ConnectionSlots::new(config.inbound_slots)); + let outbound_slots = Arc::new(ConnectionSlots::new(config.outbound_slots)); + + let table_key = peer_id.0; + let table = Arc::new(Mutex::new(RoutingTable::new(table_key))); + + let refresh_service = RefreshService::new(config.clone(), table.clone(), monitor.clone()); + let lookup_service = + LookupService::new(peer_id, table.clone(), config.clone(), monitor.clone()); + + let connector = Connector::new( + config.max_connect_retries, + outbound_slots.clone(), + monitor.clone(), + ); + let listener = Listener::new(inbound_slots.clone(), monitor.clone()); + + Arc::new(Self { + refresh_service: Arc::new(refresh_service), + lookup_service: Arc::new(lookup_service), + conn_queue, + table, + inbound_slots, + outbound_slots, + connector, + listener, + task_group: TaskGroup::new(), + config, + }) + } + + /// Start the Discovery + pub async fn start(self: &Arc<Self>, ex: Executor<'_>) -> Result<()> { + // Check if the listen_endpoint is provided, and if so, start a listener. + if let Some(endpoint) = &self.config.listen_endpoint { + // Return an error if the discovery port is set to 0. + if self.config.discovery_port == 0 { + return Err(Error::Config( + "Please add a valid discovery port".to_string(), + )); + } + + let resolved_endpoint = self.start_listener(endpoint, ex.clone()).await?; + + if endpoint.addr()? != resolved_endpoint.addr()? { + info!("Resolved listen endpoint: {resolved_endpoint}"); + self.lookup_service + .set_listen_endpoint(&resolved_endpoint) + .await; + self.refresh_service + .set_listen_endpoint(&resolved_endpoint) + .await; + } + } + + // Start the lookup service + self.lookup_service.start(ex.clone()).await?; + // Start the refresh service + self.refresh_service.start(ex.clone()).await?; + + // Attempt to manually connect to peer endpoints provided in the Config. + for endpoint in self.config.peer_endpoints.iter() { + let _ = self.connect(endpoint, None, ex.clone()).await; + } + + // Start connect loop + let selfc = self.clone(); + self.task_group + .spawn(ex.clone(), selfc.connect_loop(ex), |res| async move { + if let TaskResult::Completed(Err(err)) = res { + error!("Connect loop stopped: {err}"); + } + }); + + Ok(()) + } + + /// Shuts down the discovery + pub async fn shutdown(&self) { + self.task_group.cancel().await; + self.connector.shutdown().await; + self.listener.shutdown().await; + + self.refresh_service.shutdown().await; + self.lookup_service.shutdown().await; + } + + /// Start a listener and on success, return the resolved endpoint. + async fn start_listener( + self: &Arc<Self>, + endpoint: &Endpoint, + ex: Executor<'_>, + ) -> Result<Endpoint> { + let selfc = self.clone(); + let callback = |conn: Conn| async move { + selfc.conn_queue.handle(conn, ConnDirection::Inbound).await; + Ok(()) + }; + + let resolved_endpoint = self.listener.start(ex, endpoint.clone(), callback).await?; + Ok(resolved_endpoint) + } + + /// This method will attempt to connect to a peer in the routing table. + /// If the routing table is empty, it will start the seeding process for + /// finding new peers. + /// + /// This will perform a backoff to prevent getting stuck in the loop + /// if the seeding process couldn't find any peers. + async fn connect_loop(self: Arc<Self>, ex: Executor<'_>) -> Result<()> { + let backoff = Backoff::new(500, self.config.seeding_interval * 1000); + loop { + let random_entry = self.random_entry(PENDING_ENTRY).await; + match random_entry { + Some(entry) => { + backoff.reset(); + let endpoint = Endpoint::Tcp(entry.addr, entry.port); + self.connect(&endpoint, Some(entry.key.into()), ex.clone()) + .await; + } + None => { + backoff.sleep().await; + self.start_seeding().await; + } + } + } + } + + /// Connect to the given endpoint using the connector + async fn connect(self: &Arc<Self>, endpoint: &Endpoint, pid: Option<PeerID>, ex: Executor<'_>) { + let selfc = self.clone(); + let pid_cloned = pid.clone(); + let cback = |conn: Conn| async move { + selfc.conn_queue.handle(conn, ConnDirection::Outbound).await; + if let Some(pid) = pid_cloned { + selfc.update_entry(&pid, DISCONNECTED_ENTRY).await; + } + Ok(()) + }; + + let res = self.connector.connect_with_cback(ex, endpoint, cback).await; + + if let Some(pid) = &pid { + match res { + Ok(_) => { + self.update_entry(pid, CONNECTED_ENTRY).await; + } + Err(_) => { + self.update_entry(pid, UNREACHABLE_ENTRY).await; + } + } + } + } + + /// Starts seeding process. + /// + /// This method randomly selects a peer from the routing table and + /// attempts to connect to that peer for the initial lookup. If the routing + /// table doesn't have an available entry, it will connect to one of the + /// provided bootstrap endpoints in the `Config` and initiate the lookup. + async fn start_seeding(&self) { + match self.random_entry(PENDING_ENTRY | CONNECTED_ENTRY).await { + Some(entry) => { + let endpoint = Endpoint::Tcp(entry.addr, entry.discovery_port); + if let Err(err) = self.lookup_service.start_lookup(&endpoint).await { + self.update_entry(&entry.key.into(), UNSTABLE_ENTRY).await; + error!("Failed to do lookup: {endpoint}: {err}"); + } + } + None => { + let peers = &self.config.bootstrap_peers; + for endpoint in peers.choose_multiple(&mut OsRng, peers.len()) { + if let Err(err) = self.lookup_service.start_lookup(endpoint).await { + error!("Failed to do lookup: {endpoint}: {err}"); + } + } + } + } + } + + /// Returns a random entry from routing table. + async fn random_entry(&self, entry_flag: EntryStatusFlag) -> Option<Entry> { + self.table.lock().await.random_entry(entry_flag).cloned() + } + + /// Update the entry status + async fn update_entry(&self, pid: &PeerID, entry_flag: EntryStatusFlag) { + let table = &mut self.table.lock().await; + table.update_entry(&pid.0, entry_flag); + } +} diff --git a/p2p/src/discovery/refresh.rs b/p2p/src/discovery/refresh.rs new file mode 100644 index 0000000..7582c84 --- /dev/null +++ b/p2p/src/discovery/refresh.rs @@ -0,0 +1,289 @@ +use std::{sync::Arc, time::Duration}; + +use bincode::{Decode, Encode}; +use log::{error, info, trace}; +use rand::{rngs::OsRng, RngCore}; +use smol::{ + lock::{Mutex, RwLock}, + stream::StreamExt, + Timer, +}; + +use karyons_core::{ + async_utils::{timeout, Backoff, TaskGroup, TaskResult}, + utils::{decode, encode}, + Executor, +}; + +use karyons_net::{dial_udp, listen_udp, Addr, Connection, Endpoint, NetError, Port, UdpConn}; + +/// Maximum failures for an entry before removing it from the routing table. +pub const MAX_FAILURES: u32 = 3; + +/// Ping message size +const PINGMSG_SIZE: usize = 32; + +use crate::{ + monitor::{ConnEvent, DiscoveryEvent, Monitor}, + routing_table::{BucketEntry, Entry, RoutingTable, PENDING_ENTRY, UNREACHABLE_ENTRY}, + Config, Error, Result, +}; + +#[derive(Decode, Encode, Debug, Clone)] +pub struct PingMsg(pub [u8; 32]); + +#[derive(Decode, Encode, Debug)] +pub struct PongMsg(pub [u8; 32]); + +pub struct RefreshService { + /// Routing table + table: Arc<Mutex<RoutingTable>>, + + /// Resolved listen endpoint + listen_endpoint: Option<RwLock<Endpoint>>, + + /// Managing spawned tasks. + task_group: TaskGroup, + + /// Holds the configuration for the P2P network. + config: Arc<Config>, + + /// Responsible for network and system monitoring. + monitor: Arc<Monitor>, +} + +impl RefreshService { + /// Creates a new refresh service + pub fn new( + config: Arc<Config>, + table: Arc<Mutex<RoutingTable>>, + monitor: Arc<Monitor>, + ) -> Self { + let listen_endpoint = config + .listen_endpoint + .as_ref() + .map(|endpoint| RwLock::new(endpoint.clone())); + + Self { + table, + listen_endpoint, + task_group: TaskGroup::new(), + config, + monitor, + } + } + + /// Start the refresh service + pub async fn start(self: &Arc<Self>, ex: Executor<'_>) -> Result<()> { + if let Some(endpoint) = &self.listen_endpoint { + let endpoint = endpoint.read().await; + let addr = endpoint.addr()?; + let port = self.config.discovery_port; + + let selfc = self.clone(); + self.task_group.spawn( + ex.clone(), + selfc.listen_loop(addr.clone(), port), + |res| async move { + if let TaskResult::Completed(Err(err)) = res { + error!("Listen loop stopped: {err}"); + } + }, + ); + } + + let selfc = self.clone(); + self.task_group.spawn( + ex.clone(), + selfc.refresh_loop(ex.clone()), + |res| async move { + if let TaskResult::Completed(Err(err)) = res { + error!("Refresh loop stopped: {err}"); + } + }, + ); + + Ok(()) + } + + /// Set the resolved listen endpoint. + pub async fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) { + if let Some(endpoint) = &self.listen_endpoint { + *endpoint.write().await = resolved_endpoint.clone(); + } + } + + /// Shuts down the refresh service + pub async fn shutdown(&self) { + self.task_group.cancel().await; + } + + /// Initiates periodic refreshing of the routing table. This function will + /// select 8 random entries from each bucket in the routing table and start + /// sending Ping messages to the entries. + async fn refresh_loop(self: Arc<Self>, ex: Executor<'_>) -> Result<()> { + let mut timer = Timer::interval(Duration::from_secs(self.config.refresh_interval)); + loop { + timer.next().await; + trace!("Start refreshing the routing table..."); + + self.monitor + .notify(&DiscoveryEvent::RefreshStarted.into()) + .await; + + let table = self.table.lock().await; + let mut entries: Vec<BucketEntry> = vec![]; + for bucket in table.iter() { + for entry in bucket.random_iter(8) { + entries.push(entry.clone()) + } + } + drop(table); + + self.clone().do_refresh(&entries, ex.clone()).await; + } + } + + /// Iterates over the entries and spawns a new task for each entry to + /// initiate a connection attempt to that entry. + async fn do_refresh(self: Arc<Self>, entries: &[BucketEntry], ex: Executor<'_>) { + for chunk in entries.chunks(16) { + let mut tasks = Vec::new(); + for bucket_entry in chunk { + if bucket_entry.is_connected() { + continue; + } + + if bucket_entry.failures >= MAX_FAILURES { + self.table + .lock() + .await + .remove_entry(&bucket_entry.entry.key); + return; + } + + tasks.push(ex.spawn(self.clone().refresh_entry(bucket_entry.clone()))) + } + + for task in tasks { + task.await; + } + } + } + + /// Initiates refresh for a specific entry within the routing table. It + /// updates the routing table according to the result. + async fn refresh_entry(self: Arc<Self>, bucket_entry: BucketEntry) { + let key = &bucket_entry.entry.key; + match self.connect(&bucket_entry.entry).await { + Ok(_) => { + self.table.lock().await.update_entry(key, PENDING_ENTRY); + } + Err(err) => { + trace!("Failed to refresh entry {:?}: {err}", key); + let table = &mut self.table.lock().await; + if bucket_entry.failures >= MAX_FAILURES { + table.remove_entry(key); + return; + } + table.update_entry(key, UNREACHABLE_ENTRY); + } + } + } + + /// Initiates a UDP connection with the entry and attempts to send a Ping + /// message. If it fails, it retries according to the allowed retries + /// specified in the Config, with backoff between each retry. + async fn connect(&self, entry: &Entry) -> Result<()> { + let mut retry = 0; + let conn = dial_udp(&entry.addr, &entry.discovery_port).await?; + let backoff = Backoff::new(100, 5000); + while retry < self.config.refresh_connect_retries { + match self.send_ping_msg(&conn).await { + Ok(()) => return Ok(()), + Err(Error::KaryonsNet(NetError::Timeout)) => { + retry += 1; + backoff.sleep().await; + } + Err(err) => { + return Err(err); + } + } + } + + Err(NetError::Timeout.into()) + } + + /// Set up a UDP listener and start listening for Ping messages from other + /// peers. + async fn listen_loop(self: Arc<Self>, addr: Addr, port: Port) -> Result<()> { + let endpoint = Endpoint::Udp(addr.clone(), port); + let conn = match listen_udp(&addr, &port).await { + Ok(c) => { + self.monitor + .notify(&ConnEvent::Listening(endpoint.clone()).into()) + .await; + c + } + Err(err) => { + self.monitor + .notify(&ConnEvent::ListenFailed(endpoint.clone()).into()) + .await; + return Err(err.into()); + } + }; + info!("Start listening on {endpoint}"); + + loop { + let res = self.listen_to_ping_msg(&conn).await; + if let Err(err) = res { + trace!("Failed to handle ping msg {err}"); + self.monitor.notify(&ConnEvent::AcceptFailed.into()).await; + } + } + } + + /// Listen to receive a Ping message and respond with a Pong message. + async fn listen_to_ping_msg(&self, conn: &UdpConn) -> Result<()> { + let mut buf = [0; PINGMSG_SIZE]; + let (_, endpoint) = conn.recv_from(&mut buf).await?; + + self.monitor + .notify(&ConnEvent::Accepted(endpoint.clone()).into()) + .await; + + let (ping_msg, _) = decode::<PingMsg>(&buf)?; + + let pong_msg = PongMsg(ping_msg.0); + let buffer = encode(&pong_msg)?; + + conn.send_to(&buffer, &endpoint).await?; + + self.monitor + .notify(&ConnEvent::Disconnected(endpoint.clone()).into()) + .await; + Ok(()) + } + + /// Sends a Ping msg and wait to receive the Pong message. + async fn send_ping_msg(&self, conn: &UdpConn) -> Result<()> { + let mut nonce: [u8; 32] = [0; 32]; + RngCore::fill_bytes(&mut OsRng, &mut nonce); + + let ping_msg = PingMsg(nonce); + let buffer = encode(&ping_msg)?; + conn.send(&buffer).await?; + + let buf = &mut [0; PINGMSG_SIZE]; + let t = Duration::from_secs(self.config.refresh_response_timeout); + timeout(t, conn.recv(buf)).await??; + + let (pong_msg, _) = decode::<PongMsg>(buf)?; + + if ping_msg.0 != pong_msg.0 { + return Err(Error::InvalidPongMsg); + } + + Ok(()) + } +} diff --git a/p2p/src/error.rs b/p2p/src/error.rs new file mode 100644 index 0000000..945e90a --- /dev/null +++ b/p2p/src/error.rs @@ -0,0 +1,82 @@ +use thiserror::Error as ThisError; + +pub type Result<T> = std::result::Result<T, Error>; + +/// Represents Karyons's p2p Error. +#[derive(ThisError, Debug)] +pub enum Error { + #[error("IO Error: {0}")] + IO(#[from] std::io::Error), + + #[error("Unsupported protocol error: {0}")] + UnsupportedProtocol(String), + + #[error("Invalid message error: {0}")] + InvalidMsg(String), + + #[error("Parse error: {0}")] + ParseError(String), + + #[error("Incompatible version error: {0}")] + IncompatibleVersion(String), + + #[error("Config error: {0}")] + Config(String), + + #[error("Peer shutdown")] + PeerShutdown, + + #[error("Invalid Pong Msg")] + InvalidPongMsg, + + #[error("Discovery error: {0}")] + Discovery(&'static str), + + #[error("Lookup error: {0}")] + Lookup(&'static str), + + #[error("Peer already connected")] + PeerAlreadyConnected, + + #[error("Channel Send Error: {0}")] + ChannelSend(String), + + #[error("Channel Receive Error: {0}")] + ChannelRecv(String), + + #[error("CORE::ERROR : {0}")] + KaryonsCore(#[from] karyons_core::error::Error), + + #[error("NET::ERROR : {0}")] + KaryonsNet(#[from] karyons_net::NetError), +} + +impl<T> From<smol::channel::SendError<T>> for Error { + fn from(error: smol::channel::SendError<T>) -> Self { + Error::ChannelSend(error.to_string()) + } +} + +impl From<smol::channel::RecvError> for Error { + fn from(error: smol::channel::RecvError) -> Self { + Error::ChannelRecv(error.to_string()) + } +} + +impl From<std::num::ParseIntError> for Error { + fn from(error: std::num::ParseIntError) -> Self { + Error::ParseError(error.to_string()) + } +} + +impl From<std::num::ParseFloatError> for Error { + fn from(error: std::num::ParseFloatError) -> Self { + Error::ParseError(error.to_string()) + } +} + +impl From<semver::Error> for Error { + fn from(error: semver::Error) -> Self { + Error::ParseError(error.to_string()) + } +} diff --git a/p2p/src/io_codec.rs b/p2p/src/io_codec.rs new file mode 100644 index 0000000..4515832 --- /dev/null +++ b/p2p/src/io_codec.rs @@ -0,0 +1,102 @@ +use std::time::Duration; + +use bincode::{Decode, Encode}; + +use karyons_core::{ + async_utils::timeout, + utils::{decode, encode, encode_into_slice}, +}; + +use karyons_net::{Connection, NetError}; + +use crate::{ + message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE}, + Error, Result, +}; + +pub trait CodecMsg: Decode + Encode + std::fmt::Debug {} +impl<T: Encode + Decode + std::fmt::Debug> CodecMsg for T {} + +/// I/O codec working with generic network connections. +/// +/// It is responsible for both decoding data received from the network and +/// encoding data before sending it. +pub struct IOCodec { + conn: Box<dyn Connection>, +} + +impl IOCodec { + /// Creates a new IOCodec. + pub fn new(conn: Box<dyn Connection>) -> Self { + Self { conn } + } + + /// Reads a message of type `NetMsg` from the connection. + /// + /// It reads the first 6 bytes as the header of the message, then reads + /// and decodes the remaining message data based on the determined header. + pub async fn read(&self) -> Result<NetMsg> { + // Read 6 bytes to get the header of the incoming message + let mut buf = [0; MSG_HEADER_SIZE]; + self.conn.recv(&mut buf).await?; + + // Decode the header from bytes to NetMsgHeader + let (header, _) = decode::<NetMsgHeader>(&buf)?; + + if header.payload_size > MAX_ALLOWED_MSG_SIZE { + return Err(Error::InvalidMsg( + "Message exceeds the maximum allowed size".to_string(), + )); + } + + // Create a buffer to hold the message based on its length + let mut payload = vec![0; header.payload_size as usize]; + self.conn.recv(&mut payload).await?; + + Ok(NetMsg { header, payload }) + } + + /// Writes a message of type `T` to the connection. + /// + /// Before appending the actual message payload, it calculates the length of + /// the encoded message in bytes and appends this length to the message header. + pub async fn write<T: CodecMsg>(&self, command: NetMsgCmd, msg: &T) -> Result<()> { + let payload = encode(msg)?; + + // Create a buffer to hold the message header (6 bytes) + let header_buf = &mut [0; MSG_HEADER_SIZE]; + let header = NetMsgHeader { + command, + payload_size: payload.len() as u32, + }; + encode_into_slice(&header, header_buf)?; + + let mut buffer = vec![]; + // Append the header bytes to the buffer + buffer.extend_from_slice(header_buf); + // Append the message payload to the buffer + buffer.extend_from_slice(&payload); + + self.conn.send(&buffer).await?; + Ok(()) + } + + /// Reads a message of type `NetMsg` with the given timeout. + pub async fn read_timeout(&self, duration: Duration) -> Result<NetMsg> { + timeout(duration, self.read()) + .await + .map_err(|_| NetError::Timeout)? + } + + /// Writes a message of type `T` with the given timeout. + pub async fn write_timeout<T: CodecMsg>( + &self, + command: NetMsgCmd, + msg: &T, + duration: Duration, + ) -> Result<()> { + timeout(duration, self.write(command, msg)) + .await + .map_err(|_| NetError::Timeout)? + } +} diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs new file mode 100644 index 0000000..08ba059 --- /dev/null +++ b/p2p/src/lib.rs @@ -0,0 +1,27 @@ +mod backend; +mod config; +mod discovery; +mod error; +mod io_codec; +mod message; +mod net; +mod peer; +mod peer_pool; +mod protocols; +mod routing_table; +mod utils; + +/// Responsible for network and system monitoring. +/// [`Read More`](./monitor/struct.Monitor.html) +pub mod monitor; +/// Defines the protocol trait. +/// [`Read More`](./protocol/trait.Protocol.html) +pub mod protocol; + +pub use backend::{ArcBackend, Backend}; +pub use config::Config; +pub use error::Error as P2pError; +pub use peer::{ArcPeer, PeerID}; +pub use utils::Version; + +use error::{Error, Result}; diff --git a/p2p/src/message.rs b/p2p/src/message.rs new file mode 100644 index 0000000..833f6f4 --- /dev/null +++ b/p2p/src/message.rs @@ -0,0 +1,133 @@ +use std::collections::HashMap; + +use bincode::{Decode, Encode}; + +use karyons_net::{Addr, Port}; + +use crate::{protocol::ProtocolID, routing_table::Entry, utils::VersionInt, PeerID}; + +/// The size of the message header, in bytes. +pub const MSG_HEADER_SIZE: usize = 6; + +/// The maximum allowed size for a message in bytes. +pub const MAX_ALLOWED_MSG_SIZE: u32 = 1000000; + +/// Defines the main message in the Karyon P2P network. +/// +/// This message structure consists of a header and payload, where the header +/// typically contains essential information about the message, and the payload +/// contains the actual data being transmitted. +#[derive(Decode, Encode, Debug, Clone)] +pub struct NetMsg { + pub header: NetMsgHeader, + pub payload: Vec<u8>, +} + +/// Represents the header of a message. +#[derive(Decode, Encode, Debug, Clone)] +pub struct NetMsgHeader { + pub command: NetMsgCmd, + pub payload_size: u32, +} + +/// Defines message commands. +#[derive(Decode, Encode, Debug, Clone)] +#[repr(u8)] +pub enum NetMsgCmd { + Version, + Verack, + Protocol, + Shutdown, + + // NOTE: The following commands are used during the lookup process. + Ping, + Pong, + FindPeer, + Peer, + Peers, +} + +/// Defines a message related to a specific protocol. +#[derive(Decode, Encode, Debug, Clone)] +pub struct ProtocolMsg { + pub protocol_id: ProtocolID, + pub payload: Vec<u8>, +} + +/// Version message, providing information about a peer's capabilities. +#[derive(Decode, Encode, Debug, Clone)] +pub struct VerMsg { + pub peer_id: PeerID, + pub version: VersionInt, + pub protocols: HashMap<ProtocolID, VersionInt>, +} + +/// VerAck message acknowledging the receipt of a Version message. +#[derive(Decode, Encode, Debug, Clone)] +pub struct VerAckMsg(pub PeerID); + +/// Shutdown message. +#[derive(Decode, Encode, Debug, Clone)] +pub struct ShutdownMsg(pub u8); + +/// Ping message with a nonce and version information. +#[derive(Decode, Encode, Debug, Clone)] +pub struct PingMsg { + pub nonce: [u8; 32], + pub version: VersionInt, +} + +/// Ping message with a nonce. +#[derive(Decode, Encode, Debug)] +pub struct PongMsg(pub [u8; 32]); + +/// FindPeer message used to find a specific peer. +#[derive(Decode, Encode, Debug)] +pub struct FindPeerMsg(pub PeerID); + +/// PeerMsg containing information about a peer. +#[derive(Decode, Encode, Debug, Clone, PartialEq, Eq)] +pub struct PeerMsg { + pub peer_id: PeerID, + pub addr: Addr, + pub port: Port, + pub discovery_port: Port, +} + +/// PeersMsg a list of `PeerMsg`. +#[derive(Decode, Encode, Debug)] +pub struct PeersMsg(pub Vec<PeerMsg>); + +macro_rules! get_msg_payload { + ($a:ident, $b:expr) => { + if let NetMsgCmd::$a = $b.header.command { + $b.payload + } else { + return Err(Error::InvalidMsg(format!("Unexpected msg{:?}", $b))); + } + }; +} + +pub(super) use get_msg_payload; + +impl From<Entry> for PeerMsg { + fn from(entry: Entry) -> PeerMsg { + PeerMsg { + peer_id: PeerID(entry.key), + addr: entry.addr, + port: entry.port, + discovery_port: entry.discovery_port, + } + } +} + +impl From<PeerMsg> for Entry { + fn from(peer: PeerMsg) -> Entry { + Entry { + key: peer.peer_id.0, + addr: peer.addr, + port: peer.port, + discovery_port: peer.discovery_port, + } + } +} diff --git a/p2p/src/monitor.rs b/p2p/src/monitor.rs new file mode 100644 index 0000000..ee0bf44 --- /dev/null +++ b/p2p/src/monitor.rs @@ -0,0 +1,154 @@ +use std::fmt; + +use crate::PeerID; + +use karyons_core::pubsub::{ArcPublisher, Publisher, Subscription}; + +use karyons_net::Endpoint; + +/// Responsible for network and system monitoring. +/// +/// It use pub-sub pattern to notify the subscribers with new events. +/// +/// # Example +/// +/// ``` +/// use karyons_p2p::{Config, Backend, PeerID}; +/// async { +/// +/// let backend = Backend::new(PeerID::random(), Config::default()); +/// +/// // Create a new Subscription +/// let sub = backend.monitor().await; +/// +/// let event = sub.recv().await; +/// }; +/// ``` +pub struct Monitor { + inner: ArcPublisher<MonitorEvent>, +} + +impl Monitor { + /// Creates a new Monitor + pub(crate) fn new() -> Monitor { + Self { + inner: Publisher::new(), + } + } + + /// Sends a new monitor event to all subscribers. + pub async fn notify(&self, event: &MonitorEvent) { + self.inner.notify(event).await; + } + + /// Subscribes to listen to new events. + pub async fn subscribe(&self) -> Subscription<MonitorEvent> { + self.inner.subscribe().await + } +} + +/// Defines various type of event that can be monitored. +#[derive(Clone, Debug)] +pub enum MonitorEvent { + Conn(ConnEvent), + PeerPool(PeerPoolEvent), + Discovery(DiscoveryEvent), +} + +/// Defines connection-related events. +#[derive(Clone, Debug)] +pub enum ConnEvent { + Connected(Endpoint), + ConnectRetried(Endpoint), + ConnectFailed(Endpoint), + Accepted(Endpoint), + AcceptFailed, + Disconnected(Endpoint), + Listening(Endpoint), + ListenFailed(Endpoint), +} + +/// Defines `PeerPool` events. +#[derive(Clone, Debug)] +pub enum PeerPoolEvent { + NewPeer(PeerID), + RemovePeer(PeerID), +} + +/// Defines `Discovery` events. +#[derive(Clone, Debug)] +pub enum DiscoveryEvent { + LookupStarted(Endpoint), + LookupFailed(Endpoint), + LookupSucceeded(Endpoint, usize), + RefreshStarted, +} + +impl fmt::Display for MonitorEvent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let val = match self { + MonitorEvent::Conn(e) => format!("Connection Event: {e}"), + MonitorEvent::PeerPool(e) => format!("PeerPool Event: {e}"), + MonitorEvent::Discovery(e) => format!("Discovery Event: {e}"), + }; + write!(f, "{}", val) + } +} + +impl fmt::Display for ConnEvent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let val = match self { + ConnEvent::Connected(endpoint) => format!("Connected: {endpoint}"), + ConnEvent::ConnectFailed(endpoint) => format!("ConnectFailed: {endpoint}"), + ConnEvent::ConnectRetried(endpoint) => format!("ConnectRetried: {endpoint}"), + ConnEvent::AcceptFailed => "AcceptFailed".to_string(), + ConnEvent::Accepted(endpoint) => format!("Accepted: {endpoint}"), + ConnEvent::Disconnected(endpoint) => format!("Disconnected: {endpoint}"), + ConnEvent::Listening(endpoint) => format!("Listening: {endpoint}"), + ConnEvent::ListenFailed(endpoint) => format!("ListenFailed: {endpoint}"), + }; + write!(f, "{}", val) + } +} + +impl fmt::Display for PeerPoolEvent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let val = match self { + PeerPoolEvent::NewPeer(pid) => format!("NewPeer: {pid}"), + PeerPoolEvent::RemovePeer(pid) => format!("RemovePeer: {pid}"), + }; + write!(f, "{}", val) + } +} + +impl fmt::Display for DiscoveryEvent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let val = match self { + DiscoveryEvent::LookupStarted(endpoint) => format!("LookupStarted: {endpoint}"), + DiscoveryEvent::LookupFailed(endpoint) => format!("LookupFailed: {endpoint}"), + DiscoveryEvent::LookupSucceeded(endpoint, len) => { + format!("LookupSucceeded: {endpoint} {len}") + } + DiscoveryEvent::RefreshStarted => "RefreshStarted".to_string(), + }; + write!(f, "{}", val) + } +} + +impl From<ConnEvent> for MonitorEvent { + fn from(val: ConnEvent) -> Self { + MonitorEvent::Conn(val) + } +} + +impl From<PeerPoolEvent> for MonitorEvent { + fn from(val: PeerPoolEvent) -> Self { + MonitorEvent::PeerPool(val) + } +} + +impl From<DiscoveryEvent> for MonitorEvent { + fn from(val: DiscoveryEvent) -> Self { + MonitorEvent::Discovery(val) + } +} diff --git a/p2p/src/net/connection_queue.rs b/p2p/src/net/connection_queue.rs new file mode 100644 index 0000000..fbc4bfc --- /dev/null +++ b/p2p/src/net/connection_queue.rs @@ -0,0 +1,52 @@ +use std::sync::Arc; + +use smol::{channel::Sender, lock::Mutex}; + +use karyons_core::async_utils::CondVar; + +use karyons_net::Conn; + +use crate::net::ConnDirection; + +pub struct NewConn { + pub direction: ConnDirection, + pub conn: Conn, + pub disconnect_signal: Sender<()>, +} + +/// Connection queue +pub struct ConnQueue { + queue: Mutex<Vec<NewConn>>, + conn_available: CondVar, +} + +impl ConnQueue { + pub fn new() -> Arc<Self> { + Arc::new(Self { + queue: Mutex::new(Vec::new()), + conn_available: CondVar::new(), + }) + } + + /// Push a connection into the queue and wait for the disconnect signal + pub async fn handle(&self, conn: Conn, direction: ConnDirection) { + let (disconnect_signal, chan) = smol::channel::bounded(1); + let new_conn = NewConn { + direction, + conn, + disconnect_signal, + }; + self.queue.lock().await.push(new_conn); + self.conn_available.signal(); + let _ = chan.recv().await; + } + + /// Receive the next connection in the queue + pub async fn next(&self) -> NewConn { + let mut queue = self.queue.lock().await; + while queue.is_empty() { + queue = self.conn_available.wait(queue).await; + } + queue.pop().unwrap() + } +} diff --git a/p2p/src/net/connector.rs b/p2p/src/net/connector.rs new file mode 100644 index 0000000..72dc0d8 --- /dev/null +++ b/p2p/src/net/connector.rs @@ -0,0 +1,125 @@ +use std::{future::Future, sync::Arc}; + +use log::{trace, warn}; + +use karyons_core::{ + async_utils::{Backoff, TaskGroup, TaskResult}, + Executor, +}; +use karyons_net::{dial, Conn, Endpoint, NetError}; + +use crate::{ + monitor::{ConnEvent, Monitor}, + Result, +}; + +use super::slots::ConnectionSlots; + +/// Responsible for creating outbound connections with other peers. +pub struct Connector { + /// Managing spawned tasks. + task_group: TaskGroup, + + /// Manages available outbound slots. + connection_slots: Arc<ConnectionSlots>, + + /// The maximum number of retries allowed before successfully + /// establishing a connection. + max_retries: usize, + + /// Responsible for network and system monitoring. + monitor: Arc<Monitor>, +} + +impl Connector { + /// Creates a new Connector + pub fn new( + max_retries: usize, + connection_slots: Arc<ConnectionSlots>, + monitor: Arc<Monitor>, + ) -> Arc<Self> { + Arc::new(Self { + task_group: TaskGroup::new(), + monitor, + connection_slots, + max_retries, + }) + } + + /// Shuts down the connector + pub async fn shutdown(&self) { + self.task_group.cancel().await; + } + + /// Establish a connection to the specified `endpoint`. If the connection + /// attempt fails, it performs a backoff and retries until the maximum allowed + /// number of retries is exceeded. On a successful connection, it returns a + /// `Conn` instance. + /// + /// This method will block until it finds an available slot. + pub async fn connect(&self, endpoint: &Endpoint) -> Result<Conn> { + self.connection_slots.wait_for_slot().await; + self.connection_slots.add(); + + let mut retry = 0; + let backoff = Backoff::new(500, 2000); + while retry < self.max_retries { + let conn_result = dial(endpoint).await; + + if let Ok(conn) = conn_result { + self.monitor + .notify(&ConnEvent::Connected(endpoint.clone()).into()) + .await; + return Ok(conn); + } + + self.monitor + .notify(&ConnEvent::ConnectRetried(endpoint.clone()).into()) + .await; + + backoff.sleep().await; + + warn!("try to reconnect {endpoint}"); + retry += 1; + } + + self.monitor + .notify(&ConnEvent::ConnectFailed(endpoint.clone()).into()) + .await; + + self.connection_slots.remove().await; + Err(NetError::Timeout.into()) + } + + /// Establish a connection to the given `endpoint`. For each new connection, + /// it invokes the provided `callback`, and pass the connection to the callback. + pub async fn connect_with_cback<'a, Fut>( + self: &Arc<Self>, + ex: Executor<'a>, + endpoint: &Endpoint, + callback: impl FnOnce(Conn) -> Fut + Send + 'a, + ) -> Result<()> + where + Fut: Future<Output = Result<()>> + Send + 'a, + { + let conn = self.connect(endpoint).await?; + + let selfc = self.clone(); + let endpoint = endpoint.clone(); + let on_disconnect = |res| async move { + if let TaskResult::Completed(Err(err)) = res { + trace!("Outbound connection dropped: {err}"); + } + selfc + .monitor + .notify(&ConnEvent::Disconnected(endpoint.clone()).into()) + .await; + selfc.connection_slots.remove().await; + }; + + self.task_group + .spawn(ex.clone(), callback(conn), on_disconnect); + + Ok(()) + } +} diff --git a/p2p/src/net/listener.rs b/p2p/src/net/listener.rs new file mode 100644 index 0000000..d1a7bfb --- /dev/null +++ b/p2p/src/net/listener.rs @@ -0,0 +1,142 @@ +use std::{future::Future, sync::Arc}; + +use log::{error, info, trace}; + +use karyons_core::{ + async_utils::{TaskGroup, TaskResult}, + Executor, +}; + +use karyons_net::{listen, Conn, Endpoint, Listener as NetListener}; + +use crate::{ + monitor::{ConnEvent, Monitor}, + Result, +}; + +use super::slots::ConnectionSlots; + +/// Responsible for creating inbound connections with other peers. +pub struct Listener { + /// Managing spawned tasks. + task_group: TaskGroup, + + /// Manages available inbound slots. + connection_slots: Arc<ConnectionSlots>, + + /// Responsible for network and system monitoring. + monitor: Arc<Monitor>, +} + +impl Listener { + /// Creates a new Listener + pub fn new(connection_slots: Arc<ConnectionSlots>, monitor: Arc<Monitor>) -> Arc<Self> { + Arc::new(Self { + connection_slots, + task_group: TaskGroup::new(), + monitor, + }) + } + + /// Starts a listener on the given `endpoint`. For each incoming connection + /// that is accepted, it invokes the provided `callback`, and pass the + /// connection to the callback. + /// + /// Returns the resloved listening endpoint. + pub async fn start<'a, Fut>( + self: &Arc<Self>, + ex: Executor<'a>, + endpoint: Endpoint, + // https://github.com/rust-lang/rfcs/pull/2132 + callback: impl FnOnce(Conn) -> Fut + Clone + Send + 'a, + ) -> Result<Endpoint> + where + Fut: Future<Output = Result<()>> + Send + 'a, + { + let listener = match listen(&endpoint).await { + Ok(listener) => { + self.monitor + .notify(&ConnEvent::Listening(endpoint.clone()).into()) + .await; + listener + } + Err(err) => { + error!("Failed to listen on {endpoint}: {err}"); + self.monitor + .notify(&ConnEvent::ListenFailed(endpoint).into()) + .await; + return Err(err.into()); + } + }; + + let resolved_endpoint = listener.local_endpoint()?; + + info!("Start listening on {endpoint}"); + + let selfc = self.clone(); + self.task_group.spawn( + ex.clone(), + selfc.listen_loop(ex.clone(), listener, callback), + |res| async move { + if let TaskResult::Completed(Err(err)) = res { + error!("Listen loop stopped: {endpoint} {err}"); + } + }, + ); + Ok(resolved_endpoint) + } + + /// Shuts down the listener + pub async fn shutdown(&self) { + self.task_group.cancel().await; + } + + async fn listen_loop<'a, Fut>( + self: Arc<Self>, + ex: Executor<'a>, + listener: Box<dyn NetListener>, + callback: impl FnOnce(Conn) -> Fut + Clone + Send + 'a, + ) -> Result<()> + where + Fut: Future<Output = Result<()>> + Send + 'a, + { + loop { + // Wait for an available inbound slot. + self.connection_slots.wait_for_slot().await; + let result = listener.accept().await; + + let conn = match result { + Ok(c) => { + self.monitor + .notify(&ConnEvent::Accepted(c.peer_endpoint()?).into()) + .await; + c + } + Err(err) => { + error!("Failed to accept a new connection: {err}"); + self.monitor.notify(&ConnEvent::AcceptFailed.into()).await; + return Err(err.into()); + } + }; + + self.connection_slots.add(); + + let selfc = self.clone(); + let endpoint = conn.peer_endpoint()?; + let on_disconnect = |res| async move { + if let TaskResult::Completed(Err(err)) = res { + trace!("Inbound connection dropped: {err}"); + } + selfc + .monitor + .notify(&ConnEvent::Disconnected(endpoint).into()) + .await; + selfc.connection_slots.remove().await; + }; + + let callback = callback.clone(); + self.task_group + .spawn(ex.clone(), callback(conn), on_disconnect); + } + } +} diff --git a/p2p/src/net/mod.rs b/p2p/src/net/mod.rs new file mode 100644 index 0000000..9cdc748 --- /dev/null +++ b/p2p/src/net/mod.rs @@ -0,0 +1,27 @@ +mod connection_queue; +mod connector; +mod listener; +mod slots; + +pub use connection_queue::ConnQueue; +pub use connector::Connector; +pub use listener::Listener; +pub use slots::ConnectionSlots; + +use std::fmt; + +/// Defines the direction of a network connection. +#[derive(Clone, Debug)] +pub enum ConnDirection { + Inbound, + Outbound, +} + +impl fmt::Display for ConnDirection { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ConnDirection::Inbound => write!(f, "Inbound"), + ConnDirection::Outbound => write!(f, "Outbound"), + } + } +} diff --git a/p2p/src/net/slots.rs b/p2p/src/net/slots.rs new file mode 100644 index 0000000..99f0a78 --- /dev/null +++ b/p2p/src/net/slots.rs @@ -0,0 +1,54 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; + +use karyons_core::async_utils::CondWait; + +/// Manages available inbound and outbound slots. +pub struct ConnectionSlots { + /// A condvar for notifying when a slot become available. + signal: CondWait, + /// The number of occupied slots + slots: AtomicUsize, + /// The maximum number of slots. + max_slots: usize, +} + +impl ConnectionSlots { + /// Creates a new ConnectionSlots + pub fn new(max_slots: usize) -> Self { + Self { + signal: CondWait::new(), + slots: AtomicUsize::new(0), + max_slots, + } + } + + /// Returns the number of occupied slots + pub fn load(&self) -> usize { + self.slots.load(Ordering::SeqCst) + } + + /// Increases the occupied slots by one. + pub fn add(&self) { + self.slots.fetch_add(1, Ordering::SeqCst); + } + + /// Decreases the occupied slots by one and notifies the waiting signal + /// to start accepting/connecting new connections. + pub async fn remove(&self) { + self.slots.fetch_sub(1, Ordering::SeqCst); + if self.slots.load(Ordering::SeqCst) < self.max_slots { + self.signal.signal().await; + } + } + + /// Waits for a slot to become available. + pub async fn wait_for_slot(&self) { + if self.slots.load(Ordering::SeqCst) < self.max_slots { + return; + } + + // Wait for a signal + self.signal.wait().await; + self.signal.reset().await; + } +} diff --git a/p2p/src/peer/mod.rs b/p2p/src/peer/mod.rs new file mode 100644 index 0000000..ee0fdc4 --- /dev/null +++ b/p2p/src/peer/mod.rs @@ -0,0 +1,237 @@ +mod peer_id; + +pub use peer_id::PeerID; + +use std::sync::Arc; + +use log::{error, trace}; +use smol::{ + channel::{self, Receiver, Sender}, + lock::RwLock, +}; + +use karyons_core::{ + async_utils::{select, Either, TaskGroup, TaskResult}, + event::{ArcEventSys, EventListener, EventSys}, + utils::{decode, encode}, + Executor, +}; + +use karyons_net::Endpoint; + +use crate::{ + io_codec::{CodecMsg, IOCodec}, + message::{NetMsgCmd, ProtocolMsg, ShutdownMsg}, + net::ConnDirection, + peer_pool::{ArcPeerPool, WeakPeerPool}, + protocol::{Protocol, ProtocolEvent, ProtocolID}, + Config, Error, Result, +}; + +pub type ArcPeer = Arc<Peer>; + +pub struct Peer { + /// Peer's ID + id: PeerID, + + /// A weak pointer to `PeerPool` + peer_pool: WeakPeerPool, + + /// Holds the IOCodec for the peer connection + io_codec: IOCodec, + + /// Remote endpoint for the peer + remote_endpoint: Endpoint, + + /// The direction of the connection, either `Inbound` or `Outbound` + conn_direction: ConnDirection, + + /// A list of protocol IDs + protocol_ids: RwLock<Vec<ProtocolID>>, + + /// `EventSys` responsible for sending events to the protocols. + protocol_events: ArcEventSys<ProtocolID>, + + /// This channel is used to send a stop signal to the read loop. + stop_chan: (Sender<Result<()>>, Receiver<Result<()>>), + + /// Managing spawned tasks. + task_group: TaskGroup, +} + +impl Peer { + /// Creates a new peer + pub fn new( + peer_pool: WeakPeerPool, + id: &PeerID, + io_codec: IOCodec, + remote_endpoint: Endpoint, + conn_direction: ConnDirection, + ) -> ArcPeer { + Arc::new(Peer { + id: id.clone(), + peer_pool, + io_codec, + protocol_ids: RwLock::new(Vec::new()), + remote_endpoint, + conn_direction, + protocol_events: EventSys::new(), + task_group: TaskGroup::new(), + stop_chan: channel::bounded(1), + }) + } + + /// Run the peer + pub async fn run(self: Arc<Self>, ex: Executor<'_>) -> Result<()> { + self.start_protocols(ex.clone()).await; + self.read_loop().await + } + + /// Send a message to the peer connection using the specified protocol. + pub async fn send<T: CodecMsg>(&self, protocol_id: &ProtocolID, msg: &T) -> Result<()> { + let payload = encode(msg)?; + + let proto_msg = ProtocolMsg { + protocol_id: protocol_id.to_string(), + payload: payload.to_vec(), + }; + + self.io_codec.write(NetMsgCmd::Protocol, &proto_msg).await?; + Ok(()) + } + + /// Broadcast a message to all connected peers using the specified protocol. + pub async fn broadcast<T: CodecMsg>(&self, protocol_id: &ProtocolID, msg: &T) { + self.peer_pool().broadcast(protocol_id, msg).await; + } + + /// Shuts down the peer + pub async fn shutdown(&self) { + trace!("peer {} start shutting down", self.id); + + // Send shutdown event to all protocols + for protocol_id in self.protocol_ids.read().await.iter() { + self.protocol_events + .emit_by_topic(protocol_id, &ProtocolEvent::Shutdown) + .await; + } + + // Send a stop signal to the read loop + // + // No need to handle the error here; a dropped channel and + // sending a stop signal have the same effect. + let _ = self.stop_chan.0.try_send(Ok(())); + + // No need to handle the error here + let _ = self + .io_codec + .write(NetMsgCmd::Shutdown, &ShutdownMsg(0)) + .await; + + // Force shutting down + self.task_group.cancel().await; + } + + /// Check if the connection is Inbound + #[inline] + pub fn is_inbound(&self) -> bool { + match self.conn_direction { + ConnDirection::Inbound => true, + ConnDirection::Outbound => false, + } + } + + /// Returns the direction of the connection, which can be either `Inbound` + /// or `Outbound`. + #[inline] + pub fn direction(&self) -> &ConnDirection { + &self.conn_direction + } + + /// Returns the remote endpoint for the peer + #[inline] + pub fn remote_endpoint(&self) -> &Endpoint { + &self.remote_endpoint + } + + /// Return the peer's ID + #[inline] + pub fn id(&self) -> &PeerID { + &self.id + } + + /// Returns the `Config` instance. + pub fn config(&self) -> Arc<Config> { + self.peer_pool().config.clone() + } + + /// Registers a listener for the given Protocol `P`. + pub async fn register_listener<P: Protocol>(&self) -> EventListener<ProtocolID, ProtocolEvent> { + self.protocol_events.register(&P::id()).await + } + + /// Start a read loop to handle incoming messages from the peer connection. + async fn read_loop(&self) -> Result<()> { + loop { + let fut = select(self.stop_chan.1.recv(), self.io_codec.read()).await; + let result = match fut { + Either::Left(stop_signal) => { + trace!("Peer {} received a stop signal", self.id); + return stop_signal?; + } + Either::Right(result) => result, + }; + + let msg = result?; + + match msg.header.command { + NetMsgCmd::Protocol => { + let msg: ProtocolMsg = decode(&msg.payload)?.0; + + if !self.protocol_ids.read().await.contains(&msg.protocol_id) { + return Err(Error::UnsupportedProtocol(msg.protocol_id)); + } + + let proto_id = &msg.protocol_id; + let msg = ProtocolEvent::Message(msg.payload); + self.protocol_events.emit_by_topic(proto_id, &msg).await; + } + NetMsgCmd::Shutdown => { + return Err(Error::PeerShutdown); + } + command => return Err(Error::InvalidMsg(format!("Unexpected msg {:?}", command))), + } + } + } + + /// Start running the protocols for this peer connection. + async fn start_protocols(self: &Arc<Self>, ex: Executor<'_>) { + for (protocol_id, constructor) in self.peer_pool().protocols.read().await.iter() { + trace!("peer {} start protocol {protocol_id}", self.id); + let protocol = constructor(self.clone()); + + self.protocol_ids.write().await.push(protocol_id.clone()); + + let selfc = self.clone(); + let exc = ex.clone(); + let proto_idc = protocol_id.clone(); + + let on_failure = |result: TaskResult<Result<()>>| async move { + if let TaskResult::Completed(res) = result { + if res.is_err() { + error!("protocol {} stopped", proto_idc); + } + // Send a stop signal to read loop + let _ = selfc.stop_chan.0.try_send(res); + } + }; + + self.task_group + .spawn(ex.clone(), protocol.start(exc), on_failure); + } + } + + fn peer_pool(&self) -> ArcPeerPool { + self.peer_pool.upgrade().unwrap() + } +} diff --git a/p2p/src/peer/peer_id.rs b/p2p/src/peer/peer_id.rs new file mode 100644 index 0000000..c8aec7d --- /dev/null +++ b/p2p/src/peer/peer_id.rs @@ -0,0 +1,41 @@ +use bincode::{Decode, Encode}; +use rand::{rngs::OsRng, RngCore}; +use sha2::{Digest, Sha256}; + +/// Represents a unique identifier for a peer. +#[derive(Clone, Debug, Eq, PartialEq, Hash, Decode, Encode)] +pub struct PeerID(pub [u8; 32]); + +impl std::fmt::Display for PeerID { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let id = self.0[0..8] + .iter() + .map(|b| format!("{:x}", b)) + .collect::<Vec<String>>() + .join(""); + + write!(f, "{}", id) + } +} + +impl PeerID { + /// Creates a new PeerID. + pub fn new(src: &[u8]) -> Self { + let mut hasher = Sha256::new(); + hasher.update(src); + Self(hasher.finalize().into()) + } + + /// Generates a random PeerID. + pub fn random() -> Self { + let mut id: [u8; 32] = [0; 32]; + OsRng.fill_bytes(&mut id); + Self(id) + } +} + +impl From<[u8; 32]> for PeerID { + fn from(b: [u8; 32]) -> Self { + PeerID(b) + } +} diff --git a/p2p/src/peer_pool.rs b/p2p/src/peer_pool.rs new file mode 100644 index 0000000..eac4d3d --- /dev/null +++ b/p2p/src/peer_pool.rs @@ -0,0 +1,337 @@ +use std::{ + collections::HashMap, + sync::{Arc, Weak}, + time::Duration, +}; + +use log::{error, info, trace, warn}; +use smol::{ + channel::Sender, + lock::{Mutex, RwLock}, +}; + +use karyons_core::{ + async_utils::{TaskGroup, TaskResult}, + utils::decode, + Executor, +}; + +use karyons_net::Conn; + +use crate::{ + config::Config, + io_codec::{CodecMsg, IOCodec}, + message::{get_msg_payload, NetMsg, NetMsgCmd, VerAckMsg, VerMsg}, + monitor::{Monitor, PeerPoolEvent}, + net::ConnDirection, + net::ConnQueue, + peer::{ArcPeer, Peer, PeerID}, + protocol::{Protocol, ProtocolConstructor, ProtocolID}, + protocols::PingProtocol, + utils::{version_match, Version, VersionInt}, + Error, Result, +}; + +pub type ArcPeerPool = Arc<PeerPool>; +pub type WeakPeerPool = Weak<PeerPool>; + +pub struct PeerPool { + /// Peer's ID + pub id: PeerID, + + /// Connection queue + conn_queue: Arc<ConnQueue>, + + /// Holds the running peers. + peers: Mutex<HashMap<PeerID, ArcPeer>>, + + /// Hashmap contains protocol constructors. + pub(crate) protocols: RwLock<HashMap<ProtocolID, Box<ProtocolConstructor>>>, + + /// Hashmap contains protocol IDs and their versions. + protocol_versions: Arc<RwLock<HashMap<ProtocolID, Version>>>, + + /// Managing spawned tasks. + task_group: TaskGroup, + + /// The Configuration for the P2P network. + pub config: Arc<Config>, + + /// Responsible for network and system monitoring. + monitor: Arc<Monitor>, +} + +impl PeerPool { + /// Creates a new PeerPool + pub fn new( + id: &PeerID, + conn_queue: Arc<ConnQueue>, + config: Arc<Config>, + monitor: Arc<Monitor>, + ) -> Arc<Self> { + let protocols = RwLock::new(HashMap::new()); + let protocol_versions = Arc::new(RwLock::new(HashMap::new())); + + Arc::new(Self { + id: id.clone(), + conn_queue, + peers: Mutex::new(HashMap::new()), + protocols, + protocol_versions, + task_group: TaskGroup::new(), + monitor, + config, + }) + } + + /// Start + pub async fn start(self: &Arc<Self>, ex: Executor<'_>) -> Result<()> { + self.setup_protocols().await?; + let selfc = self.clone(); + self.task_group + .spawn(ex.clone(), selfc.listen_loop(ex.clone()), |_| async {}); + Ok(()) + } + + /// Listens to a new connection from the connection queue + pub async fn listen_loop(self: Arc<Self>, ex: Executor<'_>) { + loop { + let new_conn = self.conn_queue.next().await; + let disconnect_signal = new_conn.disconnect_signal; + + let result = self + .new_peer( + new_conn.conn, + &new_conn.direction, + disconnect_signal.clone(), + ex.clone(), + ) + .await; + + if result.is_err() { + let _ = disconnect_signal.send(()).await; + } + } + } + + /// Shuts down + pub async fn shutdown(&self) { + for (_, peer) in self.peers.lock().await.iter() { + peer.shutdown().await; + } + + self.task_group.cancel().await; + } + + /// Attach a custom protocol to the network + pub async fn attach_protocol<P: Protocol>(&self, c: Box<ProtocolConstructor>) -> Result<()> { + let protocol_versions = &mut self.protocol_versions.write().await; + let protocols = &mut self.protocols.write().await; + + protocol_versions.insert(P::id(), P::version()?); + protocols.insert(P::id(), Box::new(c) as Box<ProtocolConstructor>); + Ok(()) + } + + /// Returns the number of currently connected peers. + pub async fn peers_len(&self) -> usize { + self.peers.lock().await.len() + } + + /// Broadcast a message to all connected peers using the specified protocol. + pub async fn broadcast<T: CodecMsg>(&self, proto_id: &ProtocolID, msg: &T) { + for (pid, peer) in self.peers.lock().await.iter() { + if let Err(err) = peer.send(proto_id, msg).await { + error!("failed to send msg to {pid}: {err}"); + continue; + } + } + } + + /// Add a new peer to the peer list. + pub async fn new_peer( + self: &Arc<Self>, + conn: Conn, + conn_direction: &ConnDirection, + disconnect_signal: Sender<()>, + ex: Executor<'_>, + ) -> Result<PeerID> { + let endpoint = conn.peer_endpoint()?; + let io_codec = IOCodec::new(conn); + + // Do a handshake with a connection before creating a new peer. + let pid = self.do_handshake(&io_codec, conn_direction).await?; + + // TODO: Consider restricting the subnet for inbound connections + if self.contains_peer(&pid).await { + return Err(Error::PeerAlreadyConnected); + } + + // Create a new peer + let peer = Peer::new( + Arc::downgrade(self), + &pid, + io_codec, + endpoint.clone(), + conn_direction.clone(), + ); + + // Insert the new peer + self.peers.lock().await.insert(pid.clone(), peer.clone()); + + let selfc = self.clone(); + let pid_c = pid.clone(); + let on_disconnect = |result| async move { + if let TaskResult::Completed(_) = result { + if let Err(err) = selfc.remove_peer(&pid_c).await { + error!("Failed to remove peer {pid_c}: {err}"); + } + let _ = disconnect_signal.send(()).await; + } + }; + + self.task_group + .spawn(ex.clone(), peer.run(ex.clone()), on_disconnect); + + info!("Add new peer {pid}, direction: {conn_direction}, endpoint: {endpoint}"); + + self.monitor + .notify(&PeerPoolEvent::NewPeer(pid.clone()).into()) + .await; + Ok(pid) + } + + /// Checks if the peer list contains a peer with the given peer id + pub async fn contains_peer(&self, pid: &PeerID) -> bool { + self.peers.lock().await.contains_key(pid) + } + + /// Shuts down the peer and remove it from the peer list. + async fn remove_peer(&self, pid: &PeerID) -> Result<()> { + let mut peers = self.peers.lock().await; + let result = peers.remove(pid); + + drop(peers); + + let peer = match result { + Some(p) => p, + None => return Ok(()), + }; + + peer.shutdown().await; + + self.monitor + .notify(&PeerPoolEvent::RemovePeer(pid.clone()).into()) + .await; + + let endpoint = peer.remote_endpoint(); + let direction = peer.direction(); + + warn!("Peer {pid} removed, direction: {direction}, endpoint: {endpoint}",); + Ok(()) + } + + /// Attach the core protocols. + async fn setup_protocols(&self) -> Result<()> { + self.attach_protocol::<PingProtocol>(Box::new(PingProtocol::new)) + .await + } + + /// Initiate a handshake with a connection. + async fn do_handshake( + &self, + io_codec: &IOCodec, + conn_direction: &ConnDirection, + ) -> Result<PeerID> { + match conn_direction { + ConnDirection::Inbound => { + let pid = self.wait_vermsg(io_codec).await?; + self.send_verack(io_codec).await?; + Ok(pid) + } + ConnDirection::Outbound => { + self.send_vermsg(io_codec).await?; + self.wait_verack(io_codec).await + } + } + } + + /// Send a Version message + async fn send_vermsg(&self, io_codec: &IOCodec) -> Result<()> { + let pids = self.protocol_versions.read().await; + let protocols = pids.iter().map(|p| (p.0.clone(), p.1.v.clone())).collect(); + drop(pids); + + let vermsg = VerMsg { + peer_id: self.id.clone(), + protocols, + version: self.config.version.v.clone(), + }; + + trace!("Send VerMsg"); + io_codec.write(NetMsgCmd::Version, &vermsg).await?; + Ok(()) + } + + /// Wait for a Version message + /// + /// Returns the peer's ID upon successfully receiving the Version message. + async fn wait_vermsg(&self, io_codec: &IOCodec) -> Result<PeerID> { + let timeout = Duration::from_secs(self.config.handshake_timeout); + let msg: NetMsg = io_codec.read_timeout(timeout).await?; + + let payload = get_msg_payload!(Version, msg); + let (vermsg, _) = decode::<VerMsg>(&payload)?; + + if !version_match(&self.config.version.req, &vermsg.version) { + return Err(Error::IncompatibleVersion("system: {}".into())); + } + + self.protocols_match(&vermsg.protocols).await?; + + trace!("Received VerMsg from: {}", vermsg.peer_id); + Ok(vermsg.peer_id) + } + + /// Send a Verack message + async fn send_verack(&self, io_codec: &IOCodec) -> Result<()> { + let verack = VerAckMsg(self.id.clone()); + + trace!("Send VerAckMsg"); + io_codec.write(NetMsgCmd::Verack, &verack).await?; + Ok(()) + } + + /// Wait for a Verack message + /// + /// Returns the peer's ID upon successfully receiving the Verack message. + async fn wait_verack(&self, io_codec: &IOCodec) -> Result<PeerID> { + let timeout = Duration::from_secs(self.config.handshake_timeout); + let msg: NetMsg = io_codec.read_timeout(timeout).await?; + + let payload = get_msg_payload!(Verack, msg); + let (verack, _) = decode::<VerAckMsg>(&payload)?; + + trace!("Received VerAckMsg from: {}", verack.0); + Ok(verack.0) + } + + /// Check if the new connection has compatible protocols. + async fn protocols_match(&self, protocols: &HashMap<ProtocolID, VersionInt>) -> Result<()> { + for (n, pv) in protocols.iter() { + let pids = self.protocol_versions.read().await; + + match pids.get(n) { + Some(v) => { + if !version_match(&v.req, pv) { + return Err(Error::IncompatibleVersion(format!("{n} protocol: {pv}"))); + } + } + None => { + return Err(Error::UnsupportedProtocol(n.to_string())); + } + } + } + Ok(()) + } +} diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs new file mode 100644 index 0000000..515efc6 --- /dev/null +++ b/p2p/src/protocol.rs @@ -0,0 +1,113 @@ +use std::sync::Arc; + +use async_trait::async_trait; + +use karyons_core::{event::EventValue, Executor}; + +use crate::{peer::ArcPeer, utils::Version, Result}; + +pub type ArcProtocol = Arc<dyn Protocol>; + +pub type ProtocolConstructor = dyn Fn(ArcPeer) -> Arc<dyn Protocol> + Send + Sync; + +pub type ProtocolID = String; + +/// Protocol event +#[derive(Debug, Clone)] +pub enum ProtocolEvent { + /// Message event, contains a vector of bytes. + Message(Vec<u8>), + /// Shutdown event signals the protocol to gracefully shut down. + Shutdown, +} + +impl EventValue for ProtocolEvent { + fn id() -> &'static str { + "ProtocolEvent" + } +} + +/// The Protocol trait defines the interface for core protocols +/// and custom protocols. +/// +/// # Example +/// ``` +/// use std::sync::Arc; +/// +/// use async_trait::async_trait; +/// use smol::Executor; +/// +/// use karyons_p2p::{ +/// protocol::{ArcProtocol, Protocol, ProtocolID, ProtocolEvent}, +/// Backend, PeerID, Config, Version, P2pError, ArcPeer}; +/// +/// pub struct NewProtocol { +/// peer: ArcPeer, +/// } +/// +/// impl NewProtocol { +/// fn new(peer: ArcPeer) -> ArcProtocol { +/// Arc::new(Self { +/// peer, +/// }) +/// } +/// } +/// +/// #[async_trait] +/// impl Protocol for NewProtocol { +/// async fn start(self: Arc<Self>, ex: Arc<Executor<'_>>) -> Result<(), P2pError> { +/// let listener = self.peer.register_listener::<Self>().await; +/// loop { +/// let event = listener.recv().await.unwrap(); +/// +/// match event { +/// ProtocolEvent::Message(msg) => { +/// println!("{:?}", msg); +/// } +/// ProtocolEvent::Shutdown => { +/// break; +/// } +/// } +/// } +/// +/// listener.cancel().await; +/// Ok(()) +/// } +/// +/// fn version() -> Result<Version, P2pError> { +/// "0.2.0, >0.1.0".parse() +/// } +/// +/// fn id() -> ProtocolID { +/// "NEWPROTOCOLID".into() +/// } +/// } +/// +/// async { +/// let peer_id = PeerID::random(); +/// let config = Config::default(); +/// +/// // Create a new Backend +/// let backend = Backend::new(peer_id, config); +/// +/// // Attach the NewProtocol +/// let c = move |peer| NewProtocol::new(peer); +/// backend.attach_protocol::<NewProtocol>(c).await.unwrap(); +/// }; +/// +/// ``` +#[async_trait] +pub trait Protocol: Send + Sync { + /// Start the protocol + async fn start(self: Arc<Self>, ex: Executor<'_>) -> Result<()>; + + /// Returns the version of the protocol. + fn version() -> Result<Version> + where + Self: Sized; + + /// Returns the unique ProtocolID associated with the protocol. + fn id() -> ProtocolID + where + Self: Sized; +} diff --git a/p2p/src/protocols/mod.rs b/p2p/src/protocols/mod.rs new file mode 100644 index 0000000..4a8f6b9 --- /dev/null +++ b/p2p/src/protocols/mod.rs @@ -0,0 +1,3 @@ +mod ping; + +pub use ping::PingProtocol; diff --git a/p2p/src/protocols/ping.rs b/p2p/src/protocols/ping.rs new file mode 100644 index 0000000..b337494 --- /dev/null +++ b/p2p/src/protocols/ping.rs @@ -0,0 +1,173 @@ +use std::{sync::Arc, time::Duration}; + +use async_trait::async_trait; +use bincode::{Decode, Encode}; +use log::trace; +use rand::{rngs::OsRng, RngCore}; +use smol::{ + channel, + channel::{Receiver, Sender}, + stream::StreamExt, + Timer, +}; + +use karyons_core::{ + async_utils::{select, timeout, Either, TaskGroup, TaskResult}, + event::EventListener, + utils::decode, + Executor, +}; + +use karyons_net::NetError; + +use crate::{ + peer::ArcPeer, + protocol::{ArcProtocol, Protocol, ProtocolEvent, ProtocolID}, + utils::Version, + Result, +}; + +const MAX_FAILUERS: u32 = 3; + +#[derive(Clone, Debug, Encode, Decode)] +enum PingProtocolMsg { + Ping([u8; 32]), + Pong([u8; 32]), +} + +pub struct PingProtocol { + peer: ArcPeer, + ping_interval: u64, + ping_timeout: u64, + task_group: TaskGroup, +} + +impl PingProtocol { + #[allow(clippy::new_ret_no_self)] + pub fn new(peer: ArcPeer) -> ArcProtocol { + let ping_interval = peer.config().ping_interval; + let ping_timeout = peer.config().ping_timeout; + Arc::new(Self { + peer, + ping_interval, + ping_timeout, + task_group: TaskGroup::new(), + }) + } + + async fn recv_loop( + &self, + listener: &EventListener<ProtocolID, ProtocolEvent>, + pong_chan: Sender<[u8; 32]>, + ) -> Result<()> { + loop { + let event = listener.recv().await?; + let msg_payload = match event.clone() { + ProtocolEvent::Message(m) => m, + ProtocolEvent::Shutdown => { + break; + } + }; + + let (msg, _) = decode::<PingProtocolMsg>(&msg_payload)?; + + match msg { + PingProtocolMsg::Ping(nonce) => { + trace!("Received Ping message {:?}", nonce); + self.peer + .send(&Self::id(), &PingProtocolMsg::Pong(nonce)) + .await?; + trace!("Send back Pong message {:?}", nonce); + } + PingProtocolMsg::Pong(nonce) => { + pong_chan.send(nonce).await?; + } + } + } + Ok(()) + } + + async fn ping_loop(self: Arc<Self>, chan: Receiver<[u8; 32]>) -> Result<()> { + let mut timer = Timer::interval(Duration::from_secs(self.ping_interval)); + let rng = &mut OsRng; + let mut retry = 0; + + while retry < MAX_FAILUERS { + timer.next().await; + + let mut ping_nonce: [u8; 32] = [0; 32]; + rng.fill_bytes(&mut ping_nonce); + + trace!("Send Ping message {:?}", ping_nonce); + self.peer + .send(&Self::id(), &PingProtocolMsg::Ping(ping_nonce)) + .await?; + + let d = Duration::from_secs(self.ping_timeout); + + // Wait for Pong message + let pong_msg = match timeout(d, chan.recv()).await { + Ok(m) => m?, + Err(_) => { + retry += 1; + continue; + } + }; + + trace!("Received Pong message {:?}", pong_msg); + + if pong_msg != ping_nonce { + retry += 1; + continue; + } + } + + Err(NetError::Timeout.into()) + } +} + +#[async_trait] +impl Protocol for PingProtocol { + async fn start(self: Arc<Self>, ex: Executor<'_>) -> Result<()> { + trace!("Start Ping protocol"); + let (pong_chan, pong_chan_recv) = channel::bounded(1); + let (stop_signal_s, stop_signal) = channel::bounded::<Result<()>>(1); + + let selfc = self.clone(); + self.task_group.spawn( + ex.clone(), + selfc.clone().ping_loop(pong_chan_recv.clone()), + |res| async move { + if let TaskResult::Completed(result) = res { + let _ = stop_signal_s.send(result).await; + } + }, + ); + + let listener = self.peer.register_listener::<Self>().await; + + let result = select(self.recv_loop(&listener, pong_chan), stop_signal.recv()).await; + listener.cancel().await; + self.task_group.cancel().await; + + match result { + Either::Left(res) => { + trace!("Receive loop stopped {:?}", res); + res + } + Either::Right(res) => { + let res = res?; + trace!("Ping loop stopped {:?}", res); + res + } + } + } + + fn version() -> Result<Version> { + "0.1.0".parse() + } + + fn id() -> ProtocolID { + "PING".into() + } +} diff --git a/p2p/src/routing_table/bucket.rs b/p2p/src/routing_table/bucket.rs new file mode 100644 index 0000000..13edd24 --- /dev/null +++ b/p2p/src/routing_table/bucket.rs @@ -0,0 +1,123 @@ +use super::{Entry, Key}; + +use rand::{rngs::OsRng, seq::SliceRandom}; + +/// BITFLAGS represent the status of an Entry within a bucket. +pub type EntryStatusFlag = u16; + +/// The entry is connected. +pub const CONNECTED_ENTRY: EntryStatusFlag = 0b00001; + +/// The entry is disconnected. This will increase the failure counter. +pub const DISCONNECTED_ENTRY: EntryStatusFlag = 0b00010; + +/// The entry is ready to reconnect, meaning it has either been added and +/// has no connection attempts, or it has been refreshed. +pub const PENDING_ENTRY: EntryStatusFlag = 0b00100; + +/// The entry is unreachable. This will increase the failure counter. +pub const UNREACHABLE_ENTRY: EntryStatusFlag = 0b01000; + +/// The entry is unstable. This will increase the failure counter. +pub const UNSTABLE_ENTRY: EntryStatusFlag = 0b10000; + +#[allow(dead_code)] +pub const ALL_ENTRY: EntryStatusFlag = 0b11111; + +/// A BucketEntry represents a peer in the routing table. +#[derive(Clone, Debug)] +pub struct BucketEntry { + pub status: EntryStatusFlag, + pub entry: Entry, + pub failures: u32, + pub last_seen: i64, +} + +impl BucketEntry { + pub fn is_connected(&self) -> bool { + self.status ^ CONNECTED_ENTRY == 0 + } + + pub fn is_unreachable(&self) -> bool { + self.status ^ UNREACHABLE_ENTRY == 0 + } + + pub fn is_unstable(&self) -> bool { + self.status ^ UNSTABLE_ENTRY == 0 + } +} + +/// The number of entries that can be stored within a single bucket. +pub const BUCKET_SIZE: usize = 20; + +/// A Bucket represents a group of entries in the routing table. +#[derive(Debug, Clone)] +pub struct Bucket { + entries: Vec<BucketEntry>, +} + +impl Bucket { + /// Creates a new empty Bucket + pub fn new() -> Self { + Self { + entries: Vec::with_capacity(BUCKET_SIZE), + } + } + + /// Add an entry to the bucket. + pub fn add(&mut self, entry: &Entry) { + self.entries.push(BucketEntry { + status: PENDING_ENTRY, + entry: entry.clone(), + failures: 0, + last_seen: chrono::Utc::now().timestamp(), + }) + } + + /// Get the number of entries in the bucket. + pub fn len(&self) -> usize { + self.entries.len() + } + + /// Returns an iterator over the entries in the bucket. + pub fn iter(&self) -> impl Iterator<Item = &BucketEntry> { + self.entries.iter() + } + + /// Remove an entry. + pub fn remove(&mut self, key: &Key) { + let position = self.entries.iter().position(|e| &e.entry.key == key); + if let Some(i) = position { + self.entries.remove(i); + } + } + + /// Returns an iterator of entries in random order. + pub fn random_iter(&self, amount: usize) -> impl Iterator<Item = &BucketEntry> { + self.entries.choose_multiple(&mut OsRng, amount) + } + + /// Updates the status of an entry in the bucket identified by the given key. + /// + /// If the key is not found in the bucket, no action is taken. + /// + /// This will also update the last_seen field and increase the failures + /// counter for the bucket entry according to the new status. + pub fn update_entry(&mut self, key: &Key, entry_flag: EntryStatusFlag) { + if let Some(e) = self.entries.iter_mut().find(|e| &e.entry.key == key) { + e.status = entry_flag; + if e.is_unreachable() || e.is_unstable() { + e.failures += 1; + } + + if !e.is_unreachable() { + e.last_seen = chrono::Utc::now().timestamp(); + } + } + } + + /// Check if the bucket contains the given key. + pub fn contains_key(&self, key: &Key) -> bool { + self.entries.iter().any(|e| &e.entry.key == key) + } +} diff --git a/p2p/src/routing_table/entry.rs b/p2p/src/routing_table/entry.rs new file mode 100644 index 0000000..b3f219f --- /dev/null +++ b/p2p/src/routing_table/entry.rs @@ -0,0 +1,41 @@ +use bincode::{Decode, Encode}; + +use karyons_net::{Addr, Port}; + +/// Specifies the size of the key, in bytes. +pub const KEY_SIZE: usize = 32; + +/// An Entry represents a peer in the routing table. +#[derive(Encode, Decode, Clone, Debug)] +pub struct Entry { + /// The unique key identifying the peer. + pub key: Key, + /// The IP address of the peer. + pub addr: Addr, + /// TCP port + pub port: Port, + /// UDP/TCP port + pub discovery_port: Port, +} + +impl PartialEq for Entry { + fn eq(&self, other: &Self) -> bool { + // XXX this should also compare both addresses (the self.addr == other.addr) + self.key == other.key + } +} + +/// The unique key identifying the peer. +pub type Key = [u8; KEY_SIZE]; + +/// Calculates the XOR distance between two provided keys. +/// +/// The XOR distance is a metric used in Kademlia to measure the closeness +/// of keys. +pub fn xor_distance(key: &Key, other: &Key) -> Key { + let mut res = [0; 32]; + for (i, (k, o)) in key.iter().zip(other.iter()).enumerate() { + res[i] = k ^ o; + } + res +} diff --git a/p2p/src/routing_table/mod.rs b/p2p/src/routing_table/mod.rs new file mode 100644 index 0000000..abf9a08 --- /dev/null +++ b/p2p/src/routing_table/mod.rs @@ -0,0 +1,461 @@ +mod bucket; +mod entry; +pub use bucket::{ + Bucket, BucketEntry, EntryStatusFlag, CONNECTED_ENTRY, DISCONNECTED_ENTRY, PENDING_ENTRY, + UNREACHABLE_ENTRY, UNSTABLE_ENTRY, +}; +pub use entry::{xor_distance, Entry, Key}; + +use rand::{rngs::OsRng, seq::SliceRandom}; + +use crate::utils::subnet_match; + +use bucket::BUCKET_SIZE; +use entry::KEY_SIZE; + +/// The total number of buckets in the routing table. +const TABLE_SIZE: usize = 32; + +/// The distance limit for the closest buckets. +const DISTANCE_LIMIT: usize = 32; + +/// The maximum number of matched subnets allowed within a single bucket. +const MAX_MATCHED_SUBNET_IN_BUCKET: usize = 1; + +/// The maximum number of matched subnets across the entire routing table. +const MAX_MATCHED_SUBNET_IN_TABLE: usize = 6; + +/// Represents the possible result when adding a new entry. +#[derive(Debug)] +pub enum AddEntryResult { + /// The entry is added. + Added, + /// The entry is already exists. + Exists, + /// The entry is ignored. + Ignored, + /// The entry is restricted and not allowed. + Restricted, +} + +/// This is a modified version of the Kademlia Distributed Hash Table (DHT). +/// https://en.wikipedia.org/wiki/Kademlia +#[derive(Debug)] +pub struct RoutingTable { + key: Key, + buckets: Vec<Bucket>, +} + +impl RoutingTable { + /// Creates a new RoutingTable + pub fn new(key: Key) -> Self { + let buckets: Vec<Bucket> = (0..TABLE_SIZE).map(|_| Bucket::new()).collect(); + Self { key, buckets } + } + + /// Adds a new entry to the table and returns a result indicating success, + /// failure, or restrictions. + pub fn add_entry(&mut self, entry: Entry) -> AddEntryResult { + // Determine the index of the bucket where the entry should be placed. + let bucket_idx = match self.bucket_index(&entry.key) { + Some(i) => i, + None => return AddEntryResult::Ignored, + }; + + let bucket = &self.buckets[bucket_idx]; + + // Check if the entry already exists in the bucket. + if bucket.contains_key(&entry.key) { + return AddEntryResult::Exists; + } + + // Check if the entry is restricted. + if self.subnet_restricted(bucket_idx, &entry) { + return AddEntryResult::Restricted; + } + + let bucket = &mut self.buckets[bucket_idx]; + + // If the bucket has free space, add the entry and return success. + if bucket.len() < BUCKET_SIZE { + bucket.add(&entry); + return AddEntryResult::Added; + } + + // If the bucket is full, the entry is ignored. + AddEntryResult::Ignored + } + + /// Check if the table contains the given key. + pub fn contains_key(&self, key: &Key) -> bool { + // Determine the bucket index for the given key. + let bucket_idx = match self.bucket_index(key) { + Some(bi) => bi, + None => return false, + }; + + let bucket = &self.buckets[bucket_idx]; + bucket.contains_key(key) + } + + /// Updates the status of an entry in the routing table identified + /// by the given key. + /// + /// If the key is not found, no action is taken. + pub fn update_entry(&mut self, key: &Key, entry_flag: EntryStatusFlag) { + // Determine the bucket index for the given key. + let bucket_idx = match self.bucket_index(key) { + Some(bi) => bi, + None => return, + }; + + let bucket = &mut self.buckets[bucket_idx]; + bucket.update_entry(key, entry_flag); + } + + /// Returns a list of bucket indexes that are closest to the given target key. + pub fn bucket_indexes(&self, target_key: &Key) -> Vec<usize> { + let mut indexes = vec![]; + + // Determine the primary bucket index for the target key. + let bucket_idx = self.bucket_index(target_key).unwrap_or(0); + + indexes.push(bucket_idx); + + // Add additional bucket indexes within a certain distance limit. + for i in 1..DISTANCE_LIMIT { + if bucket_idx >= i && bucket_idx - i >= 1 { + indexes.push(bucket_idx - i); + } + + if bucket_idx + i < (TABLE_SIZE - 1) { + indexes.push(bucket_idx + i); + } + } + + indexes + } + + /// Returns a list of the closest entries to the given target key, limited by max_entries. + pub fn closest_entries(&self, target_key: &Key, max_entries: usize) -> Vec<Entry> { + let mut entries: Vec<Entry> = vec![]; + + // Collect entries + 'outer: for idx in self.bucket_indexes(target_key) { + let bucket = &self.buckets[idx]; + for bucket_entry in bucket.iter() { + if bucket_entry.is_unreachable() || bucket_entry.is_unstable() { + continue; + } + + entries.push(bucket_entry.entry.clone()); + if entries.len() == max_entries { + break 'outer; + } + } + } + + // Sort the entries by their distance to the target key. + entries.sort_by(|a, b| { + xor_distance(target_key, &a.key).cmp(&xor_distance(target_key, &b.key)) + }); + + entries + } + + /// Removes an entry with the given key from the routing table, if it exists. + pub fn remove_entry(&mut self, key: &Key) { + // Determine the bucket index for the given key. + let bucket_idx = match self.bucket_index(key) { + Some(bi) => bi, + None => return, + }; + + let bucket = &mut self.buckets[bucket_idx]; + bucket.remove(key); + } + + /// Returns an iterator of entries. + pub fn iter(&self) -> impl Iterator<Item = &Bucket> { + self.buckets.iter() + } + + /// Returns a random entry from the routing table. + pub fn random_entry(&self, entry_flag: EntryStatusFlag) -> Option<&Entry> { + for bucket in self.buckets.choose_multiple(&mut OsRng, self.buckets.len()) { + for entry in bucket.random_iter(bucket.len()) { + if entry.status & entry_flag == 0 { + continue; + } + return Some(&entry.entry); + } + } + + None + } + + // Returns the bucket index for a given key in the table. + fn bucket_index(&self, key: &Key) -> Option<usize> { + // Calculate the XOR distance between the self key and the provided key. + let distance = xor_distance(&self.key, key); + + for (i, b) in distance.iter().enumerate() { + if *b != 0 { + let lz = i * 8 + b.leading_zeros() as usize; + let bits = KEY_SIZE * 8 - 1; + let idx = (bits - lz) / 8; + return Some(idx); + } + } + None + } + + /// This function iterate through the routing table and counts how many + /// entries in the same subnet as the given Entry are already present. + /// + /// If the number of matching entries in the same bucket exceeds a + /// threshold (MAX_MATCHED_SUBNET_IN_BUCKET), or if the total count of + /// matching entries in the entire table exceeds a threshold + /// (MAX_MATCHED_SUBNET_IN_TABLE), the addition of the Entry + /// is considered restricted and returns true. + fn subnet_restricted(&self, idx: usize, entry: &Entry) -> bool { + let mut bucket_count = 0; + let mut table_count = 0; + + // Iterate through the routing table's buckets and entries to check + // for subnet matches. + for (i, bucket) in self.buckets.iter().enumerate() { + for e in bucket.iter() { + // If there is a subnet match, update the counts. + let matched = subnet_match(&e.entry.addr, &entry.addr); + if matched { + if i == idx { + bucket_count += 1; + } + table_count += 1; + } + + // If the number of matched entries in the same bucket exceeds + // the limit, return true + if bucket_count >= MAX_MATCHED_SUBNET_IN_BUCKET { + return true; + } + } + + // If the total matched entries in the table exceed the limit, + // return true. + if table_count >= MAX_MATCHED_SUBNET_IN_TABLE { + return true; + } + } + + // If no subnet restrictions are encountered, return false. + false + } +} + +#[cfg(test)] +mod tests { + use super::bucket::ALL_ENTRY; + use super::*; + + use karyons_net::Addr; + + struct Setup { + local_key: Key, + keys: Vec<Key>, + } + + fn new_entry(key: &Key, addr: &Addr, port: u16, discovery_port: u16) -> Entry { + Entry { + key: key.clone(), + addr: addr.clone(), + port, + discovery_port, + } + } + + impl Setup { + fn new() -> Self { + let keys = vec![ + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 0, 1, 1, 2, + ], + [ + 0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 3, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 1, 18, 0, 0, 0, + 0, 0, 0, 0, 0, 4, + ], + [ + 223, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 5, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 50, 1, 18, 0, 0, 0, + 0, 0, 0, 0, 0, 6, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 50, 1, 18, 0, 0, + 0, 0, 0, 0, 0, 0, 7, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 50, 1, 18, 0, 0, + 0, 0, 0, 0, 0, 0, 8, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 10, 50, 1, 18, 0, 0, + 0, 0, 0, 0, 0, 0, 9, + ], + ]; + + Self { + local_key: [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ], + keys, + } + } + + fn entries(&self) -> Vec<Entry> { + let mut entries = vec![]; + for (i, key) in self.keys.iter().enumerate() { + entries.push(new_entry( + key, + &Addr::Ip(format!("127.0.{i}.1").parse().unwrap()), + 3000, + 3010, + )); + } + entries + } + + fn table(&self) -> RoutingTable { + let mut table = RoutingTable::new(self.local_key.clone()); + + for entry in self.entries() { + let res = table.add_entry(entry); + assert!(matches!(res, AddEntryResult::Added)); + } + + table + } + } + + #[test] + fn test_bucket_index() { + let setup = Setup::new(); + let table = setup.table(); + + assert_eq!(table.bucket_index(&setup.local_key), None); + assert_eq!(table.bucket_index(&setup.keys[0]), Some(0)); + assert_eq!(table.bucket_index(&setup.keys[1]), Some(5)); + assert_eq!(table.bucket_index(&setup.keys[2]), Some(26)); + assert_eq!(table.bucket_index(&setup.keys[3]), Some(11)); + assert_eq!(table.bucket_index(&setup.keys[4]), Some(31)); + assert_eq!(table.bucket_index(&setup.keys[5]), Some(11)); + assert_eq!(table.bucket_index(&setup.keys[6]), Some(12)); + assert_eq!(table.bucket_index(&setup.keys[7]), Some(13)); + assert_eq!(table.bucket_index(&setup.keys[8]), Some(14)); + } + + #[test] + fn test_closest_entries() { + let setup = Setup::new(); + let table = setup.table(); + let entries = setup.entries(); + + assert_eq!( + table.closest_entries(&setup.keys[5], 8), + vec![ + entries[5].clone(), + entries[3].clone(), + entries[1].clone(), + entries[6].clone(), + entries[7].clone(), + entries[8].clone(), + entries[2].clone(), + ] + ); + + assert_eq!( + table.closest_entries(&setup.keys[4], 2), + vec![entries[4].clone(), entries[2].clone()] + ); + } + + #[test] + fn test_random_entry() { + let setup = Setup::new(); + let mut table = setup.table(); + let entries = setup.entries(); + + let entry = table.random_entry(ALL_ENTRY); + assert!(matches!(entry, Some(&_))); + + let entry = table.random_entry(CONNECTED_ENTRY); + assert!(matches!(entry, None)); + + for entry in entries { + table.remove_entry(&entry.key); + } + + let entry = table.random_entry(ALL_ENTRY); + assert!(matches!(entry, None)); + } + + #[test] + fn test_add_entries() { + let setup = Setup::new(); + let mut table = setup.table(); + + let key = [ + 0, 0, 0, 0, 0, 0, 0, 1, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 5, + ]; + + let key2 = [ + 0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 5, + ]; + + let entry1 = new_entry(&key, &Addr::Ip("240.120.3.1".parse().unwrap()), 3000, 3010); + assert!(matches!( + table.add_entry(entry1.clone()), + AddEntryResult::Added + )); + + assert!(matches!(table.add_entry(entry1), AddEntryResult::Exists)); + + let entry2 = new_entry(&key2, &Addr::Ip("240.120.3.2".parse().unwrap()), 3000, 3010); + assert!(matches!( + table.add_entry(entry2), + AddEntryResult::Restricted + )); + + let mut key: [u8; 32] = [0; 32]; + + for i in 0..BUCKET_SIZE { + key[i] += 1; + let entry = new_entry( + &key, + &Addr::Ip(format!("127.0.{i}.1").parse().unwrap()), + 3000, + 3010, + ); + table.add_entry(entry); + } + + key[BUCKET_SIZE] += 1; + let entry = new_entry(&key, &Addr::Ip("125.20.0.1".parse().unwrap()), 3000, 3010); + assert!(matches!(table.add_entry(entry), AddEntryResult::Ignored)); + } +} diff --git a/p2p/src/utils/mod.rs b/p2p/src/utils/mod.rs new file mode 100644 index 0000000..e8ff9d0 --- /dev/null +++ b/p2p/src/utils/mod.rs @@ -0,0 +1,21 @@ +mod version; + +pub use version::{version_match, Version, VersionInt}; + +use std::net::IpAddr; + +use karyons_net::Addr; + +/// Check if two addresses belong to the same subnet. +pub fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool { + match (addr, other_addr) { + (Addr::Ip(IpAddr::V4(ip)), Addr::Ip(IpAddr::V4(other_ip))) => { + // XXX Consider moving this to a different location + if other_ip.is_loopback() && ip.is_loopback() { + return false; + } + ip.octets()[0..3] == other_ip.octets()[0..3] + } + _ => false, + } +} diff --git a/p2p/src/utils/version.rs b/p2p/src/utils/version.rs new file mode 100644 index 0000000..4986495 --- /dev/null +++ b/p2p/src/utils/version.rs @@ -0,0 +1,93 @@ +use std::str::FromStr; + +use bincode::{Decode, Encode}; +use semver::VersionReq; + +use crate::{Error, Result}; + +/// Represents the network version and protocol version used in Karyons p2p. +/// +/// # Example +/// +/// ``` +/// use karyons_p2p::Version; +/// +/// let version: Version = "0.2.0, >0.1.0".parse().unwrap(); +/// +/// let version: Version = "0.2.0".parse().unwrap(); +/// +/// ``` +#[derive(Debug, Clone)] +pub struct Version { + pub v: VersionInt, + pub req: VersionReq, +} + +impl Version { + /// Creates a new Version + pub fn new(v: VersionInt, req: VersionReq) -> Self { + Self { v, req } + } +} + +#[derive(Debug, Decode, Encode, Clone)] +pub struct VersionInt { + major: u64, + minor: u64, + patch: u64, +} + +impl FromStr for Version { + type Err = Error; + + fn from_str(s: &str) -> Result<Self> { + let v: Vec<&str> = s.split(", ").collect(); + if v.is_empty() || v.len() > 2 { + return Err(Error::ParseError(format!("Invalid version{s}"))); + } + + let version: VersionInt = v[0].parse()?; + let req: VersionReq = if v.len() > 1 { v[1] } else { v[0] }.parse()?; + + Ok(Self { v: version, req }) + } +} + +impl FromStr for VersionInt { + type Err = Error; + + fn from_str(s: &str) -> Result<Self> { + let v: Vec<&str> = s.split('.').collect(); + if v.len() < 2 || v.len() > 3 { + return Err(Error::ParseError(format!("Invalid version{s}"))); + } + + let major = v[0].parse::<u64>()?; + let minor = v[1].parse::<u64>()?; + let patch = v.get(2).unwrap_or(&"0").parse::<u64>()?; + + Ok(Self { + major, + minor, + patch, + }) + } +} + +impl std::fmt::Display for VersionInt { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}.{}", self.major, self.minor, self.patch) + } +} + +impl From<VersionInt> for semver::Version { + fn from(v: VersionInt) -> Self { + semver::Version::new(v.major, v.minor, v.patch) + } +} + +/// Check if a version satisfies a version request. +pub fn version_match(version_req: &VersionReq, version: &VersionInt) -> bool { + let version: semver::Version = version.clone().into(); + version_req.matches(&version) +} |