diff options
author | hozan23 <hozan23@karyontech.net> | 2024-04-11 10:19:20 +0200 |
---|---|---|
committer | hozan23 <hozan23@karyontech.net> | 2024-05-19 13:51:30 +0200 |
commit | 0992071a7f1a36424bcfaf1fbc84541ea041df1a (patch) | |
tree | 961d73218af672797d49f899289bef295bc56493 /core | |
parent | a69917ecd8272a4946cfd12c75bf8f8c075b0e50 (diff) |
add support for tokio & improve net crate api
Diffstat (limited to 'core')
-rw-r--r-- | core/Cargo.toml | 28 | ||||
-rw-r--r-- | core/src/async_runtime/executor.rs | 100 | ||||
-rw-r--r-- | core/src/async_runtime/io.rs | 9 | ||||
-rw-r--r-- | core/src/async_runtime/lock.rs | 5 | ||||
-rw-r--r-- | core/src/async_runtime/mod.rs | 25 | ||||
-rw-r--r-- | core/src/async_runtime/net.rs | 12 | ||||
-rw-r--r-- | core/src/async_runtime/spawn.rs | 12 | ||||
-rw-r--r-- | core/src/async_runtime/task.rs | 52 | ||||
-rw-r--r-- | core/src/async_runtime/timer.rs | 1 | ||||
-rw-r--r-- | core/src/async_util/backoff.rs | 23 | ||||
-rw-r--r-- | core/src/async_util/condvar.rs | 38 | ||||
-rw-r--r-- | core/src/async_util/condwait.rs | 25 | ||||
-rw-r--r-- | core/src/async_util/executor.rs | 30 | ||||
-rw-r--r-- | core/src/async_util/mod.rs | 4 | ||||
-rw-r--r-- | core/src/async_util/select.rs | 12 | ||||
-rw-r--r-- | core/src/async_util/sleep.rs | 6 | ||||
-rw-r--r-- | core/src/async_util/task_group.rs | 117 | ||||
-rw-r--r-- | core/src/async_util/timeout.rs | 11 | ||||
-rw-r--r-- | core/src/error.rs | 10 | ||||
-rw-r--r-- | core/src/event.rs | 24 | ||||
-rw-r--r-- | core/src/lib.rs | 15 | ||||
-rw-r--r-- | core/src/pubsub.rs | 11 | ||||
-rw-r--r-- | core/src/util/encode.rs | 11 |
23 files changed, 408 insertions, 173 deletions
diff --git a/core/Cargo.toml b/core/Cargo.toml index c8e2b8d..4bb7f4f 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -1,13 +1,15 @@ [package] name = "karyon_core" -version.workspace = true +version.workspace = true edition.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +default = ["smol"] +crypto = ["dep:ed25519-dalek"] +tokio = ["dep:tokio"] +smol = ["dep:smol", "dep:async-process"] [dependencies] -smol = "2.0.0" pin-project-lite = "0.2.13" log = "0.4.21" bincode = "2.0.0-rc.3" @@ -15,15 +17,15 @@ chrono = "0.4.35" rand = "0.8.5" thiserror = "1.0.58" dirs = "5.0.1" -async-task = "4.7.0" -async-lock = "3.3.0" -async-process = "2.1.0" - -ed25519-dalek = { version = "2.1.1", features = ["rand_core"], optional = true} +async-channel = "2.2.0" +# crypto feature deps +ed25519-dalek = { version = "2.1.1", features = ["rand_core"], optional = true } -[features] -default = [] -crypto = ["dep:ed25519-dalek"] - +# smol feature deps +async-process = { version = "2.1.0", optional = true } +smol = { version = "2.0.0", optional = true } +# tokio feature deps +tokio = { version = "1.37.0", features = ["full"], optional = true } +once_cell = "1.19.0" diff --git a/core/src/async_runtime/executor.rs b/core/src/async_runtime/executor.rs new file mode 100644 index 0000000..9335f12 --- /dev/null +++ b/core/src/async_runtime/executor.rs @@ -0,0 +1,100 @@ +use std::{future::Future, panic::catch_unwind, sync::Arc, thread}; + +use once_cell::sync::OnceCell; + +#[cfg(feature = "smol")] +pub use smol::Executor as SmolEx; + +#[cfg(feature = "tokio")] +pub use tokio::runtime::Runtime; + +use super::Task; + +#[derive(Clone)] +pub struct Executor { + #[cfg(feature = "smol")] + inner: Arc<SmolEx<'static>>, + #[cfg(feature = "tokio")] + inner: Arc<Runtime>, +} + +impl Executor { + pub fn spawn<T: Send + 'static>( + &self, + future: impl Future<Output = T> + Send + 'static, + ) -> Task<T> { + self.inner.spawn(future).into() + } +} + +static GLOBAL_EXECUTOR: OnceCell<Executor> = OnceCell::new(); + +/// Returns a single-threaded global executor +pub fn global_executor() -> Executor { + #[cfg(feature = "smol")] + fn init_executor() -> Executor { + let ex = smol::Executor::new(); + thread::Builder::new() + .name("smol-executor".to_string()) + .spawn(|| loop { + catch_unwind(|| { + smol::block_on(global_executor().inner.run(std::future::pending::<()>())) + }) + .ok(); + }) + .expect("cannot spawn executor thread"); + // Prevent spawning another thread by running the process driver on this + // thread. see https://github.com/smol-rs/smol/blob/master/src/spawn.rs + ex.spawn(async_process::driver()).detach(); + Executor { + inner: Arc::new(ex), + } + } + + #[cfg(feature = "tokio")] + fn init_executor() -> Executor { + let ex = Arc::new(tokio::runtime::Runtime::new().expect("cannot build tokio runtime")); + let ex_cloned = ex.clone(); + thread::Builder::new() + .name("tokio-executor".to_string()) + .spawn(move || { + catch_unwind(|| ex_cloned.block_on(std::future::pending::<()>())).ok(); + }) + .expect("cannot spawn tokio runtime thread"); + Executor { inner: ex } + } + + GLOBAL_EXECUTOR.get_or_init(init_executor).clone() +} + +#[cfg(feature = "smol")] +impl From<Arc<smol::Executor<'static>>> for Executor { + fn from(ex: Arc<smol::Executor<'static>>) -> Executor { + Executor { inner: ex } + } +} + +#[cfg(feature = "tokio")] +impl From<Arc<tokio::runtime::Runtime>> for Executor { + fn from(rt: Arc<tokio::runtime::Runtime>) -> Executor { + Executor { inner: rt } + } +} + +#[cfg(feature = "smol")] +impl From<smol::Executor<'static>> for Executor { + fn from(ex: smol::Executor<'static>) -> Executor { + Executor { + inner: Arc::new(ex), + } + } +} + +#[cfg(feature = "tokio")] +impl From<tokio::runtime::Runtime> for Executor { + fn from(rt: tokio::runtime::Runtime) -> Executor { + Executor { + inner: Arc::new(rt), + } + } +} diff --git a/core/src/async_runtime/io.rs b/core/src/async_runtime/io.rs new file mode 100644 index 0000000..161c258 --- /dev/null +++ b/core/src/async_runtime/io.rs @@ -0,0 +1,9 @@ +#[cfg(feature = "smol")] +pub use smol::io::{ + split, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, +}; + +#[cfg(feature = "tokio")] +pub use tokio::io::{ + split, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, +}; diff --git a/core/src/async_runtime/lock.rs b/core/src/async_runtime/lock.rs new file mode 100644 index 0000000..fc84d1d --- /dev/null +++ b/core/src/async_runtime/lock.rs @@ -0,0 +1,5 @@ +#[cfg(feature = "smol")] +pub use smol::lock::{Mutex, MutexGuard, OnceCell, RwLock}; + +#[cfg(feature = "tokio")] +pub use tokio::sync::{Mutex, MutexGuard, OnceCell, RwLock}; diff --git a/core/src/async_runtime/mod.rs b/core/src/async_runtime/mod.rs new file mode 100644 index 0000000..d91d01b --- /dev/null +++ b/core/src/async_runtime/mod.rs @@ -0,0 +1,25 @@ +mod executor; +pub mod io; +pub mod lock; +pub mod net; +mod spawn; +mod task; +mod timer; + +pub use executor::{global_executor, Executor}; +pub use spawn::spawn; +pub use task::Task; + +#[cfg(test)] +pub fn block_on<T>(future: impl std::future::Future<Output = T>) -> T { + #[cfg(feature = "smol")] + let result = smol::block_on(future); + #[cfg(feature = "tokio")] + let result = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(future); + + result +} diff --git a/core/src/async_runtime/net.rs b/core/src/async_runtime/net.rs new file mode 100644 index 0000000..5c004ce --- /dev/null +++ b/core/src/async_runtime/net.rs @@ -0,0 +1,12 @@ +pub use std::os::unix::net::SocketAddr; + +#[cfg(feature = "smol")] +pub use smol::net::{ + unix::{SocketAddr as UnixSocketAddr, UnixListener, UnixStream}, + TcpListener, TcpStream, UdpSocket, +}; + +#[cfg(feature = "tokio")] +pub use tokio::net::{ + unix::SocketAddr as UnixSocketAddr, TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream, +}; diff --git a/core/src/async_runtime/spawn.rs b/core/src/async_runtime/spawn.rs new file mode 100644 index 0000000..2760982 --- /dev/null +++ b/core/src/async_runtime/spawn.rs @@ -0,0 +1,12 @@ +use std::future::Future; + +use super::Task; + +pub fn spawn<T: Send + 'static>(future: impl Future<Output = T> + Send + 'static) -> Task<T> { + #[cfg(feature = "smol")] + let result: Task<T> = smol::spawn(future).into(); + #[cfg(feature = "tokio")] + let result: Task<T> = tokio::spawn(future).into(); + + result +} diff --git a/core/src/async_runtime/task.rs b/core/src/async_runtime/task.rs new file mode 100644 index 0000000..a681b0f --- /dev/null +++ b/core/src/async_runtime/task.rs @@ -0,0 +1,52 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::error::Error; + +pub struct Task<T> { + #[cfg(feature = "smol")] + inner_task: smol::Task<T>, + #[cfg(feature = "tokio")] + inner_task: tokio::task::JoinHandle<T>, +} + +impl<T> Task<T> { + pub async fn cancel(self) { + #[cfg(feature = "smol")] + self.inner_task.cancel().await; + #[cfg(feature = "tokio")] + self.inner_task.abort(); + } +} + +impl<T> Future for Task<T> { + type Output = Result<T, Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + #[cfg(feature = "smol")] + let result = smol::Task::poll(Pin::new(&mut self.inner_task), cx); + #[cfg(feature = "tokio")] + let result = tokio::task::JoinHandle::poll(Pin::new(&mut self.inner_task), cx); + + #[cfg(feature = "smol")] + return result.map(Ok); + + #[cfg(feature = "tokio")] + return result.map_err(|e| e.into()); + } +} + +#[cfg(feature = "smol")] +impl<T> From<smol::Task<T>> for Task<T> { + fn from(t: smol::Task<T>) -> Task<T> { + Task { inner_task: t } + } +} + +#[cfg(feature = "tokio")] +impl<T> From<tokio::task::JoinHandle<T>> for Task<T> { + fn from(t: tokio::task::JoinHandle<T>) -> Task<T> { + Task { inner_task: t } + } +} diff --git a/core/src/async_runtime/timer.rs b/core/src/async_runtime/timer.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/core/src/async_runtime/timer.rs @@ -0,0 +1 @@ + diff --git a/core/src/async_util/backoff.rs b/core/src/async_util/backoff.rs index 4a0ab35..70e63b3 100644 --- a/core/src/async_util/backoff.rs +++ b/core/src/async_util/backoff.rs @@ -4,7 +4,7 @@ use std::{ time::Duration, }; -use smol::Timer; +use super::sleep; /// Exponential backoff /// <https://en.wikipedia.org/wiki/Exponential_backoff> @@ -57,7 +57,7 @@ impl Backoff { /// Retruns the delay value. pub async fn sleep(&self) -> u64 { if self.stop.load(Ordering::SeqCst) { - Timer::after(Duration::from_millis(self.max_delay)).await; + sleep(Duration::from_millis(self.max_delay)).await; return self.max_delay; } @@ -71,7 +71,7 @@ impl Backoff { self.retries.store(retries + 1, Ordering::SeqCst); - Timer::after(Duration::from_millis(delay)).await; + sleep(Duration::from_millis(delay)).await; delay } @@ -84,15 +84,18 @@ impl Backoff { #[cfg(test)] mod tests { - use super::*; use std::sync::Arc; + use crate::async_runtime::{block_on, spawn}; + + use super::*; + #[test] fn test_backoff() { - smol::block_on(async move { + block_on(async move { let backoff = Arc::new(Backoff::new(5, 15)); let backoff_c = backoff.clone(); - smol::spawn(async move { + spawn(async move { let delay = backoff_c.sleep().await; assert_eq!(delay, 5); @@ -102,14 +105,16 @@ mod tests { let delay = backoff_c.sleep().await; assert_eq!(delay, 15); }) - .await; + .await + .unwrap(); - smol::spawn(async move { + spawn(async move { backoff.reset(); let delay = backoff.sleep().await; assert_eq!(delay, 5); }) - .await; + .await + .unwrap(); }); } } diff --git a/core/src/async_util/condvar.rs b/core/src/async_util/condvar.rs index d3bc15b..c3f373d 100644 --- a/core/src/async_util/condvar.rs +++ b/core/src/async_util/condvar.rs @@ -6,9 +6,7 @@ use std::{ task::{Context, Poll, Waker}, }; -use smol::lock::MutexGuard; - -use crate::util::random_16; +use crate::{async_runtime::lock::MutexGuard, util::random_16}; /// CondVar is an async version of <https://doc.rust-lang.org/std/sync/struct.Condvar.html> /// @@ -17,9 +15,8 @@ use crate::util::random_16; ///``` /// use std::sync::Arc; /// -/// use smol::lock::Mutex; -/// /// use karyon_core::async_util::CondVar; +/// use karyon_core::async_runtime::{spawn, lock::Mutex}; /// /// async { /// @@ -28,7 +25,7 @@ use crate::util::random_16; /// /// let val_cloned = val.clone(); /// let condvar_cloned = condvar.clone(); -/// smol::spawn(async move { +/// spawn(async move { /// let mut val = val_cloned.lock().await; /// /// // While the boolean flag is false, wait for a signal. @@ -40,7 +37,7 @@ use crate::util::random_16; /// }); /// /// let condvar_cloned = condvar.clone(); -/// smol::spawn(async move { +/// spawn(async move { /// let mut val = val.lock().await; /// /// // While the boolean flag is false, wait for a signal. @@ -71,7 +68,10 @@ impl CondVar { /// Blocks the current task until this condition variable receives a notification. pub async fn wait<'a, T>(&self, g: MutexGuard<'a, T>) -> MutexGuard<'a, T> { + #[cfg(feature = "smol")] let m = MutexGuard::source(&g); + #[cfg(feature = "tokio")] + let m = MutexGuard::mutex(&g); CondVarAwait::new(self, g).await; @@ -206,8 +206,6 @@ impl Wakers { #[cfg(test)] mod tests { - use super::*; - use smol::lock::Mutex; use std::{ collections::VecDeque, sync::{ @@ -216,6 +214,10 @@ mod tests { }, }; + use crate::async_runtime::{block_on, lock::Mutex, spawn}; + + use super::*; + // The tests below demonstrate a solution to a problem in the Wikipedia // explanation of condition variables: // https://en.wikipedia.org/wiki/Monitor_(synchronization)#Solving_the_bounded_producer/consumer_problem. @@ -243,7 +245,7 @@ mod tests { #[test] fn test_condvar_signal() { - smol::block_on(async { + block_on(async { let number_of_tasks = 30; let queue = Arc::new(Mutex::new(Queue::new(5))); @@ -254,7 +256,7 @@ mod tests { let condvar_full_cloned = condvar_full.clone(); let condvar_empty_cloned = condvar_empty.clone(); - let _producer1 = smol::spawn(async move { + let _producer1 = spawn(async move { for i in 1..number_of_tasks { // Lock queue mtuex let mut queue = queue_cloned.lock().await; @@ -275,7 +277,7 @@ mod tests { let queue_cloned = queue.clone(); let task_consumed = Arc::new(AtomicUsize::new(0)); let task_consumed_ = task_consumed.clone(); - let consumer = smol::spawn(async move { + let consumer = spawn(async move { for _ in 1..number_of_tasks { // Lock queue mtuex let mut queue = queue_cloned.lock().await; @@ -297,7 +299,7 @@ mod tests { } }); - consumer.await; + let _ = consumer.await; assert!(queue.lock().await.is_empty()); assert_eq!(task_consumed.load(Ordering::Relaxed), 29); }); @@ -305,7 +307,7 @@ mod tests { #[test] fn test_condvar_broadcast() { - smol::block_on(async { + block_on(async { let tasks = 30; let queue = Arc::new(Mutex::new(Queue::new(5))); @@ -313,7 +315,7 @@ mod tests { let queue_cloned = queue.clone(); let condvar_cloned = condvar.clone(); - let _producer1 = smol::spawn(async move { + let _producer1 = spawn(async move { for i in 1..tasks { // Lock queue mtuex let mut queue = queue_cloned.lock().await; @@ -333,7 +335,7 @@ mod tests { let queue_cloned = queue.clone(); let condvar_cloned = condvar.clone(); - let _producer2 = smol::spawn(async move { + let _producer2 = spawn(async move { for i in 1..tasks { // Lock queue mtuex let mut queue = queue_cloned.lock().await; @@ -355,7 +357,7 @@ mod tests { let task_consumed = Arc::new(AtomicUsize::new(0)); let task_consumed_ = task_consumed.clone(); - let consumer = smol::spawn(async move { + let consumer = spawn(async move { for _ in 1..((tasks * 2) - 1) { { // Lock queue mutex @@ -379,7 +381,7 @@ mod tests { } }); - consumer.await; + let _ = consumer.await; assert!(queue.lock().await.is_empty()); assert_eq!(task_consumed.load(Ordering::Relaxed), 58); }); diff --git a/core/src/async_util/condwait.rs b/core/src/async_util/condwait.rs index 6aa8a3c..76c6a05 100644 --- a/core/src/async_util/condwait.rs +++ b/core/src/async_util/condwait.rs @@ -1,6 +1,5 @@ -use smol::lock::Mutex; - use super::CondVar; +use crate::async_runtime::lock::Mutex; /// CondWait is a wrapper struct for CondVar with a Mutex boolean flag. /// @@ -10,11 +9,12 @@ use super::CondVar; /// use std::sync::Arc; /// /// use karyon_core::async_util::CondWait; +/// use karyon_core::async_runtime::spawn; /// /// async { /// let cond_wait = Arc::new(CondWait::new()); /// let cond_wait_cloned = cond_wait.clone(); -/// let task = smol::spawn(async move { +/// let task = spawn(async move { /// cond_wait_cloned.wait().await; /// // ... /// }); @@ -76,21 +76,24 @@ impl Default for CondWait { #[cfg(test)] mod tests { - use super::*; use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, }; + use crate::async_runtime::{block_on, spawn}; + + use super::*; + #[test] fn test_cond_wait() { - smol::block_on(async { + block_on(async { let cond_wait = Arc::new(CondWait::new()); let count = Arc::new(AtomicUsize::new(0)); let cond_wait_cloned = cond_wait.clone(); let count_cloned = count.clone(); - let task = smol::spawn(async move { + let task = spawn(async move { cond_wait_cloned.wait().await; count_cloned.fetch_add(1, Ordering::Relaxed); // do something @@ -99,7 +102,7 @@ mod tests { // Send a signal to the waiting task cond_wait.signal().await; - task.await; + let _ = task.await; // Reset the boolean flag cond_wait.reset().await; @@ -108,7 +111,7 @@ mod tests { let cond_wait_cloned = cond_wait.clone(); let count_cloned = count.clone(); - let task1 = smol::spawn(async move { + let task1 = spawn(async move { cond_wait_cloned.wait().await; count_cloned.fetch_add(1, Ordering::Relaxed); // do something @@ -116,7 +119,7 @@ mod tests { let cond_wait_cloned = cond_wait.clone(); let count_cloned = count.clone(); - let task2 = smol::spawn(async move { + let task2 = spawn(async move { cond_wait_cloned.wait().await; count_cloned.fetch_add(1, Ordering::Relaxed); // do something @@ -125,8 +128,8 @@ mod tests { // Broadcast a signal to all waiting tasks cond_wait.broadcast().await; - task1.await; - task2.await; + let _ = task1.await; + let _ = task2.await; assert_eq!(count.load(Ordering::Relaxed), 3); }); } diff --git a/core/src/async_util/executor.rs b/core/src/async_util/executor.rs deleted file mode 100644 index 3e7aa06..0000000 --- a/core/src/async_util/executor.rs +++ /dev/null @@ -1,30 +0,0 @@ -use std::{panic::catch_unwind, sync::Arc, thread}; - -use async_lock::OnceCell; -use smol::Executor as SmolEx; - -static GLOBAL_EXECUTOR: OnceCell<Arc<smol::Executor<'_>>> = OnceCell::new(); - -/// A pointer to an Executor -pub type Executor<'a> = Arc<SmolEx<'a>>; - -/// Returns a single-threaded global executor -pub(crate) fn global_executor() -> Executor<'static> { - fn init_executor() -> Executor<'static> { - let ex = smol::Executor::new(); - thread::Builder::new() - .spawn(|| loop { - catch_unwind(|| { - smol::block_on(global_executor().run(smol::future::pending::<()>())) - }) - .ok(); - }) - .expect("cannot spawn executor thread"); - // Prevent spawning another thread by running the process driver on this - // thread. see https://github.com/smol-rs/smol/blob/master/src/spawn.rs - ex.spawn(async_process::driver()).detach(); - Arc::new(ex) - } - - GLOBAL_EXECUTOR.get_or_init_blocking(init_executor).clone() -} diff --git a/core/src/async_util/mod.rs b/core/src/async_util/mod.rs index 2916118..54b9607 100644 --- a/core/src/async_util/mod.rs +++ b/core/src/async_util/mod.rs @@ -1,15 +1,15 @@ mod backoff; mod condvar; mod condwait; -mod executor; mod select; +mod sleep; mod task_group; mod timeout; pub use backoff::Backoff; pub use condvar::CondVar; pub use condwait::CondWait; -pub use executor::Executor; pub use select::{select, Either}; +pub use sleep::sleep; pub use task_group::{TaskGroup, TaskResult}; pub use timeout::timeout; diff --git a/core/src/async_util/select.rs b/core/src/async_util/select.rs index 0977fa9..2008cb5 100644 --- a/core/src/async_util/select.rs +++ b/core/src/async_util/select.rs @@ -1,8 +1,8 @@ +use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use pin_project_lite::pin_project; -use smol::future::Future; /// Returns the result of the future that completes first, preferring future1 /// if both are ready. @@ -75,14 +75,16 @@ where #[cfg(test)] mod tests { - use super::{select, Either}; - use smol::Timer; use std::future; + use crate::{async_runtime::block_on, async_util::sleep}; + + use super::{select, Either}; + #[test] fn test_async_select() { - smol::block_on(async move { - let fut = select(Timer::never(), future::ready(0 as u32)).await; + block_on(async move { + let fut = select(sleep(std::time::Duration::MAX), future::ready(0 as u32)).await; assert!(matches!(fut, Either::Right(0))); let fut1 = future::pending::<String>(); diff --git a/core/src/async_util/sleep.rs b/core/src/async_util/sleep.rs new file mode 100644 index 0000000..f72b825 --- /dev/null +++ b/core/src/async_util/sleep.rs @@ -0,0 +1,6 @@ +pub async fn sleep(duration: std::time::Duration) { + #[cfg(feature = "smol")] + smol::Timer::after(duration).await; + #[cfg(feature = "tokio")] + tokio::time::sleep(duration).await; +} diff --git a/core/src/async_util/task_group.rs b/core/src/async_util/task_group.rs index 7f05696..5af75ed 100644 --- a/core/src/async_util/task_group.rs +++ b/core/src/async_util/task_group.rs @@ -1,8 +1,8 @@ use std::{future::Future, sync::Arc, sync::Mutex}; -use async_task::FallibleTask; +use crate::async_runtime::{global_executor, Executor, Task}; -use super::{executor::global_executor, select, CondWait, Either, Executor}; +use super::{select, CondWait, Either}; /// TaskGroup A group that contains spawned tasks. /// @@ -12,28 +12,25 @@ use super::{executor::global_executor, select, CondWait, Either, Executor}; /// /// use std::sync::Arc; /// -/// use karyon_core::async_util::TaskGroup; +/// use karyon_core::async_util::{TaskGroup, sleep}; /// /// async { +/// let group = TaskGroup::new(); /// -/// let ex = Arc::new(smol::Executor::new()); -/// let group = TaskGroup::with_executor(ex); -/// -/// group.spawn(smol::Timer::never(), |_| async {}); +/// group.spawn(sleep(std::time::Duration::MAX), |_| async {}); /// /// group.cancel().await; /// /// }; /// /// ``` -/// -pub struct TaskGroup<'a> { +pub struct TaskGroup { tasks: Mutex<Vec<TaskHandler>>, stop_signal: Arc<CondWait>, - executor: Executor<'a>, + executor: Executor, } -impl TaskGroup<'static> { +impl TaskGroup { /// Creates a new TaskGroup without providing an executor /// /// This will spawn a task onto a global executor (single-threaded by default). @@ -44,11 +41,9 @@ impl TaskGroup<'static> { executor: global_executor(), } } -} -impl<'a> TaskGroup<'a> { /// Creates a new TaskGroup by providing an executor - pub fn with_executor(executor: Executor<'a>) -> Self { + pub fn with_executor(executor: Executor) -> Self { Self { tasks: Mutex::new(Vec::new()), stop_signal: Arc::new(CondWait::new()), @@ -61,10 +56,10 @@ impl<'a> TaskGroup<'a> { /// parameter, indicating whether the task completed or was canceled. pub fn spawn<T, Fut, CallbackF, CallbackFut>(&self, fut: Fut, callback: CallbackF) where - T: Send + Sync + 'a, - Fut: Future<Output = T> + Send + 'a, - CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'a, - CallbackFut: Future<Output = ()> + Send + 'a, + T: Send + Sync + 'static, + Fut: Future<Output = T> + Send + 'static, + CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'static, + CallbackFut: Future<Output = ()> + Send + 'static, { let task = TaskHandler::new( self.executor.clone(), @@ -100,7 +95,7 @@ impl<'a> TaskGroup<'a> { } } -impl Default for TaskGroup<'static> { +impl Default for TaskGroup { fn default() -> Self { Self::new() } @@ -124,42 +119,40 @@ impl<T: std::fmt::Debug> std::fmt::Display for TaskResult<T> { /// TaskHandler pub struct TaskHandler { - task: FallibleTask<()>, + task: Task<()>, cancel_flag: Arc<CondWait>, } impl<'a> TaskHandler { /// Creates a new task handler fn new<T, Fut, CallbackF, CallbackFut>( - ex: Executor<'a>, + ex: Executor, fut: Fut, callback: CallbackF, stop_signal: Arc<CondWait>, ) -> TaskHandler where - T: Send + Sync + 'a, - Fut: Future<Output = T> + Send + 'a, - CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'a, - CallbackFut: Future<Output = ()> + Send + 'a, + T: Send + Sync + 'static, + Fut: Future<Output = T> + Send + 'static, + CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'static, + CallbackFut: Future<Output = ()> + Send + 'static, { let cancel_flag = Arc::new(CondWait::new()); let cancel_flag_c = cancel_flag.clone(); - let task = ex - .spawn(async move { - // Waits for either the stop signal or the task to complete. - let result = select(stop_signal.wait(), fut).await; + let task = ex.spawn(async move { + // Waits for either the stop signal or the task to complete. + let result = select(stop_signal.wait(), fut).await; - let result = match result { - Either::Left(_) => TaskResult::Cancelled, - Either::Right(res) => TaskResult::Completed(res), - }; + let result = match result { + Either::Left(_) => TaskResult::Cancelled, + Either::Right(res) => TaskResult::Completed(res), + }; - // Call the callback - callback(result).await; + // Call the callback + callback(result).await; - cancel_flag_c.signal().await; - }) - .fallible(); + cancel_flag_c.signal().await; + }); TaskHandler { task, cancel_flag } } @@ -173,14 +166,52 @@ impl<'a> TaskHandler { #[cfg(test)] mod tests { - use super::*; use std::{future, sync::Arc}; + use crate::async_runtime::block_on; + use crate::async_util::sleep; + + use super::*; + + #[cfg(feature = "tokio")] + #[test] + fn test_task_group_with_tokio_executor() { + let ex = Arc::new(tokio::runtime::Runtime::new().unwrap()); + ex.clone().block_on(async move { + let group = Arc::new(TaskGroup::with_executor(ex.into())); + + group.spawn(future::ready(0), |res| async move { + assert!(matches!(res, TaskResult::Completed(0))); + }); + + group.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + + let groupc = group.clone(); + group.spawn( + async move { + groupc.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + }, + |res| async move { + assert!(matches!(res, TaskResult::Completed(_))); + }, + ); + + // Do something + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + group.cancel().await; + }); + } + + #[cfg(feature = "smol")] #[test] - fn test_task_group_with_executor() { + fn test_task_group_with_smol_executor() { let ex = Arc::new(smol::Executor::new()); smol::block_on(ex.clone().run(async move { - let group = Arc::new(TaskGroup::with_executor(ex)); + let group = Arc::new(TaskGroup::with_executor(ex.into())); group.spawn(future::ready(0), |res| async move { assert!(matches!(res, TaskResult::Completed(0))); @@ -210,7 +241,7 @@ mod tests { #[test] fn test_task_group() { - smol::block_on(async { + block_on(async { let group = Arc::new(TaskGroup::new()); group.spawn(future::ready(0), |res| async move { @@ -234,7 +265,7 @@ mod tests { ); // Do something - smol::Timer::after(std::time::Duration::from_millis(50)).await; + sleep(std::time::Duration::from_millis(50)).await; group.cancel().await; }); } diff --git a/core/src/async_util/timeout.rs b/core/src/async_util/timeout.rs index cf3c490..9ac64c8 100644 --- a/core/src/async_util/timeout.rs +++ b/core/src/async_util/timeout.rs @@ -1,10 +1,9 @@ use std::{future::Future, time::Duration}; -use smol::Timer; - -use super::{select, Either}; use crate::{error::Error, Result}; +use super::{select, sleep, Either}; + /// Waits for a future to complete or times out if it exceeds a specified /// duration. /// @@ -26,7 +25,7 @@ pub async fn timeout<T, F>(delay: Duration, future1: F) -> Result<T> where F: Future<Output = T>, { - let result = select(Timer::after(delay), future1).await; + let result = select(sleep(delay), future1).await; match result { Either::Left(_) => Err(Error::Timeout), @@ -41,11 +40,11 @@ mod tests { #[test] fn test_timeout() { - smol::block_on(async move { + crate::async_runtime::block_on(async move { let fut = future::pending::<()>(); assert!(timeout(Duration::from_millis(10), fut).await.is_err()); - let fut = smol::Timer::after(Duration::from_millis(10)); + let fut = sleep(Duration::from_millis(10)); assert!(timeout(Duration::from_millis(50), fut).await.is_ok()) }); } diff --git a/core/src/error.rs b/core/src/error.rs index cc60696..2b8f641 100644 --- a/core/src/error.rs +++ b/core/src/error.rs @@ -20,11 +20,15 @@ pub enum Error { #[error(transparent)] Ed25519(#[from] ed25519_dalek::ed25519::Error), + #[cfg(feature = "tokio")] + #[error(transparent)] + TokioJoinError(#[from] tokio::task::JoinError), + #[error("Channel Send Error: {0}")] ChannelSend(String), #[error(transparent)] - ChannelRecv(#[from] smol::channel::RecvError), + ChannelRecv(#[from] async_channel::RecvError), #[error(transparent)] BincodeDecode(#[from] bincode::error::DecodeError), @@ -33,8 +37,8 @@ pub enum Error { BincodeEncode(#[from] bincode::error::EncodeError), } -impl<T> From<smol::channel::SendError<T>> for Error { - fn from(error: smol::channel::SendError<T>) -> Self { +impl<T> From<async_channel::SendError<T>> for Error { + fn from(error: async_channel::SendError<T>) -> Self { Error::ChannelSend(error.to_string()) } } diff --git a/core/src/event.rs b/core/src/event.rs index ef40205..e8692ef 100644 --- a/core/src/event.rs +++ b/core/src/event.rs @@ -5,14 +5,11 @@ use std::{ sync::{Arc, Weak}, }; +use async_channel::{Receiver, Sender}; use chrono::{DateTime, Utc}; use log::{error, trace}; -use smol::{ - channel::{Receiver, Sender}, - lock::Mutex, -}; -use crate::{util::random_16, Result}; +use crate::{async_runtime::lock::Mutex, util::random_16, Result}; pub type ArcEventSys<T> = Arc<EventSys<T>>; pub type WeakEventSys<T> = Weak<EventSys<T>>; @@ -139,7 +136,7 @@ where self: &Arc<Self>, topic: &T, ) -> EventListener<T, E> { - let chan = smol::channel::unbounded(); + let chan = async_channel::unbounded(); let topics = &mut self.listeners.lock().await; @@ -310,6 +307,8 @@ pub trait EventValueTopic: EventValueAny + EventValue { #[cfg(test)] mod tests { + use crate::async_runtime::block_on; + use super::*; #[derive(Hash, PartialEq, Eq, Debug, Clone)] @@ -337,11 +336,6 @@ mod tests { } #[derive(Clone, Debug, PartialEq)] - struct D { - d_value: usize, - } - - #[derive(Clone, Debug, PartialEq)] struct E { e_value: usize, } @@ -369,12 +363,6 @@ mod tests { } } - impl EventValue for D { - fn id() -> &'static str { - "D" - } - } - impl EventValue for E { fn id() -> &'static str { "E" @@ -396,7 +384,7 @@ mod tests { #[test] fn test_event_sys() { - smol::block_on(async move { + block_on(async move { let event_sys = EventSys::<Topic>::new(); let a_listener = event_sys.register::<A>(&Topic::TopicA).await; diff --git a/core/src/lib.rs b/core/src/lib.rs index ae88188..62052a8 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,8 +1,13 @@ +#[cfg(all(feature = "smol", feature = "tokio"))] +compile_error!("Only one async runtime feature should be enabled"); + +#[cfg(not(any(feature = "smol", feature = "tokio")))] +compile_error!("At least one async runtime feature must be enabled for this crate."); + /// A set of helper tools and functions. pub mod util; -/// A module containing async utilities that work with the -/// [`smol`](https://github.com/smol-rs/smol) async runtime. +/// A set of async utilities. pub mod async_util; /// Represents karyon's Core Error. @@ -14,8 +19,12 @@ pub mod event; /// A simple publish-subscribe system [`Read More`](./pubsub/struct.Publisher.html) pub mod pubsub; +/// A cross-compatible async runtime +pub mod async_runtime; + #[cfg(feature = "crypto")] + /// Collects common cryptographic tools pub mod crypto; -use error::Result; +pub use error::{Error, Result}; diff --git a/core/src/pubsub.rs b/core/src/pubsub.rs index f5cb69b..bcc24ef 100644 --- a/core/src/pubsub.rs +++ b/core/src/pubsub.rs @@ -1,9 +1,8 @@ use std::{collections::HashMap, sync::Arc}; use log::error; -use smol::lock::Mutex; -use crate::{util::random_16, Result}; +use crate::{async_runtime::lock::Mutex, util::random_16, Result}; pub type ArcPublisher<T> = Arc<Publisher<T>>; pub type SubscriptionID = u16; @@ -28,7 +27,7 @@ pub type SubscriptionID = u16; /// /// ``` pub struct Publisher<T> { - subs: Mutex<HashMap<SubscriptionID, smol::channel::Sender<T>>>, + subs: Mutex<HashMap<SubscriptionID, async_channel::Sender<T>>>, } impl<T: Clone> Publisher<T> { @@ -43,7 +42,7 @@ impl<T: Clone> Publisher<T> { pub async fn subscribe(self: &Arc<Self>) -> Subscription<T> { let mut subs = self.subs.lock().await; - let chan = smol::channel::unbounded(); + let chan = async_channel::unbounded(); let mut sub_id = random_16(); @@ -84,7 +83,7 @@ impl<T: Clone> Publisher<T> { // Subscription pub struct Subscription<T> { id: SubscriptionID, - recv_chan: smol::channel::Receiver<T>, + recv_chan: async_channel::Receiver<T>, publisher: ArcPublisher<T>, } @@ -93,7 +92,7 @@ impl<T: Clone> Subscription<T> { pub fn new( id: SubscriptionID, publisher: ArcPublisher<T>, - recv_chan: smol::channel::Receiver<T>, + recv_chan: async_channel::Receiver<T>, ) -> Subscription<T> { Self { id, diff --git a/core/src/util/encode.rs b/core/src/util/encode.rs index 7d1061b..bf63671 100644 --- a/core/src/util/encode.rs +++ b/core/src/util/encode.rs @@ -1,15 +1,14 @@ use bincode::Encode; -use crate::Result; +use crate::{Error, Result}; /// Encode the given type `T` into a `Vec<u8>`. -pub fn encode<T: Encode>(msg: &T) -> Result<Vec<u8>> { - let vec = bincode::encode_to_vec(msg, bincode::config::standard())?; +pub fn encode<T: Encode>(src: &T) -> Result<Vec<u8>> { + let vec = bincode::encode_to_vec(src, bincode::config::standard())?; Ok(vec) } /// Encode the given type `T` into the given slice.. -pub fn encode_into_slice<T: Encode>(msg: &T, dst: &mut [u8]) -> Result<()> { - bincode::encode_into_slice(msg, dst, bincode::config::standard())?; - Ok(()) +pub fn encode_into_slice<T: Encode>(src: &T, dst: &mut [u8]) -> Result<usize> { + bincode::encode_into_slice(src, dst, bincode::config::standard()).map_err(Error::from) } |