From 98a1de91a2dae06323558422c239e5a45fc86e7b Mon Sep 17 00:00:00 2001 From: hozan23 Date: Tue, 28 Nov 2023 22:41:33 +0300 Subject: implement TLS for inbound and outbound connections --- p2p/src/backend.rs | 43 ++++++--- p2p/src/codec.rs | 120 +++++++++++++++++++++++ p2p/src/config.rs | 7 +- p2p/src/connection.rs | 2 +- p2p/src/connector.rs | 56 ++++++++--- p2p/src/discovery/lookup.rs | 71 +++++++++----- p2p/src/discovery/mod.rs | 30 ++++-- p2p/src/discovery/refresh.rs | 4 +- p2p/src/error.rs | 18 ++++ p2p/src/io_codec.rs | 132 ------------------------- p2p/src/lib.rs | 13 +-- p2p/src/listener.rs | 65 +++++++++---- p2p/src/message.rs | 2 +- p2p/src/monitor.rs | 4 +- p2p/src/peer/mod.rs | 23 ++--- p2p/src/peer/peer_id.rs | 17 ++++ p2p/src/peer_pool.rs | 46 ++++----- p2p/src/protocol.rs | 7 +- p2p/src/protocols/ping.rs | 6 +- p2p/src/routing_table/entry.rs | 2 +- p2p/src/routing_table/mod.rs | 19 +++- p2p/src/slots.rs | 2 +- p2p/src/tls_config.rs | 214 +++++++++++++++++++++++++++++++++++++++++ p2p/src/utils/mod.rs | 21 ---- p2p/src/utils/version.rs | 93 ------------------ p2p/src/version.rs | 93 ++++++++++++++++++ 26 files changed, 730 insertions(+), 380 deletions(-) create mode 100644 p2p/src/codec.rs delete mode 100644 p2p/src/io_codec.rs create mode 100644 p2p/src/tls_config.rs delete mode 100644 p2p/src/utils/mod.rs delete mode 100644 p2p/src/utils/version.rs create mode 100644 p2p/src/version.rs (limited to 'p2p/src') diff --git a/p2p/src/backend.rs b/p2p/src/backend.rs index 2e34f47..56d79f7 100644 --- a/p2p/src/backend.rs +++ b/p2p/src/backend.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use log::info; -use karyons_core::{pubsub::Subscription, GlobalExecutor}; +use karyons_core::{key_pair::KeyPair, pubsub::Subscription, GlobalExecutor}; use crate::{ config::Config, @@ -22,8 +22,8 @@ pub struct Backend { /// The Configuration for the P2P network. config: Arc, - /// Peer ID. - id: PeerID, + /// Identity Key pair + key_pair: KeyPair, /// Responsible for network and system monitoring. monitor: Arc, @@ -37,17 +37,34 @@ pub struct Backend { impl Backend { /// Creates a new Backend. - pub fn new(id: PeerID, config: Config, ex: GlobalExecutor) -> ArcBackend { + pub fn new(key_pair: &KeyPair, config: Config, ex: GlobalExecutor) -> ArcBackend { let config = Arc::new(config); let monitor = Arc::new(Monitor::new()); - let cq = ConnQueue::new(); - - let peer_pool = PeerPool::new(&id, cq.clone(), config.clone(), monitor.clone(), ex.clone()); - - let discovery = Discovery::new(&id, cq, config.clone(), monitor.clone(), ex); + let conn_queue = ConnQueue::new(); + + let peer_id = PeerID::try_from(key_pair.public()) + .expect("Derive a peer id from the provided key pair."); + info!("PeerID: {}", peer_id); + + let peer_pool = PeerPool::new( + &peer_id, + conn_queue.clone(), + config.clone(), + monitor.clone(), + ex.clone(), + ); + + let discovery = Discovery::new( + key_pair, + &peer_id, + conn_queue, + config.clone(), + monitor.clone(), + ex, + ); Arc::new(Self { - id: id.clone(), + key_pair: key_pair.clone(), monitor, discovery, config, @@ -57,7 +74,6 @@ impl Backend { /// Run the Backend, starting the PeerPool and Discovery instances. pub async fn run(self: &Arc) -> Result<()> { - info!("Run the backend {}", self.id); self.peer_pool.start().await?; self.discovery.start().await?; Ok(()) @@ -81,6 +97,11 @@ impl Backend { self.config.clone() } + /// Returns the `KeyPair`. + pub async fn key_pair(&self) -> &KeyPair { + &self.key_pair + } + /// Returns the number of occupied inbound slots. pub fn inbound_slots(&self) -> usize { self.discovery.inbound_slots.load() diff --git a/p2p/src/codec.rs b/p2p/src/codec.rs new file mode 100644 index 0000000..e521824 --- /dev/null +++ b/p2p/src/codec.rs @@ -0,0 +1,120 @@ +use std::time::Duration; + +use bincode::{Decode, Encode}; + +use karyons_core::{ + async_util::timeout, + util::{decode, encode, encode_into_slice}, +}; + +use karyons_net::{Connection, NetError}; + +use crate::{ + message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE}, + Error, Result, +}; + +pub trait CodecMsg: Decode + Encode + std::fmt::Debug {} +impl CodecMsg for T {} + +/// A Codec working with generic network connections. +/// +/// It is responsible for both decoding data received from the network and +/// encoding data before sending it. +pub struct Codec { + conn: Box, +} + +impl Codec { + /// Creates a new Codec. + pub fn new(conn: Box) -> Self { + Self { conn } + } + + /// Reads a message of type `NetMsg` from the connection. + /// + /// It reads the first 6 bytes as the header of the message, then reads + /// and decodes the remaining message data based on the determined header. + pub async fn read(&self) -> Result { + // Read 6 bytes to get the header of the incoming message + let mut buf = [0; MSG_HEADER_SIZE]; + self.read_exact(&mut buf).await?; + + // Decode the header from bytes to NetMsgHeader + let (header, _) = decode::(&buf)?; + + if header.payload_size > MAX_ALLOWED_MSG_SIZE { + return Err(Error::InvalidMsg( + "Message exceeds the maximum allowed size".to_string(), + )); + } + + // Create a buffer to hold the message based on its length + let mut payload = vec![0; header.payload_size as usize]; + self.read_exact(&mut payload).await?; + + Ok(NetMsg { header, payload }) + } + + /// Writes a message of type `T` to the connection. + /// + /// Before appending the actual message payload, it calculates the length of + /// the encoded message in bytes and appends this length to the message header. + pub async fn write(&self, command: NetMsgCmd, msg: &T) -> Result<()> { + let payload = encode(msg)?; + + // Create a buffer to hold the message header (6 bytes) + let header_buf = &mut [0; MSG_HEADER_SIZE]; + let header = NetMsgHeader { + command, + payload_size: payload.len() as u32, + }; + encode_into_slice(&header, header_buf)?; + + let mut buffer = vec![]; + // Append the header bytes to the buffer + buffer.extend_from_slice(header_buf); + // Append the message payload to the buffer + buffer.extend_from_slice(&payload); + + self.write_all(&buffer).await?; + Ok(()) + } + + /// Reads a message of type `NetMsg` with the given timeout. + pub async fn read_timeout(&self, duration: Duration) -> Result { + timeout(duration, self.read()) + .await + .map_err(|_| NetError::Timeout)? + } + + /// Reads the exact number of bytes required to fill `buf`. + async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> { + while !buf.is_empty() { + let n = self.conn.read(buf).await?; + let (_, rest) = std::mem::take(&mut buf).split_at_mut(n); + buf = rest; + + if n == 0 { + return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); + } + } + + Ok(()) + } + + /// Writes an entire buffer into the connection. + async fn write_all(&self, mut buf: &[u8]) -> Result<()> { + while !buf.is_empty() { + let n = self.conn.write(buf).await?; + let (_, rest) = std::mem::take(&mut buf).split_at(n); + buf = rest; + + if n == 0 { + return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); + } + } + + Ok(()) + } +} diff --git a/p2p/src/config.rs b/p2p/src/config.rs index ebecbf0..2c5d5ec 100644 --- a/p2p/src/config.rs +++ b/p2p/src/config.rs @@ -1,6 +1,6 @@ use karyons_net::{Endpoint, Port}; -use crate::utils::Version; +use crate::Version; /// the Configuration for the P2P network. pub struct Config { @@ -71,6 +71,9 @@ pub struct Config { /// The maximum number of retries for outbound connection establishment /// during the refresh process. pub refresh_connect_retries: usize, + + /// Enables TLS for all connections. + pub enable_tls: bool, } impl Default for Config { @@ -100,6 +103,8 @@ impl Default for Config { refresh_interval: 1800, refresh_response_timeout: 1, refresh_connect_retries: 3, + + enable_tls: false, } } } diff --git a/p2p/src/connection.rs b/p2p/src/connection.rs index 8ec2617..e0a3bbd 100644 --- a/p2p/src/connection.rs +++ b/p2p/src/connection.rs @@ -2,7 +2,7 @@ use std::{collections::VecDeque, fmt, sync::Arc}; use smol::{channel::Sender, lock::Mutex}; -use karyons_core::async_utils::CondVar; +use karyons_core::async_util::CondVar; use karyons_net::Conn; use crate::Result; diff --git a/p2p/src/connector.rs b/p2p/src/connector.rs index f41ab57..6fc5734 100644 --- a/p2p/src/connector.rs +++ b/p2p/src/connector.rs @@ -1,21 +1,28 @@ use std::{future::Future, sync::Arc}; -use log::{trace, warn}; +use log::{error, trace, warn}; use karyons_core::{ - async_utils::{Backoff, TaskGroup, TaskResult}, + async_util::{Backoff, TaskGroup, TaskResult}, + key_pair::KeyPair, GlobalExecutor, }; -use karyons_net::{dial, Conn, Endpoint, NetError}; +use karyons_net::{dial, tls, Conn, Endpoint, NetError}; use crate::{ monitor::{ConnEvent, Monitor}, slots::ConnectionSlots, - Result, + tls_config::tls_client_config, + Error, PeerID, Result, }; +static DNS_NAME: &str = "karyons.org"; + /// Responsible for creating outbound connections with other peers. pub struct Connector { + /// Identity Key pair + key_pair: KeyPair, + /// Managing spawned tasks. task_group: TaskGroup<'static>, @@ -26,6 +33,9 @@ pub struct Connector { /// establishing a connection. max_retries: usize, + /// Enables secure connection. + enable_tls: bool, + /// Responsible for network and system monitoring. monitor: Arc, } @@ -33,16 +43,20 @@ pub struct Connector { impl Connector { /// Creates a new Connector pub fn new( + key_pair: &KeyPair, max_retries: usize, connection_slots: Arc, + enable_tls: bool, monitor: Arc, ex: GlobalExecutor, ) -> Arc { Arc::new(Self { + key_pair: key_pair.clone(), + max_retries, task_group: TaskGroup::new(ex), monitor, connection_slots, - max_retries, + enable_tls, }) } @@ -57,20 +71,23 @@ impl Connector { /// `Conn` instance. /// /// This method will block until it finds an available slot. - pub async fn connect(&self, endpoint: &Endpoint) -> Result { + pub async fn connect(&self, endpoint: &Endpoint, peer_id: &Option) -> Result { self.connection_slots.wait_for_slot().await; self.connection_slots.add(); let mut retry = 0; let backoff = Backoff::new(500, 2000); while retry < self.max_retries { - let conn_result = dial(endpoint).await; - - if let Ok(conn) = conn_result { - self.monitor - .notify(&ConnEvent::Connected(endpoint.clone()).into()) - .await; - return Ok(conn); + match self.dial(endpoint, peer_id).await { + Ok(conn) => { + self.monitor + .notify(&ConnEvent::Connected(endpoint.clone()).into()) + .await; + return Ok(conn); + } + Err(err) => { + error!("Failed to establish a connection to {endpoint}: {err}"); + } } self.monitor @@ -96,12 +113,13 @@ impl Connector { pub async fn connect_with_cback( self: &Arc, endpoint: &Endpoint, + peer_id: &Option, callback: impl FnOnce(Conn) -> Fut + Send + 'static, ) -> Result<()> where Fut: Future> + Send + 'static, { - let conn = self.connect(endpoint).await?; + let conn = self.connect(endpoint, peer_id).await?; let selfc = self.clone(); let endpoint = endpoint.clone(); @@ -120,4 +138,14 @@ impl Connector { Ok(()) } + + async fn dial(&self, endpoint: &Endpoint, peer_id: &Option) -> Result { + if self.enable_tls { + let tls_config = tls_client_config(&self.key_pair, peer_id.clone())?; + tls::dial(endpoint, tls_config, DNS_NAME).await + } else { + dial(endpoint).await + } + .map_err(Error::KaryonsNet) + } } diff --git a/p2p/src/discovery/lookup.rs b/p2p/src/discovery/lookup.rs index 0138068..60d8635 100644 --- a/p2p/src/discovery/lookup.rs +++ b/p2p/src/discovery/lookup.rs @@ -5,13 +5,13 @@ use log::{error, trace}; use rand::{rngs::OsRng, seq::SliceRandom, RngCore}; use smol::lock::{Mutex, RwLock}; -use karyons_core::{async_utils::timeout, utils::decode, GlobalExecutor}; +use karyons_core::{async_util::timeout, key_pair::KeyPair, util::decode, GlobalExecutor}; use karyons_net::{Conn, Endpoint}; use crate::{ + codec::Codec, connector::Connector, - io_codec::IOCodec, listener::Listener, message::{ get_msg_payload, FindPeerMsg, NetMsg, NetMsgCmd, PeerMsg, PeersMsg, PingMsg, PongMsg, @@ -20,7 +20,7 @@ use crate::{ monitor::{ConnEvent, DiscoveryEvent, Monitor}, routing_table::RoutingTable, slots::ConnectionSlots, - utils::version_match, + version::version_match, Config, Error, PeerID, Result, }; @@ -55,6 +55,7 @@ pub struct LookupService { impl LookupService { /// Creates a new lookup service pub fn new( + key_pair: &KeyPair, id: &PeerID, table: Arc>, config: Arc, @@ -64,11 +65,19 @@ impl LookupService { let inbound_slots = Arc::new(ConnectionSlots::new(config.lookup_inbound_slots)); let outbound_slots = Arc::new(ConnectionSlots::new(config.lookup_outbound_slots)); - let listener = Listener::new(inbound_slots.clone(), monitor.clone(), ex.clone()); + let listener = Listener::new( + key_pair, + inbound_slots.clone(), + config.enable_tls, + monitor.clone(), + ex.clone(), + ); let connector = Connector::new( + key_pair, config.lookup_connect_retries, outbound_slots.clone(), + config.enable_tls, monitor.clone(), ex, ); @@ -116,14 +125,17 @@ impl LookupService { /// randomly generated peer ID. Upon receiving peers from the initial lookup, /// it starts connecting to these received peers and sends them a FindPeer /// message that contains our own peer ID. - pub async fn start_lookup(&self, endpoint: &Endpoint) -> Result<()> { + pub async fn start_lookup(&self, endpoint: &Endpoint, peer_id: Option) -> Result<()> { trace!("Lookup started {endpoint}"); self.monitor .notify(&DiscoveryEvent::LookupStarted(endpoint.clone()).into()) .await; let mut random_peers = vec![]; - if let Err(err) = self.random_lookup(endpoint, &mut random_peers).await { + if let Err(err) = self + .random_lookup(endpoint, peer_id, &mut random_peers) + .await + { self.monitor .notify(&DiscoveryEvent::LookupFailed(endpoint.clone()).into()) .await; @@ -160,11 +172,14 @@ impl LookupService { async fn random_lookup( &self, endpoint: &Endpoint, + peer_id: Option, random_peers: &mut Vec, ) -> Result<()> { for _ in 0..2 { - let peer_id = PeerID::random(); - let peers = self.connect(&peer_id, endpoint.clone()).await?; + let random_peer_id = PeerID::random(); + let peers = self + .connect(endpoint.clone(), peer_id.clone(), &random_peer_id) + .await?; let table = self.table.lock().await; for peer in peers { @@ -187,7 +202,7 @@ impl LookupService { let mut tasks = FuturesUnordered::new(); for peer in random_peers.choose_multiple(&mut OsRng, random_peers.len()) { let endpoint = Endpoint::Tcp(peer.addr.clone(), peer.discovery_port); - tasks.push(self.connect(&self.id, endpoint)) + tasks.push(self.connect(endpoint, Some(peer.peer_id.clone()), &self.id)) } while let Some(result) = tasks.next().await { @@ -200,11 +215,17 @@ impl LookupService { } } - /// Connects to the given endpoint - async fn connect(&self, peer_id: &PeerID, endpoint: Endpoint) -> Result> { - let conn = self.connector.connect(&endpoint).await?; - let io_codec = IOCodec::new(conn); - let result = self.handle_outbound(io_codec, peer_id).await; + /// Connects to the given endpoint and initiates a lookup process for the + /// provided peer ID. + async fn connect( + &self, + endpoint: Endpoint, + peer_id: Option, + target_peer_id: &PeerID, + ) -> Result> { + let conn = self.connector.connect(&endpoint, &peer_id).await?; + let io_codec = Codec::new(conn); + let result = self.handle_outbound(io_codec, target_peer_id).await; self.monitor .notify(&ConnEvent::Disconnected(endpoint).into()) @@ -215,12 +236,16 @@ impl LookupService { } /// Handles outbound connection - async fn handle_outbound(&self, io_codec: IOCodec, peer_id: &PeerID) -> Result> { + async fn handle_outbound( + &self, + io_codec: Codec, + target_peer_id: &PeerID, + ) -> Result> { trace!("Send Ping msg"); self.send_ping_msg(&io_codec).await?; trace!("Send FindPeer msg"); - let peers = self.send_findpeer_msg(&io_codec, peer_id).await?; + let peers = self.send_findpeer_msg(&io_codec, target_peer_id).await?; if peers.0.len() >= MAX_PEERS_IN_PEERSMSG { return Err(Error::Lookup("Received too many peers in PeersMsg")); @@ -260,7 +285,7 @@ impl LookupService { /// Handles inbound connection async fn handle_inbound(self: &Arc, conn: Conn) -> Result<()> { - let io_codec = IOCodec::new(conn); + let io_codec = Codec::new(conn); loop { let msg: NetMsg = io_codec.read().await?; trace!("Receive msg {:?}", msg.header.command); @@ -293,7 +318,7 @@ impl LookupService { } /// Sends a Ping msg and wait to receive the Pong message. - async fn send_ping_msg(&self, io_codec: &IOCodec) -> Result<()> { + async fn send_ping_msg(&self, io_codec: &Codec) -> Result<()> { trace!("Send Pong msg"); let mut nonce: [u8; 32] = [0; 32]; @@ -319,14 +344,14 @@ impl LookupService { } /// Sends a Pong msg - async fn send_pong_msg(&self, nonce: [u8; 32], io_codec: &IOCodec) -> Result<()> { + async fn send_pong_msg(&self, nonce: [u8; 32], io_codec: &Codec) -> Result<()> { trace!("Send Pong msg"); io_codec.write(NetMsgCmd::Pong, &PongMsg(nonce)).await?; Ok(()) } /// Sends a FindPeer msg and wait to receivet the Peers msg. - async fn send_findpeer_msg(&self, io_codec: &IOCodec, peer_id: &PeerID) -> Result { + async fn send_findpeer_msg(&self, io_codec: &Codec, peer_id: &PeerID) -> Result { trace!("Send FindPeer msg"); io_codec .write(NetMsgCmd::FindPeer, &FindPeerMsg(peer_id.clone())) @@ -342,7 +367,7 @@ impl LookupService { } /// Sends a Peers msg. - async fn send_peers_msg(&self, peer_id: &PeerID, io_codec: &IOCodec) -> Result<()> { + async fn send_peers_msg(&self, peer_id: &PeerID, io_codec: &Codec) -> Result<()> { trace!("Send Peers msg"); let table = self.table.lock().await; let entries = table.closest_entries(&peer_id.0, MAX_PEERS_IN_PEERSMSG); @@ -354,7 +379,7 @@ impl LookupService { } /// Sends a Peer msg. - async fn send_peer_msg(&self, io_codec: &IOCodec, endpoint: Endpoint) -> Result<()> { + async fn send_peer_msg(&self, io_codec: &Codec, endpoint: Endpoint) -> Result<()> { trace!("Send Peer msg"); let peer_msg = PeerMsg { addr: endpoint.addr()?.clone(), @@ -367,7 +392,7 @@ impl LookupService { } /// Sends a Shutdown msg. - async fn send_shutdown_msg(&self, io_codec: &IOCodec) -> Result<()> { + async fn send_shutdown_msg(&self, io_codec: &Codec) -> Result<()> { trace!("Send Shutdown msg"); io_codec.write(NetMsgCmd::Shutdown, &ShutdownMsg(0)).await?; Ok(()) diff --git a/p2p/src/discovery/mod.rs b/p2p/src/discovery/mod.rs index 7f55309..2c1bcd8 100644 --- a/p2p/src/discovery/mod.rs +++ b/p2p/src/discovery/mod.rs @@ -8,7 +8,8 @@ use rand::{rngs::OsRng, seq::SliceRandom}; use smol::lock::Mutex; use karyons_core::{ - async_utils::{Backoff, TaskGroup, TaskResult}, + async_util::{Backoff, TaskGroup, TaskResult}, + key_pair::KeyPair, GlobalExecutor, }; @@ -66,6 +67,7 @@ pub struct Discovery { impl Discovery { /// Creates a new Discovery pub fn new( + key_pair: &KeyPair, peer_id: &PeerID, conn_queue: Arc, config: Arc, @@ -81,6 +83,7 @@ impl Discovery { let refresh_service = RefreshService::new(config.clone(), table.clone(), monitor.clone(), ex.clone()); let lookup_service = LookupService::new( + key_pair, peer_id, table.clone(), config.clone(), @@ -89,12 +92,21 @@ impl Discovery { ); let connector = Connector::new( + key_pair, config.max_connect_retries, outbound_slots.clone(), + config.enable_tls, + monitor.clone(), + ex.clone(), + ); + + let listener = Listener::new( + key_pair, + inbound_slots.clone(), + config.enable_tls, monitor.clone(), ex.clone(), ); - let listener = Listener::new(inbound_slots.clone(), monitor.clone(), ex.clone()); Arc::new(Self { refresh_service: Arc::new(refresh_service), @@ -222,7 +234,7 @@ impl Discovery { selfc.update_entry(&pid, INCOMPATIBLE_ENTRY).await; } Err(Error::PeerAlreadyConnected) => { - // TODO + // TODO: Use the appropriate status. selfc.update_entry(&pid, DISCONNECTED_ENTRY).await; } Err(_) => { @@ -236,10 +248,13 @@ impl Discovery { Ok(()) }; - let res = self.connector.connect_with_cback(endpoint, cback).await; + let result = self + .connector + .connect_with_cback(endpoint, &pid, cback) + .await; if let Some(pid) = &pid { - match res { + match result { Ok(_) => { self.update_entry(pid, CONNECTED_ENTRY).await; } @@ -260,7 +275,8 @@ impl Discovery { match self.random_entry(PENDING_ENTRY | CONNECTED_ENTRY).await { Some(entry) => { let endpoint = Endpoint::Tcp(entry.addr, entry.discovery_port); - if let Err(err) = self.lookup_service.start_lookup(&endpoint).await { + let peer_id = Some(entry.key.into()); + if let Err(err) = self.lookup_service.start_lookup(&endpoint, peer_id).await { self.update_entry(&entry.key.into(), UNSTABLE_ENTRY).await; error!("Failed to do lookup: {endpoint}: {err}"); } @@ -268,7 +284,7 @@ impl Discovery { None => { let peers = &self.config.bootstrap_peers; for endpoint in peers.choose_multiple(&mut OsRng, peers.len()) { - if let Err(err) = self.lookup_service.start_lookup(endpoint).await { + if let Err(err) = self.lookup_service.start_lookup(endpoint, None).await { error!("Failed to do lookup: {endpoint}: {err}"); } } diff --git a/p2p/src/discovery/refresh.rs b/p2p/src/discovery/refresh.rs index d095f19..f797c71 100644 --- a/p2p/src/discovery/refresh.rs +++ b/p2p/src/discovery/refresh.rs @@ -10,8 +10,8 @@ use smol::{ }; use karyons_core::{ - async_utils::{timeout, Backoff, TaskGroup, TaskResult}, - utils::{decode, encode}, + async_util::{timeout, Backoff, TaskGroup, TaskResult}, + util::{decode, encode}, GlobalExecutor, }; diff --git a/p2p/src/error.rs b/p2p/src/error.rs index 0c1d50c..6274d4c 100644 --- a/p2p/src/error.rs +++ b/p2p/src/error.rs @@ -11,6 +11,9 @@ pub enum Error { #[error("Unsupported protocol error: {0}")] UnsupportedProtocol(String), + #[error("Try from public key Error: {0}")] + TryFromPublicKey(&'static str), + #[error("Invalid message error: {0}")] InvalidMsg(String), @@ -50,6 +53,21 @@ pub enum Error { #[error("Peer already connected")] PeerAlreadyConnected, + #[error("Yasna Error: {0}")] + Yasna(#[from] yasna::ASN1Error), + + #[error("X509 Parser Error: {0}")] + X509Parser(#[from] x509_parser::error::X509Error), + + #[error("Rcgen Error: {0}")] + Rcgen(#[from] rcgen::RcgenError), + + #[error("Tls Error: {0}")] + Rustls(#[from] async_rustls::rustls::Error), + + #[error("Invalid DNS Name: {0}")] + InvalidDnsNameError(#[from] async_rustls::rustls::client::InvalidDnsNameError), + #[error("Channel Send Error: {0}")] ChannelSend(String), diff --git a/p2p/src/io_codec.rs b/p2p/src/io_codec.rs deleted file mode 100644 index ea62666..0000000 --- a/p2p/src/io_codec.rs +++ /dev/null @@ -1,132 +0,0 @@ -use std::time::Duration; - -use bincode::{Decode, Encode}; - -use karyons_core::{ - async_utils::timeout, - utils::{decode, encode, encode_into_slice}, -}; - -use karyons_net::{Connection, NetError}; - -use crate::{ - message::{NetMsg, NetMsgCmd, NetMsgHeader, MAX_ALLOWED_MSG_SIZE, MSG_HEADER_SIZE}, - Error, Result, -}; - -pub trait CodecMsg: Decode + Encode + std::fmt::Debug {} -impl CodecMsg for T {} - -/// I/O codec working with generic network connections. -/// -/// It is responsible for both decoding data received from the network and -/// encoding data before sending it. -pub struct IOCodec { - conn: Box, -} - -impl IOCodec { - /// Creates a new IOCodec. - pub fn new(conn: Box) -> Self { - Self { conn } - } - - /// Reads a message of type `NetMsg` from the connection. - /// - /// It reads the first 6 bytes as the header of the message, then reads - /// and decodes the remaining message data based on the determined header. - pub async fn read(&self) -> Result { - // Read 6 bytes to get the header of the incoming message - let mut buf = [0; MSG_HEADER_SIZE]; - self.read_exact(&mut buf).await?; - - // Decode the header from bytes to NetMsgHeader - let (header, _) = decode::(&buf)?; - - if header.payload_size > MAX_ALLOWED_MSG_SIZE { - return Err(Error::InvalidMsg( - "Message exceeds the maximum allowed size".to_string(), - )); - } - - // Create a buffer to hold the message based on its length - let mut payload = vec![0; header.payload_size as usize]; - self.read_exact(&mut payload).await?; - - Ok(NetMsg { header, payload }) - } - - /// Writes a message of type `T` to the connection. - /// - /// Before appending the actual message payload, it calculates the length of - /// the encoded message in bytes and appends this length to the message header. - pub async fn write(&self, command: NetMsgCmd, msg: &T) -> Result<()> { - let payload = encode(msg)?; - - // Create a buffer to hold the message header (6 bytes) - let header_buf = &mut [0; MSG_HEADER_SIZE]; - let header = NetMsgHeader { - command, - payload_size: payload.len() as u32, - }; - encode_into_slice(&header, header_buf)?; - - let mut buffer = vec![]; - // Append the header bytes to the buffer - buffer.extend_from_slice(header_buf); - // Append the message payload to the buffer - buffer.extend_from_slice(&payload); - - self.write_all(&buffer).await?; - Ok(()) - } - - /// Reads a message of type `NetMsg` with the given timeout. - pub async fn read_timeout(&self, duration: Duration) -> Result { - timeout(duration, self.read()) - .await - .map_err(|_| NetError::Timeout)? - } - - /// Writes a message of type `T` with the given timeout. - pub async fn write_timeout( - &self, - command: NetMsgCmd, - msg: &T, - duration: Duration, - ) -> Result<()> { - timeout(duration, self.write(command, msg)) - .await - .map_err(|_| NetError::Timeout)? - } - - /// Reads the exact number of bytes required to fill `buf`. - async fn read_exact(&self, mut buf: &mut [u8]) -> Result<()> { - while !buf.is_empty() { - let n = self.conn.read(buf).await?; - let (_, rest) = std::mem::take(&mut buf).split_at_mut(n); - buf = rest; - - if n == 0 { - return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); - } - } - - Ok(()) - } - - /// Writes an entire buffer into the connection. - async fn write_all(&self, mut buf: &[u8]) -> Result<()> { - while !buf.is_empty() { - let n = self.conn.write(buf).await?; - let (_, rest) = std::mem::take(&mut buf).split_at(n); - buf = rest; - - if n == 0 { - return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); - } - } - - Ok(()) - } -} diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs index c0a3b5b..6585287 100644 --- a/p2p/src/lib.rs +++ b/p2p/src/lib.rs @@ -7,19 +7,19 @@ //! use easy_parallel::Parallel; //! use smol::{channel as smol_channel, future, Executor}; //! +//! use karyons_core::key_pair::{KeyPair, KeyPairType}; //! use karyons_p2p::{Backend, Config, PeerID}; //! -//! let peer_id = PeerID::random(); +//! let key_pair = KeyPair::generate(&KeyPairType::Ed25519); //! //! // Create the configuration for the backend. //! let mut config = Config::default(); //! -//! //! // Create a new Executor //! let ex = Arc::new(Executor::new()); //! //! // Create a new Backend -//! let backend = Backend::new(peer_id, config, ex.clone()); +//! let backend = Backend::new(&key_pair, config, ex.clone()); //! //! let task = async { //! // Run the backend @@ -36,12 +36,12 @@ //! ``` //! mod backend; +mod codec; mod config; mod connection; mod connector; mod discovery; mod error; -mod io_codec; mod listener; mod message; mod peer; @@ -49,7 +49,8 @@ mod peer_pool; mod protocols; mod routing_table; mod slots; -mod utils; +mod tls_config; +mod version; /// Responsible for network and system monitoring. /// [`Read More`](./monitor/struct.Monitor.html) @@ -62,6 +63,6 @@ pub use backend::{ArcBackend, Backend}; pub use config::Config; pub use error::Error as P2pError; pub use peer::{ArcPeer, PeerID}; -pub use utils::Version; +pub use version::Version; use error::{Error, Result}; diff --git a/p2p/src/listener.rs b/p2p/src/listener.rs index f2391f7..58a0931 100644 --- a/p2p/src/listener.rs +++ b/p2p/src/listener.rs @@ -1,28 +1,36 @@ use std::{future::Future, sync::Arc}; -use log::{error, info, trace}; +use log::{debug, error, info}; use karyons_core::{ - async_utils::{TaskGroup, TaskResult}, + async_util::{TaskGroup, TaskResult}, + key_pair::KeyPair, GlobalExecutor, }; -use karyons_net::{listen, Conn, Endpoint, Listener as NetListener}; +use karyons_net::{listen, tls, Conn, Endpoint, Listener as NetListener}; use crate::{ monitor::{ConnEvent, Monitor}, slots::ConnectionSlots, - Result, + tls_config::tls_server_config, + Error, Result, }; /// Responsible for creating inbound connections with other peers. pub struct Listener { + /// Identity Key pair + key_pair: KeyPair, + /// Managing spawned tasks. task_group: TaskGroup<'static>, /// Manages available inbound slots. connection_slots: Arc, + /// Enables secure connection. + enable_tls: bool, + /// Responsible for network and system monitoring. monitor: Arc, } @@ -30,13 +38,17 @@ pub struct Listener { impl Listener { /// Creates a new Listener pub fn new( + key_pair: &KeyPair, connection_slots: Arc, + enable_tls: bool, monitor: Arc, ex: GlobalExecutor, ) -> Arc { Arc::new(Self { + key_pair: key_pair.clone(), connection_slots, task_group: TaskGroup::new(ex), + enable_tls, monitor, }) } @@ -55,7 +67,7 @@ impl Listener { where Fut: Future> + Send + 'static, { - let listener = match listen(&endpoint).await { + let listener = match self.listend(&endpoint).await { Ok(listener) => { self.monitor .notify(&ConnEvent::Listening(endpoint.clone()).into()) @@ -67,21 +79,17 @@ impl Listener { self.monitor .notify(&ConnEvent::ListenFailed(endpoint).into()) .await; - return Err(err.into()); + return Err(err); } }; let resolved_endpoint = listener.local_endpoint()?; - info!("Start listening on {endpoint}"); + info!("Start listening on {resolved_endpoint}"); let selfc = self.clone(); self.task_group - .spawn(selfc.listen_loop(listener, callback), |res| async move { - if let TaskResult::Completed(Err(err)) = res { - error!("Listen loop stopped: {endpoint} {err}"); - } - }); + .spawn(selfc.listen_loop(listener, callback), |_| async {}); Ok(resolved_endpoint) } @@ -94,8 +102,7 @@ impl Listener { self: Arc, listener: Box, callback: impl FnOnce(Conn) -> Fut + Clone + Send + 'static, - ) -> Result<()> - where + ) where Fut: Future> + Send + 'static, { loop { @@ -103,27 +110,35 @@ impl Listener { self.connection_slots.wait_for_slot().await; let result = listener.accept().await; - let conn = match result { + let (conn, endpoint) = match result { Ok(c) => { + let endpoint = match c.peer_endpoint() { + Ok(e) => e, + Err(err) => { + self.monitor.notify(&ConnEvent::AcceptFailed.into()).await; + error!("Failed to accept a new connection: {err}"); + continue; + } + }; + self.monitor - .notify(&ConnEvent::Accepted(c.peer_endpoint()?).into()) + .notify(&ConnEvent::Accepted(endpoint.clone()).into()) .await; - c + (c, endpoint) } Err(err) => { error!("Failed to accept a new connection: {err}"); self.monitor.notify(&ConnEvent::AcceptFailed.into()).await; - return Err(err.into()); + continue; } }; self.connection_slots.add(); let selfc = self.clone(); - let endpoint = conn.peer_endpoint()?; let on_disconnect = |res| async move { if let TaskResult::Completed(Err(err)) = res { - trace!("Inbound connection dropped: {err}"); + debug!("Inbound connection dropped: {err}"); } selfc .monitor @@ -136,4 +151,14 @@ impl Listener { self.task_group.spawn(callback(conn), on_disconnect); } } + + async fn listend(&self, endpoint: &Endpoint) -> Result> { + if self.enable_tls { + let tls_config = tls_server_config(&self.key_pair)?; + tls::listen(endpoint, tls_config).await + } else { + listen(endpoint).await + } + .map_err(Error::KaryonsNet) + } } diff --git a/p2p/src/message.rs b/p2p/src/message.rs index 3779cc1..6b23322 100644 --- a/p2p/src/message.rs +++ b/p2p/src/message.rs @@ -4,7 +4,7 @@ use bincode::{Decode, Encode}; use karyons_net::{Addr, Port}; -use crate::{protocol::ProtocolID, routing_table::Entry, utils::VersionInt, PeerID}; +use crate::{protocol::ProtocolID, routing_table::Entry, version::VersionInt, PeerID}; /// The size of the message header, in bytes. pub const MSG_HEADER_SIZE: usize = 6; diff --git a/p2p/src/monitor.rs b/p2p/src/monitor.rs index fbbf43f..1f74503 100644 --- a/p2p/src/monitor.rs +++ b/p2p/src/monitor.rs @@ -17,6 +17,7 @@ use karyons_net::Endpoint; /// /// use smol::Executor; /// +/// use karyons_core::key_pair::{KeyPair, KeyPairType}; /// use karyons_p2p::{Config, Backend, PeerID}; /// /// async { @@ -24,7 +25,8 @@ use karyons_net::Endpoint; /// // Create a new Executor /// let ex = Arc::new(Executor::new()); /// -/// let backend = Backend::new(PeerID::random(), Config::default(), ex); +/// let key_pair = KeyPair::generate(&KeyPairType::Ed25519); +/// let backend = Backend::new(&key_pair, Config::default(), ex); /// /// // Create a new Subscription /// let sub = backend.monitor().await; diff --git a/p2p/src/peer/mod.rs b/p2p/src/peer/mod.rs index 85cd558..6ed0dd8 100644 --- a/p2p/src/peer/mod.rs +++ b/p2p/src/peer/mod.rs @@ -11,17 +11,17 @@ use smol::{ }; use karyons_core::{ - async_utils::{select, Either, TaskGroup, TaskResult}, + async_util::{select, Either, TaskGroup, TaskResult}, event::{ArcEventSys, EventListener, EventSys}, - utils::{decode, encode}, + util::{decode, encode}, GlobalExecutor, }; use karyons_net::Endpoint; use crate::{ + codec::{Codec, CodecMsg}, connection::ConnDirection, - io_codec::{CodecMsg, IOCodec}, message::{NetMsgCmd, ProtocolMsg, ShutdownMsg}, peer_pool::{ArcPeerPool, WeakPeerPool}, protocol::{Protocol, ProtocolEvent, ProtocolID}, @@ -37,8 +37,8 @@ pub struct Peer { /// A weak pointer to `PeerPool` peer_pool: WeakPeerPool, - /// Holds the IOCodec for the peer connection - io_codec: IOCodec, + /// Holds the Codec for the peer connection + codec: Codec, /// Remote endpoint for the peer remote_endpoint: Endpoint, @@ -64,7 +64,7 @@ impl Peer { pub fn new( peer_pool: WeakPeerPool, id: &PeerID, - io_codec: IOCodec, + codec: Codec, remote_endpoint: Endpoint, conn_direction: ConnDirection, ex: GlobalExecutor, @@ -72,7 +72,7 @@ impl Peer { Arc::new(Peer { id: id.clone(), peer_pool, - io_codec, + codec, protocol_ids: RwLock::new(Vec::new()), remote_endpoint, conn_direction, @@ -97,7 +97,7 @@ impl Peer { payload: payload.to_vec(), }; - self.io_codec.write(NetMsgCmd::Protocol, &proto_msg).await?; + self.codec.write(NetMsgCmd::Protocol, &proto_msg).await?; Ok(()) } @@ -124,10 +124,7 @@ impl Peer { let _ = self.stop_chan.0.try_send(Ok(())); // No need to handle the error here - let _ = self - .io_codec - .write(NetMsgCmd::Shutdown, &ShutdownMsg(0)) - .await; + let _ = self.codec.write(NetMsgCmd::Shutdown, &ShutdownMsg(0)).await; // Force shutting down self.task_group.cancel().await; @@ -174,7 +171,7 @@ impl Peer { /// Start a read loop to handle incoming messages from the peer connection. async fn read_loop(&self) -> Result<()> { loop { - let fut = select(self.stop_chan.1.recv(), self.io_codec.read()).await; + let fut = select(self.stop_chan.1.recv(), self.codec.read()).await; let result = match fut { Either::Left(stop_signal) => { trace!("Peer {} received a stop signal", self.id); diff --git a/p2p/src/peer/peer_id.rs b/p2p/src/peer/peer_id.rs index c8aec7d..903d827 100644 --- a/p2p/src/peer/peer_id.rs +++ b/p2p/src/peer/peer_id.rs @@ -2,6 +2,10 @@ use bincode::{Decode, Encode}; use rand::{rngs::OsRng, RngCore}; use sha2::{Digest, Sha256}; +use karyons_core::key_pair::PublicKey; + +use crate::Error; + /// Represents a unique identifier for a peer. #[derive(Clone, Debug, Eq, PartialEq, Hash, Decode, Encode)] pub struct PeerID(pub [u8; 32]); @@ -39,3 +43,16 @@ impl From<[u8; 32]> for PeerID { PeerID(b) } } + +impl TryFrom for PeerID { + type Error = Error; + + fn try_from(pk: PublicKey) -> Result { + let pk: [u8; 32] = pk + .as_bytes() + .try_into() + .map_err(|_| Error::TryFromPublicKey("Failed to convert public key to [u8;32]"))?; + + Ok(PeerID(pk)) + } +} diff --git a/p2p/src/peer_pool.rs b/p2p/src/peer_pool.rs index a0079f2..dd7e669 100644 --- a/p2p/src/peer_pool.rs +++ b/p2p/src/peer_pool.rs @@ -11,23 +11,23 @@ use smol::{ }; use karyons_core::{ - async_utils::{TaskGroup, TaskResult}, - utils::decode, + async_util::{TaskGroup, TaskResult}, + util::decode, GlobalExecutor, }; use karyons_net::Conn; use crate::{ + codec::{Codec, CodecMsg}, config::Config, connection::{ConnDirection, ConnQueue}, - io_codec::{CodecMsg, IOCodec}, message::{get_msg_payload, NetMsg, NetMsgCmd, VerAckMsg, VerMsg}, monitor::{Monitor, PeerPoolEvent}, peer::{ArcPeer, Peer, PeerID}, protocol::{Protocol, ProtocolConstructor, ProtocolID}, protocols::PingProtocol, - utils::{version_match, Version, VersionInt}, + version::{version_match, Version, VersionInt}, Error, Result, }; @@ -155,10 +155,10 @@ impl PeerPool { disconnect_signal: Sender>, ) -> Result<()> { let endpoint = conn.peer_endpoint()?; - let io_codec = IOCodec::new(conn); + let codec = Codec::new(conn); // Do a handshake with the connection before creating a new peer. - let pid = self.do_handshake(&io_codec, conn_direction).await?; + let pid = self.do_handshake(&codec, conn_direction).await?; // TODO: Consider restricting the subnet for inbound connections if self.contains_peer(&pid).await { @@ -169,7 +169,7 @@ impl PeerPool { let peer = Peer::new( Arc::downgrade(self), &pid, - io_codec, + codec, endpoint.clone(), conn_direction.clone(), self.executor.clone(), @@ -235,20 +235,16 @@ impl PeerPool { } /// Initiate a handshake with a connection. - async fn do_handshake( - &self, - io_codec: &IOCodec, - conn_direction: &ConnDirection, - ) -> Result { + async fn do_handshake(&self, codec: &Codec, conn_direction: &ConnDirection) -> Result { match conn_direction { ConnDirection::Inbound => { - let result = self.wait_vermsg(io_codec).await; + let result = self.wait_vermsg(codec).await; match result { Ok(_) => { - self.send_verack(io_codec, true).await?; + self.send_verack(codec, true).await?; } Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => { - self.send_verack(io_codec, false).await?; + self.send_verack(codec, false).await?; } _ => {} } @@ -256,14 +252,14 @@ impl PeerPool { } ConnDirection::Outbound => { - self.send_vermsg(io_codec).await?; - self.wait_verack(io_codec).await + self.send_vermsg(codec).await?; + self.wait_verack(codec).await } } } /// Send a Version message - async fn send_vermsg(&self, io_codec: &IOCodec) -> Result<()> { + async fn send_vermsg(&self, codec: &Codec) -> Result<()> { let pids = self.protocol_versions.read().await; let protocols = pids.iter().map(|p| (p.0.clone(), p.1.v.clone())).collect(); drop(pids); @@ -275,16 +271,16 @@ impl PeerPool { }; trace!("Send VerMsg"); - io_codec.write(NetMsgCmd::Version, &vermsg).await?; + codec.write(NetMsgCmd::Version, &vermsg).await?; Ok(()) } /// Wait for a Version message /// /// Returns the peer's ID upon successfully receiving the Version message. - async fn wait_vermsg(&self, io_codec: &IOCodec) -> Result { + async fn wait_vermsg(&self, codec: &Codec) -> Result { let timeout = Duration::from_secs(self.config.handshake_timeout); - let msg: NetMsg = io_codec.read_timeout(timeout).await?; + let msg: NetMsg = codec.read_timeout(timeout).await?; let payload = get_msg_payload!(Version, msg); let (vermsg, _) = decode::(&payload)?; @@ -300,23 +296,23 @@ impl PeerPool { } /// Send a Verack message - async fn send_verack(&self, io_codec: &IOCodec, ack: bool) -> Result<()> { + async fn send_verack(&self, codec: &Codec, ack: bool) -> Result<()> { let verack = VerAckMsg { peer_id: self.id.clone(), ack, }; trace!("Send VerAckMsg {:?}", verack); - io_codec.write(NetMsgCmd::Verack, &verack).await?; + codec.write(NetMsgCmd::Verack, &verack).await?; Ok(()) } /// Wait for a Verack message /// /// Returns the peer's ID upon successfully receiving the Verack message. - async fn wait_verack(&self, io_codec: &IOCodec) -> Result { + async fn wait_verack(&self, codec: &Codec) -> Result { let timeout = Duration::from_secs(self.config.handshake_timeout); - let msg: NetMsg = io_codec.read_timeout(timeout).await?; + let msg: NetMsg = codec.read_timeout(timeout).await?; let payload = get_msg_payload!(Verack, msg); let (verack, _) = decode::(&payload)?; diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs index 770b695..8ddc685 100644 --- a/p2p/src/protocol.rs +++ b/p2p/src/protocol.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use karyons_core::{event::EventValue, Executor}; -use crate::{peer::ArcPeer, utils::Version, Result}; +use crate::{peer::ArcPeer, version::Version, Result}; pub type ArcProtocol = Arc; @@ -37,6 +37,7 @@ impl EventValue for ProtocolEvent { /// use async_trait::async_trait; /// use smol::Executor; /// +/// use karyons_core::key_pair::{KeyPair, KeyPairType}; /// use karyons_p2p::{ /// protocol::{ArcProtocol, Protocol, ProtocolID, ProtocolEvent}, /// Backend, PeerID, Config, Version, P2pError, ArcPeer}; @@ -84,14 +85,14 @@ impl EventValue for ProtocolEvent { /// } /// /// async { -/// let peer_id = PeerID::random(); +/// let key_pair = KeyPair::generate(&KeyPairType::Ed25519); /// let config = Config::default(); /// /// // Create a new Executor /// let ex = Arc::new(Executor::new()); /// /// // Create a new Backend -/// let backend = Backend::new(peer_id, config, ex); +/// let backend = Backend::new(&key_pair, config, ex); /// /// // Attach the NewProtocol /// let c = move |peer| NewProtocol::new(peer); diff --git a/p2p/src/protocols/ping.rs b/p2p/src/protocols/ping.rs index dc1b9a1..0a5488d 100644 --- a/p2p/src/protocols/ping.rs +++ b/p2p/src/protocols/ping.rs @@ -12,9 +12,9 @@ use smol::{ }; use karyons_core::{ - async_utils::{select, timeout, Either, TaskGroup, TaskResult}, + async_util::{select, timeout, Either, TaskGroup, TaskResult}, event::EventListener, - utils::decode, + util::decode, Executor, }; @@ -23,7 +23,7 @@ use karyons_net::NetError; use crate::{ peer::ArcPeer, protocol::{ArcProtocol, Protocol, ProtocolEvent, ProtocolID}, - utils::Version, + version::Version, Result, }; diff --git a/p2p/src/routing_table/entry.rs b/p2p/src/routing_table/entry.rs index b3f219f..c5fa65d 100644 --- a/p2p/src/routing_table/entry.rs +++ b/p2p/src/routing_table/entry.rs @@ -20,7 +20,7 @@ pub struct Entry { impl PartialEq for Entry { fn eq(&self, other: &Self) -> bool { - // XXX this should also compare both addresses (the self.addr == other.addr) + // TODO: this should also compare both addresses (the self.addr == other.addr) self.key == other.key } } diff --git a/p2p/src/routing_table/mod.rs b/p2p/src/routing_table/mod.rs index 5277c0a..cfc3128 100644 --- a/p2p/src/routing_table/mod.rs +++ b/p2p/src/routing_table/mod.rs @@ -1,5 +1,8 @@ +use std::net::IpAddr; + mod bucket; mod entry; + pub use bucket::{ Bucket, BucketEntry, EntryStatusFlag, CONNECTED_ENTRY, DISCONNECTED_ENTRY, INCOMPATIBLE_ENTRY, PENDING_ENTRY, UNREACHABLE_ENTRY, UNSTABLE_ENTRY, @@ -8,7 +11,7 @@ pub use entry::{xor_distance, Entry, Key}; use rand::{rngs::OsRng, seq::SliceRandom}; -use crate::utils::subnet_match; +use karyons_net::Addr; use bucket::BUCKET_SIZE; use entry::KEY_SIZE; @@ -262,6 +265,20 @@ impl RoutingTable { } } +/// Check if two addresses belong to the same subnet. +pub fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool { + match (addr, other_addr) { + (Addr::Ip(IpAddr::V4(ip)), Addr::Ip(IpAddr::V4(other_ip))) => { + // TODO: Consider moving this to a different place + if other_ip.is_loopback() && ip.is_loopback() { + return false; + } + ip.octets()[0..3] == other_ip.octets()[0..3] + } + _ => false, + } +} + #[cfg(test)] mod tests { use super::bucket::ALL_ENTRY; diff --git a/p2p/src/slots.rs b/p2p/src/slots.rs index 99f0a78..d3a1d0a 100644 --- a/p2p/src/slots.rs +++ b/p2p/src/slots.rs @@ -1,6 +1,6 @@ use std::sync::atomic::{AtomicUsize, Ordering}; -use karyons_core::async_utils::CondWait; +use karyons_core::async_util::CondWait; /// Manages available inbound and outbound slots. pub struct ConnectionSlots { diff --git a/p2p/src/tls_config.rs b/p2p/src/tls_config.rs new file mode 100644 index 0000000..f3b231a --- /dev/null +++ b/p2p/src/tls_config.rs @@ -0,0 +1,214 @@ +use std::sync::Arc; + +use async_rustls::rustls::{ + self, cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, client::ServerCertVerifier, + server::ClientCertVerifier, Certificate, CertificateError, Error::InvalidCertificate, + PrivateKey, SupportedCipherSuite, SupportedKxGroup, SupportedProtocolVersion, +}; +use log::error; +use x509_parser::{certificate::X509Certificate, parse_x509_certificate}; + +use karyons_core::key_pair::{KeyPair, KeyPairType, PublicKey}; + +use crate::{PeerID, Result}; + +// NOTE: This code needs a comprehensive audit. + +static PROTOCOL_VERSIONS: &[&SupportedProtocolVersion] = &[&rustls::version::TLS13]; +static CIPHER_SUITES: &[SupportedCipherSuite] = &[TLS13_CHACHA20_POLY1305_SHA256]; +static KX_GROUPS: &[&SupportedKxGroup] = &[&rustls::kx_group::X25519]; + +const BAD_SIGNATURE_ERR: rustls::Error = InvalidCertificate(CertificateError::BadSignature); +const BAD_ENCODING_ERR: rustls::Error = InvalidCertificate(CertificateError::BadEncoding); + +/// Returns a TLS client configuration. +pub fn tls_client_config( + key_pair: &KeyPair, + peer_id: Option, +) -> Result { + let (cert, private_key) = generate_cert(key_pair)?; + let server_verifier = SrvrCertVerifier { peer_id }; + let client_config = rustls::ClientConfig::builder() + .with_cipher_suites(CIPHER_SUITES) + .with_kx_groups(KX_GROUPS) + .with_protocol_versions(PROTOCOL_VERSIONS)? + .with_custom_certificate_verifier(Arc::new(server_verifier)) + .with_client_auth_cert(vec![cert], private_key)?; + + Ok(client_config) +} + +/// Returns a TLS server configuration. +pub fn tls_server_config(key_pair: &KeyPair) -> Result { + let (cert, private_key) = generate_cert(key_pair)?; + let client_verifier = CliCertVerifier {}; + let server_config = rustls::ServerConfig::builder() + .with_cipher_suites(CIPHER_SUITES) + .with_kx_groups(KX_GROUPS) + .with_protocol_versions(PROTOCOL_VERSIONS)? + .with_client_cert_verifier(Arc::new(client_verifier)) + .with_single_cert(vec![cert], private_key)?; + + Ok(server_config) +} + +/// Generates a certificate and returns both the certificate and the private key. +fn generate_cert(key_pair: &KeyPair) -> Result<(Certificate, PrivateKey)> { + let cert_key_pair = rcgen::KeyPair::generate(&rcgen::PKCS_ED25519)?; + let private_key = rustls::PrivateKey(cert_key_pair.serialize_der()); + + // Add a custom extension to the certificate: + // - Sign the certificate's public key with the provided key pair's public key + // - Append both the signature and the key pair's public key to the extension + let signature = key_pair.sign(&cert_key_pair.public_key_der()); + let ext_content = yasna::encode_der(&(key_pair.public().as_bytes().to_vec(), signature)); + // XXX: Not sure about the oid number ??? + let mut ext = rcgen::CustomExtension::from_oid_content(&[0, 0, 0, 0], ext_content); + ext.set_criticality(true); + + let mut params = rcgen::CertificateParams::new(vec![]); + params.alg = &rcgen::PKCS_ED25519; + params.key_pair = Some(cert_key_pair); + params.custom_extensions.push(ext); + + let cert = rustls::Certificate(rcgen::Certificate::from_params(params)?.serialize_der()?); + Ok((cert, private_key)) +} + +/// Verifies the given certification. +fn verify_cert(end_entity: &Certificate) -> std::result::Result { + // Parse the certificate. + let cert = parse_cert(end_entity)?; + + match cert.extensions().first() { + Some(ext) => { + // Extract the peer id (public key) and the signature from the extension. + let (public_key, signature): (Vec, Vec) = + yasna::decode_der(ext.value).map_err(|_| BAD_ENCODING_ERR)?; + + // Use the peer id (public key) to verify the extracted signature. + let public_key = PublicKey::from_bytes(&KeyPairType::Ed25519, &public_key) + .map_err(|_| BAD_ENCODING_ERR)?; + public_key + .verify(cert.public_key().raw, &signature) + .map_err(|_| BAD_SIGNATURE_ERR)?; + + // Verify the certificate signature. + verify_cert_signature( + &cert, + cert.tbs_certificate.as_ref(), + cert.signature_value.as_ref(), + )?; + + PeerID::try_from(public_key).map_err(|_| BAD_ENCODING_ERR) + } + None => Err(BAD_ENCODING_ERR), + } +} + +/// Parses the given x509 certificate. +fn parse_cert(end_entity: &Certificate) -> std::result::Result { + let (_, cert) = parse_x509_certificate(end_entity.as_ref()).map_err(|_| BAD_ENCODING_ERR)?; + + if !cert.validity().is_valid() { + return Err(InvalidCertificate(CertificateError::NotValidYet)); + } + + Ok(cert) +} + +/// Verifies the signature of the given certificate. +fn verify_cert_signature( + cert: &X509Certificate, + message: &[u8], + signature: &[u8], +) -> std::result::Result<(), rustls::Error> { + let public_key = PublicKey::from_bytes( + &KeyPairType::Ed25519, + cert.tbs_certificate.subject_pki.subject_public_key.as_ref(), + ) + .map_err(|_| BAD_ENCODING_ERR)?; + + public_key + .verify(message, signature) + .map_err(|_| BAD_SIGNATURE_ERR) +} + +struct SrvrCertVerifier { + peer_id: Option, +} + +impl ServerCertVerifier for SrvrCertVerifier { + fn verify_server_cert( + &self, + end_entity: &Certificate, + _intermediates: &[Certificate], + _server_name: &rustls::ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: std::time::SystemTime, + ) -> std::result::Result { + let peer_id = match verify_cert(end_entity) { + Ok(pid) => pid, + Err(err) => { + error!("Failed to verify cert: {err}"); + return Err(err); + } + }; + + // Verify that the peer id in the certificate's extension matches the + // one the client intends to connect to. + // Both should be equal for establishing a fully secure connection. + if let Some(pid) = &self.peer_id { + if pid != &peer_id { + return Err(InvalidCertificate( + CertificateError::ApplicationVerificationFailure, + )); + } + } + + Ok(rustls::client::ServerCertVerified::assertion()) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &Certificate, + dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result { + let cert = parse_cert(cert)?; + verify_cert_signature(&cert, message, dss.signature())?; + Ok(rustls::client::HandshakeSignatureValid::assertion()) + } +} + +struct CliCertVerifier {} +impl ClientCertVerifier for CliCertVerifier { + fn verify_client_cert( + &self, + end_entity: &Certificate, + _intermediates: &[Certificate], + _now: std::time::SystemTime, + ) -> std::result::Result { + if let Err(err) = verify_cert(end_entity) { + error!("Failed to verify cert: {err}"); + return Err(err); + }; + Ok(rustls::server::ClientCertVerified::assertion()) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &Certificate, + dss: &rustls::DigitallySignedStruct, + ) -> std::result::Result { + let cert = parse_cert(cert)?; + verify_cert_signature(&cert, message, dss.signature())?; + Ok(rustls::client::HandshakeSignatureValid::assertion()) + } + + fn client_auth_root_subjects(&self) -> &[rustls::DistinguishedName] { + &[] + } +} diff --git a/p2p/src/utils/mod.rs b/p2p/src/utils/mod.rs deleted file mode 100644 index e8ff9d0..0000000 --- a/p2p/src/utils/mod.rs +++ /dev/null @@ -1,21 +0,0 @@ -mod version; - -pub use version::{version_match, Version, VersionInt}; - -use std::net::IpAddr; - -use karyons_net::Addr; - -/// Check if two addresses belong to the same subnet. -pub fn subnet_match(addr: &Addr, other_addr: &Addr) -> bool { - match (addr, other_addr) { - (Addr::Ip(IpAddr::V4(ip)), Addr::Ip(IpAddr::V4(other_ip))) => { - // XXX Consider moving this to a different location - if other_ip.is_loopback() && ip.is_loopback() { - return false; - } - ip.octets()[0..3] == other_ip.octets()[0..3] - } - _ => false, - } -} diff --git a/p2p/src/utils/version.rs b/p2p/src/utils/version.rs deleted file mode 100644 index a101b28..0000000 --- a/p2p/src/utils/version.rs +++ /dev/null @@ -1,93 +0,0 @@ -use std::str::FromStr; - -use bincode::{Decode, Encode}; -use semver::VersionReq; - -use crate::{Error, Result}; - -/// Represents the network version and protocol version used in karyons p2p. -/// -/// # Example -/// -/// ``` -/// use karyons_p2p::Version; -/// -/// let version: Version = "0.2.0, >0.1.0".parse().unwrap(); -/// -/// let version: Version = "0.2.0".parse().unwrap(); -/// -/// ``` -#[derive(Debug, Clone)] -pub struct Version { - pub v: VersionInt, - pub req: VersionReq, -} - -impl Version { - /// Creates a new Version - pub fn new(v: VersionInt, req: VersionReq) -> Self { - Self { v, req } - } -} - -#[derive(Debug, Decode, Encode, Clone)] -pub struct VersionInt { - major: u64, - minor: u64, - patch: u64, -} - -impl FromStr for Version { - type Err = Error; - - fn from_str(s: &str) -> Result { - let v: Vec<&str> = s.split(", ").collect(); - if v.is_empty() || v.len() > 2 { - return Err(Error::ParseError(format!("Invalid version{s}"))); - } - - let version: VersionInt = v[0].parse()?; - let req: VersionReq = if v.len() > 1 { v[1] } else { v[0] }.parse()?; - - Ok(Self { v: version, req }) - } -} - -impl FromStr for VersionInt { - type Err = Error; - - fn from_str(s: &str) -> Result { - let v: Vec<&str> = s.split('.').collect(); - if v.len() < 2 || v.len() > 3 { - return Err(Error::ParseError(format!("Invalid version{s}"))); - } - - let major = v[0].parse::()?; - let minor = v[1].parse::()?; - let patch = v.get(2).unwrap_or(&"0").parse::()?; - - Ok(Self { - major, - minor, - patch, - }) - } -} - -impl std::fmt::Display for VersionInt { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}.{}.{}", self.major, self.minor, self.patch) - } -} - -impl From for semver::Version { - fn from(v: VersionInt) -> Self { - semver::Version::new(v.major, v.minor, v.patch) - } -} - -/// Check if a version satisfies a version request. -pub fn version_match(version_req: &VersionReq, version: &VersionInt) -> bool { - let version: semver::Version = version.clone().into(); - version_req.matches(&version) -} diff --git a/p2p/src/version.rs b/p2p/src/version.rs new file mode 100644 index 0000000..a101b28 --- /dev/null +++ b/p2p/src/version.rs @@ -0,0 +1,93 @@ +use std::str::FromStr; + +use bincode::{Decode, Encode}; +use semver::VersionReq; + +use crate::{Error, Result}; + +/// Represents the network version and protocol version used in karyons p2p. +/// +/// # Example +/// +/// ``` +/// use karyons_p2p::Version; +/// +/// let version: Version = "0.2.0, >0.1.0".parse().unwrap(); +/// +/// let version: Version = "0.2.0".parse().unwrap(); +/// +/// ``` +#[derive(Debug, Clone)] +pub struct Version { + pub v: VersionInt, + pub req: VersionReq, +} + +impl Version { + /// Creates a new Version + pub fn new(v: VersionInt, req: VersionReq) -> Self { + Self { v, req } + } +} + +#[derive(Debug, Decode, Encode, Clone)] +pub struct VersionInt { + major: u64, + minor: u64, + patch: u64, +} + +impl FromStr for Version { + type Err = Error; + + fn from_str(s: &str) -> Result { + let v: Vec<&str> = s.split(", ").collect(); + if v.is_empty() || v.len() > 2 { + return Err(Error::ParseError(format!("Invalid version{s}"))); + } + + let version: VersionInt = v[0].parse()?; + let req: VersionReq = if v.len() > 1 { v[1] } else { v[0] }.parse()?; + + Ok(Self { v: version, req }) + } +} + +impl FromStr for VersionInt { + type Err = Error; + + fn from_str(s: &str) -> Result { + let v: Vec<&str> = s.split('.').collect(); + if v.len() < 2 || v.len() > 3 { + return Err(Error::ParseError(format!("Invalid version{s}"))); + } + + let major = v[0].parse::()?; + let minor = v[1].parse::()?; + let patch = v.get(2).unwrap_or(&"0").parse::()?; + + Ok(Self { + major, + minor, + patch, + }) + } +} + +impl std::fmt::Display for VersionInt { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}.{}", self.major, self.minor, self.patch) + } +} + +impl From for semver::Version { + fn from(v: VersionInt) -> Self { + semver::Version::new(v.major, v.minor, v.patch) + } +} + +/// Check if a version satisfies a version request. +pub fn version_match(version_req: &VersionReq, version: &VersionInt) -> bool { + let version: semver::Version = version.clone().into(); + version_req.matches(&version) +} -- cgit v1.2.3