diff options
Diffstat (limited to 'p2p/src/peer_pool.rs')
| -rw-r--r-- | p2p/src/peer_pool.rs | 56 | 
1 files changed, 35 insertions, 21 deletions
diff --git a/p2p/src/peer_pool.rs b/p2p/src/peer_pool.rs index e2a9de7..a0079f2 100644 --- a/p2p/src/peer_pool.rs +++ b/p2p/src/peer_pool.rs @@ -100,18 +100,15 @@ impl PeerPool {      pub async fn listen_loop(self: Arc<Self>) {          loop {              let new_conn = self.conn_queue.next().await; -            let disconnect_signal = new_conn.disconnect_signal; +            let signal = new_conn.disconnect_signal;              let result = self -                .new_peer( -                    new_conn.conn, -                    &new_conn.direction, -                    disconnect_signal.clone(), -                ) +                .new_peer(new_conn.conn, &new_conn.direction, signal.clone())                  .await; +            // Only send a disconnect signal if there is an error when adding a peer.              if result.is_err() { -                let _ = disconnect_signal.send(()).await; +                let _ = signal.send(result).await;              }          }      } @@ -155,12 +152,12 @@ impl PeerPool {          self: &Arc<Self>,          conn: Conn,          conn_direction: &ConnDirection, -        disconnect_signal: Sender<()>, -    ) -> Result<PeerID> { +        disconnect_signal: Sender<Result<()>>, +    ) -> Result<()> {          let endpoint = conn.peer_endpoint()?;          let io_codec = IOCodec::new(conn); -        // Do a handshake with a connection before creating a new peer. +        // Do a handshake with the connection before creating a new peer.          let pid = self.do_handshake(&io_codec, conn_direction).await?;          // TODO: Consider restricting the subnet for inbound connections @@ -184,11 +181,11 @@ impl PeerPool {          let selfc = self.clone();          let pid_c = pid.clone();          let on_disconnect = |result| async move { -            if let TaskResult::Completed(_) = result { +            if let TaskResult::Completed(result) = result {                  if let Err(err) = selfc.remove_peer(&pid_c).await {                      error!("Failed to remove peer {pid_c}: {err}");                  } -                let _ = disconnect_signal.send(()).await; +                let _ = disconnect_signal.send(result).await;              }          }; @@ -200,7 +197,8 @@ impl PeerPool {          self.monitor              .notify(&PeerPoolEvent::NewPeer(pid.clone()).into())              .await; -        Ok(pid) + +        Ok(())      }      /// Checks if the peer list contains a peer with the given peer id @@ -244,10 +242,19 @@ impl PeerPool {      ) -> Result<PeerID> {          match conn_direction {              ConnDirection::Inbound => { -                let pid = self.wait_vermsg(io_codec).await?; -                self.send_verack(io_codec).await?; -                Ok(pid) +                let result = self.wait_vermsg(io_codec).await; +                match result { +                    Ok(_) => { +                        self.send_verack(io_codec, true).await?; +                    } +                    Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => { +                        self.send_verack(io_codec, false).await?; +                    } +                    _ => {} +                } +                result              } +              ConnDirection::Outbound => {                  self.send_vermsg(io_codec).await?;                  self.wait_verack(io_codec).await @@ -293,10 +300,13 @@ impl PeerPool {      }      /// Send a Verack message -    async fn send_verack(&self, io_codec: &IOCodec) -> Result<()> { -        let verack = VerAckMsg(self.id.clone()); +    async fn send_verack(&self, io_codec: &IOCodec, ack: bool) -> Result<()> { +        let verack = VerAckMsg { +            peer_id: self.id.clone(), +            ack, +        }; -        trace!("Send VerAckMsg"); +        trace!("Send VerAckMsg {:?}", verack);          io_codec.write(NetMsgCmd::Verack, &verack).await?;          Ok(())      } @@ -311,8 +321,12 @@ impl PeerPool {          let payload = get_msg_payload!(Verack, msg);          let (verack, _) = decode::<VerAckMsg>(&payload)?; -        trace!("Received VerAckMsg from: {}", verack.0); -        Ok(verack.0) +        if !verack.ack { +            return Err(Error::IncompatiblePeer); +        } + +        trace!("Received VerAckMsg from: {}", verack.peer_id); +        Ok(verack.peer_id)      }      /// Check if the new connection has compatible protocols.  | 
