diff options
Diffstat (limited to 'core/src/async_utils')
-rw-r--r-- | core/src/async_utils/task_group.rs | 73 |
1 files changed, 35 insertions, 38 deletions
diff --git a/core/src/async_utils/task_group.rs b/core/src/async_utils/task_group.rs index 8707d0e..afc9648 100644 --- a/core/src/async_utils/task_group.rs +++ b/core/src/async_utils/task_group.rs @@ -1,6 +1,6 @@ use std::{future::Future, sync::Arc, sync::Mutex}; -use smol::Task; +use async_task::FallibleTask; use crate::Executor; @@ -19,9 +19,9 @@ use super::{select, CondWait, Either}; /// async { /// /// let ex = Arc::new(smol::Executor::new()); -/// let group = TaskGroup::new(); +/// let group = TaskGroup::new(ex); /// -/// group.spawn(ex.clone(), smol::Timer::never(), |_| async {}); +/// group.spawn(smol::Timer::never(), |_| async {}); /// /// group.cancel().await; /// @@ -29,35 +29,38 @@ use super::{select, CondWait, Either}; /// /// ``` /// -pub struct TaskGroup { +pub struct TaskGroup<'a> { tasks: Mutex<Vec<TaskHandler>>, stop_signal: Arc<CondWait>, + executor: Executor<'a>, } -impl<'a> TaskGroup { +impl<'a> TaskGroup<'a> { /// Creates a new task group - pub fn new() -> Self { + pub fn new(executor: Executor<'a>) -> Self { Self { tasks: Mutex::new(Vec::new()), stop_signal: Arc::new(CondWait::new()), + executor, } } /// Spawns a new task and calls the callback after it has completed /// or been canceled. The callback will have the `TaskResult` as a /// parameter, indicating whether the task completed or was canceled. - pub fn spawn<T, Fut, CallbackF, CallbackFut>( - &self, - executor: Executor<'a>, - fut: Fut, - callback: CallbackF, - ) where + pub fn spawn<T, Fut, CallbackF, CallbackFut>(&self, fut: Fut, callback: CallbackF) + where T: Send + Sync + 'a, Fut: Future<Output = T> + Send + 'a, CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'a, CallbackFut: Future<Output = ()> + Send + 'a, { - let task = TaskHandler::new(executor.clone(), fut, callback, self.stop_signal.clone()); + let task = TaskHandler::new( + self.executor.clone(), + fut, + callback, + self.stop_signal.clone(), + ); self.tasks.lock().unwrap().push(task); } @@ -86,12 +89,6 @@ impl<'a> TaskGroup { } } -impl Default for TaskGroup { - fn default() -> Self { - Self::new() - } -} - /// The result of a spawned task. #[derive(Debug)] pub enum TaskResult<T> { @@ -110,7 +107,7 @@ impl<T: std::fmt::Debug> std::fmt::Display for TaskResult<T> { /// TaskHandler pub struct TaskHandler { - task: Task<()>, + task: FallibleTask<()>, cancel_flag: Arc<CondWait>, } @@ -130,21 +127,23 @@ impl<'a> TaskHandler { { let cancel_flag = Arc::new(CondWait::new()); let cancel_flag_c = cancel_flag.clone(); - let task = ex.spawn(async move { - //start_signal.signal().await; - // Waits for either the stop signal or the task to complete. - let result = select(stop_signal.wait(), fut).await; + let task = ex + .spawn(async move { + //start_signal.signal().await; + // Waits for either the stop signal or the task to complete. + let result = select(stop_signal.wait(), fut).await; - let result = match result { - Either::Left(_) => TaskResult::Cancelled, - Either::Right(res) => TaskResult::Completed(res), - }; + let result = match result { + Either::Left(_) => TaskResult::Cancelled, + Either::Right(res) => TaskResult::Completed(res), + }; - // Call the callback with the result. - callback(result).await; + // Call the callback with the result. + callback(result).await; - cancel_flag_c.signal().await; - }); + cancel_flag_c.signal().await; + }) + .fallible(); TaskHandler { task, cancel_flag } } @@ -165,22 +164,20 @@ mod tests { fn test_task_group() { let ex = Arc::new(smol::Executor::new()); smol::block_on(ex.clone().run(async move { - let group = Arc::new(TaskGroup::new()); + let group = Arc::new(TaskGroup::new(ex)); - group.spawn(ex.clone(), future::ready(0), |res| async move { + group.spawn(future::ready(0), |res| async move { assert!(matches!(res, TaskResult::Completed(0))); }); - group.spawn(ex.clone(), future::pending::<()>(), |res| async move { + group.spawn(future::pending::<()>(), |res| async move { assert!(matches!(res, TaskResult::Cancelled)); }); let groupc = group.clone(); - let exc = ex.clone(); group.spawn( - ex.clone(), async move { - groupc.spawn(exc.clone(), future::pending::<()>(), |res| async move { + groupc.spawn(future::pending::<()>(), |res| async move { assert!(matches!(res, TaskResult::Cancelled)); }); }, |