From b8b5f00e9695f46ea30af3ce63aec6dd17f356ae Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 27 Jun 2024 02:39:31 +0200 Subject: Improve async channels error handling and replace unbounded channels with bounded channels Remove all unbounded channels to prevent unbounded memory usage and potential crashes. Use `FuturesUnordered` for sending to multiple channels simultaneously. This prevents the sending loop from blocking if one channel is blocked, and helps handle errors properly. --- Cargo.lock | 1 + core/Cargo.toml | 7 ++-- core/src/async_runtime/executor.rs | 5 +++ core/src/async_util/task_group.rs | 1 + core/src/event.rs | 74 +++++++++++++++++++++++++------------- core/src/pubsub.rs | 56 ++++++++++++++++++++++------- jsonrpc/src/server/mod.rs | 12 ++++--- p2p/src/discovery/lookup.rs | 23 ++++++++---- 8 files changed, 129 insertions(+), 50 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b193f82..8ca75d2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1229,6 +1229,7 @@ dependencies = [ "chrono", "dirs", "ed25519-dalek", + "futures-util", "log", "once_cell", "parking_lot", diff --git a/core/Cargo.toml b/core/Cargo.toml index d30b956..895ddf6 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -10,9 +10,9 @@ authors.workspace = true [features] default = ["smol"] -crypto = ["dep:ed25519-dalek"] +crypto = ["ed25519-dalek"] tokio = ["dep:tokio"] -smol = ["dep:smol", "dep:async-process"] +smol = ["dep:smol", "async-process"] [dependencies] log = "0.4.21" @@ -29,6 +29,9 @@ pin-project-lite = "0.2.14" async-process = { version = "2.2.3", optional = true } smol = { version = "2.0.0", optional = true } tokio = { version = "1.38.0", features = ["full"], optional = true } +futures-util = { version = "0.3.5", features = [ + "alloc", +], default-features = false } # encode bincode = "2.0.0-rc.3" diff --git a/core/src/async_runtime/executor.rs b/core/src/async_runtime/executor.rs index 9335f12..88f6370 100644 --- a/core/src/async_runtime/executor.rs +++ b/core/src/async_runtime/executor.rs @@ -25,6 +25,11 @@ impl Executor { ) -> Task { self.inner.spawn(future).into() } + + #[cfg(feature = "tokio")] + pub fn handle(&self) -> &tokio::runtime::Handle { + return self.inner.handle(); + } } static GLOBAL_EXECUTOR: OnceCell = OnceCell::new(); diff --git a/core/src/async_util/task_group.rs b/core/src/async_util/task_group.rs index 63c1541..c55b9a1 100644 --- a/core/src/async_util/task_group.rs +++ b/core/src/async_util/task_group.rs @@ -87,6 +87,7 @@ impl TaskGroup { self.stop_signal.broadcast().await; loop { + // XXX BE CAREFUL HERE, it hold synchronous mutex across .await point. let task = self.tasks.lock().pop(); if let Some(t) = task { t.cancel().await diff --git a/core/src/event.rs b/core/src/event.rs index 771d661..1632df3 100644 --- a/core/src/event.rs +++ b/core/src/event.rs @@ -7,17 +7,19 @@ use std::{ use async_channel::{Receiver, Sender}; use chrono::{DateTime, Utc}; +use futures_util::stream::{FuturesUnordered, StreamExt}; use log::{debug, error}; -use crate::{async_runtime::lock::Mutex, util::random_16, Result}; +use crate::{async_runtime::lock::Mutex, util::random_32, Result}; + +const CHANNEL_BUFFER_SIZE: usize = 1000; pub type ArcEventSys = Arc>; -pub type WeakEventSys = Weak>; -pub type EventListenerID = u16; +pub type EventListenerID = u32; type Listeners = HashMap>>>; -/// EventSys supports event emission to registered listeners based on topics. +/// EventSys emits events to registered listeners based on topics. /// # Example /// /// ``` @@ -74,22 +76,41 @@ type Listeners = HashMap { listeners: Mutex>, + listener_buffer_size: usize, } impl EventSys where T: std::hash::Hash + Eq + std::fmt::Debug + Clone, { - /// Creates a new `EventSys` + /// Creates a new [`EventSys`] pub fn new() -> ArcEventSys { Arc::new(Self { listeners: Mutex::new(HashMap::new()), + listener_buffer_size: CHANNEL_BUFFER_SIZE, + }) + } + + /// Creates a new [`EventSys`] with the provided buffer size for the + /// [`EventListener`] channel. + /// + /// This is important to control the memory used by the listener channel. + /// If the consumer for the event listener can't keep up with the new events + /// coming, then the channel buffer will fill with new events, and if the + /// buffer is full, the emit function will block until the listener + /// starts to consume the buffered events. + /// + /// If `size` is zero, this function will panic. + pub fn with_buffer_size(size: usize) -> ArcEventSys { + Arc::new(Self { + listeners: Mutex::new(HashMap::new()), + listener_buffer_size: size, }) } /// Emits an event to the listeners. /// - /// The event must implement the `EventValueTopic` trait to indicate the + /// The event must implement the [`EventValueTopic`] trait to indicate the /// topic of the event. Otherwise, you can use `emit_by_topic()`. pub async fn emit + Clone>(&self, value: &E) { let topic = E::topic(); @@ -115,22 +136,26 @@ where let event_id = E::id().to_string(); if !event_ids.contains_key(&event_id) { - debug!( - "Failed to emit an event to a non-existent event id: {:?}", - event_id - ); + debug!("Failed to emit an event: unknown event id {:?}", event_id); return; } - let mut failed_listeners = vec![]; + let mut results = FuturesUnordered::new(); let listeners = event_ids.get_mut(&event_id).unwrap(); for (listener_id, listener) in listeners.iter() { - if let Err(err) = listener.send(event.clone()).await { + let result = async { (*listener_id, listener.send(event.clone()).await) }; + results.push(result); + } + + let mut failed_listeners = vec![]; + while let Some((id, fut_err)) = results.next().await { + if let Err(err) = fut_err { debug!("Failed to emit event for topic {:?}: {}", topic, err); - failed_listeners.push(*listener_id); + failed_listeners.push(id); } } + drop(results); for listener_id in failed_listeners.iter() { listeners.remove(listener_id); @@ -142,7 +167,7 @@ where self: &Arc, topic: &T, ) -> EventListener { - let chan = async_channel::unbounded(); + let chan = async_channel::bounded(self.listener_buffer_size); let topics = &mut self.listeners.lock().await; @@ -159,9 +184,10 @@ where let listeners = event_ids.get_mut(&event_id).unwrap(); - let mut listener_id = random_16(); + let mut listener_id = random_32(); + // Generate a new one if listener_id already exists while listeners.contains_key(&listener_id) { - listener_id = random_16(); + listener_id = random_32(); } let listener = @@ -197,7 +223,7 @@ where pub struct EventListener { id: EventListenerID, recv_chan: Receiver, - event_sys: WeakEventSys, + event_sys: Weak>, event_id: String, topic: T, phantom: PhantomData, @@ -208,10 +234,10 @@ where T: std::hash::Hash + Eq + Clone + std::fmt::Debug, E: EventValueAny + Clone + EventValue, { - /// Create a new event listener. + /// Creates a new [`EventListener`]. fn new( id: EventListenerID, - event_sys: WeakEventSys, + event_sys: Weak>, recv_chan: Receiver, event_id: &str, topic: &T, @@ -226,12 +252,12 @@ where } } - /// Receive the next event. + /// Receives the next event. pub async fn recv(&self) -> Result { match self.recv_chan.recv().await { Ok(event) => match ((*event.value).value_as_any()).downcast_ref::() { Some(v) => Ok(v.clone()), - None => unreachable!("Error when attempting to downcast the event value."), + None => unreachable!("Failed to downcast the event value."), }, Err(err) => { error!("Failed to receive new event: {err}"); @@ -241,7 +267,7 @@ where } } - /// Cancels the listener and removes it from the `EventSys`. + /// Cancels the event listener and removes it from the [`EventSys`]. pub async fn cancel(&self) { if let Some(es) = self.event_sys.upgrade() { es.remove(&self.topic, &self.event_id, &self.id).await; @@ -249,12 +275,12 @@ where } /// Returns the topic for this event listener. - pub async fn topic(&self) -> &T { + pub fn topic(&self) -> &T { &self.topic } /// Returns the event id for this event listener. - pub async fn event_id(&self) -> &String { + pub fn event_id(&self) -> &String { &self.event_id } } diff --git a/core/src/pubsub.rs b/core/src/pubsub.rs index bcc24ef..09b62ea 100644 --- a/core/src/pubsub.rs +++ b/core/src/pubsub.rs @@ -1,11 +1,14 @@ use std::{collections::HashMap, sync::Arc}; +use futures_util::stream::{FuturesUnordered, StreamExt}; use log::error; -use crate::{async_runtime::lock::Mutex, util::random_16, Result}; +use crate::{async_runtime::lock::Mutex, util::random_32, Result}; + +const CHANNEL_BUFFER_SIZE: usize = 1000; pub type ArcPublisher = Arc>; -pub type SubscriptionID = u16; +pub type SubscriptionID = u32; /// A simple publish-subscribe system. // # Example @@ -28,27 +31,46 @@ pub type SubscriptionID = u16; /// ``` pub struct Publisher { subs: Mutex>>, + subscription_buffer_size: usize, } impl Publisher { - /// Creates a new Publisher + /// Creates a new [`Publisher`] pub fn new() -> ArcPublisher { Arc::new(Self { subs: Mutex::new(HashMap::new()), + subscription_buffer_size: CHANNEL_BUFFER_SIZE, + }) + } + + /// Creates a new [`Publisher`] with the provided buffer size for the + /// [`Subscription`] channel. + /// + /// This is important to control the memory used by the [`Subscription`] channel. + /// If the subscriber can't keep up with the new messages coming, then the + /// channel buffer will fill with new messages, and if the buffer is full, + /// the emit function will block until the subscriber starts to process + /// the buffered messages. + /// + /// If `size` is zero, this function will panic. + pub fn with_buffer_size(size: usize) -> ArcPublisher { + Arc::new(Self { + subs: Mutex::new(HashMap::new()), + subscription_buffer_size: size, }) } - /// Subscribe and return a Subscription + /// Subscribes and return a [`Subscription`] pub async fn subscribe(self: &Arc) -> Subscription { let mut subs = self.subs.lock().await; - let chan = async_channel::unbounded(); + let chan = async_channel::bounded(self.subscription_buffer_size); - let mut sub_id = random_16(); + let mut sub_id = random_32(); - // While the SubscriptionID already exists, generate a new one + // Generate a new one if sub_id already exists while subs.contains_key(&sub_id) { - sub_id = random_16(); + sub_id = random_32(); } let sub = Subscription::new(sub_id, self.clone(), chan.1); @@ -57,22 +79,30 @@ impl Publisher { sub } - /// Unsubscribe from the Publisher + /// Unsubscribes from the publisher pub async fn unsubscribe(self: &Arc, id: &SubscriptionID) { self.subs.lock().await.remove(id); } - /// Notify all subscribers + /// Notifies all subscribers pub async fn notify(self: &Arc, value: &T) { let mut subs = self.subs.lock().await; + + let mut results = FuturesUnordered::new(); let mut closed_subs = vec![]; for (sub_id, sub) in subs.iter() { - if let Err(err) = sub.send(value.clone()).await { - error!("failed to notify {}: {}", sub_id, err); - closed_subs.push(*sub_id); + let result = async { (*sub_id, sub.send(value.clone()).await) }; + results.push(result); + } + + while let Some((id, fut_err)) = results.next().await { + if let Err(err) = fut_err { + error!("failed to notify {}: {}", id, err); + closed_subs.push(id); } } + drop(results); for sub_id in closed_subs.iter() { subs.remove(sub_id); diff --git a/jsonrpc/src/server/mod.rs b/jsonrpc/src/server/mod.rs index 8fa8a1c..8d5cd2c 100644 --- a/jsonrpc/src/server/mod.rs +++ b/jsonrpc/src/server/mod.rs @@ -33,6 +33,8 @@ pub const FAILED_TO_PARSE_ERROR_MSG: &str = "Failed to parse"; pub const METHOD_NOT_FOUND_ERROR_MSG: &str = "Method not found"; pub const UNSUPPORTED_JSONRPC_VERSION: &str = "Unsupported jsonrpc version"; +const CHANNEL_SUBSCRIPTION_BUFFER_SIZE: usize = 100; + struct NewRequest { srvc_name: String, method_name: String, @@ -108,7 +110,7 @@ impl Server { let conn = Arc::new(conn); - let (ch_tx, ch_rx) = async_channel::unbounded(); + let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_SUBSCRIPTION_BUFFER_SIZE); // Create a new connection channel for managing subscriptions let channel = Channel::new(ch_tx); @@ -120,13 +122,13 @@ impl Server { if let TaskResult::Completed(Err(err)) = result { debug!("Notification loop stopped: {err}"); } - // Close the connection subscription channel + // Close the connection channel chan.close(); }; let conn_cloned = conn.clone(); let queue_cloned = queue.clone(); - // Start listening for responses in the queue or new notifications + // Start listening for new responses in the queue or new notifications self.task_group.spawn( async move { loop { @@ -163,12 +165,12 @@ impl Server { } else { warn!("Connection {} dropped", endpoint); } - // Close the subscription channel when the connection dropped + // Close the connection channel when the connection dropped chan.close(); }; let selfc = self.clone(); - // Spawn a new task and wait for requests. + // Spawn a new task and wait for new requests. self.task_group.spawn( async move { loop { diff --git a/p2p/src/discovery/lookup.rs b/p2p/src/discovery/lookup.rs index a941986..8e06eef 100644 --- a/p2p/src/discovery/lookup.rs +++ b/p2p/src/discovery/lookup.rs @@ -1,6 +1,6 @@ use std::{sync::Arc, time::Duration}; -use futures_util::{stream::FuturesUnordered, StreamExt}; +use futures_util::stream::{FuturesUnordered, StreamExt}; use log::{error, trace}; use rand::{rngs::OsRng, seq::SliceRandom, RngCore}; @@ -146,7 +146,12 @@ impl LookupService { }; let mut peer_buffer = vec![]; - self.self_lookup(&random_peers, &mut peer_buffer).await; + if let Err(err) = self.self_lookup(&random_peers, &mut peer_buffer).await { + self.monitor + .notify(DiscvEvent::LookupFailed(endpoint.clone())) + .await; + return Err(err); + } while peer_buffer.len() < MAX_PEERS_IN_PEERSMSG { match random_peers.pop() { @@ -201,14 +206,18 @@ impl LookupService { } /// Starts a self lookup - async fn self_lookup(&self, random_peers: &Vec, peer_buffer: &mut Vec) { - let mut tasks = FuturesUnordered::new(); + async fn self_lookup( + &self, + random_peers: &Vec, + peer_buffer: &mut Vec, + ) -> Result<()> { + let mut results = FuturesUnordered::new(); for peer in random_peers.choose_multiple(&mut OsRng, random_peers.len()) { let endpoint = Endpoint::Tcp(peer.addr.clone(), peer.discovery_port); - tasks.push(self.connect(endpoint, Some(peer.peer_id.clone()), &self.id)) + results.push(self.connect(endpoint, Some(peer.peer_id.clone()), &self.id)) } - while let Some(result) = tasks.next().await { + while let Some(result) = results.next().await { match result { Ok(peers) => peer_buffer.extend(peers), Err(err) => { @@ -216,6 +225,8 @@ impl LookupService { } } } + + Ok(()) } /// Connects to the given endpoint and initiates a lookup process for the -- cgit v1.2.3