diff options
author | hozan23 <hozan23@proton.me> | 2023-11-28 22:41:33 +0300 |
---|---|---|
committer | hozan23 <hozan23@proton.me> | 2023-11-28 22:41:33 +0300 |
commit | 98a1de91a2dae06323558422c239e5a45fc86e7b (patch) | |
tree | 38c640248824fcb3b4ca5ba12df47c13ef26ccda /core/src/async_util | |
parent | ca2a5f8bbb6983d9555abd10eaaf86950b794957 (diff) |
implement TLS for inbound and outbound connections
Diffstat (limited to 'core/src/async_util')
-rw-r--r-- | core/src/async_util/backoff.rs | 115 | ||||
-rw-r--r-- | core/src/async_util/condvar.rs | 387 | ||||
-rw-r--r-- | core/src/async_util/condwait.rs | 133 | ||||
-rw-r--r-- | core/src/async_util/mod.rs | 13 | ||||
-rw-r--r-- | core/src/async_util/select.rs | 99 | ||||
-rw-r--r-- | core/src/async_util/task_group.rs | 194 | ||||
-rw-r--r-- | core/src/async_util/timeout.rs | 52 |
7 files changed, 993 insertions, 0 deletions
diff --git a/core/src/async_util/backoff.rs b/core/src/async_util/backoff.rs new file mode 100644 index 0000000..a231229 --- /dev/null +++ b/core/src/async_util/backoff.rs @@ -0,0 +1,115 @@ +use std::{ + cmp::min, + sync::atomic::{AtomicBool, AtomicU32, Ordering}, + time::Duration, +}; + +use smol::Timer; + +/// Exponential backoff +/// <https://en.wikipedia.org/wiki/Exponential_backoff> +/// +/// # Examples +/// +/// ``` +/// use karyons_core::async_util::Backoff; +/// +/// async { +/// let backoff = Backoff::new(300, 3000); +/// +/// loop { +/// backoff.sleep().await; +/// +/// // do something +/// break; +/// } +/// +/// backoff.reset(); +/// +/// // .... +/// }; +/// +/// ``` +/// +pub struct Backoff { + /// The base delay in milliseconds for the initial retry. + base_delay: u64, + /// The max delay in milliseconds allowed for a retry. + max_delay: u64, + /// Atomic counter + retries: AtomicU32, + /// Stop flag + stop: AtomicBool, +} + +impl Backoff { + /// Creates a new Backoff. + pub fn new(base_delay: u64, max_delay: u64) -> Self { + Self { + base_delay, + max_delay, + retries: AtomicU32::new(0), + stop: AtomicBool::new(false), + } + } + + /// Sleep based on the current retry count and delay values. + /// Retruns the delay value. + pub async fn sleep(&self) -> u64 { + if self.stop.load(Ordering::SeqCst) { + Timer::after(Duration::from_millis(self.max_delay)).await; + return self.max_delay; + } + + let retries = self.retries.load(Ordering::SeqCst); + let delay = self.base_delay * (2_u64).pow(retries); + let delay = min(delay, self.max_delay); + + if delay == self.max_delay { + self.stop.store(true, Ordering::SeqCst); + } + + self.retries.store(retries + 1, Ordering::SeqCst); + + Timer::after(Duration::from_millis(delay)).await; + delay + } + + /// Reset the retry counter to 0. + pub fn reset(&self) { + self.retries.store(0, Ordering::SeqCst); + self.stop.store(false, Ordering::SeqCst); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn test_backoff() { + smol::block_on(async move { + let backoff = Arc::new(Backoff::new(5, 15)); + let backoff_c = backoff.clone(); + smol::spawn(async move { + let delay = backoff_c.sleep().await; + assert_eq!(delay, 5); + + let delay = backoff_c.sleep().await; + assert_eq!(delay, 10); + + let delay = backoff_c.sleep().await; + assert_eq!(delay, 15); + }) + .await; + + smol::spawn(async move { + backoff.reset(); + let delay = backoff.sleep().await; + assert_eq!(delay, 5); + }) + .await; + }); + } +} diff --git a/core/src/async_util/condvar.rs b/core/src/async_util/condvar.rs new file mode 100644 index 0000000..7396d0d --- /dev/null +++ b/core/src/async_util/condvar.rs @@ -0,0 +1,387 @@ +use std::{ + collections::HashMap, + future::Future, + pin::Pin, + sync::Mutex, + task::{Context, Poll, Waker}, +}; + +use smol::lock::MutexGuard; + +use crate::util::random_16; + +/// CondVar is an async version of <https://doc.rust-lang.org/std/sync/struct.Condvar.html> +/// +/// # Example +/// +///``` +/// use std::sync::Arc; +/// +/// use smol::lock::Mutex; +/// +/// use karyons_core::async_util::CondVar; +/// +/// async { +/// +/// let val = Arc::new(Mutex::new(false)); +/// let condvar = Arc::new(CondVar::new()); +/// +/// let val_cloned = val.clone(); +/// let condvar_cloned = condvar.clone(); +/// smol::spawn(async move { +/// let mut val = val_cloned.lock().await; +/// +/// // While the boolean flag is false, wait for a signal. +/// while !*val { +/// val = condvar_cloned.wait(val).await; +/// } +/// +/// // ... +/// }); +/// +/// let condvar_cloned = condvar.clone(); +/// smol::spawn(async move { +/// let mut val = val.lock().await; +/// +/// // While the boolean flag is false, wait for a signal. +/// while !*val { +/// val = condvar_cloned.wait(val).await; +/// } +/// +/// // ... +/// }); +/// +/// // Wake up all waiting tasks on this condvar +/// condvar.broadcast(); +/// }; +/// +/// ``` + +pub struct CondVar { + inner: Mutex<Wakers>, +} + +impl CondVar { + /// Creates a new CondVar + pub fn new() -> Self { + Self { + inner: Mutex::new(Wakers::new()), + } + } + + /// Blocks the current task until this condition variable receives a notification. + pub async fn wait<'a, T>(&self, g: MutexGuard<'a, T>) -> MutexGuard<'a, T> { + let m = MutexGuard::source(&g); + + CondVarAwait::new(self, g).await; + + m.lock().await + } + + /// Wakes up one blocked task waiting on this condvar. + pub fn signal(&self) { + self.inner.lock().unwrap().wake(true); + } + + /// Wakes up all blocked tasks waiting on this condvar. + pub fn broadcast(&self) { + self.inner.lock().unwrap().wake(false); + } +} + +impl Default for CondVar { + fn default() -> Self { + Self::new() + } +} + +struct CondVarAwait<'a, T> { + id: Option<u16>, + condvar: &'a CondVar, + guard: Option<MutexGuard<'a, T>>, +} + +impl<'a, T> CondVarAwait<'a, T> { + fn new(condvar: &'a CondVar, guard: MutexGuard<'a, T>) -> Self { + Self { + condvar, + guard: Some(guard), + id: None, + } + } +} + +impl<'a, T> Future for CondVarAwait<'a, T> { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let mut inner = self.condvar.inner.lock().unwrap(); + + match self.guard.take() { + Some(_) => { + // the first pooll will release the Mutexguard + self.id = Some(inner.put(Some(cx.waker().clone()))); + Poll::Pending + } + None => { + // Return Ready if it has already been polled and removed + // from the waker list. + if self.id.is_none() { + return Poll::Ready(()); + } + + let i = self.id.as_ref().unwrap(); + match inner.wakers.get_mut(i).unwrap() { + Some(wk) => { + // This will prevent cloning again + if !wk.will_wake(cx.waker()) { + *wk = cx.waker().clone(); + } + Poll::Pending + } + None => { + inner.delete(i); + self.id = None; + Poll::Ready(()) + } + } + } + } + } +} + +impl<'a, T> Drop for CondVarAwait<'a, T> { + fn drop(&mut self) { + if let Some(id) = self.id { + let mut inner = self.condvar.inner.lock().unwrap(); + if let Some(wk) = inner.wakers.get_mut(&id).unwrap().take() { + wk.wake() + } + } + } +} + +/// Wakers is a helper struct to store the task wakers +struct Wakers { + wakers: HashMap<u16, Option<Waker>>, +} + +impl Wakers { + fn new() -> Self { + Self { + wakers: HashMap::new(), + } + } + + fn put(&mut self, waker: Option<Waker>) -> u16 { + let mut id: u16; + + id = random_16(); + while self.wakers.contains_key(&id) { + id = random_16(); + } + + self.wakers.insert(id, waker); + id + } + + fn delete(&mut self, id: &u16) -> Option<Option<Waker>> { + self.wakers.remove(id) + } + + fn wake(&mut self, signal: bool) { + for (_, wk) in self.wakers.iter_mut() { + match wk.take() { + Some(w) => { + w.wake(); + if signal { + break; + } + } + None => continue, + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use smol::lock::Mutex; + use std::{ + collections::VecDeque, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + }; + + // The tests below demonstrate a solution to a problem in the Wikipedia + // explanation of condition variables: + // https://en.wikipedia.org/wiki/Monitor_(synchronization)#Solving_the_bounded_producer/consumer_problem. + + struct Queue { + items: VecDeque<String>, + max_len: usize, + } + impl Queue { + fn new(max_len: usize) -> Self { + Self { + items: VecDeque::new(), + max_len, + } + } + + fn is_full(&self) -> bool { + self.items.len() == self.max_len + } + + fn is_empty(&self) -> bool { + self.items.is_empty() + } + } + + #[test] + fn test_condvar_signal() { + smol::block_on(async { + let number_of_tasks = 30; + + let queue = Arc::new(Mutex::new(Queue::new(5))); + let condvar_full = Arc::new(CondVar::new()); + let condvar_empty = Arc::new(CondVar::new()); + + let queue_cloned = queue.clone(); + let condvar_full_cloned = condvar_full.clone(); + let condvar_empty_cloned = condvar_empty.clone(); + + let _producer1 = smol::spawn(async move { + for i in 1..number_of_tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-full + while queue.is_full() { + // Release queue mutex and sleep + queue = condvar_full_cloned.wait(queue).await; + } + + queue.items.push_back(format!("task {i}")); + + // Wake up the consumer + condvar_empty_cloned.signal(); + } + }); + + let queue_cloned = queue.clone(); + let task_consumed = Arc::new(AtomicUsize::new(0)); + let task_consumed_ = task_consumed.clone(); + let consumer = smol::spawn(async move { + for _ in 1..number_of_tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-empty + while queue.is_empty() { + // Release queue mutex and sleep + queue = condvar_empty.wait(queue).await; + } + + let _ = queue.items.pop_front().unwrap(); + + task_consumed_.fetch_add(1, Ordering::Relaxed); + + // Do something + + // Wake up the producer + condvar_full.signal(); + } + }); + + consumer.await; + assert!(queue.lock().await.is_empty()); + assert_eq!(task_consumed.load(Ordering::Relaxed), 29); + }); + } + + #[test] + fn test_condvar_broadcast() { + smol::block_on(async { + let tasks = 30; + + let queue = Arc::new(Mutex::new(Queue::new(5))); + let condvar = Arc::new(CondVar::new()); + + let queue_cloned = queue.clone(); + let condvar_cloned = condvar.clone(); + let _producer1 = smol::spawn(async move { + for i in 1..tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-full + while queue.is_full() { + // Release queue mutex and sleep + queue = condvar_cloned.wait(queue).await; + } + + queue.items.push_back(format!("producer1: task {i}")); + + // Wake up all producer and consumer tasks + condvar_cloned.broadcast(); + } + }); + + let queue_cloned = queue.clone(); + let condvar_cloned = condvar.clone(); + let _producer2 = smol::spawn(async move { + for i in 1..tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-full + while queue.is_full() { + // Release queue mutex and sleep + queue = condvar_cloned.wait(queue).await; + } + + queue.items.push_back(format!("producer2: task {i}")); + + // Wake up all producer and consumer tasks + condvar_cloned.broadcast(); + } + }); + + let queue_cloned = queue.clone(); + let task_consumed = Arc::new(AtomicUsize::new(0)); + let task_consumed_ = task_consumed.clone(); + + let consumer = smol::spawn(async move { + for _ in 1..((tasks * 2) - 1) { + { + // Lock queue mutex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-empty + while queue.is_empty() { + // Release queue mutex and sleep + queue = condvar.wait(queue).await; + } + + let _ = queue.items.pop_front().unwrap(); + + task_consumed_.fetch_add(1, Ordering::Relaxed); + + // Do something + + // Wake up all producer and consumer tasks + condvar.broadcast(); + } + } + }); + + consumer.await; + assert!(queue.lock().await.is_empty()); + assert_eq!(task_consumed.load(Ordering::Relaxed), 58); + }); + } +} diff --git a/core/src/async_util/condwait.rs b/core/src/async_util/condwait.rs new file mode 100644 index 0000000..cd4b269 --- /dev/null +++ b/core/src/async_util/condwait.rs @@ -0,0 +1,133 @@ +use smol::lock::Mutex; + +use super::CondVar; + +/// CondWait is a wrapper struct for CondVar with a Mutex boolean flag. +/// +/// # Example +/// +///``` +/// use std::sync::Arc; +/// +/// use karyons_core::async_util::CondWait; +/// +/// async { +/// let cond_wait = Arc::new(CondWait::new()); +/// let cond_wait_cloned = cond_wait.clone(); +/// let task = smol::spawn(async move { +/// cond_wait_cloned.wait().await; +/// // ... +/// }); +/// +/// cond_wait.signal().await; +/// }; +/// +/// ``` +/// +pub struct CondWait { + /// The CondVar + condvar: CondVar, + /// Boolean flag + w: Mutex<bool>, +} + +impl CondWait { + /// Creates a new CondWait. + pub fn new() -> Self { + Self { + condvar: CondVar::new(), + w: Mutex::new(false), + } + } + + /// Waits for a signal or broadcast. + pub async fn wait(&self) { + let mut w = self.w.lock().await; + + // While the boolean flag is false, wait for a signal. + while !*w { + w = self.condvar.wait(w).await; + } + } + + /// Signal a waiting task. + pub async fn signal(&self) { + *self.w.lock().await = true; + self.condvar.signal(); + } + + /// Signal all waiting tasks. + pub async fn broadcast(&self) { + *self.w.lock().await = true; + self.condvar.broadcast(); + } + + /// Reset the boolean flag value to false. + pub async fn reset(&self) { + *self.w.lock().await = false; + } +} + +impl Default for CondWait { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; + + #[test] + fn test_cond_wait() { + smol::block_on(async { + let cond_wait = Arc::new(CondWait::new()); + let count = Arc::new(AtomicUsize::new(0)); + + let cond_wait_cloned = cond_wait.clone(); + let count_cloned = count.clone(); + let task = smol::spawn(async move { + cond_wait_cloned.wait().await; + count_cloned.fetch_add(1, Ordering::Relaxed); + // do something + }); + + // Send a signal to the waiting task + cond_wait.signal().await; + + task.await; + + // Reset the boolean flag + cond_wait.reset().await; + + assert_eq!(count.load(Ordering::Relaxed), 1); + + let cond_wait_cloned = cond_wait.clone(); + let count_cloned = count.clone(); + let task1 = smol::spawn(async move { + cond_wait_cloned.wait().await; + count_cloned.fetch_add(1, Ordering::Relaxed); + // do something + }); + + let cond_wait_cloned = cond_wait.clone(); + let count_cloned = count.clone(); + let task2 = smol::spawn(async move { + cond_wait_cloned.wait().await; + count_cloned.fetch_add(1, Ordering::Relaxed); + // do something + }); + + // Broadcast a signal to all waiting tasks + cond_wait.broadcast().await; + + task1.await; + task2.await; + assert_eq!(count.load(Ordering::Relaxed), 3); + }); + } +} diff --git a/core/src/async_util/mod.rs b/core/src/async_util/mod.rs new file mode 100644 index 0000000..c871bad --- /dev/null +++ b/core/src/async_util/mod.rs @@ -0,0 +1,13 @@ +mod backoff; +mod condvar; +mod condwait; +mod select; +mod task_group; +mod timeout; + +pub use backoff::Backoff; +pub use condvar::CondVar; +pub use condwait::CondWait; +pub use select::{select, Either}; +pub use task_group::{TaskGroup, TaskResult}; +pub use timeout::timeout; diff --git a/core/src/async_util/select.rs b/core/src/async_util/select.rs new file mode 100644 index 0000000..8f2f7f6 --- /dev/null +++ b/core/src/async_util/select.rs @@ -0,0 +1,99 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use pin_project_lite::pin_project; +use smol::future::Future; + +/// Returns the result of the future that completes first, preferring future1 +/// if both are ready. +/// +/// # Examples +/// +/// ``` +/// use std::future; +/// +/// use karyons_core::async_util::{select, Either}; +/// +/// async { +/// let fut1 = future::pending::<String>(); +/// let fut2 = future::ready(0); +/// let res = select(fut1, fut2).await; +/// assert!(matches!(res, Either::Right(0))); +/// // .... +/// }; +/// +/// ``` +/// +pub fn select<T1, T2, F1, F2>(future1: F1, future2: F2) -> Select<F1, F2> +where + F1: Future<Output = T1>, + F2: Future<Output = T2>, +{ + Select { future1, future2 } +} + +pin_project! { + #[derive(Debug)] + pub struct Select<F1, F2> { + #[pin] + future1: F1, + #[pin] + future2: F2, + } +} + +/// The return value from the [`select`] function, indicating which future +/// completed first. +#[derive(Debug)] +pub enum Either<T1, T2> { + Left(T1), + Right(T2), +} + +// Implement the Future trait for the Select struct. +impl<T1, T2, F1, F2> Future for Select<F1, F2> +where + F1: Future<Output = T1>, + F2: Future<Output = T2>, +{ + type Output = Either<T1, T2>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + + if let Poll::Ready(t) = this.future1.poll(cx) { + return Poll::Ready(Either::Left(t)); + } + + if let Poll::Ready(t) = this.future2.poll(cx) { + return Poll::Ready(Either::Right(t)); + } + + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use super::{select, Either}; + use smol::Timer; + use std::future; + + #[test] + fn test_async_select() { + smol::block_on(async move { + let fut = select(Timer::never(), future::ready(0 as u32)).await; + assert!(matches!(fut, Either::Right(0))); + + let fut1 = future::pending::<String>(); + let fut2 = future::ready(0); + let res = select(fut1, fut2).await; + assert!(matches!(res, Either::Right(0))); + + let fut1 = future::ready(0); + let fut2 = future::pending::<String>(); + let res = select(fut1, fut2).await; + assert!(matches!(res, Either::Left(_))); + }); + } +} diff --git a/core/src/async_util/task_group.rs b/core/src/async_util/task_group.rs new file mode 100644 index 0000000..3fc0cb7 --- /dev/null +++ b/core/src/async_util/task_group.rs @@ -0,0 +1,194 @@ +use std::{future::Future, sync::Arc, sync::Mutex}; + +use async_task::FallibleTask; + +use crate::Executor; + +use super::{select, CondWait, Either}; + +/// TaskGroup is a group of spawned tasks. +/// +/// # Example +/// +/// ``` +/// +/// use std::sync::Arc; +/// +/// use karyons_core::async_util::TaskGroup; +/// +/// async { +/// +/// let ex = Arc::new(smol::Executor::new()); +/// let group = TaskGroup::new(ex); +/// +/// group.spawn(smol::Timer::never(), |_| async {}); +/// +/// group.cancel().await; +/// +/// }; +/// +/// ``` +/// +pub struct TaskGroup<'a> { + tasks: Mutex<Vec<TaskHandler>>, + stop_signal: Arc<CondWait>, + executor: Executor<'a>, +} + +impl<'a> TaskGroup<'a> { + /// Creates a new task group + pub fn new(executor: Executor<'a>) -> Self { + Self { + tasks: Mutex::new(Vec::new()), + stop_signal: Arc::new(CondWait::new()), + executor, + } + } + + /// Spawns a new task and calls the callback after it has completed + /// or been canceled. The callback will have the `TaskResult` as a + /// parameter, indicating whether the task completed or was canceled. + pub fn spawn<T, Fut, CallbackF, CallbackFut>(&self, fut: Fut, callback: CallbackF) + where + T: Send + Sync + 'a, + Fut: Future<Output = T> + Send + 'a, + CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'a, + CallbackFut: Future<Output = ()> + Send + 'a, + { + let task = TaskHandler::new( + self.executor.clone(), + fut, + callback, + self.stop_signal.clone(), + ); + self.tasks.lock().unwrap().push(task); + } + + /// Checks if the task group is empty. + pub fn is_empty(&self) -> bool { + self.tasks.lock().unwrap().is_empty() + } + + /// Get the number of the tasks in the group. + pub fn len(&self) -> usize { + self.tasks.lock().unwrap().len() + } + + /// Cancels all tasks in the group. + pub async fn cancel(&self) { + self.stop_signal.broadcast().await; + + loop { + let task = self.tasks.lock().unwrap().pop(); + if let Some(t) = task { + t.cancel().await + } else { + break; + } + } + } +} + +/// The result of a spawned task. +#[derive(Debug)] +pub enum TaskResult<T> { + Completed(T), + Cancelled, +} + +impl<T: std::fmt::Debug> std::fmt::Display for TaskResult<T> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + TaskResult::Cancelled => write!(f, "Task cancelled"), + TaskResult::Completed(res) => write!(f, "Task completed: {:?}", res), + } + } +} + +/// TaskHandler +pub struct TaskHandler { + task: FallibleTask<()>, + cancel_flag: Arc<CondWait>, +} + +impl<'a> TaskHandler { + /// Creates a new task handle + fn new<T, Fut, CallbackF, CallbackFut>( + ex: Executor<'a>, + fut: Fut, + callback: CallbackF, + stop_signal: Arc<CondWait>, + ) -> TaskHandler + where + T: Send + Sync + 'a, + Fut: Future<Output = T> + Send + 'a, + CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'a, + CallbackFut: Future<Output = ()> + Send + 'a, + { + let cancel_flag = Arc::new(CondWait::new()); + let cancel_flag_c = cancel_flag.clone(); + let task = ex + .spawn(async move { + //start_signal.signal().await; + // Waits for either the stop signal or the task to complete. + let result = select(stop_signal.wait(), fut).await; + + let result = match result { + Either::Left(_) => TaskResult::Cancelled, + Either::Right(res) => TaskResult::Completed(res), + }; + + // Call the callback with the result. + callback(result).await; + + cancel_flag_c.signal().await; + }) + .fallible(); + + TaskHandler { task, cancel_flag } + } + + /// Cancels the task. + async fn cancel(self) { + self.cancel_flag.wait().await; + self.task.cancel().await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{future, sync::Arc}; + + #[test] + fn test_task_group() { + let ex = Arc::new(smol::Executor::new()); + smol::block_on(ex.clone().run(async move { + let group = Arc::new(TaskGroup::new(ex)); + + group.spawn(future::ready(0), |res| async move { + assert!(matches!(res, TaskResult::Completed(0))); + }); + + group.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + + let groupc = group.clone(); + group.spawn( + async move { + groupc.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + }, + |res| async move { + assert!(matches!(res, TaskResult::Completed(_))); + }, + ); + + // Do something + smol::Timer::after(std::time::Duration::from_millis(50)).await; + group.cancel().await; + })); + } +} diff --git a/core/src/async_util/timeout.rs b/core/src/async_util/timeout.rs new file mode 100644 index 0000000..6ab35c4 --- /dev/null +++ b/core/src/async_util/timeout.rs @@ -0,0 +1,52 @@ +use std::{future::Future, time::Duration}; + +use smol::Timer; + +use super::{select, Either}; +use crate::{error::Error, Result}; + +/// Waits for a future to complete or times out if it exceeds a specified +/// duration. +/// +/// # Example +/// +/// ``` +/// use std::{future, time::Duration}; +/// +/// use karyons_core::async_util::timeout; +/// +/// async { +/// let fut = future::pending::<()>(); +/// assert!(timeout(Duration::from_millis(100), fut).await.is_err()); +/// }; +/// +/// ``` +/// +pub async fn timeout<T, F>(delay: Duration, future1: F) -> Result<T> +where + F: Future<Output = T>, +{ + let result = select(Timer::after(delay), future1).await; + + match result { + Either::Left(_) => Err(Error::Timeout), + Either::Right(res) => Ok(res), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{future, time::Duration}; + + #[test] + fn test_timeout() { + smol::block_on(async move { + let fut = future::pending::<()>(); + assert!(timeout(Duration::from_millis(10), fut).await.is_err()); + + let fut = smol::Timer::after(Duration::from_millis(10)); + assert!(timeout(Duration::from_millis(50), fut).await.is_ok()) + }); + } +} |