aboutsummaryrefslogtreecommitdiff
path: root/core/src/async_util
diff options
context:
space:
mode:
authorhozan23 <hozan23@proton.me>2023-11-28 22:41:33 +0300
committerhozan23 <hozan23@proton.me>2023-11-28 22:41:33 +0300
commit98a1de91a2dae06323558422c239e5a45fc86e7b (patch)
tree38c640248824fcb3b4ca5ba12df47c13ef26ccda /core/src/async_util
parentca2a5f8bbb6983d9555abd10eaaf86950b794957 (diff)
implement TLS for inbound and outbound connections
Diffstat (limited to 'core/src/async_util')
-rw-r--r--core/src/async_util/backoff.rs115
-rw-r--r--core/src/async_util/condvar.rs387
-rw-r--r--core/src/async_util/condwait.rs133
-rw-r--r--core/src/async_util/mod.rs13
-rw-r--r--core/src/async_util/select.rs99
-rw-r--r--core/src/async_util/task_group.rs194
-rw-r--r--core/src/async_util/timeout.rs52
7 files changed, 993 insertions, 0 deletions
diff --git a/core/src/async_util/backoff.rs b/core/src/async_util/backoff.rs
new file mode 100644
index 0000000..a231229
--- /dev/null
+++ b/core/src/async_util/backoff.rs
@@ -0,0 +1,115 @@
+use std::{
+ cmp::min,
+ sync::atomic::{AtomicBool, AtomicU32, Ordering},
+ time::Duration,
+};
+
+use smol::Timer;
+
+/// Exponential backoff
+/// <https://en.wikipedia.org/wiki/Exponential_backoff>
+///
+/// # Examples
+///
+/// ```
+/// use karyons_core::async_util::Backoff;
+///
+/// async {
+/// let backoff = Backoff::new(300, 3000);
+///
+/// loop {
+/// backoff.sleep().await;
+///
+/// // do something
+/// break;
+/// }
+///
+/// backoff.reset();
+///
+/// // ....
+/// };
+///
+/// ```
+///
+pub struct Backoff {
+ /// The base delay in milliseconds for the initial retry.
+ base_delay: u64,
+ /// The max delay in milliseconds allowed for a retry.
+ max_delay: u64,
+ /// Atomic counter
+ retries: AtomicU32,
+ /// Stop flag
+ stop: AtomicBool,
+}
+
+impl Backoff {
+ /// Creates a new Backoff.
+ pub fn new(base_delay: u64, max_delay: u64) -> Self {
+ Self {
+ base_delay,
+ max_delay,
+ retries: AtomicU32::new(0),
+ stop: AtomicBool::new(false),
+ }
+ }
+
+ /// Sleep based on the current retry count and delay values.
+ /// Retruns the delay value.
+ pub async fn sleep(&self) -> u64 {
+ if self.stop.load(Ordering::SeqCst) {
+ Timer::after(Duration::from_millis(self.max_delay)).await;
+ return self.max_delay;
+ }
+
+ let retries = self.retries.load(Ordering::SeqCst);
+ let delay = self.base_delay * (2_u64).pow(retries);
+ let delay = min(delay, self.max_delay);
+
+ if delay == self.max_delay {
+ self.stop.store(true, Ordering::SeqCst);
+ }
+
+ self.retries.store(retries + 1, Ordering::SeqCst);
+
+ Timer::after(Duration::from_millis(delay)).await;
+ delay
+ }
+
+ /// Reset the retry counter to 0.
+ pub fn reset(&self) {
+ self.retries.store(0, Ordering::SeqCst);
+ self.stop.store(false, Ordering::SeqCst);
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::sync::Arc;
+
+ #[test]
+ fn test_backoff() {
+ smol::block_on(async move {
+ let backoff = Arc::new(Backoff::new(5, 15));
+ let backoff_c = backoff.clone();
+ smol::spawn(async move {
+ let delay = backoff_c.sleep().await;
+ assert_eq!(delay, 5);
+
+ let delay = backoff_c.sleep().await;
+ assert_eq!(delay, 10);
+
+ let delay = backoff_c.sleep().await;
+ assert_eq!(delay, 15);
+ })
+ .await;
+
+ smol::spawn(async move {
+ backoff.reset();
+ let delay = backoff.sleep().await;
+ assert_eq!(delay, 5);
+ })
+ .await;
+ });
+ }
+}
diff --git a/core/src/async_util/condvar.rs b/core/src/async_util/condvar.rs
new file mode 100644
index 0000000..7396d0d
--- /dev/null
+++ b/core/src/async_util/condvar.rs
@@ -0,0 +1,387 @@
+use std::{
+ collections::HashMap,
+ future::Future,
+ pin::Pin,
+ sync::Mutex,
+ task::{Context, Poll, Waker},
+};
+
+use smol::lock::MutexGuard;
+
+use crate::util::random_16;
+
+/// CondVar is an async version of <https://doc.rust-lang.org/std/sync/struct.Condvar.html>
+///
+/// # Example
+///
+///```
+/// use std::sync::Arc;
+///
+/// use smol::lock::Mutex;
+///
+/// use karyons_core::async_util::CondVar;
+///
+/// async {
+///
+/// let val = Arc::new(Mutex::new(false));
+/// let condvar = Arc::new(CondVar::new());
+///
+/// let val_cloned = val.clone();
+/// let condvar_cloned = condvar.clone();
+/// smol::spawn(async move {
+/// let mut val = val_cloned.lock().await;
+///
+/// // While the boolean flag is false, wait for a signal.
+/// while !*val {
+/// val = condvar_cloned.wait(val).await;
+/// }
+///
+/// // ...
+/// });
+///
+/// let condvar_cloned = condvar.clone();
+/// smol::spawn(async move {
+/// let mut val = val.lock().await;
+///
+/// // While the boolean flag is false, wait for a signal.
+/// while !*val {
+/// val = condvar_cloned.wait(val).await;
+/// }
+///
+/// // ...
+/// });
+///
+/// // Wake up all waiting tasks on this condvar
+/// condvar.broadcast();
+/// };
+///
+/// ```
+
+pub struct CondVar {
+ inner: Mutex<Wakers>,
+}
+
+impl CondVar {
+ /// Creates a new CondVar
+ pub fn new() -> Self {
+ Self {
+ inner: Mutex::new(Wakers::new()),
+ }
+ }
+
+ /// Blocks the current task until this condition variable receives a notification.
+ pub async fn wait<'a, T>(&self, g: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
+ let m = MutexGuard::source(&g);
+
+ CondVarAwait::new(self, g).await;
+
+ m.lock().await
+ }
+
+ /// Wakes up one blocked task waiting on this condvar.
+ pub fn signal(&self) {
+ self.inner.lock().unwrap().wake(true);
+ }
+
+ /// Wakes up all blocked tasks waiting on this condvar.
+ pub fn broadcast(&self) {
+ self.inner.lock().unwrap().wake(false);
+ }
+}
+
+impl Default for CondVar {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+struct CondVarAwait<'a, T> {
+ id: Option<u16>,
+ condvar: &'a CondVar,
+ guard: Option<MutexGuard<'a, T>>,
+}
+
+impl<'a, T> CondVarAwait<'a, T> {
+ fn new(condvar: &'a CondVar, guard: MutexGuard<'a, T>) -> Self {
+ Self {
+ condvar,
+ guard: Some(guard),
+ id: None,
+ }
+ }
+}
+
+impl<'a, T> Future for CondVarAwait<'a, T> {
+ type Output = ();
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let mut inner = self.condvar.inner.lock().unwrap();
+
+ match self.guard.take() {
+ Some(_) => {
+ // the first pooll will release the Mutexguard
+ self.id = Some(inner.put(Some(cx.waker().clone())));
+ Poll::Pending
+ }
+ None => {
+ // Return Ready if it has already been polled and removed
+ // from the waker list.
+ if self.id.is_none() {
+ return Poll::Ready(());
+ }
+
+ let i = self.id.as_ref().unwrap();
+ match inner.wakers.get_mut(i).unwrap() {
+ Some(wk) => {
+ // This will prevent cloning again
+ if !wk.will_wake(cx.waker()) {
+ *wk = cx.waker().clone();
+ }
+ Poll::Pending
+ }
+ None => {
+ inner.delete(i);
+ self.id = None;
+ Poll::Ready(())
+ }
+ }
+ }
+ }
+ }
+}
+
+impl<'a, T> Drop for CondVarAwait<'a, T> {
+ fn drop(&mut self) {
+ if let Some(id) = self.id {
+ let mut inner = self.condvar.inner.lock().unwrap();
+ if let Some(wk) = inner.wakers.get_mut(&id).unwrap().take() {
+ wk.wake()
+ }
+ }
+ }
+}
+
+/// Wakers is a helper struct to store the task wakers
+struct Wakers {
+ wakers: HashMap<u16, Option<Waker>>,
+}
+
+impl Wakers {
+ fn new() -> Self {
+ Self {
+ wakers: HashMap::new(),
+ }
+ }
+
+ fn put(&mut self, waker: Option<Waker>) -> u16 {
+ let mut id: u16;
+
+ id = random_16();
+ while self.wakers.contains_key(&id) {
+ id = random_16();
+ }
+
+ self.wakers.insert(id, waker);
+ id
+ }
+
+ fn delete(&mut self, id: &u16) -> Option<Option<Waker>> {
+ self.wakers.remove(id)
+ }
+
+ fn wake(&mut self, signal: bool) {
+ for (_, wk) in self.wakers.iter_mut() {
+ match wk.take() {
+ Some(w) => {
+ w.wake();
+ if signal {
+ break;
+ }
+ }
+ None => continue,
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use smol::lock::Mutex;
+ use std::{
+ collections::VecDeque,
+ sync::{
+ atomic::{AtomicUsize, Ordering},
+ Arc,
+ },
+ };
+
+ // The tests below demonstrate a solution to a problem in the Wikipedia
+ // explanation of condition variables:
+ // https://en.wikipedia.org/wiki/Monitor_(synchronization)#Solving_the_bounded_producer/consumer_problem.
+
+ struct Queue {
+ items: VecDeque<String>,
+ max_len: usize,
+ }
+ impl Queue {
+ fn new(max_len: usize) -> Self {
+ Self {
+ items: VecDeque::new(),
+ max_len,
+ }
+ }
+
+ fn is_full(&self) -> bool {
+ self.items.len() == self.max_len
+ }
+
+ fn is_empty(&self) -> bool {
+ self.items.is_empty()
+ }
+ }
+
+ #[test]
+ fn test_condvar_signal() {
+ smol::block_on(async {
+ let number_of_tasks = 30;
+
+ let queue = Arc::new(Mutex::new(Queue::new(5)));
+ let condvar_full = Arc::new(CondVar::new());
+ let condvar_empty = Arc::new(CondVar::new());
+
+ let queue_cloned = queue.clone();
+ let condvar_full_cloned = condvar_full.clone();
+ let condvar_empty_cloned = condvar_empty.clone();
+
+ let _producer1 = smol::spawn(async move {
+ for i in 1..number_of_tasks {
+ // Lock queue mtuex
+ let mut queue = queue_cloned.lock().await;
+
+ // Check if the queue is non-full
+ while queue.is_full() {
+ // Release queue mutex and sleep
+ queue = condvar_full_cloned.wait(queue).await;
+ }
+
+ queue.items.push_back(format!("task {i}"));
+
+ // Wake up the consumer
+ condvar_empty_cloned.signal();
+ }
+ });
+
+ let queue_cloned = queue.clone();
+ let task_consumed = Arc::new(AtomicUsize::new(0));
+ let task_consumed_ = task_consumed.clone();
+ let consumer = smol::spawn(async move {
+ for _ in 1..number_of_tasks {
+ // Lock queue mtuex
+ let mut queue = queue_cloned.lock().await;
+
+ // Check if the queue is non-empty
+ while queue.is_empty() {
+ // Release queue mutex and sleep
+ queue = condvar_empty.wait(queue).await;
+ }
+
+ let _ = queue.items.pop_front().unwrap();
+
+ task_consumed_.fetch_add(1, Ordering::Relaxed);
+
+ // Do something
+
+ // Wake up the producer
+ condvar_full.signal();
+ }
+ });
+
+ consumer.await;
+ assert!(queue.lock().await.is_empty());
+ assert_eq!(task_consumed.load(Ordering::Relaxed), 29);
+ });
+ }
+
+ #[test]
+ fn test_condvar_broadcast() {
+ smol::block_on(async {
+ let tasks = 30;
+
+ let queue = Arc::new(Mutex::new(Queue::new(5)));
+ let condvar = Arc::new(CondVar::new());
+
+ let queue_cloned = queue.clone();
+ let condvar_cloned = condvar.clone();
+ let _producer1 = smol::spawn(async move {
+ for i in 1..tasks {
+ // Lock queue mtuex
+ let mut queue = queue_cloned.lock().await;
+
+ // Check if the queue is non-full
+ while queue.is_full() {
+ // Release queue mutex and sleep
+ queue = condvar_cloned.wait(queue).await;
+ }
+
+ queue.items.push_back(format!("producer1: task {i}"));
+
+ // Wake up all producer and consumer tasks
+ condvar_cloned.broadcast();
+ }
+ });
+
+ let queue_cloned = queue.clone();
+ let condvar_cloned = condvar.clone();
+ let _producer2 = smol::spawn(async move {
+ for i in 1..tasks {
+ // Lock queue mtuex
+ let mut queue = queue_cloned.lock().await;
+
+ // Check if the queue is non-full
+ while queue.is_full() {
+ // Release queue mutex and sleep
+ queue = condvar_cloned.wait(queue).await;
+ }
+
+ queue.items.push_back(format!("producer2: task {i}"));
+
+ // Wake up all producer and consumer tasks
+ condvar_cloned.broadcast();
+ }
+ });
+
+ let queue_cloned = queue.clone();
+ let task_consumed = Arc::new(AtomicUsize::new(0));
+ let task_consumed_ = task_consumed.clone();
+
+ let consumer = smol::spawn(async move {
+ for _ in 1..((tasks * 2) - 1) {
+ {
+ // Lock queue mutex
+ let mut queue = queue_cloned.lock().await;
+
+ // Check if the queue is non-empty
+ while queue.is_empty() {
+ // Release queue mutex and sleep
+ queue = condvar.wait(queue).await;
+ }
+
+ let _ = queue.items.pop_front().unwrap();
+
+ task_consumed_.fetch_add(1, Ordering::Relaxed);
+
+ // Do something
+
+ // Wake up all producer and consumer tasks
+ condvar.broadcast();
+ }
+ }
+ });
+
+ consumer.await;
+ assert!(queue.lock().await.is_empty());
+ assert_eq!(task_consumed.load(Ordering::Relaxed), 58);
+ });
+ }
+}
diff --git a/core/src/async_util/condwait.rs b/core/src/async_util/condwait.rs
new file mode 100644
index 0000000..cd4b269
--- /dev/null
+++ b/core/src/async_util/condwait.rs
@@ -0,0 +1,133 @@
+use smol::lock::Mutex;
+
+use super::CondVar;
+
+/// CondWait is a wrapper struct for CondVar with a Mutex boolean flag.
+///
+/// # Example
+///
+///```
+/// use std::sync::Arc;
+///
+/// use karyons_core::async_util::CondWait;
+///
+/// async {
+/// let cond_wait = Arc::new(CondWait::new());
+/// let cond_wait_cloned = cond_wait.clone();
+/// let task = smol::spawn(async move {
+/// cond_wait_cloned.wait().await;
+/// // ...
+/// });
+///
+/// cond_wait.signal().await;
+/// };
+///
+/// ```
+///
+pub struct CondWait {
+ /// The CondVar
+ condvar: CondVar,
+ /// Boolean flag
+ w: Mutex<bool>,
+}
+
+impl CondWait {
+ /// Creates a new CondWait.
+ pub fn new() -> Self {
+ Self {
+ condvar: CondVar::new(),
+ w: Mutex::new(false),
+ }
+ }
+
+ /// Waits for a signal or broadcast.
+ pub async fn wait(&self) {
+ let mut w = self.w.lock().await;
+
+ // While the boolean flag is false, wait for a signal.
+ while !*w {
+ w = self.condvar.wait(w).await;
+ }
+ }
+
+ /// Signal a waiting task.
+ pub async fn signal(&self) {
+ *self.w.lock().await = true;
+ self.condvar.signal();
+ }
+
+ /// Signal all waiting tasks.
+ pub async fn broadcast(&self) {
+ *self.w.lock().await = true;
+ self.condvar.broadcast();
+ }
+
+ /// Reset the boolean flag value to false.
+ pub async fn reset(&self) {
+ *self.w.lock().await = false;
+ }
+}
+
+impl Default for CondWait {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::sync::{
+ atomic::{AtomicUsize, Ordering},
+ Arc,
+ };
+
+ #[test]
+ fn test_cond_wait() {
+ smol::block_on(async {
+ let cond_wait = Arc::new(CondWait::new());
+ let count = Arc::new(AtomicUsize::new(0));
+
+ let cond_wait_cloned = cond_wait.clone();
+ let count_cloned = count.clone();
+ let task = smol::spawn(async move {
+ cond_wait_cloned.wait().await;
+ count_cloned.fetch_add(1, Ordering::Relaxed);
+ // do something
+ });
+
+ // Send a signal to the waiting task
+ cond_wait.signal().await;
+
+ task.await;
+
+ // Reset the boolean flag
+ cond_wait.reset().await;
+
+ assert_eq!(count.load(Ordering::Relaxed), 1);
+
+ let cond_wait_cloned = cond_wait.clone();
+ let count_cloned = count.clone();
+ let task1 = smol::spawn(async move {
+ cond_wait_cloned.wait().await;
+ count_cloned.fetch_add(1, Ordering::Relaxed);
+ // do something
+ });
+
+ let cond_wait_cloned = cond_wait.clone();
+ let count_cloned = count.clone();
+ let task2 = smol::spawn(async move {
+ cond_wait_cloned.wait().await;
+ count_cloned.fetch_add(1, Ordering::Relaxed);
+ // do something
+ });
+
+ // Broadcast a signal to all waiting tasks
+ cond_wait.broadcast().await;
+
+ task1.await;
+ task2.await;
+ assert_eq!(count.load(Ordering::Relaxed), 3);
+ });
+ }
+}
diff --git a/core/src/async_util/mod.rs b/core/src/async_util/mod.rs
new file mode 100644
index 0000000..c871bad
--- /dev/null
+++ b/core/src/async_util/mod.rs
@@ -0,0 +1,13 @@
+mod backoff;
+mod condvar;
+mod condwait;
+mod select;
+mod task_group;
+mod timeout;
+
+pub use backoff::Backoff;
+pub use condvar::CondVar;
+pub use condwait::CondWait;
+pub use select::{select, Either};
+pub use task_group::{TaskGroup, TaskResult};
+pub use timeout::timeout;
diff --git a/core/src/async_util/select.rs b/core/src/async_util/select.rs
new file mode 100644
index 0000000..8f2f7f6
--- /dev/null
+++ b/core/src/async_util/select.rs
@@ -0,0 +1,99 @@
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+use pin_project_lite::pin_project;
+use smol::future::Future;
+
+/// Returns the result of the future that completes first, preferring future1
+/// if both are ready.
+///
+/// # Examples
+///
+/// ```
+/// use std::future;
+///
+/// use karyons_core::async_util::{select, Either};
+///
+/// async {
+/// let fut1 = future::pending::<String>();
+/// let fut2 = future::ready(0);
+/// let res = select(fut1, fut2).await;
+/// assert!(matches!(res, Either::Right(0)));
+/// // ....
+/// };
+///
+/// ```
+///
+pub fn select<T1, T2, F1, F2>(future1: F1, future2: F2) -> Select<F1, F2>
+where
+ F1: Future<Output = T1>,
+ F2: Future<Output = T2>,
+{
+ Select { future1, future2 }
+}
+
+pin_project! {
+ #[derive(Debug)]
+ pub struct Select<F1, F2> {
+ #[pin]
+ future1: F1,
+ #[pin]
+ future2: F2,
+ }
+}
+
+/// The return value from the [`select`] function, indicating which future
+/// completed first.
+#[derive(Debug)]
+pub enum Either<T1, T2> {
+ Left(T1),
+ Right(T2),
+}
+
+// Implement the Future trait for the Select struct.
+impl<T1, T2, F1, F2> Future for Select<F1, F2>
+where
+ F1: Future<Output = T1>,
+ F2: Future<Output = T2>,
+{
+ type Output = Either<T1, T2>;
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ let this = self.project();
+
+ if let Poll::Ready(t) = this.future1.poll(cx) {
+ return Poll::Ready(Either::Left(t));
+ }
+
+ if let Poll::Ready(t) = this.future2.poll(cx) {
+ return Poll::Ready(Either::Right(t));
+ }
+
+ Poll::Pending
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{select, Either};
+ use smol::Timer;
+ use std::future;
+
+ #[test]
+ fn test_async_select() {
+ smol::block_on(async move {
+ let fut = select(Timer::never(), future::ready(0 as u32)).await;
+ assert!(matches!(fut, Either::Right(0)));
+
+ let fut1 = future::pending::<String>();
+ let fut2 = future::ready(0);
+ let res = select(fut1, fut2).await;
+ assert!(matches!(res, Either::Right(0)));
+
+ let fut1 = future::ready(0);
+ let fut2 = future::pending::<String>();
+ let res = select(fut1, fut2).await;
+ assert!(matches!(res, Either::Left(_)));
+ });
+ }
+}
diff --git a/core/src/async_util/task_group.rs b/core/src/async_util/task_group.rs
new file mode 100644
index 0000000..3fc0cb7
--- /dev/null
+++ b/core/src/async_util/task_group.rs
@@ -0,0 +1,194 @@
+use std::{future::Future, sync::Arc, sync::Mutex};
+
+use async_task::FallibleTask;
+
+use crate::Executor;
+
+use super::{select, CondWait, Either};
+
+/// TaskGroup is a group of spawned tasks.
+///
+/// # Example
+///
+/// ```
+///
+/// use std::sync::Arc;
+///
+/// use karyons_core::async_util::TaskGroup;
+///
+/// async {
+///
+/// let ex = Arc::new(smol::Executor::new());
+/// let group = TaskGroup::new(ex);
+///
+/// group.spawn(smol::Timer::never(), |_| async {});
+///
+/// group.cancel().await;
+///
+/// };
+///
+/// ```
+///
+pub struct TaskGroup<'a> {
+ tasks: Mutex<Vec<TaskHandler>>,
+ stop_signal: Arc<CondWait>,
+ executor: Executor<'a>,
+}
+
+impl<'a> TaskGroup<'a> {
+ /// Creates a new task group
+ pub fn new(executor: Executor<'a>) -> Self {
+ Self {
+ tasks: Mutex::new(Vec::new()),
+ stop_signal: Arc::new(CondWait::new()),
+ executor,
+ }
+ }
+
+ /// Spawns a new task and calls the callback after it has completed
+ /// or been canceled. The callback will have the `TaskResult` as a
+ /// parameter, indicating whether the task completed or was canceled.
+ pub fn spawn<T, Fut, CallbackF, CallbackFut>(&self, fut: Fut, callback: CallbackF)
+ where
+ T: Send + Sync + 'a,
+ Fut: Future<Output = T> + Send + 'a,
+ CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'a,
+ CallbackFut: Future<Output = ()> + Send + 'a,
+ {
+ let task = TaskHandler::new(
+ self.executor.clone(),
+ fut,
+ callback,
+ self.stop_signal.clone(),
+ );
+ self.tasks.lock().unwrap().push(task);
+ }
+
+ /// Checks if the task group is empty.
+ pub fn is_empty(&self) -> bool {
+ self.tasks.lock().unwrap().is_empty()
+ }
+
+ /// Get the number of the tasks in the group.
+ pub fn len(&self) -> usize {
+ self.tasks.lock().unwrap().len()
+ }
+
+ /// Cancels all tasks in the group.
+ pub async fn cancel(&self) {
+ self.stop_signal.broadcast().await;
+
+ loop {
+ let task = self.tasks.lock().unwrap().pop();
+ if let Some(t) = task {
+ t.cancel().await
+ } else {
+ break;
+ }
+ }
+ }
+}
+
+/// The result of a spawned task.
+#[derive(Debug)]
+pub enum TaskResult<T> {
+ Completed(T),
+ Cancelled,
+}
+
+impl<T: std::fmt::Debug> std::fmt::Display for TaskResult<T> {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ match self {
+ TaskResult::Cancelled => write!(f, "Task cancelled"),
+ TaskResult::Completed(res) => write!(f, "Task completed: {:?}", res),
+ }
+ }
+}
+
+/// TaskHandler
+pub struct TaskHandler {
+ task: FallibleTask<()>,
+ cancel_flag: Arc<CondWait>,
+}
+
+impl<'a> TaskHandler {
+ /// Creates a new task handle
+ fn new<T, Fut, CallbackF, CallbackFut>(
+ ex: Executor<'a>,
+ fut: Fut,
+ callback: CallbackF,
+ stop_signal: Arc<CondWait>,
+ ) -> TaskHandler
+ where
+ T: Send + Sync + 'a,
+ Fut: Future<Output = T> + Send + 'a,
+ CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'a,
+ CallbackFut: Future<Output = ()> + Send + 'a,
+ {
+ let cancel_flag = Arc::new(CondWait::new());
+ let cancel_flag_c = cancel_flag.clone();
+ let task = ex
+ .spawn(async move {
+ //start_signal.signal().await;
+ // Waits for either the stop signal or the task to complete.
+ let result = select(stop_signal.wait(), fut).await;
+
+ let result = match result {
+ Either::Left(_) => TaskResult::Cancelled,
+ Either::Right(res) => TaskResult::Completed(res),
+ };
+
+ // Call the callback with the result.
+ callback(result).await;
+
+ cancel_flag_c.signal().await;
+ })
+ .fallible();
+
+ TaskHandler { task, cancel_flag }
+ }
+
+ /// Cancels the task.
+ async fn cancel(self) {
+ self.cancel_flag.wait().await;
+ self.task.cancel().await;
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::{future, sync::Arc};
+
+ #[test]
+ fn test_task_group() {
+ let ex = Arc::new(smol::Executor::new());
+ smol::block_on(ex.clone().run(async move {
+ let group = Arc::new(TaskGroup::new(ex));
+
+ group.spawn(future::ready(0), |res| async move {
+ assert!(matches!(res, TaskResult::Completed(0)));
+ });
+
+ group.spawn(future::pending::<()>(), |res| async move {
+ assert!(matches!(res, TaskResult::Cancelled));
+ });
+
+ let groupc = group.clone();
+ group.spawn(
+ async move {
+ groupc.spawn(future::pending::<()>(), |res| async move {
+ assert!(matches!(res, TaskResult::Cancelled));
+ });
+ },
+ |res| async move {
+ assert!(matches!(res, TaskResult::Completed(_)));
+ },
+ );
+
+ // Do something
+ smol::Timer::after(std::time::Duration::from_millis(50)).await;
+ group.cancel().await;
+ }));
+ }
+}
diff --git a/core/src/async_util/timeout.rs b/core/src/async_util/timeout.rs
new file mode 100644
index 0000000..6ab35c4
--- /dev/null
+++ b/core/src/async_util/timeout.rs
@@ -0,0 +1,52 @@
+use std::{future::Future, time::Duration};
+
+use smol::Timer;
+
+use super::{select, Either};
+use crate::{error::Error, Result};
+
+/// Waits for a future to complete or times out if it exceeds a specified
+/// duration.
+///
+/// # Example
+///
+/// ```
+/// use std::{future, time::Duration};
+///
+/// use karyons_core::async_util::timeout;
+///
+/// async {
+/// let fut = future::pending::<()>();
+/// assert!(timeout(Duration::from_millis(100), fut).await.is_err());
+/// };
+///
+/// ```
+///
+pub async fn timeout<T, F>(delay: Duration, future1: F) -> Result<T>
+where
+ F: Future<Output = T>,
+{
+ let result = select(Timer::after(delay), future1).await;
+
+ match result {
+ Either::Left(_) => Err(Error::Timeout),
+ Either::Right(res) => Ok(res),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::{future, time::Duration};
+
+ #[test]
+ fn test_timeout() {
+ smol::block_on(async move {
+ let fut = future::pending::<()>();
+ assert!(timeout(Duration::from_millis(10), fut).await.is_err());
+
+ let fut = smol::Timer::after(Duration::from_millis(10));
+ assert!(timeout(Duration::from_millis(50), fut).await.is_ok())
+ });
+ }
+}