aboutsummaryrefslogtreecommitdiff
path: root/p2p/src/protocols/handshake.rs
blob: b3fe989e2ba2ce4e6033991ce52a502abb7ef389 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
use std::{collections::HashMap, sync::Arc, time::Duration};

use async_trait::async_trait;
use log::trace;

use karyon_core::{async_util::timeout, util::decode};

use crate::{
    message::{NetMsg, NetMsgCmd, VerAckMsg, VerMsg},
    peer::Peer,
    protocol::{InitProtocol, ProtocolID},
    version::{version_match, VersionInt},
    Error, PeerID, Result, Version,
};

pub struct HandshakeProtocol {
    peer: Arc<Peer>,
    protocols: HashMap<ProtocolID, Version>,
}

#[async_trait]
impl InitProtocol for HandshakeProtocol {
    type T = Result<PeerID>;
    /// Initiate a handshake with a connection.
    async fn init(self: Arc<Self>) -> Self::T {
        trace!("Init Handshake: {}", self.peer.remote_endpoint());

        if !self.peer.is_inbound() {
            self.send_vermsg().await?;
        }

        let t = Duration::from_secs(self.peer.config().handshake_timeout);
        let msg: NetMsg = timeout(t, self.peer.conn.recv_inner()).await??;
        match msg.header.command {
            NetMsgCmd::Version => {
                let result = self.validate_version_msg(&msg).await;
                match result {
                    Ok(_) => {
                        self.send_verack(true).await?;
                    }
                    Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => {
                        self.send_verack(false).await?;
                    }
                    _ => {}
                };
                result
            }
            NetMsgCmd::Verack => self.validate_verack_msg(&msg).await,
            cmd => Err(Error::InvalidMsg(format!("unexpected msg found {:?}", cmd))),
        }
    }
}

impl HandshakeProtocol {
    pub fn new(peer: Arc<Peer>, protocols: HashMap<ProtocolID, Version>) -> Arc<Self> {
        Arc::new(Self { peer, protocols })
    }

    /// Sends a Version message
    async fn send_vermsg(&self) -> Result<()> {
        let protocols = self
            .protocols
            .clone()
            .into_iter()
            .map(|p| (p.0, p.1.v))
            .collect();

        let vermsg = VerMsg {
            peer_id: self.peer.own_id().clone(),
            protocols,
            version: self.peer.config().version.v.clone(),
        };

        trace!("Send VerMsg");
        self.peer
            .conn
            .send_inner(NetMsg::new(NetMsgCmd::Version, &vermsg)?)
            .await?;
        Ok(())
    }

    /// Sends a Verack message
    async fn send_verack(&self, ack: bool) -> Result<()> {
        let verack = VerAckMsg {
            peer_id: self.peer.own_id().clone(),
            ack,
        };

        trace!("Send VerAckMsg {:?}", verack);
        self.peer
            .conn
            .send_inner(NetMsg::new(NetMsgCmd::Verack, &verack)?)
            .await?;
        Ok(())
    }

    /// Validates the given version msg
    async fn validate_version_msg(&self, msg: &NetMsg) -> Result<PeerID> {
        let (vermsg, _) = decode::<VerMsg>(&msg.payload)?;

        if !version_match(&self.peer.config().version.req, &vermsg.version) {
            return Err(Error::IncompatibleVersion("system: {}".into()));
        }

        self.protocols_match(&vermsg.protocols).await?;

        trace!("Received VerMsg from: {}", vermsg.peer_id);
        Ok(vermsg.peer_id)
    }

    /// Validates the given verack msg
    async fn validate_verack_msg(&self, msg: &NetMsg) -> Result<PeerID> {
        let (verack, _) = decode::<VerAckMsg>(&msg.payload)?;

        if !verack.ack {
            return Err(Error::IncompatiblePeer);
        }

        trace!("Received VerAckMsg from: {}", verack.peer_id);
        Ok(verack.peer_id)
    }

    /// Check if the new connection has compatible protocols.
    async fn protocols_match(&self, protocols: &HashMap<ProtocolID, VersionInt>) -> Result<()> {
        for (n, pv) in protocols.iter() {
            match self.protocols.get(n) {
                Some(v) => {
                    if !version_match(&v.req, pv) {
                        return Err(Error::IncompatibleVersion(format!("{n} protocol: {pv}")));
                    }
                }
                None => {
                    return Err(Error::UnsupportedProtocol(n.to_string()));
                }
            }
        }
        Ok(())
    }
}