From 98a1de91a2dae06323558422c239e5a45fc86e7b Mon Sep 17 00:00:00 2001 From: hozan23 Date: Tue, 28 Nov 2023 22:41:33 +0300 Subject: implement TLS for inbound and outbound connections --- core/src/async_util/task_group.rs | 194 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 194 insertions(+) create mode 100644 core/src/async_util/task_group.rs (limited to 'core/src/async_util/task_group.rs') diff --git a/core/src/async_util/task_group.rs b/core/src/async_util/task_group.rs new file mode 100644 index 0000000..3fc0cb7 --- /dev/null +++ b/core/src/async_util/task_group.rs @@ -0,0 +1,194 @@ +use std::{future::Future, sync::Arc, sync::Mutex}; + +use async_task::FallibleTask; + +use crate::Executor; + +use super::{select, CondWait, Either}; + +/// TaskGroup is a group of spawned tasks. +/// +/// # Example +/// +/// ``` +/// +/// use std::sync::Arc; +/// +/// use karyons_core::async_util::TaskGroup; +/// +/// async { +/// +/// let ex = Arc::new(smol::Executor::new()); +/// let group = TaskGroup::new(ex); +/// +/// group.spawn(smol::Timer::never(), |_| async {}); +/// +/// group.cancel().await; +/// +/// }; +/// +/// ``` +/// +pub struct TaskGroup<'a> { + tasks: Mutex>, + stop_signal: Arc, + executor: Executor<'a>, +} + +impl<'a> TaskGroup<'a> { + /// Creates a new task group + pub fn new(executor: Executor<'a>) -> Self { + Self { + tasks: Mutex::new(Vec::new()), + stop_signal: Arc::new(CondWait::new()), + executor, + } + } + + /// Spawns a new task and calls the callback after it has completed + /// or been canceled. The callback will have the `TaskResult` as a + /// parameter, indicating whether the task completed or was canceled. + pub fn spawn(&self, fut: Fut, callback: CallbackF) + where + T: Send + Sync + 'a, + Fut: Future + Send + 'a, + CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, + CallbackFut: Future + Send + 'a, + { + let task = TaskHandler::new( + self.executor.clone(), + fut, + callback, + self.stop_signal.clone(), + ); + self.tasks.lock().unwrap().push(task); + } + + /// Checks if the task group is empty. + pub fn is_empty(&self) -> bool { + self.tasks.lock().unwrap().is_empty() + } + + /// Get the number of the tasks in the group. + pub fn len(&self) -> usize { + self.tasks.lock().unwrap().len() + } + + /// Cancels all tasks in the group. + pub async fn cancel(&self) { + self.stop_signal.broadcast().await; + + loop { + let task = self.tasks.lock().unwrap().pop(); + if let Some(t) = task { + t.cancel().await + } else { + break; + } + } + } +} + +/// The result of a spawned task. +#[derive(Debug)] +pub enum TaskResult { + Completed(T), + Cancelled, +} + +impl std::fmt::Display for TaskResult { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + TaskResult::Cancelled => write!(f, "Task cancelled"), + TaskResult::Completed(res) => write!(f, "Task completed: {:?}", res), + } + } +} + +/// TaskHandler +pub struct TaskHandler { + task: FallibleTask<()>, + cancel_flag: Arc, +} + +impl<'a> TaskHandler { + /// Creates a new task handle + fn new( + ex: Executor<'a>, + fut: Fut, + callback: CallbackF, + stop_signal: Arc, + ) -> TaskHandler + where + T: Send + Sync + 'a, + Fut: Future + Send + 'a, + CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, + CallbackFut: Future + Send + 'a, + { + let cancel_flag = Arc::new(CondWait::new()); + let cancel_flag_c = cancel_flag.clone(); + let task = ex + .spawn(async move { + //start_signal.signal().await; + // Waits for either the stop signal or the task to complete. + let result = select(stop_signal.wait(), fut).await; + + let result = match result { + Either::Left(_) => TaskResult::Cancelled, + Either::Right(res) => TaskResult::Completed(res), + }; + + // Call the callback with the result. + callback(result).await; + + cancel_flag_c.signal().await; + }) + .fallible(); + + TaskHandler { task, cancel_flag } + } + + /// Cancels the task. + async fn cancel(self) { + self.cancel_flag.wait().await; + self.task.cancel().await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{future, sync::Arc}; + + #[test] + fn test_task_group() { + let ex = Arc::new(smol::Executor::new()); + smol::block_on(ex.clone().run(async move { + let group = Arc::new(TaskGroup::new(ex)); + + group.spawn(future::ready(0), |res| async move { + assert!(matches!(res, TaskResult::Completed(0))); + }); + + group.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + + let groupc = group.clone(); + group.spawn( + async move { + groupc.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + }, + |res| async move { + assert!(matches!(res, TaskResult::Completed(_))); + }, + ); + + // Do something + smol::Timer::after(std::time::Duration::from_millis(50)).await; + group.cancel().await; + })); + } +} -- cgit v1.2.3