From 98a1de91a2dae06323558422c239e5a45fc86e7b Mon Sep 17 00:00:00 2001 From: hozan23 Date: Tue, 28 Nov 2023 22:41:33 +0300 Subject: implement TLS for inbound and outbound connections --- core/src/async_util/backoff.rs | 115 +++++++++++ core/src/async_util/condvar.rs | 387 +++++++++++++++++++++++++++++++++++++ core/src/async_util/condwait.rs | 133 +++++++++++++ core/src/async_util/mod.rs | 13 ++ core/src/async_util/select.rs | 99 ++++++++++ core/src/async_util/task_group.rs | 194 +++++++++++++++++++ core/src/async_util/timeout.rs | 52 +++++ core/src/async_utils/backoff.rs | 115 ----------- core/src/async_utils/condvar.rs | 387 ------------------------------------- core/src/async_utils/condwait.rs | 133 ------------- core/src/async_utils/mod.rs | 13 -- core/src/async_utils/select.rs | 99 ---------- core/src/async_utils/task_group.rs | 194 ------------------- core/src/async_utils/timeout.rs | 52 ----- core/src/error.rs | 6 + core/src/event.rs | 2 +- core/src/key_pair.rs | 189 ++++++++++++++++++ core/src/lib.rs | 9 +- core/src/pubsub.rs | 2 +- core/src/util/decode.rs | 10 + core/src/util/encode.rs | 15 ++ core/src/util/mod.rs | 19 ++ core/src/util/path.rs | 39 ++++ core/src/utils/decode.rs | 10 - core/src/utils/encode.rs | 15 -- core/src/utils/mod.rs | 19 -- core/src/utils/path.rs | 39 ---- 27 files changed, 1279 insertions(+), 1081 deletions(-) create mode 100644 core/src/async_util/backoff.rs create mode 100644 core/src/async_util/condvar.rs create mode 100644 core/src/async_util/condwait.rs create mode 100644 core/src/async_util/mod.rs create mode 100644 core/src/async_util/select.rs create mode 100644 core/src/async_util/task_group.rs create mode 100644 core/src/async_util/timeout.rs delete mode 100644 core/src/async_utils/backoff.rs delete mode 100644 core/src/async_utils/condvar.rs delete mode 100644 core/src/async_utils/condwait.rs delete mode 100644 core/src/async_utils/mod.rs delete mode 100644 core/src/async_utils/select.rs delete mode 100644 core/src/async_utils/task_group.rs delete mode 100644 core/src/async_utils/timeout.rs create mode 100644 core/src/key_pair.rs create mode 100644 core/src/util/decode.rs create mode 100644 core/src/util/encode.rs create mode 100644 core/src/util/mod.rs create mode 100644 core/src/util/path.rs delete mode 100644 core/src/utils/decode.rs delete mode 100644 core/src/utils/encode.rs delete mode 100644 core/src/utils/mod.rs delete mode 100644 core/src/utils/path.rs (limited to 'core/src') diff --git a/core/src/async_util/backoff.rs b/core/src/async_util/backoff.rs new file mode 100644 index 0000000..a231229 --- /dev/null +++ b/core/src/async_util/backoff.rs @@ -0,0 +1,115 @@ +use std::{ + cmp::min, + sync::atomic::{AtomicBool, AtomicU32, Ordering}, + time::Duration, +}; + +use smol::Timer; + +/// Exponential backoff +/// +/// +/// # Examples +/// +/// ``` +/// use karyons_core::async_util::Backoff; +/// +/// async { +/// let backoff = Backoff::new(300, 3000); +/// +/// loop { +/// backoff.sleep().await; +/// +/// // do something +/// break; +/// } +/// +/// backoff.reset(); +/// +/// // .... +/// }; +/// +/// ``` +/// +pub struct Backoff { + /// The base delay in milliseconds for the initial retry. + base_delay: u64, + /// The max delay in milliseconds allowed for a retry. + max_delay: u64, + /// Atomic counter + retries: AtomicU32, + /// Stop flag + stop: AtomicBool, +} + +impl Backoff { + /// Creates a new Backoff. + pub fn new(base_delay: u64, max_delay: u64) -> Self { + Self { + base_delay, + max_delay, + retries: AtomicU32::new(0), + stop: AtomicBool::new(false), + } + } + + /// Sleep based on the current retry count and delay values. + /// Retruns the delay value. + pub async fn sleep(&self) -> u64 { + if self.stop.load(Ordering::SeqCst) { + Timer::after(Duration::from_millis(self.max_delay)).await; + return self.max_delay; + } + + let retries = self.retries.load(Ordering::SeqCst); + let delay = self.base_delay * (2_u64).pow(retries); + let delay = min(delay, self.max_delay); + + if delay == self.max_delay { + self.stop.store(true, Ordering::SeqCst); + } + + self.retries.store(retries + 1, Ordering::SeqCst); + + Timer::after(Duration::from_millis(delay)).await; + delay + } + + /// Reset the retry counter to 0. + pub fn reset(&self) { + self.retries.store(0, Ordering::SeqCst); + self.stop.store(false, Ordering::SeqCst); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn test_backoff() { + smol::block_on(async move { + let backoff = Arc::new(Backoff::new(5, 15)); + let backoff_c = backoff.clone(); + smol::spawn(async move { + let delay = backoff_c.sleep().await; + assert_eq!(delay, 5); + + let delay = backoff_c.sleep().await; + assert_eq!(delay, 10); + + let delay = backoff_c.sleep().await; + assert_eq!(delay, 15); + }) + .await; + + smol::spawn(async move { + backoff.reset(); + let delay = backoff.sleep().await; + assert_eq!(delay, 5); + }) + .await; + }); + } +} diff --git a/core/src/async_util/condvar.rs b/core/src/async_util/condvar.rs new file mode 100644 index 0000000..7396d0d --- /dev/null +++ b/core/src/async_util/condvar.rs @@ -0,0 +1,387 @@ +use std::{ + collections::HashMap, + future::Future, + pin::Pin, + sync::Mutex, + task::{Context, Poll, Waker}, +}; + +use smol::lock::MutexGuard; + +use crate::util::random_16; + +/// CondVar is an async version of +/// +/// # Example +/// +///``` +/// use std::sync::Arc; +/// +/// use smol::lock::Mutex; +/// +/// use karyons_core::async_util::CondVar; +/// +/// async { +/// +/// let val = Arc::new(Mutex::new(false)); +/// let condvar = Arc::new(CondVar::new()); +/// +/// let val_cloned = val.clone(); +/// let condvar_cloned = condvar.clone(); +/// smol::spawn(async move { +/// let mut val = val_cloned.lock().await; +/// +/// // While the boolean flag is false, wait for a signal. +/// while !*val { +/// val = condvar_cloned.wait(val).await; +/// } +/// +/// // ... +/// }); +/// +/// let condvar_cloned = condvar.clone(); +/// smol::spawn(async move { +/// let mut val = val.lock().await; +/// +/// // While the boolean flag is false, wait for a signal. +/// while !*val { +/// val = condvar_cloned.wait(val).await; +/// } +/// +/// // ... +/// }); +/// +/// // Wake up all waiting tasks on this condvar +/// condvar.broadcast(); +/// }; +/// +/// ``` + +pub struct CondVar { + inner: Mutex, +} + +impl CondVar { + /// Creates a new CondVar + pub fn new() -> Self { + Self { + inner: Mutex::new(Wakers::new()), + } + } + + /// Blocks the current task until this condition variable receives a notification. + pub async fn wait<'a, T>(&self, g: MutexGuard<'a, T>) -> MutexGuard<'a, T> { + let m = MutexGuard::source(&g); + + CondVarAwait::new(self, g).await; + + m.lock().await + } + + /// Wakes up one blocked task waiting on this condvar. + pub fn signal(&self) { + self.inner.lock().unwrap().wake(true); + } + + /// Wakes up all blocked tasks waiting on this condvar. + pub fn broadcast(&self) { + self.inner.lock().unwrap().wake(false); + } +} + +impl Default for CondVar { + fn default() -> Self { + Self::new() + } +} + +struct CondVarAwait<'a, T> { + id: Option, + condvar: &'a CondVar, + guard: Option>, +} + +impl<'a, T> CondVarAwait<'a, T> { + fn new(condvar: &'a CondVar, guard: MutexGuard<'a, T>) -> Self { + Self { + condvar, + guard: Some(guard), + id: None, + } + } +} + +impl<'a, T> Future for CondVarAwait<'a, T> { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut inner = self.condvar.inner.lock().unwrap(); + + match self.guard.take() { + Some(_) => { + // the first pooll will release the Mutexguard + self.id = Some(inner.put(Some(cx.waker().clone()))); + Poll::Pending + } + None => { + // Return Ready if it has already been polled and removed + // from the waker list. + if self.id.is_none() { + return Poll::Ready(()); + } + + let i = self.id.as_ref().unwrap(); + match inner.wakers.get_mut(i).unwrap() { + Some(wk) => { + // This will prevent cloning again + if !wk.will_wake(cx.waker()) { + *wk = cx.waker().clone(); + } + Poll::Pending + } + None => { + inner.delete(i); + self.id = None; + Poll::Ready(()) + } + } + } + } + } +} + +impl<'a, T> Drop for CondVarAwait<'a, T> { + fn drop(&mut self) { + if let Some(id) = self.id { + let mut inner = self.condvar.inner.lock().unwrap(); + if let Some(wk) = inner.wakers.get_mut(&id).unwrap().take() { + wk.wake() + } + } + } +} + +/// Wakers is a helper struct to store the task wakers +struct Wakers { + wakers: HashMap>, +} + +impl Wakers { + fn new() -> Self { + Self { + wakers: HashMap::new(), + } + } + + fn put(&mut self, waker: Option) -> u16 { + let mut id: u16; + + id = random_16(); + while self.wakers.contains_key(&id) { + id = random_16(); + } + + self.wakers.insert(id, waker); + id + } + + fn delete(&mut self, id: &u16) -> Option> { + self.wakers.remove(id) + } + + fn wake(&mut self, signal: bool) { + for (_, wk) in self.wakers.iter_mut() { + match wk.take() { + Some(w) => { + w.wake(); + if signal { + break; + } + } + None => continue, + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use smol::lock::Mutex; + use std::{ + collections::VecDeque, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + }; + + // The tests below demonstrate a solution to a problem in the Wikipedia + // explanation of condition variables: + // https://en.wikipedia.org/wiki/Monitor_(synchronization)#Solving_the_bounded_producer/consumer_problem. + + struct Queue { + items: VecDeque, + max_len: usize, + } + impl Queue { + fn new(max_len: usize) -> Self { + Self { + items: VecDeque::new(), + max_len, + } + } + + fn is_full(&self) -> bool { + self.items.len() == self.max_len + } + + fn is_empty(&self) -> bool { + self.items.is_empty() + } + } + + #[test] + fn test_condvar_signal() { + smol::block_on(async { + let number_of_tasks = 30; + + let queue = Arc::new(Mutex::new(Queue::new(5))); + let condvar_full = Arc::new(CondVar::new()); + let condvar_empty = Arc::new(CondVar::new()); + + let queue_cloned = queue.clone(); + let condvar_full_cloned = condvar_full.clone(); + let condvar_empty_cloned = condvar_empty.clone(); + + let _producer1 = smol::spawn(async move { + for i in 1..number_of_tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-full + while queue.is_full() { + // Release queue mutex and sleep + queue = condvar_full_cloned.wait(queue).await; + } + + queue.items.push_back(format!("task {i}")); + + // Wake up the consumer + condvar_empty_cloned.signal(); + } + }); + + let queue_cloned = queue.clone(); + let task_consumed = Arc::new(AtomicUsize::new(0)); + let task_consumed_ = task_consumed.clone(); + let consumer = smol::spawn(async move { + for _ in 1..number_of_tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-empty + while queue.is_empty() { + // Release queue mutex and sleep + queue = condvar_empty.wait(queue).await; + } + + let _ = queue.items.pop_front().unwrap(); + + task_consumed_.fetch_add(1, Ordering::Relaxed); + + // Do something + + // Wake up the producer + condvar_full.signal(); + } + }); + + consumer.await; + assert!(queue.lock().await.is_empty()); + assert_eq!(task_consumed.load(Ordering::Relaxed), 29); + }); + } + + #[test] + fn test_condvar_broadcast() { + smol::block_on(async { + let tasks = 30; + + let queue = Arc::new(Mutex::new(Queue::new(5))); + let condvar = Arc::new(CondVar::new()); + + let queue_cloned = queue.clone(); + let condvar_cloned = condvar.clone(); + let _producer1 = smol::spawn(async move { + for i in 1..tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-full + while queue.is_full() { + // Release queue mutex and sleep + queue = condvar_cloned.wait(queue).await; + } + + queue.items.push_back(format!("producer1: task {i}")); + + // Wake up all producer and consumer tasks + condvar_cloned.broadcast(); + } + }); + + let queue_cloned = queue.clone(); + let condvar_cloned = condvar.clone(); + let _producer2 = smol::spawn(async move { + for i in 1..tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-full + while queue.is_full() { + // Release queue mutex and sleep + queue = condvar_cloned.wait(queue).await; + } + + queue.items.push_back(format!("producer2: task {i}")); + + // Wake up all producer and consumer tasks + condvar_cloned.broadcast(); + } + }); + + let queue_cloned = queue.clone(); + let task_consumed = Arc::new(AtomicUsize::new(0)); + let task_consumed_ = task_consumed.clone(); + + let consumer = smol::spawn(async move { + for _ in 1..((tasks * 2) - 1) { + { + // Lock queue mutex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-empty + while queue.is_empty() { + // Release queue mutex and sleep + queue = condvar.wait(queue).await; + } + + let _ = queue.items.pop_front().unwrap(); + + task_consumed_.fetch_add(1, Ordering::Relaxed); + + // Do something + + // Wake up all producer and consumer tasks + condvar.broadcast(); + } + } + }); + + consumer.await; + assert!(queue.lock().await.is_empty()); + assert_eq!(task_consumed.load(Ordering::Relaxed), 58); + }); + } +} diff --git a/core/src/async_util/condwait.rs b/core/src/async_util/condwait.rs new file mode 100644 index 0000000..cd4b269 --- /dev/null +++ b/core/src/async_util/condwait.rs @@ -0,0 +1,133 @@ +use smol::lock::Mutex; + +use super::CondVar; + +/// CondWait is a wrapper struct for CondVar with a Mutex boolean flag. +/// +/// # Example +/// +///``` +/// use std::sync::Arc; +/// +/// use karyons_core::async_util::CondWait; +/// +/// async { +/// let cond_wait = Arc::new(CondWait::new()); +/// let cond_wait_cloned = cond_wait.clone(); +/// let task = smol::spawn(async move { +/// cond_wait_cloned.wait().await; +/// // ... +/// }); +/// +/// cond_wait.signal().await; +/// }; +/// +/// ``` +/// +pub struct CondWait { + /// The CondVar + condvar: CondVar, + /// Boolean flag + w: Mutex, +} + +impl CondWait { + /// Creates a new CondWait. + pub fn new() -> Self { + Self { + condvar: CondVar::new(), + w: Mutex::new(false), + } + } + + /// Waits for a signal or broadcast. + pub async fn wait(&self) { + let mut w = self.w.lock().await; + + // While the boolean flag is false, wait for a signal. + while !*w { + w = self.condvar.wait(w).await; + } + } + + /// Signal a waiting task. + pub async fn signal(&self) { + *self.w.lock().await = true; + self.condvar.signal(); + } + + /// Signal all waiting tasks. + pub async fn broadcast(&self) { + *self.w.lock().await = true; + self.condvar.broadcast(); + } + + /// Reset the boolean flag value to false. + pub async fn reset(&self) { + *self.w.lock().await = false; + } +} + +impl Default for CondWait { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; + + #[test] + fn test_cond_wait() { + smol::block_on(async { + let cond_wait = Arc::new(CondWait::new()); + let count = Arc::new(AtomicUsize::new(0)); + + let cond_wait_cloned = cond_wait.clone(); + let count_cloned = count.clone(); + let task = smol::spawn(async move { + cond_wait_cloned.wait().await; + count_cloned.fetch_add(1, Ordering::Relaxed); + // do something + }); + + // Send a signal to the waiting task + cond_wait.signal().await; + + task.await; + + // Reset the boolean flag + cond_wait.reset().await; + + assert_eq!(count.load(Ordering::Relaxed), 1); + + let cond_wait_cloned = cond_wait.clone(); + let count_cloned = count.clone(); + let task1 = smol::spawn(async move { + cond_wait_cloned.wait().await; + count_cloned.fetch_add(1, Ordering::Relaxed); + // do something + }); + + let cond_wait_cloned = cond_wait.clone(); + let count_cloned = count.clone(); + let task2 = smol::spawn(async move { + cond_wait_cloned.wait().await; + count_cloned.fetch_add(1, Ordering::Relaxed); + // do something + }); + + // Broadcast a signal to all waiting tasks + cond_wait.broadcast().await; + + task1.await; + task2.await; + assert_eq!(count.load(Ordering::Relaxed), 3); + }); + } +} diff --git a/core/src/async_util/mod.rs b/core/src/async_util/mod.rs new file mode 100644 index 0000000..c871bad --- /dev/null +++ b/core/src/async_util/mod.rs @@ -0,0 +1,13 @@ +mod backoff; +mod condvar; +mod condwait; +mod select; +mod task_group; +mod timeout; + +pub use backoff::Backoff; +pub use condvar::CondVar; +pub use condwait::CondWait; +pub use select::{select, Either}; +pub use task_group::{TaskGroup, TaskResult}; +pub use timeout::timeout; diff --git a/core/src/async_util/select.rs b/core/src/async_util/select.rs new file mode 100644 index 0000000..8f2f7f6 --- /dev/null +++ b/core/src/async_util/select.rs @@ -0,0 +1,99 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use pin_project_lite::pin_project; +use smol::future::Future; + +/// Returns the result of the future that completes first, preferring future1 +/// if both are ready. +/// +/// # Examples +/// +/// ``` +/// use std::future; +/// +/// use karyons_core::async_util::{select, Either}; +/// +/// async { +/// let fut1 = future::pending::(); +/// let fut2 = future::ready(0); +/// let res = select(fut1, fut2).await; +/// assert!(matches!(res, Either::Right(0))); +/// // .... +/// }; +/// +/// ``` +/// +pub fn select(future1: F1, future2: F2) -> Select +where + F1: Future, + F2: Future, +{ + Select { future1, future2 } +} + +pin_project! { + #[derive(Debug)] + pub struct Select { + #[pin] + future1: F1, + #[pin] + future2: F2, + } +} + +/// The return value from the [`select`] function, indicating which future +/// completed first. +#[derive(Debug)] +pub enum Either { + Left(T1), + Right(T2), +} + +// Implement the Future trait for the Select struct. +impl Future for Select +where + F1: Future, + F2: Future, +{ + type Output = Either; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if let Poll::Ready(t) = this.future1.poll(cx) { + return Poll::Ready(Either::Left(t)); + } + + if let Poll::Ready(t) = this.future2.poll(cx) { + return Poll::Ready(Either::Right(t)); + } + + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use super::{select, Either}; + use smol::Timer; + use std::future; + + #[test] + fn test_async_select() { + smol::block_on(async move { + let fut = select(Timer::never(), future::ready(0 as u32)).await; + assert!(matches!(fut, Either::Right(0))); + + let fut1 = future::pending::(); + let fut2 = future::ready(0); + let res = select(fut1, fut2).await; + assert!(matches!(res, Either::Right(0))); + + let fut1 = future::ready(0); + let fut2 = future::pending::(); + let res = select(fut1, fut2).await; + assert!(matches!(res, Either::Left(_))); + }); + } +} diff --git a/core/src/async_util/task_group.rs b/core/src/async_util/task_group.rs new file mode 100644 index 0000000..3fc0cb7 --- /dev/null +++ b/core/src/async_util/task_group.rs @@ -0,0 +1,194 @@ +use std::{future::Future, sync::Arc, sync::Mutex}; + +use async_task::FallibleTask; + +use crate::Executor; + +use super::{select, CondWait, Either}; + +/// TaskGroup is a group of spawned tasks. +/// +/// # Example +/// +/// ``` +/// +/// use std::sync::Arc; +/// +/// use karyons_core::async_util::TaskGroup; +/// +/// async { +/// +/// let ex = Arc::new(smol::Executor::new()); +/// let group = TaskGroup::new(ex); +/// +/// group.spawn(smol::Timer::never(), |_| async {}); +/// +/// group.cancel().await; +/// +/// }; +/// +/// ``` +/// +pub struct TaskGroup<'a> { + tasks: Mutex>, + stop_signal: Arc, + executor: Executor<'a>, +} + +impl<'a> TaskGroup<'a> { + /// Creates a new task group + pub fn new(executor: Executor<'a>) -> Self { + Self { + tasks: Mutex::new(Vec::new()), + stop_signal: Arc::new(CondWait::new()), + executor, + } + } + + /// Spawns a new task and calls the callback after it has completed + /// or been canceled. The callback will have the `TaskResult` as a + /// parameter, indicating whether the task completed or was canceled. + pub fn spawn(&self, fut: Fut, callback: CallbackF) + where + T: Send + Sync + 'a, + Fut: Future + Send + 'a, + CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, + CallbackFut: Future + Send + 'a, + { + let task = TaskHandler::new( + self.executor.clone(), + fut, + callback, + self.stop_signal.clone(), + ); + self.tasks.lock().unwrap().push(task); + } + + /// Checks if the task group is empty. + pub fn is_empty(&self) -> bool { + self.tasks.lock().unwrap().is_empty() + } + + /// Get the number of the tasks in the group. + pub fn len(&self) -> usize { + self.tasks.lock().unwrap().len() + } + + /// Cancels all tasks in the group. + pub async fn cancel(&self) { + self.stop_signal.broadcast().await; + + loop { + let task = self.tasks.lock().unwrap().pop(); + if let Some(t) = task { + t.cancel().await + } else { + break; + } + } + } +} + +/// The result of a spawned task. +#[derive(Debug)] +pub enum TaskResult { + Completed(T), + Cancelled, +} + +impl std::fmt::Display for TaskResult { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + TaskResult::Cancelled => write!(f, "Task cancelled"), + TaskResult::Completed(res) => write!(f, "Task completed: {:?}", res), + } + } +} + +/// TaskHandler +pub struct TaskHandler { + task: FallibleTask<()>, + cancel_flag: Arc, +} + +impl<'a> TaskHandler { + /// Creates a new task handle + fn new( + ex: Executor<'a>, + fut: Fut, + callback: CallbackF, + stop_signal: Arc, + ) -> TaskHandler + where + T: Send + Sync + 'a, + Fut: Future + Send + 'a, + CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, + CallbackFut: Future + Send + 'a, + { + let cancel_flag = Arc::new(CondWait::new()); + let cancel_flag_c = cancel_flag.clone(); + let task = ex + .spawn(async move { + //start_signal.signal().await; + // Waits for either the stop signal or the task to complete. + let result = select(stop_signal.wait(), fut).await; + + let result = match result { + Either::Left(_) => TaskResult::Cancelled, + Either::Right(res) => TaskResult::Completed(res), + }; + + // Call the callback with the result. + callback(result).await; + + cancel_flag_c.signal().await; + }) + .fallible(); + + TaskHandler { task, cancel_flag } + } + + /// Cancels the task. + async fn cancel(self) { + self.cancel_flag.wait().await; + self.task.cancel().await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{future, sync::Arc}; + + #[test] + fn test_task_group() { + let ex = Arc::new(smol::Executor::new()); + smol::block_on(ex.clone().run(async move { + let group = Arc::new(TaskGroup::new(ex)); + + group.spawn(future::ready(0), |res| async move { + assert!(matches!(res, TaskResult::Completed(0))); + }); + + group.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + + let groupc = group.clone(); + group.spawn( + async move { + groupc.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + }, + |res| async move { + assert!(matches!(res, TaskResult::Completed(_))); + }, + ); + + // Do something + smol::Timer::after(std::time::Duration::from_millis(50)).await; + group.cancel().await; + })); + } +} diff --git a/core/src/async_util/timeout.rs b/core/src/async_util/timeout.rs new file mode 100644 index 0000000..6ab35c4 --- /dev/null +++ b/core/src/async_util/timeout.rs @@ -0,0 +1,52 @@ +use std::{future::Future, time::Duration}; + +use smol::Timer; + +use super::{select, Either}; +use crate::{error::Error, Result}; + +/// Waits for a future to complete or times out if it exceeds a specified +/// duration. +/// +/// # Example +/// +/// ``` +/// use std::{future, time::Duration}; +/// +/// use karyons_core::async_util::timeout; +/// +/// async { +/// let fut = future::pending::<()>(); +/// assert!(timeout(Duration::from_millis(100), fut).await.is_err()); +/// }; +/// +/// ``` +/// +pub async fn timeout(delay: Duration, future1: F) -> Result +where + F: Future, +{ + let result = select(Timer::after(delay), future1).await; + + match result { + Either::Left(_) => Err(Error::Timeout), + Either::Right(res) => Ok(res), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{future, time::Duration}; + + #[test] + fn test_timeout() { + smol::block_on(async move { + let fut = future::pending::<()>(); + assert!(timeout(Duration::from_millis(10), fut).await.is_err()); + + let fut = smol::Timer::after(Duration::from_millis(10)); + assert!(timeout(Duration::from_millis(50), fut).await.is_ok()) + }); + } +} diff --git a/core/src/async_utils/backoff.rs b/core/src/async_utils/backoff.rs deleted file mode 100644 index f7e131d..0000000 --- a/core/src/async_utils/backoff.rs +++ /dev/null @@ -1,115 +0,0 @@ -use std::{ - cmp::min, - sync::atomic::{AtomicBool, AtomicU32, Ordering}, - time::Duration, -}; - -use smol::Timer; - -/// Exponential backoff -/// -/// -/// # Examples -/// -/// ``` -/// use karyons_core::async_utils::Backoff; -/// -/// async { -/// let backoff = Backoff::new(300, 3000); -/// -/// loop { -/// backoff.sleep().await; -/// -/// // do something -/// break; -/// } -/// -/// backoff.reset(); -/// -/// // .... -/// }; -/// -/// ``` -/// -pub struct Backoff { - /// The base delay in milliseconds for the initial retry. - base_delay: u64, - /// The max delay in milliseconds allowed for a retry. - max_delay: u64, - /// Atomic counter - retries: AtomicU32, - /// Stop flag - stop: AtomicBool, -} - -impl Backoff { - /// Creates a new Backoff. - pub fn new(base_delay: u64, max_delay: u64) -> Self { - Self { - base_delay, - max_delay, - retries: AtomicU32::new(0), - stop: AtomicBool::new(false), - } - } - - /// Sleep based on the current retry count and delay values. - /// Retruns the delay value. - pub async fn sleep(&self) -> u64 { - if self.stop.load(Ordering::SeqCst) { - Timer::after(Duration::from_millis(self.max_delay)).await; - return self.max_delay; - } - - let retries = self.retries.load(Ordering::SeqCst); - let delay = self.base_delay * (2_u64).pow(retries); - let delay = min(delay, self.max_delay); - - if delay == self.max_delay { - self.stop.store(true, Ordering::SeqCst); - } - - self.retries.store(retries + 1, Ordering::SeqCst); - - Timer::after(Duration::from_millis(delay)).await; - delay - } - - /// Reset the retry counter to 0. - pub fn reset(&self) { - self.retries.store(0, Ordering::SeqCst); - self.stop.store(false, Ordering::SeqCst); - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::sync::Arc; - - #[test] - fn test_backoff() { - smol::block_on(async move { - let backoff = Arc::new(Backoff::new(5, 15)); - let backoff_c = backoff.clone(); - smol::spawn(async move { - let delay = backoff_c.sleep().await; - assert_eq!(delay, 5); - - let delay = backoff_c.sleep().await; - assert_eq!(delay, 10); - - let delay = backoff_c.sleep().await; - assert_eq!(delay, 15); - }) - .await; - - smol::spawn(async move { - backoff.reset(); - let delay = backoff.sleep().await; - assert_eq!(delay, 5); - }) - .await; - }); - } -} diff --git a/core/src/async_utils/condvar.rs b/core/src/async_utils/condvar.rs deleted file mode 100644 index 814f78f..0000000 --- a/core/src/async_utils/condvar.rs +++ /dev/null @@ -1,387 +0,0 @@ -use std::{ - collections::HashMap, - future::Future, - pin::Pin, - sync::Mutex, - task::{Context, Poll, Waker}, -}; - -use smol::lock::MutexGuard; - -use crate::utils::random_16; - -/// CondVar is an async version of -/// -/// # Example -/// -///``` -/// use std::sync::Arc; -/// -/// use smol::lock::Mutex; -/// -/// use karyons_core::async_utils::CondVar; -/// -/// async { -/// -/// let val = Arc::new(Mutex::new(false)); -/// let condvar = Arc::new(CondVar::new()); -/// -/// let val_cloned = val.clone(); -/// let condvar_cloned = condvar.clone(); -/// smol::spawn(async move { -/// let mut val = val_cloned.lock().await; -/// -/// // While the boolean flag is false, wait for a signal. -/// while !*val { -/// val = condvar_cloned.wait(val).await; -/// } -/// -/// // ... -/// }); -/// -/// let condvar_cloned = condvar.clone(); -/// smol::spawn(async move { -/// let mut val = val.lock().await; -/// -/// // While the boolean flag is false, wait for a signal. -/// while !*val { -/// val = condvar_cloned.wait(val).await; -/// } -/// -/// // ... -/// }); -/// -/// // Wake up all waiting tasks on this condvar -/// condvar.broadcast(); -/// }; -/// -/// ``` - -pub struct CondVar { - inner: Mutex, -} - -impl CondVar { - /// Creates a new CondVar - pub fn new() -> Self { - Self { - inner: Mutex::new(Wakers::new()), - } - } - - /// Blocks the current task until this condition variable receives a notification. - pub async fn wait<'a, T>(&self, g: MutexGuard<'a, T>) -> MutexGuard<'a, T> { - let m = MutexGuard::source(&g); - - CondVarAwait::new(self, g).await; - - m.lock().await - } - - /// Wakes up one blocked task waiting on this condvar. - pub fn signal(&self) { - self.inner.lock().unwrap().wake(true); - } - - /// Wakes up all blocked tasks waiting on this condvar. - pub fn broadcast(&self) { - self.inner.lock().unwrap().wake(false); - } -} - -impl Default for CondVar { - fn default() -> Self { - Self::new() - } -} - -struct CondVarAwait<'a, T> { - id: Option, - condvar: &'a CondVar, - guard: Option>, -} - -impl<'a, T> CondVarAwait<'a, T> { - fn new(condvar: &'a CondVar, guard: MutexGuard<'a, T>) -> Self { - Self { - condvar, - guard: Some(guard), - id: None, - } - } -} - -impl<'a, T> Future for CondVarAwait<'a, T> { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut inner = self.condvar.inner.lock().unwrap(); - - match self.guard.take() { - Some(_) => { - // the first pooll will release the Mutexguard - self.id = Some(inner.put(Some(cx.waker().clone()))); - Poll::Pending - } - None => { - // Return Ready if it has already been polled and removed - // from the waker list. - if self.id.is_none() { - return Poll::Ready(()); - } - - let i = self.id.as_ref().unwrap(); - match inner.wakers.get_mut(i).unwrap() { - Some(wk) => { - // This will prevent cloning again - if !wk.will_wake(cx.waker()) { - *wk = cx.waker().clone(); - } - Poll::Pending - } - None => { - inner.delete(i); - self.id = None; - Poll::Ready(()) - } - } - } - } - } -} - -impl<'a, T> Drop for CondVarAwait<'a, T> { - fn drop(&mut self) { - if let Some(id) = self.id { - let mut inner = self.condvar.inner.lock().unwrap(); - if let Some(wk) = inner.wakers.get_mut(&id).unwrap().take() { - wk.wake() - } - } - } -} - -/// Wakers is a helper struct to store the task wakers -struct Wakers { - wakers: HashMap>, -} - -impl Wakers { - fn new() -> Self { - Self { - wakers: HashMap::new(), - } - } - - fn put(&mut self, waker: Option) -> u16 { - let mut id: u16; - - id = random_16(); - while self.wakers.contains_key(&id) { - id = random_16(); - } - - self.wakers.insert(id, waker); - id - } - - fn delete(&mut self, id: &u16) -> Option> { - self.wakers.remove(id) - } - - fn wake(&mut self, signal: bool) { - for (_, wk) in self.wakers.iter_mut() { - match wk.take() { - Some(w) => { - w.wake(); - if signal { - break; - } - } - None => continue, - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use smol::lock::Mutex; - use std::{ - collections::VecDeque, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - }; - - // The tests below demonstrate a solution to a problem in the Wikipedia - // explanation of condition variables: - // https://en.wikipedia.org/wiki/Monitor_(synchronization)#Solving_the_bounded_producer/consumer_problem. - - struct Queue { - items: VecDeque, - max_len: usize, - } - impl Queue { - fn new(max_len: usize) -> Self { - Self { - items: VecDeque::new(), - max_len, - } - } - - fn is_full(&self) -> bool { - self.items.len() == self.max_len - } - - fn is_empty(&self) -> bool { - self.items.is_empty() - } - } - - #[test] - fn test_condvar_signal() { - smol::block_on(async { - let number_of_tasks = 30; - - let queue = Arc::new(Mutex::new(Queue::new(5))); - let condvar_full = Arc::new(CondVar::new()); - let condvar_empty = Arc::new(CondVar::new()); - - let queue_cloned = queue.clone(); - let condvar_full_cloned = condvar_full.clone(); - let condvar_empty_cloned = condvar_empty.clone(); - - let _producer1 = smol::spawn(async move { - for i in 1..number_of_tasks { - // Lock queue mtuex - let mut queue = queue_cloned.lock().await; - - // Check if the queue is non-full - while queue.is_full() { - // Release queue mutex and sleep - queue = condvar_full_cloned.wait(queue).await; - } - - queue.items.push_back(format!("task {i}")); - - // Wake up the consumer - condvar_empty_cloned.signal(); - } - }); - - let queue_cloned = queue.clone(); - let task_consumed = Arc::new(AtomicUsize::new(0)); - let task_consumed_ = task_consumed.clone(); - let consumer = smol::spawn(async move { - for _ in 1..number_of_tasks { - // Lock queue mtuex - let mut queue = queue_cloned.lock().await; - - // Check if the queue is non-empty - while queue.is_empty() { - // Release queue mutex and sleep - queue = condvar_empty.wait(queue).await; - } - - let _ = queue.items.pop_front().unwrap(); - - task_consumed_.fetch_add(1, Ordering::Relaxed); - - // Do something - - // Wake up the producer - condvar_full.signal(); - } - }); - - consumer.await; - assert!(queue.lock().await.is_empty()); - assert_eq!(task_consumed.load(Ordering::Relaxed), 29); - }); - } - - #[test] - fn test_condvar_broadcast() { - smol::block_on(async { - let tasks = 30; - - let queue = Arc::new(Mutex::new(Queue::new(5))); - let condvar = Arc::new(CondVar::new()); - - let queue_cloned = queue.clone(); - let condvar_cloned = condvar.clone(); - let _producer1 = smol::spawn(async move { - for i in 1..tasks { - // Lock queue mtuex - let mut queue = queue_cloned.lock().await; - - // Check if the queue is non-full - while queue.is_full() { - // Release queue mutex and sleep - queue = condvar_cloned.wait(queue).await; - } - - queue.items.push_back(format!("producer1: task {i}")); - - // Wake up all producer and consumer tasks - condvar_cloned.broadcast(); - } - }); - - let queue_cloned = queue.clone(); - let condvar_cloned = condvar.clone(); - let _producer2 = smol::spawn(async move { - for i in 1..tasks { - // Lock queue mtuex - let mut queue = queue_cloned.lock().await; - - // Check if the queue is non-full - while queue.is_full() { - // Release queue mutex and sleep - queue = condvar_cloned.wait(queue).await; - } - - queue.items.push_back(format!("producer2: task {i}")); - - // Wake up all producer and consumer tasks - condvar_cloned.broadcast(); - } - }); - - let queue_cloned = queue.clone(); - let task_consumed = Arc::new(AtomicUsize::new(0)); - let task_consumed_ = task_consumed.clone(); - - let consumer = smol::spawn(async move { - for _ in 1..((tasks * 2) - 1) { - { - // Lock queue mutex - let mut queue = queue_cloned.lock().await; - - // Check if the queue is non-empty - while queue.is_empty() { - // Release queue mutex and sleep - queue = condvar.wait(queue).await; - } - - let _ = queue.items.pop_front().unwrap(); - - task_consumed_.fetch_add(1, Ordering::Relaxed); - - // Do something - - // Wake up all producer and consumer tasks - condvar.broadcast(); - } - } - }); - - consumer.await; - assert!(queue.lock().await.is_empty()); - assert_eq!(task_consumed.load(Ordering::Relaxed), 58); - }); - } -} diff --git a/core/src/async_utils/condwait.rs b/core/src/async_utils/condwait.rs deleted file mode 100644 index e31fac3..0000000 --- a/core/src/async_utils/condwait.rs +++ /dev/null @@ -1,133 +0,0 @@ -use smol::lock::Mutex; - -use super::CondVar; - -/// CondWait is a wrapper struct for CondVar with a Mutex boolean flag. -/// -/// # Example -/// -///``` -/// use std::sync::Arc; -/// -/// use karyons_core::async_utils::CondWait; -/// -/// async { -/// let cond_wait = Arc::new(CondWait::new()); -/// let cond_wait_cloned = cond_wait.clone(); -/// let task = smol::spawn(async move { -/// cond_wait_cloned.wait().await; -/// // ... -/// }); -/// -/// cond_wait.signal().await; -/// }; -/// -/// ``` -/// -pub struct CondWait { - /// The CondVar - condvar: CondVar, - /// Boolean flag - w: Mutex, -} - -impl CondWait { - /// Creates a new CondWait. - pub fn new() -> Self { - Self { - condvar: CondVar::new(), - w: Mutex::new(false), - } - } - - /// Waits for a signal or broadcast. - pub async fn wait(&self) { - let mut w = self.w.lock().await; - - // While the boolean flag is false, wait for a signal. - while !*w { - w = self.condvar.wait(w).await; - } - } - - /// Signal a waiting task. - pub async fn signal(&self) { - *self.w.lock().await = true; - self.condvar.signal(); - } - - /// Signal all waiting tasks. - pub async fn broadcast(&self) { - *self.w.lock().await = true; - self.condvar.broadcast(); - } - - /// Reset the boolean flag value to false. - pub async fn reset(&self) { - *self.w.lock().await = false; - } -} - -impl Default for CondWait { - fn default() -> Self { - Self::new() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }; - - #[test] - fn test_cond_wait() { - smol::block_on(async { - let cond_wait = Arc::new(CondWait::new()); - let count = Arc::new(AtomicUsize::new(0)); - - let cond_wait_cloned = cond_wait.clone(); - let count_cloned = count.clone(); - let task = smol::spawn(async move { - cond_wait_cloned.wait().await; - count_cloned.fetch_add(1, Ordering::Relaxed); - // do something - }); - - // Send a signal to the waiting task - cond_wait.signal().await; - - task.await; - - // Reset the boolean flag - cond_wait.reset().await; - - assert_eq!(count.load(Ordering::Relaxed), 1); - - let cond_wait_cloned = cond_wait.clone(); - let count_cloned = count.clone(); - let task1 = smol::spawn(async move { - cond_wait_cloned.wait().await; - count_cloned.fetch_add(1, Ordering::Relaxed); - // do something - }); - - let cond_wait_cloned = cond_wait.clone(); - let count_cloned = count.clone(); - let task2 = smol::spawn(async move { - cond_wait_cloned.wait().await; - count_cloned.fetch_add(1, Ordering::Relaxed); - // do something - }); - - // Broadcast a signal to all waiting tasks - cond_wait.broadcast().await; - - task1.await; - task2.await; - assert_eq!(count.load(Ordering::Relaxed), 3); - }); - } -} diff --git a/core/src/async_utils/mod.rs b/core/src/async_utils/mod.rs deleted file mode 100644 index c871bad..0000000 --- a/core/src/async_utils/mod.rs +++ /dev/null @@ -1,13 +0,0 @@ -mod backoff; -mod condvar; -mod condwait; -mod select; -mod task_group; -mod timeout; - -pub use backoff::Backoff; -pub use condvar::CondVar; -pub use condwait::CondWait; -pub use select::{select, Either}; -pub use task_group::{TaskGroup, TaskResult}; -pub use timeout::timeout; diff --git a/core/src/async_utils/select.rs b/core/src/async_utils/select.rs deleted file mode 100644 index 9fe3c77..0000000 --- a/core/src/async_utils/select.rs +++ /dev/null @@ -1,99 +0,0 @@ -use std::pin::Pin; -use std::task::{Context, Poll}; - -use pin_project_lite::pin_project; -use smol::future::Future; - -/// Returns the result of the future that completes first, preferring future1 -/// if both are ready. -/// -/// # Examples -/// -/// ``` -/// use std::future; -/// -/// use karyons_core::async_utils::{select, Either}; -/// -/// async { -/// let fut1 = future::pending::(); -/// let fut2 = future::ready(0); -/// let res = select(fut1, fut2).await; -/// assert!(matches!(res, Either::Right(0))); -/// // .... -/// }; -/// -/// ``` -/// -pub fn select(future1: F1, future2: F2) -> Select -where - F1: Future, - F2: Future, -{ - Select { future1, future2 } -} - -pin_project! { - #[derive(Debug)] - pub struct Select { - #[pin] - future1: F1, - #[pin] - future2: F2, - } -} - -/// The return value from the [`select`] function, indicating which future -/// completed first. -#[derive(Debug)] -pub enum Either { - Left(T1), - Right(T2), -} - -// Implement the Future trait for the Select struct. -impl Future for Select -where - F1: Future, - F2: Future, -{ - type Output = Either; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - if let Poll::Ready(t) = this.future1.poll(cx) { - return Poll::Ready(Either::Left(t)); - } - - if let Poll::Ready(t) = this.future2.poll(cx) { - return Poll::Ready(Either::Right(t)); - } - - Poll::Pending - } -} - -#[cfg(test)] -mod tests { - use super::{select, Either}; - use smol::Timer; - use std::future; - - #[test] - fn test_async_select() { - smol::block_on(async move { - let fut = select(Timer::never(), future::ready(0 as u32)).await; - assert!(matches!(fut, Either::Right(0))); - - let fut1 = future::pending::(); - let fut2 = future::ready(0); - let res = select(fut1, fut2).await; - assert!(matches!(res, Either::Right(0))); - - let fut1 = future::ready(0); - let fut2 = future::pending::(); - let res = select(fut1, fut2).await; - assert!(matches!(res, Either::Left(_))); - }); - } -} diff --git a/core/src/async_utils/task_group.rs b/core/src/async_utils/task_group.rs deleted file mode 100644 index afc9648..0000000 --- a/core/src/async_utils/task_group.rs +++ /dev/null @@ -1,194 +0,0 @@ -use std::{future::Future, sync::Arc, sync::Mutex}; - -use async_task::FallibleTask; - -use crate::Executor; - -use super::{select, CondWait, Either}; - -/// TaskGroup is a group of spawned tasks. -/// -/// # Example -/// -/// ``` -/// -/// use std::sync::Arc; -/// -/// use karyons_core::async_utils::TaskGroup; -/// -/// async { -/// -/// let ex = Arc::new(smol::Executor::new()); -/// let group = TaskGroup::new(ex); -/// -/// group.spawn(smol::Timer::never(), |_| async {}); -/// -/// group.cancel().await; -/// -/// }; -/// -/// ``` -/// -pub struct TaskGroup<'a> { - tasks: Mutex>, - stop_signal: Arc, - executor: Executor<'a>, -} - -impl<'a> TaskGroup<'a> { - /// Creates a new task group - pub fn new(executor: Executor<'a>) -> Self { - Self { - tasks: Mutex::new(Vec::new()), - stop_signal: Arc::new(CondWait::new()), - executor, - } - } - - /// Spawns a new task and calls the callback after it has completed - /// or been canceled. The callback will have the `TaskResult` as a - /// parameter, indicating whether the task completed or was canceled. - pub fn spawn(&self, fut: Fut, callback: CallbackF) - where - T: Send + Sync + 'a, - Fut: Future + Send + 'a, - CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, - CallbackFut: Future + Send + 'a, - { - let task = TaskHandler::new( - self.executor.clone(), - fut, - callback, - self.stop_signal.clone(), - ); - self.tasks.lock().unwrap().push(task); - } - - /// Checks if the task group is empty. - pub fn is_empty(&self) -> bool { - self.tasks.lock().unwrap().is_empty() - } - - /// Get the number of the tasks in the group. - pub fn len(&self) -> usize { - self.tasks.lock().unwrap().len() - } - - /// Cancels all tasks in the group. - pub async fn cancel(&self) { - self.stop_signal.broadcast().await; - - loop { - let task = self.tasks.lock().unwrap().pop(); - if let Some(t) = task { - t.cancel().await - } else { - break; - } - } - } -} - -/// The result of a spawned task. -#[derive(Debug)] -pub enum TaskResult { - Completed(T), - Cancelled, -} - -impl std::fmt::Display for TaskResult { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - TaskResult::Cancelled => write!(f, "Task cancelled"), - TaskResult::Completed(res) => write!(f, "Task completed: {:?}", res), - } - } -} - -/// TaskHandler -pub struct TaskHandler { - task: FallibleTask<()>, - cancel_flag: Arc, -} - -impl<'a> TaskHandler { - /// Creates a new task handle - fn new( - ex: Executor<'a>, - fut: Fut, - callback: CallbackF, - stop_signal: Arc, - ) -> TaskHandler - where - T: Send + Sync + 'a, - Fut: Future + Send + 'a, - CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, - CallbackFut: Future + Send + 'a, - { - let cancel_flag = Arc::new(CondWait::new()); - let cancel_flag_c = cancel_flag.clone(); - let task = ex - .spawn(async move { - //start_signal.signal().await; - // Waits for either the stop signal or the task to complete. - let result = select(stop_signal.wait(), fut).await; - - let result = match result { - Either::Left(_) => TaskResult::Cancelled, - Either::Right(res) => TaskResult::Completed(res), - }; - - // Call the callback with the result. - callback(result).await; - - cancel_flag_c.signal().await; - }) - .fallible(); - - TaskHandler { task, cancel_flag } - } - - /// Cancels the task. - async fn cancel(self) { - self.cancel_flag.wait().await; - self.task.cancel().await; - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::{future, sync::Arc}; - - #[test] - fn test_task_group() { - let ex = Arc::new(smol::Executor::new()); - smol::block_on(ex.clone().run(async move { - let group = Arc::new(TaskGroup::new(ex)); - - group.spawn(future::ready(0), |res| async move { - assert!(matches!(res, TaskResult::Completed(0))); - }); - - group.spawn(future::pending::<()>(), |res| async move { - assert!(matches!(res, TaskResult::Cancelled)); - }); - - let groupc = group.clone(); - group.spawn( - async move { - groupc.spawn(future::pending::<()>(), |res| async move { - assert!(matches!(res, TaskResult::Cancelled)); - }); - }, - |res| async move { - assert!(matches!(res, TaskResult::Completed(_))); - }, - ); - - // Do something - smol::Timer::after(std::time::Duration::from_millis(50)).await; - group.cancel().await; - })); - } -} diff --git a/core/src/async_utils/timeout.rs b/core/src/async_utils/timeout.rs deleted file mode 100644 index 7c55e1b..0000000 --- a/core/src/async_utils/timeout.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::{future::Future, time::Duration}; - -use smol::Timer; - -use super::{select, Either}; -use crate::{error::Error, Result}; - -/// Waits for a future to complete or times out if it exceeds a specified -/// duration. -/// -/// # Example -/// -/// ``` -/// use std::{future, time::Duration}; -/// -/// use karyons_core::async_utils::timeout; -/// -/// async { -/// let fut = future::pending::<()>(); -/// assert!(timeout(Duration::from_millis(100), fut).await.is_err()); -/// }; -/// -/// ``` -/// -pub async fn timeout(delay: Duration, future1: F) -> Result -where - F: Future, -{ - let result = select(Timer::after(delay), future1).await; - - match result { - Either::Left(_) => Err(Error::Timeout), - Either::Right(res) => Ok(res), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::{future, time::Duration}; - - #[test] - fn test_timeout() { - smol::block_on(async move { - let fut = future::pending::<()>(); - assert!(timeout(Duration::from_millis(10), fut).await.is_err()); - - let fut = smol::Timer::after(Duration::from_millis(10)); - assert!(timeout(Duration::from_millis(50), fut).await.is_ok()) - }); - } -} diff --git a/core/src/error.rs b/core/src/error.rs index 63b45d3..7c547c4 100644 --- a/core/src/error.rs +++ b/core/src/error.rs @@ -7,12 +7,18 @@ pub enum Error { #[error(transparent)] IO(#[from] std::io::Error), + #[error("TryInto Error: {0}")] + TryInto(&'static str), + #[error("Timeout Error")] Timeout, #[error("Path Not Found Error: {0}")] PathNotFound(&'static str), + #[error(transparent)] + Ed25519(#[from] ed25519_dalek::ed25519::Error), + #[error("Channel Send Error: {0}")] ChannelSend(String), diff --git a/core/src/event.rs b/core/src/event.rs index 0503e88..f2c5510 100644 --- a/core/src/event.rs +++ b/core/src/event.rs @@ -12,7 +12,7 @@ use smol::{ lock::Mutex, }; -use crate::{utils::random_16, Result}; +use crate::{util::random_16, Result}; pub type ArcEventSys = Arc>; pub type WeakEventSys = Weak>; diff --git a/core/src/key_pair.rs b/core/src/key_pair.rs new file mode 100644 index 0000000..4016351 --- /dev/null +++ b/core/src/key_pair.rs @@ -0,0 +1,189 @@ +use ed25519_dalek::{Signer as _, Verifier as _}; +use rand::rngs::OsRng; + +use crate::{error::Error, Result}; + +/// key cryptography type +pub enum KeyPairType { + Ed25519, +} + +/// A Public key +pub struct PublicKey(PublicKeyInner); + +/// A Secret key +pub struct SecretKey(Vec); + +impl PublicKey { + pub fn as_bytes(&self) -> &[u8] { + self.0.as_bytes() + } + + /// Verify a signature on a message with this public key. + pub fn verify(&self, msg: &[u8], signature: &[u8]) -> Result<()> { + self.0.verify(msg, signature) + } +} + +impl PublicKey { + pub fn from_bytes(kp_type: &KeyPairType, pk: &[u8]) -> Result { + Ok(Self(PublicKeyInner::from_bytes(kp_type, pk)?)) + } +} + +/// A KeyPair. +#[derive(Clone)] +pub struct KeyPair(KeyPairInner); + +impl KeyPair { + /// Generate a new random keypair. + pub fn generate(kp_type: &KeyPairType) -> Self { + Self(KeyPairInner::generate(kp_type)) + } + + /// Sign a message using the private key. + pub fn sign(&self, msg: &[u8]) -> Vec { + self.0.sign(msg) + } + + /// Get the public key of this keypair. + pub fn public(&self) -> PublicKey { + self.0.public() + } + + /// Get the secret key of this keypair. + pub fn secret(&self) -> SecretKey { + self.0.secret() + } +} + +/// An extension trait, adding essential methods to all [`KeyPair`] types. +trait KeyPairExt { + /// Sign a message using the private key. + fn sign(&self, msg: &[u8]) -> Vec; + + /// Get the public key of this keypair. + fn public(&self) -> PublicKey; + + /// Get the secret key of this keypair. + fn secret(&self) -> SecretKey; +} + +#[derive(Clone)] +enum KeyPairInner { + Ed25519(Ed25519KeyPair), +} + +impl KeyPairInner { + fn generate(kp_type: &KeyPairType) -> Self { + match kp_type { + KeyPairType::Ed25519 => Self::Ed25519(Ed25519KeyPair::generate()), + } + } +} + +impl KeyPairExt for KeyPairInner { + fn sign(&self, msg: &[u8]) -> Vec { + match self { + KeyPairInner::Ed25519(kp) => kp.sign(msg), + } + } + + fn public(&self) -> PublicKey { + match self { + KeyPairInner::Ed25519(kp) => kp.public(), + } + } + + fn secret(&self) -> SecretKey { + match self { + KeyPairInner::Ed25519(kp) => kp.secret(), + } + } +} + +#[derive(Clone)] +struct Ed25519KeyPair(ed25519_dalek::SigningKey); + +impl Ed25519KeyPair { + fn generate() -> Self { + Self(ed25519_dalek::SigningKey::generate(&mut OsRng)) + } +} + +impl KeyPairExt for Ed25519KeyPair { + fn sign(&self, msg: &[u8]) -> Vec { + self.0.sign(msg).to_bytes().to_vec() + } + + fn public(&self) -> PublicKey { + PublicKey(PublicKeyInner::Ed25519(Ed25519PublicKey( + self.0.verifying_key(), + ))) + } + + fn secret(&self) -> SecretKey { + SecretKey(self.0.to_bytes().to_vec()) + } +} + +/// An extension trait, adding essential methods to all [`PublicKey`] types. +trait PublicKeyExt { + fn as_bytes(&self) -> &[u8]; + + /// Verify a signature on a message with this public key. + fn verify(&self, msg: &[u8], signature: &[u8]) -> Result<()>; +} + +enum PublicKeyInner { + Ed25519(Ed25519PublicKey), +} + +impl PublicKeyInner { + pub fn from_bytes(kp_type: &KeyPairType, pk: &[u8]) -> Result { + match kp_type { + KeyPairType::Ed25519 => Ok(Self::Ed25519(Ed25519PublicKey::from_bytes(pk)?)), + } + } +} + +impl PublicKeyExt for PublicKeyInner { + fn as_bytes(&self) -> &[u8] { + match self { + Self::Ed25519(pk) => pk.as_bytes(), + } + } + + fn verify(&self, msg: &[u8], signature: &[u8]) -> Result<()> { + match self { + Self::Ed25519(pk) => pk.verify(msg, signature), + } + } +} + +struct Ed25519PublicKey(ed25519_dalek::VerifyingKey); + +impl Ed25519PublicKey { + pub fn from_bytes(pk: &[u8]) -> Result { + let pk_bytes: [u8; 32] = pk + .try_into() + .map_err(|_| Error::TryInto("Failed to convert slice to [u8; 32]"))?; + + Ok(Self(ed25519_dalek::VerifyingKey::from_bytes(&pk_bytes)?)) + } +} + +impl PublicKeyExt for Ed25519PublicKey { + fn as_bytes(&self) -> &[u8] { + self.0.as_bytes() + } + + fn verify(&self, msg: &[u8], signature: &[u8]) -> Result<()> { + let sig_bytes: [u8; 64] = signature + .try_into() + .map_err(|_| Error::TryInto("Failed to convert slice to [u8; 64]"))?; + self.0 + .verify(msg, &ed25519_dalek::Signature::from_bytes(&sig_bytes))?; + Ok(()) + } +} diff --git a/core/src/lib.rs b/core/src/lib.rs index 67e6610..276ed89 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,19 +1,22 @@ /// A set of helper tools and functions. -pub mod utils; +pub mod util; /// A module containing async utilities that work with the /// [`smol`](https://github.com/smol-rs/smol) async runtime. -pub mod async_utils; +pub mod async_util; /// Represents karyons's Core Error. pub mod error; -/// [`event::EventSys`] Implementation +/// [`event::EventSys`] implementation. pub mod event; /// A simple publish-subscribe system [`Read More`](./pubsub/struct.Publisher.html) pub mod pubsub; +/// A cryptographic key pair +pub mod key_pair; + use smol::Executor as SmolEx; use std::sync::Arc; diff --git a/core/src/pubsub.rs b/core/src/pubsub.rs index 4cc0ab7..306d42f 100644 --- a/core/src/pubsub.rs +++ b/core/src/pubsub.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc}; use log::error; use smol::lock::Mutex; -use crate::{utils::random_16, Result}; +use crate::{util::random_16, Result}; pub type ArcPublisher = Arc>; pub type SubscriptionID = u16; diff --git a/core/src/util/decode.rs b/core/src/util/decode.rs new file mode 100644 index 0000000..a8a6522 --- /dev/null +++ b/core/src/util/decode.rs @@ -0,0 +1,10 @@ +use bincode::Decode; + +use crate::Result; + +/// Decodes a given type `T` from the given slice. returns the decoded value +/// along with the number of bytes read. +pub fn decode(src: &[u8]) -> Result<(T, usize)> { + let (result, bytes_read) = bincode::decode_from_slice(src, bincode::config::standard())?; + Ok((result, bytes_read)) +} diff --git a/core/src/util/encode.rs b/core/src/util/encode.rs new file mode 100644 index 0000000..7d1061b --- /dev/null +++ b/core/src/util/encode.rs @@ -0,0 +1,15 @@ +use bincode::Encode; + +use crate::Result; + +/// Encode the given type `T` into a `Vec`. +pub fn encode(msg: &T) -> Result> { + let vec = bincode::encode_to_vec(msg, bincode::config::standard())?; + Ok(vec) +} + +/// Encode the given type `T` into the given slice.. +pub fn encode_into_slice(msg: &T, dst: &mut [u8]) -> Result<()> { + bincode::encode_into_slice(msg, dst, bincode::config::standard())?; + Ok(()) +} diff --git a/core/src/util/mod.rs b/core/src/util/mod.rs new file mode 100644 index 0000000..a3c3f50 --- /dev/null +++ b/core/src/util/mod.rs @@ -0,0 +1,19 @@ +mod decode; +mod encode; +mod path; + +pub use decode::decode; +pub use encode::{encode, encode_into_slice}; +pub use path::{home_dir, tilde_expand}; + +use rand::{rngs::OsRng, Rng}; + +/// Generates and returns a random u32 using `rand::rngs::OsRng`. +pub fn random_32() -> u32 { + OsRng.gen() +} + +/// Generates and returns a random u16 using `rand::rngs::OsRng`. +pub fn random_16() -> u16 { + OsRng.gen() +} diff --git a/core/src/util/path.rs b/core/src/util/path.rs new file mode 100644 index 0000000..2cd900a --- /dev/null +++ b/core/src/util/path.rs @@ -0,0 +1,39 @@ +use std::path::PathBuf; + +use crate::{error::Error, Result}; + +/// Returns the user's home directory as a `PathBuf`. +#[allow(dead_code)] +pub fn home_dir() -> Result { + dirs::home_dir().ok_or(Error::PathNotFound("Home dir not found")) +} + +/// Expands a tilde (~) in a path and returns the expanded `PathBuf`. +#[allow(dead_code)] +pub fn tilde_expand(path: &str) -> Result { + match path { + "~" => home_dir(), + p if p.starts_with("~/") => Ok(home_dir()?.join(&path[2..])), + _ => Ok(PathBuf::from(path)), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tilde_expand() { + let path = "~/src"; + let expanded_path = dirs::home_dir().unwrap().join("src"); + assert_eq!(tilde_expand(path).unwrap(), expanded_path); + + let path = "~"; + let expanded_path = dirs::home_dir().unwrap(); + assert_eq!(tilde_expand(path).unwrap(), expanded_path); + + let path = ""; + let expanded_path = PathBuf::from(""); + assert_eq!(tilde_expand(path).unwrap(), expanded_path); + } +} diff --git a/core/src/utils/decode.rs b/core/src/utils/decode.rs deleted file mode 100644 index a8a6522..0000000 --- a/core/src/utils/decode.rs +++ /dev/null @@ -1,10 +0,0 @@ -use bincode::Decode; - -use crate::Result; - -/// Decodes a given type `T` from the given slice. returns the decoded value -/// along with the number of bytes read. -pub fn decode(src: &[u8]) -> Result<(T, usize)> { - let (result, bytes_read) = bincode::decode_from_slice(src, bincode::config::standard())?; - Ok((result, bytes_read)) -} diff --git a/core/src/utils/encode.rs b/core/src/utils/encode.rs deleted file mode 100644 index 7d1061b..0000000 --- a/core/src/utils/encode.rs +++ /dev/null @@ -1,15 +0,0 @@ -use bincode::Encode; - -use crate::Result; - -/// Encode the given type `T` into a `Vec`. -pub fn encode(msg: &T) -> Result> { - let vec = bincode::encode_to_vec(msg, bincode::config::standard())?; - Ok(vec) -} - -/// Encode the given type `T` into the given slice.. -pub fn encode_into_slice(msg: &T, dst: &mut [u8]) -> Result<()> { - bincode::encode_into_slice(msg, dst, bincode::config::standard())?; - Ok(()) -} diff --git a/core/src/utils/mod.rs b/core/src/utils/mod.rs deleted file mode 100644 index a3c3f50..0000000 --- a/core/src/utils/mod.rs +++ /dev/null @@ -1,19 +0,0 @@ -mod decode; -mod encode; -mod path; - -pub use decode::decode; -pub use encode::{encode, encode_into_slice}; -pub use path::{home_dir, tilde_expand}; - -use rand::{rngs::OsRng, Rng}; - -/// Generates and returns a random u32 using `rand::rngs::OsRng`. -pub fn random_32() -> u32 { - OsRng.gen() -} - -/// Generates and returns a random u16 using `rand::rngs::OsRng`. -pub fn random_16() -> u16 { - OsRng.gen() -} diff --git a/core/src/utils/path.rs b/core/src/utils/path.rs deleted file mode 100644 index 2cd900a..0000000 --- a/core/src/utils/path.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::path::PathBuf; - -use crate::{error::Error, Result}; - -/// Returns the user's home directory as a `PathBuf`. -#[allow(dead_code)] -pub fn home_dir() -> Result { - dirs::home_dir().ok_or(Error::PathNotFound("Home dir not found")) -} - -/// Expands a tilde (~) in a path and returns the expanded `PathBuf`. -#[allow(dead_code)] -pub fn tilde_expand(path: &str) -> Result { - match path { - "~" => home_dir(), - p if p.starts_with("~/") => Ok(home_dir()?.join(&path[2..])), - _ => Ok(PathBuf::from(path)), - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_tilde_expand() { - let path = "~/src"; - let expanded_path = dirs::home_dir().unwrap().join("src"); - assert_eq!(tilde_expand(path).unwrap(), expanded_path); - - let path = "~"; - let expanded_path = dirs::home_dir().unwrap(); - assert_eq!(tilde_expand(path).unwrap(), expanded_path); - - let path = ""; - let expanded_path = PathBuf::from(""); - assert_eq!(tilde_expand(path).unwrap(), expanded_path); - } -} -- cgit v1.2.3