aboutsummaryrefslogtreecommitdiff
path: root/jsonrpc/src/client/mod.rs
diff options
context:
space:
mode:
authorhozan23 <hozan23@karyontech.net>2024-06-21 22:45:17 +0200
committerhozan23 <hozan23@karyontech.net>2024-06-21 22:45:17 +0200
commit9aa972dd83a85cec5da71e8e893eb6e07d5db8ca (patch)
treea227c66e3e75e018f480556e1d58d40306acb12e /jsonrpc/src/client/mod.rs
parent8fc494d2d508f0e0beefccda31d15a5e387a9791 (diff)
jsonrpc/client: fix subscription error when the subscriber cannot keep up
Add a limit for receiving notifications for the subscription. If this limit is exceeded, the client will stop and raise an error. The limit is configurable when building a new client.
Diffstat (limited to 'jsonrpc/src/client/mod.rs')
-rw-r--r--jsonrpc/src/client/mod.rs82
1 files changed, 53 insertions, 29 deletions
diff --git a/jsonrpc/src/client/mod.rs b/jsonrpc/src/client/mod.rs
index eddba19..80125b1 100644
--- a/jsonrpc/src/client/mod.rs
+++ b/jsonrpc/src/client/mod.rs
@@ -2,14 +2,21 @@ pub mod builder;
mod message_dispatcher;
mod subscriptions;
-use std::{sync::Arc, time::Duration};
+use std::{
+ sync::{
+ atomic::{AtomicBool, Ordering},
+ Arc,
+ },
+ time::Duration,
+};
+use async_channel::{Receiver, Sender};
use log::{debug, error};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::json;
use karyon_core::{
- async_util::{timeout, TaskGroup, TaskResult},
+ async_util::{select, timeout, Either, TaskGroup, TaskResult},
util::random_32,
};
use karyon_net::Conn;
@@ -27,10 +34,11 @@ type RequestID = u32;
/// Represents an RPC client
pub struct Client {
- conn: Conn<serde_json::Value>,
timeout: Option<u64>,
+ disconnect: AtomicBool,
message_dispatcher: MessageDispatcher,
task_group: TaskGroup,
+ send_chan: (Sender<serde_json::Value>, Receiver<serde_json::Value>),
subscriptions: Arc<Subscriptions>,
}
@@ -60,12 +68,12 @@ impl Client {
///
/// This function sends a subscription request to the specified method
/// with the given parameters. It waits for the response and returns a
- /// tuple containing a `SubscriptionID` and a `Subscription` (channel receiver).
+ /// `Subscription`.
pub async fn subscribe<T: Serialize + DeserializeOwned>(
&self,
method: &str,
params: T,
- ) -> Result<(SubscriptionID, Subscription)> {
+ ) -> Result<Arc<Subscription>> {
let response = self.send_request(method, params).await?;
let sub_id = match response.result {
@@ -75,7 +83,7 @@ impl Client {
let sub = self.subscriptions.subscribe(sub_id).await;
- Ok((sub_id, sub))
+ Ok(sub)
}
/// Unsubscribes from the provided method, waits for the response, and returns the result.
@@ -101,11 +109,8 @@ impl Client {
params: Some(json!(params)),
};
- let req_json = serde_json::to_value(&request)?;
-
- // Send the json request
- self.conn.send(req_json).await?;
- debug!("--> {request}");
+ // Send the request
+ self.send(request).await?;
// Register a new request
let rx = self.message_dispatcher.register(id).await;
@@ -131,37 +136,57 @@ impl Client {
// It should be OK to unwrap here, as the message dispatcher checks
// for the response id.
- if *response.id.as_ref().unwrap() != request.id {
+ if *response.id.as_ref().unwrap() != id {
return Err(Error::InvalidMsg("Invalid response id"));
}
Ok(response)
}
- fn start_background_receiving(self: &Arc<Self>) {
+ async fn send(&self, req: message::Request) -> Result<()> {
+ if self.disconnect.load(Ordering::Relaxed) {
+ return Err(Error::ClientDisconnected);
+ }
+ let req = serde_json::to_value(req)?;
+ self.send_chan.0.send(req).await?;
+ Ok(())
+ }
+
+ fn start_background_loop(self: &Arc<Self>, conn: Conn<serde_json::Value>) {
let selfc = self.clone();
let on_complete = |result: TaskResult<Result<()>>| async move {
if let TaskResult::Completed(Err(err)) = result {
- error!("Background receiving loop stopped: {err}");
+ error!("Background loop stopped: {err}");
}
- // Drop all subscription
- selfc.subscriptions.drop_all().await;
+ selfc.disconnect.store(true, Ordering::Relaxed);
+ selfc.subscriptions.clear().await;
+ selfc.message_dispatcher.clear().await;
};
let selfc = self.clone();
- // Spawn a new task for listing to new coming messages.
- self.task_group.spawn(
- async move {
- loop {
- let msg = selfc.conn.recv().await?;
- if let Err(err) = selfc.handle_msg(msg).await {
- let endpoint = selfc.conn.peer_endpoint()?;
+ // Spawn a new task
+ self.task_group
+ .spawn(selfc.background_loop(conn), on_complete);
+ }
+
+ async fn background_loop(self: Arc<Self>, conn: Conn<serde_json::Value>) -> Result<()> {
+ loop {
+ match select(self.send_chan.1.recv(), conn.recv()).await {
+ Either::Left(req) => {
+ conn.send(req?).await?;
+ }
+ Either::Right(msg) => match self.handle_msg(msg?).await {
+ Err(Error::SubscriptionBufferFull) => {
+ return Err(Error::SubscriptionBufferFull);
+ }
+ Err(err) => {
+ let endpoint = conn.peer_endpoint()?;
error!("Handle a new msg from the endpoint {endpoint} : {err}",);
}
- }
- },
- on_complete,
- );
+ Ok(_) => {}
+ },
+ }
+ }
}
async fn handle_msg(&self, msg: serde_json::Value) -> Result<()> {
@@ -173,8 +198,7 @@ impl Client {
}
NewMsg::Notification(nt) => {
debug!("<-- {nt}");
- self.subscriptions.notify(nt).await?;
- Ok(())
+ self.subscriptions.notify(nt).await
}
},
Err(_) => {