aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorhozan23 <hozan23@karyontech.net>2024-04-11 10:19:20 +0200
committerhozan23 <hozan23@karyontech.net>2024-05-19 13:51:30 +0200
commit0992071a7f1a36424bcfaf1fbc84541ea041df1a (patch)
tree961d73218af672797d49f899289bef295bc56493 /core
parenta69917ecd8272a4946cfd12c75bf8f8c075b0e50 (diff)
add support for tokio & improve net crate api
Diffstat (limited to 'core')
-rw-r--r--core/Cargo.toml28
-rw-r--r--core/src/async_runtime/executor.rs100
-rw-r--r--core/src/async_runtime/io.rs9
-rw-r--r--core/src/async_runtime/lock.rs5
-rw-r--r--core/src/async_runtime/mod.rs25
-rw-r--r--core/src/async_runtime/net.rs12
-rw-r--r--core/src/async_runtime/spawn.rs12
-rw-r--r--core/src/async_runtime/task.rs52
-rw-r--r--core/src/async_runtime/timer.rs1
-rw-r--r--core/src/async_util/backoff.rs23
-rw-r--r--core/src/async_util/condvar.rs38
-rw-r--r--core/src/async_util/condwait.rs25
-rw-r--r--core/src/async_util/executor.rs30
-rw-r--r--core/src/async_util/mod.rs4
-rw-r--r--core/src/async_util/select.rs12
-rw-r--r--core/src/async_util/sleep.rs6
-rw-r--r--core/src/async_util/task_group.rs117
-rw-r--r--core/src/async_util/timeout.rs11
-rw-r--r--core/src/error.rs10
-rw-r--r--core/src/event.rs24
-rw-r--r--core/src/lib.rs15
-rw-r--r--core/src/pubsub.rs11
-rw-r--r--core/src/util/encode.rs11
23 files changed, 408 insertions, 173 deletions
diff --git a/core/Cargo.toml b/core/Cargo.toml
index c8e2b8d..4bb7f4f 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -1,13 +1,15 @@
[package]
name = "karyon_core"
-version.workspace = true
+version.workspace = true
edition.workspace = true
-
-# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+[features]
+default = ["smol"]
+crypto = ["dep:ed25519-dalek"]
+tokio = ["dep:tokio"]
+smol = ["dep:smol", "dep:async-process"]
[dependencies]
-smol = "2.0.0"
pin-project-lite = "0.2.13"
log = "0.4.21"
bincode = "2.0.0-rc.3"
@@ -15,15 +17,15 @@ chrono = "0.4.35"
rand = "0.8.5"
thiserror = "1.0.58"
dirs = "5.0.1"
-async-task = "4.7.0"
-async-lock = "3.3.0"
-async-process = "2.1.0"
-
-ed25519-dalek = { version = "2.1.1", features = ["rand_core"], optional = true}
+async-channel = "2.2.0"
+# crypto feature deps
+ed25519-dalek = { version = "2.1.1", features = ["rand_core"], optional = true }
-[features]
-default = []
-crypto = ["dep:ed25519-dalek"]
-
+# smol feature deps
+async-process = { version = "2.1.0", optional = true }
+smol = { version = "2.0.0", optional = true }
+# tokio feature deps
+tokio = { version = "1.37.0", features = ["full"], optional = true }
+once_cell = "1.19.0"
diff --git a/core/src/async_runtime/executor.rs b/core/src/async_runtime/executor.rs
new file mode 100644
index 0000000..9335f12
--- /dev/null
+++ b/core/src/async_runtime/executor.rs
@@ -0,0 +1,100 @@
+use std::{future::Future, panic::catch_unwind, sync::Arc, thread};
+
+use once_cell::sync::OnceCell;
+
+#[cfg(feature = "smol")]
+pub use smol::Executor as SmolEx;
+
+#[cfg(feature = "tokio")]
+pub use tokio::runtime::Runtime;
+
+use super::Task;
+
+#[derive(Clone)]
+pub struct Executor {
+ #[cfg(feature = "smol")]
+ inner: Arc<SmolEx<'static>>,
+ #[cfg(feature = "tokio")]
+ inner: Arc<Runtime>,
+}
+
+impl Executor {
+ pub fn spawn<T: Send + 'static>(
+ &self,
+ future: impl Future<Output = T> + Send + 'static,
+ ) -> Task<T> {
+ self.inner.spawn(future).into()
+ }
+}
+
+static GLOBAL_EXECUTOR: OnceCell<Executor> = OnceCell::new();
+
+/// Returns a single-threaded global executor
+pub fn global_executor() -> Executor {
+ #[cfg(feature = "smol")]
+ fn init_executor() -> Executor {
+ let ex = smol::Executor::new();
+ thread::Builder::new()
+ .name("smol-executor".to_string())
+ .spawn(|| loop {
+ catch_unwind(|| {
+ smol::block_on(global_executor().inner.run(std::future::pending::<()>()))
+ })
+ .ok();
+ })
+ .expect("cannot spawn executor thread");
+ // Prevent spawning another thread by running the process driver on this
+ // thread. see https://github.com/smol-rs/smol/blob/master/src/spawn.rs
+ ex.spawn(async_process::driver()).detach();
+ Executor {
+ inner: Arc::new(ex),
+ }
+ }
+
+ #[cfg(feature = "tokio")]
+ fn init_executor() -> Executor {
+ let ex = Arc::new(tokio::runtime::Runtime::new().expect("cannot build tokio runtime"));
+ let ex_cloned = ex.clone();
+ thread::Builder::new()
+ .name("tokio-executor".to_string())
+ .spawn(move || {
+ catch_unwind(|| ex_cloned.block_on(std::future::pending::<()>())).ok();
+ })
+ .expect("cannot spawn tokio runtime thread");
+ Executor { inner: ex }
+ }
+
+ GLOBAL_EXECUTOR.get_or_init(init_executor).clone()
+}
+
+#[cfg(feature = "smol")]
+impl From<Arc<smol::Executor<'static>>> for Executor {
+ fn from(ex: Arc<smol::Executor<'static>>) -> Executor {
+ Executor { inner: ex }
+ }
+}
+
+#[cfg(feature = "tokio")]
+impl From<Arc<tokio::runtime::Runtime>> for Executor {
+ fn from(rt: Arc<tokio::runtime::Runtime>) -> Executor {
+ Executor { inner: rt }
+ }
+}
+
+#[cfg(feature = "smol")]
+impl From<smol::Executor<'static>> for Executor {
+ fn from(ex: smol::Executor<'static>) -> Executor {
+ Executor {
+ inner: Arc::new(ex),
+ }
+ }
+}
+
+#[cfg(feature = "tokio")]
+impl From<tokio::runtime::Runtime> for Executor {
+ fn from(rt: tokio::runtime::Runtime) -> Executor {
+ Executor {
+ inner: Arc::new(rt),
+ }
+ }
+}
diff --git a/core/src/async_runtime/io.rs b/core/src/async_runtime/io.rs
new file mode 100644
index 0000000..161c258
--- /dev/null
+++ b/core/src/async_runtime/io.rs
@@ -0,0 +1,9 @@
+#[cfg(feature = "smol")]
+pub use smol::io::{
+ split, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf,
+};
+
+#[cfg(feature = "tokio")]
+pub use tokio::io::{
+ split, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf,
+};
diff --git a/core/src/async_runtime/lock.rs b/core/src/async_runtime/lock.rs
new file mode 100644
index 0000000..fc84d1d
--- /dev/null
+++ b/core/src/async_runtime/lock.rs
@@ -0,0 +1,5 @@
+#[cfg(feature = "smol")]
+pub use smol::lock::{Mutex, MutexGuard, OnceCell, RwLock};
+
+#[cfg(feature = "tokio")]
+pub use tokio::sync::{Mutex, MutexGuard, OnceCell, RwLock};
diff --git a/core/src/async_runtime/mod.rs b/core/src/async_runtime/mod.rs
new file mode 100644
index 0000000..d91d01b
--- /dev/null
+++ b/core/src/async_runtime/mod.rs
@@ -0,0 +1,25 @@
+mod executor;
+pub mod io;
+pub mod lock;
+pub mod net;
+mod spawn;
+mod task;
+mod timer;
+
+pub use executor::{global_executor, Executor};
+pub use spawn::spawn;
+pub use task::Task;
+
+#[cfg(test)]
+pub fn block_on<T>(future: impl std::future::Future<Output = T>) -> T {
+ #[cfg(feature = "smol")]
+ let result = smol::block_on(future);
+ #[cfg(feature = "tokio")]
+ let result = tokio::runtime::Builder::new_current_thread()
+ .enable_all()
+ .build()
+ .unwrap()
+ .block_on(future);
+
+ result
+}
diff --git a/core/src/async_runtime/net.rs b/core/src/async_runtime/net.rs
new file mode 100644
index 0000000..5c004ce
--- /dev/null
+++ b/core/src/async_runtime/net.rs
@@ -0,0 +1,12 @@
+pub use std::os::unix::net::SocketAddr;
+
+#[cfg(feature = "smol")]
+pub use smol::net::{
+ unix::{SocketAddr as UnixSocketAddr, UnixListener, UnixStream},
+ TcpListener, TcpStream, UdpSocket,
+};
+
+#[cfg(feature = "tokio")]
+pub use tokio::net::{
+ unix::SocketAddr as UnixSocketAddr, TcpListener, TcpStream, UdpSocket, UnixListener, UnixStream,
+};
diff --git a/core/src/async_runtime/spawn.rs b/core/src/async_runtime/spawn.rs
new file mode 100644
index 0000000..2760982
--- /dev/null
+++ b/core/src/async_runtime/spawn.rs
@@ -0,0 +1,12 @@
+use std::future::Future;
+
+use super::Task;
+
+pub fn spawn<T: Send + 'static>(future: impl Future<Output = T> + Send + 'static) -> Task<T> {
+ #[cfg(feature = "smol")]
+ let result: Task<T> = smol::spawn(future).into();
+ #[cfg(feature = "tokio")]
+ let result: Task<T> = tokio::spawn(future).into();
+
+ result
+}
diff --git a/core/src/async_runtime/task.rs b/core/src/async_runtime/task.rs
new file mode 100644
index 0000000..a681b0f
--- /dev/null
+++ b/core/src/async_runtime/task.rs
@@ -0,0 +1,52 @@
+use std::future::Future;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+use crate::error::Error;
+
+pub struct Task<T> {
+ #[cfg(feature = "smol")]
+ inner_task: smol::Task<T>,
+ #[cfg(feature = "tokio")]
+ inner_task: tokio::task::JoinHandle<T>,
+}
+
+impl<T> Task<T> {
+ pub async fn cancel(self) {
+ #[cfg(feature = "smol")]
+ self.inner_task.cancel().await;
+ #[cfg(feature = "tokio")]
+ self.inner_task.abort();
+ }
+}
+
+impl<T> Future for Task<T> {
+ type Output = Result<T, Error>;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ #[cfg(feature = "smol")]
+ let result = smol::Task::poll(Pin::new(&mut self.inner_task), cx);
+ #[cfg(feature = "tokio")]
+ let result = tokio::task::JoinHandle::poll(Pin::new(&mut self.inner_task), cx);
+
+ #[cfg(feature = "smol")]
+ return result.map(Ok);
+
+ #[cfg(feature = "tokio")]
+ return result.map_err(|e| e.into());
+ }
+}
+
+#[cfg(feature = "smol")]
+impl<T> From<smol::Task<T>> for Task<T> {
+ fn from(t: smol::Task<T>) -> Task<T> {
+ Task { inner_task: t }
+ }
+}
+
+#[cfg(feature = "tokio")]
+impl<T> From<tokio::task::JoinHandle<T>> for Task<T> {
+ fn from(t: tokio::task::JoinHandle<T>) -> Task<T> {
+ Task { inner_task: t }
+ }
+}
diff --git a/core/src/async_runtime/timer.rs b/core/src/async_runtime/timer.rs
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/core/src/async_runtime/timer.rs
@@ -0,0 +1 @@
+
diff --git a/core/src/async_util/backoff.rs b/core/src/async_util/backoff.rs
index 4a0ab35..70e63b3 100644
--- a/core/src/async_util/backoff.rs
+++ b/core/src/async_util/backoff.rs
@@ -4,7 +4,7 @@ use std::{
time::Duration,
};
-use smol::Timer;
+use super::sleep;
/// Exponential backoff
/// <https://en.wikipedia.org/wiki/Exponential_backoff>
@@ -57,7 +57,7 @@ impl Backoff {
/// Retruns the delay value.
pub async fn sleep(&self) -> u64 {
if self.stop.load(Ordering::SeqCst) {
- Timer::after(Duration::from_millis(self.max_delay)).await;
+ sleep(Duration::from_millis(self.max_delay)).await;
return self.max_delay;
}
@@ -71,7 +71,7 @@ impl Backoff {
self.retries.store(retries + 1, Ordering::SeqCst);
- Timer::after(Duration::from_millis(delay)).await;
+ sleep(Duration::from_millis(delay)).await;
delay
}
@@ -84,15 +84,18 @@ impl Backoff {
#[cfg(test)]
mod tests {
- use super::*;
use std::sync::Arc;
+ use crate::async_runtime::{block_on, spawn};
+
+ use super::*;
+
#[test]
fn test_backoff() {
- smol::block_on(async move {
+ block_on(async move {
let backoff = Arc::new(Backoff::new(5, 15));
let backoff_c = backoff.clone();
- smol::spawn(async move {
+ spawn(async move {
let delay = backoff_c.sleep().await;
assert_eq!(delay, 5);
@@ -102,14 +105,16 @@ mod tests {
let delay = backoff_c.sleep().await;
assert_eq!(delay, 15);
})
- .await;
+ .await
+ .unwrap();
- smol::spawn(async move {
+ spawn(async move {
backoff.reset();
let delay = backoff.sleep().await;
assert_eq!(delay, 5);
})
- .await;
+ .await
+ .unwrap();
});
}
}
diff --git a/core/src/async_util/condvar.rs b/core/src/async_util/condvar.rs
index d3bc15b..c3f373d 100644
--- a/core/src/async_util/condvar.rs
+++ b/core/src/async_util/condvar.rs
@@ -6,9 +6,7 @@ use std::{
task::{Context, Poll, Waker},
};
-use smol::lock::MutexGuard;
-
-use crate::util::random_16;
+use crate::{async_runtime::lock::MutexGuard, util::random_16};
/// CondVar is an async version of <https://doc.rust-lang.org/std/sync/struct.Condvar.html>
///
@@ -17,9 +15,8 @@ use crate::util::random_16;
///```
/// use std::sync::Arc;
///
-/// use smol::lock::Mutex;
-///
/// use karyon_core::async_util::CondVar;
+/// use karyon_core::async_runtime::{spawn, lock::Mutex};
///
/// async {
///
@@ -28,7 +25,7 @@ use crate::util::random_16;
///
/// let val_cloned = val.clone();
/// let condvar_cloned = condvar.clone();
-/// smol::spawn(async move {
+/// spawn(async move {
/// let mut val = val_cloned.lock().await;
///
/// // While the boolean flag is false, wait for a signal.
@@ -40,7 +37,7 @@ use crate::util::random_16;
/// });
///
/// let condvar_cloned = condvar.clone();
-/// smol::spawn(async move {
+/// spawn(async move {
/// let mut val = val.lock().await;
///
/// // While the boolean flag is false, wait for a signal.
@@ -71,7 +68,10 @@ impl CondVar {
/// Blocks the current task until this condition variable receives a notification.
pub async fn wait<'a, T>(&self, g: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
+ #[cfg(feature = "smol")]
let m = MutexGuard::source(&g);
+ #[cfg(feature = "tokio")]
+ let m = MutexGuard::mutex(&g);
CondVarAwait::new(self, g).await;
@@ -206,8 +206,6 @@ impl Wakers {
#[cfg(test)]
mod tests {
- use super::*;
- use smol::lock::Mutex;
use std::{
collections::VecDeque,
sync::{
@@ -216,6 +214,10 @@ mod tests {
},
};
+ use crate::async_runtime::{block_on, lock::Mutex, spawn};
+
+ use super::*;
+
// The tests below demonstrate a solution to a problem in the Wikipedia
// explanation of condition variables:
// https://en.wikipedia.org/wiki/Monitor_(synchronization)#Solving_the_bounded_producer/consumer_problem.
@@ -243,7 +245,7 @@ mod tests {
#[test]
fn test_condvar_signal() {
- smol::block_on(async {
+ block_on(async {
let number_of_tasks = 30;
let queue = Arc::new(Mutex::new(Queue::new(5)));
@@ -254,7 +256,7 @@ mod tests {
let condvar_full_cloned = condvar_full.clone();
let condvar_empty_cloned = condvar_empty.clone();
- let _producer1 = smol::spawn(async move {
+ let _producer1 = spawn(async move {
for i in 1..number_of_tasks {
// Lock queue mtuex
let mut queue = queue_cloned.lock().await;
@@ -275,7 +277,7 @@ mod tests {
let queue_cloned = queue.clone();
let task_consumed = Arc::new(AtomicUsize::new(0));
let task_consumed_ = task_consumed.clone();
- let consumer = smol::spawn(async move {
+ let consumer = spawn(async move {
for _ in 1..number_of_tasks {
// Lock queue mtuex
let mut queue = queue_cloned.lock().await;
@@ -297,7 +299,7 @@ mod tests {
}
});
- consumer.await;
+ let _ = consumer.await;
assert!(queue.lock().await.is_empty());
assert_eq!(task_consumed.load(Ordering::Relaxed), 29);
});
@@ -305,7 +307,7 @@ mod tests {
#[test]
fn test_condvar_broadcast() {
- smol::block_on(async {
+ block_on(async {
let tasks = 30;
let queue = Arc::new(Mutex::new(Queue::new(5)));
@@ -313,7 +315,7 @@ mod tests {
let queue_cloned = queue.clone();
let condvar_cloned = condvar.clone();
- let _producer1 = smol::spawn(async move {
+ let _producer1 = spawn(async move {
for i in 1..tasks {
// Lock queue mtuex
let mut queue = queue_cloned.lock().await;
@@ -333,7 +335,7 @@ mod tests {
let queue_cloned = queue.clone();
let condvar_cloned = condvar.clone();
- let _producer2 = smol::spawn(async move {
+ let _producer2 = spawn(async move {
for i in 1..tasks {
// Lock queue mtuex
let mut queue = queue_cloned.lock().await;
@@ -355,7 +357,7 @@ mod tests {
let task_consumed = Arc::new(AtomicUsize::new(0));
let task_consumed_ = task_consumed.clone();
- let consumer = smol::spawn(async move {
+ let consumer = spawn(async move {
for _ in 1..((tasks * 2) - 1) {
{
// Lock queue mutex
@@ -379,7 +381,7 @@ mod tests {
}
});
- consumer.await;
+ let _ = consumer.await;
assert!(queue.lock().await.is_empty());
assert_eq!(task_consumed.load(Ordering::Relaxed), 58);
});
diff --git a/core/src/async_util/condwait.rs b/core/src/async_util/condwait.rs
index 6aa8a3c..76c6a05 100644
--- a/core/src/async_util/condwait.rs
+++ b/core/src/async_util/condwait.rs
@@ -1,6 +1,5 @@
-use smol::lock::Mutex;
-
use super::CondVar;
+use crate::async_runtime::lock::Mutex;
/// CondWait is a wrapper struct for CondVar with a Mutex boolean flag.
///
@@ -10,11 +9,12 @@ use super::CondVar;
/// use std::sync::Arc;
///
/// use karyon_core::async_util::CondWait;
+/// use karyon_core::async_runtime::spawn;
///
/// async {
/// let cond_wait = Arc::new(CondWait::new());
/// let cond_wait_cloned = cond_wait.clone();
-/// let task = smol::spawn(async move {
+/// let task = spawn(async move {
/// cond_wait_cloned.wait().await;
/// // ...
/// });
@@ -76,21 +76,24 @@ impl Default for CondWait {
#[cfg(test)]
mod tests {
- use super::*;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
+ use crate::async_runtime::{block_on, spawn};
+
+ use super::*;
+
#[test]
fn test_cond_wait() {
- smol::block_on(async {
+ block_on(async {
let cond_wait = Arc::new(CondWait::new());
let count = Arc::new(AtomicUsize::new(0));
let cond_wait_cloned = cond_wait.clone();
let count_cloned = count.clone();
- let task = smol::spawn(async move {
+ let task = spawn(async move {
cond_wait_cloned.wait().await;
count_cloned.fetch_add(1, Ordering::Relaxed);
// do something
@@ -99,7 +102,7 @@ mod tests {
// Send a signal to the waiting task
cond_wait.signal().await;
- task.await;
+ let _ = task.await;
// Reset the boolean flag
cond_wait.reset().await;
@@ -108,7 +111,7 @@ mod tests {
let cond_wait_cloned = cond_wait.clone();
let count_cloned = count.clone();
- let task1 = smol::spawn(async move {
+ let task1 = spawn(async move {
cond_wait_cloned.wait().await;
count_cloned.fetch_add(1, Ordering::Relaxed);
// do something
@@ -116,7 +119,7 @@ mod tests {
let cond_wait_cloned = cond_wait.clone();
let count_cloned = count.clone();
- let task2 = smol::spawn(async move {
+ let task2 = spawn(async move {
cond_wait_cloned.wait().await;
count_cloned.fetch_add(1, Ordering::Relaxed);
// do something
@@ -125,8 +128,8 @@ mod tests {
// Broadcast a signal to all waiting tasks
cond_wait.broadcast().await;
- task1.await;
- task2.await;
+ let _ = task1.await;
+ let _ = task2.await;
assert_eq!(count.load(Ordering::Relaxed), 3);
});
}
diff --git a/core/src/async_util/executor.rs b/core/src/async_util/executor.rs
deleted file mode 100644
index 3e7aa06..0000000
--- a/core/src/async_util/executor.rs
+++ /dev/null
@@ -1,30 +0,0 @@
-use std::{panic::catch_unwind, sync::Arc, thread};
-
-use async_lock::OnceCell;
-use smol::Executor as SmolEx;
-
-static GLOBAL_EXECUTOR: OnceCell<Arc<smol::Executor<'_>>> = OnceCell::new();
-
-/// A pointer to an Executor
-pub type Executor<'a> = Arc<SmolEx<'a>>;
-
-/// Returns a single-threaded global executor
-pub(crate) fn global_executor() -> Executor<'static> {
- fn init_executor() -> Executor<'static> {
- let ex = smol::Executor::new();
- thread::Builder::new()
- .spawn(|| loop {
- catch_unwind(|| {
- smol::block_on(global_executor().run(smol::future::pending::<()>()))
- })
- .ok();
- })
- .expect("cannot spawn executor thread");
- // Prevent spawning another thread by running the process driver on this
- // thread. see https://github.com/smol-rs/smol/blob/master/src/spawn.rs
- ex.spawn(async_process::driver()).detach();
- Arc::new(ex)
- }
-
- GLOBAL_EXECUTOR.get_or_init_blocking(init_executor).clone()
-}
diff --git a/core/src/async_util/mod.rs b/core/src/async_util/mod.rs
index 2916118..54b9607 100644
--- a/core/src/async_util/mod.rs
+++ b/core/src/async_util/mod.rs
@@ -1,15 +1,15 @@
mod backoff;
mod condvar;
mod condwait;
-mod executor;
mod select;
+mod sleep;
mod task_group;
mod timeout;
pub use backoff::Backoff;
pub use condvar::CondVar;
pub use condwait::CondWait;
-pub use executor::Executor;
pub use select::{select, Either};
+pub use sleep::sleep;
pub use task_group::{TaskGroup, TaskResult};
pub use timeout::timeout;
diff --git a/core/src/async_util/select.rs b/core/src/async_util/select.rs
index 0977fa9..2008cb5 100644
--- a/core/src/async_util/select.rs
+++ b/core/src/async_util/select.rs
@@ -1,8 +1,8 @@
+use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use pin_project_lite::pin_project;
-use smol::future::Future;
/// Returns the result of the future that completes first, preferring future1
/// if both are ready.
@@ -75,14 +75,16 @@ where
#[cfg(test)]
mod tests {
- use super::{select, Either};
- use smol::Timer;
use std::future;
+ use crate::{async_runtime::block_on, async_util::sleep};
+
+ use super::{select, Either};
+
#[test]
fn test_async_select() {
- smol::block_on(async move {
- let fut = select(Timer::never(), future::ready(0 as u32)).await;
+ block_on(async move {
+ let fut = select(sleep(std::time::Duration::MAX), future::ready(0 as u32)).await;
assert!(matches!(fut, Either::Right(0)));
let fut1 = future::pending::<String>();
diff --git a/core/src/async_util/sleep.rs b/core/src/async_util/sleep.rs
new file mode 100644
index 0000000..f72b825
--- /dev/null
+++ b/core/src/async_util/sleep.rs
@@ -0,0 +1,6 @@
+pub async fn sleep(duration: std::time::Duration) {
+ #[cfg(feature = "smol")]
+ smol::Timer::after(duration).await;
+ #[cfg(feature = "tokio")]
+ tokio::time::sleep(duration).await;
+}
diff --git a/core/src/async_util/task_group.rs b/core/src/async_util/task_group.rs
index 7f05696..5af75ed 100644
--- a/core/src/async_util/task_group.rs
+++ b/core/src/async_util/task_group.rs
@@ -1,8 +1,8 @@
use std::{future::Future, sync::Arc, sync::Mutex};
-use async_task::FallibleTask;
+use crate::async_runtime::{global_executor, Executor, Task};
-use super::{executor::global_executor, select, CondWait, Either, Executor};
+use super::{select, CondWait, Either};
/// TaskGroup A group that contains spawned tasks.
///
@@ -12,28 +12,25 @@ use super::{executor::global_executor, select, CondWait, Either, Executor};
///
/// use std::sync::Arc;
///
-/// use karyon_core::async_util::TaskGroup;
+/// use karyon_core::async_util::{TaskGroup, sleep};
///
/// async {
+/// let group = TaskGroup::new();
///
-/// let ex = Arc::new(smol::Executor::new());
-/// let group = TaskGroup::with_executor(ex);
-///
-/// group.spawn(smol::Timer::never(), |_| async {});
+/// group.spawn(sleep(std::time::Duration::MAX), |_| async {});
///
/// group.cancel().await;
///
/// };
///
/// ```
-///
-pub struct TaskGroup<'a> {
+pub struct TaskGroup {
tasks: Mutex<Vec<TaskHandler>>,
stop_signal: Arc<CondWait>,
- executor: Executor<'a>,
+ executor: Executor,
}
-impl TaskGroup<'static> {
+impl TaskGroup {
/// Creates a new TaskGroup without providing an executor
///
/// This will spawn a task onto a global executor (single-threaded by default).
@@ -44,11 +41,9 @@ impl TaskGroup<'static> {
executor: global_executor(),
}
}
-}
-impl<'a> TaskGroup<'a> {
/// Creates a new TaskGroup by providing an executor
- pub fn with_executor(executor: Executor<'a>) -> Self {
+ pub fn with_executor(executor: Executor) -> Self {
Self {
tasks: Mutex::new(Vec::new()),
stop_signal: Arc::new(CondWait::new()),
@@ -61,10 +56,10 @@ impl<'a> TaskGroup<'a> {
/// parameter, indicating whether the task completed or was canceled.
pub fn spawn<T, Fut, CallbackF, CallbackFut>(&self, fut: Fut, callback: CallbackF)
where
- T: Send + Sync + 'a,
- Fut: Future<Output = T> + Send + 'a,
- CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'a,
- CallbackFut: Future<Output = ()> + Send + 'a,
+ T: Send + Sync + 'static,
+ Fut: Future<Output = T> + Send + 'static,
+ CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'static,
+ CallbackFut: Future<Output = ()> + Send + 'static,
{
let task = TaskHandler::new(
self.executor.clone(),
@@ -100,7 +95,7 @@ impl<'a> TaskGroup<'a> {
}
}
-impl Default for TaskGroup<'static> {
+impl Default for TaskGroup {
fn default() -> Self {
Self::new()
}
@@ -124,42 +119,40 @@ impl<T: std::fmt::Debug> std::fmt::Display for TaskResult<T> {
/// TaskHandler
pub struct TaskHandler {
- task: FallibleTask<()>,
+ task: Task<()>,
cancel_flag: Arc<CondWait>,
}
impl<'a> TaskHandler {
/// Creates a new task handler
fn new<T, Fut, CallbackF, CallbackFut>(
- ex: Executor<'a>,
+ ex: Executor,
fut: Fut,
callback: CallbackF,
stop_signal: Arc<CondWait>,
) -> TaskHandler
where
- T: Send + Sync + 'a,
- Fut: Future<Output = T> + Send + 'a,
- CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'a,
- CallbackFut: Future<Output = ()> + Send + 'a,
+ T: Send + Sync + 'static,
+ Fut: Future<Output = T> + Send + 'static,
+ CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'static,
+ CallbackFut: Future<Output = ()> + Send + 'static,
{
let cancel_flag = Arc::new(CondWait::new());
let cancel_flag_c = cancel_flag.clone();
- let task = ex
- .spawn(async move {
- // Waits for either the stop signal or the task to complete.
- let result = select(stop_signal.wait(), fut).await;
+ let task = ex.spawn(async move {
+ // Waits for either the stop signal or the task to complete.
+ let result = select(stop_signal.wait(), fut).await;
- let result = match result {
- Either::Left(_) => TaskResult::Cancelled,
- Either::Right(res) => TaskResult::Completed(res),
- };
+ let result = match result {
+ Either::Left(_) => TaskResult::Cancelled,
+ Either::Right(res) => TaskResult::Completed(res),
+ };
- // Call the callback
- callback(result).await;
+ // Call the callback
+ callback(result).await;
- cancel_flag_c.signal().await;
- })
- .fallible();
+ cancel_flag_c.signal().await;
+ });
TaskHandler { task, cancel_flag }
}
@@ -173,14 +166,52 @@ impl<'a> TaskHandler {
#[cfg(test)]
mod tests {
- use super::*;
use std::{future, sync::Arc};
+ use crate::async_runtime::block_on;
+ use crate::async_util::sleep;
+
+ use super::*;
+
+ #[cfg(feature = "tokio")]
+ #[test]
+ fn test_task_group_with_tokio_executor() {
+ let ex = Arc::new(tokio::runtime::Runtime::new().unwrap());
+ ex.clone().block_on(async move {
+ let group = Arc::new(TaskGroup::with_executor(ex.into()));
+
+ group.spawn(future::ready(0), |res| async move {
+ assert!(matches!(res, TaskResult::Completed(0)));
+ });
+
+ group.spawn(future::pending::<()>(), |res| async move {
+ assert!(matches!(res, TaskResult::Cancelled));
+ });
+
+ let groupc = group.clone();
+ group.spawn(
+ async move {
+ groupc.spawn(future::pending::<()>(), |res| async move {
+ assert!(matches!(res, TaskResult::Cancelled));
+ });
+ },
+ |res| async move {
+ assert!(matches!(res, TaskResult::Completed(_)));
+ },
+ );
+
+ // Do something
+ tokio::time::sleep(std::time::Duration::from_millis(50)).await;
+ group.cancel().await;
+ });
+ }
+
+ #[cfg(feature = "smol")]
#[test]
- fn test_task_group_with_executor() {
+ fn test_task_group_with_smol_executor() {
let ex = Arc::new(smol::Executor::new());
smol::block_on(ex.clone().run(async move {
- let group = Arc::new(TaskGroup::with_executor(ex));
+ let group = Arc::new(TaskGroup::with_executor(ex.into()));
group.spawn(future::ready(0), |res| async move {
assert!(matches!(res, TaskResult::Completed(0)));
@@ -210,7 +241,7 @@ mod tests {
#[test]
fn test_task_group() {
- smol::block_on(async {
+ block_on(async {
let group = Arc::new(TaskGroup::new());
group.spawn(future::ready(0), |res| async move {
@@ -234,7 +265,7 @@ mod tests {
);
// Do something
- smol::Timer::after(std::time::Duration::from_millis(50)).await;
+ sleep(std::time::Duration::from_millis(50)).await;
group.cancel().await;
});
}
diff --git a/core/src/async_util/timeout.rs b/core/src/async_util/timeout.rs
index cf3c490..9ac64c8 100644
--- a/core/src/async_util/timeout.rs
+++ b/core/src/async_util/timeout.rs
@@ -1,10 +1,9 @@
use std::{future::Future, time::Duration};
-use smol::Timer;
-
-use super::{select, Either};
use crate::{error::Error, Result};
+use super::{select, sleep, Either};
+
/// Waits for a future to complete or times out if it exceeds a specified
/// duration.
///
@@ -26,7 +25,7 @@ pub async fn timeout<T, F>(delay: Duration, future1: F) -> Result<T>
where
F: Future<Output = T>,
{
- let result = select(Timer::after(delay), future1).await;
+ let result = select(sleep(delay), future1).await;
match result {
Either::Left(_) => Err(Error::Timeout),
@@ -41,11 +40,11 @@ mod tests {
#[test]
fn test_timeout() {
- smol::block_on(async move {
+ crate::async_runtime::block_on(async move {
let fut = future::pending::<()>();
assert!(timeout(Duration::from_millis(10), fut).await.is_err());
- let fut = smol::Timer::after(Duration::from_millis(10));
+ let fut = sleep(Duration::from_millis(10));
assert!(timeout(Duration::from_millis(50), fut).await.is_ok())
});
}
diff --git a/core/src/error.rs b/core/src/error.rs
index cc60696..2b8f641 100644
--- a/core/src/error.rs
+++ b/core/src/error.rs
@@ -20,11 +20,15 @@ pub enum Error {
#[error(transparent)]
Ed25519(#[from] ed25519_dalek::ed25519::Error),
+ #[cfg(feature = "tokio")]
+ #[error(transparent)]
+ TokioJoinError(#[from] tokio::task::JoinError),
+
#[error("Channel Send Error: {0}")]
ChannelSend(String),
#[error(transparent)]
- ChannelRecv(#[from] smol::channel::RecvError),
+ ChannelRecv(#[from] async_channel::RecvError),
#[error(transparent)]
BincodeDecode(#[from] bincode::error::DecodeError),
@@ -33,8 +37,8 @@ pub enum Error {
BincodeEncode(#[from] bincode::error::EncodeError),
}
-impl<T> From<smol::channel::SendError<T>> for Error {
- fn from(error: smol::channel::SendError<T>) -> Self {
+impl<T> From<async_channel::SendError<T>> for Error {
+ fn from(error: async_channel::SendError<T>) -> Self {
Error::ChannelSend(error.to_string())
}
}
diff --git a/core/src/event.rs b/core/src/event.rs
index ef40205..e8692ef 100644
--- a/core/src/event.rs
+++ b/core/src/event.rs
@@ -5,14 +5,11 @@ use std::{
sync::{Arc, Weak},
};
+use async_channel::{Receiver, Sender};
use chrono::{DateTime, Utc};
use log::{error, trace};
-use smol::{
- channel::{Receiver, Sender},
- lock::Mutex,
-};
-use crate::{util::random_16, Result};
+use crate::{async_runtime::lock::Mutex, util::random_16, Result};
pub type ArcEventSys<T> = Arc<EventSys<T>>;
pub type WeakEventSys<T> = Weak<EventSys<T>>;
@@ -139,7 +136,7 @@ where
self: &Arc<Self>,
topic: &T,
) -> EventListener<T, E> {
- let chan = smol::channel::unbounded();
+ let chan = async_channel::unbounded();
let topics = &mut self.listeners.lock().await;
@@ -310,6 +307,8 @@ pub trait EventValueTopic: EventValueAny + EventValue {
#[cfg(test)]
mod tests {
+ use crate::async_runtime::block_on;
+
use super::*;
#[derive(Hash, PartialEq, Eq, Debug, Clone)]
@@ -337,11 +336,6 @@ mod tests {
}
#[derive(Clone, Debug, PartialEq)]
- struct D {
- d_value: usize,
- }
-
- #[derive(Clone, Debug, PartialEq)]
struct E {
e_value: usize,
}
@@ -369,12 +363,6 @@ mod tests {
}
}
- impl EventValue for D {
- fn id() -> &'static str {
- "D"
- }
- }
-
impl EventValue for E {
fn id() -> &'static str {
"E"
@@ -396,7 +384,7 @@ mod tests {
#[test]
fn test_event_sys() {
- smol::block_on(async move {
+ block_on(async move {
let event_sys = EventSys::<Topic>::new();
let a_listener = event_sys.register::<A>(&Topic::TopicA).await;
diff --git a/core/src/lib.rs b/core/src/lib.rs
index ae88188..62052a8 100644
--- a/core/src/lib.rs
+++ b/core/src/lib.rs
@@ -1,8 +1,13 @@
+#[cfg(all(feature = "smol", feature = "tokio"))]
+compile_error!("Only one async runtime feature should be enabled");
+
+#[cfg(not(any(feature = "smol", feature = "tokio")))]
+compile_error!("At least one async runtime feature must be enabled for this crate.");
+
/// A set of helper tools and functions.
pub mod util;
-/// A module containing async utilities that work with the
-/// [`smol`](https://github.com/smol-rs/smol) async runtime.
+/// A set of async utilities.
pub mod async_util;
/// Represents karyon's Core Error.
@@ -14,8 +19,12 @@ pub mod event;
/// A simple publish-subscribe system [`Read More`](./pubsub/struct.Publisher.html)
pub mod pubsub;
+/// A cross-compatible async runtime
+pub mod async_runtime;
+
#[cfg(feature = "crypto")]
+
/// Collects common cryptographic tools
pub mod crypto;
-use error::Result;
+pub use error::{Error, Result};
diff --git a/core/src/pubsub.rs b/core/src/pubsub.rs
index f5cb69b..bcc24ef 100644
--- a/core/src/pubsub.rs
+++ b/core/src/pubsub.rs
@@ -1,9 +1,8 @@
use std::{collections::HashMap, sync::Arc};
use log::error;
-use smol::lock::Mutex;
-use crate::{util::random_16, Result};
+use crate::{async_runtime::lock::Mutex, util::random_16, Result};
pub type ArcPublisher<T> = Arc<Publisher<T>>;
pub type SubscriptionID = u16;
@@ -28,7 +27,7 @@ pub type SubscriptionID = u16;
///
/// ```
pub struct Publisher<T> {
- subs: Mutex<HashMap<SubscriptionID, smol::channel::Sender<T>>>,
+ subs: Mutex<HashMap<SubscriptionID, async_channel::Sender<T>>>,
}
impl<T: Clone> Publisher<T> {
@@ -43,7 +42,7 @@ impl<T: Clone> Publisher<T> {
pub async fn subscribe(self: &Arc<Self>) -> Subscription<T> {
let mut subs = self.subs.lock().await;
- let chan = smol::channel::unbounded();
+ let chan = async_channel::unbounded();
let mut sub_id = random_16();
@@ -84,7 +83,7 @@ impl<T: Clone> Publisher<T> {
// Subscription
pub struct Subscription<T> {
id: SubscriptionID,
- recv_chan: smol::channel::Receiver<T>,
+ recv_chan: async_channel::Receiver<T>,
publisher: ArcPublisher<T>,
}
@@ -93,7 +92,7 @@ impl<T: Clone> Subscription<T> {
pub fn new(
id: SubscriptionID,
publisher: ArcPublisher<T>,
- recv_chan: smol::channel::Receiver<T>,
+ recv_chan: async_channel::Receiver<T>,
) -> Subscription<T> {
Self {
id,
diff --git a/core/src/util/encode.rs b/core/src/util/encode.rs
index 7d1061b..bf63671 100644
--- a/core/src/util/encode.rs
+++ b/core/src/util/encode.rs
@@ -1,15 +1,14 @@
use bincode::Encode;
-use crate::Result;
+use crate::{Error, Result};
/// Encode the given type `T` into a `Vec<u8>`.
-pub fn encode<T: Encode>(msg: &T) -> Result<Vec<u8>> {
- let vec = bincode::encode_to_vec(msg, bincode::config::standard())?;
+pub fn encode<T: Encode>(src: &T) -> Result<Vec<u8>> {
+ let vec = bincode::encode_to_vec(src, bincode::config::standard())?;
Ok(vec)
}
/// Encode the given type `T` into the given slice..
-pub fn encode_into_slice<T: Encode>(msg: &T, dst: &mut [u8]) -> Result<()> {
- bincode::encode_into_slice(msg, dst, bincode::config::standard())?;
- Ok(())
+pub fn encode_into_slice<T: Encode>(src: &T, dst: &mut [u8]) -> Result<usize> {
+ bincode::encode_into_slice(src, dst, bincode::config::standard()).map_err(Error::from)
}