diff options
author | hozan23 <hozan23@proton.me> | 2023-11-28 22:41:33 +0300 |
---|---|---|
committer | hozan23 <hozan23@proton.me> | 2023-11-28 22:41:33 +0300 |
commit | 98a1de91a2dae06323558422c239e5a45fc86e7b (patch) | |
tree | 38c640248824fcb3b4ca5ba12df47c13ef26ccda /core/src/async_util/task_group.rs | |
parent | ca2a5f8bbb6983d9555abd10eaaf86950b794957 (diff) |
implement TLS for inbound and outbound connections
Diffstat (limited to 'core/src/async_util/task_group.rs')
-rw-r--r-- | core/src/async_util/task_group.rs | 194 |
1 files changed, 194 insertions, 0 deletions
diff --git a/core/src/async_util/task_group.rs b/core/src/async_util/task_group.rs new file mode 100644 index 0000000..3fc0cb7 --- /dev/null +++ b/core/src/async_util/task_group.rs @@ -0,0 +1,194 @@ +use std::{future::Future, sync::Arc, sync::Mutex}; + +use async_task::FallibleTask; + +use crate::Executor; + +use super::{select, CondWait, Either}; + +/// TaskGroup is a group of spawned tasks. +/// +/// # Example +/// +/// ``` +/// +/// use std::sync::Arc; +/// +/// use karyons_core::async_util::TaskGroup; +/// +/// async { +/// +/// let ex = Arc::new(smol::Executor::new()); +/// let group = TaskGroup::new(ex); +/// +/// group.spawn(smol::Timer::never(), |_| async {}); +/// +/// group.cancel().await; +/// +/// }; +/// +/// ``` +/// +pub struct TaskGroup<'a> { + tasks: Mutex<Vec<TaskHandler>>, + stop_signal: Arc<CondWait>, + executor: Executor<'a>, +} + +impl<'a> TaskGroup<'a> { + /// Creates a new task group + pub fn new(executor: Executor<'a>) -> Self { + Self { + tasks: Mutex::new(Vec::new()), + stop_signal: Arc::new(CondWait::new()), + executor, + } + } + + /// Spawns a new task and calls the callback after it has completed + /// or been canceled. The callback will have the `TaskResult` as a + /// parameter, indicating whether the task completed or was canceled. + pub fn spawn<T, Fut, CallbackF, CallbackFut>(&self, fut: Fut, callback: CallbackF) + where + T: Send + Sync + 'a, + Fut: Future<Output = T> + Send + 'a, + CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'a, + CallbackFut: Future<Output = ()> + Send + 'a, + { + let task = TaskHandler::new( + self.executor.clone(), + fut, + callback, + self.stop_signal.clone(), + ); + self.tasks.lock().unwrap().push(task); + } + + /// Checks if the task group is empty. + pub fn is_empty(&self) -> bool { + self.tasks.lock().unwrap().is_empty() + } + + /// Get the number of the tasks in the group. + pub fn len(&self) -> usize { + self.tasks.lock().unwrap().len() + } + + /// Cancels all tasks in the group. + pub async fn cancel(&self) { + self.stop_signal.broadcast().await; + + loop { + let task = self.tasks.lock().unwrap().pop(); + if let Some(t) = task { + t.cancel().await + } else { + break; + } + } + } +} + +/// The result of a spawned task. +#[derive(Debug)] +pub enum TaskResult<T> { + Completed(T), + Cancelled, +} + +impl<T: std::fmt::Debug> std::fmt::Display for TaskResult<T> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + TaskResult::Cancelled => write!(f, "Task cancelled"), + TaskResult::Completed(res) => write!(f, "Task completed: {:?}", res), + } + } +} + +/// TaskHandler +pub struct TaskHandler { + task: FallibleTask<()>, + cancel_flag: Arc<CondWait>, +} + +impl<'a> TaskHandler { + /// Creates a new task handle + fn new<T, Fut, CallbackF, CallbackFut>( + ex: Executor<'a>, + fut: Fut, + callback: CallbackF, + stop_signal: Arc<CondWait>, + ) -> TaskHandler + where + T: Send + Sync + 'a, + Fut: Future<Output = T> + Send + 'a, + CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'a, + CallbackFut: Future<Output = ()> + Send + 'a, + { + let cancel_flag = Arc::new(CondWait::new()); + let cancel_flag_c = cancel_flag.clone(); + let task = ex + .spawn(async move { + //start_signal.signal().await; + // Waits for either the stop signal or the task to complete. + let result = select(stop_signal.wait(), fut).await; + + let result = match result { + Either::Left(_) => TaskResult::Cancelled, + Either::Right(res) => TaskResult::Completed(res), + }; + + // Call the callback with the result. + callback(result).await; + + cancel_flag_c.signal().await; + }) + .fallible(); + + TaskHandler { task, cancel_flag } + } + + /// Cancels the task. + async fn cancel(self) { + self.cancel_flag.wait().await; + self.task.cancel().await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{future, sync::Arc}; + + #[test] + fn test_task_group() { + let ex = Arc::new(smol::Executor::new()); + smol::block_on(ex.clone().run(async move { + let group = Arc::new(TaskGroup::new(ex)); + + group.spawn(future::ready(0), |res| async move { + assert!(matches!(res, TaskResult::Completed(0))); + }); + + group.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + + let groupc = group.clone(); + group.spawn( + async move { + groupc.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + }, + |res| async move { + assert!(matches!(res, TaskResult::Completed(_))); + }, + ); + + // Do something + smol::Timer::after(std::time::Duration::from_millis(50)).await; + group.cancel().await; + })); + } +} |