aboutsummaryrefslogtreecommitdiff
path: root/jsonrpc/src/client
diff options
context:
space:
mode:
Diffstat (limited to 'jsonrpc/src/client')
-rw-r--r--jsonrpc/src/client/builder.rs10
-rw-r--r--jsonrpc/src/client/message_dispatcher.rs54
-rw-r--r--jsonrpc/src/client/mod.rs107
-rw-r--r--jsonrpc/src/client/subscriber.rs64
4 files changed, 162 insertions, 73 deletions
diff --git a/jsonrpc/src/client/builder.rs b/jsonrpc/src/client/builder.rs
index a287070..2263498 100644
--- a/jsonrpc/src/client/builder.rs
+++ b/jsonrpc/src/client/builder.rs
@@ -1,11 +1,11 @@
-use std::{collections::HashMap, sync::Arc};
+use std::sync::Arc;
#[cfg(feature = "smol")]
use futures_rustls::rustls;
#[cfg(feature = "tokio")]
use tokio_rustls::rustls;
-use karyon_core::{async_runtime::lock::Mutex, async_util::TaskGroup};
+use karyon_core::async_util::TaskGroup;
use karyon_net::{tls::ClientTlsConfig, Conn, Endpoint, ToEndpoint};
#[cfg(feature = "ws")]
@@ -16,7 +16,7 @@ use crate::codec::WsJsonCodec;
use crate::{codec::JsonCodec, Error, Result, TcpConfig};
-use super::Client;
+use super::{Client, MessageDispatcher, Subscriber};
const DEFAULT_TIMEOUT: u64 = 3000; // 3s
@@ -171,8 +171,8 @@ impl ClientBuilder {
let client = Arc::new(Client {
timeout: self.timeout,
conn,
- chans: Mutex::new(HashMap::new()),
- subscriptions: Mutex::new(HashMap::new()),
+ message_dispatcher: MessageDispatcher::new(),
+ subscriber: Subscriber::new(),
task_group: TaskGroup::new(),
});
client.start_background_receiving();
diff --git a/jsonrpc/src/client/message_dispatcher.rs b/jsonrpc/src/client/message_dispatcher.rs
new file mode 100644
index 0000000..a803f6e
--- /dev/null
+++ b/jsonrpc/src/client/message_dispatcher.rs
@@ -0,0 +1,54 @@
+use std::collections::HashMap;
+
+use async_channel::{Receiver, Sender};
+
+use karyon_core::async_runtime::lock::Mutex;
+
+use crate::{message, Error, Result};
+
+use super::RequestID;
+
+const CHANNEL_CAP: usize = 10;
+
+/// Manages client requests
+pub(super) struct MessageDispatcher {
+ chans: Mutex<HashMap<RequestID, Sender<message::Response>>>,
+}
+
+impl MessageDispatcher {
+ /// Creates a new MessageDispatcher
+ pub(super) fn new() -> Self {
+ Self {
+ chans: Mutex::new(HashMap::new()),
+ }
+ }
+
+ /// Registers a new request with a given ID and returns a Receiver channel
+ /// to wait for the response.
+ pub(super) async fn register(&self, id: RequestID) -> Receiver<message::Response> {
+ let (tx, rx) = async_channel::bounded(CHANNEL_CAP);
+ self.chans.lock().await.insert(id, tx);
+ rx
+ }
+
+ /// Unregisters the request with the provided ID
+ pub(super) async fn unregister(&self, id: &RequestID) {
+ self.chans.lock().await.remove(id);
+ }
+
+ /// Dispatches a response to the channel associated with the response's ID.
+ ///
+ /// If a channel is registered for the response's ID, the response is sent
+ /// through that channel. If no channel is found for the ID, returns an error.
+ pub(super) async fn dispatch(&self, res: message::Response) -> Result<()> {
+ if res.id.is_none() {
+ return Err(Error::InvalidMsg("Response id is none"));
+ }
+ let id: RequestID = serde_json::from_value(res.id.clone().unwrap())?;
+ let val = self.chans.lock().await.remove(&id);
+ match val {
+ Some(tx) => tx.send(res).await.map_err(Error::from),
+ None => Err(Error::InvalidMsg("Receive unknown message")),
+ }
+ }
+}
diff --git a/jsonrpc/src/client/mod.rs b/jsonrpc/src/client/mod.rs
index 3a4505c..95354d3 100644
--- a/jsonrpc/src/client/mod.rs
+++ b/jsonrpc/src/client/mod.rs
@@ -1,13 +1,13 @@
pub mod builder;
+mod message_dispatcher;
+mod subscriber;
-use std::{collections::HashMap, sync::Arc, time::Duration};
-
-use log::{debug, error, warn};
+use log::{debug, error};
use serde::{de::DeserializeOwned, Serialize};
use serde_json::json;
+use std::{sync::Arc, time::Duration};
use karyon_core::{
- async_runtime::lock::Mutex,
async_util::{timeout, TaskGroup, TaskResult},
util::random_32,
};
@@ -18,21 +18,19 @@ use crate::{
Error, Result,
};
-const CHANNEL_CAP: usize = 10;
+use message_dispatcher::MessageDispatcher;
+use subscriber::Subscriber;
+pub use subscriber::Subscription;
-/// Type alias for a subscription to receive notifications.
-///
-/// The receiver channel is returned by the `subscribe` method to receive
-/// notifications from the server.
-pub type Subscription = async_channel::Receiver<serde_json::Value>;
+type RequestID = u32;
/// Represents an RPC client
pub struct Client {
conn: Conn<serde_json::Value>,
timeout: Option<u64>,
- chans: Mutex<HashMap<u32, async_channel::Sender<message::Response>>>,
- subscriptions: Mutex<HashMap<SubscriptionID, async_channel::Sender<serde_json::Value>>>,
+ message_dispatcher: MessageDispatcher,
task_group: TaskGroup,
+ subscriber: Subscriber,
}
impl Client {
@@ -67,10 +65,9 @@ impl Client {
None => return Err(Error::InvalidMsg("Invalid subscription id")),
};
- let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_CAP);
- self.subscriptions.lock().await.insert(sub_id, ch_tx);
+ let rx = self.subscriber.subscribe(sub_id).await;
- Ok((sub_id, ch_rx))
+ Ok((sub_id, rx))
}
/// Unsubscribes from the provided method, waits for the response, and returns the result.
@@ -79,7 +76,7 @@ impl Client {
/// and subscription ID. It waits for the response to confirm the unsubscription.
pub async fn unsubscribe(&self, method: &str, sub_id: SubscriptionID) -> Result<()> {
let _ = self.send_request(method, sub_id).await?;
- self.subscriptions.lock().await.remove(&sub_id);
+ self.subscriber.unsubscribe(&sub_id).await;
Ok(())
}
@@ -88,7 +85,7 @@ impl Client {
method: &str,
params: T,
) -> Result<message::Response> {
- let id = random_32();
+ let id: RequestID = random_32();
let request = message::Request {
jsonrpc: message::JSONRPC_VERSION.to_string(),
id: json!(id),
@@ -98,16 +95,24 @@ impl Client {
let req_json = serde_json::to_value(&request)?;
+ // Send the json request
self.conn.send(req_json).await?;
- let (tx, rx) = async_channel::bounded(CHANNEL_CAP);
- self.chans.lock().await.insert(id, tx);
+ // Register a new request
+ let rx = self.message_dispatcher.register(id).await;
+
+ // Wait for the message dispatcher to send the response
+ let result = match self.timeout {
+ Some(t) => timeout(Duration::from_millis(t), rx.recv()).await?,
+ None => rx.recv().await,
+ };
- let response = match self.wait_for_response(rx).await {
+ let response = match result {
Ok(r) => r,
Err(err) => {
- self.chans.lock().await.remove(&id);
- return Err(err);
+ // Unregister the request if an error occurs
+ self.message_dispatcher.unregister(&id).await;
+ return Err(err.into());
}
};
@@ -115,6 +120,8 @@ impl Client {
return Err(Error::SubscribeError(error.code, error.message));
}
+ // It should be OK to unwrap here, as the message dispatcher checks
+ // for the response id.
if *response.id.as_ref().unwrap() != request.id {
return Err(Error::InvalidMsg("Invalid response id"));
}
@@ -123,28 +130,17 @@ impl Client {
Ok(response)
}
- async fn wait_for_response(
- &self,
- rx: async_channel::Receiver<message::Response>,
- ) -> Result<message::Response> {
- match self.timeout {
- Some(t) => timeout(Duration::from_millis(t), rx.recv())
- .await?
- .map_err(Error::from),
- None => rx.recv().await.map_err(Error::from),
- }
- }
-
fn start_background_receiving(self: &Arc<Self>) {
let selfc = self.clone();
- let on_failure = |result: TaskResult<Result<()>>| async move {
+ let on_complete = |result: TaskResult<Result<()>>| async move {
if let TaskResult::Completed(Err(err)) = result {
error!("background receiving stopped: {err}");
}
- // drop all subscription channels
- selfc.subscriptions.lock().await.clear();
+ // Drop all subscription
+ selfc.subscriber.drop_all().await;
};
let selfc = self.clone();
+ // Spawn a new task for listing to new coming messages.
self.task_group.spawn(
async move {
loop {
@@ -157,48 +153,23 @@ impl Client {
}
}
},
- on_failure,
+ on_complete,
);
}
async fn handle_msg(&self, msg: serde_json::Value) -> Result<()> {
+ // Check if the received message is of type Response
if let Ok(res) = serde_json::from_value::<message::Response>(msg.clone()) {
debug!("<-- {res}");
- if res.id.is_none() {
- return Err(Error::InvalidMsg("Response id is none"));
- }
-
- let id: u32 = serde_json::from_value(res.id.clone().unwrap())?;
- match self.chans.lock().await.remove(&id) {
- Some(tx) => tx.send(res).await?,
- None => return Err(Error::InvalidMsg("Receive unkown message")),
- }
-
+ self.message_dispatcher.dispatch(res).await?;
return Ok(());
}
+ // Check if the received message is of type Notification
if let Ok(nt) = serde_json::from_value::<message::Notification>(msg.clone()) {
debug!("<-- {nt}");
- let sub_result: message::NotificationResult = match nt.params {
- Some(ref p) => serde_json::from_value(p.clone())?,
- None => return Err(Error::InvalidMsg("Invalid notification msg")),
- };
-
- match self
- .subscriptions
- .lock()
- .await
- .get(&sub_result.subscription)
- {
- Some(s) => {
- s.send(sub_result.result.unwrap_or(json!(""))).await?;
- return Ok(());
- }
- None => {
- warn!("Receive unknown notification {}", sub_result.subscription);
- return Ok(());
- }
- }
+ self.subscriber.notify(nt).await?;
+ return Ok(());
}
error!("Receive unexpected msg: {msg}");
diff --git a/jsonrpc/src/client/subscriber.rs b/jsonrpc/src/client/subscriber.rs
new file mode 100644
index 0000000..d47cc2a
--- /dev/null
+++ b/jsonrpc/src/client/subscriber.rs
@@ -0,0 +1,64 @@
+use std::collections::HashMap;
+
+use async_channel::{Receiver, Sender};
+use log::warn;
+use serde_json::json;
+
+use karyon_core::async_runtime::lock::Mutex;
+
+use crate::{
+ message::{Notification, NotificationResult, SubscriptionID},
+ Error, Result,
+};
+
+const CHANNEL_CAP: usize = 10;
+
+/// Manages subscriptions for the client.
+pub(super) struct Subscriber {
+ subs: Mutex<HashMap<SubscriptionID, Sender<serde_json::Value>>>,
+}
+
+/// Type alias for a subscription to receive notifications.
+///
+/// The receiver channel is returned by the `subscribe`
+pub type Subscription = Receiver<serde_json::Value>;
+
+impl Subscriber {
+ pub(super) fn new() -> Self {
+ Self {
+ subs: Mutex::new(HashMap::new()),
+ }
+ }
+
+ pub(super) async fn subscribe(&self, id: SubscriptionID) -> Receiver<serde_json::Value> {
+ let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_CAP);
+ self.subs.lock().await.insert(id, ch_tx);
+ ch_rx
+ }
+
+ pub(super) async fn drop_all(&self) {
+ self.subs.lock().await.clear();
+ }
+
+ /// Unsubscribe
+ pub(super) async fn unsubscribe(&self, id: &SubscriptionID) {
+ self.subs.lock().await.remove(id);
+ }
+
+ pub(super) async fn notify(&self, nt: Notification) -> Result<()> {
+ let nt_res: NotificationResult = match nt.params {
+ Some(ref p) => serde_json::from_value(p.clone())?,
+ None => return Err(Error::InvalidMsg("Invalid notification msg")),
+ };
+ match self.subs.lock().await.get(&nt_res.subscription) {
+ Some(s) => {
+ s.send(nt_res.result.unwrap_or(json!(""))).await?;
+ Ok(())
+ }
+ None => {
+ warn!("Receive unknown notification {}", nt_res.subscription);
+ Ok(())
+ }
+ }
+ }
+}