diff options
Diffstat (limited to 'karyons_p2p/src/peer')
-rw-r--r-- | karyons_p2p/src/peer/mod.rs | 237 | ||||
-rw-r--r-- | karyons_p2p/src/peer/peer_id.rs | 41 |
2 files changed, 278 insertions, 0 deletions
diff --git a/karyons_p2p/src/peer/mod.rs b/karyons_p2p/src/peer/mod.rs new file mode 100644 index 0000000..ee0fdc4 --- /dev/null +++ b/karyons_p2p/src/peer/mod.rs @@ -0,0 +1,237 @@ +mod peer_id; + +pub use peer_id::PeerID; + +use std::sync::Arc; + +use log::{error, trace}; +use smol::{ + channel::{self, Receiver, Sender}, + lock::RwLock, +}; + +use karyons_core::{ + async_utils::{select, Either, TaskGroup, TaskResult}, + event::{ArcEventSys, EventListener, EventSys}, + utils::{decode, encode}, + Executor, +}; + +use karyons_net::Endpoint; + +use crate::{ + io_codec::{CodecMsg, IOCodec}, + message::{NetMsgCmd, ProtocolMsg, ShutdownMsg}, + net::ConnDirection, + peer_pool::{ArcPeerPool, WeakPeerPool}, + protocol::{Protocol, ProtocolEvent, ProtocolID}, + Config, Error, Result, +}; + +pub type ArcPeer = Arc<Peer>; + +pub struct Peer { + /// Peer's ID + id: PeerID, + + /// A weak pointer to `PeerPool` + peer_pool: WeakPeerPool, + + /// Holds the IOCodec for the peer connection + io_codec: IOCodec, + + /// Remote endpoint for the peer + remote_endpoint: Endpoint, + + /// The direction of the connection, either `Inbound` or `Outbound` + conn_direction: ConnDirection, + + /// A list of protocol IDs + protocol_ids: RwLock<Vec<ProtocolID>>, + + /// `EventSys` responsible for sending events to the protocols. + protocol_events: ArcEventSys<ProtocolID>, + + /// This channel is used to send a stop signal to the read loop. + stop_chan: (Sender<Result<()>>, Receiver<Result<()>>), + + /// Managing spawned tasks. + task_group: TaskGroup, +} + +impl Peer { + /// Creates a new peer + pub fn new( + peer_pool: WeakPeerPool, + id: &PeerID, + io_codec: IOCodec, + remote_endpoint: Endpoint, + conn_direction: ConnDirection, + ) -> ArcPeer { + Arc::new(Peer { + id: id.clone(), + peer_pool, + io_codec, + protocol_ids: RwLock::new(Vec::new()), + remote_endpoint, + conn_direction, + protocol_events: EventSys::new(), + task_group: TaskGroup::new(), + stop_chan: channel::bounded(1), + }) + } + + /// Run the peer + pub async fn run(self: Arc<Self>, ex: Executor<'_>) -> Result<()> { + self.start_protocols(ex.clone()).await; + self.read_loop().await + } + + /// Send a message to the peer connection using the specified protocol. + pub async fn send<T: CodecMsg>(&self, protocol_id: &ProtocolID, msg: &T) -> Result<()> { + let payload = encode(msg)?; + + let proto_msg = ProtocolMsg { + protocol_id: protocol_id.to_string(), + payload: payload.to_vec(), + }; + + self.io_codec.write(NetMsgCmd::Protocol, &proto_msg).await?; + Ok(()) + } + + /// Broadcast a message to all connected peers using the specified protocol. + pub async fn broadcast<T: CodecMsg>(&self, protocol_id: &ProtocolID, msg: &T) { + self.peer_pool().broadcast(protocol_id, msg).await; + } + + /// Shuts down the peer + pub async fn shutdown(&self) { + trace!("peer {} start shutting down", self.id); + + // Send shutdown event to all protocols + for protocol_id in self.protocol_ids.read().await.iter() { + self.protocol_events + .emit_by_topic(protocol_id, &ProtocolEvent::Shutdown) + .await; + } + + // Send a stop signal to the read loop + // + // No need to handle the error here; a dropped channel and + // sending a stop signal have the same effect. + let _ = self.stop_chan.0.try_send(Ok(())); + + // No need to handle the error here + let _ = self + .io_codec + .write(NetMsgCmd::Shutdown, &ShutdownMsg(0)) + .await; + + // Force shutting down + self.task_group.cancel().await; + } + + /// Check if the connection is Inbound + #[inline] + pub fn is_inbound(&self) -> bool { + match self.conn_direction { + ConnDirection::Inbound => true, + ConnDirection::Outbound => false, + } + } + + /// Returns the direction of the connection, which can be either `Inbound` + /// or `Outbound`. + #[inline] + pub fn direction(&self) -> &ConnDirection { + &self.conn_direction + } + + /// Returns the remote endpoint for the peer + #[inline] + pub fn remote_endpoint(&self) -> &Endpoint { + &self.remote_endpoint + } + + /// Return the peer's ID + #[inline] + pub fn id(&self) -> &PeerID { + &self.id + } + + /// Returns the `Config` instance. + pub fn config(&self) -> Arc<Config> { + self.peer_pool().config.clone() + } + + /// Registers a listener for the given Protocol `P`. + pub async fn register_listener<P: Protocol>(&self) -> EventListener<ProtocolID, ProtocolEvent> { + self.protocol_events.register(&P::id()).await + } + + /// Start a read loop to handle incoming messages from the peer connection. + async fn read_loop(&self) -> Result<()> { + loop { + let fut = select(self.stop_chan.1.recv(), self.io_codec.read()).await; + let result = match fut { + Either::Left(stop_signal) => { + trace!("Peer {} received a stop signal", self.id); + return stop_signal?; + } + Either::Right(result) => result, + }; + + let msg = result?; + + match msg.header.command { + NetMsgCmd::Protocol => { + let msg: ProtocolMsg = decode(&msg.payload)?.0; + + if !self.protocol_ids.read().await.contains(&msg.protocol_id) { + return Err(Error::UnsupportedProtocol(msg.protocol_id)); + } + + let proto_id = &msg.protocol_id; + let msg = ProtocolEvent::Message(msg.payload); + self.protocol_events.emit_by_topic(proto_id, &msg).await; + } + NetMsgCmd::Shutdown => { + return Err(Error::PeerShutdown); + } + command => return Err(Error::InvalidMsg(format!("Unexpected msg {:?}", command))), + } + } + } + + /// Start running the protocols for this peer connection. + async fn start_protocols(self: &Arc<Self>, ex: Executor<'_>) { + for (protocol_id, constructor) in self.peer_pool().protocols.read().await.iter() { + trace!("peer {} start protocol {protocol_id}", self.id); + let protocol = constructor(self.clone()); + + self.protocol_ids.write().await.push(protocol_id.clone()); + + let selfc = self.clone(); + let exc = ex.clone(); + let proto_idc = protocol_id.clone(); + + let on_failure = |result: TaskResult<Result<()>>| async move { + if let TaskResult::Completed(res) = result { + if res.is_err() { + error!("protocol {} stopped", proto_idc); + } + // Send a stop signal to read loop + let _ = selfc.stop_chan.0.try_send(res); + } + }; + + self.task_group + .spawn(ex.clone(), protocol.start(exc), on_failure); + } + } + + fn peer_pool(&self) -> ArcPeerPool { + self.peer_pool.upgrade().unwrap() + } +} diff --git a/karyons_p2p/src/peer/peer_id.rs b/karyons_p2p/src/peer/peer_id.rs new file mode 100644 index 0000000..c8aec7d --- /dev/null +++ b/karyons_p2p/src/peer/peer_id.rs @@ -0,0 +1,41 @@ +use bincode::{Decode, Encode}; +use rand::{rngs::OsRng, RngCore}; +use sha2::{Digest, Sha256}; + +/// Represents a unique identifier for a peer. +#[derive(Clone, Debug, Eq, PartialEq, Hash, Decode, Encode)] +pub struct PeerID(pub [u8; 32]); + +impl std::fmt::Display for PeerID { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let id = self.0[0..8] + .iter() + .map(|b| format!("{:x}", b)) + .collect::<Vec<String>>() + .join(""); + + write!(f, "{}", id) + } +} + +impl PeerID { + /// Creates a new PeerID. + pub fn new(src: &[u8]) -> Self { + let mut hasher = Sha256::new(); + hasher.update(src); + Self(hasher.finalize().into()) + } + + /// Generates a random PeerID. + pub fn random() -> Self { + let mut id: [u8; 32] = [0; 32]; + OsRng.fill_bytes(&mut id); + Self(id) + } +} + +impl From<[u8; 32]> for PeerID { + fn from(b: [u8; 32]) -> Self { + PeerID(b) + } +} |