From 9aa972dd83a85cec5da71e8e893eb6e07d5db8ca Mon Sep 17 00:00:00 2001 From: hozan23 Date: Fri, 21 Jun 2024 22:45:17 +0200 Subject: jsonrpc/client: fix subscription error when the subscriber cannot keep up Add a limit for receiving notifications for the subscription. If this limit is exceeded, the client will stop and raise an error. The limit is configurable when building a new client. --- jsonrpc/src/client/mod.rs | 82 ++++++++++++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 29 deletions(-) (limited to 'jsonrpc/src/client/mod.rs') diff --git a/jsonrpc/src/client/mod.rs b/jsonrpc/src/client/mod.rs index eddba19..80125b1 100644 --- a/jsonrpc/src/client/mod.rs +++ b/jsonrpc/src/client/mod.rs @@ -2,14 +2,21 @@ pub mod builder; mod message_dispatcher; mod subscriptions; -use std::{sync::Arc, time::Duration}; +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, +}; +use async_channel::{Receiver, Sender}; use log::{debug, error}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::json; use karyon_core::{ - async_util::{timeout, TaskGroup, TaskResult}, + async_util::{select, timeout, Either, TaskGroup, TaskResult}, util::random_32, }; use karyon_net::Conn; @@ -27,10 +34,11 @@ type RequestID = u32; /// Represents an RPC client pub struct Client { - conn: Conn, timeout: Option, + disconnect: AtomicBool, message_dispatcher: MessageDispatcher, task_group: TaskGroup, + send_chan: (Sender, Receiver), subscriptions: Arc, } @@ -60,12 +68,12 @@ impl Client { /// /// This function sends a subscription request to the specified method /// with the given parameters. It waits for the response and returns a - /// tuple containing a `SubscriptionID` and a `Subscription` (channel receiver). + /// `Subscription`. pub async fn subscribe( &self, method: &str, params: T, - ) -> Result<(SubscriptionID, Subscription)> { + ) -> Result> { let response = self.send_request(method, params).await?; let sub_id = match response.result { @@ -75,7 +83,7 @@ impl Client { let sub = self.subscriptions.subscribe(sub_id).await; - Ok((sub_id, sub)) + Ok(sub) } /// Unsubscribes from the provided method, waits for the response, and returns the result. @@ -101,11 +109,8 @@ impl Client { params: Some(json!(params)), }; - let req_json = serde_json::to_value(&request)?; - - // Send the json request - self.conn.send(req_json).await?; - debug!("--> {request}"); + // Send the request + self.send(request).await?; // Register a new request let rx = self.message_dispatcher.register(id).await; @@ -131,37 +136,57 @@ impl Client { // It should be OK to unwrap here, as the message dispatcher checks // for the response id. - if *response.id.as_ref().unwrap() != request.id { + if *response.id.as_ref().unwrap() != id { return Err(Error::InvalidMsg("Invalid response id")); } Ok(response) } - fn start_background_receiving(self: &Arc) { + async fn send(&self, req: message::Request) -> Result<()> { + if self.disconnect.load(Ordering::Relaxed) { + return Err(Error::ClientDisconnected); + } + let req = serde_json::to_value(req)?; + self.send_chan.0.send(req).await?; + Ok(()) + } + + fn start_background_loop(self: &Arc, conn: Conn) { let selfc = self.clone(); let on_complete = |result: TaskResult>| async move { if let TaskResult::Completed(Err(err)) = result { - error!("Background receiving loop stopped: {err}"); + error!("Background loop stopped: {err}"); } - // Drop all subscription - selfc.subscriptions.drop_all().await; + selfc.disconnect.store(true, Ordering::Relaxed); + selfc.subscriptions.clear().await; + selfc.message_dispatcher.clear().await; }; let selfc = self.clone(); - // Spawn a new task for listing to new coming messages. - self.task_group.spawn( - async move { - loop { - let msg = selfc.conn.recv().await?; - if let Err(err) = selfc.handle_msg(msg).await { - let endpoint = selfc.conn.peer_endpoint()?; + // Spawn a new task + self.task_group + .spawn(selfc.background_loop(conn), on_complete); + } + + async fn background_loop(self: Arc, conn: Conn) -> Result<()> { + loop { + match select(self.send_chan.1.recv(), conn.recv()).await { + Either::Left(req) => { + conn.send(req?).await?; + } + Either::Right(msg) => match self.handle_msg(msg?).await { + Err(Error::SubscriptionBufferFull) => { + return Err(Error::SubscriptionBufferFull); + } + Err(err) => { + let endpoint = conn.peer_endpoint()?; error!("Handle a new msg from the endpoint {endpoint} : {err}",); } - } - }, - on_complete, - ); + Ok(_) => {} + }, + } + } } async fn handle_msg(&self, msg: serde_json::Value) -> Result<()> { @@ -173,8 +198,7 @@ impl Client { } NewMsg::Notification(nt) => { debug!("<-- {nt}"); - self.subscriptions.notify(nt).await?; - Ok(()) + self.subscriptions.notify(nt).await } }, Err(_) => { -- cgit v1.2.3