From 0992071a7f1a36424bcfaf1fbc84541ea041df1a Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 11 Apr 2024 10:19:20 +0200 Subject: add support for tokio & improve net crate api --- core/src/async_util/backoff.rs | 23 +++++--- core/src/async_util/condvar.rs | 38 +++++++------ core/src/async_util/condwait.rs | 25 ++++---- core/src/async_util/executor.rs | 30 ---------- core/src/async_util/mod.rs | 4 +- core/src/async_util/select.rs | 12 ++-- core/src/async_util/sleep.rs | 6 ++ core/src/async_util/task_group.rs | 117 ++++++++++++++++++++++++-------------- core/src/async_util/timeout.rs | 11 ++-- 9 files changed, 142 insertions(+), 124 deletions(-) delete mode 100644 core/src/async_util/executor.rs create mode 100644 core/src/async_util/sleep.rs (limited to 'core/src/async_util') diff --git a/core/src/async_util/backoff.rs b/core/src/async_util/backoff.rs index 4a0ab35..70e63b3 100644 --- a/core/src/async_util/backoff.rs +++ b/core/src/async_util/backoff.rs @@ -4,7 +4,7 @@ use std::{ time::Duration, }; -use smol::Timer; +use super::sleep; /// Exponential backoff /// @@ -57,7 +57,7 @@ impl Backoff { /// Retruns the delay value. pub async fn sleep(&self) -> u64 { if self.stop.load(Ordering::SeqCst) { - Timer::after(Duration::from_millis(self.max_delay)).await; + sleep(Duration::from_millis(self.max_delay)).await; return self.max_delay; } @@ -71,7 +71,7 @@ impl Backoff { self.retries.store(retries + 1, Ordering::SeqCst); - Timer::after(Duration::from_millis(delay)).await; + sleep(Duration::from_millis(delay)).await; delay } @@ -84,15 +84,18 @@ impl Backoff { #[cfg(test)] mod tests { - use super::*; use std::sync::Arc; + use crate::async_runtime::{block_on, spawn}; + + use super::*; + #[test] fn test_backoff() { - smol::block_on(async move { + block_on(async move { let backoff = Arc::new(Backoff::new(5, 15)); let backoff_c = backoff.clone(); - smol::spawn(async move { + spawn(async move { let delay = backoff_c.sleep().await; assert_eq!(delay, 5); @@ -102,14 +105,16 @@ mod tests { let delay = backoff_c.sleep().await; assert_eq!(delay, 15); }) - .await; + .await + .unwrap(); - smol::spawn(async move { + spawn(async move { backoff.reset(); let delay = backoff.sleep().await; assert_eq!(delay, 5); }) - .await; + .await + .unwrap(); }); } } diff --git a/core/src/async_util/condvar.rs b/core/src/async_util/condvar.rs index d3bc15b..c3f373d 100644 --- a/core/src/async_util/condvar.rs +++ b/core/src/async_util/condvar.rs @@ -6,9 +6,7 @@ use std::{ task::{Context, Poll, Waker}, }; -use smol::lock::MutexGuard; - -use crate::util::random_16; +use crate::{async_runtime::lock::MutexGuard, util::random_16}; /// CondVar is an async version of /// @@ -17,9 +15,8 @@ use crate::util::random_16; ///``` /// use std::sync::Arc; /// -/// use smol::lock::Mutex; -/// /// use karyon_core::async_util::CondVar; +/// use karyon_core::async_runtime::{spawn, lock::Mutex}; /// /// async { /// @@ -28,7 +25,7 @@ use crate::util::random_16; /// /// let val_cloned = val.clone(); /// let condvar_cloned = condvar.clone(); -/// smol::spawn(async move { +/// spawn(async move { /// let mut val = val_cloned.lock().await; /// /// // While the boolean flag is false, wait for a signal. @@ -40,7 +37,7 @@ use crate::util::random_16; /// }); /// /// let condvar_cloned = condvar.clone(); -/// smol::spawn(async move { +/// spawn(async move { /// let mut val = val.lock().await; /// /// // While the boolean flag is false, wait for a signal. @@ -71,7 +68,10 @@ impl CondVar { /// Blocks the current task until this condition variable receives a notification. pub async fn wait<'a, T>(&self, g: MutexGuard<'a, T>) -> MutexGuard<'a, T> { + #[cfg(feature = "smol")] let m = MutexGuard::source(&g); + #[cfg(feature = "tokio")] + let m = MutexGuard::mutex(&g); CondVarAwait::new(self, g).await; @@ -206,8 +206,6 @@ impl Wakers { #[cfg(test)] mod tests { - use super::*; - use smol::lock::Mutex; use std::{ collections::VecDeque, sync::{ @@ -216,6 +214,10 @@ mod tests { }, }; + use crate::async_runtime::{block_on, lock::Mutex, spawn}; + + use super::*; + // The tests below demonstrate a solution to a problem in the Wikipedia // explanation of condition variables: // https://en.wikipedia.org/wiki/Monitor_(synchronization)#Solving_the_bounded_producer/consumer_problem. @@ -243,7 +245,7 @@ mod tests { #[test] fn test_condvar_signal() { - smol::block_on(async { + block_on(async { let number_of_tasks = 30; let queue = Arc::new(Mutex::new(Queue::new(5))); @@ -254,7 +256,7 @@ mod tests { let condvar_full_cloned = condvar_full.clone(); let condvar_empty_cloned = condvar_empty.clone(); - let _producer1 = smol::spawn(async move { + let _producer1 = spawn(async move { for i in 1..number_of_tasks { // Lock queue mtuex let mut queue = queue_cloned.lock().await; @@ -275,7 +277,7 @@ mod tests { let queue_cloned = queue.clone(); let task_consumed = Arc::new(AtomicUsize::new(0)); let task_consumed_ = task_consumed.clone(); - let consumer = smol::spawn(async move { + let consumer = spawn(async move { for _ in 1..number_of_tasks { // Lock queue mtuex let mut queue = queue_cloned.lock().await; @@ -297,7 +299,7 @@ mod tests { } }); - consumer.await; + let _ = consumer.await; assert!(queue.lock().await.is_empty()); assert_eq!(task_consumed.load(Ordering::Relaxed), 29); }); @@ -305,7 +307,7 @@ mod tests { #[test] fn test_condvar_broadcast() { - smol::block_on(async { + block_on(async { let tasks = 30; let queue = Arc::new(Mutex::new(Queue::new(5))); @@ -313,7 +315,7 @@ mod tests { let queue_cloned = queue.clone(); let condvar_cloned = condvar.clone(); - let _producer1 = smol::spawn(async move { + let _producer1 = spawn(async move { for i in 1..tasks { // Lock queue mtuex let mut queue = queue_cloned.lock().await; @@ -333,7 +335,7 @@ mod tests { let queue_cloned = queue.clone(); let condvar_cloned = condvar.clone(); - let _producer2 = smol::spawn(async move { + let _producer2 = spawn(async move { for i in 1..tasks { // Lock queue mtuex let mut queue = queue_cloned.lock().await; @@ -355,7 +357,7 @@ mod tests { let task_consumed = Arc::new(AtomicUsize::new(0)); let task_consumed_ = task_consumed.clone(); - let consumer = smol::spawn(async move { + let consumer = spawn(async move { for _ in 1..((tasks * 2) - 1) { { // Lock queue mutex @@ -379,7 +381,7 @@ mod tests { } }); - consumer.await; + let _ = consumer.await; assert!(queue.lock().await.is_empty()); assert_eq!(task_consumed.load(Ordering::Relaxed), 58); }); diff --git a/core/src/async_util/condwait.rs b/core/src/async_util/condwait.rs index 6aa8a3c..76c6a05 100644 --- a/core/src/async_util/condwait.rs +++ b/core/src/async_util/condwait.rs @@ -1,6 +1,5 @@ -use smol::lock::Mutex; - use super::CondVar; +use crate::async_runtime::lock::Mutex; /// CondWait is a wrapper struct for CondVar with a Mutex boolean flag. /// @@ -10,11 +9,12 @@ use super::CondVar; /// use std::sync::Arc; /// /// use karyon_core::async_util::CondWait; +/// use karyon_core::async_runtime::spawn; /// /// async { /// let cond_wait = Arc::new(CondWait::new()); /// let cond_wait_cloned = cond_wait.clone(); -/// let task = smol::spawn(async move { +/// let task = spawn(async move { /// cond_wait_cloned.wait().await; /// // ... /// }); @@ -76,21 +76,24 @@ impl Default for CondWait { #[cfg(test)] mod tests { - use super::*; use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, }; + use crate::async_runtime::{block_on, spawn}; + + use super::*; + #[test] fn test_cond_wait() { - smol::block_on(async { + block_on(async { let cond_wait = Arc::new(CondWait::new()); let count = Arc::new(AtomicUsize::new(0)); let cond_wait_cloned = cond_wait.clone(); let count_cloned = count.clone(); - let task = smol::spawn(async move { + let task = spawn(async move { cond_wait_cloned.wait().await; count_cloned.fetch_add(1, Ordering::Relaxed); // do something @@ -99,7 +102,7 @@ mod tests { // Send a signal to the waiting task cond_wait.signal().await; - task.await; + let _ = task.await; // Reset the boolean flag cond_wait.reset().await; @@ -108,7 +111,7 @@ mod tests { let cond_wait_cloned = cond_wait.clone(); let count_cloned = count.clone(); - let task1 = smol::spawn(async move { + let task1 = spawn(async move { cond_wait_cloned.wait().await; count_cloned.fetch_add(1, Ordering::Relaxed); // do something @@ -116,7 +119,7 @@ mod tests { let cond_wait_cloned = cond_wait.clone(); let count_cloned = count.clone(); - let task2 = smol::spawn(async move { + let task2 = spawn(async move { cond_wait_cloned.wait().await; count_cloned.fetch_add(1, Ordering::Relaxed); // do something @@ -125,8 +128,8 @@ mod tests { // Broadcast a signal to all waiting tasks cond_wait.broadcast().await; - task1.await; - task2.await; + let _ = task1.await; + let _ = task2.await; assert_eq!(count.load(Ordering::Relaxed), 3); }); } diff --git a/core/src/async_util/executor.rs b/core/src/async_util/executor.rs deleted file mode 100644 index 3e7aa06..0000000 --- a/core/src/async_util/executor.rs +++ /dev/null @@ -1,30 +0,0 @@ -use std::{panic::catch_unwind, sync::Arc, thread}; - -use async_lock::OnceCell; -use smol::Executor as SmolEx; - -static GLOBAL_EXECUTOR: OnceCell>> = OnceCell::new(); - -/// A pointer to an Executor -pub type Executor<'a> = Arc>; - -/// Returns a single-threaded global executor -pub(crate) fn global_executor() -> Executor<'static> { - fn init_executor() -> Executor<'static> { - let ex = smol::Executor::new(); - thread::Builder::new() - .spawn(|| loop { - catch_unwind(|| { - smol::block_on(global_executor().run(smol::future::pending::<()>())) - }) - .ok(); - }) - .expect("cannot spawn executor thread"); - // Prevent spawning another thread by running the process driver on this - // thread. see https://github.com/smol-rs/smol/blob/master/src/spawn.rs - ex.spawn(async_process::driver()).detach(); - Arc::new(ex) - } - - GLOBAL_EXECUTOR.get_or_init_blocking(init_executor).clone() -} diff --git a/core/src/async_util/mod.rs b/core/src/async_util/mod.rs index 2916118..54b9607 100644 --- a/core/src/async_util/mod.rs +++ b/core/src/async_util/mod.rs @@ -1,15 +1,15 @@ mod backoff; mod condvar; mod condwait; -mod executor; mod select; +mod sleep; mod task_group; mod timeout; pub use backoff::Backoff; pub use condvar::CondVar; pub use condwait::CondWait; -pub use executor::Executor; pub use select::{select, Either}; +pub use sleep::sleep; pub use task_group::{TaskGroup, TaskResult}; pub use timeout::timeout; diff --git a/core/src/async_util/select.rs b/core/src/async_util/select.rs index 0977fa9..2008cb5 100644 --- a/core/src/async_util/select.rs +++ b/core/src/async_util/select.rs @@ -1,8 +1,8 @@ +use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use pin_project_lite::pin_project; -use smol::future::Future; /// Returns the result of the future that completes first, preferring future1 /// if both are ready. @@ -75,14 +75,16 @@ where #[cfg(test)] mod tests { - use super::{select, Either}; - use smol::Timer; use std::future; + use crate::{async_runtime::block_on, async_util::sleep}; + + use super::{select, Either}; + #[test] fn test_async_select() { - smol::block_on(async move { - let fut = select(Timer::never(), future::ready(0 as u32)).await; + block_on(async move { + let fut = select(sleep(std::time::Duration::MAX), future::ready(0 as u32)).await; assert!(matches!(fut, Either::Right(0))); let fut1 = future::pending::(); diff --git a/core/src/async_util/sleep.rs b/core/src/async_util/sleep.rs new file mode 100644 index 0000000..f72b825 --- /dev/null +++ b/core/src/async_util/sleep.rs @@ -0,0 +1,6 @@ +pub async fn sleep(duration: std::time::Duration) { + #[cfg(feature = "smol")] + smol::Timer::after(duration).await; + #[cfg(feature = "tokio")] + tokio::time::sleep(duration).await; +} diff --git a/core/src/async_util/task_group.rs b/core/src/async_util/task_group.rs index 7f05696..5af75ed 100644 --- a/core/src/async_util/task_group.rs +++ b/core/src/async_util/task_group.rs @@ -1,8 +1,8 @@ use std::{future::Future, sync::Arc, sync::Mutex}; -use async_task::FallibleTask; +use crate::async_runtime::{global_executor, Executor, Task}; -use super::{executor::global_executor, select, CondWait, Either, Executor}; +use super::{select, CondWait, Either}; /// TaskGroup A group that contains spawned tasks. /// @@ -12,28 +12,25 @@ use super::{executor::global_executor, select, CondWait, Either, Executor}; /// /// use std::sync::Arc; /// -/// use karyon_core::async_util::TaskGroup; +/// use karyon_core::async_util::{TaskGroup, sleep}; /// /// async { +/// let group = TaskGroup::new(); /// -/// let ex = Arc::new(smol::Executor::new()); -/// let group = TaskGroup::with_executor(ex); -/// -/// group.spawn(smol::Timer::never(), |_| async {}); +/// group.spawn(sleep(std::time::Duration::MAX), |_| async {}); /// /// group.cancel().await; /// /// }; /// /// ``` -/// -pub struct TaskGroup<'a> { +pub struct TaskGroup { tasks: Mutex>, stop_signal: Arc, - executor: Executor<'a>, + executor: Executor, } -impl TaskGroup<'static> { +impl TaskGroup { /// Creates a new TaskGroup without providing an executor /// /// This will spawn a task onto a global executor (single-threaded by default). @@ -44,11 +41,9 @@ impl TaskGroup<'static> { executor: global_executor(), } } -} -impl<'a> TaskGroup<'a> { /// Creates a new TaskGroup by providing an executor - pub fn with_executor(executor: Executor<'a>) -> Self { + pub fn with_executor(executor: Executor) -> Self { Self { tasks: Mutex::new(Vec::new()), stop_signal: Arc::new(CondWait::new()), @@ -61,10 +56,10 @@ impl<'a> TaskGroup<'a> { /// parameter, indicating whether the task completed or was canceled. pub fn spawn(&self, fut: Fut, callback: CallbackF) where - T: Send + Sync + 'a, - Fut: Future + Send + 'a, - CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, - CallbackFut: Future + Send + 'a, + T: Send + Sync + 'static, + Fut: Future + Send + 'static, + CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'static, + CallbackFut: Future + Send + 'static, { let task = TaskHandler::new( self.executor.clone(), @@ -100,7 +95,7 @@ impl<'a> TaskGroup<'a> { } } -impl Default for TaskGroup<'static> { +impl Default for TaskGroup { fn default() -> Self { Self::new() } @@ -124,42 +119,40 @@ impl std::fmt::Display for TaskResult { /// TaskHandler pub struct TaskHandler { - task: FallibleTask<()>, + task: Task<()>, cancel_flag: Arc, } impl<'a> TaskHandler { /// Creates a new task handler fn new( - ex: Executor<'a>, + ex: Executor, fut: Fut, callback: CallbackF, stop_signal: Arc, ) -> TaskHandler where - T: Send + Sync + 'a, - Fut: Future + Send + 'a, - CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, - CallbackFut: Future + Send + 'a, + T: Send + Sync + 'static, + Fut: Future + Send + 'static, + CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'static, + CallbackFut: Future + Send + 'static, { let cancel_flag = Arc::new(CondWait::new()); let cancel_flag_c = cancel_flag.clone(); - let task = ex - .spawn(async move { - // Waits for either the stop signal or the task to complete. - let result = select(stop_signal.wait(), fut).await; + let task = ex.spawn(async move { + // Waits for either the stop signal or the task to complete. + let result = select(stop_signal.wait(), fut).await; - let result = match result { - Either::Left(_) => TaskResult::Cancelled, - Either::Right(res) => TaskResult::Completed(res), - }; + let result = match result { + Either::Left(_) => TaskResult::Cancelled, + Either::Right(res) => TaskResult::Completed(res), + }; - // Call the callback - callback(result).await; + // Call the callback + callback(result).await; - cancel_flag_c.signal().await; - }) - .fallible(); + cancel_flag_c.signal().await; + }); TaskHandler { task, cancel_flag } } @@ -173,14 +166,52 @@ impl<'a> TaskHandler { #[cfg(test)] mod tests { - use super::*; use std::{future, sync::Arc}; + use crate::async_runtime::block_on; + use crate::async_util::sleep; + + use super::*; + + #[cfg(feature = "tokio")] + #[test] + fn test_task_group_with_tokio_executor() { + let ex = Arc::new(tokio::runtime::Runtime::new().unwrap()); + ex.clone().block_on(async move { + let group = Arc::new(TaskGroup::with_executor(ex.into())); + + group.spawn(future::ready(0), |res| async move { + assert!(matches!(res, TaskResult::Completed(0))); + }); + + group.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + + let groupc = group.clone(); + group.spawn( + async move { + groupc.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + }, + |res| async move { + assert!(matches!(res, TaskResult::Completed(_))); + }, + ); + + // Do something + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + group.cancel().await; + }); + } + + #[cfg(feature = "smol")] #[test] - fn test_task_group_with_executor() { + fn test_task_group_with_smol_executor() { let ex = Arc::new(smol::Executor::new()); smol::block_on(ex.clone().run(async move { - let group = Arc::new(TaskGroup::with_executor(ex)); + let group = Arc::new(TaskGroup::with_executor(ex.into())); group.spawn(future::ready(0), |res| async move { assert!(matches!(res, TaskResult::Completed(0))); @@ -210,7 +241,7 @@ mod tests { #[test] fn test_task_group() { - smol::block_on(async { + block_on(async { let group = Arc::new(TaskGroup::new()); group.spawn(future::ready(0), |res| async move { @@ -234,7 +265,7 @@ mod tests { ); // Do something - smol::Timer::after(std::time::Duration::from_millis(50)).await; + sleep(std::time::Duration::from_millis(50)).await; group.cancel().await; }); } diff --git a/core/src/async_util/timeout.rs b/core/src/async_util/timeout.rs index cf3c490..9ac64c8 100644 --- a/core/src/async_util/timeout.rs +++ b/core/src/async_util/timeout.rs @@ -1,10 +1,9 @@ use std::{future::Future, time::Duration}; -use smol::Timer; - -use super::{select, Either}; use crate::{error::Error, Result}; +use super::{select, sleep, Either}; + /// Waits for a future to complete or times out if it exceeds a specified /// duration. /// @@ -26,7 +25,7 @@ pub async fn timeout(delay: Duration, future1: F) -> Result where F: Future, { - let result = select(Timer::after(delay), future1).await; + let result = select(sleep(delay), future1).await; match result { Either::Left(_) => Err(Error::Timeout), @@ -41,11 +40,11 @@ mod tests { #[test] fn test_timeout() { - smol::block_on(async move { + crate::async_runtime::block_on(async move { let fut = future::pending::<()>(); assert!(timeout(Duration::from_millis(10), fut).await.is_err()); - let fut = smol::Timer::after(Duration::from_millis(10)); + let fut = sleep(Duration::from_millis(10)); assert!(timeout(Duration::from_millis(50), fut).await.is_ok()) }); } -- cgit v1.2.3