From 4fe665fc8bc6265baf5bfba6b6a5f3ee2dba63dc Mon Sep 17 00:00:00 2001 From: hozan23 Date: Wed, 8 Nov 2023 13:03:27 +0300 Subject: first commit --- karyons_core/Cargo.toml | 16 + karyons_core/src/async_utils/backoff.rs | 115 ++++++++ karyons_core/src/async_utils/condvar.rs | 387 +++++++++++++++++++++++++ karyons_core/src/async_utils/condwait.rs | 96 ++++++ karyons_core/src/async_utils/mod.rs | 13 + karyons_core/src/async_utils/select.rs | 99 +++++++ karyons_core/src/async_utils/task_group.rs | 197 +++++++++++++ karyons_core/src/async_utils/timeout.rs | 52 ++++ karyons_core/src/error.rs | 51 ++++ karyons_core/src/event.rs | 451 +++++++++++++++++++++++++++++ karyons_core/src/lib.rs | 21 ++ karyons_core/src/pubsub.rs | 115 ++++++++ karyons_core/src/utils/decode.rs | 10 + karyons_core/src/utils/encode.rs | 15 + karyons_core/src/utils/mod.rs | 19 ++ karyons_core/src/utils/path.rs | 39 +++ 16 files changed, 1696 insertions(+) create mode 100644 karyons_core/Cargo.toml create mode 100644 karyons_core/src/async_utils/backoff.rs create mode 100644 karyons_core/src/async_utils/condvar.rs create mode 100644 karyons_core/src/async_utils/condwait.rs create mode 100644 karyons_core/src/async_utils/mod.rs create mode 100644 karyons_core/src/async_utils/select.rs create mode 100644 karyons_core/src/async_utils/task_group.rs create mode 100644 karyons_core/src/async_utils/timeout.rs create mode 100644 karyons_core/src/error.rs create mode 100644 karyons_core/src/event.rs create mode 100644 karyons_core/src/lib.rs create mode 100644 karyons_core/src/pubsub.rs create mode 100644 karyons_core/src/utils/decode.rs create mode 100644 karyons_core/src/utils/encode.rs create mode 100644 karyons_core/src/utils/mod.rs create mode 100644 karyons_core/src/utils/path.rs (limited to 'karyons_core') diff --git a/karyons_core/Cargo.toml b/karyons_core/Cargo.toml new file mode 100644 index 0000000..712b7db --- /dev/null +++ b/karyons_core/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "karyons_core" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +smol = "1.3.0" +pin-project-lite = "0.2.13" +log = "0.4.20" +bincode = { version="2.0.0-rc.3", features = ["derive"]} +chrono = "0.4.30" +rand = "0.8.5" +thiserror = "1.0.47" +dirs = "5.0.1" diff --git a/karyons_core/src/async_utils/backoff.rs b/karyons_core/src/async_utils/backoff.rs new file mode 100644 index 0000000..f7e131d --- /dev/null +++ b/karyons_core/src/async_utils/backoff.rs @@ -0,0 +1,115 @@ +use std::{ + cmp::min, + sync::atomic::{AtomicBool, AtomicU32, Ordering}, + time::Duration, +}; + +use smol::Timer; + +/// Exponential backoff +/// +/// +/// # Examples +/// +/// ``` +/// use karyons_core::async_utils::Backoff; +/// +/// async { +/// let backoff = Backoff::new(300, 3000); +/// +/// loop { +/// backoff.sleep().await; +/// +/// // do something +/// break; +/// } +/// +/// backoff.reset(); +/// +/// // .... +/// }; +/// +/// ``` +/// +pub struct Backoff { + /// The base delay in milliseconds for the initial retry. + base_delay: u64, + /// The max delay in milliseconds allowed for a retry. + max_delay: u64, + /// Atomic counter + retries: AtomicU32, + /// Stop flag + stop: AtomicBool, +} + +impl Backoff { + /// Creates a new Backoff. + pub fn new(base_delay: u64, max_delay: u64) -> Self { + Self { + base_delay, + max_delay, + retries: AtomicU32::new(0), + stop: AtomicBool::new(false), + } + } + + /// Sleep based on the current retry count and delay values. + /// Retruns the delay value. + pub async fn sleep(&self) -> u64 { + if self.stop.load(Ordering::SeqCst) { + Timer::after(Duration::from_millis(self.max_delay)).await; + return self.max_delay; + } + + let retries = self.retries.load(Ordering::SeqCst); + let delay = self.base_delay * (2_u64).pow(retries); + let delay = min(delay, self.max_delay); + + if delay == self.max_delay { + self.stop.store(true, Ordering::SeqCst); + } + + self.retries.store(retries + 1, Ordering::SeqCst); + + Timer::after(Duration::from_millis(delay)).await; + delay + } + + /// Reset the retry counter to 0. + pub fn reset(&self) { + self.retries.store(0, Ordering::SeqCst); + self.stop.store(false, Ordering::SeqCst); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn test_backoff() { + smol::block_on(async move { + let backoff = Arc::new(Backoff::new(5, 15)); + let backoff_c = backoff.clone(); + smol::spawn(async move { + let delay = backoff_c.sleep().await; + assert_eq!(delay, 5); + + let delay = backoff_c.sleep().await; + assert_eq!(delay, 10); + + let delay = backoff_c.sleep().await; + assert_eq!(delay, 15); + }) + .await; + + smol::spawn(async move { + backoff.reset(); + let delay = backoff.sleep().await; + assert_eq!(delay, 5); + }) + .await; + }); + } +} diff --git a/karyons_core/src/async_utils/condvar.rs b/karyons_core/src/async_utils/condvar.rs new file mode 100644 index 0000000..814f78f --- /dev/null +++ b/karyons_core/src/async_utils/condvar.rs @@ -0,0 +1,387 @@ +use std::{ + collections::HashMap, + future::Future, + pin::Pin, + sync::Mutex, + task::{Context, Poll, Waker}, +}; + +use smol::lock::MutexGuard; + +use crate::utils::random_16; + +/// CondVar is an async version of +/// +/// # Example +/// +///``` +/// use std::sync::Arc; +/// +/// use smol::lock::Mutex; +/// +/// use karyons_core::async_utils::CondVar; +/// +/// async { +/// +/// let val = Arc::new(Mutex::new(false)); +/// let condvar = Arc::new(CondVar::new()); +/// +/// let val_cloned = val.clone(); +/// let condvar_cloned = condvar.clone(); +/// smol::spawn(async move { +/// let mut val = val_cloned.lock().await; +/// +/// // While the boolean flag is false, wait for a signal. +/// while !*val { +/// val = condvar_cloned.wait(val).await; +/// } +/// +/// // ... +/// }); +/// +/// let condvar_cloned = condvar.clone(); +/// smol::spawn(async move { +/// let mut val = val.lock().await; +/// +/// // While the boolean flag is false, wait for a signal. +/// while !*val { +/// val = condvar_cloned.wait(val).await; +/// } +/// +/// // ... +/// }); +/// +/// // Wake up all waiting tasks on this condvar +/// condvar.broadcast(); +/// }; +/// +/// ``` + +pub struct CondVar { + inner: Mutex, +} + +impl CondVar { + /// Creates a new CondVar + pub fn new() -> Self { + Self { + inner: Mutex::new(Wakers::new()), + } + } + + /// Blocks the current task until this condition variable receives a notification. + pub async fn wait<'a, T>(&self, g: MutexGuard<'a, T>) -> MutexGuard<'a, T> { + let m = MutexGuard::source(&g); + + CondVarAwait::new(self, g).await; + + m.lock().await + } + + /// Wakes up one blocked task waiting on this condvar. + pub fn signal(&self) { + self.inner.lock().unwrap().wake(true); + } + + /// Wakes up all blocked tasks waiting on this condvar. + pub fn broadcast(&self) { + self.inner.lock().unwrap().wake(false); + } +} + +impl Default for CondVar { + fn default() -> Self { + Self::new() + } +} + +struct CondVarAwait<'a, T> { + id: Option, + condvar: &'a CondVar, + guard: Option>, +} + +impl<'a, T> CondVarAwait<'a, T> { + fn new(condvar: &'a CondVar, guard: MutexGuard<'a, T>) -> Self { + Self { + condvar, + guard: Some(guard), + id: None, + } + } +} + +impl<'a, T> Future for CondVarAwait<'a, T> { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut inner = self.condvar.inner.lock().unwrap(); + + match self.guard.take() { + Some(_) => { + // the first pooll will release the Mutexguard + self.id = Some(inner.put(Some(cx.waker().clone()))); + Poll::Pending + } + None => { + // Return Ready if it has already been polled and removed + // from the waker list. + if self.id.is_none() { + return Poll::Ready(()); + } + + let i = self.id.as_ref().unwrap(); + match inner.wakers.get_mut(i).unwrap() { + Some(wk) => { + // This will prevent cloning again + if !wk.will_wake(cx.waker()) { + *wk = cx.waker().clone(); + } + Poll::Pending + } + None => { + inner.delete(i); + self.id = None; + Poll::Ready(()) + } + } + } + } + } +} + +impl<'a, T> Drop for CondVarAwait<'a, T> { + fn drop(&mut self) { + if let Some(id) = self.id { + let mut inner = self.condvar.inner.lock().unwrap(); + if let Some(wk) = inner.wakers.get_mut(&id).unwrap().take() { + wk.wake() + } + } + } +} + +/// Wakers is a helper struct to store the task wakers +struct Wakers { + wakers: HashMap>, +} + +impl Wakers { + fn new() -> Self { + Self { + wakers: HashMap::new(), + } + } + + fn put(&mut self, waker: Option) -> u16 { + let mut id: u16; + + id = random_16(); + while self.wakers.contains_key(&id) { + id = random_16(); + } + + self.wakers.insert(id, waker); + id + } + + fn delete(&mut self, id: &u16) -> Option> { + self.wakers.remove(id) + } + + fn wake(&mut self, signal: bool) { + for (_, wk) in self.wakers.iter_mut() { + match wk.take() { + Some(w) => { + w.wake(); + if signal { + break; + } + } + None => continue, + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use smol::lock::Mutex; + use std::{ + collections::VecDeque, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + }; + + // The tests below demonstrate a solution to a problem in the Wikipedia + // explanation of condition variables: + // https://en.wikipedia.org/wiki/Monitor_(synchronization)#Solving_the_bounded_producer/consumer_problem. + + struct Queue { + items: VecDeque, + max_len: usize, + } + impl Queue { + fn new(max_len: usize) -> Self { + Self { + items: VecDeque::new(), + max_len, + } + } + + fn is_full(&self) -> bool { + self.items.len() == self.max_len + } + + fn is_empty(&self) -> bool { + self.items.is_empty() + } + } + + #[test] + fn test_condvar_signal() { + smol::block_on(async { + let number_of_tasks = 30; + + let queue = Arc::new(Mutex::new(Queue::new(5))); + let condvar_full = Arc::new(CondVar::new()); + let condvar_empty = Arc::new(CondVar::new()); + + let queue_cloned = queue.clone(); + let condvar_full_cloned = condvar_full.clone(); + let condvar_empty_cloned = condvar_empty.clone(); + + let _producer1 = smol::spawn(async move { + for i in 1..number_of_tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-full + while queue.is_full() { + // Release queue mutex and sleep + queue = condvar_full_cloned.wait(queue).await; + } + + queue.items.push_back(format!("task {i}")); + + // Wake up the consumer + condvar_empty_cloned.signal(); + } + }); + + let queue_cloned = queue.clone(); + let task_consumed = Arc::new(AtomicUsize::new(0)); + let task_consumed_ = task_consumed.clone(); + let consumer = smol::spawn(async move { + for _ in 1..number_of_tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-empty + while queue.is_empty() { + // Release queue mutex and sleep + queue = condvar_empty.wait(queue).await; + } + + let _ = queue.items.pop_front().unwrap(); + + task_consumed_.fetch_add(1, Ordering::Relaxed); + + // Do something + + // Wake up the producer + condvar_full.signal(); + } + }); + + consumer.await; + assert!(queue.lock().await.is_empty()); + assert_eq!(task_consumed.load(Ordering::Relaxed), 29); + }); + } + + #[test] + fn test_condvar_broadcast() { + smol::block_on(async { + let tasks = 30; + + let queue = Arc::new(Mutex::new(Queue::new(5))); + let condvar = Arc::new(CondVar::new()); + + let queue_cloned = queue.clone(); + let condvar_cloned = condvar.clone(); + let _producer1 = smol::spawn(async move { + for i in 1..tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-full + while queue.is_full() { + // Release queue mutex and sleep + queue = condvar_cloned.wait(queue).await; + } + + queue.items.push_back(format!("producer1: task {i}")); + + // Wake up all producer and consumer tasks + condvar_cloned.broadcast(); + } + }); + + let queue_cloned = queue.clone(); + let condvar_cloned = condvar.clone(); + let _producer2 = smol::spawn(async move { + for i in 1..tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-full + while queue.is_full() { + // Release queue mutex and sleep + queue = condvar_cloned.wait(queue).await; + } + + queue.items.push_back(format!("producer2: task {i}")); + + // Wake up all producer and consumer tasks + condvar_cloned.broadcast(); + } + }); + + let queue_cloned = queue.clone(); + let task_consumed = Arc::new(AtomicUsize::new(0)); + let task_consumed_ = task_consumed.clone(); + + let consumer = smol::spawn(async move { + for _ in 1..((tasks * 2) - 1) { + { + // Lock queue mutex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-empty + while queue.is_empty() { + // Release queue mutex and sleep + queue = condvar.wait(queue).await; + } + + let _ = queue.items.pop_front().unwrap(); + + task_consumed_.fetch_add(1, Ordering::Relaxed); + + // Do something + + // Wake up all producer and consumer tasks + condvar.broadcast(); + } + } + }); + + consumer.await; + assert!(queue.lock().await.is_empty()); + assert_eq!(task_consumed.load(Ordering::Relaxed), 58); + }); + } +} diff --git a/karyons_core/src/async_utils/condwait.rs b/karyons_core/src/async_utils/condwait.rs new file mode 100644 index 0000000..f16a99e --- /dev/null +++ b/karyons_core/src/async_utils/condwait.rs @@ -0,0 +1,96 @@ +use smol::lock::Mutex; + +use super::CondVar; + +/// CondWait is a wrapper struct for CondVar with a Mutex boolean flag. +/// +/// # Example +/// +///``` +/// use std::sync::Arc; +/// +/// use karyons_core::async_utils::CondWait; +/// +/// async { +/// let cond_wait = Arc::new(CondWait::new()); +/// let cond_wait_cloned = cond_wait.clone(); +/// let task = smol::spawn(async move { +/// cond_wait_cloned.wait().await; +/// // ... +/// }); +/// +/// cond_wait.signal().await; +/// }; +/// +/// ``` +/// +pub struct CondWait { + /// The CondVar + condvar: CondVar, + /// Boolean flag + w: Mutex, +} + +impl CondWait { + /// Creates a new CondWait. + pub fn new() -> Self { + Self { + condvar: CondVar::new(), + w: Mutex::new(false), + } + } + + /// Waits for a signal or broadcast. + pub async fn wait(&self) { + let mut w = self.w.lock().await; + + // While the boolean flag is false, wait for a signal. + while !*w { + w = self.condvar.wait(w).await; + } + } + + /// Signal a waiting task. + pub async fn signal(&self) { + *self.w.lock().await = true; + self.condvar.signal(); + } + + /// Signal all waiting tasks. + pub async fn broadcast(&self) { + *self.w.lock().await = true; + self.condvar.broadcast(); + } + + /// Reset the boolean flag value to false. + pub async fn reset(&self) { + *self.w.lock().await = false; + } +} + +impl Default for CondWait { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn test_cond_wait() { + smol::block_on(async { + let cond_wait = Arc::new(CondWait::new()); + let cond_wait_cloned = cond_wait.clone(); + let task = smol::spawn(async move { + cond_wait_cloned.wait().await; + true + }); + + cond_wait.signal().await; + assert!(task.await); + }); + } +} diff --git a/karyons_core/src/async_utils/mod.rs b/karyons_core/src/async_utils/mod.rs new file mode 100644 index 0000000..c871bad --- /dev/null +++ b/karyons_core/src/async_utils/mod.rs @@ -0,0 +1,13 @@ +mod backoff; +mod condvar; +mod condwait; +mod select; +mod task_group; +mod timeout; + +pub use backoff::Backoff; +pub use condvar::CondVar; +pub use condwait::CondWait; +pub use select::{select, Either}; +pub use task_group::{TaskGroup, TaskResult}; +pub use timeout::timeout; diff --git a/karyons_core/src/async_utils/select.rs b/karyons_core/src/async_utils/select.rs new file mode 100644 index 0000000..d61b355 --- /dev/null +++ b/karyons_core/src/async_utils/select.rs @@ -0,0 +1,99 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use pin_project_lite::pin_project; +use smol::future::Future; + +/// Returns the result of the future that completes first, preferring future1 +/// if both are ready. +/// +/// # Examples +/// +/// ``` +/// use std::future; +/// +/// use karyons_core::async_utils::{select, Either}; +/// +/// async { +/// let fut1 = future::pending::(); +/// let fut2 = future::ready(0); +/// let res = select(fut1, fut2).await; +/// assert!(matches!(res, Either::Right(0))); +/// // .... +/// }; +/// +/// ``` +/// +pub fn select(future1: F1, future2: F2) -> Select +where + F1: Future, + F2: Future, +{ + Select { future1, future2 } +} + +pin_project! { + #[derive(Debug)] + pub struct Select { + #[pin] + future1: F1, + #[pin] + future2: F2, + } +} + +/// The return value from the `select` function, indicating which future +/// completed first. +#[derive(Debug)] +pub enum Either { + Left(T1), + Right(T2), +} + +// Implement the Future trait for the Select struct. +impl Future for Select +where + F1: Future, + F2: Future, +{ + type Output = Either; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if let Poll::Ready(t) = this.future1.poll(cx) { + return Poll::Ready(Either::Left(t)); + } + + if let Poll::Ready(t) = this.future2.poll(cx) { + return Poll::Ready(Either::Right(t)); + } + + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use super::{select, Either}; + use smol::Timer; + use std::future; + + #[test] + fn test_async_select() { + smol::block_on(async move { + let fut = select(Timer::never(), future::ready(0 as u32)).await; + assert!(matches!(fut, Either::Right(0))); + + let fut1 = future::pending::(); + let fut2 = future::ready(0); + let res = select(fut1, fut2).await; + assert!(matches!(res, Either::Right(0))); + + let fut1 = future::ready(0); + let fut2 = future::pending::(); + let res = select(fut1, fut2).await; + assert!(matches!(res, Either::Left(_))); + }); + } +} diff --git a/karyons_core/src/async_utils/task_group.rs b/karyons_core/src/async_utils/task_group.rs new file mode 100644 index 0000000..8707d0e --- /dev/null +++ b/karyons_core/src/async_utils/task_group.rs @@ -0,0 +1,197 @@ +use std::{future::Future, sync::Arc, sync::Mutex}; + +use smol::Task; + +use crate::Executor; + +use super::{select, CondWait, Either}; + +/// TaskGroup is a group of spawned tasks. +/// +/// # Example +/// +/// ``` +/// +/// use std::sync::Arc; +/// +/// use karyons_core::async_utils::TaskGroup; +/// +/// async { +/// +/// let ex = Arc::new(smol::Executor::new()); +/// let group = TaskGroup::new(); +/// +/// group.spawn(ex.clone(), smol::Timer::never(), |_| async {}); +/// +/// group.cancel().await; +/// +/// }; +/// +/// ``` +/// +pub struct TaskGroup { + tasks: Mutex>, + stop_signal: Arc, +} + +impl<'a> TaskGroup { + /// Creates a new task group + pub fn new() -> Self { + Self { + tasks: Mutex::new(Vec::new()), + stop_signal: Arc::new(CondWait::new()), + } + } + + /// Spawns a new task and calls the callback after it has completed + /// or been canceled. The callback will have the `TaskResult` as a + /// parameter, indicating whether the task completed or was canceled. + pub fn spawn( + &self, + executor: Executor<'a>, + fut: Fut, + callback: CallbackF, + ) where + T: Send + Sync + 'a, + Fut: Future + Send + 'a, + CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, + CallbackFut: Future + Send + 'a, + { + let task = TaskHandler::new(executor.clone(), fut, callback, self.stop_signal.clone()); + self.tasks.lock().unwrap().push(task); + } + + /// Checks if the task group is empty. + pub fn is_empty(&self) -> bool { + self.tasks.lock().unwrap().is_empty() + } + + /// Get the number of the tasks in the group. + pub fn len(&self) -> usize { + self.tasks.lock().unwrap().len() + } + + /// Cancels all tasks in the group. + pub async fn cancel(&self) { + self.stop_signal.broadcast().await; + + loop { + let task = self.tasks.lock().unwrap().pop(); + if let Some(t) = task { + t.cancel().await + } else { + break; + } + } + } +} + +impl Default for TaskGroup { + fn default() -> Self { + Self::new() + } +} + +/// The result of a spawned task. +#[derive(Debug)] +pub enum TaskResult { + Completed(T), + Cancelled, +} + +impl std::fmt::Display for TaskResult { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + TaskResult::Cancelled => write!(f, "Task cancelled"), + TaskResult::Completed(res) => write!(f, "Task completed: {:?}", res), + } + } +} + +/// TaskHandler +pub struct TaskHandler { + task: Task<()>, + cancel_flag: Arc, +} + +impl<'a> TaskHandler { + /// Creates a new task handle + fn new( + ex: Executor<'a>, + fut: Fut, + callback: CallbackF, + stop_signal: Arc, + ) -> TaskHandler + where + T: Send + Sync + 'a, + Fut: Future + Send + 'a, + CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, + CallbackFut: Future + Send + 'a, + { + let cancel_flag = Arc::new(CondWait::new()); + let cancel_flag_c = cancel_flag.clone(); + let task = ex.spawn(async move { + //start_signal.signal().await; + // Waits for either the stop signal or the task to complete. + let result = select(stop_signal.wait(), fut).await; + + let result = match result { + Either::Left(_) => TaskResult::Cancelled, + Either::Right(res) => TaskResult::Completed(res), + }; + + // Call the callback with the result. + callback(result).await; + + cancel_flag_c.signal().await; + }); + + TaskHandler { task, cancel_flag } + } + + /// Cancels the task. + async fn cancel(self) { + self.cancel_flag.wait().await; + self.task.cancel().await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{future, sync::Arc}; + + #[test] + fn test_task_group() { + let ex = Arc::new(smol::Executor::new()); + smol::block_on(ex.clone().run(async move { + let group = Arc::new(TaskGroup::new()); + + group.spawn(ex.clone(), future::ready(0), |res| async move { + assert!(matches!(res, TaskResult::Completed(0))); + }); + + group.spawn(ex.clone(), future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + + let groupc = group.clone(); + let exc = ex.clone(); + group.spawn( + ex.clone(), + async move { + groupc.spawn(exc.clone(), future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + }, + |res| async move { + assert!(matches!(res, TaskResult::Completed(_))); + }, + ); + + // Do something + smol::Timer::after(std::time::Duration::from_millis(50)).await; + group.cancel().await; + })); + } +} diff --git a/karyons_core/src/async_utils/timeout.rs b/karyons_core/src/async_utils/timeout.rs new file mode 100644 index 0000000..7c55e1b --- /dev/null +++ b/karyons_core/src/async_utils/timeout.rs @@ -0,0 +1,52 @@ +use std::{future::Future, time::Duration}; + +use smol::Timer; + +use super::{select, Either}; +use crate::{error::Error, Result}; + +/// Waits for a future to complete or times out if it exceeds a specified +/// duration. +/// +/// # Example +/// +/// ``` +/// use std::{future, time::Duration}; +/// +/// use karyons_core::async_utils::timeout; +/// +/// async { +/// let fut = future::pending::<()>(); +/// assert!(timeout(Duration::from_millis(100), fut).await.is_err()); +/// }; +/// +/// ``` +/// +pub async fn timeout(delay: Duration, future1: F) -> Result +where + F: Future, +{ + let result = select(Timer::after(delay), future1).await; + + match result { + Either::Left(_) => Err(Error::Timeout), + Either::Right(res) => Ok(res), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{future, time::Duration}; + + #[test] + fn test_timeout() { + smol::block_on(async move { + let fut = future::pending::<()>(); + assert!(timeout(Duration::from_millis(10), fut).await.is_err()); + + let fut = smol::Timer::after(Duration::from_millis(10)); + assert!(timeout(Duration::from_millis(50), fut).await.is_ok()) + }); + } +} diff --git a/karyons_core/src/error.rs b/karyons_core/src/error.rs new file mode 100644 index 0000000..15947c8 --- /dev/null +++ b/karyons_core/src/error.rs @@ -0,0 +1,51 @@ +use thiserror::Error as ThisError; + +pub type Result = std::result::Result; + +#[derive(ThisError, Debug)] +pub enum Error { + #[error("IO Error: {0}")] + IO(#[from] std::io::Error), + + #[error("Timeout Error")] + Timeout, + + #[error("Path Not Found Error: {0}")] + PathNotFound(&'static str), + + #[error("Channel Send Error: {0}")] + ChannelSend(String), + + #[error("Channel Receive Error: {0}")] + ChannelRecv(String), + + #[error("Decode Error: {0}")] + Decode(String), + + #[error("Encode Error: {0}")] + Encode(String), +} + +impl From> for Error { + fn from(error: smol::channel::SendError) -> Self { + Error::ChannelSend(error.to_string()) + } +} + +impl From for Error { + fn from(error: smol::channel::RecvError) -> Self { + Error::ChannelRecv(error.to_string()) + } +} + +impl From for Error { + fn from(error: bincode::error::DecodeError) -> Self { + Error::Decode(error.to_string()) + } +} + +impl From for Error { + fn from(error: bincode::error::EncodeError) -> Self { + Error::Encode(error.to_string()) + } +} diff --git a/karyons_core/src/event.rs b/karyons_core/src/event.rs new file mode 100644 index 0000000..b856385 --- /dev/null +++ b/karyons_core/src/event.rs @@ -0,0 +1,451 @@ +use std::{ + any::Any, + collections::HashMap, + marker::PhantomData, + sync::{Arc, Weak}, +}; + +use chrono::{DateTime, Utc}; +use log::{error, trace}; +use smol::{ + channel::{Receiver, Sender}, + lock::Mutex, +}; + +use crate::{utils::random_16, Result}; + +pub type ArcEventSys = Arc>; +pub type WeakEventSys = Weak>; +pub type EventListenerID = u16; + +type Listeners = HashMap>>>; + +/// EventSys supports event emission to registered listeners based on topics. +/// # Example +/// +/// ``` +/// use karyons_core::event::{EventSys, EventValueTopic, EventValue}; +/// +/// async { +/// let event_sys = EventSys::new(); +/// +/// #[derive(Hash, PartialEq, Eq, Debug, Clone)] +/// enum Topic { +/// TopicA, +/// TopicB, +/// } +/// +/// #[derive(Clone, Debug, PartialEq)] +/// struct A(usize); +/// +/// impl EventValue for A { +/// fn id() -> &'static str { +/// "A" +/// } +/// } +/// +/// let listener = event_sys.register::(&Topic::TopicA).await; +/// +/// event_sys.emit_by_topic(&Topic::TopicA, &A(3)) .await; +/// let msg: A = listener.recv().await.unwrap(); +/// +/// #[derive(Clone, Debug, PartialEq)] +/// struct B(usize); +/// +/// impl EventValue for B { +/// fn id() -> &'static str { +/// "B" +/// } +/// } +/// +/// impl EventValueTopic for B { +/// type Topic = Topic; +/// fn topic() -> Self::Topic{ +/// Topic::TopicB +/// } +/// } +/// +/// let listener = event_sys.register::(&Topic::TopicB).await; +/// +/// event_sys.emit(&B(3)) .await; +/// let msg: B = listener.recv().await.unwrap(); +/// +/// // .... +/// }; +/// +/// ``` +/// +pub struct EventSys { + listeners: Mutex>, +} + +impl EventSys +where + T: std::hash::Hash + Eq + std::fmt::Debug + Clone, +{ + /// Creates a new `EventSys` + pub fn new() -> ArcEventSys { + Arc::new(Self { + listeners: Mutex::new(HashMap::new()), + }) + } + + /// Emits an event to the listeners. + /// + /// The event must implement the `EventValueTopic` trait to indicate the + /// topic of the event. Otherwise, you can use `emit_by_topic()`. + pub async fn emit + Clone>(&self, value: &E) { + let topic = E::topic(); + self.emit_by_topic(&topic, value).await; + } + + /// Emits an event to the listeners. + pub async fn emit_by_topic(&self, topic: &T, value: &E) { + let value: Arc = Arc::new(value.clone()); + let event = Event::new(value); + + let mut topics = self.listeners.lock().await; + + if !topics.contains_key(topic) { + error!("Failed to emit an event to a non-existent topic"); + return; + } + + let event_ids = topics.get_mut(topic).unwrap(); + let event_id = E::id().to_string(); + + if !event_ids.contains_key(&event_id) { + error!("Failed to emit an event to a non-existent event id"); + return; + } + + let mut failed_listeners = vec![]; + + let listeners = event_ids.get_mut(&event_id).unwrap(); + for (listener_id, listener) in listeners.iter() { + if let Err(err) = listener.send(event.clone()).await { + trace!("Failed to emit event for topic {:?}: {}", topic, err); + failed_listeners.push(*listener_id); + } + } + + for listener_id in failed_listeners.iter() { + listeners.remove(listener_id); + } + } + + /// Registers a new event listener for the given topic. + pub async fn register( + self: &Arc, + topic: &T, + ) -> EventListener { + let chan = smol::channel::unbounded(); + + let topics = &mut self.listeners.lock().await; + + if !topics.contains_key(topic) { + topics.insert(topic.clone(), HashMap::new()); + } + + let event_ids = topics.get_mut(topic).unwrap(); + let event_id = E::id().to_string(); + + if !event_ids.contains_key(&event_id) { + event_ids.insert(event_id.clone(), HashMap::new()); + } + + let listeners = event_ids.get_mut(&event_id).unwrap(); + + let mut listener_id = random_16(); + while listeners.contains_key(&listener_id) { + listener_id = random_16(); + } + + let listener = + EventListener::new(listener_id, Arc::downgrade(self), chan.1, &event_id, topic); + + listeners.insert(listener_id, chan.0); + + listener + } + + /// Removes an event listener attached to the given topic. + async fn remove(&self, topic: &T, event_id: &str, listener_id: &EventListenerID) { + let topics = &mut self.listeners.lock().await; + if !topics.contains_key(topic) { + error!("Failed to remove a non-existent topic"); + return; + } + + let event_ids = topics.get_mut(topic).unwrap(); + if !event_ids.contains_key(event_id) { + error!("Failed to remove a non-existent event id"); + return; + } + + let listeners = event_ids.get_mut(event_id).unwrap(); + if listeners.remove(listener_id).is_none() { + error!("Failed to remove a non-existent event listener"); + } + } +} + +/// EventListener listens for and receives events from the `EventSys`. +pub struct EventListener { + id: EventListenerID, + recv_chan: Receiver, + event_sys: WeakEventSys, + event_id: String, + topic: T, + phantom: PhantomData, +} + +impl EventListener +where + T: std::hash::Hash + Eq + Clone + std::fmt::Debug, + E: EventValueAny + Clone + EventValue, +{ + /// Create a new event listener. + fn new( + id: EventListenerID, + event_sys: WeakEventSys, + recv_chan: Receiver, + event_id: &str, + topic: &T, + ) -> EventListener { + Self { + id, + recv_chan, + event_sys, + event_id: event_id.to_string(), + topic: topic.clone(), + phantom: PhantomData, + } + } + + /// Receive the next event. + pub async fn recv(&self) -> Result { + match self.recv_chan.recv().await { + Ok(event) => match ((*event.value).value_as_any()).downcast_ref::() { + Some(v) => Ok(v.clone()), + None => unreachable!("Error when attempting to downcast the event value."), + }, + Err(err) => { + error!("Failed to receive new event: {err}"); + self.cancel().await; + Err(err.into()) + } + } + } + + /// Cancels the listener and removes it from the `EventSys`. + pub async fn cancel(&self) { + self.event_sys() + .remove(&self.topic, &self.event_id, &self.id) + .await; + } + + /// Returns the topic for this event listener. + pub async fn topic(&self) -> &T { + &self.topic + } + + /// Returns the event id for this event listener. + pub async fn event_id(&self) -> &String { + &self.event_id + } + + fn event_sys(&self) -> ArcEventSys { + self.event_sys.upgrade().unwrap() + } +} + +/// An event within the `EventSys`. +#[derive(Clone, Debug)] +pub struct Event { + /// The time at which the event was created. + created_at: DateTime, + /// The value of the Event. + value: Arc, +} + +impl Event { + /// Creates a new Event. + pub fn new(value: Arc) -> Self { + Self { + created_at: Utc::now(), + value, + } + } +} + +impl std::fmt::Display for Event { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}: {:?}", self.created_at, self.value) + } +} + +pub trait EventValueAny: Any + Send + Sync + std::fmt::Debug { + fn value_as_any(&self) -> &dyn Any; +} + +impl EventValueAny for T { + fn value_as_any(&self) -> &dyn Any { + self + } +} + +pub trait EventValue: EventValueAny { + fn id() -> &'static str + where + Self: Sized; +} + +pub trait EventValueTopic: EventValueAny + EventValue { + type Topic; + fn topic() -> Self::Topic + where + Self: Sized; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Hash, PartialEq, Eq, Debug, Clone)] + enum Topic { + TopicA, + TopicB, + TopicC, + TopicD, + TopicE, + } + + #[derive(Clone, Debug, PartialEq)] + struct A { + a_value: usize, + } + + #[derive(Clone, Debug, PartialEq)] + struct B { + b_value: usize, + } + + #[derive(Clone, Debug, PartialEq)] + struct C { + c_value: usize, + } + + #[derive(Clone, Debug, PartialEq)] + struct D { + d_value: usize, + } + + #[derive(Clone, Debug, PartialEq)] + struct E { + e_value: usize, + } + + #[derive(Clone, Debug, PartialEq)] + struct F { + f_value: usize, + } + + impl EventValue for A { + fn id() -> &'static str { + "A" + } + } + + impl EventValue for B { + fn id() -> &'static str { + "B" + } + } + + impl EventValue for C { + fn id() -> &'static str { + "C" + } + } + + impl EventValue for D { + fn id() -> &'static str { + "D" + } + } + + impl EventValue for E { + fn id() -> &'static str { + "E" + } + } + + impl EventValue for F { + fn id() -> &'static str { + "F" + } + } + + impl EventValueTopic for C { + type Topic = Topic; + fn topic() -> Self::Topic { + Topic::TopicC + } + } + + #[test] + fn test_event_sys() { + smol::block_on(async move { + let event_sys = EventSys::::new(); + + let a_listener = event_sys.register::(&Topic::TopicA).await; + let b_listener = event_sys.register::(&Topic::TopicB).await; + + event_sys + .emit_by_topic(&Topic::TopicA, &A { a_value: 3 }) + .await; + event_sys + .emit_by_topic(&Topic::TopicB, &B { b_value: 5 }) + .await; + + let msg = a_listener.recv().await.unwrap(); + assert_eq!(msg, A { a_value: 3 }); + + let msg = b_listener.recv().await.unwrap(); + assert_eq!(msg, B { b_value: 5 }); + + // register the same event type to different topics + let c_listener = event_sys.register::(&Topic::TopicC).await; + let d_listener = event_sys.register::(&Topic::TopicD).await; + + event_sys.emit(&C { c_value: 10 }).await; + let msg = c_listener.recv().await.unwrap(); + assert_eq!(msg, C { c_value: 10 }); + + event_sys + .emit_by_topic(&Topic::TopicD, &C { c_value: 10 }) + .await; + let msg = d_listener.recv().await.unwrap(); + assert_eq!(msg, C { c_value: 10 }); + + // register different event types to the same topic + let e_listener = event_sys.register::(&Topic::TopicE).await; + let f_listener = event_sys.register::(&Topic::TopicE).await; + + event_sys + .emit_by_topic(&Topic::TopicE, &E { e_value: 5 }) + .await; + + let msg = e_listener.recv().await.unwrap(); + assert_eq!(msg, E { e_value: 5 }); + + event_sys + .emit_by_topic(&Topic::TopicE, &F { f_value: 5 }) + .await; + + let msg = f_listener.recv().await.unwrap(); + assert_eq!(msg, F { f_value: 5 }); + }); + } +} diff --git a/karyons_core/src/lib.rs b/karyons_core/src/lib.rs new file mode 100644 index 0000000..83af888 --- /dev/null +++ b/karyons_core/src/lib.rs @@ -0,0 +1,21 @@ +/// A set of helper tools and functions. +pub mod utils; + +/// A module containing async utilities that work with the `smol` async runtime. +pub mod async_utils; + +/// Represents Karyons's Core Error. +pub mod error; + +/// [`EventSys`](./event/struct.EventSys.html) Implementation +pub mod event; + +/// A simple publish-subscribe system.[`Read More`](./pubsub/struct.Publisher.html) +pub mod pubsub; + +use error::Result; +use smol::Executor as SmolEx; +use std::sync::Arc; + +/// A wrapper for smol::Executor +pub type Executor<'a> = Arc>; diff --git a/karyons_core/src/pubsub.rs b/karyons_core/src/pubsub.rs new file mode 100644 index 0000000..4cc0ab7 --- /dev/null +++ b/karyons_core/src/pubsub.rs @@ -0,0 +1,115 @@ +use std::{collections::HashMap, sync::Arc}; + +use log::error; +use smol::lock::Mutex; + +use crate::{utils::random_16, Result}; + +pub type ArcPublisher = Arc>; +pub type SubscriptionID = u16; + +/// A simple publish-subscribe system. +// # Example +/// +/// ``` +/// use karyons_core::pubsub::{Publisher}; +/// +/// async { +/// let publisher = Publisher::new(); +/// +/// let sub = publisher.subscribe().await; +/// +/// publisher.notify(&String::from("MESSAGE")).await; +/// +/// let msg = sub.recv().await; +/// +/// // .... +/// }; +/// +/// ``` +pub struct Publisher { + subs: Mutex>>, +} + +impl Publisher { + /// Creates a new Publisher + pub fn new() -> ArcPublisher { + Arc::new(Self { + subs: Mutex::new(HashMap::new()), + }) + } + + /// Subscribe and return a Subscription + pub async fn subscribe(self: &Arc) -> Subscription { + let mut subs = self.subs.lock().await; + + let chan = smol::channel::unbounded(); + + let mut sub_id = random_16(); + + // While the SubscriptionID already exists, generate a new one + while subs.contains_key(&sub_id) { + sub_id = random_16(); + } + + let sub = Subscription::new(sub_id, self.clone(), chan.1); + subs.insert(sub_id, chan.0); + + sub + } + + /// Unsubscribe from the Publisher + pub async fn unsubscribe(self: &Arc, id: &SubscriptionID) { + self.subs.lock().await.remove(id); + } + + /// Notify all subscribers + pub async fn notify(self: &Arc, value: &T) { + let mut subs = self.subs.lock().await; + let mut closed_subs = vec![]; + + for (sub_id, sub) in subs.iter() { + if let Err(err) = sub.send(value.clone()).await { + error!("failed to notify {}: {}", sub_id, err); + closed_subs.push(*sub_id); + } + } + + for sub_id in closed_subs.iter() { + subs.remove(sub_id); + } + } +} + +// Subscription +pub struct Subscription { + id: SubscriptionID, + recv_chan: smol::channel::Receiver, + publisher: ArcPublisher, +} + +impl Subscription { + /// Creates a new Subscription + pub fn new( + id: SubscriptionID, + publisher: ArcPublisher, + recv_chan: smol::channel::Receiver, + ) -> Subscription { + Self { + id, + recv_chan, + publisher, + } + } + + /// Receive a message from the Publisher + pub async fn recv(&self) -> Result { + let msg = self.recv_chan.recv().await?; + Ok(msg) + } + + /// Unsubscribe from the Publisher + pub async fn unsubscribe(&self) { + self.publisher.unsubscribe(&self.id).await; + } +} diff --git a/karyons_core/src/utils/decode.rs b/karyons_core/src/utils/decode.rs new file mode 100644 index 0000000..a8a6522 --- /dev/null +++ b/karyons_core/src/utils/decode.rs @@ -0,0 +1,10 @@ +use bincode::Decode; + +use crate::Result; + +/// Decodes a given type `T` from the given slice. returns the decoded value +/// along with the number of bytes read. +pub fn decode(src: &[u8]) -> Result<(T, usize)> { + let (result, bytes_read) = bincode::decode_from_slice(src, bincode::config::standard())?; + Ok((result, bytes_read)) +} diff --git a/karyons_core/src/utils/encode.rs b/karyons_core/src/utils/encode.rs new file mode 100644 index 0000000..7d1061b --- /dev/null +++ b/karyons_core/src/utils/encode.rs @@ -0,0 +1,15 @@ +use bincode::Encode; + +use crate::Result; + +/// Encode the given type `T` into a `Vec`. +pub fn encode(msg: &T) -> Result> { + let vec = bincode::encode_to_vec(msg, bincode::config::standard())?; + Ok(vec) +} + +/// Encode the given type `T` into the given slice.. +pub fn encode_into_slice(msg: &T, dst: &mut [u8]) -> Result<()> { + bincode::encode_into_slice(msg, dst, bincode::config::standard())?; + Ok(()) +} diff --git a/karyons_core/src/utils/mod.rs b/karyons_core/src/utils/mod.rs new file mode 100644 index 0000000..a3c3f50 --- /dev/null +++ b/karyons_core/src/utils/mod.rs @@ -0,0 +1,19 @@ +mod decode; +mod encode; +mod path; + +pub use decode::decode; +pub use encode::{encode, encode_into_slice}; +pub use path::{home_dir, tilde_expand}; + +use rand::{rngs::OsRng, Rng}; + +/// Generates and returns a random u32 using `rand::rngs::OsRng`. +pub fn random_32() -> u32 { + OsRng.gen() +} + +/// Generates and returns a random u16 using `rand::rngs::OsRng`. +pub fn random_16() -> u16 { + OsRng.gen() +} diff --git a/karyons_core/src/utils/path.rs b/karyons_core/src/utils/path.rs new file mode 100644 index 0000000..2cd900a --- /dev/null +++ b/karyons_core/src/utils/path.rs @@ -0,0 +1,39 @@ +use std::path::PathBuf; + +use crate::{error::Error, Result}; + +/// Returns the user's home directory as a `PathBuf`. +#[allow(dead_code)] +pub fn home_dir() -> Result { + dirs::home_dir().ok_or(Error::PathNotFound("Home dir not found")) +} + +/// Expands a tilde (~) in a path and returns the expanded `PathBuf`. +#[allow(dead_code)] +pub fn tilde_expand(path: &str) -> Result { + match path { + "~" => home_dir(), + p if p.starts_with("~/") => Ok(home_dir()?.join(&path[2..])), + _ => Ok(PathBuf::from(path)), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tilde_expand() { + let path = "~/src"; + let expanded_path = dirs::home_dir().unwrap().join("src"); + assert_eq!(tilde_expand(path).unwrap(), expanded_path); + + let path = "~"; + let expanded_path = dirs::home_dir().unwrap(); + assert_eq!(tilde_expand(path).unwrap(), expanded_path); + + let path = ""; + let expanded_path = PathBuf::from(""); + assert_eq!(tilde_expand(path).unwrap(), expanded_path); + } +} -- cgit v1.2.3