aboutsummaryrefslogtreecommitdiff
path: root/p2p/src/peer/mod.rs
blob: 85cd558053e3dce827d42f12a4920b1894f3ac58 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
mod peer_id;

pub use peer_id::PeerID;

use std::sync::Arc;

use log::{error, trace};
use smol::{
    channel::{self, Receiver, Sender},
    lock::RwLock,
};

use karyons_core::{
    async_utils::{select, Either, TaskGroup, TaskResult},
    event::{ArcEventSys, EventListener, EventSys},
    utils::{decode, encode},
    GlobalExecutor,
};

use karyons_net::Endpoint;

use crate::{
    connection::ConnDirection,
    io_codec::{CodecMsg, IOCodec},
    message::{NetMsgCmd, ProtocolMsg, ShutdownMsg},
    peer_pool::{ArcPeerPool, WeakPeerPool},
    protocol::{Protocol, ProtocolEvent, ProtocolID},
    Config, Error, Result,
};

pub type ArcPeer = Arc<Peer>;

pub struct Peer {
    /// Peer's ID
    id: PeerID,

    /// A weak pointer to `PeerPool`
    peer_pool: WeakPeerPool,

    /// Holds the IOCodec for the peer connection
    io_codec: IOCodec,

    /// Remote endpoint for the peer
    remote_endpoint: Endpoint,

    /// The direction of the connection, either `Inbound` or `Outbound`
    conn_direction: ConnDirection,

    /// A list of protocol IDs
    protocol_ids: RwLock<Vec<ProtocolID>>,

    /// `EventSys` responsible for sending events to the protocols.
    protocol_events: ArcEventSys<ProtocolID>,

    /// This channel is used to send a stop signal to the read loop.
    stop_chan: (Sender<Result<()>>, Receiver<Result<()>>),

    /// Managing spawned tasks.
    task_group: TaskGroup<'static>,
}

impl Peer {
    /// Creates a new peer
    pub fn new(
        peer_pool: WeakPeerPool,
        id: &PeerID,
        io_codec: IOCodec,
        remote_endpoint: Endpoint,
        conn_direction: ConnDirection,
        ex: GlobalExecutor,
    ) -> ArcPeer {
        Arc::new(Peer {
            id: id.clone(),
            peer_pool,
            io_codec,
            protocol_ids: RwLock::new(Vec::new()),
            remote_endpoint,
            conn_direction,
            protocol_events: EventSys::new(),
            task_group: TaskGroup::new(ex),
            stop_chan: channel::bounded(1),
        })
    }

    /// Run the peer
    pub async fn run(self: Arc<Self>, ex: GlobalExecutor) -> Result<()> {
        self.start_protocols(ex).await;
        self.read_loop().await
    }

    /// Send a message to the peer connection using the specified protocol.
    pub async fn send<T: CodecMsg>(&self, protocol_id: &ProtocolID, msg: &T) -> Result<()> {
        let payload = encode(msg)?;

        let proto_msg = ProtocolMsg {
            protocol_id: protocol_id.to_string(),
            payload: payload.to_vec(),
        };

        self.io_codec.write(NetMsgCmd::Protocol, &proto_msg).await?;
        Ok(())
    }

    /// Broadcast a message to all connected peers using the specified protocol.
    pub async fn broadcast<T: CodecMsg>(&self, protocol_id: &ProtocolID, msg: &T) {
        self.peer_pool().broadcast(protocol_id, msg).await;
    }

    /// Shuts down the peer
    pub async fn shutdown(&self) {
        trace!("peer {} start shutting down", self.id);

        // Send shutdown event to all protocols
        for protocol_id in self.protocol_ids.read().await.iter() {
            self.protocol_events
                .emit_by_topic(protocol_id, &ProtocolEvent::Shutdown)
                .await;
        }

        // Send a stop signal to the read loop
        //
        // No need to handle the error here; a dropped channel and
        // sending a stop signal have the same effect.
        let _ = self.stop_chan.0.try_send(Ok(()));

        // No need to handle the error here
        let _ = self
            .io_codec
            .write(NetMsgCmd::Shutdown, &ShutdownMsg(0))
            .await;

        // Force shutting down
        self.task_group.cancel().await;
    }

    /// Check if the connection is Inbound
    #[inline]
    pub fn is_inbound(&self) -> bool {
        match self.conn_direction {
            ConnDirection::Inbound => true,
            ConnDirection::Outbound => false,
        }
    }

    /// Returns the direction of the connection, which can be either `Inbound`
    /// or `Outbound`.
    #[inline]
    pub fn direction(&self) -> &ConnDirection {
        &self.conn_direction
    }

    /// Returns the remote endpoint for the peer
    #[inline]
    pub fn remote_endpoint(&self) -> &Endpoint {
        &self.remote_endpoint
    }

    /// Return the peer's ID
    #[inline]
    pub fn id(&self) -> &PeerID {
        &self.id
    }

    /// Returns the `Config` instance.
    pub fn config(&self) -> Arc<Config> {
        self.peer_pool().config.clone()
    }

    /// Registers a listener for the given Protocol `P`.
    pub async fn register_listener<P: Protocol>(&self) -> EventListener<ProtocolID, ProtocolEvent> {
        self.protocol_events.register(&P::id()).await
    }

    /// Start a read loop to handle incoming messages from the peer connection.
    async fn read_loop(&self) -> Result<()> {
        loop {
            let fut = select(self.stop_chan.1.recv(), self.io_codec.read()).await;
            let result = match fut {
                Either::Left(stop_signal) => {
                    trace!("Peer {} received a stop signal", self.id);
                    return stop_signal?;
                }
                Either::Right(result) => result,
            };

            let msg = result?;

            match msg.header.command {
                NetMsgCmd::Protocol => {
                    let msg: ProtocolMsg = decode(&msg.payload)?.0;

                    if !self.protocol_ids.read().await.contains(&msg.protocol_id) {
                        return Err(Error::UnsupportedProtocol(msg.protocol_id));
                    }

                    let proto_id = &msg.protocol_id;
                    let msg = ProtocolEvent::Message(msg.payload);
                    self.protocol_events.emit_by_topic(proto_id, &msg).await;
                }
                NetMsgCmd::Shutdown => {
                    return Err(Error::PeerShutdown);
                }
                command => return Err(Error::InvalidMsg(format!("Unexpected msg {:?}", command))),
            }
        }
    }

    /// Start running the protocols for this peer connection.
    async fn start_protocols(self: &Arc<Self>, ex: GlobalExecutor) {
        for (protocol_id, constructor) in self.peer_pool().protocols.read().await.iter() {
            trace!("peer {} start protocol {protocol_id}", self.id);
            let protocol = constructor(self.clone());

            self.protocol_ids.write().await.push(protocol_id.clone());

            let selfc = self.clone();
            let proto_idc = protocol_id.clone();

            let on_failure = |result: TaskResult<Result<()>>| async move {
                if let TaskResult::Completed(res) = result {
                    if res.is_err() {
                        error!("protocol {} stopped", proto_idc);
                    }
                    // Send a stop signal to read loop
                    let _ = selfc.stop_chan.0.try_send(res);
                }
            };

            self.task_group
                .spawn(protocol.start(ex.clone()), on_failure);
        }
    }

    fn peer_pool(&self) -> ArcPeerPool {
        self.peer_pool.upgrade().unwrap()
    }
}