From 0992071a7f1a36424bcfaf1fbc84541ea041df1a Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 11 Apr 2024 10:19:20 +0200 Subject: add support for tokio & improve net crate api --- core/src/async_util/task_group.rs | 117 ++++++++++++++++++++++++-------------- 1 file changed, 74 insertions(+), 43 deletions(-) (limited to 'core/src/async_util/task_group.rs') diff --git a/core/src/async_util/task_group.rs b/core/src/async_util/task_group.rs index 7f05696..5af75ed 100644 --- a/core/src/async_util/task_group.rs +++ b/core/src/async_util/task_group.rs @@ -1,8 +1,8 @@ use std::{future::Future, sync::Arc, sync::Mutex}; -use async_task::FallibleTask; +use crate::async_runtime::{global_executor, Executor, Task}; -use super::{executor::global_executor, select, CondWait, Either, Executor}; +use super::{select, CondWait, Either}; /// TaskGroup A group that contains spawned tasks. /// @@ -12,28 +12,25 @@ use super::{executor::global_executor, select, CondWait, Either, Executor}; /// /// use std::sync::Arc; /// -/// use karyon_core::async_util::TaskGroup; +/// use karyon_core::async_util::{TaskGroup, sleep}; /// /// async { +/// let group = TaskGroup::new(); /// -/// let ex = Arc::new(smol::Executor::new()); -/// let group = TaskGroup::with_executor(ex); -/// -/// group.spawn(smol::Timer::never(), |_| async {}); +/// group.spawn(sleep(std::time::Duration::MAX), |_| async {}); /// /// group.cancel().await; /// /// }; /// /// ``` -/// -pub struct TaskGroup<'a> { +pub struct TaskGroup { tasks: Mutex>, stop_signal: Arc, - executor: Executor<'a>, + executor: Executor, } -impl TaskGroup<'static> { +impl TaskGroup { /// Creates a new TaskGroup without providing an executor /// /// This will spawn a task onto a global executor (single-threaded by default). @@ -44,11 +41,9 @@ impl TaskGroup<'static> { executor: global_executor(), } } -} -impl<'a> TaskGroup<'a> { /// Creates a new TaskGroup by providing an executor - pub fn with_executor(executor: Executor<'a>) -> Self { + pub fn with_executor(executor: Executor) -> Self { Self { tasks: Mutex::new(Vec::new()), stop_signal: Arc::new(CondWait::new()), @@ -61,10 +56,10 @@ impl<'a> TaskGroup<'a> { /// parameter, indicating whether the task completed or was canceled. pub fn spawn(&self, fut: Fut, callback: CallbackF) where - T: Send + Sync + 'a, - Fut: Future + Send + 'a, - CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, - CallbackFut: Future + Send + 'a, + T: Send + Sync + 'static, + Fut: Future + Send + 'static, + CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'static, + CallbackFut: Future + Send + 'static, { let task = TaskHandler::new( self.executor.clone(), @@ -100,7 +95,7 @@ impl<'a> TaskGroup<'a> { } } -impl Default for TaskGroup<'static> { +impl Default for TaskGroup { fn default() -> Self { Self::new() } @@ -124,42 +119,40 @@ impl std::fmt::Display for TaskResult { /// TaskHandler pub struct TaskHandler { - task: FallibleTask<()>, + task: Task<()>, cancel_flag: Arc, } impl<'a> TaskHandler { /// Creates a new task handler fn new( - ex: Executor<'a>, + ex: Executor, fut: Fut, callback: CallbackF, stop_signal: Arc, ) -> TaskHandler where - T: Send + Sync + 'a, - Fut: Future + Send + 'a, - CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, - CallbackFut: Future + Send + 'a, + T: Send + Sync + 'static, + Fut: Future + Send + 'static, + CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'static, + CallbackFut: Future + Send + 'static, { let cancel_flag = Arc::new(CondWait::new()); let cancel_flag_c = cancel_flag.clone(); - let task = ex - .spawn(async move { - // Waits for either the stop signal or the task to complete. - let result = select(stop_signal.wait(), fut).await; + let task = ex.spawn(async move { + // Waits for either the stop signal or the task to complete. + let result = select(stop_signal.wait(), fut).await; - let result = match result { - Either::Left(_) => TaskResult::Cancelled, - Either::Right(res) => TaskResult::Completed(res), - }; + let result = match result { + Either::Left(_) => TaskResult::Cancelled, + Either::Right(res) => TaskResult::Completed(res), + }; - // Call the callback - callback(result).await; + // Call the callback + callback(result).await; - cancel_flag_c.signal().await; - }) - .fallible(); + cancel_flag_c.signal().await; + }); TaskHandler { task, cancel_flag } } @@ -173,14 +166,52 @@ impl<'a> TaskHandler { #[cfg(test)] mod tests { - use super::*; use std::{future, sync::Arc}; + use crate::async_runtime::block_on; + use crate::async_util::sleep; + + use super::*; + + #[cfg(feature = "tokio")] + #[test] + fn test_task_group_with_tokio_executor() { + let ex = Arc::new(tokio::runtime::Runtime::new().unwrap()); + ex.clone().block_on(async move { + let group = Arc::new(TaskGroup::with_executor(ex.into())); + + group.spawn(future::ready(0), |res| async move { + assert!(matches!(res, TaskResult::Completed(0))); + }); + + group.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + + let groupc = group.clone(); + group.spawn( + async move { + groupc.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + }, + |res| async move { + assert!(matches!(res, TaskResult::Completed(_))); + }, + ); + + // Do something + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + group.cancel().await; + }); + } + + #[cfg(feature = "smol")] #[test] - fn test_task_group_with_executor() { + fn test_task_group_with_smol_executor() { let ex = Arc::new(smol::Executor::new()); smol::block_on(ex.clone().run(async move { - let group = Arc::new(TaskGroup::with_executor(ex)); + let group = Arc::new(TaskGroup::with_executor(ex.into())); group.spawn(future::ready(0), |res| async move { assert!(matches!(res, TaskResult::Completed(0))); @@ -210,7 +241,7 @@ mod tests { #[test] fn test_task_group() { - smol::block_on(async { + block_on(async { let group = Arc::new(TaskGroup::new()); group.spawn(future::ready(0), |res| async move { @@ -234,7 +265,7 @@ mod tests { ); // Do something - smol::Timer::after(std::time::Duration::from_millis(50)).await; + sleep(std::time::Duration::from_millis(50)).await; group.cancel().await; }); } -- cgit v1.2.3