From 9aa972dd83a85cec5da71e8e893eb6e07d5db8ca Mon Sep 17 00:00:00 2001 From: hozan23 Date: Fri, 21 Jun 2024 22:45:17 +0200 Subject: jsonrpc/client: fix subscription error when the subscriber cannot keep up Add a limit for receiving notifications for the subscription. If this limit is exceeded, the client will stop and raise an error. The limit is configurable when building a new client. --- jsonrpc/src/client/builder.rs | 102 ++++++++++++++++++++++------ jsonrpc/src/client/message_dispatcher.rs | 9 +++ jsonrpc/src/client/mod.rs | 82 +++++++++++++++-------- jsonrpc/src/client/subscriptions.rs | 69 ++++++++++++++----- jsonrpc/src/error.rs | 6 ++ jsonrpc/src/lib.rs | 1 + jsonrpc/src/server/builder.rs | 110 +++++++++++++++++++++++++++++-- jsonrpc/src/server/mod.rs | 2 +- 8 files changed, 308 insertions(+), 73 deletions(-) (limited to 'jsonrpc/src') diff --git a/jsonrpc/src/client/builder.rs b/jsonrpc/src/client/builder.rs index 510ce56..c34d2da 100644 --- a/jsonrpc/src/client/builder.rs +++ b/jsonrpc/src/client/builder.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::sync::{atomic::AtomicBool, Arc}; #[cfg(feature = "smol")] use futures_rustls::rustls; @@ -20,6 +20,8 @@ use super::{Client, MessageDispatcher, Subscriptions}; const DEFAULT_TIMEOUT: u64 = 3000; // 3s +const DEFAULT_MAX_SUBSCRIPTION_BUFFER_SIZE: usize = 20000; + impl Client { /// Creates a new [`ClientBuilder`] /// @@ -27,8 +29,13 @@ impl Client { /// /// # Example /// - /// ```ignore - /// let builder = Client::builder("ws://127.0.0.1:3000")?.build()?; + /// ``` + /// use karyon_jsonrpc::Client; + /// + /// async { + /// let builder = Client::builder("ws://127.0.0.1:3000").unwrap(); + /// let client = builder.build().await.unwrap(); + /// }; /// ``` pub fn builder(endpoint: impl ToEndpoint) -> Result { let endpoint = endpoint.to_endpoint()?; @@ -37,6 +44,7 @@ impl Client { timeout: Some(DEFAULT_TIMEOUT), tls_config: None, tcp_config: Default::default(), + subscription_buffer_size: DEFAULT_MAX_SUBSCRIPTION_BUFFER_SIZE, }) } } @@ -47,29 +55,66 @@ pub struct ClientBuilder { tls_config: Option<(rustls::ClientConfig, String)>, tcp_config: TcpConfig, timeout: Option, + subscription_buffer_size: usize, } impl ClientBuilder { /// Set timeout for receiving messages, in milliseconds. Requests will /// fail if it takes longer. /// - /// # Examples + /// # Example /// - /// ```ignore - /// let client = Client::builder()?.set_timeout(5000).build()?; + /// ``` + /// use karyon_jsonrpc::Client; + /// + /// async { + /// let client = Client::builder("ws://127.0.0.1:3000").unwrap() + /// .set_timeout(5000) + /// .build().await.unwrap(); + /// }; /// ``` pub fn set_timeout(mut self, timeout: u64) -> Self { self.timeout = Some(timeout); self } + /// Set max size for the subscription buffer. + /// + /// The client will stop when the subscriber cannot keep up. + /// When subscribing to a method, a new channel with the provided buffer + /// size is initialized. Once the buffer is full and the subscriber doesn't + /// process the messages in the buffer, the client will disconnect and + /// raise an error. + /// + /// # Example + /// + /// ``` + /// use karyon_jsonrpc::Client; + /// + /// async { + /// let client = Client::builder("ws://127.0.0.1:3000").unwrap() + /// .set_max_subscription_buffer_size(10000) + /// .build().await.unwrap(); + /// }; + /// ``` + pub fn set_max_subscription_buffer_size(mut self, size: usize) -> Self { + self.subscription_buffer_size = size; + self + } + /// Configure TCP settings for the client. /// /// # Example /// - /// ```ignore - /// let tcp_config = TcpConfig::default(); - /// let client = Client::builder()?.tcp_config(tcp_config)?.build()?; + /// ``` + /// use karyon_jsonrpc::{Client, TcpConfig}; + /// + /// async { + /// let tcp_config = TcpConfig::default(); + /// + /// let client = Client::builder("ws://127.0.0.1:3000").unwrap() + /// .tcp_config(tcp_config).unwrap().build().await.unwrap(); + /// }; /// ``` /// /// This function will return an error if the endpoint does not support TCP protocols. @@ -88,8 +133,16 @@ impl ClientBuilder { /// # Example /// /// ```ignore - /// let tls_config = rustls::ClientConfig::new(...); - /// let client = Client::builder()?.tls_config(tls_config, "example.com")?.build()?; + /// use karyon_jsonrpc::Client; + /// use futures_rustls::rustls; + /// + /// async { + /// let tls_config = rustls::ClientConfig::new(...); + /// + /// let client_builder = Client::builder("ws://127.0.0.1:3000").unwrap() + /// .tls_config(tls_config, "example.com").unwrap() + /// .build().await.unwrap(); + /// }; /// ``` /// /// This function will return an error if the endpoint does not support TLS protocols. @@ -110,13 +163,17 @@ impl ClientBuilder { /// /// # Example /// - /// ```ignore - /// let client = Client::builder(endpoint)? - /// .set_timeout(5000) - /// .tcp_config(tcp_config)? - /// .tls_config(tls_config, "example.com")? - /// .build() - /// .await?; + /// ``` + /// use karyon_jsonrpc::{Client, TcpConfig}; + /// + /// async { + /// let tcp_config = TcpConfig::default(); + /// let client = Client::builder("ws://127.0.0.1:3000").unwrap() + /// .tcp_config(tcp_config).unwrap() + /// .set_timeout(5000) + /// .build().await.unwrap(); + /// }; + /// /// ``` pub async fn build(self) -> Result> { let conn: Conn = match self.endpoint { @@ -168,14 +225,17 @@ impl ClientBuilder { _ => return Err(Error::UnsupportedProtocol(self.endpoint.to_string())), }; + let send_chan = async_channel::bounded(10); + let client = Arc::new(Client { timeout: self.timeout, - conn, + disconnect: AtomicBool::new(false), + send_chan, message_dispatcher: MessageDispatcher::new(), - subscriptions: Subscriptions::new(), + subscriptions: Subscriptions::new(self.subscription_buffer_size), task_group: TaskGroup::new(), }); - client.start_background_receiving(); + client.start_background_loop(conn); Ok(client) } } diff --git a/jsonrpc/src/client/message_dispatcher.rs b/jsonrpc/src/client/message_dispatcher.rs index 14dcc71..f370985 100644 --- a/jsonrpc/src/client/message_dispatcher.rs +++ b/jsonrpc/src/client/message_dispatcher.rs @@ -34,6 +34,15 @@ impl MessageDispatcher { self.chans.lock().await.remove(id); } + /// Clear the registered channels. + pub(super) async fn clear(&self) { + let mut chans = self.chans.lock().await; + for (_, tx) in chans.iter() { + tx.close(); + } + chans.clear(); + } + /// Dispatches a response to the channel associated with the response's ID. /// /// If a channel is registered for the response's ID, the response is sent diff --git a/jsonrpc/src/client/mod.rs b/jsonrpc/src/client/mod.rs index eddba19..80125b1 100644 --- a/jsonrpc/src/client/mod.rs +++ b/jsonrpc/src/client/mod.rs @@ -2,14 +2,21 @@ pub mod builder; mod message_dispatcher; mod subscriptions; -use std::{sync::Arc, time::Duration}; +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, +}; +use async_channel::{Receiver, Sender}; use log::{debug, error}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::json; use karyon_core::{ - async_util::{timeout, TaskGroup, TaskResult}, + async_util::{select, timeout, Either, TaskGroup, TaskResult}, util::random_32, }; use karyon_net::Conn; @@ -27,10 +34,11 @@ type RequestID = u32; /// Represents an RPC client pub struct Client { - conn: Conn, timeout: Option, + disconnect: AtomicBool, message_dispatcher: MessageDispatcher, task_group: TaskGroup, + send_chan: (Sender, Receiver), subscriptions: Arc, } @@ -60,12 +68,12 @@ impl Client { /// /// This function sends a subscription request to the specified method /// with the given parameters. It waits for the response and returns a - /// tuple containing a `SubscriptionID` and a `Subscription` (channel receiver). + /// `Subscription`. pub async fn subscribe( &self, method: &str, params: T, - ) -> Result<(SubscriptionID, Subscription)> { + ) -> Result> { let response = self.send_request(method, params).await?; let sub_id = match response.result { @@ -75,7 +83,7 @@ impl Client { let sub = self.subscriptions.subscribe(sub_id).await; - Ok((sub_id, sub)) + Ok(sub) } /// Unsubscribes from the provided method, waits for the response, and returns the result. @@ -101,11 +109,8 @@ impl Client { params: Some(json!(params)), }; - let req_json = serde_json::to_value(&request)?; - - // Send the json request - self.conn.send(req_json).await?; - debug!("--> {request}"); + // Send the request + self.send(request).await?; // Register a new request let rx = self.message_dispatcher.register(id).await; @@ -131,37 +136,57 @@ impl Client { // It should be OK to unwrap here, as the message dispatcher checks // for the response id. - if *response.id.as_ref().unwrap() != request.id { + if *response.id.as_ref().unwrap() != id { return Err(Error::InvalidMsg("Invalid response id")); } Ok(response) } - fn start_background_receiving(self: &Arc) { + async fn send(&self, req: message::Request) -> Result<()> { + if self.disconnect.load(Ordering::Relaxed) { + return Err(Error::ClientDisconnected); + } + let req = serde_json::to_value(req)?; + self.send_chan.0.send(req).await?; + Ok(()) + } + + fn start_background_loop(self: &Arc, conn: Conn) { let selfc = self.clone(); let on_complete = |result: TaskResult>| async move { if let TaskResult::Completed(Err(err)) = result { - error!("Background receiving loop stopped: {err}"); + error!("Background loop stopped: {err}"); } - // Drop all subscription - selfc.subscriptions.drop_all().await; + selfc.disconnect.store(true, Ordering::Relaxed); + selfc.subscriptions.clear().await; + selfc.message_dispatcher.clear().await; }; let selfc = self.clone(); - // Spawn a new task for listing to new coming messages. - self.task_group.spawn( - async move { - loop { - let msg = selfc.conn.recv().await?; - if let Err(err) = selfc.handle_msg(msg).await { - let endpoint = selfc.conn.peer_endpoint()?; + // Spawn a new task + self.task_group + .spawn(selfc.background_loop(conn), on_complete); + } + + async fn background_loop(self: Arc, conn: Conn) -> Result<()> { + loop { + match select(self.send_chan.1.recv(), conn.recv()).await { + Either::Left(req) => { + conn.send(req?).await?; + } + Either::Right(msg) => match self.handle_msg(msg?).await { + Err(Error::SubscriptionBufferFull) => { + return Err(Error::SubscriptionBufferFull); + } + Err(err) => { + let endpoint = conn.peer_endpoint()?; error!("Handle a new msg from the endpoint {endpoint} : {err}",); } - } - }, - on_complete, - ); + Ok(_) => {} + }, + } + } } async fn handle_msg(&self, msg: serde_json::Value) -> Result<()> { @@ -173,8 +198,7 @@ impl Client { } NewMsg::Notification(nt) => { debug!("<-- {nt}"); - self.subscriptions.notify(nt).await?; - Ok(()) + self.subscriptions.notify(nt).await } }, Err(_) => { diff --git a/jsonrpc/src/client/subscriptions.rs b/jsonrpc/src/client/subscriptions.rs index 9c8a9f4..f3d8cb2 100644 --- a/jsonrpc/src/client/subscriptions.rs +++ b/jsonrpc/src/client/subscriptions.rs @@ -1,7 +1,6 @@ use std::{collections::HashMap, sync::Arc}; use async_channel::{Receiver, Sender}; -use log::warn; use serde_json::json; use serde_json::Value; @@ -12,37 +11,77 @@ use crate::{ Error, Result, }; -/// Type alias for a subscription to receive notifications. -/// -/// The receiver channel is returned by the `subscribe` -pub type Subscription = Receiver; +/// A subscription established when the client's subscribe to a method +pub struct Subscription { + id: SubscriptionID, + rx: Receiver, + tx: Sender, +} + +impl Subscription { + fn new(id: SubscriptionID, buffer_size: usize) -> Arc { + let (tx, rx) = async_channel::bounded(buffer_size); + Arc::new(Self { tx, id, rx }) + } + + pub async fn recv(&self) -> Result { + self.rx.recv().await.map_err(Error::from) + } + + pub fn id(&self) -> SubscriptionID { + self.id + } + + async fn notify(&self, val: Value) -> Result<()> { + if self.tx.is_full() { + return Err(Error::SubscriptionBufferFull); + } + self.tx.send(val).await?; + Ok(()) + } + + fn close(&self) { + self.tx.close(); + } +} /// Manages subscriptions for the client. pub(super) struct Subscriptions { - subs: Mutex>>, + subs: Mutex>>, + sub_buffer_size: usize, } impl Subscriptions { - pub(super) fn new() -> Arc { + /// Creates a new [`Subscriptions`]. + pub(super) fn new(sub_buffer_size: usize) -> Arc { Arc::new(Self { subs: Mutex::new(HashMap::new()), + sub_buffer_size, }) } - pub(super) async fn subscribe(self: &Arc, id: SubscriptionID) -> Subscription { - let (ch_tx, ch_rx) = async_channel::unbounded(); - self.subs.lock().await.insert(id, ch_tx); - ch_rx + /// Returns a new [`Subscription`] + pub(super) async fn subscribe(&self, id: SubscriptionID) -> Arc { + let sub = Subscription::new(id, self.sub_buffer_size); + self.subs.lock().await.insert(id, sub.clone()); + sub } - pub(super) async fn drop_all(&self) { - self.subs.lock().await.clear(); + /// Closes subscription channels and clear the inner map. + pub(super) async fn clear(&self) { + let mut subs = self.subs.lock().await; + for (_, sub) in subs.iter() { + sub.close(); + } + subs.clear(); } + /// Unsubscribe from the provided subscription id. pub(super) async fn unsubscribe(&self, id: &SubscriptionID) { self.subs.lock().await.remove(id); } + /// Notifies the subscription about the given notification. pub(super) async fn notify(&self, nt: Notification) -> Result<()> { let nt_res: NotificationResult = match nt.params { Some(ref p) => serde_json::from_value(p.clone())?, @@ -50,9 +89,9 @@ impl Subscriptions { }; match self.subs.lock().await.get(&nt_res.subscription) { - Some(s) => s.send(nt_res.result.unwrap_or(json!(""))).await?, + Some(s) => s.notify(nt_res.result.unwrap_or(json!(""))).await?, None => { - warn!("Receive unknown notification {}", nt_res.subscription) + return Err(Error::InvalidMsg("Unknown notification")); } } diff --git a/jsonrpc/src/error.rs b/jsonrpc/src/error.rs index 89d0e2f..f409c8d 100644 --- a/jsonrpc/src/error.rs +++ b/jsonrpc/src/error.rs @@ -29,6 +29,12 @@ pub enum Error { #[error("Subscription not found: {0}")] SubscriptionNotFound(String), + #[error("Subscription exceeds the maximum buffer size")] + SubscriptionBufferFull, + + #[error("ClientDisconnected")] + ClientDisconnected, + #[error(transparent)] ChannelRecv(#[from] async_channel::RecvError), diff --git a/jsonrpc/src/lib.rs b/jsonrpc/src/lib.rs index d43783f..c72f067 100644 --- a/jsonrpc/src/lib.rs +++ b/jsonrpc/src/lib.rs @@ -8,6 +8,7 @@ mod server; pub use client::{builder::ClientBuilder, Client}; pub use error::{Error, RPCError, RPCResult, Result}; +pub use message::SubscriptionID; pub use server::{ builder::ServerBuilder, channel::{Channel, Subscription}, diff --git a/jsonrpc/src/server/builder.rs b/jsonrpc/src/server/builder.rs index 90024f3..ca6d1a7 100644 --- a/jsonrpc/src/server/builder.rs +++ b/jsonrpc/src/server/builder.rs @@ -29,12 +29,91 @@ pub struct ServerBuilder { impl ServerBuilder { /// Adds a new RPC service to the server. + /// + /// # Example + /// ``` + /// use std::sync::Arc; + /// + /// use serde_json::Value; + /// + /// use karyon_jsonrpc::{Server, rpc_impl, RPCError}; + /// + /// struct Ping {} + /// + /// #[rpc_impl] + /// impl Ping { + /// async fn ping(&self, _params: Value) -> Result { + /// Ok(serde_json::json!("Pong")) + /// } + /// } + /// + /// async { + /// let server = Server::builder("ws://127.0.0.1:3000").unwrap() + /// .service(Arc::new(Ping{})) + /// .build().await.unwrap(); + /// }; + /// + /// ``` pub fn service(mut self, service: Arc) -> Self { self.services.insert(service.name(), service); self } /// Adds a new PubSub RPC service to the server. + /// + /// # Example + /// ``` + /// use std::sync::Arc; + /// + /// use serde_json::Value; + /// + /// use karyon_jsonrpc::{ + /// Server, rpc_impl, rpc_pubsub_impl, RPCError, Channel, SubscriptionID, + /// }; + /// + /// struct Ping {} + /// + /// #[rpc_impl] + /// impl Ping { + /// async fn ping(&self, _params: Value) -> Result { + /// Ok(serde_json::json!("Pong")) + /// } + /// } + /// + /// #[rpc_pubsub_impl] + /// impl Ping { + /// async fn log_subscribe( + /// &self, + /// chan: Arc, + /// method: String, + /// _params: Value, + /// ) -> Result { + /// let sub = chan.new_subscription(&method).await; + /// let sub_id = sub.id.clone(); + /// Ok(serde_json::json!(sub_id)) + /// } + /// + /// async fn log_unsubscribe( + /// &self, + /// chan: Arc, + /// _method: String, + /// params: Value, + /// ) -> Result { + /// let sub_id: SubscriptionID = serde_json::from_value(params)?; + /// chan.remove_subscription(&sub_id).await; + /// Ok(serde_json::json!(true)) + /// } + /// } + /// + /// async { + /// let ping_service = Arc::new(Ping{}); + /// let server = Server::builder("ws://127.0.0.1:3000").unwrap() + /// .service(ping_service.clone()) + /// .pubsub_service(ping_service) + /// .build().await.unwrap(); + /// }; + /// + /// ``` pub fn pubsub_service(mut self, service: Arc) -> Self { self.pubsub_services.insert(service.name(), service); self @@ -44,9 +123,15 @@ impl ServerBuilder { /// /// # Example /// - /// ```ignore - /// let tcp_config = TcpConfig::default(); - /// let server = Server::builder()?.tcp_config(tcp_config)?.build()?; + /// ``` + /// use karyon_jsonrpc::{Server, TcpConfig}; + /// + /// async { + /// let tcp_config = TcpConfig::default(); + /// let server = Server::builder("ws://127.0.0.1:3000").unwrap() + /// .tcp_config(tcp_config).unwrap() + /// .build().await.unwrap(); + /// }; /// ``` /// /// This function will return an error if the endpoint does not support TCP protocols. @@ -65,8 +150,15 @@ impl ServerBuilder { /// # Example /// /// ```ignore - /// let tls_config = rustls::ServerConfig::new(...); - /// let server = Server::builder()?.tls_config(tls_config)?.build()?; + /// use karon_jsonrpc::Server; + /// use futures_rustls::rustls; + /// + /// async { + /// let tls_config = rustls::ServerConfig::new(...); + /// let server = Server::builder("ws://127.0.0.1:3000").unwrap() + /// .tls_config(tls_config).unwrap() + /// .build().await.unwrap(); + /// }; /// ``` /// /// This function will return an error if the endpoint does not support TLS protocols. @@ -157,8 +249,12 @@ impl Server { /// /// # Example /// - /// ```ignore - /// let builder = Server::builder("ws://127.0.0.1:3000")?.build()?; + /// ``` + /// use karyon_jsonrpc::Server; + /// async { + /// let server = Server::builder("ws://127.0.0.1:3000").unwrap() + /// .build().await.unwrap(); + /// }; /// ``` pub fn builder(endpoint: impl ToEndpoint) -> Result { let endpoint = endpoint.to_endpoint()?; diff --git a/jsonrpc/src/server/mod.rs b/jsonrpc/src/server/mod.rs index 6f539be..00b0fd2 100644 --- a/jsonrpc/src/server/mod.rs +++ b/jsonrpc/src/server/mod.rs @@ -126,7 +126,7 @@ impl Server { method: nt.method, params, }; - // debug!("--> {notification}"); + debug!("--> {notification}"); conn_cloned.send(serde_json::json!(notification)).await?; } } -- cgit v1.2.3