aboutsummaryrefslogtreecommitdiff
path: root/core/src/async_utils/task_group.rs
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/async_utils/task_group.rs')
-rw-r--r--core/src/async_utils/task_group.rs73
1 files changed, 35 insertions, 38 deletions
diff --git a/core/src/async_utils/task_group.rs b/core/src/async_utils/task_group.rs
index 8707d0e..afc9648 100644
--- a/core/src/async_utils/task_group.rs
+++ b/core/src/async_utils/task_group.rs
@@ -1,6 +1,6 @@
use std::{future::Future, sync::Arc, sync::Mutex};
-use smol::Task;
+use async_task::FallibleTask;
use crate::Executor;
@@ -19,9 +19,9 @@ use super::{select, CondWait, Either};
/// async {
///
/// let ex = Arc::new(smol::Executor::new());
-/// let group = TaskGroup::new();
+/// let group = TaskGroup::new(ex);
///
-/// group.spawn(ex.clone(), smol::Timer::never(), |_| async {});
+/// group.spawn(smol::Timer::never(), |_| async {});
///
/// group.cancel().await;
///
@@ -29,35 +29,38 @@ use super::{select, CondWait, Either};
///
/// ```
///
-pub struct TaskGroup {
+pub struct TaskGroup<'a> {
tasks: Mutex<Vec<TaskHandler>>,
stop_signal: Arc<CondWait>,
+ executor: Executor<'a>,
}
-impl<'a> TaskGroup {
+impl<'a> TaskGroup<'a> {
/// Creates a new task group
- pub fn new() -> Self {
+ pub fn new(executor: Executor<'a>) -> Self {
Self {
tasks: Mutex::new(Vec::new()),
stop_signal: Arc::new(CondWait::new()),
+ executor,
}
}
/// Spawns a new task and calls the callback after it has completed
/// or been canceled. The callback will have the `TaskResult` as a
/// parameter, indicating whether the task completed or was canceled.
- pub fn spawn<T, Fut, CallbackF, CallbackFut>(
- &self,
- executor: Executor<'a>,
- fut: Fut,
- callback: CallbackF,
- ) where
+ pub fn spawn<T, Fut, CallbackF, CallbackFut>(&self, fut: Fut, callback: CallbackF)
+ where
T: Send + Sync + 'a,
Fut: Future<Output = T> + Send + 'a,
CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'a,
CallbackFut: Future<Output = ()> + Send + 'a,
{
- let task = TaskHandler::new(executor.clone(), fut, callback, self.stop_signal.clone());
+ let task = TaskHandler::new(
+ self.executor.clone(),
+ fut,
+ callback,
+ self.stop_signal.clone(),
+ );
self.tasks.lock().unwrap().push(task);
}
@@ -86,12 +89,6 @@ impl<'a> TaskGroup {
}
}
-impl Default for TaskGroup {
- fn default() -> Self {
- Self::new()
- }
-}
-
/// The result of a spawned task.
#[derive(Debug)]
pub enum TaskResult<T> {
@@ -110,7 +107,7 @@ impl<T: std::fmt::Debug> std::fmt::Display for TaskResult<T> {
/// TaskHandler
pub struct TaskHandler {
- task: Task<()>,
+ task: FallibleTask<()>,
cancel_flag: Arc<CondWait>,
}
@@ -130,21 +127,23 @@ impl<'a> TaskHandler {
{
let cancel_flag = Arc::new(CondWait::new());
let cancel_flag_c = cancel_flag.clone();
- let task = ex.spawn(async move {
- //start_signal.signal().await;
- // Waits for either the stop signal or the task to complete.
- let result = select(stop_signal.wait(), fut).await;
+ let task = ex
+ .spawn(async move {
+ //start_signal.signal().await;
+ // Waits for either the stop signal or the task to complete.
+ let result = select(stop_signal.wait(), fut).await;
- let result = match result {
- Either::Left(_) => TaskResult::Cancelled,
- Either::Right(res) => TaskResult::Completed(res),
- };
+ let result = match result {
+ Either::Left(_) => TaskResult::Cancelled,
+ Either::Right(res) => TaskResult::Completed(res),
+ };
- // Call the callback with the result.
- callback(result).await;
+ // Call the callback with the result.
+ callback(result).await;
- cancel_flag_c.signal().await;
- });
+ cancel_flag_c.signal().await;
+ })
+ .fallible();
TaskHandler { task, cancel_flag }
}
@@ -165,22 +164,20 @@ mod tests {
fn test_task_group() {
let ex = Arc::new(smol::Executor::new());
smol::block_on(ex.clone().run(async move {
- let group = Arc::new(TaskGroup::new());
+ let group = Arc::new(TaskGroup::new(ex));
- group.spawn(ex.clone(), future::ready(0), |res| async move {
+ group.spawn(future::ready(0), |res| async move {
assert!(matches!(res, TaskResult::Completed(0)));
});
- group.spawn(ex.clone(), future::pending::<()>(), |res| async move {
+ group.spawn(future::pending::<()>(), |res| async move {
assert!(matches!(res, TaskResult::Cancelled));
});
let groupc = group.clone();
- let exc = ex.clone();
group.spawn(
- ex.clone(),
async move {
- groupc.spawn(exc.clone(), future::pending::<()>(), |res| async move {
+ groupc.spawn(future::pending::<()>(), |res| async move {
assert!(matches!(res, TaskResult::Cancelled));
});
},