aboutsummaryrefslogtreecommitdiff
path: root/jsonrpc/src
diff options
context:
space:
mode:
authorhozan23 <hozan23@karyontech.net>2024-06-21 22:45:17 +0200
committerhozan23 <hozan23@karyontech.net>2024-06-21 22:45:17 +0200
commit9aa972dd83a85cec5da71e8e893eb6e07d5db8ca (patch)
treea227c66e3e75e018f480556e1d58d40306acb12e /jsonrpc/src
parent8fc494d2d508f0e0beefccda31d15a5e387a9791 (diff)
jsonrpc/client: fix subscription error when the subscriber cannot keep up
Add a limit for receiving notifications for the subscription. If this limit is exceeded, the client will stop and raise an error. The limit is configurable when building a new client.
Diffstat (limited to 'jsonrpc/src')
-rw-r--r--jsonrpc/src/client/builder.rs102
-rw-r--r--jsonrpc/src/client/message_dispatcher.rs9
-rw-r--r--jsonrpc/src/client/mod.rs82
-rw-r--r--jsonrpc/src/client/subscriptions.rs69
-rw-r--r--jsonrpc/src/error.rs6
-rw-r--r--jsonrpc/src/lib.rs1
-rw-r--r--jsonrpc/src/server/builder.rs110
-rw-r--r--jsonrpc/src/server/mod.rs2
8 files changed, 308 insertions, 73 deletions
diff --git a/jsonrpc/src/client/builder.rs b/jsonrpc/src/client/builder.rs
index 510ce56..c34d2da 100644
--- a/jsonrpc/src/client/builder.rs
+++ b/jsonrpc/src/client/builder.rs
@@ -1,4 +1,4 @@
-use std::sync::Arc;
+use std::sync::{atomic::AtomicBool, Arc};
#[cfg(feature = "smol")]
use futures_rustls::rustls;
@@ -20,6 +20,8 @@ use super::{Client, MessageDispatcher, Subscriptions};
const DEFAULT_TIMEOUT: u64 = 3000; // 3s
+const DEFAULT_MAX_SUBSCRIPTION_BUFFER_SIZE: usize = 20000;
+
impl Client {
/// Creates a new [`ClientBuilder`]
///
@@ -27,8 +29,13 @@ impl Client {
///
/// # Example
///
- /// ```ignore
- /// let builder = Client::builder("ws://127.0.0.1:3000")?.build()?;
+ /// ```
+ /// use karyon_jsonrpc::Client;
+ ///
+ /// async {
+ /// let builder = Client::builder("ws://127.0.0.1:3000").unwrap();
+ /// let client = builder.build().await.unwrap();
+ /// };
/// ```
pub fn builder(endpoint: impl ToEndpoint) -> Result<ClientBuilder> {
let endpoint = endpoint.to_endpoint()?;
@@ -37,6 +44,7 @@ impl Client {
timeout: Some(DEFAULT_TIMEOUT),
tls_config: None,
tcp_config: Default::default(),
+ subscription_buffer_size: DEFAULT_MAX_SUBSCRIPTION_BUFFER_SIZE,
})
}
}
@@ -47,29 +55,66 @@ pub struct ClientBuilder {
tls_config: Option<(rustls::ClientConfig, String)>,
tcp_config: TcpConfig,
timeout: Option<u64>,
+ subscription_buffer_size: usize,
}
impl ClientBuilder {
/// Set timeout for receiving messages, in milliseconds. Requests will
/// fail if it takes longer.
///
- /// # Examples
+ /// # Example
///
- /// ```ignore
- /// let client = Client::builder()?.set_timeout(5000).build()?;
+ /// ```
+ /// use karyon_jsonrpc::Client;
+ ///
+ /// async {
+ /// let client = Client::builder("ws://127.0.0.1:3000").unwrap()
+ /// .set_timeout(5000)
+ /// .build().await.unwrap();
+ /// };
/// ```
pub fn set_timeout(mut self, timeout: u64) -> Self {
self.timeout = Some(timeout);
self
}
+ /// Set max size for the subscription buffer.
+ ///
+ /// The client will stop when the subscriber cannot keep up.
+ /// When subscribing to a method, a new channel with the provided buffer
+ /// size is initialized. Once the buffer is full and the subscriber doesn't
+ /// process the messages in the buffer, the client will disconnect and
+ /// raise an error.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use karyon_jsonrpc::Client;
+ ///
+ /// async {
+ /// let client = Client::builder("ws://127.0.0.1:3000").unwrap()
+ /// .set_max_subscription_buffer_size(10000)
+ /// .build().await.unwrap();
+ /// };
+ /// ```
+ pub fn set_max_subscription_buffer_size(mut self, size: usize) -> Self {
+ self.subscription_buffer_size = size;
+ self
+ }
+
/// Configure TCP settings for the client.
///
/// # Example
///
- /// ```ignore
- /// let tcp_config = TcpConfig::default();
- /// let client = Client::builder()?.tcp_config(tcp_config)?.build()?;
+ /// ```
+ /// use karyon_jsonrpc::{Client, TcpConfig};
+ ///
+ /// async {
+ /// let tcp_config = TcpConfig::default();
+ ///
+ /// let client = Client::builder("ws://127.0.0.1:3000").unwrap()
+ /// .tcp_config(tcp_config).unwrap().build().await.unwrap();
+ /// };
/// ```
///
/// This function will return an error if the endpoint does not support TCP protocols.
@@ -88,8 +133,16 @@ impl ClientBuilder {
/// # Example
///
/// ```ignore
- /// let tls_config = rustls::ClientConfig::new(...);
- /// let client = Client::builder()?.tls_config(tls_config, "example.com")?.build()?;
+ /// use karyon_jsonrpc::Client;
+ /// use futures_rustls::rustls;
+ ///
+ /// async {
+ /// let tls_config = rustls::ClientConfig::new(...);
+ ///
+ /// let client_builder = Client::builder("ws://127.0.0.1:3000").unwrap()
+ /// .tls_config(tls_config, "example.com").unwrap()
+ /// .build().await.unwrap();
+ /// };
/// ```
///
/// This function will return an error if the endpoint does not support TLS protocols.
@@ -110,13 +163,17 @@ impl ClientBuilder {
///
/// # Example
///
- /// ```ignore
- /// let client = Client::builder(endpoint)?
- /// .set_timeout(5000)
- /// .tcp_config(tcp_config)?
- /// .tls_config(tls_config, "example.com")?
- /// .build()
- /// .await?;
+ /// ```
+ /// use karyon_jsonrpc::{Client, TcpConfig};
+ ///
+ /// async {
+ /// let tcp_config = TcpConfig::default();
+ /// let client = Client::builder("ws://127.0.0.1:3000").unwrap()
+ /// .tcp_config(tcp_config).unwrap()
+ /// .set_timeout(5000)
+ /// .build().await.unwrap();
+ /// };
+ ///
/// ```
pub async fn build(self) -> Result<Arc<Client>> {
let conn: Conn<serde_json::Value> = match self.endpoint {
@@ -168,14 +225,17 @@ impl ClientBuilder {
_ => return Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
};
+ let send_chan = async_channel::bounded(10);
+
let client = Arc::new(Client {
timeout: self.timeout,
- conn,
+ disconnect: AtomicBool::new(false),
+ send_chan,
message_dispatcher: MessageDispatcher::new(),
- subscriptions: Subscriptions::new(),
+ subscriptions: Subscriptions::new(self.subscription_buffer_size),
task_group: TaskGroup::new(),
});
- client.start_background_receiving();
+ client.start_background_loop(conn);
Ok(client)
}
}
diff --git a/jsonrpc/src/client/message_dispatcher.rs b/jsonrpc/src/client/message_dispatcher.rs
index 14dcc71..f370985 100644
--- a/jsonrpc/src/client/message_dispatcher.rs
+++ b/jsonrpc/src/client/message_dispatcher.rs
@@ -34,6 +34,15 @@ impl MessageDispatcher {
self.chans.lock().await.remove(id);
}
+ /// Clear the registered channels.
+ pub(super) async fn clear(&self) {
+ let mut chans = self.chans.lock().await;
+ for (_, tx) in chans.iter() {
+ tx.close();
+ }
+ chans.clear();
+ }
+
/// Dispatches a response to the channel associated with the response's ID.
///
/// If a channel is registered for the response's ID, the response is sent
diff --git a/jsonrpc/src/client/mod.rs b/jsonrpc/src/client/mod.rs
index eddba19..80125b1 100644
--- a/jsonrpc/src/client/mod.rs
+++ b/jsonrpc/src/client/mod.rs
@@ -2,14 +2,21 @@ pub mod builder;
mod message_dispatcher;
mod subscriptions;
-use std::{sync::Arc, time::Duration};
+use std::{
+ sync::{
+ atomic::{AtomicBool, Ordering},
+ Arc,
+ },
+ time::Duration,
+};
+use async_channel::{Receiver, Sender};
use log::{debug, error};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::json;
use karyon_core::{
- async_util::{timeout, TaskGroup, TaskResult},
+ async_util::{select, timeout, Either, TaskGroup, TaskResult},
util::random_32,
};
use karyon_net::Conn;
@@ -27,10 +34,11 @@ type RequestID = u32;
/// Represents an RPC client
pub struct Client {
- conn: Conn<serde_json::Value>,
timeout: Option<u64>,
+ disconnect: AtomicBool,
message_dispatcher: MessageDispatcher,
task_group: TaskGroup,
+ send_chan: (Sender<serde_json::Value>, Receiver<serde_json::Value>),
subscriptions: Arc<Subscriptions>,
}
@@ -60,12 +68,12 @@ impl Client {
///
/// This function sends a subscription request to the specified method
/// with the given parameters. It waits for the response and returns a
- /// tuple containing a `SubscriptionID` and a `Subscription` (channel receiver).
+ /// `Subscription`.
pub async fn subscribe<T: Serialize + DeserializeOwned>(
&self,
method: &str,
params: T,
- ) -> Result<(SubscriptionID, Subscription)> {
+ ) -> Result<Arc<Subscription>> {
let response = self.send_request(method, params).await?;
let sub_id = match response.result {
@@ -75,7 +83,7 @@ impl Client {
let sub = self.subscriptions.subscribe(sub_id).await;
- Ok((sub_id, sub))
+ Ok(sub)
}
/// Unsubscribes from the provided method, waits for the response, and returns the result.
@@ -101,11 +109,8 @@ impl Client {
params: Some(json!(params)),
};
- let req_json = serde_json::to_value(&request)?;
-
- // Send the json request
- self.conn.send(req_json).await?;
- debug!("--> {request}");
+ // Send the request
+ self.send(request).await?;
// Register a new request
let rx = self.message_dispatcher.register(id).await;
@@ -131,37 +136,57 @@ impl Client {
// It should be OK to unwrap here, as the message dispatcher checks
// for the response id.
- if *response.id.as_ref().unwrap() != request.id {
+ if *response.id.as_ref().unwrap() != id {
return Err(Error::InvalidMsg("Invalid response id"));
}
Ok(response)
}
- fn start_background_receiving(self: &Arc<Self>) {
+ async fn send(&self, req: message::Request) -> Result<()> {
+ if self.disconnect.load(Ordering::Relaxed) {
+ return Err(Error::ClientDisconnected);
+ }
+ let req = serde_json::to_value(req)?;
+ self.send_chan.0.send(req).await?;
+ Ok(())
+ }
+
+ fn start_background_loop(self: &Arc<Self>, conn: Conn<serde_json::Value>) {
let selfc = self.clone();
let on_complete = |result: TaskResult<Result<()>>| async move {
if let TaskResult::Completed(Err(err)) = result {
- error!("Background receiving loop stopped: {err}");
+ error!("Background loop stopped: {err}");
}
- // Drop all subscription
- selfc.subscriptions.drop_all().await;
+ selfc.disconnect.store(true, Ordering::Relaxed);
+ selfc.subscriptions.clear().await;
+ selfc.message_dispatcher.clear().await;
};
let selfc = self.clone();
- // Spawn a new task for listing to new coming messages.
- self.task_group.spawn(
- async move {
- loop {
- let msg = selfc.conn.recv().await?;
- if let Err(err) = selfc.handle_msg(msg).await {
- let endpoint = selfc.conn.peer_endpoint()?;
+ // Spawn a new task
+ self.task_group
+ .spawn(selfc.background_loop(conn), on_complete);
+ }
+
+ async fn background_loop(self: Arc<Self>, conn: Conn<serde_json::Value>) -> Result<()> {
+ loop {
+ match select(self.send_chan.1.recv(), conn.recv()).await {
+ Either::Left(req) => {
+ conn.send(req?).await?;
+ }
+ Either::Right(msg) => match self.handle_msg(msg?).await {
+ Err(Error::SubscriptionBufferFull) => {
+ return Err(Error::SubscriptionBufferFull);
+ }
+ Err(err) => {
+ let endpoint = conn.peer_endpoint()?;
error!("Handle a new msg from the endpoint {endpoint} : {err}",);
}
- }
- },
- on_complete,
- );
+ Ok(_) => {}
+ },
+ }
+ }
}
async fn handle_msg(&self, msg: serde_json::Value) -> Result<()> {
@@ -173,8 +198,7 @@ impl Client {
}
NewMsg::Notification(nt) => {
debug!("<-- {nt}");
- self.subscriptions.notify(nt).await?;
- Ok(())
+ self.subscriptions.notify(nt).await
}
},
Err(_) => {
diff --git a/jsonrpc/src/client/subscriptions.rs b/jsonrpc/src/client/subscriptions.rs
index 9c8a9f4..f3d8cb2 100644
--- a/jsonrpc/src/client/subscriptions.rs
+++ b/jsonrpc/src/client/subscriptions.rs
@@ -1,7 +1,6 @@
use std::{collections::HashMap, sync::Arc};
use async_channel::{Receiver, Sender};
-use log::warn;
use serde_json::json;
use serde_json::Value;
@@ -12,37 +11,77 @@ use crate::{
Error, Result,
};
-/// Type alias for a subscription to receive notifications.
-///
-/// The receiver channel is returned by the `subscribe`
-pub type Subscription = Receiver<Value>;
+/// A subscription established when the client's subscribe to a method
+pub struct Subscription {
+ id: SubscriptionID,
+ rx: Receiver<Value>,
+ tx: Sender<Value>,
+}
+
+impl Subscription {
+ fn new(id: SubscriptionID, buffer_size: usize) -> Arc<Self> {
+ let (tx, rx) = async_channel::bounded(buffer_size);
+ Arc::new(Self { tx, id, rx })
+ }
+
+ pub async fn recv(&self) -> Result<Value> {
+ self.rx.recv().await.map_err(Error::from)
+ }
+
+ pub fn id(&self) -> SubscriptionID {
+ self.id
+ }
+
+ async fn notify(&self, val: Value) -> Result<()> {
+ if self.tx.is_full() {
+ return Err(Error::SubscriptionBufferFull);
+ }
+ self.tx.send(val).await?;
+ Ok(())
+ }
+
+ fn close(&self) {
+ self.tx.close();
+ }
+}
/// Manages subscriptions for the client.
pub(super) struct Subscriptions {
- subs: Mutex<HashMap<SubscriptionID, Sender<Value>>>,
+ subs: Mutex<HashMap<SubscriptionID, Arc<Subscription>>>,
+ sub_buffer_size: usize,
}
impl Subscriptions {
- pub(super) fn new() -> Arc<Self> {
+ /// Creates a new [`Subscriptions`].
+ pub(super) fn new(sub_buffer_size: usize) -> Arc<Self> {
Arc::new(Self {
subs: Mutex::new(HashMap::new()),
+ sub_buffer_size,
})
}
- pub(super) async fn subscribe(self: &Arc<Self>, id: SubscriptionID) -> Subscription {
- let (ch_tx, ch_rx) = async_channel::unbounded();
- self.subs.lock().await.insert(id, ch_tx);
- ch_rx
+ /// Returns a new [`Subscription`]
+ pub(super) async fn subscribe(&self, id: SubscriptionID) -> Arc<Subscription> {
+ let sub = Subscription::new(id, self.sub_buffer_size);
+ self.subs.lock().await.insert(id, sub.clone());
+ sub
}
- pub(super) async fn drop_all(&self) {
- self.subs.lock().await.clear();
+ /// Closes subscription channels and clear the inner map.
+ pub(super) async fn clear(&self) {
+ let mut subs = self.subs.lock().await;
+ for (_, sub) in subs.iter() {
+ sub.close();
+ }
+ subs.clear();
}
+ /// Unsubscribe from the provided subscription id.
pub(super) async fn unsubscribe(&self, id: &SubscriptionID) {
self.subs.lock().await.remove(id);
}
+ /// Notifies the subscription about the given notification.
pub(super) async fn notify(&self, nt: Notification) -> Result<()> {
let nt_res: NotificationResult = match nt.params {
Some(ref p) => serde_json::from_value(p.clone())?,
@@ -50,9 +89,9 @@ impl Subscriptions {
};
match self.subs.lock().await.get(&nt_res.subscription) {
- Some(s) => s.send(nt_res.result.unwrap_or(json!(""))).await?,
+ Some(s) => s.notify(nt_res.result.unwrap_or(json!(""))).await?,
None => {
- warn!("Receive unknown notification {}", nt_res.subscription)
+ return Err(Error::InvalidMsg("Unknown notification"));
}
}
diff --git a/jsonrpc/src/error.rs b/jsonrpc/src/error.rs
index 89d0e2f..f409c8d 100644
--- a/jsonrpc/src/error.rs
+++ b/jsonrpc/src/error.rs
@@ -29,6 +29,12 @@ pub enum Error {
#[error("Subscription not found: {0}")]
SubscriptionNotFound(String),
+ #[error("Subscription exceeds the maximum buffer size")]
+ SubscriptionBufferFull,
+
+ #[error("ClientDisconnected")]
+ ClientDisconnected,
+
#[error(transparent)]
ChannelRecv(#[from] async_channel::RecvError),
diff --git a/jsonrpc/src/lib.rs b/jsonrpc/src/lib.rs
index d43783f..c72f067 100644
--- a/jsonrpc/src/lib.rs
+++ b/jsonrpc/src/lib.rs
@@ -8,6 +8,7 @@ mod server;
pub use client::{builder::ClientBuilder, Client};
pub use error::{Error, RPCError, RPCResult, Result};
+pub use message::SubscriptionID;
pub use server::{
builder::ServerBuilder,
channel::{Channel, Subscription},
diff --git a/jsonrpc/src/server/builder.rs b/jsonrpc/src/server/builder.rs
index 90024f3..ca6d1a7 100644
--- a/jsonrpc/src/server/builder.rs
+++ b/jsonrpc/src/server/builder.rs
@@ -29,12 +29,91 @@ pub struct ServerBuilder {
impl ServerBuilder {
/// Adds a new RPC service to the server.
+ ///
+ /// # Example
+ /// ```
+ /// use std::sync::Arc;
+ ///
+ /// use serde_json::Value;
+ ///
+ /// use karyon_jsonrpc::{Server, rpc_impl, RPCError};
+ ///
+ /// struct Ping {}
+ ///
+ /// #[rpc_impl]
+ /// impl Ping {
+ /// async fn ping(&self, _params: Value) -> Result<Value, RPCError> {
+ /// Ok(serde_json::json!("Pong"))
+ /// }
+ /// }
+ ///
+ /// async {
+ /// let server = Server::builder("ws://127.0.0.1:3000").unwrap()
+ /// .service(Arc::new(Ping{}))
+ /// .build().await.unwrap();
+ /// };
+ ///
+ /// ```
pub fn service(mut self, service: Arc<dyn RPCService>) -> Self {
self.services.insert(service.name(), service);
self
}
/// Adds a new PubSub RPC service to the server.
+ ///
+ /// # Example
+ /// ```
+ /// use std::sync::Arc;
+ ///
+ /// use serde_json::Value;
+ ///
+ /// use karyon_jsonrpc::{
+ /// Server, rpc_impl, rpc_pubsub_impl, RPCError, Channel, SubscriptionID,
+ /// };
+ ///
+ /// struct Ping {}
+ ///
+ /// #[rpc_impl]
+ /// impl Ping {
+ /// async fn ping(&self, _params: Value) -> Result<Value, RPCError> {
+ /// Ok(serde_json::json!("Pong"))
+ /// }
+ /// }
+ ///
+ /// #[rpc_pubsub_impl]
+ /// impl Ping {
+ /// async fn log_subscribe(
+ /// &self,
+ /// chan: Arc<Channel>,
+ /// method: String,
+ /// _params: Value,
+ /// ) -> Result<Value, RPCError> {
+ /// let sub = chan.new_subscription(&method).await;
+ /// let sub_id = sub.id.clone();
+ /// Ok(serde_json::json!(sub_id))
+ /// }
+ ///
+ /// async fn log_unsubscribe(
+ /// &self,
+ /// chan: Arc<Channel>,
+ /// _method: String,
+ /// params: Value,
+ /// ) -> Result<Value, RPCError> {
+ /// let sub_id: SubscriptionID = serde_json::from_value(params)?;
+ /// chan.remove_subscription(&sub_id).await;
+ /// Ok(serde_json::json!(true))
+ /// }
+ /// }
+ ///
+ /// async {
+ /// let ping_service = Arc::new(Ping{});
+ /// let server = Server::builder("ws://127.0.0.1:3000").unwrap()
+ /// .service(ping_service.clone())
+ /// .pubsub_service(ping_service)
+ /// .build().await.unwrap();
+ /// };
+ ///
+ /// ```
pub fn pubsub_service(mut self, service: Arc<dyn PubSubRPCService>) -> Self {
self.pubsub_services.insert(service.name(), service);
self
@@ -44,9 +123,15 @@ impl ServerBuilder {
///
/// # Example
///
- /// ```ignore
- /// let tcp_config = TcpConfig::default();
- /// let server = Server::builder()?.tcp_config(tcp_config)?.build()?;
+ /// ```
+ /// use karyon_jsonrpc::{Server, TcpConfig};
+ ///
+ /// async {
+ /// let tcp_config = TcpConfig::default();
+ /// let server = Server::builder("ws://127.0.0.1:3000").unwrap()
+ /// .tcp_config(tcp_config).unwrap()
+ /// .build().await.unwrap();
+ /// };
/// ```
///
/// This function will return an error if the endpoint does not support TCP protocols.
@@ -65,8 +150,15 @@ impl ServerBuilder {
/// # Example
///
/// ```ignore
- /// let tls_config = rustls::ServerConfig::new(...);
- /// let server = Server::builder()?.tls_config(tls_config)?.build()?;
+ /// use karon_jsonrpc::Server;
+ /// use futures_rustls::rustls;
+ ///
+ /// async {
+ /// let tls_config = rustls::ServerConfig::new(...);
+ /// let server = Server::builder("ws://127.0.0.1:3000").unwrap()
+ /// .tls_config(tls_config).unwrap()
+ /// .build().await.unwrap();
+ /// };
/// ```
///
/// This function will return an error if the endpoint does not support TLS protocols.
@@ -157,8 +249,12 @@ impl Server {
///
/// # Example
///
- /// ```ignore
- /// let builder = Server::builder("ws://127.0.0.1:3000")?.build()?;
+ /// ```
+ /// use karyon_jsonrpc::Server;
+ /// async {
+ /// let server = Server::builder("ws://127.0.0.1:3000").unwrap()
+ /// .build().await.unwrap();
+ /// };
/// ```
pub fn builder(endpoint: impl ToEndpoint) -> Result<ServerBuilder> {
let endpoint = endpoint.to_endpoint()?;
diff --git a/jsonrpc/src/server/mod.rs b/jsonrpc/src/server/mod.rs
index 6f539be..00b0fd2 100644
--- a/jsonrpc/src/server/mod.rs
+++ b/jsonrpc/src/server/mod.rs
@@ -126,7 +126,7 @@ impl Server {
method: nt.method,
params,
};
- // debug!("--> {notification}");
+ debug!("--> {notification}");
conn_cloned.send(serde_json::json!(notification)).await?;
}
}