aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorhozan23 <hozan23@karyontech.net>2024-06-27 02:39:31 +0200
committerhozan23 <hozan23@karyontech.net>2024-06-27 02:39:31 +0200
commitb8b5f00e9695f46ea30af3ce63aec6dd17f356ae (patch)
tree3f1b07539c248f9536f5c7b6e3870e235d4f49d7 /core
parent1a3ef2d77ab54bfe286f7400ac0cee2e25ea14e3 (diff)
Improve async channels error handling and replace unbounded channels with bounded channels
Remove all unbounded channels to prevent unbounded memory usage and potential crashes. Use `FuturesUnordered` for sending to multiple channels simultaneously. This prevents the sending loop from blocking if one channel is blocked, and helps handle errors properly.
Diffstat (limited to 'core')
-rw-r--r--core/Cargo.toml7
-rw-r--r--core/src/async_runtime/executor.rs5
-rw-r--r--core/src/async_util/task_group.rs1
-rw-r--r--core/src/event.rs74
-rw-r--r--core/src/pubsub.rs56
5 files changed, 104 insertions, 39 deletions
diff --git a/core/Cargo.toml b/core/Cargo.toml
index d30b956..895ddf6 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -10,9 +10,9 @@ authors.workspace = true
[features]
default = ["smol"]
-crypto = ["dep:ed25519-dalek"]
+crypto = ["ed25519-dalek"]
tokio = ["dep:tokio"]
-smol = ["dep:smol", "dep:async-process"]
+smol = ["dep:smol", "async-process"]
[dependencies]
log = "0.4.21"
@@ -29,6 +29,9 @@ pin-project-lite = "0.2.14"
async-process = { version = "2.2.3", optional = true }
smol = { version = "2.0.0", optional = true }
tokio = { version = "1.38.0", features = ["full"], optional = true }
+futures-util = { version = "0.3.5", features = [
+ "alloc",
+], default-features = false }
# encode
bincode = "2.0.0-rc.3"
diff --git a/core/src/async_runtime/executor.rs b/core/src/async_runtime/executor.rs
index 9335f12..88f6370 100644
--- a/core/src/async_runtime/executor.rs
+++ b/core/src/async_runtime/executor.rs
@@ -25,6 +25,11 @@ impl Executor {
) -> Task<T> {
self.inner.spawn(future).into()
}
+
+ #[cfg(feature = "tokio")]
+ pub fn handle(&self) -> &tokio::runtime::Handle {
+ return self.inner.handle();
+ }
}
static GLOBAL_EXECUTOR: OnceCell<Executor> = OnceCell::new();
diff --git a/core/src/async_util/task_group.rs b/core/src/async_util/task_group.rs
index 63c1541..c55b9a1 100644
--- a/core/src/async_util/task_group.rs
+++ b/core/src/async_util/task_group.rs
@@ -87,6 +87,7 @@ impl TaskGroup {
self.stop_signal.broadcast().await;
loop {
+ // XXX BE CAREFUL HERE, it hold synchronous mutex across .await point.
let task = self.tasks.lock().pop();
if let Some(t) = task {
t.cancel().await
diff --git a/core/src/event.rs b/core/src/event.rs
index 771d661..1632df3 100644
--- a/core/src/event.rs
+++ b/core/src/event.rs
@@ -7,17 +7,19 @@ use std::{
use async_channel::{Receiver, Sender};
use chrono::{DateTime, Utc};
+use futures_util::stream::{FuturesUnordered, StreamExt};
use log::{debug, error};
-use crate::{async_runtime::lock::Mutex, util::random_16, Result};
+use crate::{async_runtime::lock::Mutex, util::random_32, Result};
+
+const CHANNEL_BUFFER_SIZE: usize = 1000;
pub type ArcEventSys<T> = Arc<EventSys<T>>;
-pub type WeakEventSys<T> = Weak<EventSys<T>>;
-pub type EventListenerID = u16;
+pub type EventListenerID = u32;
type Listeners<T> = HashMap<T, HashMap<String, HashMap<EventListenerID, Sender<Event>>>>;
-/// EventSys supports event emission to registered listeners based on topics.
+/// EventSys emits events to registered listeners based on topics.
/// # Example
///
/// ```
@@ -74,22 +76,41 @@ type Listeners<T> = HashMap<T, HashMap<String, HashMap<EventListenerID, Sender<E
///
pub struct EventSys<T> {
listeners: Mutex<Listeners<T>>,
+ listener_buffer_size: usize,
}
impl<T> EventSys<T>
where
T: std::hash::Hash + Eq + std::fmt::Debug + Clone,
{
- /// Creates a new `EventSys`
+ /// Creates a new [`EventSys`]
pub fn new() -> ArcEventSys<T> {
Arc::new(Self {
listeners: Mutex::new(HashMap::new()),
+ listener_buffer_size: CHANNEL_BUFFER_SIZE,
+ })
+ }
+
+ /// Creates a new [`EventSys`] with the provided buffer size for the
+ /// [`EventListener`] channel.
+ ///
+ /// This is important to control the memory used by the listener channel.
+ /// If the consumer for the event listener can't keep up with the new events
+ /// coming, then the channel buffer will fill with new events, and if the
+ /// buffer is full, the emit function will block until the listener
+ /// starts to consume the buffered events.
+ ///
+ /// If `size` is zero, this function will panic.
+ pub fn with_buffer_size(size: usize) -> ArcEventSys<T> {
+ Arc::new(Self {
+ listeners: Mutex::new(HashMap::new()),
+ listener_buffer_size: size,
})
}
/// Emits an event to the listeners.
///
- /// The event must implement the `EventValueTopic` trait to indicate the
+ /// The event must implement the [`EventValueTopic`] trait to indicate the
/// topic of the event. Otherwise, you can use `emit_by_topic()`.
pub async fn emit<E: EventValueTopic<Topic = T> + Clone>(&self, value: &E) {
let topic = E::topic();
@@ -115,22 +136,26 @@ where
let event_id = E::id().to_string();
if !event_ids.contains_key(&event_id) {
- debug!(
- "Failed to emit an event to a non-existent event id: {:?}",
- event_id
- );
+ debug!("Failed to emit an event: unknown event id {:?}", event_id);
return;
}
- let mut failed_listeners = vec![];
+ let mut results = FuturesUnordered::new();
let listeners = event_ids.get_mut(&event_id).unwrap();
for (listener_id, listener) in listeners.iter() {
- if let Err(err) = listener.send(event.clone()).await {
+ let result = async { (*listener_id, listener.send(event.clone()).await) };
+ results.push(result);
+ }
+
+ let mut failed_listeners = vec![];
+ while let Some((id, fut_err)) = results.next().await {
+ if let Err(err) = fut_err {
debug!("Failed to emit event for topic {:?}: {}", topic, err);
- failed_listeners.push(*listener_id);
+ failed_listeners.push(id);
}
}
+ drop(results);
for listener_id in failed_listeners.iter() {
listeners.remove(listener_id);
@@ -142,7 +167,7 @@ where
self: &Arc<Self>,
topic: &T,
) -> EventListener<T, E> {
- let chan = async_channel::unbounded();
+ let chan = async_channel::bounded(self.listener_buffer_size);
let topics = &mut self.listeners.lock().await;
@@ -159,9 +184,10 @@ where
let listeners = event_ids.get_mut(&event_id).unwrap();
- let mut listener_id = random_16();
+ let mut listener_id = random_32();
+ // Generate a new one if listener_id already exists
while listeners.contains_key(&listener_id) {
- listener_id = random_16();
+ listener_id = random_32();
}
let listener =
@@ -197,7 +223,7 @@ where
pub struct EventListener<T, E> {
id: EventListenerID,
recv_chan: Receiver<Event>,
- event_sys: WeakEventSys<T>,
+ event_sys: Weak<EventSys<T>>,
event_id: String,
topic: T,
phantom: PhantomData<E>,
@@ -208,10 +234,10 @@ where
T: std::hash::Hash + Eq + Clone + std::fmt::Debug,
E: EventValueAny + Clone + EventValue,
{
- /// Create a new event listener.
+ /// Creates a new [`EventListener`].
fn new(
id: EventListenerID,
- event_sys: WeakEventSys<T>,
+ event_sys: Weak<EventSys<T>>,
recv_chan: Receiver<Event>,
event_id: &str,
topic: &T,
@@ -226,12 +252,12 @@ where
}
}
- /// Receive the next event.
+ /// Receives the next event.
pub async fn recv(&self) -> Result<E> {
match self.recv_chan.recv().await {
Ok(event) => match ((*event.value).value_as_any()).downcast_ref::<E>() {
Some(v) => Ok(v.clone()),
- None => unreachable!("Error when attempting to downcast the event value."),
+ None => unreachable!("Failed to downcast the event value."),
},
Err(err) => {
error!("Failed to receive new event: {err}");
@@ -241,7 +267,7 @@ where
}
}
- /// Cancels the listener and removes it from the `EventSys`.
+ /// Cancels the event listener and removes it from the [`EventSys`].
pub async fn cancel(&self) {
if let Some(es) = self.event_sys.upgrade() {
es.remove(&self.topic, &self.event_id, &self.id).await;
@@ -249,12 +275,12 @@ where
}
/// Returns the topic for this event listener.
- pub async fn topic(&self) -> &T {
+ pub fn topic(&self) -> &T {
&self.topic
}
/// Returns the event id for this event listener.
- pub async fn event_id(&self) -> &String {
+ pub fn event_id(&self) -> &String {
&self.event_id
}
}
diff --git a/core/src/pubsub.rs b/core/src/pubsub.rs
index bcc24ef..09b62ea 100644
--- a/core/src/pubsub.rs
+++ b/core/src/pubsub.rs
@@ -1,11 +1,14 @@
use std::{collections::HashMap, sync::Arc};
+use futures_util::stream::{FuturesUnordered, StreamExt};
use log::error;
-use crate::{async_runtime::lock::Mutex, util::random_16, Result};
+use crate::{async_runtime::lock::Mutex, util::random_32, Result};
+
+const CHANNEL_BUFFER_SIZE: usize = 1000;
pub type ArcPublisher<T> = Arc<Publisher<T>>;
-pub type SubscriptionID = u16;
+pub type SubscriptionID = u32;
/// A simple publish-subscribe system.
// # Example
@@ -28,27 +31,46 @@ pub type SubscriptionID = u16;
/// ```
pub struct Publisher<T> {
subs: Mutex<HashMap<SubscriptionID, async_channel::Sender<T>>>,
+ subscription_buffer_size: usize,
}
impl<T: Clone> Publisher<T> {
- /// Creates a new Publisher
+ /// Creates a new [`Publisher`]
pub fn new() -> ArcPublisher<T> {
Arc::new(Self {
subs: Mutex::new(HashMap::new()),
+ subscription_buffer_size: CHANNEL_BUFFER_SIZE,
+ })
+ }
+
+ /// Creates a new [`Publisher`] with the provided buffer size for the
+ /// [`Subscription`] channel.
+ ///
+ /// This is important to control the memory used by the [`Subscription`] channel.
+ /// If the subscriber can't keep up with the new messages coming, then the
+ /// channel buffer will fill with new messages, and if the buffer is full,
+ /// the emit function will block until the subscriber starts to process
+ /// the buffered messages.
+ ///
+ /// If `size` is zero, this function will panic.
+ pub fn with_buffer_size(size: usize) -> ArcPublisher<T> {
+ Arc::new(Self {
+ subs: Mutex::new(HashMap::new()),
+ subscription_buffer_size: size,
})
}
- /// Subscribe and return a Subscription
+ /// Subscribes and return a [`Subscription`]
pub async fn subscribe(self: &Arc<Self>) -> Subscription<T> {
let mut subs = self.subs.lock().await;
- let chan = async_channel::unbounded();
+ let chan = async_channel::bounded(self.subscription_buffer_size);
- let mut sub_id = random_16();
+ let mut sub_id = random_32();
- // While the SubscriptionID already exists, generate a new one
+ // Generate a new one if sub_id already exists
while subs.contains_key(&sub_id) {
- sub_id = random_16();
+ sub_id = random_32();
}
let sub = Subscription::new(sub_id, self.clone(), chan.1);
@@ -57,22 +79,30 @@ impl<T: Clone> Publisher<T> {
sub
}
- /// Unsubscribe from the Publisher
+ /// Unsubscribes from the publisher
pub async fn unsubscribe(self: &Arc<Self>, id: &SubscriptionID) {
self.subs.lock().await.remove(id);
}
- /// Notify all subscribers
+ /// Notifies all subscribers
pub async fn notify(self: &Arc<Self>, value: &T) {
let mut subs = self.subs.lock().await;
+
+ let mut results = FuturesUnordered::new();
let mut closed_subs = vec![];
for (sub_id, sub) in subs.iter() {
- if let Err(err) = sub.send(value.clone()).await {
- error!("failed to notify {}: {}", sub_id, err);
- closed_subs.push(*sub_id);
+ let result = async { (*sub_id, sub.send(value.clone()).await) };
+ results.push(result);
+ }
+
+ while let Some((id, fut_err)) = results.next().await {
+ if let Err(err) = fut_err {
+ error!("failed to notify {}: {}", id, err);
+ closed_subs.push(id);
}
}
+ drop(results);
for sub_id in closed_subs.iter() {
subs.remove(sub_id);