diff options
author | hozan23 <hozan23@proton.me> | 2023-11-08 13:03:27 +0300 |
---|---|---|
committer | hozan23 <hozan23@proton.me> | 2023-11-08 13:03:27 +0300 |
commit | 4fe665fc8bc6265baf5bfba6b6a5f3ee2dba63dc (patch) | |
tree | 77c7c40c9725539546e53b00f424deafe5ec81a8 |
first commit
63 files changed, 6692 insertions, 0 deletions
diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 0000000..ef30d2d --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,26 @@ +name: Rust + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Build + run: cargo build --all --verbose + - name: Run tests + run: cargo test --all --verbose + - name: Run clippy + run: cargo clippy --all -- -D warnings + - name: Run fmt + run: cargo fmt --all -- --check diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4fffb2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..aa71954 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "karyons" +version = "0.1.0" +edition = "2021" + +[workspace] + +members = [ + "karyons_core", + "karyons_net", + "karyons_p2p", +] diff --git a/README.md b/README.md new file mode 100644 index 0000000..8eb1ca1 --- /dev/null +++ b/README.md @@ -0,0 +1,154 @@ +# karyons + +> :warning: **Warning: This Project is a Work in Progress** + +We are in the process of developing an infrastructure for peer-to-peer, +decentralized, and collaborative software. + +Join us on: + +- [Discord](https://discord.gg/NDSpDdck) +- [Matrix](https://matrix.to/#/#karyons:matrix.org) + +## karyons p2p + +karyons p2p serves as the foundational stack for the karyons project. It +offers a modular, lightweight, and customizable p2p network stack that +seamlessly integrates with any p2p project. + +### Architecture + +#### Discovery + +karyons p2p uses a customized version of the Kademlia for discovering new peers +in the network. This approach is based on Kademlia but with several significant +differences and optimizations. Some of the main changes: + +1. karyons p2p uses TCP for the lookup process, while UDP is used for + validating and refreshing the routing table. The reason for this choice is + that the lookup process is infrequent, and the work required to manage + messages with UDP is largely equivalent to using TCP for this purpose. + However, for the periodic and efficient sending of numerous Ping messages to + the entries in the routing table during refreshing, it makes sense to + use UDP. + +2. In contrast to traditional Kademlia, which often employs 160 buckets, + karyons p2p reduces the number of buckets to 32. This optimization is a + result of the observation that most nodes tend to map into the last few + buckets, with the majority of other buckets remaining empty. + +3. While Kademlia typically uses a 160-bit key to identify a peer, karyons p2p + uses a 256-bit key. + +> Despite criticisms of Kademlia's vulnerabilities, particularly concerning +> Sybil and Eclipse attacks [1](https://eprint.iacr.org/2018/236.pdf) +> [2](https://arxiv.org/abs/1908.10141), we chose to use Kademlia because our +> main goal is to build an infrastructure focused on sharing data. This choice +> may also assist us in supporting sharding in the future. However, we have made +> efforts to mitigate most of its vulnerabilities. Several projects, including +> BitTorrent, Ethereum, IPFS, and Storj, still rely on Kademlia. + +#### Peer IDs + +Peers in the karyons p2p network are identified by their 256-bit (32-byte) Peer IDs. + +#### Seeding + +At the network's initiation, the client populates the routing table with peers +closest to its key(PeerID) through a seeding process. Once this process is +complete, and the routing table is filled, the client selects a random peer +from the routing table and establishes an outbound connection. This process +continues until all outbound slots are occupied. + +The client can optionally provide a listening endpoint to accept inbound +connections. + +#### Handshake + +When an inbound or outbound connection is established, the client initiates a +handshake with that connection. If the handshake is successful, the connection +is added to the PeerPool. + +#### Protocols + +In the karyons p2p network, we have two types of protocols: core protocols and +custom protocols. Core protocols are prebuilt into karyons p2p, such as the +Ping protocol used to maintain connections. Custom protocols, on the other +hand, are protocols that you define for your application to provide its core +functionality. + +Here's an example of a custom protocol: + +```rust +pub struct NewProtocol { + peer: ArcPeer, +} + +impl NewProtocol { + fn new(peer: ArcPeer) -> ArcProtocol { + Arc::new(Self { + peer, + }) + } +} + +#[async_trait] +impl Protocol for NewProtocol { + async fn start(self: Arc<Self>, ex: Arc<Executor<'_>>) -> Result<(), P2pError> { + let listener = self.peer.register_listener::<Self>().await; + loop { + let event = listener.recv().await.unwrap(); + + match event { + ProtocolEvent::Message(msg) => { + println!("{:?}", msg); + } + ProtocolEvent::Shutdown => { + break; + } + } + } + + listener.cancel().await; + Ok(()) + } + + fn version() -> Result<Version, P2pError> { + "0.2.0, >0.1.0".parse() + } + + fn id() -> ProtocolID { + "NEWPROTOCOLID".into() + } +} + +``` + +Whenever a new peer is added to the PeerPool, all the protocols, including +your custom protocols, will automatically start running with the newly connected peer. + +### Network Security + +It's obvious that connections in karyons p2p are not secure at the moment, as +it currently only supports TCP connections. However, we are currently working +on adding support for TLS connections. + +### Usage + +You can check out the examples [here](./karyons_p2p/examples). + +If you have tmux installed, you can run the network simulation script in the +examples directory to run 12 peers simultaneously. + + RUST_LOG=karyons=debug ./net_simulation.sh + +## Contribution + +Feel free to open a pull request. We appreciate your help. + +## License + +All the code in this repository is licensed under the GNU General Public +License, version 3 (GPL-3.0). You can find a copy of the license in the +[LICENSE](./LICENSE) file. + diff --git a/karyons_core/Cargo.toml b/karyons_core/Cargo.toml new file mode 100644 index 0000000..712b7db --- /dev/null +++ b/karyons_core/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "karyons_core" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +smol = "1.3.0" +pin-project-lite = "0.2.13" +log = "0.4.20" +bincode = { version="2.0.0-rc.3", features = ["derive"]} +chrono = "0.4.30" +rand = "0.8.5" +thiserror = "1.0.47" +dirs = "5.0.1" diff --git a/karyons_core/src/async_utils/backoff.rs b/karyons_core/src/async_utils/backoff.rs new file mode 100644 index 0000000..f7e131d --- /dev/null +++ b/karyons_core/src/async_utils/backoff.rs @@ -0,0 +1,115 @@ +use std::{ + cmp::min, + sync::atomic::{AtomicBool, AtomicU32, Ordering}, + time::Duration, +}; + +use smol::Timer; + +/// Exponential backoff +/// <https://en.wikipedia.org/wiki/Exponential_backoff> +/// +/// # Examples +/// +/// ``` +/// use karyons_core::async_utils::Backoff; +/// +/// async { +/// let backoff = Backoff::new(300, 3000); +/// +/// loop { +/// backoff.sleep().await; +/// +/// // do something +/// break; +/// } +/// +/// backoff.reset(); +/// +/// // .... +/// }; +/// +/// ``` +/// +pub struct Backoff { + /// The base delay in milliseconds for the initial retry. + base_delay: u64, + /// The max delay in milliseconds allowed for a retry. + max_delay: u64, + /// Atomic counter + retries: AtomicU32, + /// Stop flag + stop: AtomicBool, +} + +impl Backoff { + /// Creates a new Backoff. + pub fn new(base_delay: u64, max_delay: u64) -> Self { + Self { + base_delay, + max_delay, + retries: AtomicU32::new(0), + stop: AtomicBool::new(false), + } + } + + /// Sleep based on the current retry count and delay values. + /// Retruns the delay value. + pub async fn sleep(&self) -> u64 { + if self.stop.load(Ordering::SeqCst) { + Timer::after(Duration::from_millis(self.max_delay)).await; + return self.max_delay; + } + + let retries = self.retries.load(Ordering::SeqCst); + let delay = self.base_delay * (2_u64).pow(retries); + let delay = min(delay, self.max_delay); + + if delay == self.max_delay { + self.stop.store(true, Ordering::SeqCst); + } + + self.retries.store(retries + 1, Ordering::SeqCst); + + Timer::after(Duration::from_millis(delay)).await; + delay + } + + /// Reset the retry counter to 0. + pub fn reset(&self) { + self.retries.store(0, Ordering::SeqCst); + self.stop.store(false, Ordering::SeqCst); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn test_backoff() { + smol::block_on(async move { + let backoff = Arc::new(Backoff::new(5, 15)); + let backoff_c = backoff.clone(); + smol::spawn(async move { + let delay = backoff_c.sleep().await; + assert_eq!(delay, 5); + + let delay = backoff_c.sleep().await; + assert_eq!(delay, 10); + + let delay = backoff_c.sleep().await; + assert_eq!(delay, 15); + }) + .await; + + smol::spawn(async move { + backoff.reset(); + let delay = backoff.sleep().await; + assert_eq!(delay, 5); + }) + .await; + }); + } +} diff --git a/karyons_core/src/async_utils/condvar.rs b/karyons_core/src/async_utils/condvar.rs new file mode 100644 index 0000000..814f78f --- /dev/null +++ b/karyons_core/src/async_utils/condvar.rs @@ -0,0 +1,387 @@ +use std::{ + collections::HashMap, + future::Future, + pin::Pin, + sync::Mutex, + task::{Context, Poll, Waker}, +}; + +use smol::lock::MutexGuard; + +use crate::utils::random_16; + +/// CondVar is an async version of <https://doc.rust-lang.org/std/sync/struct.Condvar.html> +/// +/// # Example +/// +///``` +/// use std::sync::Arc; +/// +/// use smol::lock::Mutex; +/// +/// use karyons_core::async_utils::CondVar; +/// +/// async { +/// +/// let val = Arc::new(Mutex::new(false)); +/// let condvar = Arc::new(CondVar::new()); +/// +/// let val_cloned = val.clone(); +/// let condvar_cloned = condvar.clone(); +/// smol::spawn(async move { +/// let mut val = val_cloned.lock().await; +/// +/// // While the boolean flag is false, wait for a signal. +/// while !*val { +/// val = condvar_cloned.wait(val).await; +/// } +/// +/// // ... +/// }); +/// +/// let condvar_cloned = condvar.clone(); +/// smol::spawn(async move { +/// let mut val = val.lock().await; +/// +/// // While the boolean flag is false, wait for a signal. +/// while !*val { +/// val = condvar_cloned.wait(val).await; +/// } +/// +/// // ... +/// }); +/// +/// // Wake up all waiting tasks on this condvar +/// condvar.broadcast(); +/// }; +/// +/// ``` + +pub struct CondVar { + inner: Mutex<Wakers>, +} + +impl CondVar { + /// Creates a new CondVar + pub fn new() -> Self { + Self { + inner: Mutex::new(Wakers::new()), + } + } + + /// Blocks the current task until this condition variable receives a notification. + pub async fn wait<'a, T>(&self, g: MutexGuard<'a, T>) -> MutexGuard<'a, T> { + let m = MutexGuard::source(&g); + + CondVarAwait::new(self, g).await; + + m.lock().await + } + + /// Wakes up one blocked task waiting on this condvar. + pub fn signal(&self) { + self.inner.lock().unwrap().wake(true); + } + + /// Wakes up all blocked tasks waiting on this condvar. + pub fn broadcast(&self) { + self.inner.lock().unwrap().wake(false); + } +} + +impl Default for CondVar { + fn default() -> Self { + Self::new() + } +} + +struct CondVarAwait<'a, T> { + id: Option<u16>, + condvar: &'a CondVar, + guard: Option<MutexGuard<'a, T>>, +} + +impl<'a, T> CondVarAwait<'a, T> { + fn new(condvar: &'a CondVar, guard: MutexGuard<'a, T>) -> Self { + Self { + condvar, + guard: Some(guard), + id: None, + } + } +} + +impl<'a, T> Future for CondVarAwait<'a, T> { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let mut inner = self.condvar.inner.lock().unwrap(); + + match self.guard.take() { + Some(_) => { + // the first pooll will release the Mutexguard + self.id = Some(inner.put(Some(cx.waker().clone()))); + Poll::Pending + } + None => { + // Return Ready if it has already been polled and removed + // from the waker list. + if self.id.is_none() { + return Poll::Ready(()); + } + + let i = self.id.as_ref().unwrap(); + match inner.wakers.get_mut(i).unwrap() { + Some(wk) => { + // This will prevent cloning again + if !wk.will_wake(cx.waker()) { + *wk = cx.waker().clone(); + } + Poll::Pending + } + None => { + inner.delete(i); + self.id = None; + Poll::Ready(()) + } + } + } + } + } +} + +impl<'a, T> Drop for CondVarAwait<'a, T> { + fn drop(&mut self) { + if let Some(id) = self.id { + let mut inner = self.condvar.inner.lock().unwrap(); + if let Some(wk) = inner.wakers.get_mut(&id).unwrap().take() { + wk.wake() + } + } + } +} + +/// Wakers is a helper struct to store the task wakers +struct Wakers { + wakers: HashMap<u16, Option<Waker>>, +} + +impl Wakers { + fn new() -> Self { + Self { + wakers: HashMap::new(), + } + } + + fn put(&mut self, waker: Option<Waker>) -> u16 { + let mut id: u16; + + id = random_16(); + while self.wakers.contains_key(&id) { + id = random_16(); + } + + self.wakers.insert(id, waker); + id + } + + fn delete(&mut self, id: &u16) -> Option<Option<Waker>> { + self.wakers.remove(id) + } + + fn wake(&mut self, signal: bool) { + for (_, wk) in self.wakers.iter_mut() { + match wk.take() { + Some(w) => { + w.wake(); + if signal { + break; + } + } + None => continue, + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use smol::lock::Mutex; + use std::{ + collections::VecDeque, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + }; + + // The tests below demonstrate a solution to a problem in the Wikipedia + // explanation of condition variables: + // https://en.wikipedia.org/wiki/Monitor_(synchronization)#Solving_the_bounded_producer/consumer_problem. + + struct Queue { + items: VecDeque<String>, + max_len: usize, + } + impl Queue { + fn new(max_len: usize) -> Self { + Self { + items: VecDeque::new(), + max_len, + } + } + + fn is_full(&self) -> bool { + self.items.len() == self.max_len + } + + fn is_empty(&self) -> bool { + self.items.is_empty() + } + } + + #[test] + fn test_condvar_signal() { + smol::block_on(async { + let number_of_tasks = 30; + + let queue = Arc::new(Mutex::new(Queue::new(5))); + let condvar_full = Arc::new(CondVar::new()); + let condvar_empty = Arc::new(CondVar::new()); + + let queue_cloned = queue.clone(); + let condvar_full_cloned = condvar_full.clone(); + let condvar_empty_cloned = condvar_empty.clone(); + + let _producer1 = smol::spawn(async move { + for i in 1..number_of_tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-full + while queue.is_full() { + // Release queue mutex and sleep + queue = condvar_full_cloned.wait(queue).await; + } + + queue.items.push_back(format!("task {i}")); + + // Wake up the consumer + condvar_empty_cloned.signal(); + } + }); + + let queue_cloned = queue.clone(); + let task_consumed = Arc::new(AtomicUsize::new(0)); + let task_consumed_ = task_consumed.clone(); + let consumer = smol::spawn(async move { + for _ in 1..number_of_tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-empty + while queue.is_empty() { + // Release queue mutex and sleep + queue = condvar_empty.wait(queue).await; + } + + let _ = queue.items.pop_front().unwrap(); + + task_consumed_.fetch_add(1, Ordering::Relaxed); + + // Do something + + // Wake up the producer + condvar_full.signal(); + } + }); + + consumer.await; + assert!(queue.lock().await.is_empty()); + assert_eq!(task_consumed.load(Ordering::Relaxed), 29); + }); + } + + #[test] + fn test_condvar_broadcast() { + smol::block_on(async { + let tasks = 30; + + let queue = Arc::new(Mutex::new(Queue::new(5))); + let condvar = Arc::new(CondVar::new()); + + let queue_cloned = queue.clone(); + let condvar_cloned = condvar.clone(); + let _producer1 = smol::spawn(async move { + for i in 1..tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-full + while queue.is_full() { + // Release queue mutex and sleep + queue = condvar_cloned.wait(queue).await; + } + + queue.items.push_back(format!("producer1: task {i}")); + + // Wake up all producer and consumer tasks + condvar_cloned.broadcast(); + } + }); + + let queue_cloned = queue.clone(); + let condvar_cloned = condvar.clone(); + let _producer2 = smol::spawn(async move { + for i in 1..tasks { + // Lock queue mtuex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-full + while queue.is_full() { + // Release queue mutex and sleep + queue = condvar_cloned.wait(queue).await; + } + + queue.items.push_back(format!("producer2: task {i}")); + + // Wake up all producer and consumer tasks + condvar_cloned.broadcast(); + } + }); + + let queue_cloned = queue.clone(); + let task_consumed = Arc::new(AtomicUsize::new(0)); + let task_consumed_ = task_consumed.clone(); + + let consumer = smol::spawn(async move { + for _ in 1..((tasks * 2) - 1) { + { + // Lock queue mutex + let mut queue = queue_cloned.lock().await; + + // Check if the queue is non-empty + while queue.is_empty() { + // Release queue mutex and sleep + queue = condvar.wait(queue).await; + } + + let _ = queue.items.pop_front().unwrap(); + + task_consumed_.fetch_add(1, Ordering::Relaxed); + + // Do something + + // Wake up all producer and consumer tasks + condvar.broadcast(); + } + } + }); + + consumer.await; + assert!(queue.lock().await.is_empty()); + assert_eq!(task_consumed.load(Ordering::Relaxed), 58); + }); + } +} diff --git a/karyons_core/src/async_utils/condwait.rs b/karyons_core/src/async_utils/condwait.rs new file mode 100644 index 0000000..f16a99e --- /dev/null +++ b/karyons_core/src/async_utils/condwait.rs @@ -0,0 +1,96 @@ +use smol::lock::Mutex; + +use super::CondVar; + +/// CondWait is a wrapper struct for CondVar with a Mutex boolean flag. +/// +/// # Example +/// +///``` +/// use std::sync::Arc; +/// +/// use karyons_core::async_utils::CondWait; +/// +/// async { +/// let cond_wait = Arc::new(CondWait::new()); +/// let cond_wait_cloned = cond_wait.clone(); +/// let task = smol::spawn(async move { +/// cond_wait_cloned.wait().await; +/// // ... +/// }); +/// +/// cond_wait.signal().await; +/// }; +/// +/// ``` +/// +pub struct CondWait { + /// The CondVar + condvar: CondVar, + /// Boolean flag + w: Mutex<bool>, +} + +impl CondWait { + /// Creates a new CondWait. + pub fn new() -> Self { + Self { + condvar: CondVar::new(), + w: Mutex::new(false), + } + } + + /// Waits for a signal or broadcast. + pub async fn wait(&self) { + let mut w = self.w.lock().await; + + // While the boolean flag is false, wait for a signal. + while !*w { + w = self.condvar.wait(w).await; + } + } + + /// Signal a waiting task. + pub async fn signal(&self) { + *self.w.lock().await = true; + self.condvar.signal(); + } + + /// Signal all waiting tasks. + pub async fn broadcast(&self) { + *self.w.lock().await = true; + self.condvar.broadcast(); + } + + /// Reset the boolean flag value to false. + pub async fn reset(&self) { + *self.w.lock().await = false; + } +} + +impl Default for CondWait { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn test_cond_wait() { + smol::block_on(async { + let cond_wait = Arc::new(CondWait::new()); + let cond_wait_cloned = cond_wait.clone(); + let task = smol::spawn(async move { + cond_wait_cloned.wait().await; + true + }); + + cond_wait.signal().await; + assert!(task.await); + }); + } +} diff --git a/karyons_core/src/async_utils/mod.rs b/karyons_core/src/async_utils/mod.rs new file mode 100644 index 0000000..c871bad --- /dev/null +++ b/karyons_core/src/async_utils/mod.rs @@ -0,0 +1,13 @@ +mod backoff; +mod condvar; +mod condwait; +mod select; +mod task_group; +mod timeout; + +pub use backoff::Backoff; +pub use condvar::CondVar; +pub use condwait::CondWait; +pub use select::{select, Either}; +pub use task_group::{TaskGroup, TaskResult}; +pub use timeout::timeout; diff --git a/karyons_core/src/async_utils/select.rs b/karyons_core/src/async_utils/select.rs new file mode 100644 index 0000000..d61b355 --- /dev/null +++ b/karyons_core/src/async_utils/select.rs @@ -0,0 +1,99 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use pin_project_lite::pin_project; +use smol::future::Future; + +/// Returns the result of the future that completes first, preferring future1 +/// if both are ready. +/// +/// # Examples +/// +/// ``` +/// use std::future; +/// +/// use karyons_core::async_utils::{select, Either}; +/// +/// async { +/// let fut1 = future::pending::<String>(); +/// let fut2 = future::ready(0); +/// let res = select(fut1, fut2).await; +/// assert!(matches!(res, Either::Right(0))); +/// // .... +/// }; +/// +/// ``` +/// +pub fn select<T1, T2, F1, F2>(future1: F1, future2: F2) -> Select<F1, F2> +where + F1: Future<Output = T1>, + F2: Future<Output = T2>, +{ + Select { future1, future2 } +} + +pin_project! { + #[derive(Debug)] + pub struct Select<F1, F2> { + #[pin] + future1: F1, + #[pin] + future2: F2, + } +} + +/// The return value from the `select` function, indicating which future +/// completed first. +#[derive(Debug)] +pub enum Either<T1, T2> { + Left(T1), + Right(T2), +} + +// Implement the Future trait for the Select struct. +impl<T1, T2, F1, F2> Future for Select<F1, F2> +where + F1: Future<Output = T1>, + F2: Future<Output = T2>, +{ + type Output = Either<T1, T2>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + + if let Poll::Ready(t) = this.future1.poll(cx) { + return Poll::Ready(Either::Left(t)); + } + + if let Poll::Ready(t) = this.future2.poll(cx) { + return Poll::Ready(Either::Right(t)); + } + + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use super::{select, Either}; + use smol::Timer; + use std::future; + + #[test] + fn test_async_select() { + smol::block_on(async move { + let fut = select(Timer::never(), future::ready(0 as u32)).await; + assert!(matches!(fut, Either::Right(0))); + + let fut1 = future::pending::<String>(); + let fut2 = future::ready(0); + let res = select(fut1, fut2).await; + assert!(matches!(res, Either::Right(0))); + + let fut1 = future::ready(0); + let fut2 = future::pending::<String>(); + let res = select(fut1, fut2).await; + assert!(matches!(res, Either::Left(_))); + }); + } +} diff --git a/karyons_core/src/async_utils/task_group.rs b/karyons_core/src/async_utils/task_group.rs new file mode 100644 index 0000000..8707d0e --- /dev/null +++ b/karyons_core/src/async_utils/task_group.rs @@ -0,0 +1,197 @@ +use std::{future::Future, sync::Arc, sync::Mutex}; + +use smol::Task; + +use crate::Executor; + +use super::{select, CondWait, Either}; + +/// TaskGroup is a group of spawned tasks. +/// +/// # Example +/// +/// ``` +/// +/// use std::sync::Arc; +/// +/// use karyons_core::async_utils::TaskGroup; +/// +/// async { +/// +/// let ex = Arc::new(smol::Executor::new()); +/// let group = TaskGroup::new(); +/// +/// group.spawn(ex.clone(), smol::Timer::never(), |_| async {}); +/// +/// group.cancel().await; +/// +/// }; +/// +/// ``` +/// +pub struct TaskGroup { + tasks: Mutex<Vec<TaskHandler>>, + stop_signal: Arc<CondWait>, +} + +impl<'a> TaskGroup { + /// Creates a new task group + pub fn new() -> Self { + Self { + tasks: Mutex::new(Vec::new()), + stop_signal: Arc::new(CondWait::new()), + } + } + + /// Spawns a new task and calls the callback after it has completed + /// or been canceled. The callback will have the `TaskResult` as a + /// parameter, indicating whether the task completed or was canceled. + pub fn spawn<T, Fut, CallbackF, CallbackFut>( + &self, + executor: Executor<'a>, + fut: Fut, + callback: CallbackF, + ) where + T: Send + Sync + 'a, + Fut: Future<Output = T> + Send + 'a, + CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'a, + CallbackFut: Future<Output = ()> + Send + 'a, + { + let task = TaskHandler::new(executor.clone(), fut, callback, self.stop_signal.clone()); + self.tasks.lock().unwrap().push(task); + } + + /// Checks if the task group is empty. + pub fn is_empty(&self) -> bool { + self.tasks.lock().unwrap().is_empty() + } + + /// Get the number of the tasks in the group. + pub fn len(&self) -> usize { + self.tasks.lock().unwrap().len() + } + + /// Cancels all tasks in the group. + pub async fn cancel(&self) { + self.stop_signal.broadcast().await; + + loop { + let task = self.tasks.lock().unwrap().pop(); + if let Some(t) = task { + t.cancel().await + } else { + break; + } + } + } +} + +impl Default for TaskGroup { + fn default() -> Self { + Self::new() + } +} + +/// The result of a spawned task. +#[derive(Debug)] +pub enum TaskResult<T> { + Completed(T), + Cancelled, +} + +impl<T: std::fmt::Debug> std::fmt::Display for TaskResult<T> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + TaskResult::Cancelled => write!(f, "Task cancelled"), + TaskResult::Completed(res) => write!(f, "Task completed: {:?}", res), + } + } +} + +/// TaskHandler +pub struct TaskHandler { + task: Task<()>, + cancel_flag: Arc<CondWait>, +} + +impl<'a> TaskHandler { + /// Creates a new task handle + fn new<T, Fut, CallbackF, CallbackFut>( + ex: Executor<'a>, + fut: Fut, + callback: CallbackF, + stop_signal: Arc<CondWait>, + ) -> TaskHandler + where + T: Send + Sync + 'a, + Fut: Future<Output = T> + Send + 'a, + CallbackF: FnOnce(TaskResult<T>) -> CallbackFut + Send + 'a, + CallbackFut: Future<Output = ()> + Send + 'a, + { + let cancel_flag = Arc::new(CondWait::new()); + let cancel_flag_c = cancel_flag.clone(); + let task = ex.spawn(async move { + //start_signal.signal().await; + // Waits for either the stop signal or the task to complete. + let result = select(stop_signal.wait(), fut).await; + + let result = match result { + Either::Left(_) => TaskResult::Cancelled, + Either::Right(res) => TaskResult::Completed(res), + }; + + // Call the callback with the result. + callback(result).await; + + cancel_flag_c.signal().await; + }); + + TaskHandler { task, cancel_flag } + } + + /// Cancels the task. + async fn cancel(self) { + self.cancel_flag.wait().await; + self.task.cancel().await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{future, sync::Arc}; + + #[test] + fn test_task_group() { + let ex = Arc::new(smol::Executor::new()); + smol::block_on(ex.clone().run(async move { + let group = Arc::new(TaskGroup::new()); + + group.spawn(ex.clone(), future::ready(0), |res| async move { + assert!(matches!(res, TaskResult::Completed(0))); + }); + + group.spawn(ex.clone(), future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + + let groupc = group.clone(); + let exc = ex.clone(); + group.spawn( + ex.clone(), + async move { + groupc.spawn(exc.clone(), future::pending::<()>(), |res| async move { + assert!(matches!(res, TaskResult::Cancelled)); + }); + }, + |res| async move { + assert!(matches!(res, TaskResult::Completed(_))); + }, + ); + + // Do something + smol::Timer::after(std::time::Duration::from_millis(50)).await; + group.cancel().await; + })); + } +} diff --git a/karyons_core/src/async_utils/timeout.rs b/karyons_core/src/async_utils/timeout.rs new file mode 100644 index 0000000..7c55e1b --- /dev/null +++ b/karyons_core/src/async_utils/timeout.rs @@ -0,0 +1,52 @@ +use std::{future::Future, time::Duration}; + +use smol::Timer; + +use super::{select, Either}; +use crate::{error::Error, Result}; + +/// Waits for a future to complete or times out if it exceeds a specified +/// duration. +/// +/// # Example +/// +/// ``` +/// use std::{future, time::Duration}; +/// +/// use karyons_core::async_utils::timeout; +/// +/// async { +/// let fut = future::pending::<()>(); +/// assert!(timeout(Duration::from_millis(100), fut).await.is_err()); +/// }; +/// +/// ``` +/// +pub async fn timeout<T, F>(delay: Duration, future1: F) -> Result<T> +where + F: Future<Output = T>, +{ + let result = select(Timer::after(delay), future1).await; + + match result { + Either::Left(_) => Err(Error::Timeout), + Either::Right(res) => Ok(res), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{future, time::Duration}; + + #[test] + fn test_timeout() { + smol::block_on(async move { + let fut = future::pending::<()>(); + assert!(timeout(Duration::from_millis(10), fut).await.is_err()); + + let fut = smol::Timer::after(Duration::from_millis(10)); + assert!(timeout(Duration::from_millis(50), fut).await.is_ok()) + }); + } +} diff --git a/karyons_core/src/error.rs b/karyons_core/src/error.rs new file mode 100644 index 0000000..15947c8 --- /dev/null +++ b/karyons_core/src/error.rs @@ -0,0 +1,51 @@ +use thiserror::Error as ThisError; + +pub type Result<T> = std::result::Result<T, Error>; + +#[derive(ThisError, Debug)] +pub enum Error { + #[error("IO Error: {0}")] + IO(#[from] std::io::Error), + + #[error("Timeout Error")] + Timeout, + + #[error("Path Not Found Error: {0}")] + PathNotFound(&'static str), + + #[error("Channel Send Error: {0}")] + ChannelSend(String), + + #[error("Channel Receive Error: {0}")] + ChannelRecv(String), + + #[error("Decode Error: {0}")] + Decode(String), + + #[error("Encode Error: {0}")] + Encode(String), +} + +impl<T> From<smol::channel::SendError<T>> for Error { + fn from(error: smol::channel::SendError<T>) -> Self { + Error::ChannelSend(error.to_string()) + } +} + +impl From<smol::channel::RecvError> for Error { + fn from(error: smol::channel::RecvError) -> Self { + Error::ChannelRecv(error.to_string()) + } +} + +impl From<bincode::error::DecodeError> for Error { + fn from(error: bincode::error::DecodeError) -> Self { + Error::Decode(error.to_string()) + } +} + +impl From<bincode::error::EncodeError> for Error { + fn from(error: bincode::error::EncodeError) -> Self { + Error::Encode(error.to_string()) + } +} diff --git a/karyons_core/src/event.rs b/karyons_core/src/event.rs new file mode 100644 index 0000000..b856385 --- /dev/null +++ b/karyons_core/src/event.rs @@ -0,0 +1,451 @@ +use std::{ + any::Any, + collections::HashMap, + marker::PhantomData, + sync::{Arc, Weak}, +}; + +use chrono::{DateTime, Utc}; +use log::{error, trace}; +use smol::{ + channel::{Receiver, Sender}, + lock::Mutex, +}; + +use crate::{utils::random_16, Result}; + +pub type ArcEventSys<T> = Arc<EventSys<T>>; +pub type WeakEventSys<T> = Weak<EventSys<T>>; +pub type EventListenerID = u16; + +type Listeners<T> = HashMap<T, HashMap<String, HashMap<EventListenerID, Sender<Event>>>>; + +/// EventSys supports event emission to registered listeners based on topics. +/// # Example +/// +/// ``` +/// use karyons_core::event::{EventSys, EventValueTopic, EventValue}; +/// +/// async { +/// let event_sys = EventSys::new(); +/// +/// #[derive(Hash, PartialEq, Eq, Debug, Clone)] +/// enum Topic { +/// TopicA, +/// TopicB, +/// } +/// +/// #[derive(Clone, Debug, PartialEq)] +/// struct A(usize); +/// +/// impl EventValue for A { +/// fn id() -> &'static str { +/// "A" +/// } +/// } +/// +/// let listener = event_sys.register::<A>(&Topic::TopicA).await; +/// +/// event_sys.emit_by_topic(&Topic::TopicA, &A(3)) .await; +/// let msg: A = listener.recv().await.unwrap(); +/// +/// #[derive(Clone, Debug, PartialEq)] +/// struct B(usize); +/// +/// impl EventValue for B { +/// fn id() -> &'static str { +/// "B" +/// } +/// } +/// +/// impl EventValueTopic for B { +/// type Topic = Topic; +/// fn topic() -> Self::Topic{ +/// Topic::TopicB +/// } +/// } +/// +/// let listener = event_sys.register::<B>(&Topic::TopicB).await; +/// +/// event_sys.emit(&B(3)) .await; +/// let msg: B = listener.recv().await.unwrap(); +/// +/// // .... +/// }; +/// +/// ``` +/// +pub struct EventSys<T> { + listeners: Mutex<Listeners<T>>, +} + +impl<T> EventSys<T> +where + T: std::hash::Hash + Eq + std::fmt::Debug + Clone, +{ + /// Creates a new `EventSys` + pub fn new() -> ArcEventSys<T> { + Arc::new(Self { + listeners: Mutex::new(HashMap::new()), + }) + } + + /// Emits an event to the listeners. + /// + /// The event must implement the `EventValueTopic` trait to indicate the + /// topic of the event. Otherwise, you can use `emit_by_topic()`. + pub async fn emit<E: EventValueTopic<Topic = T> + Clone>(&self, value: &E) { + let topic = E::topic(); + self.emit_by_topic(&topic, value).await; + } + + /// Emits an event to the listeners. + pub async fn emit_by_topic<E: EventValueAny + EventValue + Clone>(&self, topic: &T, value: &E) { + let value: Arc<dyn EventValueAny> = Arc::new(value.clone()); + let event = Event::new(value); + + let mut topics = self.listeners.lock().await; + + if !topics.contains_key(topic) { + error!("Failed to emit an event to a non-existent topic"); + return; + } + + let event_ids = topics.get_mut(topic).unwrap(); + let event_id = E::id().to_string(); + + if !event_ids.contains_key(&event_id) { + error!("Failed to emit an event to a non-existent event id"); + return; + } + + let mut failed_listeners = vec![]; + + let listeners = event_ids.get_mut(&event_id).unwrap(); + for (listener_id, listener) in listeners.iter() { + if let Err(err) = listener.send(event.clone()).await { + trace!("Failed to emit event for topic {:?}: {}", topic, err); + failed_listeners.push(*listener_id); + } + } + + for listener_id in failed_listeners.iter() { + listeners.remove(listener_id); + } + } + + /// Registers a new event listener for the given topic. + pub async fn register<E: EventValueAny + EventValue + Clone>( + self: &Arc<Self>, + topic: &T, + ) -> EventListener<T, E> { + let chan = smol::channel::unbounded(); + + let topics = &mut self.listeners.lock().await; + + if !topics.contains_key(topic) { + topics.insert(topic.clone(), HashMap::new()); + } + + let event_ids = topics.get_mut(topic).unwrap(); + let event_id = E::id().to_string(); + + if !event_ids.contains_key(&event_id) { + event_ids.insert(event_id.clone(), HashMap::new()); + } + + let listeners = event_ids.get_mut(&event_id).unwrap(); + + let mut listener_id = random_16(); + while listeners.contains_key(&listener_id) { + listener_id = random_16(); + } + + let listener = + EventListener::new(listener_id, Arc::downgrade(self), chan.1, &event_id, topic); + + listeners.insert(listener_id, chan.0); + + listener + } + + /// Removes an event listener attached to the given topic. + async fn remove(&self, topic: &T, event_id: &str, listener_id: &EventListenerID) { + let topics = &mut self.listeners.lock().await; + if !topics.contains_key(topic) { + error!("Failed to remove a non-existent topic"); + return; + } + + let event_ids = topics.get_mut(topic).unwrap(); + if !event_ids.contains_key(event_id) { + error!("Failed to remove a non-existent event id"); + return; + } + + let listeners = event_ids.get_mut(event_id).unwrap(); + if listeners.remove(listener_id).is_none() { + error!("Failed to remove a non-existent event listener"); + } + } +} + +/// EventListener listens for and receives events from the `EventSys`. +pub struct EventListener<T, E> { + id: EventListenerID, + recv_chan: Receiver<Event>, + event_sys: WeakEventSys<T>, + event_id: String, + topic: T, + phantom: PhantomData<E>, +} + +impl<T, E> EventListener<T, E> +where + T: std::hash::Hash + Eq + Clone + std::fmt::Debug, + E: EventValueAny + Clone + EventValue, +{ + /// Create a new event listener. + fn new( + id: EventListenerID, + event_sys: WeakEventSys<T>, + recv_chan: Receiver<Event>, + event_id: &str, + topic: &T, + ) -> EventListener<T, E> { + Self { + id, + recv_chan, + event_sys, + event_id: event_id.to_string(), + topic: topic.clone(), + phantom: PhantomData, + } + } + + /// Receive the next event. + pub async fn recv(&self) -> Result<E> { + match self.recv_chan.recv().await { + Ok(event) => match ((*event.value).value_as_any()).downcast_ref::<E>() { + Some(v) => Ok(v.clone()), + None => unreachable!("Error when attempting to downcast the event value."), + }, + Err(err) => { + error!("Failed to receive new event: {err}"); + self.cancel().await; + Err(err.into()) + } + } + } + + /// Cancels the listener and removes it from the `EventSys`. + pub async fn cancel(&self) { + self.event_sys() + .remove(&self.topic, &self.event_id, &self.id) + .await; + } + + /// Returns the topic for this event listener. + pub async fn topic(&self) -> &T { + &self.topic + } + + /// Returns the event id for this event listener. + pub async fn event_id(&self) -> &String { + &self.event_id + } + + fn event_sys(&self) -> ArcEventSys<T> { + self.event_sys.upgrade().unwrap() + } +} + +/// An event within the `EventSys`. +#[derive(Clone, Debug)] +pub struct Event { + /// The time at which the event was created. + created_at: DateTime<Utc>, + /// The value of the Event. + value: Arc<dyn EventValueAny>, +} + +impl Event { + /// Creates a new Event. + pub fn new(value: Arc<dyn EventValueAny>) -> Self { + Self { + created_at: Utc::now(), + value, + } + } +} + +impl std::fmt::Display for Event { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}: {:?}", self.created_at, self.value) + } +} + +pub trait EventValueAny: Any + Send + Sync + std::fmt::Debug { + fn value_as_any(&self) -> &dyn Any; +} + +impl<T: Send + Sync + std::fmt::Debug + Any> EventValueAny for T { + fn value_as_any(&self) -> &dyn Any { + self + } +} + +pub trait EventValue: EventValueAny { + fn id() -> &'static str + where + Self: Sized; +} + +pub trait EventValueTopic: EventValueAny + EventValue { + type Topic; + fn topic() -> Self::Topic + where + Self: Sized; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Hash, PartialEq, Eq, Debug, Clone)] + enum Topic { + TopicA, + TopicB, + TopicC, + TopicD, + TopicE, + } + + #[derive(Clone, Debug, PartialEq)] + struct A { + a_value: usize, + } + + #[derive(Clone, Debug, PartialEq)] + struct B { + b_value: usize, + } + + #[derive(Clone, Debug, PartialEq)] + struct C { + c_value: usize, + } + + #[derive(Clone, Debug, PartialEq)] + struct D { + d_value: usize, + } + + #[derive(Clone, Debug, PartialEq)] + struct E { + e_value: usize, + } + + #[derive(Clone, Debug, PartialEq)] + struct F { + f_value: usize, + } + + impl EventValue for A { + fn id() -> &'static str { + "A" + } + } + + impl EventValue for B { + fn id() -> &'static str { + "B" + } + } + + impl EventValue for C { + fn id() -> &'static str { + "C" + } + } + + impl EventValue for D { + fn id() -> &'static str { + "D" + } + } + + impl EventValue for E { + fn id() -> &'static str { + "E" + } + } + + impl EventValue for F { + fn id() -> &'static str { + "F" + } + } + + impl EventValueTopic for C { + type Topic = Topic; + fn topic() -> Self::Topic { + Topic::TopicC + } + } + + #[test] + fn test_event_sys() { + smol::block_on(async move { + let event_sys = EventSys::<Topic>::new(); + + let a_listener = event_sys.register::<A>(&Topic::TopicA).await; + let b_listener = event_sys.register::<B>(&Topic::TopicB).await; + + event_sys + .emit_by_topic(&Topic::TopicA, &A { a_value: 3 }) + .await; + event_sys + .emit_by_topic(&Topic::TopicB, &B { b_value: 5 }) + .await; + + let msg = a_listener.recv().await.unwrap(); + assert_eq!(msg, A { a_value: 3 }); + + let msg = b_listener.recv().await.unwrap(); + assert_eq!(msg, B { b_value: 5 }); + + // register the same event type to different topics + let c_listener = event_sys.register::<C>(&Topic::TopicC).await; + let d_listener = event_sys.register::<C>(&Topic::TopicD).await; + + event_sys.emit(&C { c_value: 10 }).await; + let msg = c_listener.recv().await.unwrap(); + assert_eq!(msg, C { c_value: 10 }); + + event_sys + .emit_by_topic(&Topic::TopicD, &C { c_value: 10 }) + .await; + let msg = d_listener.recv().await.unwrap(); + assert_eq!(msg, C { c_value: 10 }); + + // register different event types to the same topic + let e_listener = event_sys.register::<E>(&Topic::TopicE).await; + let f_listener = event_sys.register::<F>(&Topic::TopicE).await; + + event_sys + .emit_by_topic(&Topic::TopicE, &E { e_value: 5 }) + .await; + + let msg = e_listener.recv().await.unwrap(); + assert_eq!(msg, E { e_value: 5 }); + + event_sys + .emit_by_topic(&Topic::TopicE, &F { f_value: 5 }) + .await; + + let msg = f_listener.recv().await.unwrap(); + assert_eq!(msg, F { f_value: 5 }); + }); + } +} diff --git a/karyons_core/src/lib.rs b/karyons_core/src/lib.rs new file mode 100644 index 0000000..83af888 --- /dev/null +++ b/karyons_core/src/lib.rs @@ -0,0 +1,21 @@ +/// A set of helper tools and functions. +pub mod utils; + +/// A module containing async utilities that work with the `smol` async runtime. +pub mod async_utils; + +/// Represents Karyons's Core Error. +pub mod error; + +/// [`EventSys`](./event/struct.EventSys.html) Implementation +pub mod event; + +/// A simple publish-subscribe system.[`Read More`](./pubsub/struct.Publisher.html) +pub mod pubsub; + +use error::Result; +use smol::Executor as SmolEx; +use std::sync::Arc; + +/// A wrapper for smol::Executor +pub type Executor<'a> = Arc<SmolEx<'a>>; diff --git a/karyons_core/src/pubsub.rs b/karyons_core/src/pubsub.rs new file mode 100644 index 0000000..4cc0ab7 --- /dev/null +++ b/karyons_core/src/pubsub.rs @@ -0,0 +1,115 @@ +use std::{collections::HashMap, sync::Arc}; + +use log::error; +use smol::lock::Mutex; + +use crate::{utils::random_16, Result}; + +pub type ArcPublisher<T> = Arc<Publisher<T>>; +pub type SubscriptionID = u16; + +/// A simple publish-subscribe system. +// # Example +/// +/// ``` +/// use karyons_core::pubsub::{Publisher}; +/// +/// async { +/// let publisher = Publisher::new(); +/// +/// let sub = publisher.subscribe().await; +/// +/// publisher.notify(&String::from("MESSAGE")).await; +/// +/// let msg = sub.recv().await; +/// +/// // .... +/// }; +/// +/// ``` +pub struct Publisher<T> { + subs: Mutex<HashMap<SubscriptionID, smol::channel::Sender<T>>>, +} + +impl<T: Clone> Publisher<T> { + /// Creates a new Publisher + pub fn new() -> ArcPublisher<T> { + Arc::new(Self { + subs: Mutex::new(HashMap::new()), + }) + } + + /// Subscribe and return a Subscription + pub async fn subscribe(self: &Arc<Self>) -> Subscription<T> { + let mut subs = self.subs.lock().await; + + let chan = smol::channel::unbounded(); + + let mut sub_id = random_16(); + + // While the SubscriptionID already exists, generate a new one + while subs.contains_key(&sub_id) { + sub_id = random_16(); + } + + let sub = Subscription::new(sub_id, self.clone(), chan.1); + subs.insert(sub_id, chan.0); + + sub + } + + /// Unsubscribe from the Publisher + pub async fn unsubscribe(self: &Arc<Self>, id: &SubscriptionID) { + self.subs.lock().await.remove(id); + } + + /// Notify all subscribers + pub async fn notify(self: &Arc<Self>, value: &T) { + let mut subs = self.subs.lock().await; + let mut closed_subs = vec![]; + + for (sub_id, sub) in subs.iter() { + if let Err(err) = sub.send(value.clone()).await { + error!("failed to notify {}: {}", sub_id, err); + closed_subs.push(*sub_id); + } + } + + for sub_id in closed_subs.iter() { + subs.remove(sub_id); + } + } +} + +// Subscription +pub struct Subscription<T> { + id: SubscriptionID, + recv_chan: smol::channel::Receiver<T>, + publisher: ArcPublisher<T>, +} + +impl<T: Clone> Subscription<T> { + /// Creates a new Subscription + pub fn new( + id: SubscriptionID, + publisher: ArcPublisher<T>, + recv_chan: smol::channel::Receiver<T>, + ) -> Subscription<T> { + Self { + id, + recv_chan, + publisher, + } + } + + /// Receive a message from the Publisher + pub async fn recv(&self) -> Result<T> { + let msg = self.recv_chan.recv().await?; + Ok(msg) + } + + /// Unsubscribe from the Publisher + pub async fn unsubscribe(&self) { + self.publisher.unsubscribe(&self.id).await; + } +} diff --git a/karyons_core/src/utils/decode.rs b/karyons_core/src/utils/decode.rs new file mode 100644 index 0000000..a8a6522 --- /dev/null +++ b/karyons_core/src/utils/decode.rs @@ -0,0 +1,10 @@ +use bincode::Decode; + +use crate::Result; + +/// Decodes a given type `T` from the given slice. returns the decoded value +/// along with the number of bytes read. +pub fn decode<T: Decode>(src: &[u8]) -> Result<(T, usize)> { + let (result, bytes_read) = bincode::decode_from_slice(src, bincode::config::standard())?; + Ok((result, bytes_read)) +} diff --git a/karyons_core/src/utils/encode.rs b/karyons_core/src/utils/encode.rs new file mode 100644 index 0000000..7d1061b --- /dev/null +++ b/karyons_core/src/utils/encode.rs @@ -0,0 +1,15 @@ +use bincode::Encode; + +use crate::Result; + +/// Encode the given type `T` into a `Vec<u8>`. +pub fn encode<T: Encode>(msg: &T) -> Result<Vec<u8>> { + let vec = bincode::encode_to_vec(msg, bincode::config::standard())?; + Ok(vec) +} + +/// Encode the given type `T` into the given slice.. +pub fn encode_into_slice<T: Encode>(msg: &T, dst: &mut [u8]) -> Result<()> { + bincode::encode_into_slice(msg, dst, bincode::config::standard())?; + Ok(()) +} diff --git a/karyons_core/src/utils/mod.rs b/karyons_core/src/utils/mod.rs new file mode 100644 index 0000000..a3c3f50 --- /dev/null +++ b/karyons_core/src/utils/mod.rs @@ -0,0 +1,19 @@ +mod decode; +mod encode; +mod path; + +pub use decode::decode; +pub use encode::{encode, encode_into_slice}; +pub use path::{home_dir, tilde_expand}; + +use rand::{rngs::OsRng, Rng}; + +/// Generates and returns a random u32 using `rand::rngs::OsRng`. +pub fn random_32() -> u32 { + OsRng.gen() +} + +/// Generates and returns a random u16 using `rand::rngs::OsRng`. +pub fn random_16() -> u16 { + OsRng.gen() +} diff --git a/karyons_core/src/utils/path.rs b/karyons_core/src/utils/path.rs new file mode 100644 index 0000000..2cd900a --- /dev/null +++ b/karyons_core/src/utils/path.rs @@ -0,0 +1,39 @@ +use std::path::PathBuf; + +use crate::{error::Error, Result}; + +/// Returns the user's home directory as a `PathBuf`. +#[allow(dead_code)] +pub fn home_dir() -> Result<PathBuf> { + dirs::home_dir().ok_or(Error::PathNotFound("Home dir not found")) +} + +/// Expands a tilde (~) in a path and returns the expanded `PathBuf`. +#[allow(dead_code)] +pub fn tilde_expand(path: &str) -> Result<PathBuf> { + match path { + "~" => home_dir(), + p if p.starts_with("~/") => Ok(home_dir()?.join(&path[2..])), + _ => Ok(PathBuf::from(path)), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tilde_expand() { + let path = "~/src"; + let expanded_path = dirs::home_dir().unwrap().join("src"); + assert_eq!(tilde_expand(path).unwrap(), expanded_path); + + let path = "~"; + let expanded_path = dirs::home_dir().unwrap(); + assert_eq!(tilde_expand(path).unwrap(), expanded_path); + + let path = ""; + let expanded_path = PathBuf::from(""); + assert_eq!(tilde_expand(path).unwrap(), expanded_path); + } +} diff --git a/karyons_net/Cargo.toml b/karyons_net/Cargo.toml new file mode 100644 index 0000000..70e1917 --- /dev/null +++ b/karyons_net/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "karyons_net" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +karyons_core = { path = "../karyons_core" } + +smol = "1.3.0" +async-trait = "0.1.74" +pin-project-lite = "0.2.13" +log = "0.4.20" +bincode = { version="2.0.0-rc.3", features = ["derive"]} +chrono = "0.4.30" +rand = "0.8.5" +thiserror = "1.0.47" +dirs = "5.0.1" +url = "2.4.1" + diff --git a/karyons_net/src/connection.rs b/karyons_net/src/connection.rs new file mode 100644 index 0000000..518ccfd --- /dev/null +++ b/karyons_net/src/connection.rs @@ -0,0 +1,57 @@ +use crate::{Endpoint, Result}; +use async_trait::async_trait; + +use crate::transports::{tcp, udp, unix}; + +/// Alias for `Box<dyn Connection>` +pub type Conn = Box<dyn Connection>; + +/// Connection is a generic network connection interface for +/// `UdpConn`, `TcpConn`, and `UnixConn`. +/// +/// If you are familiar with the Go language, this is similar to the `Conn` +/// interface <https://pkg.go.dev/net#Conn> +#[async_trait] +pub trait Connection: Send + Sync { + /// Returns the remote peer endpoint of this connection + fn peer_endpoint(&self) -> Result<Endpoint>; + + /// Returns the local socket endpoint of this connection + fn local_endpoint(&self) -> Result<Endpoint>; + + /// Reads data from this connection. + async fn recv(&self, buf: &mut [u8]) -> Result<usize>; + + /// Sends data to this connection + async fn send(&self, buf: &[u8]) -> Result<usize>; +} + +/// Connects to the provided endpoint. +/// +/// it only supports `tcp4/6`, `udp4/6` and `unix`. +/// +/// #Example +/// +/// ``` +/// use karyons_net::{Endpoint, dial}; +/// +/// async { +/// let endpoint: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap(); +/// +/// let conn = dial(&endpoint).await.unwrap(); +/// +/// conn.send(b"MSG").await.unwrap(); +/// +/// let mut buffer = [0;32]; +/// conn.recv(&mut buffer).await.unwrap(); +/// }; +/// +/// ``` +/// +pub async fn dial(endpoint: &Endpoint) -> Result<Conn> { + match endpoint { + Endpoint::Tcp(addr, port) => Ok(Box::new(tcp::dial_tcp(addr, port).await?)), + Endpoint::Udp(addr, port) => Ok(Box::new(udp::dial_udp(addr, port).await?)), + Endpoint::Unix(addr) => Ok(Box::new(unix::dial_unix(addr).await?)), + } +} diff --git a/karyons_net/src/endpoint.rs b/karyons_net/src/endpoint.rs new file mode 100644 index 0000000..50dfe6b --- /dev/null +++ b/karyons_net/src/endpoint.rs @@ -0,0 +1,223 @@ +use std::{ + net::{IpAddr, SocketAddr}, + os::unix::net::SocketAddr as UnixSocketAddress, + path::PathBuf, + str::FromStr, +}; + +use bincode::{Decode, Encode}; +use url::Url; + +use crate::{Error, Result}; + +/// Port defined as a u16. +pub type Port = u16; + +/// Endpoint defines generic network endpoints for karyons. +/// +/// # Example +/// +/// ``` +/// use std::net::SocketAddr; +/// +/// use karyons_net::Endpoint; +/// +/// let endpoint: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap(); +/// +/// let socketaddr: SocketAddr = "127.0.0.1:3000".parse().unwrap(); +/// let endpoint = Endpoint::new_udp_addr(&socketaddr); +/// +/// ``` +/// +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Endpoint { + Udp(Addr, Port), + Tcp(Addr, Port), + Unix(String), +} + +impl std::fmt::Display for Endpoint { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Endpoint::Udp(ip, port) => { + write!(f, "udp://{}:{}", ip, port) + } + Endpoint::Tcp(ip, port) => { + write!(f, "tcp://{}:{}", ip, port) + } + Endpoint::Unix(path) => { + if path.is_empty() { + write!(f, "unix:/UNNAMED") + } else { + write!(f, "unix:/{}", path) + } + } + } + } +} + +impl TryFrom<Endpoint> for SocketAddr { + type Error = Error; + fn try_from(endpoint: Endpoint) -> std::result::Result<SocketAddr, Self::Error> { + match endpoint { + Endpoint::Udp(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)), + Endpoint::Tcp(ip, port) => Ok(SocketAddr::new(ip.try_into()?, port)), + Endpoint::Unix(_) => Err(Error::TryFromEndpointError), + } + } +} + +impl TryFrom<Endpoint> for PathBuf { + type Error = Error; + fn try_from(endpoint: Endpoint) -> std::result::Result<PathBuf, Self::Error> { + match endpoint { + Endpoint::Unix(path) => Ok(PathBuf::from(&path)), + _ => Err(Error::TryFromEndpointError), + } + } +} + +impl TryFrom<Endpoint> for UnixSocketAddress { + type Error = Error; + fn try_from(endpoint: Endpoint) -> std::result::Result<UnixSocketAddress, Self::Error> { + match endpoint { + Endpoint::Unix(a) => Ok(UnixSocketAddress::from_pathname(a)?), + _ => Err(Error::TryFromEndpointError), + } + } +} + +impl FromStr for Endpoint { + type Err = Error; + + fn from_str(s: &str) -> std::result::Result<Self, Self::Err> { + let url: Url = match s.parse() { + Ok(u) => u, + Err(err) => return Err(Error::ParseEndpoint(err.to_string())), + }; + + if url.has_host() { + let host = url.host_str().unwrap(); + + let addr = match host.parse::<IpAddr>() { + Ok(addr) => Addr::Ip(addr), + Err(_) => Addr::Domain(host.to_string()), + }; + + let port = match url.port() { + Some(p) => p, + None => return Err(Error::ParseEndpoint(format!("port missing: {s}"))), + }; + + match url.scheme() { + "tcp" => Ok(Endpoint::Tcp(addr, port)), + "udp" => Ok(Endpoint::Udp(addr, port)), + _ => Err(Error::InvalidEndpoint(s.to_string())), + } + } else { + if url.path().is_empty() { + return Err(Error::InvalidEndpoint(s.to_string())); + } + + match url.scheme() { + "unix" => Ok(Endpoint::Unix(url.path().to_string())), + _ => Err(Error::InvalidEndpoint(s.to_string())), + } + } + } +} + +impl Endpoint { + /// Creates a new TCP endpoint from a `SocketAddr`. + pub fn new_tcp_addr(addr: &SocketAddr) -> Endpoint { + Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port()) + } + + /// Creates a new UDP endpoint from a `SocketAddr`. + pub fn new_udp_addr(addr: &SocketAddr) -> Endpoint { + Endpoint::Udp(Addr::Ip(addr.ip()), addr.port()) + } + + /// Creates a new Unix endpoint from a `UnixSocketAddress`. + pub fn new_unix_addr(addr: &UnixSocketAddress) -> Endpoint { + Endpoint::Unix( + addr.as_pathname() + .and_then(|a| a.to_str()) + .unwrap_or("") + .to_string(), + ) + } + + /// Returns the `Port` of the endpoint. + pub fn port(&self) -> Result<&Port> { + match self { + Endpoint::Tcp(_, port) => Ok(port), + Endpoint::Udp(_, port) => Ok(port), + _ => Err(Error::TryFromEndpointError), + } + } + + /// Returns the `Addr` of the endpoint. + pub fn addr(&self) -> Result<&Addr> { + match self { + Endpoint::Tcp(addr, _) => Ok(addr), + Endpoint::Udp(addr, _) => Ok(addr), + _ => Err(Error::TryFromEndpointError), + } + } +} + +/// Addr defines a type for an address, either IP or domain. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode)] +pub enum Addr { + Ip(IpAddr), + Domain(String), +} + +impl TryFrom<Addr> for IpAddr { + type Error = Error; + fn try_from(addr: Addr) -> std::result::Result<IpAddr, Self::Error> { + match addr { + Addr::Ip(ip) => Ok(ip), + Addr::Domain(d) => Err(Error::InvalidAddress(d)), + } + } +} + +impl std::fmt::Display for Addr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Addr::Ip(ip) => { + write!(f, "{}", ip) + } + Addr::Domain(d) => { + write!(f, "{}", d) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::Ipv4Addr; + + #[test] + fn test_endpoint_from_str() { + let endpoint_str: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap(); + let endpoint = Endpoint::Tcp(Addr::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))), 3000); + assert_eq!(endpoint_str, endpoint); + + let endpoint_str: Endpoint = "udp://127.0.0.1:4000".parse().unwrap(); + let endpoint = Endpoint::Udp(Addr::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))), 4000); + assert_eq!(endpoint_str, endpoint); + + let endpoint_str: Endpoint = "tcp://example.com:3000".parse().unwrap(); + let endpoint = Endpoint::Tcp(Addr::Domain("example.com".to_string()), 3000); + assert_eq!(endpoint_str, endpoint); + + let endpoint_str = "unix:/home/x/s.socket".parse::<Endpoint>().unwrap(); + let endpoint = Endpoint::Unix("/home/x/s.socket".to_string()); + assert_eq!(endpoint_str, endpoint); + } +} diff --git a/karyons_net/src/error.rs b/karyons_net/src/error.rs new file mode 100644 index 0000000..a1c85db --- /dev/null +++ b/karyons_net/src/error.rs @@ -0,0 +1,45 @@ +use thiserror::Error as ThisError; + +pub type Result<T> = std::result::Result<T, Error>; + +#[derive(ThisError, Debug)] +pub enum Error { + #[error("IO Error: {0}")] + IO(#[from] std::io::Error), + + #[error("Try from endpoint Error")] + TryFromEndpointError, + + #[error("invalid address {0}")] + InvalidAddress(String), + + #[error("invalid endpoint {0}")] + InvalidEndpoint(String), + + #[error("Parse endpoint error {0}")] + ParseEndpoint(String), + + #[error("Timeout Error")] + Timeout, + + #[error("Channel Send Error: {0}")] + ChannelSend(String), + + #[error("Channel Receive Error: {0}")] + ChannelRecv(String), + + #[error("Karyons core error : {0}")] + KaryonsCore(#[from] karyons_core::error::Error), +} + +impl<T> From<smol::channel::SendError<T>> for Error { + fn from(error: smol::channel::SendError<T>) -> Self { + Error::ChannelSend(error.to_string()) + } +} + +impl From<smol::channel::RecvError> for Error { + fn from(error: smol::channel::RecvError) -> Self { + Error::ChannelRecv(error.to_string()) + } +} diff --git a/karyons_net/src/lib.rs b/karyons_net/src/lib.rs new file mode 100644 index 0000000..914c6d8 --- /dev/null +++ b/karyons_net/src/lib.rs @@ -0,0 +1,24 @@ +mod connection; +mod endpoint; +mod error; +mod listener; +mod transports; + +pub use { + connection::{dial, Conn, Connection}, + endpoint::{Addr, Endpoint, Port}, + listener::{listen, Listener}, + transports::{ + tcp::{dial_tcp, listen_tcp, TcpConn}, + udp::{dial_udp, listen_udp, UdpConn}, + unix::{dial_unix, listen_unix, UnixConn}, + }, +}; + +use error::{Error, Result}; + +/// Represents Karyons's Net Error +pub use error::Error as NetError; + +/// Represents Karyons's Net Result +pub use error::Result as NetResult; diff --git a/karyons_net/src/listener.rs b/karyons_net/src/listener.rs new file mode 100644 index 0000000..31a63ae --- /dev/null +++ b/karyons_net/src/listener.rs @@ -0,0 +1,39 @@ +use crate::{Endpoint, Error, Result}; +use async_trait::async_trait; + +use crate::{ + transports::{tcp, unix}, + Conn, +}; + +/// Listener is a generic network listener. +#[async_trait] +pub trait Listener: Send + Sync { + fn local_endpoint(&self) -> Result<Endpoint>; + async fn accept(&self) -> Result<Conn>; +} + +/// Listens to the provided endpoint. +/// +/// it only supports `tcp4/6` and `unix`. +/// +/// #Example +/// +/// ``` +/// use karyons_net::{Endpoint, listen}; +/// +/// async { +/// let endpoint: Endpoint = "tcp://127.0.0.1:3000".parse().unwrap(); +/// +/// let listener = listen(&endpoint).await.unwrap(); +/// let conn = listener.accept().await.unwrap(); +/// }; +/// +/// ``` +pub async fn listen(endpoint: &Endpoint) -> Result<Box<dyn Listener>> { + match endpoint { + Endpoint::Tcp(addr, port) => Ok(Box::new(tcp::listen_tcp(addr, port).await?)), + Endpoint::Unix(addr) => Ok(Box::new(unix::listen_unix(addr)?)), + _ => Err(Error::InvalidEndpoint(endpoint.to_string())), + } +} diff --git a/karyons_net/src/transports/mod.rs b/karyons_net/src/transports/mod.rs new file mode 100644 index 0000000..f399133 --- /dev/null +++ b/karyons_net/src/transports/mod.rs @@ -0,0 +1,3 @@ +pub mod tcp; +pub mod udp; +pub mod unix; diff --git a/karyons_net/src/transports/tcp.rs b/karyons_net/src/transports/tcp.rs new file mode 100644 index 0000000..5ff7b28 --- /dev/null +++ b/karyons_net/src/transports/tcp.rs @@ -0,0 +1,82 @@ +use async_trait::async_trait; + +use smol::{ + io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, + lock::Mutex, + net::{TcpListener, TcpStream}, +}; + +use crate::{ + connection::Connection, + endpoint::{Addr, Endpoint, Port}, + listener::Listener, + Result, +}; + +/// TCP network connection implementations of the `Connection` trait. +pub struct TcpConn { + inner: TcpStream, + read: Mutex<ReadHalf<TcpStream>>, + write: Mutex<WriteHalf<TcpStream>>, +} + +impl TcpConn { + /// Creates a new TcpConn + pub fn new(conn: TcpStream) -> Self { + let (read, write) = split(conn.clone()); + Self { + inner: conn, + read: Mutex::new(read), + write: Mutex::new(write), + } + } +} + +#[async_trait] +impl Connection for TcpConn { + fn peer_endpoint(&self) -> Result<Endpoint> { + Ok(Endpoint::new_tcp_addr(&self.inner.peer_addr()?)) + } + + fn local_endpoint(&self) -> Result<Endpoint> { + Ok(Endpoint::new_tcp_addr(&self.inner.local_addr()?)) + } + + async fn recv(&self, buf: &mut [u8]) -> Result<usize> { + self.read.lock().await.read_exact(buf).await?; + Ok(buf.len()) + } + + async fn send(&self, buf: &[u8]) -> Result<usize> { + self.write.lock().await.write_all(buf).await?; + Ok(buf.len()) + } +} + +#[async_trait] +impl Listener for TcpListener { + fn local_endpoint(&self) -> Result<Endpoint> { + Ok(Endpoint::new_tcp_addr(&self.local_addr()?)) + } + + async fn accept(&self) -> Result<Box<dyn Connection>> { + let (conn, _) = self.accept().await?; + conn.set_nodelay(true)?; + Ok(Box::new(TcpConn::new(conn))) + } +} + +/// Connects to the given TCP address and port. +pub async fn dial_tcp(addr: &Addr, port: &Port) -> Result<TcpConn> { + let address = format!("{}:{}", addr, port); + let conn = TcpStream::connect(address).await?; + conn.set_nodelay(true)?; + Ok(TcpConn::new(conn)) +} + +/// Listens on the given TCP address and port. +pub async fn listen_tcp(addr: &Addr, port: &Port) -> Result<TcpListener> { + let address = format!("{}:{}", addr, port); + let listener = TcpListener::bind(address).await?; + Ok(listener) +} diff --git a/karyons_net/src/transports/udp.rs b/karyons_net/src/transports/udp.rs new file mode 100644 index 0000000..27fb9ae --- /dev/null +++ b/karyons_net/src/transports/udp.rs @@ -0,0 +1,77 @@ +use std::net::SocketAddr; + +use async_trait::async_trait; +use smol::net::UdpSocket; + +use crate::{ + connection::Connection, + endpoint::{Addr, Endpoint, Port}, + Result, +}; + +/// UDP network connection implementations of the `Connection` trait. +pub struct UdpConn { + inner: UdpSocket, +} + +impl UdpConn { + /// Creates a new UdpConn + pub fn new(conn: UdpSocket) -> Self { + Self { inner: conn } + } +} + +impl UdpConn { + /// Receives a single datagram message. Returns the number of bytes read + /// and the origin endpoint. + pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, Endpoint)> { + let (size, addr) = self.inner.recv_from(buf).await?; + Ok((size, Endpoint::new_udp_addr(&addr))) + } + + /// Sends data to the given address. Returns the number of bytes written. + pub async fn send_to(&self, buf: &[u8], addr: &Endpoint) -> Result<usize> { + let addr: SocketAddr = addr.clone().try_into()?; + let size = self.inner.send_to(buf, addr).await?; + Ok(size) + } +} + +#[async_trait] +impl Connection for UdpConn { + fn peer_endpoint(&self) -> Result<Endpoint> { + Ok(Endpoint::new_udp_addr(&self.inner.peer_addr()?)) + } + + fn local_endpoint(&self) -> Result<Endpoint> { + Ok(Endpoint::new_udp_addr(&self.inner.local_addr()?)) + } + + async fn recv(&self, buf: &mut [u8]) -> Result<usize> { + let size = self.inner.recv(buf).await?; + Ok(size) + } + + async fn send(&self, buf: &[u8]) -> Result<usize> { + let size = self.inner.send(buf).await?; + Ok(size) + } +} + +/// Connects to the given UDP address and port. +pub async fn dial_udp(addr: &Addr, port: &Port) -> Result<UdpConn> { + let address = format!("{}:{}", addr, port); + + // Let the operating system assign an available port to this socket + let conn = UdpSocket::bind("[::]:0").await?; + conn.connect(address).await?; + Ok(UdpConn::new(conn)) +} + +/// Listens on the given UDP address and port. +pub async fn listen_udp(addr: &Addr, port: &Port) -> Result<UdpConn> { + let address = format!("{}:{}", addr, port); + let conn = UdpSocket::bind(address).await?; + let udp_conn = UdpConn::new(conn); + Ok(udp_conn) +} diff --git a/karyons_net/src/transports/unix.rs b/karyons_net/src/transports/unix.rs new file mode 100644 index 0000000..c89832e --- /dev/null +++ b/karyons_net/src/transports/unix.rs @@ -0,0 +1,73 @@ +use async_trait::async_trait; + +use smol::{ + io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, + lock::Mutex, + net::unix::{UnixListener, UnixStream}, +}; + +use crate::{connection::Connection, endpoint::Endpoint, listener::Listener, Result}; + +/// Unix domain socket implementations of the `Connection` trait. +pub struct UnixConn { + inner: UnixStream, + read: Mutex<ReadHalf<UnixStream>>, + write: Mutex<WriteHalf<UnixStream>>, +} + +impl UnixConn { + /// Creates a new UnixConn + pub fn new(conn: UnixStream) -> Self { + let (read, write) = split(conn.clone()); + Self { + inner: conn, + read: Mutex::new(read), + write: Mutex::new(write), + } + } +} + +#[async_trait] +impl Connection for UnixConn { + fn peer_endpoint(&self) -> Result<Endpoint> { + Ok(Endpoint::new_unix_addr(&self.inner.peer_addr()?)) + } + + fn local_endpoint(&self) -> Result<Endpoint> { + Ok(Endpoint::new_unix_addr(&self.inner.local_addr()?)) + } + + async fn recv(&self, buf: &mut [u8]) -> Result<usize> { + self.read.lock().await.read_exact(buf).await?; + Ok(buf.len()) + } + + async fn send(&self, buf: &[u8]) -> Result<usize> { + self.write.lock().await.write_all(buf).await?; + Ok(buf.len()) + } +} + +#[async_trait] +impl Listener for UnixListener { + fn local_endpoint(&self) -> Result<Endpoint> { + Ok(Endpoint::new_unix_addr(&self.local_addr()?)) + } + + async fn accept(&self) -> Result<Box<dyn Connection>> { + let (conn, _) = self.accept().await?; + Ok(Box::new(UnixConn::new(conn))) + } +} + +/// Connects to the given Unix socket path. +pub async fn dial_unix(path: &String) -> Result<UnixConn> { + let conn = UnixStream::connect(path).await?; + Ok(UnixConn::new(conn)) +} + +/// Listens on the given Unix socket path. +pub fn listen_unix(path: &String) -> Result<UnixListener> { + let listener = UnixListener::bind(path)?; + Ok(listener) +} diff --git a/karyons_p2p/Cargo.toml b/karyons_p2p/Cargo.toml new file mode 100644 index 0000000..1c2a5aa --- /dev/null +++ b/karyons_p2p/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "karyons_p2p" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +karyons_core = { path = "../karyons_core" } +karyons_net = { path = "../karyons_net" } + +async-trait = "0.1.73" +smol = "1.3.0" +futures-util = {version = "0.3.5", features=["std"], default-features = false } +log = "0.4.20" +chrono = "0.4.30" +bincode = { version="2.0.0-rc.3", features = ["derive"]} +rand = "0.8.5" +thiserror = "1.0.47" +semver = "1.0.20" +sha2 = "0.10.8" + +[[example]] +name = "peer" +path = "examples/peer.rs" + +[[example]] +name = "chat" +path = "examples/chat.rs" + +[[example]] +name = "monitor" +path = "examples/monitor.rs" + +[dev-dependencies] +async-std = "1.12.0" +clap = { version = "4.4.6", features = ["derive"] } +ctrlc = "3.4.1" +easy-parallel = "3.3.1" +env_logger = "0.10.0" + diff --git a/karyons_p2p/examples/chat.rs b/karyons_p2p/examples/chat.rs new file mode 100644 index 0000000..4358362 --- /dev/null +++ b/karyons_p2p/examples/chat.rs @@ -0,0 +1,141 @@ +use std::sync::Arc; + +use async_std::io; +use async_trait::async_trait; +use clap::Parser; +use smol::{channel, future, Executor}; + +use karyons_net::{Endpoint, Port}; + +use karyons_p2p::{ + protocol::{ArcProtocol, Protocol, ProtocolEvent, ProtocolID}, + ArcPeer, Backend, Config, P2pError, PeerID, Version, +}; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Cli { + /// Optional list of bootstrap peers to start the seeding process. + #[arg(short)] + bootstrap_peers: Vec<Endpoint>, + + /// Optional list of peer endpoints for manual connections. + #[arg(short)] + peer_endpoints: Vec<Endpoint>, + + /// Optional endpoint for accepting incoming connections. + #[arg(short)] + listen_endpoint: Option<Endpoint>, + + /// Optional TCP/UDP port for the discovery service. + #[arg(short)] + discovery_port: Option<Port>, + + /// Username + #[arg(long)] + username: String, +} + +pub struct ChatProtocol { + username: String, + peer: ArcPeer, +} + +impl ChatProtocol { + fn new(username: &str, peer: ArcPeer) -> ArcProtocol { + Arc::new(Self { + peer, + username: username.to_string(), + }) + } +} + +#[async_trait] +impl Protocol for ChatProtocol { + async fn start(self: Arc<Self>, ex: Arc<Executor<'_>>) -> Result<(), P2pError> { + let selfc = self.clone(); + let stdin = io::stdin(); + let task = ex.spawn(async move { + loop { + let mut input = String::new(); + stdin.read_line(&mut input).await.unwrap(); + let msg = format!("> {}: {}", selfc.username, input.trim()); + selfc.peer.broadcast(&Self::id(), &msg).await; + } + }); + + let listener = self.peer.register_listener::<Self>().await; + loop { + let event = listener.recv().await.unwrap(); + + match event { + ProtocolEvent::Message(msg) => { + let msg = String::from_utf8(msg).unwrap(); + println!("{msg}"); + } + ProtocolEvent::Shutdown => { + break; + } + } + } + + task.cancel().await; + listener.cancel().await; + Ok(()) + } + + fn version() -> Result<Version, P2pError> { + "0.1.0, 0.1.0".parse() + } + + fn id() -> ProtocolID { + "CHAT".into() + } +} + +fn main() { + env_logger::init(); + let cli = Cli::parse(); + + // Create a PeerID based on the username. + let peer_id = PeerID::new(cli.username.as_bytes()); + + // Create the configuration for the backend. + let config = Config { + listen_endpoint: cli.listen_endpoint, + peer_endpoints: cli.peer_endpoints, + bootstrap_peers: cli.bootstrap_peers, + discovery_port: cli.discovery_port.unwrap_or(0), + ..Default::default() + }; + + // Create a new Backend + let backend = Backend::new(peer_id, config); + + let (ctrlc_s, ctrlc_r) = channel::unbounded(); + let handle = move || ctrlc_s.try_send(()).unwrap(); + ctrlc::set_handler(handle).unwrap(); + + // Create a new Executor + let ex = Arc::new(Executor::new()); + + let ex_cloned = ex.clone(); + let task = ex.spawn(async { + let username = cli.username; + + // Attach the ChatProtocol + let c = move |peer| ChatProtocol::new(&username, peer); + backend.attach_protocol::<ChatProtocol>(c).await.unwrap(); + + // Run the backend + backend.run(ex_cloned).await.unwrap(); + + // Wait for ctrlc signal + ctrlc_r.recv().await.unwrap(); + + // Shutdown the backend + backend.shutdown().await; + }); + + future::block_on(ex.run(task)); +} diff --git a/karyons_p2p/examples/chat_simulation.sh b/karyons_p2p/examples/chat_simulation.sh new file mode 100755 index 0000000..82bbe96 --- /dev/null +++ b/karyons_p2p/examples/chat_simulation.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# build +cargo build --release --example chat + +tmux new-session -d -s karyons_chat + +tmux send-keys -t karyons_chat "../../target/release/examples/chat --username 'user1'\ + -l 'tcp://127.0.0.1:40000' -d '40010'" Enter + +tmux split-window -h -t karyons_chat +tmux send-keys -t karyons_chat "../../target/release/examples/chat --username 'user2'\ + -l 'tcp://127.0.0.1:40001' -d '40011' -b 'tcp://127.0.0.1:40010 ' " Enter + +tmux split-window -h -t karyons_chat +tmux send-keys -t karyons_chat "../../target/release/examples/chat --username 'user3'\ + -l 'tcp://127.0.0.1:40002' -d '40012' -b 'tcp://127.0.0.1:40010'" Enter + +tmux split-window -h -t karyons_chat +tmux send-keys -t karyons_chat "../../target/release/examples/chat --username 'user4'\ + -b 'tcp://127.0.0.1:40010'" Enter + +tmux select-layout tiled + +tmux attach -t karyons_chat diff --git a/karyons_p2p/examples/monitor.rs b/karyons_p2p/examples/monitor.rs new file mode 100644 index 0000000..cd4defc --- /dev/null +++ b/karyons_p2p/examples/monitor.rs @@ -0,0 +1,93 @@ +use std::sync::Arc; + +use clap::Parser; +use easy_parallel::Parallel; +use smol::{channel, future, Executor}; + +use karyons_net::{Endpoint, Port}; + +use karyons_p2p::{Backend, Config, PeerID}; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Cli { + /// Optional list of bootstrap peers to start the seeding process. + #[arg(short)] + bootstrap_peers: Vec<Endpoint>, + + /// Optional list of peer endpoints for manual connections. + #[arg(short)] + peer_endpoints: Vec<Endpoint>, + + /// Optional endpoint for accepting incoming connections. + #[arg(short)] + listen_endpoint: Option<Endpoint>, + + /// Optional TCP/UDP port for the discovery service. + #[arg(short)] + discovery_port: Option<Port>, + + /// Optional user id + #[arg(long)] + userid: Option<String>, +} + +fn main() { + env_logger::init(); + let cli = Cli::parse(); + + let peer_id = match cli.userid { + Some(userid) => PeerID::new(userid.as_bytes()), + None => PeerID::random(), + }; + + // Create the configuration for the backend. + let config = Config { + listen_endpoint: cli.listen_endpoint, + peer_endpoints: cli.peer_endpoints, + bootstrap_peers: cli.bootstrap_peers, + discovery_port: cli.discovery_port.unwrap_or(0), + ..Default::default() + }; + + // Create a new Backend + let backend = Backend::new(peer_id, config); + + let (ctrlc_s, ctrlc_r) = channel::unbounded(); + let handle = move || ctrlc_s.try_send(()).unwrap(); + ctrlc::set_handler(handle).unwrap(); + + let (signal, shutdown) = channel::unbounded::<()>(); + + // Create a new Executor + let ex = Arc::new(Executor::new()); + + let task = async { + let monitor = backend.monitor().await; + + let monitor_task = ex.spawn(async move { + loop { + let event = monitor.recv().await.unwrap(); + println!("{}", event); + } + }); + + // Run the backend + backend.run(ex.clone()).await.unwrap(); + + // Wait for ctrlc signal + ctrlc_r.recv().await.unwrap(); + + // Shutdown the backend + backend.shutdown().await; + + monitor_task.cancel().await; + + drop(signal); + }; + + // Run four executor threads. + Parallel::new() + .each(0..4, |_| future::block_on(ex.run(shutdown.recv()))) + .finish(|| future::block_on(task)); +} diff --git a/karyons_p2p/examples/net_simulation.sh b/karyons_p2p/examples/net_simulation.sh new file mode 100755 index 0000000..b223b63 --- /dev/null +++ b/karyons_p2p/examples/net_simulation.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +# build +cargo build --release --example peer + +tmux new-session -d -s karyons_p2p + +tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer1'\ + -l 'tcp://127.0.0.1:30000' -d '30010'" Enter + +tmux split-window -h -t karyons_p2p +tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer2'\ + -l 'tcp://127.0.0.1:30001' -d '30011' -b 'tcp://127.0.0.1:30010 ' " Enter + +tmux split-window -h -t karyons_p2p +tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer3'\ + -l 'tcp://127.0.0.1:30002' -d '30012' -b 'tcp://127.0.0.1:30010'" Enter + +tmux split-window -h -t karyons_p2p +tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer4'\ + -l 'tcp://127.0.0.1:30003' -d '30013' -b 'tcp://127.0.0.1:30010'" Enter + +tmux split-window -h -t karyons_p2p +tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer5'\ + -l 'tcp://127.0.0.1:30004' -d '30014' -b 'tcp://127.0.0.1:30010'" Enter + +tmux split-window -h -t karyons_p2p +tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer6'\ + -l 'tcp://127.0.0.1:30005' -d '30015' -b 'tcp://127.0.0.1:30010'" Enter + +tmux select-layout even-horizontal + +sleep 3; + +tmux select-pane -t karyons_p2p:0.0 + +tmux split-window -v -t karyons_p2p +tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer7'\ + -b 'tcp://127.0.0.1:30010' -b 'tcp://127.0.0.1:30011'" Enter + +tmux select-pane -t karyons_p2p:0.2 + +tmux split-window -v -t karyons_p2p +tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer8'\ + -b 'tcp://127.0.0.1:30010' -b 'tcp://127.0.0.1:30012' -p 'tcp://127.0.0.1:30005'" Enter + +tmux select-pane -t karyons_p2p:0.4 + +tmux split-window -v -t karyons_p2p +tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer9'\ + -b 'tcp://127.0.0.1:30010' -b 'tcp://127.0.0.1:30013'" Enter + +tmux select-pane -t karyons_p2p:0.6 + +tmux split-window -v -t karyons_p2p +tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer10'\ + -b 'tcp://127.0.0.1:30010' -b 'tcp://127.0.0.1:30014'" Enter + +tmux select-pane -t karyons_p2p:0.8 + +tmux split-window -v -t karyons_p2p +tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer11'\ + -b 'tcp://127.0.0.1:30010' -b 'tcp://127.0.0.1:30015'" Enter + +tmux select-pane -t karyons_p2p:0.10 + +tmux split-window -v -t karyons_p2p +tmux send-keys -t karyons_p2p "../../target/release/examples/peer --userid 'peer12'\ + -b 'tcp://127.0.0.1:30010' -b 'tcp://127.0.0.1:30015' -b 'tcp://127.0.0.1:30011'" Enter + +tmux set-window-option -t karyons_p2p synchronize-panes on + +tmux attach -t karyons_p2p diff --git a/karyons_p2p/examples/peer.rs b/karyons_p2p/examples/peer.rs new file mode 100644 index 0000000..f805d68 --- /dev/null +++ b/karyons_p2p/examples/peer.rs @@ -0,0 +1,82 @@ +use std::sync::Arc; + +use clap::Parser; +use easy_parallel::Parallel; +use smol::{channel, future, Executor}; + +use karyons_net::{Endpoint, Port}; + +use karyons_p2p::{Backend, Config, PeerID}; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Cli { + /// Optional list of bootstrap peers to start the seeding process. + #[arg(short)] + bootstrap_peers: Vec<Endpoint>, + + /// Optional list of peer endpoints for manual connections. + #[arg(short)] + peer_endpoints: Vec<Endpoint>, + + /// Optional endpoint for accepting incoming connections. + #[arg(short)] + listen_endpoint: Option<Endpoint>, + + /// Optional TCP/UDP port for the discovery service. + #[arg(short)] + discovery_port: Option<Port>, + + /// Optional user id + #[arg(long)] + userid: Option<String>, +} + +fn main() { + env_logger::init(); + let cli = Cli::parse(); + + let peer_id = match cli.userid { + Some(userid) => PeerID::new(userid.as_bytes()), + None => PeerID::random(), + }; + + // Create the configuration for the backend. + let config = Config { + listen_endpoint: cli.listen_endpoint, + peer_endpoints: cli.peer_endpoints, + bootstrap_peers: cli.bootstrap_peers, + discovery_port: cli.discovery_port.unwrap_or(0), + ..Default::default() + }; + + // Create a new Backend + let backend = Backend::new(peer_id, config); + + let (ctrlc_s, ctrlc_r) = channel::unbounded(); + let handle = move || ctrlc_s.try_send(()).unwrap(); + ctrlc::set_handler(handle).unwrap(); + + let (signal, shutdown) = channel::unbounded::<()>(); + + // Create a new Executor + let ex = Arc::new(Executor::new()); + + let task = async { + // Run the backend + backend.run(ex.clone()).await.unwrap(); + + // Wait for ctrlc signal + ctrlc_r.recv().await.unwrap(); + + // Shutdown the backend + backend.shutdown().await; + + drop(signal); + }; + + // Run four executor threads. + Parallel::new() + .each(0..4, |_| future::block_on(ex.run(shutdown.recv()))) + .finish(|| future::block_on(task)); +} diff --git a/karyons_p2p/src/backend.rs b/karyons_p2p/src/backend.rs new file mode 100644 index 0000000..290e3e7 --- /dev/null +++ b/karyons_p2p/src/backend.rs @@ -0,0 +1,139 @@ +use std::sync::Arc; + +use log::info; + +use karyons_core::{pubsub::Subscription, Executor}; + +use crate::{ + config::Config, + discovery::{ArcDiscovery, Discovery}, + monitor::{Monitor, MonitorEvent}, + net::ConnQueue, + peer_pool::PeerPool, + protocol::{ArcProtocol, Protocol}, + ArcPeer, PeerID, Result, +}; + +pub type ArcBackend = Arc<Backend>; + +/// Backend serves as the central entry point for initiating and managing +/// the P2P network. +/// +/// +/// # Example +/// ``` +/// use std::sync::Arc; +/// +/// use easy_parallel::Parallel; +/// use smol::{channel as smol_channel, future, Executor}; +/// +/// use karyons_p2p::{Backend, Config, PeerID}; +/// +/// let peer_id = PeerID::random(); +/// +/// // Create the configuration for the backend. +/// let mut config = Config::default(); +/// +/// // Create a new Backend +/// let backend = Backend::new(peer_id, config); +/// +/// // Create a new Executor +/// let ex = Arc::new(Executor::new()); +/// +/// let task = async { +/// // Run the backend +/// backend.run(ex.clone()).await.unwrap(); +/// +/// // .... +/// +/// // Shutdown the backend +/// backend.shutdown().await; +/// }; +/// +/// future::block_on(ex.run(task)); +/// +/// ``` +pub struct Backend { + /// The Configuration for the P2P network. + config: Arc<Config>, + + /// Peer ID. + id: PeerID, + + /// Responsible for network and system monitoring. + monitor: Arc<Monitor>, + + /// Discovery instance. + discovery: ArcDiscovery, + + /// PeerPool instance. + peer_pool: Arc<PeerPool>, +} + +impl Backend { + /// Creates a new Backend. + pub fn new(id: PeerID, config: Config) -> ArcBackend { + let config = Arc::new(config); + let monitor = Arc::new(Monitor::new()); + + let conn_queue = ConnQueue::new(); + + let peer_pool = PeerPool::new(&id, conn_queue.clone(), config.clone(), monitor.clone()); + let discovery = Discovery::new(&id, conn_queue, config.clone(), monitor.clone()); + + Arc::new(Self { + id: id.clone(), + monitor, + discovery, + config, + peer_pool, + }) + } + + /// Run the Backend, starting the PeerPool and Discovery instances. + pub async fn run(self: &Arc<Self>, ex: Executor<'_>) -> Result<()> { + info!("Run the backend {}", self.id); + self.peer_pool.start(ex.clone()).await?; + self.discovery.start(ex.clone()).await?; + Ok(()) + } + + /// Attach a custom protocol to the network + pub async fn attach_protocol<P: Protocol>( + &self, + c: impl Fn(ArcPeer) -> ArcProtocol + Send + Sync + 'static, + ) -> Result<()> { + self.peer_pool.attach_protocol::<P>(Box::new(c)).await + } + + /// Returns the number of currently connected peers. + pub async fn peers(&self) -> usize { + self.peer_pool.peers_len().await + } + + /// Returns the `Config`. + pub fn config(&self) -> Arc<Config> { + self.config.clone() + } + + /// Returns the number of occupied inbound slots. + pub fn inbound_slots(&self) -> usize { + self.discovery.inbound_slots.load() + } + + /// Returns the number of occupied outbound slots. + pub fn outbound_slots(&self) -> usize { + self.discovery.outbound_slots.load() + } + + /// Subscribes to the monitor to receive network events. + pub async fn monitor(&self) -> Subscription<MonitorEvent> { + self.monitor.subscribe().await + } + + /// Shuts down the Backend. + pub async fn shutdown(&self) { + self.discovery.shutdown().await; + self.peer_pool.shutdown().await; + } +} diff --git a/karyons_p2p/src/config.rs b/karyons_p2p/src/config.rs new file mode 100644 index 0000000..ebecbf0 --- /dev/null +++ b/karyons_p2p/src/config.rs @@ -0,0 +1,105 @@ +use karyons_net::{Endpoint, Port}; + +use crate::utils::Version; + +/// the Configuration for the P2P network. +pub struct Config { + /// Represents the network version. + pub version: Version, + + ///////////////// + // PeerPool + //////////////// + /// Timeout duration for the handshake with new peers, in seconds. + pub handshake_timeout: u64, + /// Interval at which the ping protocol sends ping messages to a peer to + /// maintain connections, in seconds. + pub ping_interval: u64, + /// Timeout duration for receiving the pong message corresponding to the + /// sent ping message, in seconds. + pub ping_timeout: u64, + /// The maximum number of retries for outbound connection establishment. + pub max_connect_retries: usize, + + ///////////////// + // DISCOVERY + //////////////// + /// A list of bootstrap peers for the seeding process. + pub bootstrap_peers: Vec<Endpoint>, + /// An optional listening endpoint to accept incoming connections. + pub listen_endpoint: Option<Endpoint>, + /// A list of endpoints representing peers that the `Discovery` will + /// manually connect to. + pub peer_endpoints: Vec<Endpoint>, + /// The number of available inbound slots for incoming connections. + pub inbound_slots: usize, + /// The number of available outbound slots for outgoing connections. + pub outbound_slots: usize, + /// TCP/UDP port for lookup and refresh processes. + pub discovery_port: Port, + /// Time interval, in seconds, at which the Discovery restarts the + /// seeding process. + pub seeding_interval: u64, + + ///////////////// + // LOOKUP + //////////////// + /// The number of available inbound slots for incoming connections during + /// the lookup process. + pub lookup_inbound_slots: usize, + /// The number of available outbound slots for outgoing connections during + /// the lookup process. + pub lookup_outbound_slots: usize, + /// Timeout duration for a peer response during the lookup process, in + /// seconds. + pub lookup_response_timeout: u64, + /// Maximum allowable time for a live connection with a peer during the + /// lookup process, in seconds. + pub lookup_connection_lifespan: u64, + /// The maximum number of retries for outbound connection establishment + /// during the lookup process. + pub lookup_connect_retries: usize, + + ///////////////// + // REFRESH + //////////////// + /// Interval at which the table refreshes its entries, in seconds. + pub refresh_interval: u64, + /// Timeout duration for a peer response during the table refresh process, + /// in seconds. + pub refresh_response_timeout: u64, + /// The maximum number of retries for outbound connection establishment + /// during the refresh process. + pub refresh_connect_retries: usize, +} + +impl Default for Config { + fn default() -> Self { + Config { + version: "0.1.0".parse().unwrap(), + + handshake_timeout: 2, + ping_interval: 20, + ping_timeout: 2, + + bootstrap_peers: vec![], + listen_endpoint: None, + peer_endpoints: vec![], + inbound_slots: 12, + outbound_slots: 12, + max_connect_retries: 3, + discovery_port: 0, + seeding_interval: 60, + + lookup_inbound_slots: 20, + lookup_outbound_slots: 20, + lookup_response_timeout: 1, + lookup_connection_lifespan: 3, + lookup_connect_retries: 3, + + refresh_interval: 1800, + refresh_response_timeout: 1, + refresh_connect_retries: 3, + } + } +} diff --git a/karyons_p2p/src/discovery/lookup.rs b/karyons_p2p/src/discovery/lookup.rs new file mode 100644 index 0000000..f404133 --- /dev/null +++ b/karyons_p2p/src/discovery/lookup.rs @@ -0,0 +1,366 @@ +use std::{sync::Arc, time::Duration}; + +use futures_util::{stream::FuturesUnordered, StreamExt}; +use log::{error, trace}; +use rand::{rngs::OsRng, seq::SliceRandom, RngCore}; +use smol::lock::{Mutex, RwLock}; + +use karyons_core::{async_utils::timeout, utils::decode, Executor}; + +use karyons_net::{Conn, Endpoint}; + +use crate::{ + io_codec::IOCodec, + message::{ + get_msg_payload, FindPeerMsg, NetMsg, NetMsgCmd, PeerMsg, PeersMsg, PingMsg, PongMsg, + ShutdownMsg, + }, + monitor::{ConnEvent, DiscoveryEvent, Monitor}, + net::{ConnectionSlots, Connector, Listener}, + routing_table::RoutingTable, + utils::version_match, + Config, Error, PeerID, Result, +}; + +/// Maximum number of peers that can be returned in a PeersMsg. +pub const MAX_PEERS_IN_PEERSMSG: usize = 10; + +pub struct LookupService { + /// Peer's ID + id: PeerID, + + /// Routing Table + table: Arc<Mutex<RoutingTable>>, + + /// Listener + listener: Arc<Listener>, + /// Connector + connector: Arc<Connector>, + + /// Outbound slots. + outbound_slots: Arc<ConnectionSlots>, + + /// Resolved listen endpoint + listen_endpoint: Option<RwLock<Endpoint>>, + + /// Holds the configuration for the P2P network. + config: Arc<Config>, + + /// Responsible for network and system monitoring. + monitor: Arc<Monitor>, +} + +impl LookupService { + /// Creates a new lookup service + pub fn new( + id: &PeerID, + table: Arc<Mutex<RoutingTable>>, + config: Arc<Config>, + monitor: Arc<Monitor>, + ) -> Self { + let inbound_slots = Arc::new(ConnectionSlots::new(config.lookup_inbound_slots)); + let outbound_slots = Arc::new(ConnectionSlots::new(config.lookup_outbound_slots)); + + let listener = Listener::new(inbound_slots.clone(), monitor.clone()); + let connector = Connector::new( + config.lookup_connect_retries, + outbound_slots.clone(), + monitor.clone(), + ); + + let listen_endpoint = config + .listen_endpoint + .as_ref() + .map(|endpoint| RwLock::new(endpoint.clone())); + + Self { + id: id.clone(), + table, + listener, + connector, + outbound_slots, + listen_endpoint, + config, + monitor, + } + } + + /// Start the lookup service. + pub async fn start(self: &Arc<Self>, ex: Executor<'_>) -> Result<()> { + self.start_listener(ex).await?; + Ok(()) + } + + /// Set the resolved listen endpoint. + pub async fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) { + if let Some(endpoint) = &self.listen_endpoint { + *endpoint.write().await = resolved_endpoint.clone(); + } + } + + /// Shuts down the lookup service. + pub async fn shutdown(&self) { + self.connector.shutdown().await; + self.listener.shutdown().await; + } + + /// Starts iterative lookup and populate the routing table. + /// + /// This method begins by generating a random peer ID and connecting to the + /// provided endpoint. It then sends a FindPeer message containing the + /// randomly generated peer ID. Upon receiving peers from the initial lookup, + /// it starts connecting to these received peers and sends them a FindPeer + /// message that contains our own peer ID. + pub async fn start_lookup(&self, endpoint: &Endpoint) -> Result<()> { + trace!("Lookup started {endpoint}"); + self.monitor + .notify(&DiscoveryEvent::LookupStarted(endpoint.clone()).into()) + .await; + + let mut random_peers = vec![]; + if let Err(err) = self.random_lookup(endpoint, &mut random_peers).await { + self.monitor + .notify(&DiscoveryEvent::LookupFailed(endpoint.clone()).into()) + .await; + return Err(err); + }; + + let mut peer_buffer = vec![]; + self.self_lookup(&random_peers, &mut peer_buffer).await; + + while peer_buffer.len() < MAX_PEERS_IN_PEERSMSG { + match random_peers.pop() { + Some(p) => peer_buffer.push(p), + None => break, + } + } + + for peer in peer_buffer.iter() { + let mut table = self.table.lock().await; + let result = table.add_entry(peer.clone().into()); + trace!("Add entry {:?}", result); + } + + self.monitor + .notify(&DiscoveryEvent::LookupSucceeded(endpoint.clone(), peer_buffer.len()).into()) + .await; + + Ok(()) + } + + /// Starts a random lookup + /// + /// This will perfom lookup on a random generated PeerID + async fn random_lookup( + &self, + endpoint: &Endpoint, + random_peers: &mut Vec<PeerMsg>, + ) -> Result<()> { + for _ in 0..2 { + let peer_id = PeerID::random(); + let peers = self.connect(&peer_id, endpoint.clone()).await?; + for peer in peers { + if random_peers.contains(&peer) + || peer.peer_id == self.id + || self.table.lock().await.contains_key(&peer.peer_id.0) + { + continue; + } + + random_peers.push(peer); + } + } + + Ok(()) + } + + /// Starts a self lookup + async fn self_lookup(&self, random_peers: &Vec<PeerMsg>, peer_buffer: &mut Vec<PeerMsg>) { + let mut tasks = FuturesUnordered::new(); + for peer in random_peers.choose_multiple(&mut OsRng, random_peers.len()) { + let endpoint = Endpoint::Tcp(peer.addr.clone(), peer.discovery_port); + tasks.push(self.connect(&self.id, endpoint)) + } + + while let Some(result) = tasks.next().await { + match result { + Ok(peers) => peer_buffer.extend(peers), + Err(err) => { + error!("Failed to do self lookup: {err}"); + } + } + } + } + + /// Connects to the given endpoint + async fn connect(&self, peer_id: &PeerID, endpoint: Endpoint) -> Result<Vec<PeerMsg>> { + let conn = self.connector.connect(&endpoint).await?; + let io_codec = IOCodec::new(conn); + let result = self.handle_outbound(io_codec, peer_id).await; + + self.monitor + .notify(&ConnEvent::Disconnected(endpoint).into()) + .await; + self.outbound_slots.remove().await; + + result + } + + /// Handles outbound connection + async fn handle_outbound(&self, io_codec: IOCodec, peer_id: &PeerID) -> Result<Vec<PeerMsg>> { + trace!("Send Ping msg"); + self.send_ping_msg(&io_codec).await?; + + trace!("Send FindPeer msg"); + let peers = self.send_findpeer_msg(&io_codec, peer_id).await?; + + if peers.0.len() >= MAX_PEERS_IN_PEERSMSG { + return Err(Error::Lookup("Received too many peers in PeersMsg")); + } + + trace!("Send Peer msg"); + if let Some(endpoint) = &self.listen_endpoint { + self.send_peer_msg(&io_codec, endpoint.read().await.clone()) + .await?; + } + + trace!("Send Shutdown msg"); + self.send_shutdown_msg(&io_codec).await?; + + Ok(peers.0) + } + + /// Start a listener. + async fn start_listener(self: &Arc<Self>, ex: Executor<'_>) -> Result<()> { + let addr = match &self.listen_endpoint { + Some(a) => a.read().await.addr()?.clone(), + None => return Ok(()), + }; + + let endpoint = Endpoint::Tcp(addr, self.config.discovery_port); + + let selfc = self.clone(); + let callback = |conn: Conn| async move { + let t = Duration::from_secs(selfc.config.lookup_connection_lifespan); + timeout(t, selfc.handle_inbound(conn)).await??; + Ok(()) + }; + + self.listener.start(ex, endpoint.clone(), callback).await?; + Ok(()) + } + + /// Handles inbound connection + async fn handle_inbound(self: &Arc<Self>, conn: Conn) -> Result<()> { + let io_codec = IOCodec::new(conn); + loop { + let msg: NetMsg = io_codec.read().await?; + trace!("Receive msg {:?}", msg.header.command); + + if let NetMsgCmd::Shutdown = msg.header.command { + return Ok(()); + } + + match &msg.header.command { + NetMsgCmd::Ping => { + let (ping_msg, _) = decode::<PingMsg>(&msg.payload)?; + if !version_match(&self.config.version.req, &ping_msg.version) { + return Err(Error::IncompatibleVersion("system: {}".into())); + } + self.send_pong_msg(ping_msg.nonce, &io_codec).await?; + } + NetMsgCmd::FindPeer => { + let (findpeer_msg, _) = decode::<FindPeerMsg>(&msg.payload)?; + let peer_id = findpeer_msg.0; + self.send_peers_msg(&peer_id, &io_codec).await?; + } + NetMsgCmd::Peer => { + let (peer, _) = decode::<PeerMsg>(&msg.payload)?; + let result = self.table.lock().await.add_entry(peer.clone().into()); + trace!("Add entry result: {:?}", result); + } + c => return Err(Error::InvalidMsg(format!("Unexpected msg: {:?}", c))), + } + } + } + + /// Sends a Ping msg and wait to receive the Pong message. + async fn send_ping_msg(&self, io_codec: &IOCodec) -> Result<()> { + trace!("Send Pong msg"); + + let mut nonce: [u8; 32] = [0; 32]; + RngCore::fill_bytes(&mut OsRng, &mut nonce); + + let ping_msg = PingMsg { + version: self.config.version.v.clone(), + nonce, + }; + io_codec.write(NetMsgCmd::Ping, &ping_msg).await?; + + let t = Duration::from_secs(self.config.lookup_response_timeout); + let recv_msg: NetMsg = io_codec.read_timeout(t).await?; + + let payload = get_msg_payload!(Pong, recv_msg); + let (pong_msg, _) = decode::<PongMsg>(&payload)?; + + if ping_msg.nonce != pong_msg.0 { + return Err(Error::InvalidPongMsg); + } + + Ok(()) + } + + /// Sends a Pong msg + async fn send_pong_msg(&self, nonce: [u8; 32], io_codec: &IOCodec) -> Result<()> { + trace!("Send Pong msg"); + io_codec.write(NetMsgCmd::Pong, &PongMsg(nonce)).await?; + Ok(()) + } + + /// Sends a FindPeer msg and wait to receivet the Peers msg. + async fn send_findpeer_msg(&self, io_codec: &IOCodec, peer_id: &PeerID) -> Result<PeersMsg> { + trace!("Send FindPeer msg"); + io_codec + .write(NetMsgCmd::FindPeer, &FindPeerMsg(peer_id.clone())) + .await?; + + let t = Duration::from_secs(self.config.lookup_response_timeout); + let recv_msg: NetMsg = io_codec.read_timeout(t).await?; + + let payload = get_msg_payload!(Peers, recv_msg); + let (peers, _) = decode(&payload)?; + + Ok(peers) + } + + /// Sends a Peers msg. + async fn send_peers_msg(&self, peer_id: &PeerID, io_codec: &IOCodec) -> Result<()> { + trace!("Send Peers msg"); + let table = self.table.lock().await; + let entries = table.closest_entries(&peer_id.0, MAX_PEERS_IN_PEERSMSG); + let peers: Vec<PeerMsg> = entries.into_iter().map(|e| e.into()).collect(); + drop(table); + io_codec.write(NetMsgCmd::Peers, &PeersMsg(peers)).await?; + Ok(()) + } + + /// Sends a Peer msg. + async fn send_peer_msg(&self, io_codec: &IOCodec, endpoint: Endpoint) -> Result<()> { + trace!("Send Peer msg"); + let peer_msg = PeerMsg { + addr: endpoint.addr()?.clone(), + port: *endpoint.port()?, + discovery_port: self.config.discovery_port, + peer_id: self.id.clone(), + }; + io_codec.write(NetMsgCmd::Peer, &peer_msg).await?; + Ok(()) + } + + /// Sends a Shutdown msg. + async fn send_shutdown_msg(&self, io_codec: &IOCodec) -> Result<()> { + trace!("Send Shutdown msg"); + io_codec.write(NetMsgCmd::Shutdown, &ShutdownMsg(0)).await?; + Ok(()) + } +} diff --git a/karyons_p2p/src/discovery/mod.rs b/karyons_p2p/src/discovery/mod.rs new file mode 100644 index 0000000..94b350b --- /dev/null +++ b/karyons_p2p/src/discovery/mod.rs @@ -0,0 +1,262 @@ +mod lookup; +mod refresh; + +use std::sync::Arc; + +use log::{error, info}; +use rand::{rngs::OsRng, seq::SliceRandom}; +use smol::lock::Mutex; + +use karyons_core::{ + async_utils::{Backoff, TaskGroup, TaskResult}, + Executor, +}; + +use karyons_net::{Conn, Endpoint}; + +use crate::{ + config::Config, + monitor::Monitor, + net::ConnQueue, + net::{ConnDirection, ConnectionSlots, Connector, Listener}, + routing_table::{ + Entry, EntryStatusFlag, RoutingTable, CONNECTED_ENTRY, DISCONNECTED_ENTRY, PENDING_ENTRY, + UNREACHABLE_ENTRY, UNSTABLE_ENTRY, + }, + Error, PeerID, Result, +}; + +use lookup::LookupService; +use refresh::RefreshService; + +pub type ArcDiscovery = Arc<Discovery>; + +pub struct Discovery { + /// Routing table + table: Arc<Mutex<RoutingTable>>, + + /// Lookup Service + lookup_service: Arc<LookupService>, + + /// Refresh Service + refresh_service: Arc<RefreshService>, + + /// Connector + connector: Arc<Connector>, + /// Listener + listener: Arc<Listener>, + + /// Connection queue + conn_queue: Arc<ConnQueue>, + + /// Inbound slots. + pub(crate) inbound_slots: Arc<ConnectionSlots>, + /// Outbound slots. + pub(crate) outbound_slots: Arc<ConnectionSlots>, + + /// Managing spawned tasks. + task_group: TaskGroup, + + /// Holds the configuration for the P2P network. + config: Arc<Config>, +} + +impl Discovery { + /// Creates a new Discovery + pub fn new( + peer_id: &PeerID, + conn_queue: Arc<ConnQueue>, + config: Arc<Config>, + monitor: Arc<Monitor>, + ) -> ArcDiscovery { + let inbound_slots = Arc::new(ConnectionSlots::new(config.inbound_slots)); + let outbound_slots = Arc::new(ConnectionSlots::new(config.outbound_slots)); + + let table_key = peer_id.0; + let table = Arc::new(Mutex::new(RoutingTable::new(table_key))); + + let refresh_service = RefreshService::new(config.clone(), table.clone(), monitor.clone()); + let lookup_service = + LookupService::new(peer_id, table.clone(), config.clone(), monitor.clone()); + + let connector = Connector::new( + config.max_connect_retries, + outbound_slots.clone(), + monitor.clone(), + ); + let listener = Listener::new(inbound_slots.clone(), monitor.clone()); + + Arc::new(Self { + refresh_service: Arc::new(refresh_service), + lookup_service: Arc::new(lookup_service), + conn_queue, + table, + inbound_slots, + outbound_slots, + connector, + listener, + task_group: TaskGroup::new(), + config, + }) + } + + /// Start the Discovery + pub async fn start(self: &Arc<Self>, ex: Executor<'_>) -> Result<()> { + // Check if the listen_endpoint is provided, and if so, start a listener. + if let Some(endpoint) = &self.config.listen_endpoint { + // Return an error if the discovery port is set to 0. + if self.config.discovery_port == 0 { + return Err(Error::Config( + "Please add a valid discovery port".to_string(), + )); + } + + let resolved_endpoint = self.start_listener(endpoint, ex.clone()).await?; + + if endpoint.addr()? != resolved_endpoint.addr()? { + info!("Resolved listen endpoint: {resolved_endpoint}"); + self.lookup_service + .set_listen_endpoint(&resolved_endpoint) + .await; + self.refresh_service + .set_listen_endpoint(&resolved_endpoint) + .await; + } + } + + // Start the lookup service + self.lookup_service.start(ex.clone()).await?; + // Start the refresh service + self.refresh_service.start(ex.clone()).await?; + + // Attempt to manually connect to peer endpoints provided in the Config. + for endpoint in self.config.peer_endpoints.iter() { + let _ = self.connect(endpoint, None, ex.clone()).await; + } + + // Start connect loop + let selfc = self.clone(); + self.task_group + .spawn(ex.clone(), selfc.connect_loop(ex), |res| async move { + if let TaskResult::Completed(Err(err)) = res { + error!("Connect loop stopped: {err}"); + } + }); + + Ok(()) + } + + /// Shuts down the discovery + pub async fn shutdown(&self) { + self.task_group.cancel().await; + self.connector.shutdown().await; + self.listener.shutdown().await; + + self.refresh_service.shutdown().await; + self.lookup_service.shutdown().await; + } + + /// Start a listener and on success, return the resolved endpoint. + async fn start_listener( + self: &Arc<Self>, + endpoint: &Endpoint, + ex: Executor<'_>, + ) -> Result<Endpoint> { + let selfc = self.clone(); + let callback = |conn: Conn| async move { + selfc.conn_queue.handle(conn, ConnDirection::Inbound).await; + Ok(()) + }; + + let resolved_endpoint = self.listener.start(ex, endpoint.clone(), callback).await?; + Ok(resolved_endpoint) + } + + /// This method will attempt to connect to a peer in the routing table. + /// If the routing table is empty, it will start the seeding process for + /// finding new peers. + /// + /// This will perform a backoff to prevent getting stuck in the loop + /// if the seeding process couldn't find any peers. + async fn connect_loop(self: Arc<Self>, ex: Executor<'_>) -> Result<()> { + let backoff = Backoff::new(500, self.config.seeding_interval * 1000); + loop { + let random_entry = self.random_entry(PENDING_ENTRY).await; + match random_entry { + Some(entry) => { + backoff.reset(); + let endpoint = Endpoint::Tcp(entry.addr, entry.port); + self.connect(&endpoint, Some(entry.key.into()), ex.clone()) + .await; + } + None => { + backoff.sleep().await; + self.start_seeding().await; + } + } + } + } + + /// Connect to the given endpoint using the connector + async fn connect(self: &Arc<Self>, endpoint: &Endpoint, pid: Option<PeerID>, ex: Executor<'_>) { + let selfc = self.clone(); + let pid_cloned = pid.clone(); + let cback = |conn: Conn| async move { + selfc.conn_queue.handle(conn, ConnDirection::Outbound).await; + if let Some(pid) = pid_cloned { + selfc.update_entry(&pid, DISCONNECTED_ENTRY).await; + } + Ok(()) + }; + + let res = self.connector.connect_with_cback(ex, endpoint, cback).await; + + if let Some(pid) = &pid { + match res { + Ok(_) => { + self.update_entry(pid, CONNECTED_ENTRY).await; + } + Err(_) => { + self.update_entry(pid, UNREACHABLE_ENTRY).await; + } + } + } + } + + /// Starts seeding process. + /// + /// This method randomly selects a peer from the routing table and + /// attempts to connect to that peer for the initial lookup. If the routing + /// table doesn't have an available entry, it will connect to one of the + /// provided bootstrap endpoints in the `Config` and initiate the lookup. + async fn start_seeding(&self) { + match self.random_entry(PENDING_ENTRY | CONNECTED_ENTRY).await { + Some(entry) => { + let endpoint = Endpoint::Tcp(entry.addr, entry.discovery_port); + if let Err(err) = self.lookup_service.start_lookup(&endpoint).await { + self.update_entry(&entry.key.into(), UNSTABLE_ENTRY).await; + error!("Failed to do lookup: {endpoint}: {err}"); + } + } + None => { + let peers = &self.config.bootstrap_peers; + for endpoint in peers.choose_multiple(&mut OsRng, peers.len()) { + if let Err(err) = self.lookup_service.start_lookup(endpoint).await { + error!("Failed to do lookup: {endpoint}: {err}"); + } + } + } + } + } + + /// Returns a random entry from routing table. + async fn random_entry(&self, entry_flag: EntryStatusFlag) -> Option<Entry> { + self.table.lock().await.random_entry(entry_flag).cloned() + } + + /// Update the entry status + async fn update_entry(&self, pid: &PeerID, entry_flag: EntryStatusFlag) { + let table = &mut self.table.lock().await; + table.update_entry(&pid.0, entry_flag); + } +} diff --git a/karyons_p2p/src/discovery/refresh.rs b/karyons_p2p/src/discovery/refresh.rs new file mode 100644 index 0000000..7582c84 --- /dev/null +++ b/karyons_p2p/src/discovery/refresh.rs @@ -0,0 +1,289 @@ +use std::{sync::Arc, time::Duration}; + +use bincode::{Decode, Encode}; +use log::{error, info, trace}; +use rand::{rngs::OsRng, RngCore}; +use smol::{ + lock::{Mutex, RwLock}, + stream::StreamExt, + Timer, +}; + +use karyons_core::{ + async_utils::{timeout, Backoff, TaskGroup, TaskResult}, + utils::{decode, encode}, + Executor, +}; + +use karyons_net::{dial_udp, listen_udp, Addr, Connection, Endpoint, NetError, Port, UdpConn}; + +/// Maximum failures for an entry before removing it from the routing table. +pub const MAX_FAILURES: u32 = 3; + +/// Ping message size +const PINGMSG_SIZE: usize = 32; + +use crate::{ + monitor::{ConnEvent, DiscoveryEvent, Monitor}, + routing_table::{BucketEntry, Entry, RoutingTable, PENDING_ENTRY, UNREACHABLE_ENTRY}, + Config, Error, Result, +}; + +#[derive(Decode, Encode, Debug, Clone)] +pub struct PingMsg(pub [u8; 32]); + +#[derive(Decode, Encode, Debug)] +pub struct PongMsg(pub [u8; 32]); + +pub struct RefreshService { + /// Routing table + table: Arc<Mutex<RoutingTable>>, + + /// Resolved listen endpoint + listen_endpoint: Option<RwLock<Endpoint>>, + + /// Managing spawned tasks. + task_group: TaskGroup, + + /// Holds the configuration for the P2P network. + config: Arc<Config>, + + /// Responsible for network and system monitoring. + monitor: Arc<Monitor>, +} + +impl RefreshService { + /// Creates a new refresh service + pub fn new( + config: Arc<Config>, + table: Arc<Mutex<RoutingTable>>, + monitor: Arc<Monitor>, + ) -> Self { + let listen_endpoint = config + .listen_endpoint + .as_ref() + .map(|endpoint| RwLock::new(endpoint.clone())); + + Self { + table, + listen_endpoint, + task_group: TaskGroup::new(), + config, + monitor, + } + } + + /// Start the refresh service + pub async fn start(self: &Arc<Self>, ex: Executor<'_>) -> Result<()> { + if let Some(endpoint) = &self.listen_endpoint { + let endpoint = endpoint.read().await; + let addr = endpoint.addr()?; + let port = self.config.discovery_port; + + let selfc = self.clone(); + self.task_group.spawn( + ex.clone(), + selfc.listen_loop(addr.clone(), port), + |res| async move { + if let TaskResult::Completed(Err(err)) = res { + error!("Listen loop stopped: {err}"); + } + }, + ); + } + + let selfc = self.clone(); + self.task_group.spawn( + ex.clone(), + selfc.refresh_loop(ex.clone()), + |res| async move { + if let TaskResult::Completed(Err(err)) = res { + error!("Refresh loop stopped: {err}"); + } + }, + ); + + Ok(()) + } + + /// Set the resolved listen endpoint. + pub async fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) { + if let Some(endpoint) = &self.listen_endpoint { + *endpoint.write().await = resolved_endpoint.clone(); + } + } + + /// Shuts down the refresh service + pub async fn shutdown(&self) { + self.task_group.cancel().await; + } + + /// Initiates periodic refreshing of the routing table. This function will + /// select 8 random entries from each bucket in the routing table and start + /// sending Ping messages to the entries. + async fn refresh_loop(self: Arc<Self>, ex: Executor<'_>) -> Result<()> { + let mut timer = Timer::interval(Duration::from_secs(self.config.refresh_interval)); + loop { + timer.next().await; + trace!("Start refreshing the routing table..."); + + self.monitor + .notify(&DiscoveryEvent::RefreshStarted.into()) + .await; + + let table = self.table.lock().await; + let mut entries: Vec<BucketEntry> = vec![]; + for bucket in table.iter() { + for entry in bucket.random_iter(8) { + entries.push(entry.clone()) + } + } + drop(table); + + self.clone().do_refresh(&entries, ex.clone()).await; + } + } + + /// Iterates over the entries and spawns a new task for each entry to + /// initiate a connection attempt to that entry. + async fn do_refresh(self: Arc<Self>, entries: &[BucketEntry], ex: Executor<'_>) { + for chunk in entries.chunks(16) { + let mut tasks = Vec::new(); + for bucket_entry in chunk { + if bucket_entry.is_connected() { + continue; + } + + if bucket_entry.failures >= MAX_FAILURES { + self.table + .lock() + .await + .remove_entry(&bucket_entry.entry.key); + return; + } + + tasks.push(ex.spawn(self.clone().refresh_entry(bucket_entry.clone()))) + } + + for task in tasks { + task.await; + } + } + } + + /// Initiates refresh for a specific entry within the routing table. It + /// updates the routing table according to the result. + async fn refresh_entry(self: Arc<Self>, bucket_entry: BucketEntry) { + let key = &bucket_entry.entry.key; + match self.connect(&bucket_entry.entry).await { + Ok(_) => { + self.table.lock().await.update_entry(key, PENDING_ENTRY); + } + Err(err) => { + trace!("Failed to refresh entry {:?}: {err}", key); + let table = &mut self.table.lock().await; + if bucket_entry.failures >= MAX_FAILURES { + table.remove_entry(key); + return; + } + table.update_entry(key, UNREACHABLE_ENTRY); + } + } + } + + /// Initiates a UDP connection with the entry and attempts to send a Ping + /// message. If it fails, it retries according to the allowed retries + /// specified in the Config, with backoff between each retry. + async fn connect(&self, entry: &Entry) -> Result<()> { + let mut retry = 0; + let conn = dial_udp(&entry.addr, &entry.discovery_port).await?; + let backoff = Backoff::new(100, 5000); + while retry < self.config.refresh_connect_retries { + match self.send_ping_msg(&conn).await { + Ok(()) => return Ok(()), + Err(Error::KaryonsNet(NetError::Timeout)) => { + retry += 1; + backoff.sleep().await; + } + Err(err) => { + return Err(err); + } + } + } + + Err(NetError::Timeout.into()) + } + + /// Set up a UDP listener and start listening for Ping messages from other + /// peers. + async fn listen_loop(self: Arc<Self>, addr: Addr, port: Port) -> Result<()> { + let endpoint = Endpoint::Udp(addr.clone(), port); + let conn = match listen_udp(&addr, &port).await { + Ok(c) => { + self.monitor + .notify(&ConnEvent::Listening(endpoint.clone()).into()) + .await; + c + } + Err(err) => { + self.monitor + .notify(&ConnEvent::ListenFailed(endpoint.clone()).into()) + .await; + return Err(err.into()); + } + }; + info!("Start listening on {endpoint}"); + + loop { + let res = self.listen_to_ping_msg(&conn).await; + if let Err(err) = res { + trace!("Failed to handle ping msg {err}"); + self.monitor.notify(&ConnEvent::AcceptFailed.into()).await; + } + } + } + + /// Listen to receive a Ping message and respond with a Pong message. + async fn listen_to_ping_msg(&self, conn: &UdpConn) -> Result<()> { + let mut buf = [0; PINGMSG_SIZE]; + let (_, endpoint) = conn.recv_from(&mut buf).await?; + + self.monitor + .notify(&ConnEvent::Accepted(endpoint.clone()).into()) + .await; + + let (ping_msg, _) = decode::<PingMsg>(&buf)?; + + let pong_msg = PongMsg(ping_msg.0); + let buffer = encode(&pong_msg)?; + + conn.send_to(&buffer, &endpoint).await?; + + self.monitor + .notify(&ConnEvent::Disconnected(endpoint.clone()).into()) + .await; + Ok(()) + } + + /// Sends a Ping msg and wait to receive the Pong message. + async fn send_ping_msg(&self, conn: &UdpConn) -> Result<()> { + let mut nonce: [u8; 32] = [0; 32]; + RngCore::fill_bytes(&mut OsRng, &mut nonce); + + let ping_msg = PingMsg(nonce); + let buffer = encode(&ping_msg)?; + conn.send(&buffer).await?; + + let buf = &mut [0; PINGMSG_SIZE]; + let t = Duration::from_secs(self.config.refresh_response_timeout); + timeout(t, conn.recv(buf)).await??; + + let (pong_msg, _) = decode::<PongMsg>(buf)?; + + if ping_msg.0 != pong_msg.0 { + return Err(Error::InvalidPongMsg); + } + + Ok(()) + } +} diff --git a/karyons_p2p/src/error.rs b/karyons_p2p/src/error.rs new file mode 100644 index 0000000..945e90a --- /dev/null +++ b/karyons_p2p/src/error.rs @@ -0,0 +1,82 @@ +use thiserror::Error as ThisError; + +pub type Result<T> = std::result::Result<T, Error>; + +/// Represents Karyons's p2p Error. +#[derive(ThisError, Debug)] +pub enum Error { + #[error("IO Error: {0}")] + IO(#[from] std::io::Error), + + #[error("Unsupported protocol error: {0}")] + UnsupportedProtocol(String), + + #[error("Invalid message error: {0}")] + InvalidMsg(String), + + #[error("Parse error: {0}")] + ParseError(String), + + #[error("Incompatible version error: {0}")] + IncompatibleVersion(String), + + #[error("Config error: {0}")] + Config(String), + + #[error("Peer shutdown")] + PeerShutdown, + + #[error("Invalid Pong Msg")] + InvalidPongMsg, + + #[error("Discovery error: {0}")] + Discovery(&'static str), + + #[error("Lookup error: {0}")] + Lookup(&'static str), + + #[error("Peer already connected")] + PeerAlreadyConnected, + + #[error("Channel Send Error: {0}")] + ChannelSend(String), + + #[error("Channel Receive Error: {0}")] + ChannelRecv(String), + + #[error("CORE::ERROR : {0}")] + KaryonsCore(#[from] karyons_core::error::Error), + + #[error("NET::ERROR : {0}")] + KaryonsNet(#[from] karyons_net::NetError), +} + +impl<T> From<smol::channel::SendError<T>> for Error { + fn from(error: smol::channel::SendError<T>) -> Self { + Error::ChannelSend(error.to_string()) + } +} + +impl From<smol::channel::RecvError> for Error { + fn from(error: smol::channel::RecvError) -> Self { + Error::ChannelRecv(error.to_string()) + } +} + +impl From<std::num::ParseIntError> for Error { + fn from(error: std::num::ParseIntError) -> Self { + Error::ParseError(error.to_string()) + } +} + +impl From<std::num::ParseFloatError> for Error { + fn from(error: std::num::ParseFloatError) -> Self { + Error::ParseError(error.to_string()) + } +} + +impl From<semver::Error> for Error { + fn from(error: semver::Error) -> Self { + Error::ParseError(error.to_string()) + } +} diff --git a/karyons_p2p/src/io_codec.rs b/karyons_p2p/src/io_codec.rs new file mode 100644 index 0000000..4515832 --- /dev/null +++ b/karyons_p2p/src/io_codec.rs @@ -0,0 +1,102 @@ +use std::time::Duration; + +use bincode::{Decode, Encode}; + +use karyons_core::{ + async_utils::timeout, + utils::{decode, encode, encode_into_slice}, +}; + +use karyons_net::{Connection, NetError}; + +use crate::{ + message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE}, + Error, Result, +}; + +pub trait CodecMsg: Decode + Encode + std::fmt::Debug {} +impl<T: Encode + Decode + std::fmt::Debug> CodecMsg for T {} + +/// I/O codec working with generic network connections. +/// +/// It is responsible for both decoding data received from the network and +/// encoding data before sending it. +pub struct IOCodec { + conn: Box<dyn Connection>, +} + +impl IOCodec { + /// Creates a new IOCodec. + pub fn new(conn: Box<dyn Connection>) -> Self { + Self { conn } + } + + /// Reads a message of type `NetMsg` from the connection. + /// + /// It reads the first 6 bytes as the header of the message, then reads + /// and decodes the remaining message data based on the determined header. + pub async fn read(&self) -> Result<NetMsg> { + // Read 6 bytes to get the header of the incoming message + let mut buf = [0; MSG_HEADER_SIZE]; + self.conn.recv(&mut buf).await?; + + // Decode the header from bytes to NetMsgHeader + let (header, _) = decode::<NetMsgHeader>(&buf)?; + + if header.payload_size > MAX_ALLOWED_MSG_SIZE { + return Err(Error::InvalidMsg( + "Message exceeds the maximum allowed size".to_string(), + )); + } + + // Create a buffer to hold the message based on its length + let mut payload = vec![0; header.payload_size as usize]; + self.conn.recv(&mut payload).await?; + + Ok(NetMsg { header, payload }) + } + + /// Writes a message of type `T` to the connection. + /// + /// Before appending the actual message payload, it calculates the length of + /// the encoded message in bytes and appends this length to the message header. + pub async fn write<T: CodecMsg>(&self, command: NetMsgCmd, msg: &T) -> Result<()> { + let payload = encode(msg)?; + + // Create a buffer to hold the message header (6 bytes) + let header_buf = &mut [0; MSG_HEADER_SIZE]; + let header = NetMsgHeader { + command, + payload_size: payload.len() as u32, + }; + encode_into_slice(&header, header_buf)?; + + let mut buffer = vec![]; + // Append the header bytes to the buffer + buffer.extend_from_slice(header_buf); + // Append the message payload to the buffer + buffer.extend_from_slice(&payload); + + self.conn.send(&buffer).await?; + Ok(()) + } + + /// Reads a message of type `NetMsg` with the given timeout. + pub async fn read_timeout(&self, duration: Duration) -> Result<NetMsg> { + timeout(duration, self.read()) + .await + .map_err(|_| NetError::Timeout)? + } + + /// Writes a message of type `T` with the given timeout. + pub async fn write_timeout<T: CodecMsg>( + &self, + command: NetMsgCmd, + msg: &T, + duration: Duration, + ) -> Result<()> { + timeout(duration, self.write(command, msg)) + .await + .map_err(|_| NetError::Timeout)? + } +} diff --git a/karyons_p2p/src/lib.rs b/karyons_p2p/src/lib.rs new file mode 100644 index 0000000..08ba059 --- /dev/null +++ b/karyons_p2p/src/lib.rs @@ -0,0 +1,27 @@ +mod backend; +mod config; +mod discovery; +mod error; +mod io_codec; +mod message; +mod net; +mod peer; +mod peer_pool; +mod protocols; +mod routing_table; +mod utils; + +/// Responsible for network and system monitoring. +/// [`Read More`](./monitor/struct.Monitor.html) +pub mod monitor; +/// Defines the protocol trait. +/// [`Read More`](./protocol/trait.Protocol.html) +pub mod protocol; + +pub use backend::{ArcBackend, Backend}; +pub use config::Config; +pub use error::Error as P2pError; +pub use peer::{ArcPeer, PeerID}; +pub use utils::Version; + +use error::{Error, Result}; diff --git a/karyons_p2p/src/message.rs b/karyons_p2p/src/message.rs new file mode 100644 index 0000000..833f6f4 --- /dev/null +++ b/karyons_p2p/src/message.rs @@ -0,0 +1,133 @@ +use std::collections::HashMap; + +use bincode::{Decode, Encode}; + +use karyons_net::{Addr, Port}; + +use crate::{protocol::ProtocolID, routing_table::Entry, utils::VersionInt, PeerID}; + +/// The size of the message header, in bytes. +pub const MSG_HEADER_SIZE: usize = 6; + +/// The maximum allowed size for a message in bytes. +pub const MAX_ALLOWED_MSG_SIZE: u32 = 1000000; + +/// Defines the main message in the Karyon P2P network. +/// +/// This message structure consists of a header and payload, where the header +/// typically contains essential information about the message, and the payload +/// contains the actual data being transmitted. +#[derive(Decode, Encode, Debug, Clone)] +pub struct NetMsg { + pub header: NetMsgHeader, + pub payload: Vec<u8>, +} + +/// Represents the header of a message. +#[derive(Decode, Encode, Debug, Clone)] +pub struct NetMsgHeader { + pub command: NetMsgCmd, + pub payload_size: u32, +} + +/// Defines message commands. +#[derive(Decode, Encode, Debug, Clone)] +#[repr(u8)] +pub enum NetMsgCmd { + Version, + Verack, + Protocol, + Shutdown, + + // NOTE: The following commands are used during the lookup process. + Ping, + Pong, + FindPeer, + Peer, + Peers, +} + +/// Defines a message related to a specific protocol. +#[derive(Decode, Encode, Debug, Clone)] +pub struct ProtocolMsg { + pub protocol_id: ProtocolID, + pub payload: Vec<u8>, +} + +/// Version message, providing information about a peer's capabilities. +#[derive(Decode, Encode, Debug, Clone)] +pub struct VerMsg { + pub peer_id: PeerID, + pub version: VersionInt, + pub protocols: HashMap<ProtocolID, VersionInt>, +} + +/// VerAck message acknowledging the receipt of a Version message. +#[derive(Decode, Encode, Debug, Clone)] +pub struct VerAckMsg(pub PeerID); + +/// Shutdown message. +#[derive(Decode, Encode, Debug, Clone)] +pub struct ShutdownMsg(pub u8); + +/// Ping message with a nonce and version information. +#[derive(Decode, Encode, Debug, Clone)] +pub struct PingMsg { + pub nonce: [u8; 32], + pub version: VersionInt, +} + +/// Ping message with a nonce. +#[derive(Decode, Encode, Debug)] +pub struct PongMsg(pub [u8; 32]); + +/// FindPeer message used to find a specific peer. +#[derive(Decode, Encode, Debug)] +pub struct FindPeerMsg(pub PeerID); + +/// PeerMsg containing information about a peer. +#[derive(Decode, Encode, Debug, Clone, PartialEq, Eq)] +pub struct PeerMsg { + pub peer_id: PeerID, + pub addr: Addr, + pub port: Port, + pub discovery_port: Port, +} + +/// PeersMsg a list of `PeerMsg`. +#[derive(Decode, Encode, Debug)] +pub struct PeersMsg(pub Vec<PeerMsg>); + +macro_rules! get_msg_payload { + ($a:ident, $b:expr) => { + if let NetMsgCmd::$a = $b.header.command { + $b.payload + } else { + return Err(Error::InvalidMsg(format!("Unexpected msg{:?}", $b))); + } + }; +} + +pub(super) use get_msg_payload; + +impl From<Entry> for PeerMsg { + fn from(entry: Entry) -> PeerMsg { + PeerMsg { + peer_id: PeerID(entry.key), + addr: entry.addr, + port: entry.port, + discovery_port: entry.discovery_port, + } + } +} + +impl From<PeerMsg> for Entry { + fn from(peer: PeerMsg) -> Entry { + Entry { + key: peer.peer_id.0, + addr: peer.addr, + port: peer.port, + discovery_port: peer.discovery_port, + } + } +} diff --git a/karyons_p2p/src/monitor.rs b/karyons_p2p/src/monitor.rs new file mode 100644 index 0000000..ee0bf44 --- /dev/null +++ b/karyons_p2p/src/monitor.rs @@ -0,0 +1,154 @@ +use std::fmt; + +use crate::PeerID; + +use karyons_core::pubsub::{ArcPublisher, Publisher, Subscription}; + +use karyons_net::Endpoint; + +/// Responsible for network and system monitoring. +/// +/// It use pub-sub pattern to notify the subscribers with new events. +/// +/// # Example +/// +/// ``` +/// use karyons_p2p::{Config, Backend, PeerID}; +/// async { +/// +/// let backend = Backend::new(PeerID::random(), Config::default()); +/// +/// // Create a new Subscription +/// let sub = backend.monitor().await; +/// +/// let event = sub.recv().await; +/// }; +/// ``` +pub struct Monitor { + inner: ArcPublisher<MonitorEvent>, +} + +impl Monitor { + /// Creates a new Monitor + pub(crate) fn new() -> Monitor { + Self { + inner: Publisher::new(), + } + } + + /// Sends a new monitor event to all subscribers. + pub async fn notify(&self, event: &MonitorEvent) { + self.inner.notify(event).await; + } + + /// Subscribes to listen to new events. + pub async fn subscribe(&self) -> Subscription<MonitorEvent> { + self.inner.subscribe().await + } +} + +/// Defines various type of event that can be monitored. +#[derive(Clone, Debug)] +pub enum MonitorEvent { + Conn(ConnEvent), + PeerPool(PeerPoolEvent), + Discovery(DiscoveryEvent), +} + +/// Defines connection-related events. +#[derive(Clone, Debug)] +pub enum ConnEvent { + Connected(Endpoint), + ConnectRetried(Endpoint), + ConnectFailed(Endpoint), + Accepted(Endpoint), + AcceptFailed, + Disconnected(Endpoint), + Listening(Endpoint), + ListenFailed(Endpoint), +} + +/// Defines `PeerPool` events. +#[derive(Clone, Debug)] +pub enum PeerPoolEvent { + NewPeer(PeerID), + RemovePeer(PeerID), +} + +/// Defines `Discovery` events. +#[derive(Clone, Debug)] +pub enum DiscoveryEvent { + LookupStarted(Endpoint), + LookupFailed(Endpoint), + LookupSucceeded(Endpoint, usize), + RefreshStarted, +} + +impl fmt::Display for MonitorEvent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let val = match self { + MonitorEvent::Conn(e) => format!("Connection Event: {e}"), + MonitorEvent::PeerPool(e) => format!("PeerPool Event: {e}"), + MonitorEvent::Discovery(e) => format!("Discovery Event: {e}"), + }; + write!(f, "{}", val) + } +} + +impl fmt::Display for ConnEvent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let val = match self { + ConnEvent::Connected(endpoint) => format!("Connected: {endpoint}"), + ConnEvent::ConnectFailed(endpoint) => format!("ConnectFailed: {endpoint}"), + ConnEvent::ConnectRetried(endpoint) => format!("ConnectRetried: {endpoint}"), + ConnEvent::AcceptFailed => "AcceptFailed".to_string(), + ConnEvent::Accepted(endpoint) => format!("Accepted: {endpoint}"), + ConnEvent::Disconnected(endpoint) => format!("Disconnected: {endpoint}"), + ConnEvent::Listening(endpoint) => format!("Listening: {endpoint}"), + ConnEvent::ListenFailed(endpoint) => format!("ListenFailed: {endpoint}"), + }; + write!(f, "{}", val) + } +} + +impl fmt::Display for PeerPoolEvent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let val = match self { + PeerPoolEvent::NewPeer(pid) => format!("NewPeer: {pid}"), + PeerPoolEvent::RemovePeer(pid) => format!("RemovePeer: {pid}"), + }; + write!(f, "{}", val) + } +} + +impl fmt::Display for DiscoveryEvent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let val = match self { + DiscoveryEvent::LookupStarted(endpoint) => format!("LookupStarted: {endpoint}"), + DiscoveryEvent::LookupFailed(endpoint) => format!("LookupFailed: {endpoint}"), + DiscoveryEvent::LookupSucceeded(endpoint, len) => { + format!("LookupSucceeded: {endpoint} {len}") + } + DiscoveryEvent::RefreshStarted => "RefreshStarted".to_string(), + }; + write!(f, "{}", val) + } +} + +impl From<ConnEvent> for MonitorEvent { + fn from(val: ConnEvent) -> Self { + MonitorEvent::Conn(val) + } +} + +impl From<PeerPoolEvent> for MonitorEvent { + fn from(val: PeerPoolEvent) -> Self { + MonitorEvent::PeerPool(val) + } +} + +impl From<DiscoveryEvent> for MonitorEvent { + fn from(val: DiscoveryEvent) -> Self { + MonitorEvent::Discovery(val) + } +} diff --git a/karyons_p2p/src/net/connection_queue.rs b/karyons_p2p/src/net/connection_queue.rs new file mode 100644 index 0000000..fbc4bfc --- /dev/null +++ b/karyons_p2p/src/net/connection_queue.rs @@ -0,0 +1,52 @@ +use std::sync::Arc; + +use smol::{channel::Sender, lock::Mutex}; + +use karyons_core::async_utils::CondVar; + +use karyons_net::Conn; + +use crate::net::ConnDirection; + +pub struct NewConn { + pub direction: ConnDirection, + pub conn: Conn, + pub disconnect_signal: Sender<()>, +} + +/// Connection queue +pub struct ConnQueue { + queue: Mutex<Vec<NewConn>>, + conn_available: CondVar, +} + +impl ConnQueue { + pub fn new() -> Arc<Self> { + Arc::new(Self { + queue: Mutex::new(Vec::new()), + conn_available: CondVar::new(), + }) + } + + /// Push a connection into the queue and wait for the disconnect signal + pub async fn handle(&self, conn: Conn, direction: ConnDirection) { + let (disconnect_signal, chan) = smol::channel::bounded(1); + let new_conn = NewConn { + direction, + conn, + disconnect_signal, + }; + self.queue.lock().await.push(new_conn); + self.conn_available.signal(); + let _ = chan.recv().await; + } + + /// Receive the next connection in the queue + pub async fn next(&self) -> NewConn { + let mut queue = self.queue.lock().await; + while queue.is_empty() { + queue = self.conn_available.wait(queue).await; + } + queue.pop().unwrap() + } +} diff --git a/karyons_p2p/src/net/connector.rs b/karyons_p2p/src/net/connector.rs new file mode 100644 index 0000000..72dc0d8 --- /dev/null +++ b/karyons_p2p/src/net/connector.rs @@ -0,0 +1,125 @@ +use std::{future::Future, sync::Arc}; + +use log::{trace, warn}; + +use karyons_core::{ + async_utils::{Backoff, TaskGroup, TaskResult}, + Executor, +}; +use karyons_net::{dial, Conn, Endpoint, NetError}; + +use crate::{ + monitor::{ConnEvent, Monitor}, + Result, +}; + +use super::slots::ConnectionSlots; + +/// Responsible for creating outbound connections with other peers. +pub struct Connector { + /// Managing spawned tasks. + task_group: TaskGroup, + + /// Manages available outbound slots. + connection_slots: Arc<ConnectionSlots>, + + /// The maximum number of retries allowed before successfully + /// establishing a connection. + max_retries: usize, + + /// Responsible for network and system monitoring. + monitor: Arc<Monitor>, +} + +impl Connector { + /// Creates a new Connector + pub fn new( + max_retries: usize, + connection_slots: Arc<ConnectionSlots>, + monitor: Arc<Monitor>, + ) -> Arc<Self> { + Arc::new(Self { + task_group: TaskGroup::new(), + monitor, + connection_slots, + max_retries, + }) + } + + /// Shuts down the connector + pub async fn shutdown(&self) { + self.task_group.cancel().await; + } + + /// Establish a connection to the specified `endpoint`. If the connection + /// attempt fails, it performs a backoff and retries until the maximum allowed + /// number of retries is exceeded. On a successful connection, it returns a + /// `Conn` instance. + /// + /// This method will block until it finds an available slot. + pub async fn connect(&self, endpoint: &Endpoint) -> Result<Conn> { + self.connection_slots.wait_for_slot().await; + self.connection_slots.add(); + + let mut retry = 0; + let backoff = Backoff::new(500, 2000); + while retry < self.max_retries { + let conn_result = dial(endpoint).await; + + if let Ok(conn) = conn_result { + self.monitor + .notify(&ConnEvent::Connected(endpoint.clone()).into()) + .await; + return Ok(conn); + } + + self.monitor + .notify(&ConnEvent::ConnectRetried(endpoint.clone()).into()) + .await; + + backoff.sleep().await; + + warn!("try to reconnect {endpoint}"); + retry += 1; + } + + self.monitor + .notify(&ConnEvent::ConnectFailed(endpoint.clone()).into()) + .await; + + self.connection_slots.remove().await; + Err(NetError::Timeout.into()) + } + + /// Establish a connection to the given `endpoint`. For each new connection, + /// it invokes the provided `callback`, and pass the connection to the callback. + pub async fn connect_with_cback<'a, Fut>( + self: &Arc<Self>, + ex: Executor<'a>, + endpoint: &Endpoint, + callback: impl FnOnce(Conn) -> Fut + Send + 'a, + ) -> Result<()> + where + Fut: Future<Output = Result<()>> + Send + 'a, + { + let conn = self.connect(endpoint).await?; + + let selfc = self.clone(); + let endpoint = endpoint.clone(); + let on_disconnect = |res| async move { + if let TaskResult::Completed(Err(err)) = res { + trace!("Outbound connection dropped: {err}"); + } + selfc + .monitor + .notify(&ConnEvent::Disconnected(endpoint.clone()).into()) + .await; + selfc.connection_slots.remove().await; + }; + + self.task_group + .spawn(ex.clone(), callback(conn), on_disconnect); + + Ok(()) + } +} diff --git a/karyons_p2p/src/net/listener.rs b/karyons_p2p/src/net/listener.rs new file mode 100644 index 0000000..d1a7bfb --- /dev/null +++ b/karyons_p2p/src/net/listener.rs @@ -0,0 +1,142 @@ +use std::{future::Future, sync::Arc}; + +use log::{error, info, trace}; + +use karyons_core::{ + async_utils::{TaskGroup, TaskResult}, + Executor, +}; + +use karyons_net::{listen, Conn, Endpoint, Listener as NetListener}; + +use crate::{ + monitor::{ConnEvent, Monitor}, + Result, +}; + +use super::slots::ConnectionSlots; + +/// Responsible for creating inbound connections with other peers. +pub struct Listener { + /// Managing spawned tasks. + task_group: TaskGroup, + + /// Manages available inbound slots. + connection_slots: Arc<ConnectionSlots>, + + /// Responsible for network and system monitoring. + monitor: Arc<Monitor>, +} + +impl Listener { + /// Creates a new Listener + pub fn new(connection_slots: Arc<ConnectionSlots>, monitor: Arc<Monitor>) -> Arc<Self> { + Arc::new(Self { + connection_slots, + task_group: TaskGroup::new(), + monitor, + }) + } + + /// Starts a listener on the given `endpoint`. For each incoming connection + /// that is accepted, it invokes the provided `callback`, and pass the + /// connection to the callback. + /// + /// Returns the resloved listening endpoint. + pub async fn start<'a, Fut>( + self: &Arc<Self>, + ex: Executor<'a>, + endpoint: Endpoint, + // https://github.com/rust-lang/rfcs/pull/2132 + callback: impl FnOnce(Conn) -> Fut + Clone + Send + 'a, + ) -> Result<Endpoint> + where + Fut: Future<Output = Result<()>> + Send + 'a, + { + let listener = match listen(&endpoint).await { + Ok(listener) => { + self.monitor + .notify(&ConnEvent::Listening(endpoint.clone()).into()) + .await; + listener + } + Err(err) => { + error!("Failed to listen on {endpoint}: {err}"); + self.monitor + .notify(&ConnEvent::ListenFailed(endpoint).into()) + .await; + return Err(err.into()); + } + }; + + let resolved_endpoint = listener.local_endpoint()?; + + info!("Start listening on {endpoint}"); + + let selfc = self.clone(); + self.task_group.spawn( + ex.clone(), + selfc.listen_loop(ex.clone(), listener, callback), + |res| async move { + if let TaskResult::Completed(Err(err)) = res { + error!("Listen loop stopped: {endpoint} {err}"); + } + }, + ); + Ok(resolved_endpoint) + } + + /// Shuts down the listener + pub async fn shutdown(&self) { + self.task_group.cancel().await; + } + + async fn listen_loop<'a, Fut>( + self: Arc<Self>, + ex: Executor<'a>, + listener: Box<dyn NetListener>, + callback: impl FnOnce(Conn) -> Fut + Clone + Send + 'a, + ) -> Result<()> + where + Fut: Future<Output = Result<()>> + Send + 'a, + { + loop { + // Wait for an available inbound slot. + self.connection_slots.wait_for_slot().await; + let result = listener.accept().await; + + let conn = match result { + Ok(c) => { + self.monitor + .notify(&ConnEvent::Accepted(c.peer_endpoint()?).into()) + .await; + c + } + Err(err) => { + error!("Failed to accept a new connection: {err}"); + self.monitor.notify(&ConnEvent::AcceptFailed.into()).await; + return Err(err.into()); + } + }; + + self.connection_slots.add(); + + let selfc = self.clone(); + let endpoint = conn.peer_endpoint()?; + let on_disconnect = |res| async move { + if let TaskResult::Completed(Err(err)) = res { + trace!("Inbound connection dropped: {err}"); + } + selfc + .monitor + .notify(&ConnEvent::Disconnected(endpoint).into()) + .await; + selfc.connection_slots.remove().await; + }; + + let callback = callback.clone(); + self.task_group + .spawn(ex.clone(), callback(conn), on_disconnect); + } + } +} diff --git a/karyons_p2p/src/net/mod.rs b/karyons_p2p/src/net/mod.rs new file mode 100644 index 0000000..9cdc748 --- /dev/null +++ b/karyons_p2p/src/net/mod.rs @@ -0,0 +1,27 @@ +mod connection_queue; +mod connector; +mod listener; +mod slots; + +pub use connection_queue::ConnQueue; +pub use connector::Connector; +pub use listener::Listener; +pub use slots::ConnectionSlots; + +use std::fmt; + +/// Defines the direction of a network connection. +#[derive(Clone, Debug)] +pub enum ConnDirection { + Inbound, + Outbound, +} + +impl fmt::Display for ConnDirection { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ConnDirection::Inbound => write!(f, "Inbound"), + ConnDirection::Outbound => write!(f, "Outbound"), + } + } +} diff --git a/karyons_p2p/src/net/slots.rs b/karyons_p2p/src/net/slots.rs new file mode 100644 index 0000000..99f0a78 --- /dev/null +++ b/karyons_p2p/src/net/slots.rs @@ -0,0 +1,54 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; + +use karyons_core::async_utils::CondWait; + +/// Manages available inbound and outbound slots. +pub struct ConnectionSlots { + /// A condvar for notifying when a slot become available. + signal: CondWait, + /// The number of occupied slots + slots: AtomicUsize, + /// The maximum number of slots. + max_slots: usize, +} + +impl ConnectionSlots { + /// Creates a new ConnectionSlots + pub fn new(max_slots: usize) -> Self { + Self { + signal: CondWait::new(), + slots: AtomicUsize::new(0), + max_slots, + } + } + + /// Returns the number of occupied slots + pub fn load(&self) -> usize { + self.slots.load(Ordering::SeqCst) + } + + /// Increases the occupied slots by one. + pub fn add(&self) { + self.slots.fetch_add(1, Ordering::SeqCst); + } + + /// Decreases the occupied slots by one and notifies the waiting signal + /// to start accepting/connecting new connections. + pub async fn remove(&self) { + self.slots.fetch_sub(1, Ordering::SeqCst); + if self.slots.load(Ordering::SeqCst) < self.max_slots { + self.signal.signal().await; + } + } + + /// Waits for a slot to become available. + pub async fn wait_for_slot(&self) { + if self.slots.load(Ordering::SeqCst) < self.max_slots { + return; + } + + // Wait for a signal + self.signal.wait().await; + self.signal.reset().await; + } +} diff --git a/karyons_p2p/src/peer/mod.rs b/karyons_p2p/src/peer/mod.rs new file mode 100644 index 0000000..ee0fdc4 --- /dev/null +++ b/karyons_p2p/src/peer/mod.rs @@ -0,0 +1,237 @@ +mod peer_id; + +pub use peer_id::PeerID; + +use std::sync::Arc; + +use log::{error, trace}; +use smol::{ + channel::{self, Receiver, Sender}, + lock::RwLock, +}; + +use karyons_core::{ + async_utils::{select, Either, TaskGroup, TaskResult}, + event::{ArcEventSys, EventListener, EventSys}, + utils::{decode, encode}, + Executor, +}; + +use karyons_net::Endpoint; + +use crate::{ + io_codec::{CodecMsg, IOCodec}, + message::{NetMsgCmd, ProtocolMsg, ShutdownMsg}, + net::ConnDirection, + peer_pool::{ArcPeerPool, WeakPeerPool}, + protocol::{Protocol, ProtocolEvent, ProtocolID}, + Config, Error, Result, +}; + +pub type ArcPeer = Arc<Peer>; + +pub struct Peer { + /// Peer's ID + id: PeerID, + + /// A weak pointer to `PeerPool` + peer_pool: WeakPeerPool, + + /// Holds the IOCodec for the peer connection + io_codec: IOCodec, + + /// Remote endpoint for the peer + remote_endpoint: Endpoint, + + /// The direction of the connection, either `Inbound` or `Outbound` + conn_direction: ConnDirection, + + /// A list of protocol IDs + protocol_ids: RwLock<Vec<ProtocolID>>, + + /// `EventSys` responsible for sending events to the protocols. + protocol_events: ArcEventSys<ProtocolID>, + + /// This channel is used to send a stop signal to the read loop. + stop_chan: (Sender<Result<()>>, Receiver<Result<()>>), + + /// Managing spawned tasks. + task_group: TaskGroup, +} + +impl Peer { + /// Creates a new peer + pub fn new( + peer_pool: WeakPeerPool, + id: &PeerID, + io_codec: IOCodec, + remote_endpoint: Endpoint, + conn_direction: ConnDirection, + ) -> ArcPeer { + Arc::new(Peer { + id: id.clone(), + peer_pool, + io_codec, + protocol_ids: RwLock::new(Vec::new()), + remote_endpoint, + conn_direction, + protocol_events: EventSys::new(), + task_group: TaskGroup::new(), + stop_chan: channel::bounded(1), + }) + } + + /// Run the peer + pub async fn run(self: Arc<Self>, ex: Executor<'_>) -> Result<()> { + self.start_protocols(ex.clone()).await; + self.read_loop().await + } + + /// Send a message to the peer connection using the specified protocol. + pub async fn send<T: CodecMsg>(&self, protocol_id: &ProtocolID, msg: &T) -> Result<()> { + let payload = encode(msg)?; + + let proto_msg = ProtocolMsg { + protocol_id: protocol_id.to_string(), + payload: payload.to_vec(), + }; + + self.io_codec.write(NetMsgCmd::Protocol, &proto_msg).await?; + Ok(()) + } + + /// Broadcast a message to all connected peers using the specified protocol. + pub async fn broadcast<T: CodecMsg>(&self, protocol_id: &ProtocolID, msg: &T) { + self.peer_pool().broadcast(protocol_id, msg).await; + } + + /// Shuts down the peer + pub async fn shutdown(&self) { + trace!("peer {} start shutting down", self.id); + + // Send shutdown event to all protocols + for protocol_id in self.protocol_ids.read().await.iter() { + self.protocol_events + .emit_by_topic(protocol_id, &ProtocolEvent::Shutdown) + .await; + } + + // Send a stop signal to the read loop + // + // No need to handle the error here; a dropped channel and + // sending a stop signal have the same effect. + let _ = self.stop_chan.0.try_send(Ok(())); + + // No need to handle the error here + let _ = self + .io_codec + .write(NetMsgCmd::Shutdown, &ShutdownMsg(0)) + .await; + + // Force shutting down + self.task_group.cancel().await; + } + + /// Check if the connection is Inbound + #[inline] + pub fn is_inbound(&self) -> bool { + match self.conn_direction { + ConnDirection::Inbound => true, + ConnDirection::Outbound => false, + } + } + + /// Returns the direction of the connection, which can be either `Inbound` + /// or `Outbound`. + #[inline] + pub fn direction(&self) -> &ConnDirection { + &self.conn_direction + } + + /// Returns the remote endpoint for the peer + #[inline] + pub fn remote_endpoint(&self) -> &Endpoint { + &self.remote_endpoint + } + + /// Return the peer's ID + #[inline] + pub fn id(&self) -> &PeerID { + &self.id + } + + /// Returns the `Config` instance. + pub fn config(&self) -> Arc<Config> { + self.peer_pool().config.clone() + } + + /// Registers a listener for the given Protocol `P`. + pub async fn register_listener<P: Protocol>(&self) -> EventListener<ProtocolID, ProtocolEvent> { + self.protocol_events.register(&P::id()).await + } + + /// Start a read loop to handle incoming messages from the peer connection. + async fn read_loop(&self) -> Result<()> { + loop { + let fut = select(self.stop_chan.1.recv(), self.io_codec.read()).await; + let result = match fut { + Either::Left(stop_signal) => { + trace!("Peer {} received a stop signal", self.id); + return stop_signal?; + } + Either::Right(result) => result, + }; + + let msg = result?; + + match msg.header.command { + NetMsgCmd::Protocol => { + let msg: ProtocolMsg = decode(&msg.payload)?.0; + + if !self.protocol_ids.read().await.contains(&msg.protocol_id) { + return Err(Error::UnsupportedProtocol(msg.protocol_id)); + } + + let proto_id = &msg.protocol_id; + let msg = ProtocolEvent::Message(msg.payload); + self.protocol_events.emit_by_topic(proto_id, &msg).await; + } + NetMsgCmd::Shutdown => { + return Err(Error::PeerShutdown); + } + command => return Err(Error::InvalidMsg(format!("Unexpected msg {:?}", command))), + } + } + } + + /// Start running the protocols for this peer connection. + async fn start_protocols(self: &Arc<Self>, ex: Executor<'_>) { + for (protocol_id, constructor) in self.peer_pool().protocols.read().await.iter() { + trace!("peer {} start protocol {protocol_id}", self.id); + let protocol = constructor(self.clone()); + + self.protocol_ids.write().await.push(protocol_id.clone()); + + let selfc = self.clone(); + let exc = ex.clone(); + let proto_idc = protocol_id.clone(); + + let on_failure = |result: TaskResult<Result<()>>| async move { + if let TaskResult::Completed(res) = result { + if res.is_err() { + error!("protocol {} stopped", proto_idc); + } + // Send a stop signal to read loop + let _ = selfc.stop_chan.0.try_send(res); + } + }; + + self.task_group + .spawn(ex.clone(), protocol.start(exc), on_failure); + } + } + + fn peer_pool(&self) -> ArcPeerPool { + self.peer_pool.upgrade().unwrap() + } +} diff --git a/karyons_p2p/src/peer/peer_id.rs b/karyons_p2p/src/peer/peer_id.rs new file mode 100644 index 0000000..c8aec7d --- /dev/null +++ b/karyons_p2p/src/peer/peer_id.rs @@ -0,0 +1,41 @@ +use bincode::{Decode, Encode}; +use rand::{rngs::OsRng, RngCore}; +use sha2::{Digest, Sha256}; + +/// Represents a unique identifier for a peer. +#[derive(Clone, Debug, Eq, PartialEq, Hash, Decode, Encode)] +pub struct PeerID(pub [u8; 32]); + +impl std::fmt::Display for PeerID { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let id = self.0[0..8] + .iter() + .map(|b| format!("{:x}", b)) + .collect::<Vec<String>>() + .join(""); + + write!(f, "{}", id) + } +} + +impl PeerID { + /// Creates a new PeerID. + pub fn new(src: &[u8]) -> Self { + let mut hasher = Sha256::new(); + hasher.update(src); + Self(hasher.finalize().into()) + } + + /// Generates a random PeerID. + pub fn random() -> Self { + let mut id: [u8; 32] = [0; 32]; + OsRng.fill_bytes(&mut id); + Self(id) + } +} + +impl From<[u8; 32]> for PeerID { + fn from(b: [u8; 32]) -> Self { + PeerID(b) + } +} diff --git a/karyons_p2p/src/peer_pool.rs b/karyons_p2p/src/peer_pool.rs new file mode 100644 index 0000000..eac4d3d --- /dev/null +++ b/karyons_p2p/src/peer_pool.rs @@ -0,0 +1,337 @@ +use std::{ + collections::HashMap, + sync::{Arc, Weak}, + time::Duration, +}; + +use log::{error, info, trace, warn}; +use smol::{ + channel::Sender, + lock::{Mutex, RwLock}, +}; + +use karyons_core::{ + async_utils::{TaskGroup, TaskResult}, + utils::decode, + Executor, +}; + +use karyons_net::Conn; + +use crate::{ + config::Config, + io_codec::{CodecMsg, IOCodec}, + message::{get_msg_payload, NetMsg, NetMsgCmd, VerAckMsg, VerMsg}, + monitor::{Monitor, PeerPoolEvent}, + net::ConnDirection, + net::ConnQueue, + peer::{ArcPeer, Peer, PeerID}, + protocol::{Protocol, ProtocolConstructor, ProtocolID}, + protocols::PingProtocol, + utils::{version_match, Version, VersionInt}, + Error, Result, +}; + +pub type ArcPeerPool = Arc<PeerPool>; +pub type WeakPeerPool = Weak<PeerPool>; + +pub struct PeerPool { + /// Peer's ID + pub id: PeerID, + + /// Connection queue + conn_queue: Arc<ConnQueue>, + + /// Holds the running peers. + peers: Mutex<HashMap<PeerID, ArcPeer>>, + + /// Hashmap contains protocol constructors. + pub(crate) protocols: RwLock<HashMap<ProtocolID, Box<ProtocolConstructor>>>, + + /// Hashmap contains protocol IDs and their versions. + protocol_versions: Arc<RwLock<HashMap<ProtocolID, Version>>>, + + /// Managing spawned tasks. + task_group: TaskGroup, + + /// The Configuration for the P2P network. + pub config: Arc<Config>, + + /// Responsible for network and system monitoring. + monitor: Arc<Monitor>, +} + +impl PeerPool { + /// Creates a new PeerPool + pub fn new( + id: &PeerID, + conn_queue: Arc<ConnQueue>, + config: Arc<Config>, + monitor: Arc<Monitor>, + ) -> Arc<Self> { + let protocols = RwLock::new(HashMap::new()); + let protocol_versions = Arc::new(RwLock::new(HashMap::new())); + + Arc::new(Self { + id: id.clone(), + conn_queue, + peers: Mutex::new(HashMap::new()), + protocols, + protocol_versions, + task_group: TaskGroup::new(), + monitor, + config, + }) + } + + /// Start + pub async fn start(self: &Arc<Self>, ex: Executor<'_>) -> Result<()> { + self.setup_protocols().await?; + let selfc = self.clone(); + self.task_group + .spawn(ex.clone(), selfc.listen_loop(ex.clone()), |_| async {}); + Ok(()) + } + + /// Listens to a new connection from the connection queue + pub async fn listen_loop(self: Arc<Self>, ex: Executor<'_>) { + loop { + let new_conn = self.conn_queue.next().await; + let disconnect_signal = new_conn.disconnect_signal; + + let result = self + .new_peer( + new_conn.conn, + &new_conn.direction, + disconnect_signal.clone(), + ex.clone(), + ) + .await; + + if result.is_err() { + let _ = disconnect_signal.send(()).await; + } + } + } + + /// Shuts down + pub async fn shutdown(&self) { + for (_, peer) in self.peers.lock().await.iter() { + peer.shutdown().await; + } + + self.task_group.cancel().await; + } + + /// Attach a custom protocol to the network + pub async fn attach_protocol<P: Protocol>(&self, c: Box<ProtocolConstructor>) -> Result<()> { + let protocol_versions = &mut self.protocol_versions.write().await; + let protocols = &mut self.protocols.write().await; + + protocol_versions.insert(P::id(), P::version()?); + protocols.insert(P::id(), Box::new(c) as Box<ProtocolConstructor>); + Ok(()) + } + + /// Returns the number of currently connected peers. + pub async fn peers_len(&self) -> usize { + self.peers.lock().await.len() + } + + /// Broadcast a message to all connected peers using the specified protocol. + pub async fn broadcast<T: CodecMsg>(&self, proto_id: &ProtocolID, msg: &T) { + for (pid, peer) in self.peers.lock().await.iter() { + if let Err(err) = peer.send(proto_id, msg).await { + error!("failed to send msg to {pid}: {err}"); + continue; + } + } + } + + /// Add a new peer to the peer list. + pub async fn new_peer( + self: &Arc<Self>, + conn: Conn, + conn_direction: &ConnDirection, + disconnect_signal: Sender<()>, + ex: Executor<'_>, + ) -> Result<PeerID> { + let endpoint = conn.peer_endpoint()?; + let io_codec = IOCodec::new(conn); + + // Do a handshake with a connection before creating a new peer. + let pid = self.do_handshake(&io_codec, conn_direction).await?; + + // TODO: Consider restricting the subnet for inbound connections + if self.contains_peer(&pid).await { + return Err(Error::PeerAlreadyConnected); + } + + // Create a new peer + let peer = Peer::new( + Arc::downgrade(self), + &pid, + io_codec, + endpoint.clone(), + conn_direction.clone(), + ); + + // Insert the new peer + self.peers.lock().await.insert(pid.clone(), peer.clone()); + + let selfc = self.clone(); + let pid_c = pid.clone(); + let on_disconnect = |result| async move { + if let TaskResult::Completed(_) = result { + if let Err(err) = selfc.remove_peer(&pid_c).await { + error!("Failed to remove peer {pid_c}: {err}"); + } + let _ = disconnect_signal.send(()).await; + } + }; + + self.task_group + .spawn(ex.clone(), peer.run(ex.clone()), on_disconnect); + + info!("Add new peer {pid}, direction: {conn_direction}, endpoint: {endpoint}"); + + self.monitor + .notify(&PeerPoolEvent::NewPeer(pid.clone()).into()) + .await; + Ok(pid) + } + + /// Checks if the peer list contains a peer with the given peer id + pub async fn contains_peer(&self, pid: &PeerID) -> bool { + self.peers.lock().await.contains_key(pid) + } + + /// Shuts down the peer and remove it from the peer list. + async fn remove_peer(&self, pid: &PeerID) -> Result<()> { + let mut peers = self.peers.lock().await; + let result = peers.remove(pid); + + drop(peers); + + let peer = match result { + Some(p) => p, + None => return Ok(()), + }; + + peer.shutdown().await; + + self.monitor + .notify(&PeerPoolEvent::RemovePeer(pid.clone()).into()) + .await; + + let endpoint = peer.remote_endpoint(); + let direction = peer.direction(); + + warn!("Peer {pid} removed, direction: {direction}, endpoint: {endpoint}",); + Ok(()) + } + + /// Attach the core protocols. + async fn setup_protocols(&self) -> Result<()> { + self.attach_protocol::<PingProtocol>(Box::new(PingProtocol::new)) + .await + } + + /// Initiate a handshake with a connection. + async fn do_handshake( + &self, + io_codec: &IOCodec, + conn_direction: &ConnDirection, + ) -> Result<PeerID> { + match conn_direction { + ConnDirection::Inbound => { + let pid = self.wait_vermsg(io_codec).await?; + self.send_verack(io_codec).await?; + Ok(pid) + } + ConnDirection::Outbound => { + self.send_vermsg(io_codec).await?; + self.wait_verack(io_codec).await + } + } + } + + /// Send a Version message + async fn send_vermsg(&self, io_codec: &IOCodec) -> Result<()> { + let pids = self.protocol_versions.read().await; + let protocols = pids.iter().map(|p| (p.0.clone(), p.1.v.clone())).collect(); + drop(pids); + + let vermsg = VerMsg { + peer_id: self.id.clone(), + protocols, + version: self.config.version.v.clone(), + }; + + trace!("Send VerMsg"); + io_codec.write(NetMsgCmd::Version, &vermsg).await?; + Ok(()) + } + + /// Wait for a Version message + /// + /// Returns the peer's ID upon successfully receiving the Version message. + async fn wait_vermsg(&self, io_codec: &IOCodec) -> Result<PeerID> { + let timeout = Duration::from_secs(self.config.handshake_timeout); + let msg: NetMsg = io_codec.read_timeout(timeout).await?; + + let payload = get_msg_payload!(Version, msg); + let (vermsg, _) = decode::<VerMsg>(&payload)?; + + if !version_match(&self.config.version.req, &vermsg.version) { + return Err(Error::IncompatibleVersion("system: {}".into())); + } + + self.protocols_match(&vermsg.protocols).await?; + + trace!("Received VerMsg from: {}", vermsg.peer_id); + Ok(vermsg.peer_id) + } + + /// Send a Verack message + async fn send_verack(&self, io_codec: &IOCodec) -> Result<()> { + let verack = VerAckMsg(self.id.clone()); + + trace!("Send VerAckMsg"); + io_codec.write(NetMsgCmd::Verack, &verack).await?; + Ok(()) + } + + /// Wait for a Verack message + /// + /// Returns the peer's ID upon successfully receiving the Verack message. + async fn wait_verack(&self, io_codec: &IOCodec) -> Result<PeerID> { + let timeout = Duration::from_secs(self.config.handshake_timeout); + let msg: NetMsg = io_codec.read_timeout(timeout).await?; + + let payload = get_msg_payload!(Verack, msg); + let (verack, _) = decode::<VerAckMsg>(&payload)?; + + trace!("Received VerAckMsg from: {}", verack.0); + Ok(verack.0) + } + + /// Check if the new connection has compatible protocols. + async fn protocols_match(&self, protocols: &HashMap<ProtocolID, VersionInt>) -> Result<()> { + for (n, pv) in protocols.iter() { + let pids = self.protocol_versions.read().await; + + match pids.get(n) { + Some(v) => { + if !version_match(&v.req, pv) { + return Err(Error::IncompatibleVersion(format!("{n} protocol: {pv}"))); + } + } + None => { + return Err(Error::UnsupportedProtocol(n.to_string())); + } + } + } + Ok(()) + } +} diff --git a/karyons_p2p/src/protocol.rs b/karyons_p2p/src/protocol.rs new file mode 100644 index 0000000..515efc6 --- /dev/null +++ b/karyons_p2p/src/protocol.rs @@ -0,0 +1,113 @@ +use std::sync::Arc; + +use async_trait::async_trait; + +use karyons_core::{event::EventValue, Executor}; + +use crate::{peer::ArcPeer, utils::Version, Result}; + +pub type ArcProtocol = Arc<dyn Protocol>; + +pub type ProtocolConstructor = dyn Fn(ArcPeer) -> Arc<dyn Protocol> + Send + Sync; + +pub type ProtocolID = String; + +/// Protocol event +#[derive(Debug, Clone)] +pub enum ProtocolEvent { + /// Message event, contains a vector of bytes. + Message(Vec<u8>), + /// Shutdown event signals the protocol to gracefully shut down. + Shutdown, +} + +impl EventValue for ProtocolEvent { + fn id() -> &'static str { + "ProtocolEvent" + } +} + +/// The Protocol trait defines the interface for core protocols +/// and custom protocols. +/// +/// # Example +/// ``` +/// use std::sync::Arc; +/// +/// use async_trait::async_trait; +/// use smol::Executor; +/// +/// use karyons_p2p::{ +/// protocol::{ArcProtocol, Protocol, ProtocolID, ProtocolEvent}, +/// Backend, PeerID, Config, Version, P2pError, ArcPeer}; +/// +/// pub struct NewProtocol { +/// peer: ArcPeer, +/// } +/// +/// impl NewProtocol { +/// fn new(peer: ArcPeer) -> ArcProtocol { +/// Arc::new(Self { +/// peer, +/// }) +/// } +/// } +/// +/// #[async_trait] +/// impl Protocol for NewProtocol { +/// async fn start(self: Arc<Self>, ex: Arc<Executor<'_>>) -> Result<(), P2pError> { +/// let listener = self.peer.register_listener::<Self>().await; +/// loop { +/// let event = listener.recv().await.unwrap(); +/// +/// match event { +/// ProtocolEvent::Message(msg) => { +/// println!("{:?}", msg); +/// } +/// ProtocolEvent::Shutdown => { +/// break; +/// } +/// } +/// } +/// +/// listener.cancel().await; +/// Ok(()) +/// } +/// +/// fn version() -> Result<Version, P2pError> { +/// "0.2.0, >0.1.0".parse() +/// } +/// +/// fn id() -> ProtocolID { +/// "NEWPROTOCOLID".into() +/// } +/// } +/// +/// async { +/// let peer_id = PeerID::random(); +/// let config = Config::default(); +/// +/// // Create a new Backend +/// let backend = Backend::new(peer_id, config); +/// +/// // Attach the NewProtocol +/// let c = move |peer| NewProtocol::new(peer); +/// backend.attach_protocol::<NewProtocol>(c).await.unwrap(); +/// }; +/// +/// ``` +#[async_trait] +pub trait Protocol: Send + Sync { + /// Start the protocol + async fn start(self: Arc<Self>, ex: Executor<'_>) -> Result<()>; + + /// Returns the version of the protocol. + fn version() -> Result<Version> + where + Self: Sized; + + /// Returns the unique ProtocolID associated with the protocol. + fn id() -> ProtocolID + where + Self: Sized; +} diff --git a/karyons_p2p/src/protocols/mod.rs b/karyons_p2p/src/protocols/mod.rs new file mode 100644 index 0000000..4a8f6b9 --- /dev/null +++ b/karyons_p2p/src/protocols/mod.rs @@ -0,0 +1,3 @@ +mod ping; + +pub use ping::PingProtocol; diff --git a/karyons_p2p/src/protocols/ping.rs b/karyons_p2p/src/protocols/ping.rs new file mode 100644 index 0000000..b337494 --- /dev/null +++ b/karyons_p2p/src/protocols/ping.rs @@ -0,0 +1,173 @@ +use std::{sync::Arc, time::Duration}; + +use async_trait::async_trait; +use bincode::{Decode, Encode}; +use log::trace; +use rand::{rngs::OsRng, RngCore}; +use smol::{ + channel, + channel::{Receiver, Sender}, + stream::StreamExt, + Timer, +}; + +use karyons_core::{ + async_utils::{select, timeout, Either, TaskGroup, TaskResult}, + event::EventListener, + utils::decode, + Executor, +}; + +use karyons_net::NetError; + +use crate::{ + peer::ArcPeer, + protocol::{ArcProtocol, Protocol, ProtocolEvent, ProtocolID}, + utils::Version, + Result, +}; + +const MAX_FAILUERS: u32 = 3; + +#[derive(Clone, Debug, Encode, Decode)] +enum PingProtocolMsg { + Ping([u8; 32]), + Pong([u8; 32]), +} + +pub struct PingProtocol { + peer: ArcPeer, + ping_interval: u64, + ping_timeout: u64, + task_group: TaskGroup, +} + +impl PingProtocol { + #[allow(clippy::new_ret_no_self)] + pub fn new(peer: ArcPeer) -> ArcProtocol { + let ping_interval = peer.config().ping_interval; + let ping_timeout = peer.config().ping_timeout; + Arc::new(Self { + peer, + ping_interval, + ping_timeout, + task_group: TaskGroup::new(), + }) + } + + async fn recv_loop( + &self, + listener: &EventListener<ProtocolID, ProtocolEvent>, + pong_chan: Sender<[u8; 32]>, + ) -> Result<()> { + loop { + let event = listener.recv().await?; + let msg_payload = match event.clone() { + ProtocolEvent::Message(m) => m, + ProtocolEvent::Shutdown => { + break; + } + }; + + let (msg, _) = decode::<PingProtocolMsg>(&msg_payload)?; + + match msg { + PingProtocolMsg::Ping(nonce) => { + trace!("Received Ping message {:?}", nonce); + self.peer + .send(&Self::id(), &PingProtocolMsg::Pong(nonce)) + .await?; + trace!("Send back Pong message {:?}", nonce); + } + PingProtocolMsg::Pong(nonce) => { + pong_chan.send(nonce).await?; + } + } + } + Ok(()) + } + + async fn ping_loop(self: Arc<Self>, chan: Receiver<[u8; 32]>) -> Result<()> { + let mut timer = Timer::interval(Duration::from_secs(self.ping_interval)); + let rng = &mut OsRng; + let mut retry = 0; + + while retry < MAX_FAILUERS { + timer.next().await; + + let mut ping_nonce: [u8; 32] = [0; 32]; + rng.fill_bytes(&mut ping_nonce); + + trace!("Send Ping message {:?}", ping_nonce); + self.peer + .send(&Self::id(), &PingProtocolMsg::Ping(ping_nonce)) + .await?; + + let d = Duration::from_secs(self.ping_timeout); + + // Wait for Pong message + let pong_msg = match timeout(d, chan.recv()).await { + Ok(m) => m?, + Err(_) => { + retry += 1; + continue; + } + }; + + trace!("Received Pong message {:?}", pong_msg); + + if pong_msg != ping_nonce { + retry += 1; + continue; + } + } + + Err(NetError::Timeout.into()) + } +} + +#[async_trait] +impl Protocol for PingProtocol { + async fn start(self: Arc<Self>, ex: Executor<'_>) -> Result<()> { + trace!("Start Ping protocol"); + let (pong_chan, pong_chan_recv) = channel::bounded(1); + let (stop_signal_s, stop_signal) = channel::bounded::<Result<()>>(1); + + let selfc = self.clone(); + self.task_group.spawn( + ex.clone(), + selfc.clone().ping_loop(pong_chan_recv.clone()), + |res| async move { + if let TaskResult::Completed(result) = res { + let _ = stop_signal_s.send(result).await; + } + }, + ); + + let listener = self.peer.register_listener::<Self>().await; + + let result = select(self.recv_loop(&listener, pong_chan), stop_signal.recv()).await; + listener.cancel().await; + self.task_group.cancel().await; + + match result { + Either::Left(res) => { + trace!("Receive loop stopped {:?}", res); + res + } + Either::Right(res) => { + let res = res?; + trace!("Ping loop stopped {:?}", res); + res + } + } + } + + fn version() -> Result<Version> { + "0.1.0".parse() + } + + fn id() -> ProtocolID { + "PING".into() + } +} diff --git a/karyons_p2p/src/routing_table/bucket.rs b/karyons_p2p/src/routing_table/bucket.rs new file mode 100644 index 0000000..13edd24 --- /dev/null +++ b/karyons_p2p/src/routing_table/bucket.rs @@ -0,0 +1,123 @@ +use super::{Entry, Key}; + +use rand::{rngs::OsRng, seq::SliceRandom}; + +/// BITFLAGS represent the status of an Entry within a bucket. +pub type EntryStatusFlag = u16; + +/// The entry is connected. +pub const CONNECTED_ENTRY: EntryStatusFlag = 0b00001; + +/// The entry is disconnected. This will increase the failure counter. +pub const DISCONNECTED_ENTRY: EntryStatusFlag = 0b00010; + +/// The entry is ready to reconnect, meaning it has either been added and +/// has no connection attempts, or it has been refreshed. +pub const PENDING_ENTRY: EntryStatusFlag = 0b00100; + +/// The entry is unreachable. This will increase the failure counter. +pub const UNREACHABLE_ENTRY: EntryStatusFlag = 0b01000; + +/// The entry is unstable. This will increase the failure counter. +pub const UNSTABLE_ENTRY: EntryStatusFlag = 0b10000; + +#[allow(dead_code)] +pub const ALL_ENTRY: EntryStatusFlag = 0b11111; + +/// A BucketEntry represents a peer in the routing table. +#[derive(Clone, Debug)] +pub struct BucketEntry { + pub status: EntryStatusFlag, + pub entry: Entry, + pub failures: u32, + pub last_seen: i64, +} + +impl BucketEntry { + pub fn is_connected(&self) -> bool { + self.status ^ CONNECTED_ENTRY == 0 + } + + pub fn is_unreachable(&self) -> bool { + self.status ^ UNREACHABLE_ENTRY == 0 + } + + pub fn is_unstable(&self) -> bool { + self.status ^ UNSTABLE_ENTRY == 0 + } +} + +/// The number of entries that can be stored within a single bucket. +pub const BUCKET_SIZE: usize = 20; + +/// A Bucket represents a group of entries in the routing table. +#[derive(Debug, Clone)] +pub struct Bucket { + entries: Vec<BucketEntry>, +} + +impl Bucket { + /// Creates a new empty Bucket + pub fn new() -> Self { + Self { + entries: Vec::with_capacity(BUCKET_SIZE), + } + } + + /// Add an entry to the bucket. + pub fn add(&mut self, entry: &Entry) { + self.entries.push(BucketEntry { + status: PENDING_ENTRY, + entry: entry.clone(), + failures: 0, + last_seen: chrono::Utc::now().timestamp(), + }) + } + + /// Get the number of entries in the bucket. + pub fn len(&self) -> usize { + self.entries.len() + } + + /// Returns an iterator over the entries in the bucket. + pub fn iter(&self) -> impl Iterator<Item = &BucketEntry> { + self.entries.iter() + } + + /// Remove an entry. + pub fn remove(&mut self, key: &Key) { + let position = self.entries.iter().position(|e| &e.entry.key == key); + if let Some(i) = position { + self.entries.remove(i); + } + } + + /// Returns an iterator of entries in random order. + pub fn random_iter(&self, amount: usize) -> impl Iterator<Item = &BucketEntry> { + self.entries.choose_multiple(&mut OsRng, amount) + } + + /// Updates the status of an entry in the bucket identified by the given key. + /// + /// If the key is not found in the bucket, no action is taken. + /// + /// This will also update the last_seen field and increase the failures + /// counter for the bucket entry according to the new status. + pub fn update_entry(&mut self, key: &Key, entry_flag: EntryStatusFlag) { + if let Some(e) = self.entries.iter_mut().find(|e| &e.entry.key == key) { + e.status = entry_flag; + if e.is_unreachable() || e.is_unstable() { + e.failures += 1; + } + + if !e.is_unreachable() { + e.last_seen = chrono::Utc::now().timestamp(); + } + } + } + + /// Check if the bucket contains the given key. + pub fn contains_key(&self, key: &Key) -> bool { + self.entries.iter().any(|e| &e.entry.key == key) + } +} diff --git a/karyons_p2p/src/routing_table/entry.rs b/karyons_p2p/src/routing_table/entry.rs new file mode 100644 index 0000000..b3f219f --- /dev/null +++ b/karyons_p2p/src/routing_table/entry.rs @@ -0,0 +1,41 @@ +use bincode::{Decode, Encode}; + +use karyons_net::{Addr, Port}; + +/// Specifies the size of the key, in bytes. +pub const KEY_SIZE: usize = 32; + +/// An Entry represents a peer in the routing table. +#[derive(Encode, Decode, Clone, Debug)] +pub struct Entry { + /// The unique key identifying the peer. + pub key: Key, + /// The IP address of the peer. + pub addr: Addr, + /// TCP port + pub port: Port, + /// UDP/TCP port + pub discovery_port: Port, +} + +impl PartialEq for Entry { + fn eq(&self, other: &Self) -> bool { + // XXX this should also compare both addresses (the self.addr == other.addr) + self.key == other.key + } +} + +/// The unique key identifying the peer. +pub type Key = [u8; KEY_SIZE]; + +/// Calculates the XOR distance between two provided keys. +/// +/// The XOR distance is a metric used in Kademlia to measure the closeness +/// of keys. +pub fn xor_distance(key: &Key, other: &Key) -> Key { + let mut res = [0; 32]; + for (i, (k, o)) in key.iter().zip(other.iter()).enumerate() { + res[i] = k ^ o; + } + res +} diff --git a/karyons_p2p/src/routing_table/mod.rs b/karyons_p2p/src/routing_table/mod.rs new file mode 100644 index 0000000..abf9a08 --- /dev/null +++ b/karyons_p2p/src/routing_table/mod.rs @@ -0,0 +1,461 @@ +mod bucket; +mod entry; +pub use bucket::{ + Bucket, BucketEntry, EntryStatusFlag, CONNECTED_ENTRY, DISCONNECTED_ENTRY, PENDING_ENTRY, + UNREACHABLE_ENTRY, UNSTABLE_ENTRY, +}; +pub use entry::{xor_distance, Entry, Key}; + +use rand::{rngs::OsRng, seq::SliceRandom}; + +use crate::utils::subnet_match; + +use bucket::BUCKET_SIZE; +use entry::KEY_SIZE; + +/// The total number of buckets in the routing table. +const TABLE_SIZE: usize = 32; + +/// The distance limit for the closest buckets. +const DISTANCE_LIMIT: usize = 32; + +/// The maximum number of matched subnets allowed within a single bucket. +const MAX_MATCHED_SUBNET_IN_BUCKET: usize = 1; + +/// The maximum number of matched subnets across the entire routing table. +const MAX_MATCHED_SUBNET_IN_TABLE: usize = 6; + +/// Represents the possible result when adding a new entry. +#[derive(Debug)] +pub enum AddEntryResult { + /// The entry is added. + Added, + /// The entry is already exists. + Exists, + /// The entry is ignored. + Ignored, + /// The entry is restricted and not allowed. + Restricted, +} + +/// This is a modified version of the Kademlia Distributed Hash Table (DHT). +/// https://en.wikipedia.org/wiki/Kademlia +#[derive(Debug)] +pub struct RoutingTable { + key: Key, + buckets: Vec<Bucket>, +} + +impl RoutingTable { + /// Creates a new RoutingTable + pub fn new(key: Key) -> Self { + let buckets: Vec<Bucket> = (0..TABLE_SIZE).map(|_| Bucket::new()).collect(); + Self { key, buckets } + } + + /// Adds a new entry to the table and returns a result indicating success, + /// failure, or restrictions. + pub fn add_entry(&mut self, entry: Entry) -> AddEntryResult { + // Determine the index of the bucket where the entry should be placed. + let bucket_idx = match self.bucket_index(&entry.key) { + Some(i) => i, + None => return AddEntryResult::Ignored, + }; + + let bucket = &self.buckets[bucket_idx]; + + // Check if the entry already exists in the bucket. + if bucket.contains_key(&entry.key) { + return AddEntryResult::Exists; + } + + // Check if the entry is restricted. + if self.subnet_restricted(bucket_idx, &entry) { + return AddEntryResult::Restricted; + } + + let bucket = &mut self.buckets[bucket_idx]; + + // If the bucket has free space, add the entry and return success. + if bucket.len() < BUCKET_SIZE { + bucket.add(&entry); + return AddEntryResult::Added; + } + + // If the bucket is full, the entry is ignored. + AddEntryResult::Ignored + } + + /// Check if the table contains the given key. + pub fn contains_key(&self, key: &Key) -> bool { + // Determine the bucket index for the given key. + let bucket_idx = match self.bucket_index(key) { + Some(bi) => bi, + None => return false, + }; + + let bucket = &self.buckets[bucket_idx]; + bucket.contains_key(key) + } + + /// Updates the status of an entry in the routing table identified + /// by the given key. + /// + /// If the key is not found, no action is taken. + pub fn update_entry(&mut self, key: &Key, entry_flag: EntryStatusFlag) { + // Determine the bucket index for the given key. + let bucket_idx = match self.bucket_index(key) { + Some(bi) => bi, + None => return, + }; + + let bucket = &mut self.buckets[bucket_idx]; + bucket.update_entry(key, entry_flag); + } + + /// Returns a list of bucket indexes that are closest to the given target key. + pub fn bucket_indexes(&self, target_key: &Key) -> Vec<usize> { + let mut indexes = vec![]; + + // Determine the primary bucket index for the target key. + let bucket_idx = self.bucket_index(target_key).unwrap_or(0); + + indexes.push(bucket_idx); + + // Add additional bucket indexes within a certain distance limit. + for i in 1..DISTANCE_LIMIT { + if bucket_idx >= i && bucket_idx - i >= 1 { + indexes.push(bucket_idx - i); + } + + if bucket_idx + i < (TABLE_SIZE - 1) { + indexes.push(bucket_idx + i); + } + } + + indexes + } + + /// Returns a list of the closest entries to the given target key, limited by max_entries. + pub fn closest_entries(&self, target_key: &Key, max_entries: usize) -> Vec<Entry> { + let mut entries: Vec<Entry> = vec![]; + + // Collect entries + 'outer: for idx in self.bucket_indexes(target_key) { + let bucket = &self.buckets[idx]; + for bucket_entry in bucket.iter() { + if bucket_entry.is_unreachable() || bucket_entry.is_unstable() { + continue; + } + + entries.push(bucket_entry.entry.clone()); + if entries.len() == max_entries { + break 'outer; + } + } + } + + // Sort the entries by their distance to the target key. + entries.sort_by(|a, b| { + xor_distance(target_key, &a.key).cmp(&xor_distance(target_key, &b.key)) + }); + + entries + } + + /// Removes an entry with the given key from the routing table, if it exists. + pub fn remove_entry(&mut self, key: &Key) { + // Determine the bucket index for the given key. + let bucket_idx = match self.bucket_index(key) { + Some(bi) => bi, + None => return, + }; + + let bucket = &mut self.buckets[bucket_idx]; + bucket.remove(key); + } + + /// Returns an iterator of entries. + pub fn iter(&self) -> impl Iterator<Item = &Bucket> { + self.buckets.iter() + } + + /// Returns a random entry from the routing table. + pub fn random_entry(&self, entry_flag: EntryStatusFlag) -> Option<&Entry> { + for bucket in self.buckets.choose_multiple(&mut OsRng, self.buckets.len()) { + for entry in bucket.random_iter(bucket.len()) { + if entry.status & entry_flag == 0 { + continue; + } + return Some(&entry.entry); + } + } + + None + } + + // Returns the bucket index for a given key in the table. + fn bucket_index(&self, key: &Key) -> Option<usize> { + // Calculate the XOR distance between the self key and the provided key. + let distance = xor_distance(&self.key, key); + + for (i, b) in distance.iter().enumerate() { + if *b != 0 { + let lz = i * 8 + b.leading_zeros() as usize; + let bits = KEY_SIZE * 8 - 1; + let idx = (bits - lz) / 8; + return Some(idx); + } + } + None + } + + /// This function iterate through the routing table and counts how many + /// entries in the same subnet as the given Entry are already present. + /// + /// If the number of matching entries in the same bucket exceeds a + /// threshold (MAX_MATCHED_SUBNET_IN_BUCKET), or if the total count of + /// matching entries in the entire table exceeds a threshold + /// (MAX_MATCHED_SUBNET_IN_TABLE), the addition of the Entry + /// is considered restricted and returns true. + fn subnet_restricted(&self, idx: usize, entry: &Entry) -> bool { + let mut bucket_count = 0; + let mut table_count = 0; + + // Iterate through the routing table's buckets and entries to check + // for subnet matches. + for (i, bucket) in self.buckets.iter().enumerate() { + for e in bucket.iter() { + // If there is a subnet match, update the counts. + let matched = subnet_match(&e.entry.addr, &entry.addr); + if matched { + if i == idx { + bucket_count += 1; + } + table_count += 1; + } + + // If the number of matched entries in the same bucket exceeds + // the limit, return true + if bucket_count >= MAX_MATCHED_SUBNET_IN_BUCKET { + return true; + } + } + + // If the total matched entries in the table exceed the limit, + // return true. + if table_count >= MAX_MATCHED_SUBNET_IN_TABLE { + return true; + } + } + + // If no subnet restrictions are encountered, return false. + false + } +} + +#[cfg(test)] +mod tests { + use super::bucket::ALL_ENTRY; + use super::*; + + use karyons_net::Addr; + + struct Setup { + local_key: Key, + keys: Vec<Key>, + } + + fn new_entry(key: &Key, addr: &Addr, port: u16, discovery_port: u16) -> Entry { + Entry { + key: key.clone(), + addr: addr.clone(), + port, + discovery_port, + } + } + + impl Setup { + fn new() -> Self { + let keys = vec![ + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 0, 1, 1, 2, + ], + [ + 0, 0, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 3, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 1, 18, 0, 0, 0, + 0, 0, 0, 0, 0, 4, + ], + [ + 223, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 5, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 50, 1, 18, 0, 0, 0, + 0, 0, 0, 0, 0, 6, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 50, 1, 18, 0, 0, + 0, 0, 0, 0, 0, 0, 7, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 50, 1, 18, 0, 0, + 0, 0, 0, 0, 0, 0, 8, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 10, 50, 1, 18, 0, 0, + 0, 0, 0, 0, 0, 0, 9, + ], + ]; + + Self { + local_key: [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ], + keys, + } + } + + fn entries(&self) -> Vec<Entry> { + let mut entries = vec![]; + for (i, key) in self.keys.iter().enumerate() { + entries.push(new_entry( + key, + &Addr::Ip(format!("127.0.{i}.1").parse().unwrap()), + 3000, + 3010, + )); + } + entries + } + + fn table(&self) -> RoutingTable { + let mut table = RoutingTable::new(self.local_key.clone()); + + for entry in self.entries() { + let res = table.add_entry(entry); + assert!(matches!(res, AddEntryResult::Added)); + } + + table + } + } + + #[test] + fn test_bucket_index() { + let setup = Setup::new(); + let table = setup.table(); + + assert_eq!(table.bucket_index(&setup.local_key), None); + assert_eq!(table.bucket_index(&setup.keys[0]), Some(0)); + assert_eq!(table.bucket_index(&setup.keys[1]), Some(5)); + assert_eq!(table.bucket_index(&setup.keys[2]), Some(26)); + assert_eq!(table.bucket_index(&setup.keys[3]), Some(11)); + assert_eq!(table.bucket_index(&setup.keys[4]), Some(31)); + assert_eq!(table.bucket_index(&setup.keys[5]), Some(11)); + assert_eq!(table.bucket_index(&setup.keys[6]), Some(12)); + assert_eq!(table.bucket_index(&setup.keys[7]), Some(13)); + assert_eq!(table.bucket_index(&setup.keys[8]), Some(14)); + } + + #[test] + fn test_closest_entries() { + let setup = Setup::new(); + let table = setup.table(); + let entries = setup.entries(); + + assert_eq!( + table.closest_entries(&setup.keys[5], 8), + vec![ + entries[5].clone(), + entries[3].clone(), + entries[1].clone(), + entries[6].clone(), + entries[7].clone(), + entries[8].clone(), + entries[2].clone(), + ] + ); + + assert_eq!( + table.closest_entries(&setup.keys[4], 2), + vec![entries[4].clone(), entries[2].clone()] + ); + } + + #[test] + fn test_random_entry() { + let setup = Setup::new(); + let mut table = setup.table(); + let entries = setup.entries(); + + let entry = table.random_entry(ALL_ENTRY); + assert!(matches!(entry, Some(&_))); + + let entry = table.random_entry(CONNECTED_ENTRY); + assert!(matches!(entry, None)); + + for entry in entries { + table.remove_entry(&entry.key); + } + + let entry = table.random_entry(ALL_ENTRY); + assert!(matches!(entry, None)); + } + + #[test] + fn test_add_entries() { + let setup = Setup::new(); + let mut table = setup.table(); + + let key = [ + 0, 0, 0, 0, 0, 0, 0, 1, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 5, + ]; + + let key2 = [ + 0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 5, + ]; + + let entry1 = new_entry(&key, &Addr::Ip("240.120.3.1".parse().unwrap()), 3000, 3010); + assert!(matches!( + table.add_entry(entry1.clone()), + AddEntryResult::Added + )); + + assert!(matches!(table.add_entry(entry1), AddEntryResult::Exists)); + + let entry2 = new_entry(&key2, &Addr::Ip("240.120.3.2".parse().unwrap()), 3000, 3010); + assert!(matches!( + table.add_entry(entry2), + AddEntryResult::Restricted + )); + + let mut key: [u8; 32] = [0; 32]; + + for i in 0..BUCKET_SIZE { + key[i] += 1; + let entry = new_entry( + &key, + &Addr::Ip(format!("127.0.{i}.1").parse().unwrap()), + 3000, + 3010, + ); + table.add_entry(entry); + } + + key[BUCKET_SIZE] += 1; + let entry = new_entry(&key, &Addr::Ip("125.20.0.1".parse().unwrap()), 3000, 3010); + assert!(matches!(table.add_entry(entry), AddEntryResult::Ignored)); + } +} diff --git a/karyons_p2p/src/utils/mod.rs b/karyons_p2p/src/utils/mod.rs new file mode 100644 index 0000000..e8ff9d0 --- /dev/null +++ b/karyons_p2p/src/utils/mod.rs @@ -0,0 +1,21 @@ +mod version; + +pub use version::{version_match, Version, VersionInt}; + +use std::net::IpAddr; + +use karyons_net::Addr; + +/// Check if two addresses belong to the same subnet. +pub fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool { + match (addr, other_addr) { + (Addr::Ip(IpAddr::V4(ip)), Addr::Ip(IpAddr::V4(other_ip))) => { + // XXX Consider moving this to a different location + if other_ip.is_loopback() && ip.is_loopback() { + return false; + } + ip.octets()[0..3] == other_ip.octets()[0..3] + } + _ => false, + } +} diff --git a/karyons_p2p/src/utils/version.rs b/karyons_p2p/src/utils/version.rs new file mode 100644 index 0000000..4986495 --- /dev/null +++ b/karyons_p2p/src/utils/version.rs @@ -0,0 +1,93 @@ +use std::str::FromStr; + +use bincode::{Decode, Encode}; +use semver::VersionReq; + +use crate::{Error, Result}; + +/// Represents the network version and protocol version used in Karyons p2p. +/// +/// # Example +/// +/// ``` +/// use karyons_p2p::Version; +/// +/// let version: Version = "0.2.0, >0.1.0".parse().unwrap(); +/// +/// let version: Version = "0.2.0".parse().unwrap(); +/// +/// ``` +#[derive(Debug, Clone)] +pub struct Version { + pub v: VersionInt, + pub req: VersionReq, +} + +impl Version { + /// Creates a new Version + pub fn new(v: VersionInt, req: VersionReq) -> Self { + Self { v, req } + } +} + +#[derive(Debug, Decode, Encode, Clone)] +pub struct VersionInt { + major: u64, + minor: u64, + patch: u64, +} + +impl FromStr for Version { + type Err = Error; + + fn from_str(s: &str) -> Result<Self> { + let v: Vec<&str> = s.split(", ").collect(); + if v.is_empty() || v.len() > 2 { + return Err(Error::ParseError(format!("Invalid version{s}"))); + } + + let version: VersionInt = v[0].parse()?; + let req: VersionReq = if v.len() > 1 { v[1] } else { v[0] }.parse()?; + + Ok(Self { v: version, req }) + } +} + +impl FromStr for VersionInt { + type Err = Error; + + fn from_str(s: &str) -> Result<Self> { + let v: Vec<&str> = s.split('.').collect(); + if v.len() < 2 || v.len() > 3 { + return Err(Error::ParseError(format!("Invalid version{s}"))); + } + + let major = v[0].parse::<u64>()?; + let minor = v[1].parse::<u64>()?; + let patch = v.get(2).unwrap_or(&"0").parse::<u64>()?; + + Ok(Self { + major, + minor, + patch, + }) + } +} + +impl std::fmt::Display for VersionInt { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}.{}", self.major, self.minor, self.patch) + } +} + +impl From<VersionInt> for semver::Version { + fn from(v: VersionInt) -> Self { + semver::Version::new(v.major, v.minor, v.patch) + } +} + +/// Check if a version satisfies a version request. +pub fn version_match(version_req: &VersionReq, version: &VersionInt) -> bool { + let version: semver::Version = version.clone().into(); + version_req.matches(&version) +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1 @@ + |