aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--net/src/endpoint.rs148
-rw-r--r--p2p/src/connection.rs3
-rw-r--r--p2p/src/connector.rs8
-rw-r--r--p2p/src/discovery/lookup.rs27
-rw-r--r--p2p/src/discovery/mod.rs1
-rw-r--r--p2p/src/error.rs3
-rw-r--r--p2p/src/listener.rs8
7 files changed, 131 insertions, 67 deletions
diff --git a/net/src/endpoint.rs b/net/src/endpoint.rs
index c3626ec..dff703d 100644
--- a/net/src/endpoint.rs
+++ b/net/src/endpoint.rs
@@ -45,6 +45,98 @@ pub enum Endpoint {
Unix(PathBuf),
}
+impl Endpoint {
+ /// Creates a new TCP endpoint from a `SocketAddr`.
+ pub fn new_tcp_addr(addr: SocketAddr) -> Endpoint {
+ Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port())
+ }
+
+ /// Creates a new UDP endpoint from a `SocketAddr`.
+ pub fn new_udp_addr(addr: SocketAddr) -> Endpoint {
+ Endpoint::Udp(Addr::Ip(addr.ip()), addr.port())
+ }
+
+ /// Creates a new TLS endpoint from a `SocketAddr`.
+ pub fn new_tls_addr(addr: SocketAddr) -> Endpoint {
+ Endpoint::Tls(Addr::Ip(addr.ip()), addr.port())
+ }
+
+ /// Creates a new WS endpoint from a `SocketAddr`.
+ pub fn new_ws_addr(addr: SocketAddr) -> Endpoint {
+ Endpoint::Ws(Addr::Ip(addr.ip()), addr.port())
+ }
+
+ /// Creates a new WSS endpoint from a `SocketAddr`.
+ pub fn new_wss_addr(addr: SocketAddr) -> Endpoint {
+ Endpoint::Wss(Addr::Ip(addr.ip()), addr.port())
+ }
+
+ /// Creates a new Unix endpoint from a `UnixSocketAddr`.
+ pub fn new_unix_addr(addr: &std::path::Path) -> Endpoint {
+ Endpoint::Unix(addr.to_path_buf())
+ }
+
+ #[inline]
+ /// Checks if the `Endpoint` is of type `Tcp`.
+ pub fn is_tcp(&self) -> bool {
+ matches!(self, Endpoint::Tcp(..))
+ }
+
+ #[inline]
+ /// Checks if the `Endpoint` is of type `Tls`.
+ pub fn is_tls(&self) -> bool {
+ matches!(self, Endpoint::Tls(..))
+ }
+
+ #[inline]
+ /// Checks if the `Endpoint` is of type `Ws`.
+ pub fn is_ws(&self) -> bool {
+ matches!(self, Endpoint::Ws(..))
+ }
+
+ #[inline]
+ /// Checks if the `Endpoint` is of type `Wss`.
+ pub fn is_wss(&self) -> bool {
+ matches!(self, Endpoint::Wss(..))
+ }
+
+ #[inline]
+ /// Checks if the `Endpoint` is of type `Udp`.
+ pub fn is_udp(&self) -> bool {
+ matches!(self, Endpoint::Udp(..))
+ }
+
+ #[inline]
+ /// Checks if the `Endpoint` is of type `Unix`.
+ pub fn is_unix(&self) -> bool {
+ matches!(self, Endpoint::Unix(..))
+ }
+
+ /// Returns the `Port` of the endpoint.
+ pub fn port(&self) -> Result<&Port> {
+ match self {
+ Endpoint::Tcp(_, port)
+ | Endpoint::Udp(_, port)
+ | Endpoint::Tls(_, port)
+ | Endpoint::Ws(_, port)
+ | Endpoint::Wss(_, port) => Ok(port),
+ _ => Err(Error::TryFromEndpoint),
+ }
+ }
+
+ /// Returns the `Addr` of the endpoint.
+ pub fn addr(&self) -> Result<&Addr> {
+ match self {
+ Endpoint::Tcp(addr, _)
+ | Endpoint::Udp(addr, _)
+ | Endpoint::Tls(addr, _)
+ | Endpoint::Ws(addr, _)
+ | Endpoint::Wss(addr, _) => Ok(addr),
+ _ => Err(Error::TryFromEndpoint),
+ }
+ }
+}
+
impl std::fmt::Display for Endpoint {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
@@ -152,62 +244,6 @@ impl FromStr for Endpoint {
}
}
-impl Endpoint {
- /// Creates a new TCP endpoint from a `SocketAddr`.
- pub fn new_tcp_addr(addr: SocketAddr) -> Endpoint {
- Endpoint::Tcp(Addr::Ip(addr.ip()), addr.port())
- }
-
- /// Creates a new UDP endpoint from a `SocketAddr`.
- pub fn new_udp_addr(addr: SocketAddr) -> Endpoint {
- Endpoint::Udp(Addr::Ip(addr.ip()), addr.port())
- }
-
- /// Creates a new TLS endpoint from a `SocketAddr`.
- pub fn new_tls_addr(addr: SocketAddr) -> Endpoint {
- Endpoint::Tls(Addr::Ip(addr.ip()), addr.port())
- }
-
- /// Creates a new WS endpoint from a `SocketAddr`.
- pub fn new_ws_addr(addr: SocketAddr) -> Endpoint {
- Endpoint::Ws(Addr::Ip(addr.ip()), addr.port())
- }
-
- /// Creates a new WSS endpoint from a `SocketAddr`.
- pub fn new_wss_addr(addr: SocketAddr) -> Endpoint {
- Endpoint::Wss(Addr::Ip(addr.ip()), addr.port())
- }
-
- /// Creates a new Unix endpoint from a `UnixSocketAddr`.
- pub fn new_unix_addr(addr: &std::path::Path) -> Endpoint {
- Endpoint::Unix(addr.to_path_buf())
- }
-
- /// Returns the `Port` of the endpoint.
- pub fn port(&self) -> Result<&Port> {
- match self {
- Endpoint::Tcp(_, port)
- | Endpoint::Udp(_, port)
- | Endpoint::Tls(_, port)
- | Endpoint::Ws(_, port)
- | Endpoint::Wss(_, port) => Ok(port),
- _ => Err(Error::TryFromEndpoint),
- }
- }
-
- /// Returns the `Addr` of the endpoint.
- pub fn addr(&self) -> Result<&Addr> {
- match self {
- Endpoint::Tcp(addr, _)
- | Endpoint::Udp(addr, _)
- | Endpoint::Tls(addr, _)
- | Endpoint::Ws(addr, _)
- | Endpoint::Wss(addr, _) => Ok(addr),
- _ => Err(Error::TryFromEndpoint),
- }
- }
-}
-
/// Addr defines a type for an address, either IP or domain.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Encode, Decode)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
diff --git a/p2p/src/connection.rs b/p2p/src/connection.rs
index c1a7a8c..f2e9d1e 100644
--- a/p2p/src/connection.rs
+++ b/p2p/src/connection.rs
@@ -74,8 +74,7 @@ impl Connection {
pub async fn recv<P: Protocol>(&self) -> Result<ProtocolEvent> {
match self.listeners.get(&P::id()) {
Some(l) => l.recv().await.map_err(Error::from),
- // TODO
- None => todo!(),
+ None => Err(Error::UnsupportedProtocol(P::id())),
}
}
diff --git a/p2p/src/connector.rs b/p2p/src/connector.rs
index cfa661b..98cdfc7 100644
--- a/p2p/src/connector.rs
+++ b/p2p/src/connector.rs
@@ -148,6 +148,10 @@ impl Connector {
async fn dial(&self, endpoint: &Endpoint, peer_id: &Option<PeerID>) -> Result<Conn<NetMsg>> {
if self.enable_tls {
+ if !endpoint.is_tcp() && !endpoint.is_tls() {
+ return Err(Error::UnsupportedEndpoint(endpoint.to_string()));
+ }
+
let tls_config = tls::ClientTlsConfig {
tcp_config: Default::default(),
client_config: tls_client_config(&self.key_pair, peer_id.clone())?,
@@ -157,6 +161,10 @@ impl Connector {
.await
.map(|l| Box::new(l) as karyon_net::Conn<NetMsg>)
} else {
+ if !endpoint.is_tcp() {
+ return Err(Error::UnsupportedEndpoint(endpoint.to_string()));
+ }
+
tcp::dial(endpoint, tcp::TcpConfig::default(), NetMsgCodec::new())
.await
.map(|l| Box::new(l) as karyon_net::Conn<NetMsg>)
diff --git a/p2p/src/discovery/lookup.rs b/p2p/src/discovery/lookup.rs
index 47a1d09..ba15da3 100644
--- a/p2p/src/discovery/lookup.rs
+++ b/p2p/src/discovery/lookup.rs
@@ -41,6 +41,9 @@ pub struct LookupService {
/// Resolved listen endpoint
listen_endpoint: RwLock<Option<Endpoint>>,
+ /// Resolved discovery endpoint
+ discovery_endpoint: RwLock<Option<Endpoint>>,
+
/// Holds the configuration for the P2P network.
config: Arc<Config>,
@@ -52,7 +55,6 @@ impl LookupService {
/// Creates a new lookup service
pub fn new(
key_pair: &KeyPair,
- id: &PeerID,
table: Arc<RoutingTable>,
config: Arc<Config>,
monitor: Arc<Monitor>,
@@ -78,13 +80,18 @@ impl LookupService {
ex,
);
+ let id = key_pair
+ .public()
+ .try_into()
+ .expect("Get PeerID from KeyPair");
Self {
- id: id.clone(),
+ id,
table,
listener,
connector,
outbound_slots,
listen_endpoint: RwLock::new(None),
+ discovery_endpoint: RwLock::new(None),
config,
monitor,
}
@@ -98,19 +105,23 @@ impl LookupService {
/// Set the resolved listen endpoint.
pub fn set_listen_endpoint(&self, resolved_endpoint: &Endpoint) -> Result<()> {
- let resolved_endpoint = Endpoint::Tcp(
+ let discovery_endpoint = Endpoint::Tcp(
resolved_endpoint.addr()?.clone(),
self.config.discovery_port,
);
- *self.listen_endpoint.write() = Some(resolved_endpoint);
+ *self.listen_endpoint.write() = Some(resolved_endpoint.clone());
+ *self.discovery_endpoint.write() = Some(discovery_endpoint.clone());
Ok(())
}
- /// Get the listening endpoint.
pub fn listen_endpoint(&self) -> Option<Endpoint> {
self.listen_endpoint.read().clone()
}
+ pub fn discovery_endpoint(&self) -> Option<Endpoint> {
+ self.discovery_endpoint.read().clone()
+ }
+
/// Shuts down the lookup service.
pub async fn shutdown(&self) {
self.connector.shutdown().await;
@@ -278,7 +289,7 @@ impl LookupService {
trace!("Send Peer msg");
if let Some(endpoint) = self.listen_endpoint() {
- self.send_peer_msg(&conn, endpoint.clone()).await?;
+ self.send_peer_msg(&conn, endpoint).await?;
}
trace!("Send Shutdown msg");
@@ -289,8 +300,8 @@ impl LookupService {
/// Start a listener.
async fn start_listener(self: &Arc<Self>) -> Result<()> {
- let endpoint: Endpoint = match self.listen_endpoint() {
- Some(e) => e.clone(),
+ let endpoint: Endpoint = match self.discovery_endpoint() {
+ Some(e) => e,
None => return Ok(()),
};
diff --git a/p2p/src/discovery/mod.rs b/p2p/src/discovery/mod.rs
index a81a817..b4c20bf 100644
--- a/p2p/src/discovery/mod.rs
+++ b/p2p/src/discovery/mod.rs
@@ -80,7 +80,6 @@ impl Discovery {
let lookup_service = Arc::new(LookupService::new(
key_pair,
- peer_id,
table.clone(),
config.clone(),
monitor.clone(),
diff --git a/p2p/src/error.rs b/p2p/src/error.rs
index a490b57..cc30aff 100644
--- a/p2p/src/error.rs
+++ b/p2p/src/error.rs
@@ -11,6 +11,9 @@ pub enum Error {
#[error("Unsupported protocol error: {0}")]
UnsupportedProtocol(String),
+ #[error("Unsupported Endpoint: {0}")]
+ UnsupportedEndpoint(String),
+
#[error("PeerID try from PublicKey Error")]
PeerIDTryFromPublicKey,
diff --git a/p2p/src/listener.rs b/p2p/src/listener.rs
index 8a5deaa..347099d 100644
--- a/p2p/src/listener.rs
+++ b/p2p/src/listener.rs
@@ -157,6 +157,10 @@ impl Listener {
async fn listen(&self, endpoint: &Endpoint) -> Result<karyon_net::Listener<NetMsg>> {
if self.enable_tls {
+ if !endpoint.is_tcp() && !endpoint.is_tls() {
+ return Err(Error::UnsupportedEndpoint(endpoint.to_string()));
+ }
+
let tls_config = tls::ServerTlsConfig {
tcp_config: Default::default(),
server_config: tls_server_config(&self.key_pair)?,
@@ -165,6 +169,10 @@ impl Listener {
.await
.map(|l| Box::new(l) as karyon_net::Listener<NetMsg>)
} else {
+ if !endpoint.is_tcp() {
+ return Err(Error::UnsupportedEndpoint(endpoint.to_string()));
+ }
+
tcp::listen(endpoint, tcp::TcpConfig::default(), NetMsgCodec::new())
.await
.map(|l| Box::new(l) as karyon_net::Listener<NetMsg>)