diff options
Diffstat (limited to 'p2p/src/peer_pool.rs')
-rw-r--r-- | p2p/src/peer_pool.rs | 56 |
1 files changed, 35 insertions, 21 deletions
diff --git a/p2p/src/peer_pool.rs b/p2p/src/peer_pool.rs index e2a9de7..a0079f2 100644 --- a/p2p/src/peer_pool.rs +++ b/p2p/src/peer_pool.rs @@ -100,18 +100,15 @@ impl PeerPool { pub async fn listen_loop(self: Arc<Self>) { loop { let new_conn = self.conn_queue.next().await; - let disconnect_signal = new_conn.disconnect_signal; + let signal = new_conn.disconnect_signal; let result = self - .new_peer( - new_conn.conn, - &new_conn.direction, - disconnect_signal.clone(), - ) + .new_peer(new_conn.conn, &new_conn.direction, signal.clone()) .await; + // Only send a disconnect signal if there is an error when adding a peer. if result.is_err() { - let _ = disconnect_signal.send(()).await; + let _ = signal.send(result).await; } } } @@ -155,12 +152,12 @@ impl PeerPool { self: &Arc<Self>, conn: Conn, conn_direction: &ConnDirection, - disconnect_signal: Sender<()>, - ) -> Result<PeerID> { + disconnect_signal: Sender<Result<()>>, + ) -> Result<()> { let endpoint = conn.peer_endpoint()?; let io_codec = IOCodec::new(conn); - // Do a handshake with a connection before creating a new peer. + // Do a handshake with the connection before creating a new peer. let pid = self.do_handshake(&io_codec, conn_direction).await?; // TODO: Consider restricting the subnet for inbound connections @@ -184,11 +181,11 @@ impl PeerPool { let selfc = self.clone(); let pid_c = pid.clone(); let on_disconnect = |result| async move { - if let TaskResult::Completed(_) = result { + if let TaskResult::Completed(result) = result { if let Err(err) = selfc.remove_peer(&pid_c).await { error!("Failed to remove peer {pid_c}: {err}"); } - let _ = disconnect_signal.send(()).await; + let _ = disconnect_signal.send(result).await; } }; @@ -200,7 +197,8 @@ impl PeerPool { self.monitor .notify(&PeerPoolEvent::NewPeer(pid.clone()).into()) .await; - Ok(pid) + + Ok(()) } /// Checks if the peer list contains a peer with the given peer id @@ -244,10 +242,19 @@ impl PeerPool { ) -> Result<PeerID> { match conn_direction { ConnDirection::Inbound => { - let pid = self.wait_vermsg(io_codec).await?; - self.send_verack(io_codec).await?; - Ok(pid) + let result = self.wait_vermsg(io_codec).await; + match result { + Ok(_) => { + self.send_verack(io_codec, true).await?; + } + Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => { + self.send_verack(io_codec, false).await?; + } + _ => {} + } + result } + ConnDirection::Outbound => { self.send_vermsg(io_codec).await?; self.wait_verack(io_codec).await @@ -293,10 +300,13 @@ impl PeerPool { } /// Send a Verack message - async fn send_verack(&self, io_codec: &IOCodec) -> Result<()> { - let verack = VerAckMsg(self.id.clone()); + async fn send_verack(&self, io_codec: &IOCodec, ack: bool) -> Result<()> { + let verack = VerAckMsg { + peer_id: self.id.clone(), + ack, + }; - trace!("Send VerAckMsg"); + trace!("Send VerAckMsg {:?}", verack); io_codec.write(NetMsgCmd::Verack, &verack).await?; Ok(()) } @@ -311,8 +321,12 @@ impl PeerPool { let payload = get_msg_payload!(Verack, msg); let (verack, _) = decode::<VerAckMsg>(&payload)?; - trace!("Received VerAckMsg from: {}", verack.0); - Ok(verack.0) + if !verack.ack { + return Err(Error::IncompatiblePeer); + } + + trace!("Received VerAckMsg from: {}", verack.peer_id); + Ok(verack.peer_id) } /// Check if the new connection has compatible protocols. |