From 98a1de91a2dae06323558422c239e5a45fc86e7b Mon Sep 17 00:00:00 2001 From: hozan23 Date: Tue, 28 Nov 2023 22:41:33 +0300 Subject: implement TLS for inbound and outbound connections --- README.md | 6 +- core/Cargo.toml | 6 +- core/src/async_util/backoff.rs | 115 +++++++++++ core/src/async_util/condvar.rs | 387 +++++++++++++++++++++++++++++++++++++ core/src/async_util/condwait.rs | 133 +++++++++++++ core/src/async_util/mod.rs | 13 ++ core/src/async_util/select.rs | 99 ++++++++++ core/src/async_util/task_group.rs | 194 +++++++++++++++++++ core/src/async_util/timeout.rs | 52 +++++ core/src/async_utils/backoff.rs | 115 ----------- core/src/async_utils/condvar.rs | 387 ------------------------------------- core/src/async_utils/condwait.rs | 133 ------------- core/src/async_utils/mod.rs | 13 -- core/src/async_utils/select.rs | 99 ---------- core/src/async_utils/task_group.rs | 194 ------------------- core/src/async_utils/timeout.rs | 52 ----- core/src/error.rs | 6 + core/src/event.rs | 2 +- core/src/key_pair.rs | 189 ++++++++++++++++++ core/src/lib.rs | 9 +- core/src/pubsub.rs | 2 +- core/src/util/decode.rs | 10 + core/src/util/encode.rs | 15 ++ core/src/util/mod.rs | 19 ++ core/src/util/path.rs | 39 ++++ core/src/utils/decode.rs | 10 - core/src/utils/encode.rs | 15 -- core/src/utils/mod.rs | 19 -- core/src/utils/path.rs | 39 ---- jsonrpc/src/client.rs | 2 +- jsonrpc/src/codec.rs | 2 +- jsonrpc/src/server.rs | 2 +- net/Cargo.toml | 3 +- net/src/connection.rs | 9 +- net/src/endpoint.rs | 70 +++++-- net/src/error.rs | 8 +- net/src/lib.rs | 1 + net/src/listener.rs | 5 +- net/src/transports/mod.rs | 1 + net/src/transports/tcp.rs | 2 +- net/src/transports/tls.rs | 140 ++++++++++++++ net/src/transports/udp.rs | 2 +- net/src/transports/unix.rs | 2 +- p2p/Cargo.toml | 6 + p2p/README.md | 11 +- p2p/examples/chat.rs | 7 +- p2p/examples/monitor.rs | 14 +- p2p/examples/net_simulation.sh | 24 +-- p2p/examples/peer.rs | 14 +- p2p/src/backend.rs | 43 +++-- p2p/src/codec.rs | 120 ++++++++++++ p2p/src/config.rs | 7 +- p2p/src/connection.rs | 2 +- p2p/src/connector.rs | 56 ++++-- p2p/src/discovery/lookup.rs | 71 ++++--- p2p/src/discovery/mod.rs | 30 ++- p2p/src/discovery/refresh.rs | 4 +- p2p/src/error.rs | 18 ++ p2p/src/io_codec.rs | 132 ------------- p2p/src/lib.rs | 13 +- p2p/src/listener.rs | 65 +++++-- p2p/src/message.rs | 2 +- p2p/src/monitor.rs | 4 +- p2p/src/peer/mod.rs | 23 +-- p2p/src/peer/peer_id.rs | 17 ++ p2p/src/peer_pool.rs | 46 ++--- p2p/src/protocol.rs | 7 +- p2p/src/protocols/ping.rs | 6 +- p2p/src/routing_table/entry.rs | 2 +- p2p/src/routing_table/mod.rs | 19 +- p2p/src/slots.rs | 2 +- p2p/src/tls_config.rs | 214 ++++++++++++++++++++ p2p/src/utils/mod.rs | 21 -- p2p/src/utils/version.rs | 93 --------- p2p/src/version.rs | 93 +++++++++ 75 files changed, 2275 insertions(+), 1532 deletions(-) create mode 100644 core/src/async_util/backoff.rs create mode 100644 core/src/async_util/condvar.rs create mode 100644 core/src/async_util/condwait.rs create mode 100644 core/src/async_util/mod.rs create mode 100644 core/src/async_util/select.rs create mode 100644 core/src/async_util/task_group.rs create mode 100644 core/src/async_util/timeout.rs delete mode 100644 core/src/async_utils/backoff.rs delete mode 100644 core/src/async_utils/condvar.rs delete mode 100644 core/src/async_utils/condwait.rs delete mode 100644 core/src/async_utils/mod.rs delete mode 100644 core/src/async_utils/select.rs delete mode 100644 core/src/async_utils/task_group.rs delete mode 100644 core/src/async_utils/timeout.rs create mode 100644 core/src/key_pair.rs create mode 100644 core/src/util/decode.rs create mode 100644 core/src/util/encode.rs create mode 100644 core/src/util/mod.rs create mode 100644 core/src/util/path.rs delete mode 100644 core/src/utils/decode.rs delete mode 100644 core/src/utils/encode.rs delete mode 100644 core/src/utils/mod.rs delete mode 100644 core/src/utils/path.rs create mode 100644 net/src/transports/tls.rs create mode 100644 p2p/src/codec.rs delete mode 100644 p2p/src/io_codec.rs create mode 100644 p2p/src/tls_config.rs delete mode 100644 p2p/src/utils/mod.rs delete mode 100644 p2p/src/utils/version.rs create mode 100644 p2p/src/version.rs diff --git a/README.md b/README.md index 3343f85..68ef672 100644 --- a/README.md +++ b/README.md @@ -25,9 +25,9 @@ implementation for building collaborative software. ## Status This project is a work in progress. The current focus is on shipping karyons -crdt and karyons store, along with major changes to the network stack, -including TLS implementation. You can check the -[issues](https://github.com/karyons/karyons/issues) for updates on ongoing tasks. +crdt and karyons store, along with major changes to the network stack. You can +check the [issues](https://github.com/karyons/karyons/issues) for updates on +ongoing tasks. ## Docs diff --git a/core/Cargo.toml b/core/Cargo.toml index ab05288..5a99e2d 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -10,9 +10,13 @@ edition.workspace = true smol = "1.3.0" pin-project-lite = "0.2.13" log = "0.4.20" -bincode = { version="2.0.0-rc.3", features = ["derive"]} +bincode = "2.0.0-rc.3" chrono = "0.4.30" rand = "0.8.5" thiserror = "1.0.47" dirs = "5.0.1" async-task = "4.5.0" +ed25519-dalek = { version = "2.1.0", features = ["rand_core"]} + + + diff --git a/core/src/async_util/backoff.rs b/core/src/async_util/backoff.rs new file mode 100644 index 0000000..a231229 --- /dev/null +++ b/core/src/async_util/backoff.rs @@ -0,0 +1,115 @@ +use std::{ + cmp::min, + sync::atomic::{AtomicBool, AtomicU32, Ordering}, + time::Duration, +}; + +use smol::Timer; + +/// Exponential backoff +/// +/// +/// # Examples +/// +/// ``` +/// use karyons_core::async_util::Backoff; +/// +/// async { +/// let backoff = Backoff::new(300, 3000); +/// +/// loop { +/// backoff.sleep().await; +/// +/// // do something +/// break; +/// } +/// +/// backoff.reset(); +/// +/// // .... +/// }; +/// +/// ``` +/// +pub struct Backoff { + /// The base delay in milliseconds for the initial retry. + base_delay: u64, + /// The max delay in milliseconds allowed for a retry. + max_delay: u64, + /// Atomic counter + retries: AtomicU32, + /// Stop flag + stop: AtomicBool, +} + +impl Backoff { + /// Creates a new Backoff. + pub fn new(base_delay: u64, max_delay: u64) -> Self { + Self { + base_delay, + max_delay, + retries: AtomicU32::new(0), + stop: AtomicBool::new(false), + } + } + + /// Sleep based on the current retry count and delay values. + /// Retruns the delay value. + pub async fn sleep(&self) -> u64 { + if self.stop.load(Ordering::SeqCst) { + Timer::after(Duration::from_millis(self.max_delay)).await; + return self.max_delay; + } + + let retries = self.retries.load(Ordering::SeqCst); + let delay = self.base_delay * (2_u64).pow(retries); + let delay = min(delay, self.max_delay); + + if delay == self.max_delay { + self.stop.store(true, Ordering::SeqCst); + } + + self.retries.store(retries + 1, Ordering::SeqCst); + + Timer::after(Duration::from_millis(delay)).await; + delay + } + + /// Reset the retry counter to 0. + pub fn reset(&self) { + self.retries.store(0, Ordering::SeqCst); + self.stop.store(false, Ordering::SeqCst); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn test_backoff() { + smol::block_on(async move { + let backoff = Arc::new(Backoff::new(5, 15)); + let backoff_c = backoff.clone(); + smol::spawn(async move { + let delay = backoff_c.sleep().await; + assert_eq!(delay, 5); + + let delay = backoff_c.sleep().await; + assert_eq!(delay, 10); + + let delay = backoff_c.sleep().await; + assert_eq!(delay, 15); + }) + .await; + + smol::spawn(async move { + backoff.reset(); + let delay = backoff.sleep().await; + assert_eq!(delay, 5); + }) + .await; + }); + } +} diff --git a/core/src/async_util/condvar.rs b/core/src/async_util/condvar.rs new file mode 100644 index 0000000..7396d0d --- /dev/null +++ b/core/src/async_util/condvar.rs @@ -0,0 +1,387 @@ +use std::{ + collections::HashMap, + future::Future, + pin::Pin, + sync::Mutex, + task::{Context, Poll, Waker}, +}; + +use smol::lock::MutexGuard; + +use crate::util::random_16; + +/// CondVar is an async version of +/// +/// # Example +/// +///``` +/// use std::sync::Arc; +/// +/// use smol::lock::Mutex; +/// +/// use karyons_core::async_util::CondVar; +/// +/// async { +/// +/// let val = Arc::new(Mutex::new(false)); +/// let condvar = Arc::new(CondVar::new()); +/// +/// let val_cloned = val.clone(); +/// let condvar_cloned = condvar.clone(); +/// smol::spawn(async move { +/// let mut val = val_cloned.lock().await; +/// +/// // While the boolean flag is false, wait for a signal. +/// while !*val { +/// val = condvar_cloned.wait(val).await; +/// } +/// +/// // ... +/// }); +/// +/// let condvar_cloned = condvar.clone(); +/// smol::spawn(async move { +/// let mut val = val.lock().await; +/// +/// // While the boolean flag is false, wait for a signal. +/// while !*val { +/// val = condvar_cloned.wait(val).await; +/// } +/// +/// // ... +/// }); +/// +/// // Wake up all waiting tasks on this condvar +/// condvar.broadcast(); +/// }; +/// +/// ``` + +pub struct CondVar { + inner: Mutex, +} + +impl CondVar { + /// Creates a new CondVar + pub fn new() -> Self { + Self { + inner: Mutex::new(Wakers::new()), + } + } + + /// Blocks the current task until this condition variable receives a notification. + pub async fn wait<'a, T>(&self, g: MutexGuard<'a, T>) -> MutexGuard<'a, T> { + let m = MutexGuard::source(&g); + + CondVarAwait::new(self, g).await; + + m.lock().await + } + + /// Wakes up one blocked task waiting on this condvar. + pub fn signal(&self) { + self.inner.lock().unwrap().wake(true); + } + + /// Wakes up all blocked tasks waiting on this condvar. + pub fn broadcast(&self) { + self.inner.lock().unwrap().wake(false); + } +} + +impl Default for CondVar { + fn default() -> Self { + Self::new() + } +} + +struct CondVarAwait<'a, T> { + id: Option, + condvar: &'a CondVar, + guard: Option>, +} + +impl<'a, T> CondVarAwait<'a, T> { + fn new(condvar: &'a CondVar, guard: MutexGuard<'a, T>) -> Self { + Self { + condvar, + guard: Some(guard), + id: None, + } + } +} + +impl<'a, T> Future for CondVarAwait<'a, T> { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut inner = self.condvar.inner.lock().unwrap(); + + match self.guard.take() { + Some(_) => { + // the first pooll will release the Mutexguard + self.id = Some(inner.put(Some(cx.waker().clone()))); + Poll::Pending + } + None => { + // Return Ready if it has already been polled and removed + // from the waker list. + if self.id.is_none() { + return Poll::Ready(()); + } + + let i = self.id.as_ref().unwrap(); + match inner.wakers.get_mut(i).unwrap() { + Some(wk) => { + // This will prevent cloning again + if !wk.will_wake(cx.waker()) { + *wk = cx.waker().clone(); + } + Poll::Pending + } + None => { + inner.delete(i); + self.id = None; + Poll::Ready(()) + } + } + } + } + } +} + +impl<'a, T> Drop for CondVarAwait<'a, T> { + fn drop(&mut self) { + if let Some(id) = self.id { + let mut inner = self.condvar.inner.lock().unwrap(); + if let Some(wk) = inner.wakers.get_mut(&id).unwrap().take() { + wk.wake() + } + } + } +} + +/// Wakers is a helper struct to store the task wakers +struct Wakers { + wakers: HashMap>, +} + +impl Wakers { + fn new() -> Self { + Self { + wakers: HashMap::new(), + } + } + + fn put(&mut self, waker: Option) -> u16 { + let mut id: u16; + + id = random_16(); + while self.wakers.contains_key(&id) { + id = random_16(); + } + + self.wakers.insert(id, waker); + id + } + + fn delete(&mut self, id: &u16) -> Option> { + self.wakers.remove(id) + } + + fn wake(&mut self, signal: bool) { + for (_, wk) in self.wakers.iter_mut() { + match wk.take() { + Some(w) => { + w.wake(); + if signal { + break; + } + } + None => continue, + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use smol::lock::Mutex; + use std::{ + collections::VecDeque, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + }; + + // The tests below demonstrate a solution to a problem in the Wikipedia + // explanation of condition variables: + // https://en.wikipedia.org/wiki/Monitor_(synchronization)#Solving_the_bounded_producer/consumer_problem. + + struct Queue { + items: VecDeque, + max_len: usize, + } + impl Queue { + fn new(max_len: usize) -> Self { + Self { + items: VecDeque::new(), + max_len, + } + } + + fn is_full(&self) -> bool { + self.items.len() == self.max_len + } + + fn is_empty(&self) -> bool { + self.items.is_empty() + } + } + + #[test] + fn test_condvar_signal() { + smol::block_on(async { + let number_of_tasks = 30; + + let queue = Arc::new(Mutex::new(Queue::new(5))); + let condvar_full = Arc::new(CondVar::new()); + let condvar_empty = Arc::new(CondVar::new()); + + let queue_cloned = queue.clone(); + let condvar_full_cloned = condvar_full.clone(); + let condvar_empty_cloned = condvar_empty.clone(); + + let _producer1 = smol::spawn(async move { + for i in 1..number_of_tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-full + while queue.is_full() { + // Release queue mutex and sleep + queue = condvar_full_cloned.wait(queue).await; + } + + queue.items.push_back(format!("task {i}")); + + // Wake up the consumer + condvar_empty_cloned.signal(); + } + }); + + let queue_cloned = queue.clone(); + let task_consumed = Arc::new(AtomicUsize::new(0)); + let task_consumed_ = task_consumed.clone(); + let consumer = smol::spawn(async move { + for _ in 1..number_of_tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-empty + while queue.is_empty() { + // Release queue mutex and sleep + queue = condvar_empty.wait(queue).await; + } + + let _ = queue.items.pop_front().unwrap(); + + task_consumed_.fetch_add(1, Ordering::Relaxed); + + // Do something + + // Wake up the producer + condvar_full.signal(); + } + }); + + consumer.await; + assert!(queue.lock().await.is_empty()); + assert_eq!(task_consumed.load(Ordering::Relaxed), 29); + }); + } + + #[test] + fn test_condvar_broadcast() { + smol::block_on(async { + let tasks = 30; + + let queue = Arc::new(Mutex::new(Queue::new(5))); + let condvar = Arc::new(CondVar::new()); + + let queue_cloned = queue.clone(); + let condvar_cloned = condvar.clone(); + let _producer1 = smol::spawn(async move { + for i in 1..tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-full + while queue.is_full() { + // Release queue mutex and sleep + queue = condvar_cloned.wait(queue).await; + } + + queue.items.push_back(format!("producer1: task {i}")); + + // Wake up all producer and consumer tasks + condvar_cloned.broadcast(); + } + }); + + let queue_cloned = queue.clone(); + let condvar_cloned = condvar.clone(); + let _producer2 = smol::spawn(async move { + for i in 1..tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-full + while queue.is_full() { + // Release queue mutex and sleep + queue = condvar_cloned.wait(queue).await; + } + + queue.items.push_back(format!("producer2: task {i}")); + + // Wake up all producer and consumer tasks + condvar_cloned.broadcast(); + } + }); + + let queue_cloned = queue.clone(); + let task_consumed = Arc::new(AtomicUsize::new(0)); + let task_consumed_ = task_consumed.clone(); + + let consumer = smol::spawn(async move { + for _ in 1..((tasks * 2) - 1) { + { + // Lock queue mutex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-empty + while queue.is_empty() { + // Release queue mutex and sleep + queue = condvar.wait(queue).await; + } + + let _ = queue.items.pop_front().unwrap(); + + task_consumed_.fetch_add(1, Ordering::Relaxed); + + // Do something + + // Wake up all producer and consumer tasks + condvar.broadcast(); + } + } + }); + + consumer.await; + assert!(queue.lock().await.is_empty()); + assert_eq!(task_consumed.load(Ordering::Relaxed), 58); + }); + } +} diff --git a/core/src/async_util/condwait.rs b/core/src/async_util/condwait.rs new file mode 100644 index 0000000..cd4b269 --- /dev/null +++ b/core/src/async_util/condwait.rs @@ -0,0 +1,133 @@ +use smol::lock::Mutex; + +use super::CondVar; + +/// CondWait is a wrapper struct for CondVar with a Mutex boolean flag. +/// +/// # Example +/// +///``` +/// use std::sync::Arc; +/// +/// use karyons_core::async_util::CondWait; +/// +/// async { +/// let cond_wait = Arc::new(CondWait::new()); +/// let cond_wait_cloned = cond_wait.clone(); +/// let task = smol::spawn(async move { +/// cond_wait_cloned.wait().await; +/// // ... +/// }); +/// +/// cond_wait.signal().await; +/// }; +/// +/// ``` +/// +pub struct CondWait { + /// The CondVar + condvar: CondVar, + /// Boolean flag + w: Mutex, +} + +impl CondWait { + /// Creates a new CondWait. + pub fn new() -> Self { + Self { + condvar: CondVar::new(), + w: Mutex::new(false), + } + } + + /// Waits for a signal or broadcast. + pub async fn wait(&self) { + let mut w = self.w.lock().await; + + // While the boolean flag is false, wait for a signal. + while !*w { + w = self.condvar.wait(w).await; + } + } + + /// Signal a waiting task. + pub async fn signal(&self) { + *self.w.lock().await = true; + self.condvar.signal(); + } + + /// Signal all waiting tasks. + pub async fn broadcast(&self) { + *self.w.lock().await = true; + self.condvar.broadcast(); + } + + /// Reset the boolean flag value to false. + pub async fn reset(&self) { + *self.w.lock().await = false; + } +} + +impl Default for CondWait { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; + + #[test] + fn test_cond_wait() { + smol::block_on(async { + let cond_wait = Arc::new(CondWait::new()); + let count = Arc::new(AtomicUsize::new(0)); + + let cond_wait_cloned = cond_wait.clone(); + let count_cloned = count.clone(); + let task = smol::spawn(async move { + cond_wait_cloned.wait().await; + count_cloned.fetch_add(1, Ordering::Relaxed); + // do something + }); + + // Send a signal to the waiting task + cond_wait.signal().await; + + task.await; + + // Reset the boolean flag + cond_wait.reset().await; + + assert_eq!(count.load(Ordering::Relaxed), 1); + + let cond_wait_cloned = cond_wait.clone(); + let count_cloned = count.clone(); + let task1 = smol::spawn(async move { + cond_wait_cloned.wait().await; + count_cloned.fetch_add(1, Ordering::Relaxed); + // do something + }); + + let cond_wait_cloned = cond_wait.clone(); + let count_cloned = count.clone(); + let task2 = smol::spawn(async move { + cond_wait_cloned.wait().await; + count_cloned.fetch_add(1, Ordering::Relaxed); + // do something + }); + + // Broadcast a signal to all waiting tasks + cond_wait.broadcast().await; + + task1.await; + task2.await; + assert_eq!(count.load(Ordering::Relaxed), 3); + }); + } +} diff --git a/core/src/async_util/mod.rs b/core/src/async_util/mod.rs new file mode 100644 index 0000000..c871bad --- /dev/null +++ b/core/src/async_util/mod.rs @@ -0,0 +1,13 @@ +mod backoff; +mod condvar; +mod condwait; +mod select; +mod task_group; +mod timeout; + +pub use backoff::Backoff; +pub use condvar::CondVar; +pub use condwait::CondWait; +pub use select::{select, Either}; +pub use task_group::{TaskGroup, TaskResult}; +pub use timeout::timeout; diff --git a/core/src/async_util/select.rs b/core/src/async_util/select.rs new file mode 100644 index 0000000..8f2f7f6 --- /dev/null +++ b/core/src/async_util/select.rs @@ -0,0 +1,99 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use pin_project_lite::pin_project; +use smol::future::Future; + +/// Returns the result of the future that completes first, preferring future1 +/// if both are ready. +/// +/// # Examples +/// +/// ``` +/// use std::future; +/// +/// use karyons_core::async_util::{select, Either}; +/// +/// async { +/// let fut1 = future::pending::(); +/// let fut2 = future::ready(0); +/// let res = select(fut1, fut2).await; +/// assert!(matches!(res, Either::Right(0))); +/// // .... +/// }; +/// +/// ``` +/// +pub fn select(future1: F1, future2: F2) -> Select +where + F1: Future, + F2: Future, +{ + Select { future1, future2 } +} + +pin_project! { + #[derive(Debug)] + pub struct Select { + #[pin] + future1: F1, + #[pin] + future2: F2, + } +} + +/// The return value from the [`select`] function, indicating which future +/// completed first. +#[derive(Debug)] +pub enum Either { + Left(T1), + Right(T2), +} + +// Implement the Future trait for the Select struct. +impl Future for Select +where + F1: Future, + F2: Future, +{ + type Output = Either; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if let Poll::Ready(t) = this.future1.poll(cx) { + return Poll::Ready(Either::Left(t)); + } + + if let Poll::Ready(t) = this.future2.poll(cx) { + return Poll::Ready(Either::Right(t)); + } + + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use super::{select, Either}; + use smol::Timer; + use std::future; + + #[test] + fn test_async_select() { + smol::block_on(async move { + let fut = select(Timer::never(), future::ready(0 as u32)).await; + assert!(matches!(fut, Either::Right(0))); + + let fut1 = future::pending::(); + let fut2 = future::ready(0); + let res = select(fut1, fut2).await; + assert!(matches!(res, Either::Right(0))); + + let fut1 = future::ready(0); + let fut2 = future::pending::(); + let res = select(fut1, fut2).await; + assert!(matches!(res, Either::Left(_))); + }); + } +} diff --git a/core/src/async_util/task_group.rs b/core/src/async_util/task_group.rs new file mode 100644 index 0000000..3fc0cb7 --- /dev/null +++ b/core/src/async_util/task_group.rs @@ -0,0 +1,194 @@ +use std::{future::Future, sync::Arc, sync::Mutex}; + +use async_task::FallibleTask; + +use crate::Executor; + +use super::{select, CondWait, Either}; + +/// TaskGroup is a group of spawned tasks. +/// +/// # Example +/// +/// ``` +/// +/// use std::sync::Arc; +/// +/// use karyons_core::async_util::TaskGroup; +/// +/// async { +/// +/// let ex = Arc::new(smol::Executor::new()); +/// let group = TaskGroup::new(ex); +/// +/// group.spawn(smol::Timer::never(), |_| async {}); +/// +/// group.cancel().await; +/// +/// }; +/// +/// ``` +/// +pub struct TaskGroup<'a> { + tasks: Mutex>, + stop_signal: Arc, + executor: Executor<'a>, +} + +impl<'a> TaskGroup<'a> { + /// Creates a new task group + pub fn new(executor: Executor<'a>) -> Self { + Self { + tasks: Mutex::new(Vec::new()), + stop_signal: Arc::new(CondWait::new()), + executor, + } + } + + /// Spawns a new task and calls the callback after it has completed + /// or been canceled. The callback will have the `TaskResult` as a + /// parameter, indicating whether the task completed or was canceled. + pub fn spawn(&self, fut: Fut, callback: CallbackF) + where + T: Send + Sync + 'a, + Fut: Future + Send + 'a, + CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, + CallbackFut: Future + Send + 'a, + { + let task = TaskHandler::new( + self.executor.clone(), + fut, + callback, + self.stop_signal.clone(), + ); + self.tasks.lock().unwrap().push(task); + } + + /// Checks if the task group is empty. + pub fn is_empty(&self) -> bool { + self.tasks.lock().unwrap().is_empty() + } + + /// Get the number of the tasks in the group. + pub fn len(&self) -> usize { + self.tasks.lock().unwrap().len() + } + + /// Cancels all tasks in the group. + pub async fn cancel(&self) { + self.stop_signal.broadcast().await; + + loop { + let task = self.tasks.lock().unwrap().pop(); + if let Some(t) = task { + t.cancel().await + } else { + break; + } + } + } +} + +/// The result of a spawned task. +#[derive(Debug)] +pub enum TaskResult { + Completed(T), + Cancelled, +} + +impl std::fmt::Display for TaskResult { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + TaskResult::Cancelled => write!(f, "Task cancelled"), + TaskResult::Completed(res) => write!(f, "Task completed: {:?}", res), + } + } +} + +/// TaskHandler +pub struct TaskHandler { + task: FallibleTask<()>, + cancel_flag: Arc, +} + +impl<'a> TaskHandler { + /// Creates a new task handle + fn new( + ex: Executor<'a>, + fut: Fut, + callback: CallbackF, + stop_signal: Arc, + ) -> TaskHandler + where + T: Send + Sync + 'a, + Fut: Future + Send + 'a, + CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, + CallbackFut: Future + Send + 'a, + { + let cancel_flag = Arc::new(CondWait::new()); + let cancel_flag_c = cancel_flag.clone(); + let task = ex + .spawn(async move { + //start_signal.signal().await; + // Waits for either the stop signal or the task to complete. + let result = select(stop_signal.wait(), fut).await; + + let result = match result { + Either::Left(_) => TaskResult::Cancelled, + Either::Right(res) => TaskResult::Completed(res), + }; + + // Call the callback with the result. + callback(result).await; + + cancel_flag_c.signal().await; + }) + .fallible(); + + TaskHandler { task, cancel_flag } + } + + /// Cancels the task. + async fn cancel(self) { + self.cancel_flag.wait().await; + self.task.cancel().await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{future, sync::Arc}; + + #[test] + fn test_task_group() { + let ex = Arc::new(smol::Executor::new()); + smol::block_on(ex.clone().run(async move { + let group = Arc::new(TaskGroup::new(ex)); + + group.spawn(future::ready(0), |res| async move { + assert!(matches!(res, TaskResult::Completed(0))); + }); + + group.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + + let groupc = group.clone(); + group.spawn( + async move { + groupc.spawn(future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + }, + |res| async move { + assert!(matches!(res, TaskResult::Completed(_))); + }, + ); + + // Do something + smol::Timer::after(std::time::Duration::from_millis(50)).await; + group.cancel().await; + })); + } +} diff --git a/core/src/async_util/timeout.rs b/core/src/async_util/timeout.rs new file mode 100644 index 0000000..6ab35c4 --- /dev/null +++ b/core/src/async_util/timeout.rs @@ -0,0 +1,52 @@ +use std::{future::Future, time::Duration}; + +use smol::Timer; + +use super::{select, Either}; +use crate::{error::Error, Result}; + +/// Waits for a future to complete or times out if it exceeds a specified +/// duration. +/// +/// # Example +/// +/// ``` +/// use std::{future, time::Duration}; +/// +/// use karyons_core::async_util::timeout; +/// +/// async { +/// let fut = future::pending::<()>(); +/// assert!(timeout(Duration::from_millis(100), fut).await.is_err()); +/// }; +/// +/// ``` +/// +pub async fn timeout(delay: Duration, future1: F) -> Result +where + F: Future, +{ + let result = select(Timer::after(delay), future1).await; + + match result { + Either::Left(_) => Err(Error::Timeout), + Either::Right(res) => Ok(res), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{future, time::Duration}; + + #[test] + fn test_timeout() { + smol::block_on(async move { + let fut = future::pending::<()>(); + assert!(timeout(Duration::from_millis(10), fut).await.is_err()); + + let fut = smol::Timer::after(Duration::from_millis(10)); + assert!(timeout(Duration::from_millis(50), fut).await.is_ok()) + }); + } +} diff --git a/core/src/async_utils/backoff.rs b/core/src/async_utils/backoff.rs deleted file mode 100644 index f7e131d..0000000 --- a/core/src/async_utils/backoff.rs +++ /dev/null @@ -1,115 +0,0 @@ -use std::{ - cmp::min, - sync::atomic::{AtomicBool, AtomicU32, Ordering}, - time::Duration, -}; - -use smol::Timer; - -/// Exponential backoff -/// -/// -/// # Examples -/// -/// ``` -/// use karyons_core::async_utils::Backoff; -/// -/// async { -/// let backoff = Backoff::new(300, 3000); -/// -/// loop { -/// backoff.sleep().await; -/// -/// // do something -/// break; -/// } -/// -/// backoff.reset(); -/// -/// // .... -/// }; -/// -/// ``` -/// -pub struct Backoff { - /// The base delay in milliseconds for the initial retry. - base_delay: u64, - /// The max delay in milliseconds allowed for a retry. - max_delay: u64, - /// Atomic counter - retries: AtomicU32, - /// Stop flag - stop: AtomicBool, -} - -impl Backoff { - /// Creates a new Backoff. - pub fn new(base_delay: u64, max_delay: u64) -> Self { - Self { - base_delay, - max_delay, - retries: AtomicU32::new(0), - stop: AtomicBool::new(false), - } - } - - /// Sleep based on the current retry count and delay values. - /// Retruns the delay value. - pub async fn sleep(&self) -> u64 { - if self.stop.load(Ordering::SeqCst) { - Timer::after(Duration::from_millis(self.max_delay)).await; - return self.max_delay; - } - - let retries = self.retries.load(Ordering::SeqCst); - let delay = self.base_delay * (2_u64).pow(retries); - let delay = min(delay, self.max_delay); - - if delay == self.max_delay { - self.stop.store(true, Ordering::SeqCst); - } - - self.retries.store(retries + 1, Ordering::SeqCst); - - Timer::after(Duration::from_millis(delay)).await; - delay - } - - /// Reset the retry counter to 0. - pub fn reset(&self) { - self.retries.store(0, Ordering::SeqCst); - self.stop.store(false, Ordering::SeqCst); - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::sync::Arc; - - #[test] - fn test_backoff() { - smol::block_on(async move { - let backoff = Arc::new(Backoff::new(5, 15)); - let backoff_c = backoff.clone(); - smol::spawn(async move { - let delay = backoff_c.sleep().await; - assert_eq!(delay, 5); - - let delay = backoff_c.sleep().await; - assert_eq!(delay, 10); - - let delay = backoff_c.sleep().await; - assert_eq!(delay, 15); - }) - .await; - - smol::spawn(async move { - backoff.reset(); - let delay = backoff.sleep().await; - assert_eq!(delay, 5); - }) - .await; - }); - } -} diff --git a/core/src/async_utils/condvar.rs b/core/src/async_utils/condvar.rs deleted file mode 100644 index 814f78f..0000000 --- a/core/src/async_utils/condvar.rs +++ /dev/null @@ -1,387 +0,0 @@ -use std::{ - collections::HashMap, - future::Future, - pin::Pin, - sync::Mutex, - task::{Context, Poll, Waker}, -}; - -use smol::lock::MutexGuard; - -use crate::utils::random_16; - -/// CondVar is an async version of -/// -/// # Example -/// -///``` -/// use std::sync::Arc; -/// -/// use smol::lock::Mutex; -/// -/// use karyons_core::async_utils::CondVar; -/// -/// async { -/// -/// let val = Arc::new(Mutex::new(false)); -/// let condvar = Arc::new(CondVar::new()); -/// -/// let val_cloned = val.clone(); -/// let condvar_cloned = condvar.clone(); -/// smol::spawn(async move { -/// let mut val = val_cloned.lock().await; -/// -/// // While the boolean flag is false, wait for a signal. -/// while !*val { -/// val = condvar_cloned.wait(val).await; -/// } -/// -/// // ... -/// }); -/// -/// let condvar_cloned = condvar.clone(); -/// smol::spawn(async move { -/// let mut val = val.lock().await; -/// -/// // While the boolean flag is false, wait for a signal. -/// while !*val { -/// val = condvar_cloned.wait(val).await; -/// } -/// -/// // ... -/// }); -/// -/// // Wake up all waiting tasks on this condvar -/// condvar.broadcast(); -/// }; -/// -/// ``` - -pub struct CondVar { - inner: Mutex, -} - -impl CondVar { - /// Creates a new CondVar - pub fn new() -> Self { - Self { - inner: Mutex::new(Wakers::new()), - } - } - - /// Blocks the current task until this condition variable receives a notification. - pub async fn wait<'a, T>(&self, g: MutexGuard<'a, T>) -> MutexGuard<'a, T> { - let m = MutexGuard::source(&g); - - CondVarAwait::new(self, g).await; - - m.lock().await - } - - /// Wakes up one blocked task waiting on this condvar. - pub fn signal(&self) { - self.inner.lock().unwrap().wake(true); - } - - /// Wakes up all blocked tasks waiting on this condvar. - pub fn broadcast(&self) { - self.inner.lock().unwrap().wake(false); - } -} - -impl Default for CondVar { - fn default() -> Self { - Self::new() - } -} - -struct CondVarAwait<'a, T> { - id: Option, - condvar: &'a CondVar, - guard: Option>, -} - -impl<'a, T> CondVarAwait<'a, T> { - fn new(condvar: &'a CondVar, guard: MutexGuard<'a, T>) -> Self { - Self { - condvar, - guard: Some(guard), - id: None, - } - } -} - -impl<'a, T> Future for CondVarAwait<'a, T> { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut inner = self.condvar.inner.lock().unwrap(); - - match self.guard.take() { - Some(_) => { - // the first pooll will release the Mutexguard - self.id = Some(inner.put(Some(cx.waker().clone()))); - Poll::Pending - } - None => { - // Return Ready if it has already been polled and removed - // from the waker list. - if self.id.is_none() { - return Poll::Ready(()); - } - - let i = self.id.as_ref().unwrap(); - match inner.wakers.get_mut(i).unwrap() { - Some(wk) => { - // This will prevent cloning again - if !wk.will_wake(cx.waker()) { - *wk = cx.waker().clone(); - } - Poll::Pending - } - None => { - inner.delete(i); - self.id = None; - Poll::Ready(()) - } - } - } - } - } -} - -impl<'a, T> Drop for CondVarAwait<'a, T> { - fn drop(&mut self) { - if let Some(id) = self.id { - let mut inner = self.condvar.inner.lock().unwrap(); - if let Some(wk) = inner.wakers.get_mut(&id).unwrap().take() { - wk.wake() - } - } - } -} - -/// Wakers is a helper struct to store the task wakers -struct Wakers { - wakers: HashMap>, -} - -impl Wakers { - fn new() -> Self { - Self { - wakers: HashMap::new(), - } - } - - fn put(&mut self, waker: Option) -> u16 { - let mut id: u16; - - id = random_16(); - while self.wakers.contains_key(&id) { - id = random_16(); - } - - self.wakers.insert(id, waker); - id - } - - fn delete(&mut self, id: &u16) -> Option> { - self.wakers.remove(id) - } - - fn wake(&mut self, signal: bool) { - for (_, wk) in self.wakers.iter_mut() { - match wk.take() { - Some(w) => { - w.wake(); - if signal { - break; - } - } - None => continue, - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use smol::lock::Mutex; - use std::{ - collections::VecDeque, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - }; - - // The tests below demonstrate a solution to a problem in the Wikipedia - // explanation of condition variables: - // https://en.wikipedia.org/wiki/Monitor_(synchronization)#Solving_the_bounded_producer/consumer_problem. - - struct Queue { - items: VecDeque, - max_len: usize, - } - impl Queue { - fn new(max_len: usize) -> Self { - Self { - items: VecDeque::new(), - max_len, - } - } - - fn is_full(&self) -> bool { - self.items.len() == self.max_len - } - - fn is_empty(&self) -> bool { - self.items.is_empty() - } - } - - #[test] - fn test_condvar_signal() { - smol::block_on(async { - let number_of_tasks = 30; - - let queue = Arc::new(Mutex::new(Queue::new(5))); - let condvar_full = Arc::new(CondVar::new()); - let condvar_empty = Arc::new(CondVar::new()); - - let queue_cloned = queue.clone(); - let condvar_full_cloned = condvar_full.clone(); - let condvar_empty_cloned = condvar_empty.clone(); - - let _producer1 = smol::spawn(async move { - for i in 1..number_of_tasks { - // Lock queue mtuex - let mut queue = queue_cloned.lock().await; - - // Check if the queue is non-full - while queue.is_full() { - // Release queue mutex and sleep - queue = condvar_full_cloned.wait(queue).await; - } - - queue.items.push_back(format!("task {i}")); - - // Wake up the consumer - condvar_empty_cloned.signal(); - } - }); - - let queue_cloned = queue.clone(); - let task_consumed = Arc::new(AtomicUsize::new(0)); - let task_consumed_ = task_consumed.clone(); - let consumer = smol::spawn(async move { - for _ in 1..number_of_tasks { - // Lock queue mtuex - let mut queue = queue_cloned.lock().await; - - // Check if the queue is non-empty - while queue.is_empty() { - // Release queue mutex and sleep - queue = condvar_empty.wait(queue).await; - } - - let _ = queue.items.pop_front().unwrap(); - - task_consumed_.fetch_add(1, Ordering::Relaxed); - - // Do something - - // Wake up the producer - condvar_full.signal(); - } - }); - - consumer.await; - assert!(queue.lock().await.is_empty()); - assert_eq!(task_consumed.load(Ordering::Relaxed), 29); - }); - } - - #[test] - fn test_condvar_broadcast() { - smol::block_on(async { - let tasks = 30; - - let queue = Arc::new(Mutex::new(Queue::new(5))); - let condvar = Arc::new(CondVar::new()); - - let queue_cloned = queue.clone(); - let condvar_cloned = condvar.clone(); - let _producer1 = smol::spawn(async move { - for i in 1..tasks { - // Lock queue mtuex - let mut queue = queue_cloned.lock().await; - - // Check if the queue is non-full - while queue.is_full() { - // Release queue mutex and sleep - queue = condvar_cloned.wait(queue).await; - } - - queue.items.push_back(format!("producer1: task {i}")); - - // Wake up all producer and consumer tasks - condvar_cloned.broadcast(); - } - }); - - let queue_cloned = queue.clone(); - let condvar_cloned = condvar.clone(); - let _producer2 = smol::spawn(async move { - for i in 1..tasks { - // Lock queue mtuex - let mut queue = queue_cloned.lock().await; - - // Check if the queue is non-full - while queue.is_full() { - // Release queue mutex and sleep - queue = condvar_cloned.wait(queue).await; - } - - queue.items.push_back(format!("producer2: task {i}")); - - // Wake up all producer and consumer tasks - condvar_cloned.broadcast(); - } - }); - - let queue_cloned = queue.clone(); - let task_consumed = Arc::new(AtomicUsize::new(0)); - let task_consumed_ = task_consumed.clone(); - - let consumer = smol::spawn(async move { - for _ in 1..((tasks * 2) - 1) { - { - // Lock queue mutex - let mut queue = queue_cloned.lock().await; - - // Check if the queue is non-empty - while queue.is_empty() { - // Release queue mutex and sleep - queue = condvar.wait(queue).await; - } - - let _ = queue.items.pop_front().unwrap(); - - task_consumed_.fetch_add(1, Ordering::Relaxed); - - // Do something - - // Wake up all producer and consumer tasks - condvar.broadcast(); - } - } - }); - - consumer.await; - assert!(queue.lock().await.is_empty()); - assert_eq!(task_consumed.load(Ordering::Relaxed), 58); - }); - } -} diff --git a/core/src/async_utils/condwait.rs b/core/src/async_utils/condwait.rs deleted file mode 100644 index e31fac3..0000000 --- a/core/src/async_utils/condwait.rs +++ /dev/null @@ -1,133 +0,0 @@ -use smol::lock::Mutex; - -use super::CondVar; - -/// CondWait is a wrapper struct for CondVar with a Mutex boolean flag. -/// -/// # Example -/// -///``` -/// use std::sync::Arc; -/// -/// use karyons_core::async_utils::CondWait; -/// -/// async { -/// let cond_wait = Arc::new(CondWait::new()); -/// let cond_wait_cloned = cond_wait.clone(); -/// let task = smol::spawn(async move { -/// cond_wait_cloned.wait().await; -/// // ... -/// }); -/// -/// cond_wait.signal().await; -/// }; -/// -/// ``` -/// -pub struct CondWait { - /// The CondVar - condvar: CondVar, - /// Boolean flag - w: Mutex, -} - -impl CondWait { - /// Creates a new CondWait. - pub fn new() -> Self { - Self { - condvar: CondVar::new(), - w: Mutex::new(false), - } - } - - /// Waits for a signal or broadcast. - pub async fn wait(&self) { - let mut w = self.w.lock().await; - - // While the boolean flag is false, wait for a signal. - while !*w { - w = self.condvar.wait(w).await; - } - } - - /// Signal a waiting task. - pub async fn signal(&self) { - *self.w.lock().await = true; - self.condvar.signal(); - } - - /// Signal all waiting tasks. - pub async fn broadcast(&self) { - *self.w.lock().await = true; - self.condvar.broadcast(); - } - - /// Reset the boolean flag value to false. - pub async fn reset(&self) { - *self.w.lock().await = false; - } -} - -impl Default for CondWait { - fn default() -> Self { - Self::new() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }; - - #[test] - fn test_cond_wait() { - smol::block_on(async { - let cond_wait = Arc::new(CondWait::new()); - let count = Arc::new(AtomicUsize::new(0)); - - let cond_wait_cloned = cond_wait.clone(); - let count_cloned = count.clone(); - let task = smol::spawn(async move { - cond_wait_cloned.wait().await; - count_cloned.fetch_add(1, Ordering::Relaxed); - // do something - }); - - // Send a signal to the waiting task - cond_wait.signal().await; - - task.await; - - // Reset the boolean flag - cond_wait.reset().await; - - assert_eq!(count.load(Ordering::Relaxed), 1); - - let cond_wait_cloned = cond_wait.clone(); - let count_cloned = count.clone(); - let task1 = smol::spawn(async move { - cond_wait_cloned.wait().await; - count_cloned.fetch_add(1, Ordering::Relaxed); - // do something - }); - - let cond_wait_cloned = cond_wait.clone(); - let count_cloned = count.clone(); - let task2 = smol::spawn(async move { - cond_wait_cloned.wait().await; - count_cloned.fetch_add(1, Ordering::Relaxed); - // do something - }); - - // Broadcast a signal to all waiting tasks - cond_wait.broadcast().await; - - task1.await; - task2.await; - assert_eq!(count.load(Ordering::Relaxed), 3); - }); - } -} diff --git a/core/src/async_utils/mod.rs b/core/src/async_utils/mod.rs deleted file mode 100644 index c871bad..0000000 --- a/core/src/async_utils/mod.rs +++ /dev/null @@ -1,13 +0,0 @@ -mod backoff; -mod condvar; -mod condwait; -mod select; -mod task_group; -mod timeout; - -pub use backoff::Backoff; -pub use condvar::CondVar; -pub use condwait::CondWait; -pub use select::{select, Either}; -pub use task_group::{TaskGroup, TaskResult}; -pub use timeout::timeout; diff --git a/core/src/async_utils/select.rs b/core/src/async_utils/select.rs deleted file mode 100644 index 9fe3c77..0000000 --- a/core/src/async_utils/select.rs +++ /dev/null @@ -1,99 +0,0 @@ -use std::pin::Pin; -use std::task::{Context, Poll}; - -use pin_project_lite::pin_project; -use smol::future::Future; - -/// Returns the result of the future that completes first, preferring future1 -/// if both are ready. -/// -/// # Examples -/// -/// ``` -/// use std::future; -/// -/// use karyons_core::async_utils::{select, Either}; -/// -/// async { -/// let fut1 = future::pending::(); -/// let fut2 = future::ready(0); -/// let res = select(fut1, fut2).await; -/// assert!(matches!(res, Either::Right(0))); -/// // .... -/// }; -/// -/// ``` -/// -pub fn select(future1: F1, future2: F2) -> Select -where - F1: Future, - F2: Future, -{ - Select { future1, future2 } -} - -pin_project! { - #[derive(Debug)] - pub struct Select { - #[pin] - future1: F1, - #[pin] - future2: F2, - } -} - -/// The return value from the [`select`] function, indicating which future -/// completed first. -#[derive(Debug)] -pub enum Either { - Left(T1), - Right(T2), -} - -// Implement the Future trait for the Select struct. -impl Future for Select -where - F1: Future, - F2: Future, -{ - type Output = Either; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - if let Poll::Ready(t) = this.future1.poll(cx) { - return Poll::Ready(Either::Left(t)); - } - - if let Poll::Ready(t) = this.future2.poll(cx) { - return Poll::Ready(Either::Right(t)); - } - - Poll::Pending - } -} - -#[cfg(test)] -mod tests { - use super::{select, Either}; - use smol::Timer; - use std::future; - - #[test] - fn test_async_select() { - smol::block_on(async move { - let fut = select(Timer::never(), future::ready(0 as u32)).await; - assert!(matches!(fut, Either::Right(0))); - - let fut1 = future::pending::(); - let fut2 = future::ready(0); - let res = select(fut1, fut2).await; - assert!(matches!(res, Either::Right(0))); - - let fut1 = future::ready(0); - let fut2 = future::pending::(); - let res = select(fut1, fut2).await; - assert!(matches!(res, Either::Left(_))); - }); - } -} diff --git a/core/src/async_utils/task_group.rs b/core/src/async_utils/task_group.rs deleted file mode 100644 index afc9648..0000000 --- a/core/src/async_utils/task_group.rs +++ /dev/null @@ -1,194 +0,0 @@ -use std::{future::Future, sync::Arc, sync::Mutex}; - -use async_task::FallibleTask; - -use crate::Executor; - -use super::{select, CondWait, Either}; - -/// TaskGroup is a group of spawned tasks. -/// -/// # Example -/// -/// ``` -/// -/// use std::sync::Arc; -/// -/// use karyons_core::async_utils::TaskGroup; -/// -/// async { -/// -/// let ex = Arc::new(smol::Executor::new()); -/// let group = TaskGroup::new(ex); -/// -/// group.spawn(smol::Timer::never(), |_| async {}); -/// -/// group.cancel().await; -/// -/// }; -/// -/// ``` -/// -pub struct TaskGroup<'a> { - tasks: Mutex>, - stop_signal: Arc, - executor: Executor<'a>, -} - -impl<'a> TaskGroup<'a> { - /// Creates a new task group - pub fn new(executor: Executor<'a>) -> Self { - Self { - tasks: Mutex::new(Vec::new()), - stop_signal: Arc::new(CondWait::new()), - executor, - } - } - - /// Spawns a new task and calls the callback after it has completed - /// or been canceled. The callback will have the `TaskResult` as a - /// parameter, indicating whether the task completed or was canceled. - pub fn spawn(&self, fut: Fut, callback: CallbackF) - where - T: Send + Sync + 'a, - Fut: Future + Send + 'a, - CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, - CallbackFut: Future + Send + 'a, - { - let task = TaskHandler::new( - self.executor.clone(), - fut, - callback, - self.stop_signal.clone(), - ); - self.tasks.lock().unwrap().push(task); - } - - /// Checks if the task group is empty. - pub fn is_empty(&self) -> bool { - self.tasks.lock().unwrap().is_empty() - } - - /// Get the number of the tasks in the group. - pub fn len(&self) -> usize { - self.tasks.lock().unwrap().len() - } - - /// Cancels all tasks in the group. - pub async fn cancel(&self) { - self.stop_signal.broadcast().await; - - loop { - let task = self.tasks.lock().unwrap().pop(); - if let Some(t) = task { - t.cancel().await - } else { - break; - } - } - } -} - -/// The result of a spawned task. -#[derive(Debug)] -pub enum TaskResult { - Completed(T), - Cancelled, -} - -impl std::fmt::Display for TaskResult { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - TaskResult::Cancelled => write!(f, "Task cancelled"), - TaskResult::Completed(res) => write!(f, "Task completed: {:?}", res), - } - } -} - -/// TaskHandler -pub struct TaskHandler { - task: FallibleTask<()>, - cancel_flag: Arc, -} - -impl<'a> TaskHandler { - /// Creates a new task handle - fn new( - ex: Executor<'a>, - fut: Fut, - callback: CallbackF, - stop_signal: Arc, - ) -> TaskHandler - where - T: Send + Sync + 'a, - Fut: Future + Send + 'a, - CallbackF: FnOnce(TaskResult) -> CallbackFut + Send + 'a, - CallbackFut: Future + Send + 'a, - { - let cancel_flag = Arc::new(CondWait::new()); - let cancel_flag_c = cancel_flag.clone(); - let task = ex - .spawn(async move { - //start_signal.signal().await; - // Waits for either the stop signal or the task to complete. - let result = select(stop_signal.wait(), fut).await; - - let result = match result { - Either::Left(_) => TaskResult::Cancelled, - Either::Right(res) => TaskResult::Completed(res), - }; - - // Call the callback with the result. - callback(result).await; - - cancel_flag_c.signal().await; - }) - .fallible(); - - TaskHandler { task, cancel_flag } - } - - /// Cancels the task. - async fn cancel(self) { - self.cancel_flag.wait().await; - self.task.cancel().await; - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::{future, sync::Arc}; - - #[test] - fn test_task_group() { - let ex = Arc::new(smol::Executor::new()); - smol::block_on(ex.clone().run(async move { - let group = Arc::new(TaskGroup::new(ex)); - - group.spawn(future::ready(0), |res| async move { - assert!(matches!(res, TaskResult::Completed(0))); - }); - - group.spawn(future::pending::<()>(), |res| async move { - assert!(matches!(res, TaskResult::Cancelled)); - }); - - let groupc = group.clone(); - group.spawn( - async move { - groupc.spawn(future::pending::<()>(), |res| async move { - assert!(matches!(res, TaskResult::Cancelled)); - }); - }, - |res| async move { - assert!(matches!(res, TaskResult::Completed(_))); - }, - ); - - // Do something - smol::Timer::after(std::time::Duration::from_millis(50)).await; - group.cancel().await; - })); - } -} diff --git a/core/src/async_utils/timeout.rs b/core/src/async_utils/timeout.rs deleted file mode 100644 index 7c55e1b..0000000 --- a/core/src/async_utils/timeout.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::{future::Future, time::Duration}; - -use smol::Timer; - -use super::{select, Either}; -use crate::{error::Error, Result}; - -/// Waits for a future to complete or times out if it exceeds a specified -/// duration. -/// -/// # Example -/// -/// ``` -/// use std::{future, time::Duration}; -/// -/// use karyons_core::async_utils::timeout; -/// -/// async { -/// let fut = future::pending::<()>(); -/// assert!(timeout(Duration::from_millis(100), fut).await.is_err()); -/// }; -/// -/// ``` -/// -pub async fn timeout(delay: Duration, future1: F) -> Result -where - F: Future, -{ - let result = select(Timer::after(delay), future1).await; - - match result { - Either::Left(_) => Err(Error::Timeout), - Either::Right(res) => Ok(res), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::{future, time::Duration}; - - #[test] - fn test_timeout() { - smol::block_on(async move { - let fut = future::pending::<()>(); - assert!(timeout(Duration::from_millis(10), fut).await.is_err()); - - let fut = smol::Timer::after(Duration::from_millis(10)); - assert!(timeout(Duration::from_millis(50), fut).await.is_ok()) - }); - } -} diff --git a/core/src/error.rs b/core/src/error.rs index 63b45d3..7c547c4 100644 --- a/core/src/error.rs +++ b/core/src/error.rs @@ -7,12 +7,18 @@ pub enum Error { #[error(transparent)] IO(#[from] std::io::Error), + #[error("TryInto Error: {0}")] + TryInto(&'static str), + #[error("Timeout Error")] Timeout, #[error("Path Not Found Error: {0}")] PathNotFound(&'static str), + #[error(transparent)] + Ed25519(#[from] ed25519_dalek::ed25519::Error), + #[error("Channel Send Error: {0}")] ChannelSend(String), diff --git a/core/src/event.rs b/core/src/event.rs index 0503e88..f2c5510 100644 --- a/core/src/event.rs +++ b/core/src/event.rs @@ -12,7 +12,7 @@ use smol::{ lock::Mutex, }; -use crate::{utils::random_16, Result}; +use crate::{util::random_16, Result}; pub type ArcEventSys = Arc>; pub type WeakEventSys = Weak>; diff --git a/core/src/key_pair.rs b/core/src/key_pair.rs new file mode 100644 index 0000000..4016351 --- /dev/null +++ b/core/src/key_pair.rs @@ -0,0 +1,189 @@ +use ed25519_dalek::{Signer as _, Verifier as _}; +use rand::rngs::OsRng; + +use crate::{error::Error, Result}; + +/// key cryptography type +pub enum KeyPairType { + Ed25519, +} + +/// A Public key +pub struct PublicKey(PublicKeyInner); + +/// A Secret key +pub struct SecretKey(Vec); + +impl PublicKey { + pub fn as_bytes(&self) -> &[u8] { + self.0.as_bytes() + } + + /// Verify a signature on a message with this public key. + pub fn verify(&self, msg: &[u8], signature: &[u8]) -> Result<()> { + self.0.verify(msg, signature) + } +} + +impl PublicKey { + pub fn from_bytes(kp_type: &KeyPairType, pk: &[u8]) -> Result { + Ok(Self(PublicKeyInner::from_bytes(kp_type, pk)?)) + } +} + +/// A KeyPair. +#[derive(Clone)] +pub struct KeyPair(KeyPairInner); + +impl KeyPair { + /// Generate a new random keypair. + pub fn generate(kp_type: &KeyPairType) -> Self { + Self(KeyPairInner::generate(kp_type)) + } + + /// Sign a message using the private key. + pub fn sign(&self, msg: &[u8]) -> Vec { + self.0.sign(msg) + } + + /// Get the public key of this keypair. + pub fn public(&self) -> PublicKey { + self.0.public() + } + + /// Get the secret key of this keypair. + pub fn secret(&self) -> SecretKey { + self.0.secret() + } +} + +/// An extension trait, adding essential methods to all [`KeyPair`] types. +trait KeyPairExt { + /// Sign a message using the private key. + fn sign(&self, msg: &[u8]) -> Vec; + + /// Get the public key of this keypair. + fn public(&self) -> PublicKey; + + /// Get the secret key of this keypair. + fn secret(&self) -> SecretKey; +} + +#[derive(Clone)] +enum KeyPairInner { + Ed25519(Ed25519KeyPair), +} + +impl KeyPairInner { + fn generate(kp_type: &KeyPairType) -> Self { + match kp_type { + KeyPairType::Ed25519 => Self::Ed25519(Ed25519KeyPair::generate()), + } + } +} + +impl KeyPairExt for KeyPairInner { + fn sign(&self, msg: &[u8]) -> Vec { + match self { + KeyPairInner::Ed25519(kp) => kp.sign(msg), + } + } + + fn public(&self) -> PublicKey { + match self { + KeyPairInner::Ed25519(kp) => kp.public(), + } + } + + fn secret(&self) -> SecretKey { + match self { + KeyPairInner::Ed25519(kp) => kp.secret(), + } + } +} + +#[derive(Clone)] +struct Ed25519KeyPair(ed25519_dalek::SigningKey); + +impl Ed25519KeyPair { + fn generate() -> Self { + Self(ed25519_dalek::SigningKey::generate(&mut OsRng)) + } +} + +impl KeyPairExt for Ed25519KeyPair { + fn sign(&self, msg: &[u8]) -> Vec { + self.0.sign(msg).to_bytes().to_vec() + } + + fn public(&self) -> PublicKey { + PublicKey(PublicKeyInner::Ed25519(Ed25519PublicKey( + self.0.verifying_key(), + ))) + } + + fn secret(&self) -> SecretKey { + SecretKey(self.0.to_bytes().to_vec()) + } +} + +/// An extension trait, adding essential methods to all [`PublicKey`] types. +trait PublicKeyExt { + fn as_bytes(&self) -> &[u8]; + + /// Verify a signature on a message with this public key. + fn verify(&self, msg: &[u8], signature: &[u8]) -> Result<()>; +} + +enum PublicKeyInner { + Ed25519(Ed25519PublicKey), +} + +impl PublicKeyInner { + pub fn from_bytes(kp_type: &KeyPairType, pk: &[u8]) -> Result { + match kp_type { + KeyPairType::Ed25519 => Ok(Self::Ed25519(Ed25519PublicKey::from_bytes(pk)?)), + } + } +} + +impl PublicKeyExt for PublicKeyInner { + fn as_bytes(&self) -> &[u8] { + match self { + Self::Ed25519(pk) => pk.as_bytes(), + } + } + + fn verify(&self, msg: &[u8], signature: &[u8]) -> Result<()> { + match self { + Self::Ed25519(pk) => pk.verify(msg, signature), + } + } +} + +struct Ed25519PublicKey(ed25519_dalek::VerifyingKey); + +impl Ed25519PublicKey { + pub fn from_bytes(pk: &[u8]) -> Result { + let pk_bytes: [u8; 32] = pk + .try_into() + .map_err(|_| Error::TryInto("Failed to convert slice to [u8; 32]"))?; + + Ok(Self(ed25519_dalek::VerifyingKey::from_bytes(&pk_bytes)?)) + } +} + +impl PublicKeyExt for Ed25519PublicKey { + fn as_bytes(&self) -> &[u8] { + self.0.as_bytes() + } + + fn verify(&self, msg: &[u8], signature: &[u8]) -> Result<()> { + let sig_bytes: [u8; 64] = signature + .try_into() + .map_err(|_| Error::TryInto("Failed to convert slice to [u8; 64]"))?; + self.0 + .verify(msg, &ed25519_dalek::Signature::from_bytes(&sig_bytes))?; + Ok(()) + } +} diff --git a/core/src/lib.rs b/core/src/lib.rs index 67e6610..276ed89 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,19 +1,22 @@ /// A set of helper tools and functions. -pub mod utils; +pub mod util; /// A module containing async utilities that work with the /// [`smol`](https://github.com/smol-rs/smol) async runtime. -pub mod async_utils; +pub mod async_util; /// Represents karyons's Core Error. pub mod error; -/// [`event::EventSys`] Implementation +/// [`event::EventSys`] implementation. pub mod event; /// A simple publish-subscribe system [`Read More`](./pubsub/struct.Publisher.html) pub mod pubsub; +/// A cryptographic key pair +pub mod key_pair; + use smol::Executor as SmolEx; use std::sync::Arc; diff --git a/core/src/pubsub.rs b/core/src/pubsub.rs index 4cc0ab7..306d42f 100644 --- a/core/src/pubsub.rs +++ b/core/src/pubsub.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc}; use log::error; use smol::lock::Mutex; -use crate::{utils::random_16, Result}; +use crate::{util::random_16, Result}; pub type ArcPublisher = Arc>; pub type SubscriptionID = u16; diff --git a/core/src/util/decode.rs b/core/src/util/decode.rs new file mode 100644 index 0000000..a8a6522 --- /dev/null +++ b/core/src/util/decode.rs @@ -0,0 +1,10 @@ +use bincode::Decode; + +use crate::Result; + +/// Decodes a given type `T` from the given slice. returns the decoded value +/// along with the number of bytes read. +pub fn decode(src: &[u8]) -> Result<(T, usize)> { + let (result, bytes_read) = bincode::decode_from_slice(src, bincode::config::standard())?; + Ok((result, bytes_read)) +} diff --git a/core/src/util/encode.rs b/core/src/util/encode.rs new file mode 100644 index 0000000..7d1061b --- /dev/null +++ b/core/src/util/encode.rs @@ -0,0 +1,15 @@ +use bincode::Encode; + +use crate::Result; + +/// Encode the given type `T` into a `Vec`. +pub fn encode(msg: &T) -> Result> { + let vec = bincode::encode_to_vec(msg, bincode::config::standard())?; + Ok(vec) +} + +/// Encode the given type `T` into the given slice.. +pub fn encode_into_slice(msg: &T, dst: &mut [u8]) -> Result<()> { + bincode::encode_into_slice(msg, dst, bincode::config::standard())?; + Ok(()) +} diff --git a/core/src/util/mod.rs b/core/src/util/mod.rs new file mode 100644 index 0000000..a3c3f50 --- /dev/null +++ b/core/src/util/mod.rs @@ -0,0 +1,19 @@ +mod decode; +mod encode; +mod path; + +pub use decode::decode; +pub use encode::{encode, encode_into_slice}; +pub use path::{home_dir, tilde_expand}; + +use rand::{rngs::OsRng, Rng}; + +/// Generates and returns a random u32 using `rand::rngs::OsRng`. +pub fn random_32() -> u32 { + OsRng.gen() +} + +/// Generates and returns a random u16 using `rand::rngs::OsRng`. +pub fn random_16() -> u16 { + OsRng.gen() +} diff --git a/core/src/util/path.rs b/core/src/util/path.rs new file mode 100644 index 0000000..2cd900a --- /dev/null +++ b/core/src/util/path.rs @@ -0,0 +1,39 @@ +use std::path::PathBuf; + +use crate::{error::Error, Result}; + +/// Returns the user's home directory as a `PathBuf`. +#[allow(dead_code)] +pub fn home_dir() -> Result { + dirs::home_dir().ok_or(Error::PathNotFound("Home dir not found")) +} + +/// Expands a tilde (~) in a path and returns the expanded `PathBuf`. +#[allow(dead_code)] +pub fn tilde_expand(path: &str) -> Result { + match path { + "~" => home_dir(), + p if p.starts_with("~/") => Ok(home_dir()?.join(&path[2..])), + _ => Ok(PathBuf::from(path)), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tilde_expand() { + let path = "~/src"; + let expanded_path = dirs::home_dir().unwrap().join("src"); + assert_eq!(tilde_expand(path).unwrap(), expanded_path); + + let path = "~"; + let expanded_path = dirs::home_dir().unwrap(); + assert_eq!(tilde_expand(path).unwrap(), expanded_path); + + let path = ""; + let expanded_path = PathBuf::from(""); + assert_eq!(tilde_expand(path).unwrap(), expanded_path); + } +} diff --git a/core/src/utils/decode.rs b/core/src/utils/decode.rs deleted file mode 100644 index a8a6522..0000000 --- a/core/src/utils/decode.rs +++ /dev/null @@ -1,10 +0,0 @@ -use bincode::Decode; - -use crate::Result; - -/// Decodes a given type `T` from the given slice. returns the decoded value -/// along with the number of bytes read. -pub fn decode(src: &[u8]) -> Result<(T, usize)> { - let (result, bytes_read) = bincode::decode_from_slice(src, bincode::config::standard())?; - Ok((result, bytes_read)) -} diff --git a/core/src/utils/encode.rs b/core/src/utils/encode.rs deleted file mode 100644 index 7d1061b..0000000 --- a/core/src/utils/encode.rs +++ /dev/null @@ -1,15 +0,0 @@ -use bincode::Encode; - -use crate::Result; - -/// Encode the given type `T` into a `Vec`. -pub fn encode(msg: &T) -> Result> { - let vec = bincode::encode_to_vec(msg, bincode::config::standard())?; - Ok(vec) -} - -/// Encode the given type `T` into the given slice.. -pub fn encode_into_slice(msg: &T, dst: &mut [u8]) -> Result<()> { - bincode::encode_into_slice(msg, dst, bincode::config::standard())?; - Ok(()) -} diff --git a/core/src/utils/mod.rs b/core/src/utils/mod.rs deleted file mode 100644 index a3c3f50..0000000 --- a/core/src/utils/mod.rs +++ /dev/null @@ -1,19 +0,0 @@ -mod decode; -mod encode; -mod path; - -pub use decode::decode; -pub use encode::{encode, encode_into_slice}; -pub use path::{home_dir, tilde_expand}; - -use rand::{rngs::OsRng, Rng}; - -/// Generates and returns a random u32 using `rand::rngs::OsRng`. -pub fn random_32() -> u32 { - OsRng.gen() -} - -/// Generates and returns a random u16 using `rand::rngs::OsRng`. -pub fn random_16() -> u16 { - OsRng.gen() -} diff --git a/core/src/utils/path.rs b/core/src/utils/path.rs deleted file mode 100644 index 2cd900a..0000000 --- a/core/src/utils/path.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::path::PathBuf; - -use crate::{error::Error, Result}; - -/// Returns the user's home directory as a `PathBuf`. -#[allow(dead_code)] -pub fn home_dir() -> Result { - dirs::home_dir().ok_or(Error::PathNotFound("Home dir not found")) -} - -/// Expands a tilde (~) in a path and returns the expanded `PathBuf`. -#[allow(dead_code)] -pub fn tilde_expand(path: &str) -> Result { - match path { - "~" => home_dir(), - p if p.starts_with("~/") => Ok(home_dir()?.join(&path[2..])), - _ => Ok(PathBuf::from(path)), - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_tilde_expand() { - let path = "~/src"; - let expanded_path = dirs::home_dir().unwrap().join("src"); - assert_eq!(tilde_expand(path).unwrap(), expanded_path); - - let path = "~"; - let expanded_path = dirs::home_dir().unwrap(); - assert_eq!(tilde_expand(path).unwrap(), expanded_path); - - let path = ""; - let expanded_path = PathBuf::from(""); - assert_eq!(tilde_expand(path).unwrap(), expanded_path); - } -} diff --git a/jsonrpc/src/client.rs b/jsonrpc/src/client.rs index f5277aa..939d177 100644 --- a/jsonrpc/src/client.rs +++ b/jsonrpc/src/client.rs @@ -1,7 +1,7 @@ use log::debug; use serde::{de::DeserializeOwned, Serialize}; -use karyons_core::utils::random_32; +use karyons_core::util::random_32; use karyons_net::{dial, Conn, Endpoint}; use crate::{ diff --git a/jsonrpc/src/codec.rs b/jsonrpc/src/codec.rs index e198a6e..5dac8da 100644 --- a/jsonrpc/src/codec.rs +++ b/jsonrpc/src/codec.rs @@ -1,6 +1,6 @@ use memchr::memchr; -use karyons_core::async_utils::timeout; +use karyons_core::async_util::timeout; use karyons_net::Conn; use crate::{Error, Result}; diff --git a/jsonrpc/src/server.rs b/jsonrpc/src/server.rs index 5b9b799..05ef7da 100644 --- a/jsonrpc/src/server.rs +++ b/jsonrpc/src/server.rs @@ -4,7 +4,7 @@ use log::{debug, error, warn}; use smol::lock::RwLock; use karyons_core::{ - async_utils::{TaskGroup, TaskResult}, + async_util::{TaskGroup, TaskResult}, Executor, }; use karyons_net::{listen, Conn, Endpoint, Listener}; diff --git a/net/Cargo.toml b/net/Cargo.toml index de9b33b..863a250 100644 --- a/net/Cargo.toml +++ b/net/Cargo.toml @@ -11,6 +11,7 @@ karyons_core.workspace = true smol = "1.3.0" async-trait = "0.1.73" log = "0.4.20" -bincode = { version="2.0.0-rc.3", features = ["derive"]} +bincode = "2.0.0-rc.3" thiserror = "1.0.47" url = "2.4.1" +async-rustls = { version = "0.4.1", features = ["dangerous_configuration"] } diff --git a/net/src/connection.rs b/net/src/connection.rs index d8ec0a3..b1d7550 100644 --- a/net/src/connection.rs +++ b/net/src/connection.rs @@ -1,7 +1,9 @@ -use crate::{Endpoint, Result}; use async_trait::async_trait; -use crate::transports::{tcp, udp, unix}; +use crate::{ + transports::{tcp, udp, unix}, + Endpoint, Error, Result, +}; /// Alias for `Box` pub type Conn = Box; @@ -28,7 +30,7 @@ pub trait Connection: Send + Sync { /// Connects to the provided endpoint. /// -/// it only supports `tcp4/6`, `udp4/6` and `unix`. +/// it only supports `tcp4/6`, `udp4/6`, and `unix`. /// /// #Example /// @@ -53,5 +55,6 @@ pub async fn dial(endpoint: &Endpoint) -> Result { Endpoint::Tcp(addr, port) => Ok(Box::new(tcp::dial_tcp(addr, port).await?)), Endpoint::Udp(addr, port) => Ok(Box::new(udp::dial_udp(addr, port).await?)), Endpoint::Unix(addr) => Ok(Box::new(unix::dial_unix(addr).await?)), + _ => Err(Error::InvalidEndpoint(endpoint.to_string())), } } diff --git a/net/src/endpoint.rs b/net/src/endpoint.rs index 50dfe6b..720eea3 100644 --- a/net/src/endpoint.rs +++ b/net/src/endpoint.rs @@ -5,7 +5,7 @@ use std::{ str::FromStr, }; -use bincode::{Decode, Encode}; +use bincode::{impl_borrow_decode, Decode, Encode}; use url::Url; use crate::{Error, Result}; @@ -33,6 +33,7 @@ pub type Port = u16; pub enum Endpoint { Udp(Addr, Port), Tcp(Addr, Port), + Tls(Addr, Port), Unix(String), } @@ -45,6 +46,9 @@ impl std::fmt::Display for Endpoint { Endpoint::Tcp(ip, port) => { write!(f, "tcp://{}:{}", ip, port) } + Endpoint::Tls(ip, port) => { + write!(f, "tls://{}:{}", ip, port) + } Endpoint::Unix(path) => { if path.is_empty() { write!(f, "unix:/UNNAMED") @@ -60,9 +64,10 @@ impl TryFrom for SocketAddr { type Error = Error; fn try_from(endpoint: Endpoint) -> std::result::Result { match endpoint { - Endpoint::Udp(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)), - Endpoint::Tcp(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)), - Endpoint::Unix(_) => Err(Error::TryFromEndpointError), + Endpoint::Udp(ip, port) | Endpoint::Tcp(ip, port) | Endpoint::Tls(ip, port) => { + Ok(SocketAddr::new(ip.try_into()?, port)) + } + Endpoint::Unix(_) => Err(Error::TryFromEndpoint), } } } @@ -72,7 +77,7 @@ impl TryFrom for PathBuf { fn try_from(endpoint: Endpoint) -> std::result::Result { match endpoint { Endpoint::Unix(path) => Ok(PathBuf::from(&path)), - _ => Err(Error::TryFromEndpointError), + _ => Err(Error::TryFromEndpoint), } } } @@ -82,7 +87,7 @@ impl TryFrom for UnixSocketAddress { fn try_from(endpoint: Endpoint) -> std::result::Result { match endpoint { Endpoint::Unix(a) => Ok(UnixSocketAddress::from_pathname(a)?), - _ => Err(Error::TryFromEndpointError), + _ => Err(Error::TryFromEndpoint), } } } @@ -112,6 +117,7 @@ impl FromStr for Endpoint { match url.scheme() { "tcp" => Ok(Endpoint::Tcp(addr, port)), "udp" => Ok(Endpoint::Udp(addr, port)), + "tls" => Ok(Endpoint::Tls(addr, port)), _ => Err(Error::InvalidEndpoint(s.to_string())), } } else { @@ -133,6 +139,11 @@ impl Endpoint { Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port()) } + /// Creates a new TLS endpoint from a `SocketAddr`. + pub fn new_tls_addr(addr: &SocketAddr) -> Endpoint { + Endpoint::Tls(Addr::Ip(addr.ip()), addr.port()) + } + /// Creates a new UDP endpoint from a `SocketAddr`. pub fn new_udp_addr(addr: &SocketAddr) -> Endpoint { Endpoint::Udp(Addr::Ip(addr.ip()), addr.port()) @@ -151,29 +162,62 @@ impl Endpoint { /// Returns the `Port` of the endpoint. pub fn port(&self) -> Result<&Port> { match self { - Endpoint::Tcp(_, port) => Ok(port), - Endpoint::Udp(_, port) => Ok(port), - _ => Err(Error::TryFromEndpointError), + Endpoint::Udp(_, port) | Endpoint::Tcp(_, port) | Endpoint::Tls(_, port) => Ok(port), + _ => Err(Error::TryFromEndpoint), } } /// Returns the `Addr` of the endpoint. pub fn addr(&self) -> Result<&Addr> { match self { - Endpoint::Tcp(addr, _) => Ok(addr), - Endpoint::Udp(addr, _) => Ok(addr), - _ => Err(Error::TryFromEndpointError), + Endpoint::Udp(addr, _) | Endpoint::Tcp(addr, _) | Endpoint::Tls(addr, _) => Ok(addr), + _ => Err(Error::TryFromEndpoint), } } } /// Addr defines a type for an address, either IP or domain. -#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Addr { Ip(IpAddr), Domain(String), } +impl Encode for Addr { + fn encode( + &self, + encoder: &mut E, + ) -> std::result::Result<(), bincode::error::EncodeError> { + match self { + Addr::Ip(addr) => { + 0u32.encode(encoder)?; + addr.encode(encoder) + } + Addr::Domain(domain) => { + 1u32.encode(encoder)?; + domain.encode(encoder) + } + } + } +} + +impl Decode for Addr { + fn decode( + decoder: &mut D, + ) -> std::result::Result { + match u32::decode(decoder)? { + 0 => Ok(Addr::Ip(IpAddr::decode(decoder)?)), + 1 => Ok(Addr::Domain(String::decode(decoder)?)), + found => Err(bincode::error::DecodeError::UnexpectedVariant { + allowed: &bincode::error::AllowedEnumVariants::Range { min: 0, max: 1 }, + found, + type_name: core::any::type_name::(), + }), + } + } +} +impl_borrow_decode!(Addr); + impl TryFrom for IpAddr { type Error = Error; fn try_from(addr: Addr) -> std::result::Result { diff --git a/net/src/error.rs b/net/src/error.rs index 346184a..5dd6348 100644 --- a/net/src/error.rs +++ b/net/src/error.rs @@ -8,7 +8,7 @@ pub enum Error { IO(#[from] std::io::Error), #[error("Try from endpoint Error")] - TryFromEndpointError, + TryFromEndpoint, #[error("invalid address {0}")] InvalidAddress(String), @@ -28,6 +28,12 @@ pub enum Error { #[error(transparent)] ChannelRecv(#[from] smol::channel::RecvError), + #[error("Tls Error: {0}")] + Rustls(#[from] async_rustls::rustls::Error), + + #[error("Invalid DNS Name: {0}")] + InvalidDnsNameError(#[from] async_rustls::rustls::client::InvalidDnsNameError), + #[error(transparent)] KaryonsCore(#[from] karyons_core::error::Error), } diff --git a/net/src/lib.rs b/net/src/lib.rs index 0e4c361..61069ef 100644 --- a/net/src/lib.rs +++ b/net/src/lib.rs @@ -10,6 +10,7 @@ pub use { listener::{listen, Listener}, transports::{ tcp::{dial_tcp, listen_tcp, TcpConn}, + tls, udp::{dial_udp, listen_udp, UdpConn}, unix::{dial_unix, listen_unix, UnixConn}, }, diff --git a/net/src/listener.rs b/net/src/listener.rs index 31a63ae..c6c3d94 100644 --- a/net/src/listener.rs +++ b/net/src/listener.rs @@ -1,9 +1,8 @@ -use crate::{Endpoint, Error, Result}; use async_trait::async_trait; use crate::{ transports::{tcp, unix}, - Conn, + Conn, Endpoint, Error, Result, }; /// Listener is a generic network listener. @@ -15,7 +14,7 @@ pub trait Listener: Send + Sync { /// Listens to the provided endpoint. /// -/// it only supports `tcp4/6` and `unix`. +/// it only supports `tcp4/6`, and `unix`. /// /// #Example /// diff --git a/net/src/transports/mod.rs b/net/src/transports/mod.rs index f399133..ac23021 100644 --- a/net/src/transports/mod.rs +++ b/net/src/transports/mod.rs @@ -1,3 +1,4 @@ pub mod tcp; +pub mod tls; pub mod udp; pub mod unix; diff --git a/net/src/transports/tcp.rs b/net/src/transports/tcp.rs index 84aa980..37f00a7 100644 --- a/net/src/transports/tcp.rs +++ b/net/src/transports/tcp.rs @@ -13,7 +13,7 @@ use crate::{ Error, Result, }; -/// TCP network connection implementations of the [`Connection`] trait. +/// TCP network connection implementation of the [`Connection`] trait. pub struct TcpConn { inner: TcpStream, read: Mutex>, diff --git a/net/src/transports/tls.rs b/net/src/transports/tls.rs new file mode 100644 index 0000000..01bb5aa --- /dev/null +++ b/net/src/transports/tls.rs @@ -0,0 +1,140 @@ +use std::sync::Arc; + +use async_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream}; +use async_trait::async_trait; +use smol::{ + io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, + lock::Mutex, + net::{TcpListener, TcpStream}, +}; + +use crate::{ + connection::Connection, + endpoint::{Addr, Endpoint, Port}, + listener::Listener, + Error, Result, +}; + +/// TLS network connection implementation of the [`Connection`] trait. +pub struct TlsConn { + inner: TcpStream, + read: Mutex>>, + write: Mutex>>, +} + +impl TlsConn { + /// Creates a new TlsConn + pub fn new(sock: TcpStream, conn: TlsStream) -> Self { + let (read, write) = split(conn); + Self { + inner: sock, + read: Mutex::new(read), + write: Mutex::new(write), + } + } +} + +#[async_trait] +impl Connection for TlsConn { + fn peer_endpoint(&self) -> Result { + Ok(Endpoint::new_tls_addr(&self.inner.peer_addr()?)) + } + + fn local_endpoint(&self) -> Result { + Ok(Endpoint::new_tls_addr(&self.inner.local_addr()?)) + } + + async fn read(&self, buf: &mut [u8]) -> Result { + self.read.lock().await.read(buf).await.map_err(Error::from) + } + + async fn write(&self, buf: &[u8]) -> Result { + self.write + .lock() + .await + .write(buf) + .await + .map_err(Error::from) + } +} + +/// Connects to the given TLS address and port. +pub async fn dial_tls( + addr: &Addr, + port: &Port, + config: rustls::ClientConfig, + dns_name: &str, +) -> Result { + let address = format!("{}:{}", addr, port); + + let connector = TlsConnector::from(Arc::new(config)); + + let sock = TcpStream::connect(&address).await?; + sock.set_nodelay(true)?; + + let altname = rustls::ServerName::try_from(dns_name)?; + let conn = connector.connect(altname, sock.clone()).await?; + Ok(TlsConn::new(sock, TlsStream::Client(conn))) +} + +/// Connects to the given TLS endpoint, returns `Conn` ([`Connection`]). +pub async fn dial( + endpoint: &Endpoint, + config: rustls::ClientConfig, + dns_name: &str, +) -> Result> { + match endpoint { + Endpoint::Tcp(..) | Endpoint::Tls(..) => {} + _ => return Err(Error::InvalidEndpoint(endpoint.to_string())), + } + + dial_tls(endpoint.addr()?, endpoint.port()?, config, dns_name) + .await + .map(|c| Box::new(c) as Box) +} +/// Tls network listener implementation of the [`Listener`] trait. +pub struct TlsListener { + acceptor: TlsAcceptor, + listener: TcpListener, +} + +#[async_trait] +impl Listener for TlsListener { + fn local_endpoint(&self) -> Result { + Ok(Endpoint::new_tls_addr(&self.listener.local_addr()?)) + } + + async fn accept(&self) -> Result> { + let (sock, _) = self.listener.accept().await?; + sock.set_nodelay(true)?; + let conn = self.acceptor.accept(sock.clone()).await?; + Ok(Box::new(TlsConn::new(sock, TlsStream::Server(conn)))) + } +} + +/// Listens on the given TLS address and port. +pub async fn listen_tls( + addr: &Addr, + port: &Port, + config: rustls::ServerConfig, +) -> Result { + let address = format!("{}:{}", addr, port); + let acceptor = TlsAcceptor::from(Arc::new(config)); + let listener = TcpListener::bind(&address).await?; + Ok(TlsListener { acceptor, listener }) +} + +/// Listens on the given TLS endpoint, returns [`Listener`]. +pub async fn listen( + endpoint: &Endpoint, + config: rustls::ServerConfig, +) -> Result> { + match endpoint { + Endpoint::Tcp(..) | Endpoint::Tls(..) => {} + _ => return Err(Error::InvalidEndpoint(endpoint.to_string())), + } + + listen_tls(endpoint.addr()?, endpoint.port()?, config) + .await + .map(|l| Box::new(l) as Box) +} diff --git a/net/src/transports/udp.rs b/net/src/transports/udp.rs index ca5b94d..8a2fbec 100644 --- a/net/src/transports/udp.rs +++ b/net/src/transports/udp.rs @@ -9,7 +9,7 @@ use crate::{ Error, Result, }; -/// UDP network connection implementations of the [`Connection`] trait. +/// UDP network connection implementation of the [`Connection`] trait. pub struct UdpConn { inner: UdpSocket, } diff --git a/net/src/transports/unix.rs b/net/src/transports/unix.rs index a720d91..e504934 100644 --- a/net/src/transports/unix.rs +++ b/net/src/transports/unix.rs @@ -8,7 +8,7 @@ use smol::{ use crate::{connection::Connection, endpoint::Endpoint, listener::Listener, Error, Result}; -/// Unix domain socket implementations of the [`Connection`] trait. +/// Unix domain socket implementation of the [`Connection`] trait. pub struct UnixConn { inner: UnixStream, read: Mutex>, diff --git a/p2p/Cargo.toml b/p2p/Cargo.toml index 98b700e..315983b 100644 --- a/p2p/Cargo.toml +++ b/p2p/Cargo.toml @@ -20,6 +20,12 @@ thiserror = "1.0.47" semver = "1.0.20" sha2 = "0.10.8" +# tls +async-rustls = { version = "0.4.1", features = ["dangerous_configuration"] } +rcgen = "0.11.3" +yasna = "0.5.2" +x509-parser = "0.15.1" + [[example]] name = "peer" path = "examples/peer.rs" diff --git a/p2p/README.md b/p2p/README.md index 5bdaf63..edc5fcd 100644 --- a/p2p/README.md +++ b/p2p/README.md @@ -115,11 +115,12 @@ impl Protocol for NewProtocol { Whenever a new peer is added to the PeerPool, all the protocols, including your custom protocols, will automatically start running with the newly connected peer. -## Network Security +## Network Security -It's obvious that connections in karyons p2p are not secure at the moment, as -it currently only supports TCP connections. However, we are currently working -on adding support for TLS connections. +Using TLS is possible for all inbound and outbound connections by enabling the +boolean `enable_tls` field in the configuration. However, implementing TLS for +a P2P network is not trivial and is still unstable, requiring a comprehensive +audit. ## Usage @@ -129,5 +130,5 @@ If you have tmux installed, you can run the network simulation script in the examples directory to run 12 peers simultaneously. ```bash -$ RUST_LOG=karyons=debug ./net_simulation.sh +$ RUST_LOG=karyons=info ./net_simulation.sh ``` diff --git a/p2p/examples/chat.rs b/p2p/examples/chat.rs index 907ba06..d94bca4 100644 --- a/p2p/examples/chat.rs +++ b/p2p/examples/chat.rs @@ -7,11 +7,12 @@ use async_trait::async_trait; use clap::Parser; use smol::{channel, Executor}; +use karyons_core::key_pair::{KeyPair, KeyPairType}; use karyons_net::{Endpoint, Port}; use karyons_p2p::{ protocol::{ArcProtocol, Protocol, ProtocolEvent, ProtocolID}, - ArcPeer, Backend, Config, P2pError, PeerID, Version, + ArcPeer, Backend, Config, P2pError, Version, }; use shared::run_executor; @@ -102,7 +103,7 @@ fn main() { let cli = Cli::parse(); // Create a PeerID based on the username. - let peer_id = PeerID::new(cli.username.as_bytes()); + let key_pair = KeyPair::generate(&KeyPairType::Ed25519); // Create the configuration for the backend. let config = Config { @@ -117,7 +118,7 @@ fn main() { let ex = Arc::new(Executor::new()); // Create a new Backend - let backend = Backend::new(peer_id, config, ex.clone()); + let backend = Backend::new(&key_pair, config, ex.clone()); let (ctrlc_s, ctrlc_r) = channel::unbounded(); let handle = move || ctrlc_s.try_send(()).unwrap(); diff --git a/p2p/examples/monitor.rs b/p2p/examples/monitor.rs index fc48c2f..530d2d5 100644 --- a/p2p/examples/monitor.rs +++ b/p2p/examples/monitor.rs @@ -5,9 +5,10 @@ use std::sync::Arc; use clap::Parser; use smol::{channel, Executor}; +use karyons_core::key_pair::{KeyPair, KeyPairType}; use karyons_net::{Endpoint, Port}; -use karyons_p2p::{Backend, Config, PeerID}; +use karyons_p2p::{Backend, Config}; use shared::run_executor; @@ -29,20 +30,13 @@ struct Cli { /// Optional TCP/UDP port for the discovery service. #[arg(short)] discovery_port: Option, - - /// Optional user id - #[arg(long)] - userid: Option, } fn main() { env_logger::init(); let cli = Cli::parse(); - let peer_id = match cli.userid { - Some(userid) => PeerID::new(userid.as_bytes()), - None => PeerID::random(), - }; + let key_pair = KeyPair::generate(&KeyPairType::Ed25519); // Create the configuration for the backend. let config = Config { @@ -57,7 +51,7 @@ fn main() { let ex = Arc::new(Executor::new()); // Create a new Backend - let backend = Backend::new(peer_id, config, ex.clone()); + let backend = Backend::new(&key_pair, config, ex.clone()); let (ctrlc_s, ctrlc_r) = channel::unbounded(); let handle = move || ctrlc_s.try_send(()).unwrap(); diff --git a/p2p/examples/net_simulation.sh b/p2p/examples/net_simulation.sh index 1a05adf..dd489e5 100755 --- a/p2p/examples/net_simulation.sh +++ b/p2p/examples/net_simulation.sh @@ -5,27 +5,27 @@ cargo build --release --example peer tmux new-session -d -s karyons_p2p -tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer1'\ +tmux send-keys -t karyons_p2p "../../target/release/examples/peer\ -l 'tcp://127.0.0.1:30000' -d '30010'" Enter tmux split-window -h -t karyons_p2p -tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer2'\ +tmux send-keys -t karyons_p2p "../../target/release/examples/peer\ -l 'tcp://127.0.0.1:30001' -d '30011' -b 'tcp://127.0.0.1:30010 ' " Enter tmux split-window -h -t karyons_p2p -tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer3'\ +tmux send-keys -t karyons_p2p "../../target/release/examples/peer\ -l 'tcp://127.0.0.1:30002' -d '30012' -b 'tcp://127.0.0.1:30010'" Enter tmux split-window -h -t karyons_p2p -tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer4'\ +tmux send-keys -t karyons_p2p "../../target/release/examples/peer\ -l 'tcp://127.0.0.1:30003' -d '30013' -b 'tcp://127.0.0.1:30010'" Enter tmux split-window -h -t karyons_p2p -tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer5'\ +tmux send-keys -t karyons_p2p "../../target/release/examples/peer\ -l 'tcp://127.0.0.1:30004' -d '30014' -b 'tcp://127.0.0.1:30010'" Enter tmux split-window -h -t karyons_p2p -tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer6'\ +tmux send-keys -t karyons_p2p "../../target/release/examples/peer\ -l 'tcp://127.0.0.1:30005' -d '30015' -b 'tcp://127.0.0.1:30010'" Enter tmux select-layout even-horizontal @@ -35,37 +35,37 @@ sleep 3; tmux select-pane -t karyons_p2p:0.0 tmux split-window -v -t karyons_p2p -tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer7'\ +tmux send-keys -t karyons_p2p "../../target/release/examples/peer\ -b 'tcp://127.0.0.1:30010' -b 'tcp://127.0.0.1:30011'" Enter tmux select-pane -t karyons_p2p:0.2 tmux split-window -v -t karyons_p2p -tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer8'\ +tmux send-keys -t karyons_p2p "../../target/release/examples/peer\ -b 'tcp://127.0.0.1:30010' -b 'tcp://127.0.0.1:30012' -p 'tcp://127.0.0.1:30005'" Enter tmux select-pane -t karyons_p2p:0.4 tmux split-window -v -t karyons_p2p -tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer9'\ +tmux send-keys -t karyons_p2p "../../target/release/examples/peer\ -b 'tcp://127.0.0.1:30010' -b 'tcp://127.0.0.1:30013'" Enter tmux select-pane -t karyons_p2p:0.6 tmux split-window -v -t karyons_p2p -tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer10'\ +tmux send-keys -t karyons_p2p "../../target/release/examples/peer\ -b 'tcp://127.0.0.1:30010' -b 'tcp://127.0.0.1:30014'" Enter tmux select-pane -t karyons_p2p:0.8 tmux split-window -v -t karyons_p2p -tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer11'\ +tmux send-keys -t karyons_p2p "../../target/release/examples/peer\ -b 'tcp://127.0.0.1:30010' -b 'tcp://127.0.0.1:30015'" Enter tmux select-pane -t karyons_p2p:0.10 tmux split-window -v -t karyons_p2p -tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer12'\ +tmux send-keys -t karyons_p2p "../../target/release/examples/peer\ -b 'tcp://127.0.0.1:30010' -b 'tcp://127.0.0.1:30015' -b 'tcp://127.0.0.1:30011'" Enter # tmux set-window-option -t karyons_p2p synchronize-panes on diff --git a/p2p/examples/peer.rs b/p2p/examples/peer.rs index 5ff365d..b595b4a 100644 --- a/p2p/examples/peer.rs +++ b/p2p/examples/peer.rs @@ -5,9 +5,10 @@ use std::sync::Arc; use clap::Parser; use smol::{channel, Executor}; +use karyons_core::key_pair::{KeyPair, KeyPairType}; use karyons_net::{Endpoint, Port}; -use karyons_p2p::{Backend, Config, PeerID}; +use karyons_p2p::{Backend, Config}; use shared::run_executor; @@ -29,20 +30,13 @@ struct Cli { /// Optional TCP/UDP port for the discovery service. #[arg(short)] discovery_port: Option, - - /// Optional user id - #[arg(long)] - userid: Option, } fn main() { env_logger::init(); let cli = Cli::parse(); - let peer_id = match cli.userid { - Some(userid) => PeerID::new(userid.as_bytes()), - None => PeerID::random(), - }; + let key_pair = KeyPair::generate(&KeyPairType::Ed25519); // Create the configuration for the backend. let config = Config { @@ -57,7 +51,7 @@ fn main() { let ex = Arc::new(Executor::new()); // Create a new Backend - let backend = Backend::new(peer_id, config, ex.clone()); + let backend = Backend::new(&key_pair, config, ex.clone()); let (ctrlc_s, ctrlc_r) = channel::unbounded(); let handle = move || ctrlc_s.try_send(()).unwrap(); diff --git a/p2p/src/backend.rs b/p2p/src/backend.rs index 2e34f47..56d79f7 100644 --- a/p2p/src/backend.rs +++ b/p2p/src/backend.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use log::info; -use karyons_core::{pubsub::Subscription, GlobalExecutor}; +use karyons_core::{key_pair::KeyPair, pubsub::Subscription, GlobalExecutor}; use crate::{ config::Config, @@ -22,8 +22,8 @@ pub struct Backend { /// The Configuration for the P2P network. config: Arc, - /// Peer ID. - id: PeerID, + /// Identity Key pair + key_pair: KeyPair, /// Responsible for network and system monitoring. monitor: Arc, @@ -37,17 +37,34 @@ pub struct Backend { impl Backend { /// Creates a new Backend. - pub fn new(id: PeerID, config: Config, ex: GlobalExecutor) -> ArcBackend { + pub fn new(key_pair: &KeyPair, config: Config, ex: GlobalExecutor) -> ArcBackend { let config = Arc::new(config); let monitor = Arc::new(Monitor::new()); - let cq = ConnQueue::new(); - - let peer_pool = PeerPool::new(&id, cq.clone(), config.clone(), monitor.clone(), ex.clone()); - - let discovery = Discovery::new(&id, cq, config.clone(), monitor.clone(), ex); + let conn_queue = ConnQueue::new(); + + let peer_id = PeerID::try_from(key_pair.public()) + .expect("Derive a peer id from the provided key pair."); + info!("PeerID: {}", peer_id); + + let peer_pool = PeerPool::new( + &peer_id, + conn_queue.clone(), + config.clone(), + monitor.clone(), + ex.clone(), + ); + + let discovery = Discovery::new( + key_pair, + &peer_id, + conn_queue, + config.clone(), + monitor.clone(), + ex, + ); Arc::new(Self { - id: id.clone(), + key_pair: key_pair.clone(), monitor, discovery, config, @@ -57,7 +74,6 @@ impl Backend { /// Run the Backend, starting the PeerPool and Discovery instances. pub async fn run(self: &Arc) -> Result<()> { - info!("Run the backend {}", self.id); self.peer_pool.start().await?; self.discovery.start().await?; Ok(()) @@ -81,6 +97,11 @@ impl Backend { self.config.clone() } + /// Returns the `KeyPair`. + pub async fn key_pair(&self) -> &KeyPair { + &self.key_pair + } + /// Returns the number of occupied inbound slots. pub fn inbound_slots(&self) -> usize { self.discovery.inbound_slots.load() diff --git a/p2p/src/codec.rs b/p2p/src/codec.rs new file mode 100644 index 0000000..e521824 --- /dev/null +++ b/p2p/src/codec.rs @@ -0,0 +1,120 @@ +use std::time::Duration; + +use bincode::{Decode, Encode}; + +use karyons_core::{ + async_util::timeout, + util::{decode, encode, encode_into_slice}, +}; + +use karyons_net::{Connection, NetError}; + +use crate::{ + message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE}, + Error, Result, +}; + +pub trait CodecMsg: Decode + Encode + std::fmt::Debug {} +impl CodecMsg for T {} + +/// A Codec working with generic network connections. +/// +/// It is responsible for both decoding data received from the network and +/// encoding data before sending it. +pub struct Codec { + conn: Box, +} + +impl Codec { + /// Creates a new Codec. + pub fn new(conn: Box) -> Self { + Self { conn } + } + + /// Reads a message of type `NetMsg` from the connection. + /// + /// It reads the first 6 bytes as the header of the message, then reads + /// and decodes the remaining message data based on the determined header. + pub async fn read(&self) -> Result { + // Read 6 bytes to get the header of the incoming message + let mut buf = [0; MSG_HEADER_SIZE]; + self.read_exact(&mut buf).await?; + + // Decode the header from bytes to NetMsgHeader + let (header, _) = decode::(&buf)?; + + if header.payload_size > MAX_ALLOWED_MSG_SIZE { + return Err(Error::InvalidMsg( + "Message exceeds the maximum allowed size".to_string(), + )); + } + + // Create a buffer to hold the message based on its length + let mut payload = vec![0; header.payload_size as usize]; + self.read_exact(&mut payload).await?; + + Ok(NetMsg { header, payload }) + } + + /// Writes a message of type `T` to the connection. + /// + /// Before appending the actual message payload, it calculates the length of + /// the encoded message in bytes and appends this length to the message header. + pub async fn write(&self, command: NetMsgCmd, msg: &T) -> Result<()> { + let payload = encode(msg)?; + + // Create a buffer to hold the message header (6 bytes) + let header_buf = &mut [0; MSG_HEADER_SIZE]; + let header = NetMsgHeader { + command, + payload_size: payload.len() as u32, + }; + encode_into_slice(&header, header_buf)?; + + let mut buffer = vec![]; + // Append the header bytes to the buffer + buffer.extend_from_slice(header_buf); + // Append the message payload to the buffer + buffer.extend_from_slice(&payload); + + self.write_all(&buffer).await?; + Ok(()) + } + + /// Reads a message of type `NetMsg` with the given timeout. + pub async fn read_timeout(&self, duration: Duration) -> Result { + timeout(duration, self.read()) + .await + .map_err(|_| NetError::Timeout)? + } + + /// Reads the exact number of bytes required to fill `buf`. + async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> { + while !buf.is_empty() { + let n = self.conn.read(buf).await?; + let (_, rest) = std::mem::take(&mut buf).split_at_mut(n); + buf = rest; + + if n == 0 { + return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); + } + } + + Ok(()) + } + + /// Writes an entire buffer into the connection. + async fn write_all(&self, mut buf: &[u8]) -> Result<()> { + while !buf.is_empty() { + let n = self.conn.write(buf).await?; + let (_, rest) = std::mem::take(&mut buf).split_at(n); + buf = rest; + + if n == 0 { + return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); + } + } + + Ok(()) + } +} diff --git a/p2p/src/config.rs b/p2p/src/config.rs index ebecbf0..2c5d5ec 100644 --- a/p2p/src/config.rs +++ b/p2p/src/config.rs @@ -1,6 +1,6 @@ use karyons_net::{Endpoint, Port}; -use crate::utils::Version; +use crate::Version; /// the Configuration for the P2P network. pub struct Config { @@ -71,6 +71,9 @@ pub struct Config { /// The maximum number of retries for outbound connection establishment /// during the refresh process. pub refresh_connect_retries: usize, + + /// Enables TLS for all connections. + pub enable_tls: bool, } impl Default for Config { @@ -100,6 +103,8 @@ impl Default for Config { refresh_interval: 1800, refresh_response_timeout: 1, refresh_connect_retries: 3, + + enable_tls: false, } } } diff --git a/p2p/src/connection.rs b/p2p/src/connection.rs index 8ec2617..e0a3bbd 100644 --- a/p2p/src/connection.rs +++ b/p2p/src/connection.rs @@ -2,7 +2,7 @@ use std::{collections::VecDeque, fmt, sync::Arc}; use smol::{channel::Sender, lock::Mutex}; -use karyons_core::async_utils::CondVar; +use karyons_core::async_util::CondVar; use karyons_net::Conn; use crate::Result; diff --git a/p2p/src/connector.rs b/p2p/src/connector.rs index f41ab57..6fc5734 100644 --- a/p2p/src/connector.rs +++ b/p2p/src/connector.rs @@ -1,21 +1,28 @@ use std::{future::Future, sync::Arc}; -use log::{trace, warn}; +use log::{error, trace, warn}; use karyons_core::{ - async_utils::{Backoff, TaskGroup, TaskResult}, + async_util::{Backoff, TaskGroup, TaskResult}, + key_pair::KeyPair, GlobalExecutor, }; -use karyons_net::{dial, Conn, Endpoint, NetError}; +use karyons_net::{dial, tls, Conn, Endpoint, NetError}; use crate::{ monitor::{ConnEvent, Monitor}, slots::ConnectionSlots, - Result, + tls_config::tls_client_config, + Error, PeerID, Result, }; +static DNS_NAME: &str = "karyons.org"; + /// Responsible for creating outbound connections with other peers. pub struct Connector { + /// Identity Key pair + key_pair: KeyPair, + /// Managing spawned tasks. task_group: TaskGroup<'static>, @@ -26,6 +33,9 @@ pub struct Connector { /// establishing a connection. max_retries: usize, + /// Enables secure connection. + enable_tls: bool, + /// Responsible for network and system monitoring. monitor: Arc, } @@ -33,16 +43,20 @@ pub struct Connector { impl Connector { /// Creates a new Connector pub fn new( + key_pair: &KeyPair, max_retries: usize, connection_slots: Arc, + enable_tls: bool, monitor: Arc, ex: GlobalExecutor, ) -> Arc { Arc::new(Self { + key_pair: key_pair.clone(), + max_retries, task_group: TaskGroup::new(ex), monitor, connection_slots, - max_retries, + enable_tls, }) } @@ -57,20 +71,23 @@ impl Connector { /// `Conn` instance. /// /// This method will block until it finds an available slot. - pub async fn connect(&self, endpoint: &Endpoint) -> Result { + pub async fn connect(&self, endpoint: &Endpoint, peer_id: &Option) -> Result { self.connection_slots.wait_for_slot().await; self.connection_slots.add(); let mut retry = 0; let backoff = Backoff::new(500, 2000); while retry < self.max_retries { - let conn_result = dial(endpoint).await; - - if let Ok(conn) = conn_result { - self.monitor - .notify(&ConnEvent::Connected(endpoint.clone()).into()) - .await; - return Ok(conn); + match self.dial(endpoint, peer_id).await { + Ok(conn) => { + self.monitor + .notify(&ConnEvent::Connected(endpoint.clone()).into()) + .await; + return Ok(conn); + } + Err(err) => { + error!("Failed to establish a connection to {endpoint}: {err}"); + } } self.monitor @@ -96,12 +113,13 @@ impl Connector { pub async fn connect_with_cback( self: &Arc, endpoint: &Endpoint, + peer_id: &Option, callback: impl FnOnce(Conn) -> Fut + Send + 'static, ) -> Result<()> where Fut: Future> + Send + 'static, { - let conn = self.connect(endpoint).await?; + let conn = self.connect(endpoint, peer_id).await?; let selfc = self.clone(); let endpoint = endpoint.clone(); @@ -120,4 +138,14 @@ impl Connector { Ok(()) } + + async fn dial(&self, endpoint: &Endpoint, peer_id: &Option) -> Result { + if self.enable_tls { + let tls_config = tls_client_config(&self.key_pair, peer_id.clone())?; + tls::dial(endpoint, tls_config, DNS_NAME).await + } else { + dial(endpoint).await + } + .map_err(Error::KaryonsNet) + } } diff --git a/p2p/src/discovery/lookup.rs b/p2p/src/discovery/lookup.rs index 0138068..60d8635 100644 --- a/p2p/src/discovery/lookup.rs +++ b/p2p/src/discovery/lookup.rs @@ -5,13 +5,13 @@ use log::{error, trace}; use rand::{rngs::OsRng, seq::SliceRandom, RngCore}; use smol::lock::{Mutex, RwLock}; -use karyons_core::{async_utils::timeout, utils::decode, GlobalExecutor}; +use karyons_core::{async_util::timeout, key_pair::KeyPair, util::decode, GlobalExecutor}; use karyons_net::{Conn, Endpoint}; use crate::{ + codec::Codec, connector::Connector, - io_codec::IOCodec, listener::Listener, message::{ get_msg_payload, FindPeerMsg, NetMsg, NetMsgCmd, PeerMsg, PeersMsg, PingMsg, PongMsg, @@ -20,7 +20,7 @@ use crate::{ monitor::{ConnEvent, DiscoveryEvent, Monitor}, routing_table::RoutingTable, slots::ConnectionSlots, - utils::version_match, + version::version_match, Config, Error, PeerID, Result, }; @@ -55,6 +55,7 @@ pub struct LookupService { impl LookupService { /// Creates a new lookup service pub fn new( + key_pair: &KeyPair, id: &PeerID, table: Arc>, config: Arc, @@ -64,11 +65,19 @@ impl LookupService { let inbound_slots = Arc::new(ConnectionSlots::new(config.lookup_inbound_slots)); let outbound_slots = Arc::new(ConnectionSlots::new(config.lookup_outbound_slots)); - let listener = Listener::new(inbound_slots.clone(), monitor.clone(), ex.clone()); + let listener = Listener::new( + key_pair, + inbound_slots.clone(), + config.enable_tls, + monitor.clone(), + ex.clone(), + ); let connector = Connector::new( + key_pair, config.lookup_connect_retries, outbound_slots.clone(), + config.enable_tls, monitor.clone(), ex, ); @@ -116,14 +125,17 @@ impl LookupService { /// randomly generated peer ID. Upon receiving peers from the initial lookup, /// it starts connecting to these received peers and sends them a FindPeer /// message that contains our own peer ID. - pub async fn start_lookup(&self, endpoint: &Endpoint) -> Result<()> { + pub async fn start_lookup(&self, endpoint: &Endpoint, peer_id: Option) -> Result<()> { trace!("Lookup started {endpoint}"); self.monitor .notify(&DiscoveryEvent::LookupStarted(endpoint.clone()).into()) .await; let mut random_peers = vec![]; - if let Err(err) = self.random_lookup(endpoint, &mut random_peers).await { + if let Err(err) = self + .random_lookup(endpoint, peer_id, &mut random_peers) + .await + { self.monitor .notify(&DiscoveryEvent::LookupFailed(endpoint.clone()).into()) .await; @@ -160,11 +172,14 @@ impl LookupService { async fn random_lookup( &self, endpoint: &Endpoint, + peer_id: Option, random_peers: &mut Vec, ) -> Result<()> { for _ in 0..2 { - let peer_id = PeerID::random(); - let peers = self.connect(&peer_id, endpoint.clone()).await?; + let random_peer_id = PeerID::random(); + let peers = self + .connect(endpoint.clone(), peer_id.clone(), &random_peer_id) + .await?; let table = self.table.lock().await; for peer in peers { @@ -187,7 +202,7 @@ impl LookupService { let mut tasks = FuturesUnordered::new(); for peer in random_peers.choose_multiple(&mut OsRng, random_peers.len()) { let endpoint = Endpoint::Tcp(peer.addr.clone(), peer.discovery_port); - tasks.push(self.connect(&self.id, endpoint)) + tasks.push(self.connect(endpoint, Some(peer.peer_id.clone()), &self.id)) } while let Some(result) = tasks.next().await { @@ -200,11 +215,17 @@ impl LookupService { } } - /// Connects to the given endpoint - async fn connect(&self, peer_id: &PeerID, endpoint: Endpoint) -> Result> { - let conn = self.connector.connect(&endpoint).await?; - let io_codec = IOCodec::new(conn); - let result = self.handle_outbound(io_codec, peer_id).await; + /// Connects to the given endpoint and initiates a lookup process for the + /// provided peer ID. + async fn connect( + &self, + endpoint: Endpoint, + peer_id: Option, + target_peer_id: &PeerID, + ) -> Result> { + let conn = self.connector.connect(&endpoint, &peer_id).await?; + let io_codec = Codec::new(conn); + let result = self.handle_outbound(io_codec, target_peer_id).await; self.monitor .notify(&ConnEvent::Disconnected(endpoint).into()) @@ -215,12 +236,16 @@ impl LookupService { } /// Handles outbound connection - async fn handle_outbound(&self, io_codec: IOCodec, peer_id: &PeerID) -> Result> { + async fn handle_outbound( + &self, + io_codec: Codec, + target_peer_id: &PeerID, + ) -> Result> { trace!("Send Ping msg"); self.send_ping_msg(&io_codec).await?; trace!("Send FindPeer msg"); - let peers = self.send_findpeer_msg(&io_codec, peer_id).await?; + let peers = self.send_findpeer_msg(&io_codec, target_peer_id).await?; if peers.0.len() >= MAX_PEERS_IN_PEERSMSG { return Err(Error::Lookup("Received too many peers in PeersMsg")); @@ -260,7 +285,7 @@ impl LookupService { /// Handles inbound connection async fn handle_inbound(self: &Arc, conn: Conn) -> Result<()> { - let io_codec = IOCodec::new(conn); + let io_codec = Codec::new(conn); loop { let msg: NetMsg = io_codec.read().await?; trace!("Receive msg {:?}", msg.header.command); @@ -293,7 +318,7 @@ impl LookupService { } /// Sends a Ping msg and wait to receive the Pong message. - async fn send_ping_msg(&self, io_codec: &IOCodec) -> Result<()> { + async fn send_ping_msg(&self, io_codec: &Codec) -> Result<()> { trace!("Send Pong msg"); let mut nonce: [u8; 32] = [0; 32]; @@ -319,14 +344,14 @@ impl LookupService { } /// Sends a Pong msg - async fn send_pong_msg(&self, nonce: [u8; 32], io_codec: &IOCodec) -> Result<()> { + async fn send_pong_msg(&self, nonce: [u8; 32], io_codec: &Codec) -> Result<()> { trace!("Send Pong msg"); io_codec.write(NetMsgCmd::Pong, &PongMsg(nonce)).await?; Ok(()) } /// Sends a FindPeer msg and wait to receivet the Peers msg. - async fn send_findpeer_msg(&self, io_codec: &IOCodec, peer_id: &PeerID) -> Result { + async fn send_findpeer_msg(&self, io_codec: &Codec, peer_id: &PeerID) -> Result { trace!("Send FindPeer msg"); io_codec .write(NetMsgCmd::FindPeer, &FindPeerMsg(peer_id.clone())) @@ -342,7 +367,7 @@ impl LookupService { } /// Sends a Peers msg. - async fn send_peers_msg(&self, peer_id: &PeerID, io_codec: &IOCodec) -> Result<()> { + async fn send_peers_msg(&self, peer_id: &PeerID, io_codec: &Codec) -> Result<()> { trace!("Send Peers msg"); let table = self.table.lock().await; let entries = table.closest_entries(&peer_id.0, MAX_PEERS_IN_PEERSMSG); @@ -354,7 +379,7 @@ impl LookupService { } /// Sends a Peer msg. - async fn send_peer_msg(&self, io_codec: &IOCodec, endpoint: Endpoint) -> Result<()> { + async fn send_peer_msg(&self, io_codec: &Codec, endpoint: Endpoint) -> Result<()> { trace!("Send Peer msg"); let peer_msg = PeerMsg { addr: endpoint.addr()?.clone(), @@ -367,7 +392,7 @@ impl LookupService { } /// Sends a Shutdown msg. - async fn send_shutdown_msg(&self, io_codec: &IOCodec) -> Result<()> { + async fn send_shutdown_msg(&self, io_codec: &Codec) -> Result<()> { trace!("Send Shutdown msg"); io_codec.write(NetMsgCmd::Shutdown, &ShutdownMsg(0)).await?; Ok(()) diff --git a/p2p/src/discovery/mod.rs b/p2p/src/discovery/mod.rs index 7f55309..2c1bcd8 100644 --- a/p2p/src/discovery/mod.rs +++ b/p2p/src/discovery/mod.rs @@ -8,7 +8,8 @@ use rand::{rngs::OsRng, seq::SliceRandom}; use smol::lock::Mutex; use karyons_core::{ - async_utils::{Backoff, TaskGroup, TaskResult}, + async_util::{Backoff, TaskGroup, TaskResult}, + key_pair::KeyPair, GlobalExecutor, }; @@ -66,6 +67,7 @@ pub struct Discovery { impl Discovery { /// Creates a new Discovery pub fn new( + key_pair: &KeyPair, peer_id: &PeerID, conn_queue: Arc, config: Arc, @@ -81,6 +83,7 @@ impl Discovery { let refresh_service = RefreshService::new(config.clone(), table.clone(), monitor.clone(), ex.clone()); let lookup_service = LookupService::new( + key_pair, peer_id, table.clone(), config.clone(), @@ -89,12 +92,21 @@ impl Discovery { ); let connector = Connector::new( + key_pair, config.max_connect_retries, outbound_slots.clone(), + config.enable_tls, + monitor.clone(), + ex.clone(), + ); + + let listener = Listener::new( + key_pair, + inbound_slots.clone(), + config.enable_tls, monitor.clone(), ex.clone(), ); - let listener = Listener::new(inbound_slots.clone(), monitor.clone(), ex.clone()); Arc::new(Self { refresh_service: Arc::new(refresh_service), @@ -222,7 +234,7 @@ impl Discovery { selfc.update_entry(&pid, INCOMPATIBLE_ENTRY).await; } Err(Error::PeerAlreadyConnected) => { - // TODO + // TODO: Use the appropriate status. selfc.update_entry(&pid, DISCONNECTED_ENTRY).await; } Err(_) => { @@ -236,10 +248,13 @@ impl Discovery { Ok(()) }; - let res = self.connector.connect_with_cback(endpoint, cback).await; + let result = self + .connector + .connect_with_cback(endpoint, &pid, cback) + .await; if let Some(pid) = &pid { - match res { + match result { Ok(_) => { self.update_entry(pid, CONNECTED_ENTRY).await; } @@ -260,7 +275,8 @@ impl Discovery { match self.random_entry(PENDING_ENTRY | CONNECTED_ENTRY).await { Some(entry) => { let endpoint = Endpoint::Tcp(entry.addr, entry.discovery_port); - if let Err(err) = self.lookup_service.start_lookup(&endpoint).await { + let peer_id = Some(entry.key.into()); + if let Err(err) = self.lookup_service.start_lookup(&endpoint, peer_id).await { self.update_entry(&entry.key.into(), UNSTABLE_ENTRY).await; error!("Failed to do lookup: {endpoint}: {err}"); } @@ -268,7 +284,7 @@ impl Discovery { None => { let peers = &self.config.bootstrap_peers; for endpoint in peers.choose_multiple(&mut OsRng, peers.len()) { - if let Err(err) = self.lookup_service.start_lookup(endpoint).await { + if let Err(err) = self.lookup_service.start_lookup(endpoint, None).await { error!("Failed to do lookup: {endpoint}: {err}"); } } diff --git a/p2p/src/discovery/refresh.rs b/p2p/src/discovery/refresh.rs index d095f19..f797c71 100644 --- a/p2p/src/discovery/refresh.rs +++ b/p2p/src/discovery/refresh.rs @@ -10,8 +10,8 @@ use smol::{ }; use karyons_core::{ - async_utils::{timeout, Backoff, TaskGroup, TaskResult}, - utils::{decode, encode}, + async_util::{timeout, Backoff, TaskGroup, TaskResult}, + util::{decode, encode}, GlobalExecutor, }; diff --git a/p2p/src/error.rs b/p2p/src/error.rs index 0c1d50c..6274d4c 100644 --- a/p2p/src/error.rs +++ b/p2p/src/error.rs @@ -11,6 +11,9 @@ pub enum Error { #[error("Unsupported protocol error: {0}")] UnsupportedProtocol(String), + #[error("Try from public key Error: {0}")] + TryFromPublicKey(&'static str), + #[error("Invalid message error: {0}")] InvalidMsg(String), @@ -50,6 +53,21 @@ pub enum Error { #[error("Peer already connected")] PeerAlreadyConnected, + #[error("Yasna Error: {0}")] + Yasna(#[from] yasna::ASN1Error), + + #[error("X509 Parser Error: {0}")] + X509Parser(#[from] x509_parser::error::X509Error), + + #[error("Rcgen Error: {0}")] + Rcgen(#[from] rcgen::RcgenError), + + #[error("Tls Error: {0}")] + Rustls(#[from] async_rustls::rustls::Error), + + #[error("Invalid DNS Name: {0}")] + InvalidDnsNameError(#[from] async_rustls::rustls::client::InvalidDnsNameError), + #[error("Channel Send Error: {0}")] ChannelSend(String), diff --git a/p2p/src/io_codec.rs b/p2p/src/io_codec.rs deleted file mode 100644 index ea62666..0000000 --- a/p2p/src/io_codec.rs +++ /dev/null @@ -1,132 +0,0 @@ -use std::time::Duration; - -use bincode::{Decode, Encode}; - -use karyons_core::{ - async_utils::timeout, - utils::{decode, encode, encode_into_slice}, -}; - -use karyons_net::{Connection, NetError}; - -use crate::{ - message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE}, - Error, Result, -}; - -pub trait CodecMsg: Decode + Encode + std::fmt::Debug {} -impl CodecMsg for T {} - -/// I/O codec working with generic network connections. -/// -/// It is responsible for both decoding data received from the network and -/// encoding data before sending it. -pub struct IOCodec { - conn: Box, -} - -impl IOCodec { - /// Creates a new IOCodec. - pub fn new(conn: Box) -> Self { - Self { conn } - } - - /// Reads a message of type `NetMsg` from the connection. - /// - /// It reads the first 6 bytes as the header of the message, then reads - /// and decodes the remaining message data based on the determined header. - pub async fn read(&self) -> Result { - // Read 6 bytes to get the header of the incoming message - let mut buf = [0; MSG_HEADER_SIZE]; - self.read_exact(&mut buf).await?; - - // Decode the header from bytes to NetMsgHeader - let (header, _) = decode::(&buf)?; - - if header.payload_size > MAX_ALLOWED_MSG_SIZE { - return Err(Error::InvalidMsg( - "Message exceeds the maximum allowed size".to_string(), - )); - } - - // Create a buffer to hold the message based on its length - let mut payload = vec![0; header.payload_size as usize]; - self.read_exact(&mut payload).await?; - - Ok(NetMsg { header, payload }) - } - - /// Writes a message of type `T` to the connection. - /// - /// Before appending the actual message payload, it calculates the length of - /// the encoded message in bytes and appends this length to the message header. - pub async fn write(&self, command: NetMsgCmd, msg: &T) -> Result<()> { - let payload = encode(msg)?; - - // Create a buffer to hold the message header (6 bytes) - let header_buf = &mut [0; MSG_HEADER_SIZE]; - let header = NetMsgHeader { - command, - payload_size: payload.len() as u32, - }; - encode_into_slice(&header, header_buf)?; - - let mut buffer = vec![]; - // Append the header bytes to the buffer - buffer.extend_from_slice(header_buf); - // Append the message payload to the buffer - buffer.extend_from_slice(&payload); - - self.write_all(&buffer).await?; - Ok(()) - } - - /// Reads a message of type `NetMsg` with the given timeout. - pub async fn read_timeout(&self, duration: Duration) -> Result { - timeout(duration, self.read()) - .await - .map_err(|_| NetError::Timeout)? - } - - /// Writes a message of type `T` with the given timeout. - pub async fn write_timeout( - &self, - command: NetMsgCmd, - msg: &T, - duration: Duration, - ) -> Result<()> { - timeout(duration, self.write(command, msg)) - .await - .map_err(|_| NetError::Timeout)? - } - - /// Reads the exact number of bytes required to fill `buf`. - async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> { - while !buf.is_empty() { - let n = self.conn.read(buf).await?; - let (_, rest) = std::mem::take(&mut buf).split_at_mut(n); - buf = rest; - - if n == 0 { - return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); - } - } - - Ok(()) - } - - /// Writes an entire buffer into the connection. - async fn write_all(&self, mut buf: &[u8]) -> Result<()> { - while !buf.is_empty() { - let n = self.conn.write(buf).await?; - let (_, rest) = std::mem::take(&mut buf).split_at(n); - buf = rest; - - if n == 0 { - return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); - } - } - - Ok(()) - } -} diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs index c0a3b5b..6585287 100644 --- a/p2p/src/lib.rs +++ b/p2p/src/lib.rs @@ -7,19 +7,19 @@ //! use easy_parallel::Parallel; //! use smol::{channel as smol_channel, future, Executor}; //! +//! use karyons_core::key_pair::{KeyPair, KeyPairType}; //! use karyons_p2p::{Backend, Config, PeerID}; //! -//! let peer_id = PeerID::random(); +//! let key_pair = KeyPair::generate(&KeyPairType::Ed25519); //! //! // Create the configuration for the backend. //! let mut config = Config::default(); //! -//! //! // Create a new Executor //! let ex = Arc::new(Executor::new()); //! //! // Create a new Backend -//! let backend = Backend::new(peer_id, config, ex.clone()); +//! let backend = Backend::new(&key_pair, config, ex.clone()); //! //! let task = async { //! // Run the backend @@ -36,12 +36,12 @@ //! ``` //! mod backend; +mod codec; mod config; mod connection; mod connector; mod discovery; mod error; -mod io_codec; mod listener; mod message; mod peer; @@ -49,7 +49,8 @@ mod peer_pool; mod protocols; mod routing_table; mod slots; -mod utils; +mod tls_config; +mod version; /// Responsible for network and system monitoring. /// [`Read More`](./monitor/struct.Monitor.html) @@ -62,6 +63,6 @@ pub use backend::{ArcBackend, Backend}; pub use config::Config; pub use error::Error as P2pError; pub use peer::{ArcPeer, PeerID}; -pub use utils::Version; +pub use version::Version; use error::{Error, Result}; diff --git a/p2p/src/listener.rs b/p2p/src/listener.rs index f2391f7..58a0931 100644 --- a/p2p/src/listener.rs +++ b/p2p/src/listener.rs @@ -1,28 +1,36 @@ use std::{future::Future, sync::Arc}; -use log::{error, info, trace}; +use log::{debug, error, info}; use karyons_core::{ - async_utils::{TaskGroup, TaskResult}, + async_util::{TaskGroup, TaskResult}, + key_pair::KeyPair, GlobalExecutor, }; -use karyons_net::{listen, Conn, Endpoint, Listener as NetListener}; +use karyons_net::{listen, tls, Conn, Endpoint, Listener as NetListener}; use crate::{ monitor::{ConnEvent, Monitor}, slots::ConnectionSlots, - Result, + tls_config::tls_server_config, + Error, Result, }; /// Responsible for creating inbound connections with other peers. pub struct Listener { + /// Identity Key pair + key_pair: KeyPair, + /// Managing spawned tasks. task_group: TaskGroup<'static>, /// Manages available inbound slots. connection_slots: Arc, + /// Enables secure connection. + enable_tls: bool, + /// Responsible for network and system monitoring. monitor: Arc, } @@ -30,13 +38,17 @@ pub struct Listener { impl Listener { /// Creates a new Listener pub fn new( + key_pair: &KeyPair, connection_slots: Arc, + enable_tls: bool, monitor: Arc, ex: GlobalExecutor, ) -> Arc { Arc::new(Self { + key_pair: key_pair.clone(), connection_slots, task_group: TaskGroup::new(ex), + enable_tls, monitor, }) } @@ -55,7 +67,7 @@ impl Listener { where Fut: Future> + Send + 'static, { - let listener = match listen(&endpoint).await { + let listener = match self.listend(&endpoint).await { Ok(listener) => { self.monitor .notify(&ConnEvent::Listening(endpoint.clone()).into()) @@ -67,21 +79,17 @@ impl Listener { self.monitor .notify(&ConnEvent::ListenFailed(endpoint).into()) .await; - return Err(err.into()); + return Err(err); } }; let resolved_endpoint = listener.local_endpoint()?; - info!("Start listening on {endpoint}"); + info!("Start listening on {resolved_endpoint}"); let selfc = self.clone(); self.task_group - .spawn(selfc.listen_loop(listener, callback), |res| async move { - if let TaskResult::Completed(Err(err)) = res { - error!("Listen loop stopped: {endpoint} {err}"); - } - }); + .spawn(selfc.listen_loop(listener, callback), |_| async {}); Ok(resolved_endpoint) } @@ -94,8 +102,7 @@ impl Listener { self: Arc, listener: Box, callback: impl FnOnce(Conn) -> Fut + Clone + Send + 'static, - ) -> Result<()> - where + ) where Fut: Future> + Send + 'static, { loop { @@ -103,27 +110,35 @@ impl Listener { self.connection_slots.wait_for_slot().await; let result = listener.accept().await; - let conn = match result { + let (conn, endpoint) = match result { Ok(c) => { + let endpoint = match c.peer_endpoint() { + Ok(e) => e, + Err(err) => { + self.monitor.notify(&ConnEvent::AcceptFailed.into()).await; + error!("Failed to accept a new connection: {err}"); + continue; + } + }; + self.monitor - .notify(&ConnEvent::Accepted(c.peer_endpoint()?).into()) + .notify(&ConnEvent::Accepted(endpoint.clone()).into()) .await; - c + (c, endpoint) } Err(err) => { error!("Failed to accept a new connection: {err}"); self.monitor.notify(&ConnEvent::AcceptFailed.into()).await; - return Err(err.into()); + continue; } }; self.connection_slots.add(); let selfc = self.clone(); - let endpoint = conn.peer_endpoint()?; let on_disconnect = |res| async move { if let TaskResult::Completed(Err(err)) = res { - trace!("Inbound connection dropped: {err}"); + debug!("Inbound connection dropped: {err}"); } selfc .monitor @@ -136,4 +151,14 @@ impl Listener { self.task_group.spawn(callback(conn), on_disconnect); } } + + async fn listend(&self, endpoint: &Endpoint) -> Result> { + if self.enable_tls { + let tls_config = tls_server_config(&self.key_pair)?; + tls::listen(endpoint, tls_config).await + } else { + listen(endpoint).await + } + .map_err(Error::KaryonsNet) + } } diff --git a/p2p/src/message.rs b/p2p/src/message.rs index 3779cc1..6b23322 100644 --- a/p2p/src/message.rs +++ b/p2p/src/message.rs @@ -4,7 +4,7 @@ use bincode::{Decode, Encode}; use karyons_net::{Addr, Port}; -use crate::{protocol::ProtocolID, routing_table::Entry, utils::VersionInt, PeerID}; +use crate::{protocol::ProtocolID, routing_table::Entry, version::VersionInt, PeerID}; /// The size of the message header, in bytes. pub const MSG_HEADER_SIZE: usize = 6; diff --git a/p2p/src/monitor.rs b/p2p/src/monitor.rs index fbbf43f..1f74503 100644 --- a/p2p/src/monitor.rs +++ b/p2p/src/monitor.rs @@ -17,6 +17,7 @@ use karyons_net::Endpoint; /// /// use smol::Executor; /// +/// use karyons_core::key_pair::{KeyPair, KeyPairType}; /// use karyons_p2p::{Config, Backend, PeerID}; /// /// async { @@ -24,7 +25,8 @@ use karyons_net::Endpoint; /// // Create a new Executor /// let ex = Arc::new(Executor::new()); /// -/// let backend = Backend::new(PeerID::random(), Config::default(), ex); +/// let key_pair = KeyPair::generate(&KeyPairType::Ed25519); +/// let backend = Backend::new(&key_pair, Config::default(), ex); /// /// // Create a new Subscription /// let sub = backend.monitor().await; diff --git a/p2p/src/peer/mod.rs b/p2p/src/peer/mod.rs index 85cd558..6ed0dd8 100644 --- a/p2p/src/peer/mod.rs +++ b/p2p/src/peer/mod.rs @@ -11,17 +11,17 @@ use smol::{ }; use karyons_core::{ - async_utils::{select, Either, TaskGroup, TaskResult}, + async_util::{select, Either, TaskGroup, TaskResult}, event::{ArcEventSys, EventListener, EventSys}, - utils::{decode, encode}, + util::{decode, encode}, GlobalExecutor, }; use karyons_net::Endpoint; use crate::{ + codec::{Codec, CodecMsg}, connection::ConnDirection, - io_codec::{CodecMsg, IOCodec}, message::{NetMsgCmd, ProtocolMsg, ShutdownMsg}, peer_pool::{ArcPeerPool, WeakPeerPool}, protocol::{Protocol, ProtocolEvent, ProtocolID}, @@ -37,8 +37,8 @@ pub struct Peer { /// A weak pointer to `PeerPool` peer_pool: WeakPeerPool, - /// Holds the IOCodec for the peer connection - io_codec: IOCodec, + /// Holds the Codec for the peer connection + codec: Codec, /// Remote endpoint for the peer remote_endpoint: Endpoint, @@ -64,7 +64,7 @@ impl Peer { pub fn new( peer_pool: WeakPeerPool, id: &PeerID, - io_codec: IOCodec, + codec: Codec, remote_endpoint: Endpoint, conn_direction: ConnDirection, ex: GlobalExecutor, @@ -72,7 +72,7 @@ impl Peer { Arc::new(Peer { id: id.clone(), peer_pool, - io_codec, + codec, protocol_ids: RwLock::new(Vec::new()), remote_endpoint, conn_direction, @@ -97,7 +97,7 @@ impl Peer { payload: payload.to_vec(), }; - self.io_codec.write(NetMsgCmd::Protocol, &proto_msg).await?; + self.codec.write(NetMsgCmd::Protocol, &proto_msg).await?; Ok(()) } @@ -124,10 +124,7 @@ impl Peer { let _ = self.stop_chan.0.try_send(Ok(())); // No need to handle the error here - let _ = self - .io_codec - .write(NetMsgCmd::Shutdown, &ShutdownMsg(0)) - .await; + let _ = self.codec.write(NetMsgCmd::Shutdown, &ShutdownMsg(0)).await; // Force shutting down self.task_group.cancel().await; @@ -174,7 +171,7 @@ impl Peer { /// Start a read loop to handle incoming messages from the peer connection. async fn read_loop(&self) -> Result<()> { loop { - let fut = select(self.stop_chan.1.recv(), self.io_codec.read()).await; + let fut = select(self.stop_chan.1.recv(), self.codec.read()).await; let result = match fut { Either::Left(stop_signal) => { trace!("Peer {} received a stop signal", self.id); diff --git a/p2p/src/peer/peer_id.rs b/p2p/src/peer/peer_id.rs index c8aec7d..903d827 100644 --- a/p2p/src/peer/peer_id.rs +++ b/p2p/src/peer/peer_id.rs @@ -2,6 +2,10 @@ use bincode::{Decode, Encode}; use rand::{rngs::OsRng, RngCore}; use sha2::{Digest, Sha256}; +use karyons_core::key_pair::PublicKey; + +use crate::Error; + /// Represents a unique identifier for a peer. #[derive(Clone, Debug, Eq, PartialEq, Hash, Decode, Encode)] pub struct PeerID(pub [u8; 32]); @@ -39,3 +43,16 @@ impl From<[u8; 32]> for PeerID { PeerID(b) } } + +impl TryFrom for PeerID { + type Error = Error; + + fn try_from(pk: PublicKey) -> Result { + let pk: [u8; 32] = pk + .as_bytes() + .try_into() + .map_err(|_| Error::TryFromPublicKey("Failed to convert public key to [u8;32]"))?; + + Ok(PeerID(pk)) + } +} diff --git a/p2p/src/peer_pool.rs b/p2p/src/peer_pool.rs index a0079f2..dd7e669 100644 --- a/p2p/src/peer_pool.rs +++ b/p2p/src/peer_pool.rs @@ -11,23 +11,23 @@ use smol::{ }; use karyons_core::{ - async_utils::{TaskGroup, TaskResult}, - utils::decode, + async_util::{TaskGroup, TaskResult}, + util::decode, GlobalExecutor, }; use karyons_net::Conn; use crate::{ + codec::{Codec, CodecMsg}, config::Config, connection::{ConnDirection, ConnQueue}, - io_codec::{CodecMsg, IOCodec}, message::{get_msg_payload, NetMsg, NetMsgCmd, VerAckMsg, VerMsg}, monitor::{Monitor, PeerPoolEvent}, peer::{ArcPeer, Peer, PeerID}, protocol::{Protocol, ProtocolConstructor, ProtocolID}, protocols::PingProtocol, - utils::{version_match, Version, VersionInt}, + version::{version_match, Version, VersionInt}, Error, Result, }; @@ -155,10 +155,10 @@ impl PeerPool { disconnect_signal: Sender>, ) -> Result<()> { let endpoint = conn.peer_endpoint()?; - let io_codec = IOCodec::new(conn); + let codec = Codec::new(conn); // Do a handshake with the connection before creating a new peer. - let pid = self.do_handshake(&io_codec, conn_direction).await?; + let pid = self.do_handshake(&codec, conn_direction).await?; // TODO: Consider restricting the subnet for inbound connections if self.contains_peer(&pid).await { @@ -169,7 +169,7 @@ impl PeerPool { let peer = Peer::new( Arc::downgrade(self), &pid, - io_codec, + codec, endpoint.clone(), conn_direction.clone(), self.executor.clone(), @@ -235,20 +235,16 @@ impl PeerPool { } /// Initiate a handshake with a connection. - async fn do_handshake( - &self, - io_codec: &IOCodec, - conn_direction: &ConnDirection, - ) -> Result { + async fn do_handshake(&self, codec: &Codec, conn_direction: &ConnDirection) -> Result { match conn_direction { ConnDirection::Inbound => { - let result = self.wait_vermsg(io_codec).await; + let result = self.wait_vermsg(codec).await; match result { Ok(_) => { - self.send_verack(io_codec, true).await?; + self.send_verack(codec, true).await?; } Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => { - self.send_verack(io_codec, false).await?; + self.send_verack(codec, false).await?; } _ => {} } @@ -256,14 +252,14 @@ impl PeerPool { } ConnDirection::Outbound => { - self.send_vermsg(io_codec).await?; - self.wait_verack(io_codec).await + self.send_vermsg(codec).await?; + self.wait_verack(codec).await } } } /// Send a Version message - async fn send_vermsg(&self, io_codec: &IOCodec) -> Result<()> { + async fn send_vermsg(&self, codec: &Codec) -> Result<()> { let pids = self.protocol_versions.read().await; let protocols = pids.iter().map(|p| (p.0.clone(), p.1.v.clone())).collect(); drop(pids); @@ -275,16 +271,16 @@ impl PeerPool { }; trace!("Send VerMsg"); - io_codec.write(NetMsgCmd::Version, &vermsg).await?; + codec.write(NetMsgCmd::Version, &vermsg).await?; Ok(()) } /// Wait for a Version message /// /// Returns the peer's ID upon successfully receiving the Version message. - async fn wait_vermsg(&self, io_codec: &IOCodec) -> Result { + async fn wait_vermsg(&self, codec: &Codec) -> Result { let timeout = Duration::from_secs(self.config.handshake_timeout); - let msg: NetMsg = io_codec.read_timeout(timeout).await?; + let msg: NetMsg = codec.read_timeout(timeout).await?; let payload = get_msg_payload!(Version, msg); let (vermsg, _) = decode::(&payload)?; @@ -300,23 +296,23 @@ impl PeerPool { } /// Send a Verack message - async fn send_verack(&self, io_codec: &IOCodec, ack: bool) -> Result<()> { + async fn send_verack(&self, codec: &Codec, ack: bool) -> Result<()> { let verack = VerAckMsg { peer_id: self.id.clone(), ack, }; trace!("Send VerAckMsg {:?}", verack); - io_codec.write(NetMsgCmd::Verack, &verack).await?; + codec.write(NetMsgCmd::Verack, &verack).await?; Ok(()) } /// Wait for a Verack message /// /// Returns the peer's ID upon successfully receiving the Verack message. - async fn wait_verack(&self, io_codec: &IOCodec) -> Result { + async fn wait_verack(&self, codec: &Codec) -> Result { let timeout = Duration::from_secs(self.config.handshake_timeout); - let msg: NetMsg = io_codec.read_timeout(timeout).await?; + let msg: NetMsg = codec.read_timeout(timeout).await?; let payload = get_msg_payload!(Verack, msg); let (verack, _) = decode::(&payload)?; diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs index 770b695..8ddc685 100644 --- a/p2p/src/protocol.rs +++ b/p2p/src/protocol.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use karyons_core::{event::EventValue, Executor}; -use crate::{peer::ArcPeer, utils::Version, Result}; +use crate::{peer::ArcPeer, version::Version, Result}; pub type ArcProtocol = Arc; @@ -37,6 +37,7 @@ impl EventValue for ProtocolEvent { /// use async_trait::async_trait; /// use smol::Executor; /// +/// use karyons_core::key_pair::{KeyPair, KeyPairType}; /// use karyons_p2p::{ /// protocol::{ArcProtocol, Protocol, ProtocolID, ProtocolEvent}, /// Backend, PeerID, Config, Version, P2pError, ArcPeer}; @@ -84,14 +85,14 @@ impl EventValue for ProtocolEvent { /// } /// /// async { -/// let peer_id = PeerID::random(); +/// let key_pair = KeyPair::generate(&KeyPairType::Ed25519); /// let config = Config::default(); /// /// // Create a new Executor /// let ex = Arc::new(Executor::new()); /// /// // Create a new Backend -/// let backend = Backend::new(peer_id, config, ex); +/// let backend = Backend::new(&key_pair, config, ex); /// /// // Attach the NewProtocol /// let c = move |peer| NewProtocol::new(peer); diff --git a/p2p/src/protocols/ping.rs b/p2p/src/protocols/ping.rs index dc1b9a1..0a5488d 100644 --- a/p2p/src/protocols/ping.rs +++ b/p2p/src/protocols/ping.rs @@ -12,9 +12,9 @@ use smol::{ }; use karyons_core::{ - async_utils::{select, timeout, Either, TaskGroup, TaskResult}, + async_util::{select, timeout, Either, TaskGroup, TaskResult}, event::EventListener, - utils::decode, + util::decode, Executor, }; @@ -23,7 +23,7 @@ use karyons_net::NetError; use crate::{ peer::ArcPeer, protocol::{ArcProtocol, Protocol, ProtocolEvent, ProtocolID}, - utils::Version, + version::Version, Result, }; diff --git a/p2p/src/routing_table/entry.rs b/p2p/src/routing_table/entry.rs index b3f219f..c5fa65d 100644 --- a/p2p/src/routing_table/entry.rs +++ b/p2p/src/routing_table/entry.rs @@ -20,7 +20,7 @@ pub struct Entry { impl PartialEq for Entry { fn eq(&self, other: &Self) -> bool { - // XXX this should also compare both addresses (the self.addr == other.addr) + // TODO: this should also compare both addresses (the self.addr == other.addr) self.key == other.key } } diff --git a/p2p/src/routing_table/mod.rs b/p2p/src/routing_table/mod.rs index 5277c0a..cfc3128 100644 --- a/p2p/src/routing_table/mod.rs +++ b/p2p/src/routing_table/mod.rs @@ -1,5 +1,8 @@ +use std::net::IpAddr; + mod bucket; mod entry; + pub use bucket::{ Bucket, BucketEntry, EntryStatusFlag, CONNECTED_ENTRY, DISCONNECTED_ENTRY, INCOMPATIBLE_ENTRY, PENDING_ENTRY, UNREACHABLE_ENTRY, UNSTABLE_ENTRY, @@ -8,7 +11,7 @@ pub use entry::{xor_distance, Entry, Key}; use rand::{rngs::OsRng, seq::SliceRandom}; -use crate::utils::subnet_match; +use karyons_net::Addr; use bucket::BUCKET_SIZE; use entry::KEY_SIZE; @@ -262,6 +265,20 @@ impl RoutingTable { } } +/// Check if two addresses belong to the same subnet. +pub fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool { + match (addr, other_addr) { + (Addr::Ip(IpAddr::V4(ip)), Addr::Ip(IpAddr::V4(other_ip))) => { + // TODO: Consider moving this to a different place + if other_ip.is_loopback() && ip.is_loopback() { + return false; + } + ip.octets()[0..3] == other_ip.octets()[0..3] + } + _ => false, + } +} + #[cfg(test)] mod tests { use super::bucket::ALL_ENTRY; diff --git a/p2p/src/slots.rs b/p2p/src/slots.rs index 99f0a78..d3a1d0a 100644 --- a/p2p/src/slots.rs +++ b/p2p/src/slots.rs @@ -1,6 +1,6 @@ use std::sync::atomic::{AtomicUsize, Ordering}; -use karyons_core::async_utils::CondWait; +use karyons_core::async_util::CondWait; /// Manages available inbound and outbound slots. pub struct ConnectionSlots { diff --git a/p2p/src/tls_config.rs b/p2p/src/tls_config.rs new file mode 100644 index 0000000..f3b231a --- /dev/null +++ b/p2p/src/tls_config.rs @@ -0,0 +1,214 @@ +use std::sync::Arc; + +use async_rustls::rustls::{ + self, cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, client::ServerCertVerifier, + server::ClientCertVerifier, Certificate, CertificateError, Error::InvalidCertificate, + PrivateKey, SupportedCipherSuite, SupportedKxGroup, SupportedProtocolVersion, +}; +use log::error; +use x509_parser::{certificate::X509Certificate, parse_x509_certificate}; + +use karyons_core::key_pair::{KeyPair, KeyPairType, PublicKey}; + +use crate::{PeerID, Result}; + +// NOTE: This code needs a comprehensive audit. + +static PROTOCOL_VERSIONS: &[&SupportedProtocolVersion] = &[&rustls::version::TLS13]; +static CIPHER_SUITES: &[SupportedCipherSuite] = &[TLS13_CHACHA20_POLY1305_SHA256]; +static KX_GROUPS: &[&SupportedKxGroup] = &[&rustls::kx_group::X25519]; + +const BAD_SIGNATURE_ERR: rustls::Error = InvalidCertificate(CertificateError::BadSignature); +const BAD_ENCODING_ERR: rustls::Error = InvalidCertificate(CertificateError::BadEncoding); + +/// Returns a TLS client configuration. +pub fn tls_client_config( + key_pair: &KeyPair, + peer_id: Option, +) -> Result { + let (cert, private_key) = generate_cert(key_pair)?; + let server_verifier = SrvrCertVerifier { peer_id }; + let client_config = rustls::ClientConfig::builder() + .with_cipher_suites(CIPHER_SUITES) + .with_kx_groups(KX_GROUPS) + .with_protocol_versions(PROTOCOL_VERSIONS)? + .with_custom_certificate_verifier(Arc::new(server_verifier)) + .with_client_auth_cert(vec![cert], private_key)?; + + Ok(client_config) +} + +/// Returns a TLS server configuration. +pub fn tls_server_config(key_pair: &KeyPair) -> Result { + let (cert, private_key) = generate_cert(key_pair)?; + let client_verifier = CliCertVerifier {}; + let server_config = rustls::ServerConfig::builder() + .with_cipher_suites(CIPHER_SUITES) + .with_kx_groups(KX_GROUPS) + .with_protocol_versions(PROTOCOL_VERSIONS)? + .with_client_cert_verifier(Arc::new(client_verifier)) + .with_single_cert(vec![cert], private_key)?; + + Ok(server_config) +} + +/// Generates a certificate and returns both the certificate and the private key. +fn generate_cert(key_pair: &KeyPair) -> Result<(Certificate, PrivateKey)> { + let cert_key_pair = rcgen::KeyPair::generate(&rcgen::PKCS_ED25519)?; + let private_key = rustls::PrivateKey(cert_key_pair.serialize_der()); + + // Add a custom extension to the certificate: + // - Sign the certificate's public key with the provided key pair's public key + // - Append both the signature and the key pair's public key to the extension + let signature = key_pair.sign(&cert_key_pair.public_key_der()); + let ext_content = yasna::encode_der(&(key_pair.public().as_bytes().to_vec(), signature)); + // XXX: Not sure about the oid number ??? + let mut ext = rcgen::CustomExtension::from_oid_content(&[0, 0, 0, 0], ext_content); + ext.set_criticality(true); + + let mut params = rcgen::CertificateParams::new(vec![]); + params.alg = &rcgen::PKCS_ED25519; + params.key_pair = Some(cert_key_pair); + params.custom_extensions.push(ext); + + let cert = rustls::Certificate(rcgen::Certificate::from_params(params)?.serialize_der()?); + Ok((cert, private_key)) +} + +/// Verifies the given certification. +fn verify_cert(end_entity: &Certificate) -> std::result::Result { + // Parse the certificate. + let cert = parse_cert(end_entity)?; + + match cert.extensions().first() { + Some(ext) => { + // Extract the peer id (public key) and the signature from the extension. + let (public_key, signature): (Vec, Vec) = + yasna::decode_der(ext.value).map_err(|_| BAD_ENCODING_ERR)?; + + // Use the peer id (public key) to verify the extracted signature. + let public_key = PublicKey::from_bytes(&KeyPairType::Ed25519, &public_key) + .map_err(|_| BAD_ENCODING_ERR)?; + public_key + .verify(cert.public_key().raw, &signature) + .map_err(|_| BAD_SIGNATURE_ERR)?; + + // Verify the certificate signature. + verify_cert_signature( + &cert, + cert.tbs_certificate.as_ref(), + cert.signature_value.as_ref(), + )?; + + PeerID::try_from(public_key).map_err(|_| BAD_ENCODING_ERR) + } + None => Err(BAD_ENCODING_ERR), + } +} + +/// Parses the given x509 certificate. +fn parse_cert(end_entity: &Certificate) -> std::result::Result { + let (_, cert) = parse_x509_certificate(end_entity.as_ref()).map_err(|_| BAD_ENCODING_ERR)?; + + if !cert.validity().is_valid() { + return Err(InvalidCertificate(CertificateError::NotValidYet)); + } + + Ok(cert) +} + +/// Verifies the signature of the given certificate. +fn verify_cert_signature( + cert: &X509Certificate, + message: &[u8], + signature: &[u8], +) -> std::result::Result<(), rustls::Error> { + let public_key = PublicKey::from_bytes( + &KeyPairType::Ed25519, + cert.tbs_certificate.subject_pki.subject_public_key.as_ref(), + ) + .map_err(|_| BAD_ENCODING_ERR)?; + + public_key + .verify(message, signature) + .map_err(|_| BAD_SIGNATURE_ERR) +} + +struct SrvrCertVerifier { + peer_id: Option, +} + +impl ServerCertVerifier for SrvrCertVerifier { + fn verify_server_cert( + &self, + end_entity: &Certificate, + _intermediates: &[Certificate], + _server_name: &rustls::ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: std::time::SystemTime, + ) -> std::result::Result { + let peer_id = match verify_cert(end_entity) { + Ok(pid) => pid, + Err(err) => { + error!("Failed to verify cert: {err}"); + return Err(err); + } + }; + + // Verify that the peer id in the certificate's extension matches the + // one the client intends to connect to. + // Both should be equal for establishing a fully secure connection. + if let Some(pid) = &self.peer_id { + if pid != &peer_id { + return Err(InvalidCertificate( + CertificateError::ApplicationVerificationFailure, + )); + } + } + + Ok(rustls::client::ServerCertVerified::assertion()) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &Certificate, + dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result { + let cert = parse_cert(cert)?; + verify_cert_signature(&cert, message, dss.signature())?; + Ok(rustls::client::HandshakeSignatureValid::assertion()) + } +} + +struct CliCertVerifier {} +impl ClientCertVerifier for CliCertVerifier { + fn verify_client_cert( + &self, + end_entity: &Certificate, + _intermediates: &[Certificate], + _now: std::time::SystemTime, + ) -> std::result::Result { + if let Err(err) = verify_cert(end_entity) { + error!("Failed to verify cert: {err}"); + return Err(err); + }; + Ok(rustls::server::ClientCertVerified::assertion()) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &Certificate, + dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result { + let cert = parse_cert(cert)?; + verify_cert_signature(&cert, message, dss.signature())?; + Ok(rustls::client::HandshakeSignatureValid::assertion()) + } + + fn client_auth_root_subjects(&self) -> &[rustls::DistinguishedName] { + &[] + } +} diff --git a/p2p/src/utils/mod.rs b/p2p/src/utils/mod.rs deleted file mode 100644 index e8ff9d0..0000000 --- a/p2p/src/utils/mod.rs +++ /dev/null @@ -1,21 +0,0 @@ -mod version; - -pub use version::{version_match, Version, VersionInt}; - -use std::net::IpAddr; - -use karyons_net::Addr; - -/// Check if two addresses belong to the same subnet. -pub fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool { - match (addr, other_addr) { - (Addr::Ip(IpAddr::V4(ip)), Addr::Ip(IpAddr::V4(other_ip))) => { - // XXX Consider moving this to a different location - if other_ip.is_loopback() && ip.is_loopback() { - return false; - } - ip.octets()[0..3] == other_ip.octets()[0..3] - } - _ => false, - } -} diff --git a/p2p/src/utils/version.rs b/p2p/src/utils/version.rs deleted file mode 100644 index a101b28..0000000 --- a/p2p/src/utils/version.rs +++ /dev/null @@ -1,93 +0,0 @@ -use std::str::FromStr; - -use bincode::{Decode, Encode}; -use semver::VersionReq; - -use crate::{Error, Result}; - -/// Represents the network version and protocol version used in karyons p2p. -/// -/// # Example -/// -/// ``` -/// use karyons_p2p::Version; -/// -/// let version: Version = "0.2.0, >0.1.0".parse().unwrap(); -/// -/// let version: Version = "0.2.0".parse().unwrap(); -/// -/// ``` -#[derive(Debug, Clone)] -pub struct Version { - pub v: VersionInt, - pub req: VersionReq, -} - -impl Version { - /// Creates a new Version - pub fn new(v: VersionInt, req: VersionReq) -> Self { - Self { v, req } - } -} - -#[derive(Debug, Decode, Encode, Clone)] -pub struct VersionInt { - major: u64, - minor: u64, - patch: u64, -} - -impl FromStr for Version { - type Err = Error; - - fn from_str(s: &str) -> Result { - let v: Vec<&str> = s.split(", ").collect(); - if v.is_empty() || v.len() > 2 { - return Err(Error::ParseError(format!("Invalid version{s}"))); - } - - let version: VersionInt = v[0].parse()?; - let req: VersionReq = if v.len() > 1 { v[1] } else { v[0] }.parse()?; - - Ok(Self { v: version, req }) - } -} - -impl FromStr for VersionInt { - type Err = Error; - - fn from_str(s: &str) -> Result { - let v: Vec<&str> = s.split('.').collect(); - if v.len() < 2 || v.len() > 3 { - return Err(Error::ParseError(format!("Invalid version{s}"))); - } - - let major = v[0].parse::()?; - let minor = v[1].parse::()?; - let patch = v.get(2).unwrap_or(&"0").parse::()?; - - Ok(Self { - major, - minor, - patch, - }) - } -} - -impl std::fmt::Display for VersionInt { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}.{}.{}", self.major, self.minor, self.patch) - } -} - -impl From for semver::Version { - fn from(v: VersionInt) -> Self { - semver::Version::new(v.major, v.minor, v.patch) - } -} - -/// Check if a version satisfies a version request. -pub fn version_match(version_req: &VersionReq, version: &VersionInt) -> bool { - let version: semver::Version = version.clone().into(); - version_req.matches(&version) -} diff --git a/p2p/src/version.rs b/p2p/src/version.rs new file mode 100644 index 0000000..a101b28 --- /dev/null +++ b/p2p/src/version.rs @@ -0,0 +1,93 @@ +use std::str::FromStr; + +use bincode::{Decode, Encode}; +use semver::VersionReq; + +use crate::{Error, Result}; + +/// Represents the network version and protocol version used in karyons p2p. +/// +/// # Example +/// +/// ``` +/// use karyons_p2p::Version; +/// +/// let version: Version = "0.2.0, >0.1.0".parse().unwrap(); +/// +/// let version: Version = "0.2.0".parse().unwrap(); +/// +/// ``` +#[derive(Debug, Clone)] +pub struct Version { + pub v: VersionInt, + pub req: VersionReq, +} + +impl Version { + /// Creates a new Version + pub fn new(v: VersionInt, req: VersionReq) -> Self { + Self { v, req } + } +} + +#[derive(Debug, Decode, Encode, Clone)] +pub struct VersionInt { + major: u64, + minor: u64, + patch: u64, +} + +impl FromStr for Version { + type Err = Error; + + fn from_str(s: &str) -> Result { + let v: Vec<&str> = s.split(", ").collect(); + if v.is_empty() || v.len() > 2 { + return Err(Error::ParseError(format!("Invalid version{s}"))); + } + + let version: VersionInt = v[0].parse()?; + let req: VersionReq = if v.len() > 1 { v[1] } else { v[0] }.parse()?; + + Ok(Self { v: version, req }) + } +} + +impl FromStr for VersionInt { + type Err = Error; + + fn from_str(s: &str) -> Result { + let v: Vec<&str> = s.split('.').collect(); + if v.len() < 2 || v.len() > 3 { + return Err(Error::ParseError(format!("Invalid version{s}"))); + } + + let major = v[0].parse::()?; + let minor = v[1].parse::()?; + let patch = v.get(2).unwrap_or(&"0").parse::()?; + + Ok(Self { + major, + minor, + patch, + }) + } +} + +impl std::fmt::Display for VersionInt { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}.{}", self.major, self.minor, self.patch) + } +} + +impl From for semver::Version { + fn from(v: VersionInt) -> Self { + semver::Version::new(v.major, v.minor, v.patch) + } +} + +/// Check if a version satisfies a version request. +pub fn version_match(version_req: &VersionReq, version: &VersionInt) -> bool { + let version: semver::Version = version.clone().into(); + version_req.matches(&version) +} -- cgit v1.2.3