diff options
Diffstat (limited to 'jsonrpc/src/client')
-rw-r--r-- | jsonrpc/src/client/builder.rs | 10 | ||||
-rw-r--r-- | jsonrpc/src/client/message_dispatcher.rs | 54 | ||||
-rw-r--r-- | jsonrpc/src/client/mod.rs | 107 | ||||
-rw-r--r-- | jsonrpc/src/client/subscriber.rs | 64 |
4 files changed, 162 insertions, 73 deletions
diff --git a/jsonrpc/src/client/builder.rs b/jsonrpc/src/client/builder.rs index a287070..2263498 100644 --- a/jsonrpc/src/client/builder.rs +++ b/jsonrpc/src/client/builder.rs @@ -1,11 +1,11 @@ -use std::{collections::HashMap, sync::Arc}; +use std::sync::Arc; #[cfg(feature = "smol")] use futures_rustls::rustls; #[cfg(feature = "tokio")] use tokio_rustls::rustls; -use karyon_core::{async_runtime::lock::Mutex, async_util::TaskGroup}; +use karyon_core::async_util::TaskGroup; use karyon_net::{tls::ClientTlsConfig, Conn, Endpoint, ToEndpoint}; #[cfg(feature = "ws")] @@ -16,7 +16,7 @@ use crate::codec::WsJsonCodec; use crate::{codec::JsonCodec, Error, Result, TcpConfig}; -use super::Client; +use super::{Client, MessageDispatcher, Subscriber}; const DEFAULT_TIMEOUT: u64 = 3000; // 3s @@ -171,8 +171,8 @@ impl ClientBuilder { let client = Arc::new(Client { timeout: self.timeout, conn, - chans: Mutex::new(HashMap::new()), - subscriptions: Mutex::new(HashMap::new()), + message_dispatcher: MessageDispatcher::new(), + subscriber: Subscriber::new(), task_group: TaskGroup::new(), }); client.start_background_receiving(); diff --git a/jsonrpc/src/client/message_dispatcher.rs b/jsonrpc/src/client/message_dispatcher.rs new file mode 100644 index 0000000..a803f6e --- /dev/null +++ b/jsonrpc/src/client/message_dispatcher.rs @@ -0,0 +1,54 @@ +use std::collections::HashMap; + +use async_channel::{Receiver, Sender}; + +use karyon_core::async_runtime::lock::Mutex; + +use crate::{message, Error, Result}; + +use super::RequestID; + +const CHANNEL_CAP: usize = 10; + +/// Manages client requests +pub(super) struct MessageDispatcher { + chans: Mutex<HashMap<RequestID, Sender<message::Response>>>, +} + +impl MessageDispatcher { + /// Creates a new MessageDispatcher + pub(super) fn new() -> Self { + Self { + chans: Mutex::new(HashMap::new()), + } + } + + /// Registers a new request with a given ID and returns a Receiver channel + /// to wait for the response. + pub(super) async fn register(&self, id: RequestID) -> Receiver<message::Response> { + let (tx, rx) = async_channel::bounded(CHANNEL_CAP); + self.chans.lock().await.insert(id, tx); + rx + } + + /// Unregisters the request with the provided ID + pub(super) async fn unregister(&self, id: &RequestID) { + self.chans.lock().await.remove(id); + } + + /// Dispatches a response to the channel associated with the response's ID. + /// + /// If a channel is registered for the response's ID, the response is sent + /// through that channel. If no channel is found for the ID, returns an error. + pub(super) async fn dispatch(&self, res: message::Response) -> Result<()> { + if res.id.is_none() { + return Err(Error::InvalidMsg("Response id is none")); + } + let id: RequestID = serde_json::from_value(res.id.clone().unwrap())?; + let val = self.chans.lock().await.remove(&id); + match val { + Some(tx) => tx.send(res).await.map_err(Error::from), + None => Err(Error::InvalidMsg("Receive unknown message")), + } + } +} diff --git a/jsonrpc/src/client/mod.rs b/jsonrpc/src/client/mod.rs index 3a4505c..95354d3 100644 --- a/jsonrpc/src/client/mod.rs +++ b/jsonrpc/src/client/mod.rs @@ -1,13 +1,13 @@ pub mod builder; +mod message_dispatcher; +mod subscriber; -use std::{collections::HashMap, sync::Arc, time::Duration}; - -use log::{debug, error, warn}; +use log::{debug, error}; use serde::{de::DeserializeOwned, Serialize}; use serde_json::json; +use std::{sync::Arc, time::Duration}; use karyon_core::{ - async_runtime::lock::Mutex, async_util::{timeout, TaskGroup, TaskResult}, util::random_32, }; @@ -18,21 +18,19 @@ use crate::{ Error, Result, }; -const CHANNEL_CAP: usize = 10; +use message_dispatcher::MessageDispatcher; +use subscriber::Subscriber; +pub use subscriber::Subscription; -/// Type alias for a subscription to receive notifications. -/// -/// The receiver channel is returned by the `subscribe` method to receive -/// notifications from the server. -pub type Subscription = async_channel::Receiver<serde_json::Value>; +type RequestID = u32; /// Represents an RPC client pub struct Client { conn: Conn<serde_json::Value>, timeout: Option<u64>, - chans: Mutex<HashMap<u32, async_channel::Sender<message::Response>>>, - subscriptions: Mutex<HashMap<SubscriptionID, async_channel::Sender<serde_json::Value>>>, + message_dispatcher: MessageDispatcher, task_group: TaskGroup, + subscriber: Subscriber, } impl Client { @@ -67,10 +65,9 @@ impl Client { None => return Err(Error::InvalidMsg("Invalid subscription id")), }; - let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_CAP); - self.subscriptions.lock().await.insert(sub_id, ch_tx); + let rx = self.subscriber.subscribe(sub_id).await; - Ok((sub_id, ch_rx)) + Ok((sub_id, rx)) } /// Unsubscribes from the provided method, waits for the response, and returns the result. @@ -79,7 +76,7 @@ impl Client { /// and subscription ID. It waits for the response to confirm the unsubscription. pub async fn unsubscribe(&self, method: &str, sub_id: SubscriptionID) -> Result<()> { let _ = self.send_request(method, sub_id).await?; - self.subscriptions.lock().await.remove(&sub_id); + self.subscriber.unsubscribe(&sub_id).await; Ok(()) } @@ -88,7 +85,7 @@ impl Client { method: &str, params: T, ) -> Result<message::Response> { - let id = random_32(); + let id: RequestID = random_32(); let request = message::Request { jsonrpc: message::JSONRPC_VERSION.to_string(), id: json!(id), @@ -98,16 +95,24 @@ impl Client { let req_json = serde_json::to_value(&request)?; + // Send the json request self.conn.send(req_json).await?; - let (tx, rx) = async_channel::bounded(CHANNEL_CAP); - self.chans.lock().await.insert(id, tx); + // Register a new request + let rx = self.message_dispatcher.register(id).await; + + // Wait for the message dispatcher to send the response + let result = match self.timeout { + Some(t) => timeout(Duration::from_millis(t), rx.recv()).await?, + None => rx.recv().await, + }; - let response = match self.wait_for_response(rx).await { + let response = match result { Ok(r) => r, Err(err) => { - self.chans.lock().await.remove(&id); - return Err(err); + // Unregister the request if an error occurs + self.message_dispatcher.unregister(&id).await; + return Err(err.into()); } }; @@ -115,6 +120,8 @@ impl Client { return Err(Error::SubscribeError(error.code, error.message)); } + // It should be OK to unwrap here, as the message dispatcher checks + // for the response id. if *response.id.as_ref().unwrap() != request.id { return Err(Error::InvalidMsg("Invalid response id")); } @@ -123,28 +130,17 @@ impl Client { Ok(response) } - async fn wait_for_response( - &self, - rx: async_channel::Receiver<message::Response>, - ) -> Result<message::Response> { - match self.timeout { - Some(t) => timeout(Duration::from_millis(t), rx.recv()) - .await? - .map_err(Error::from), - None => rx.recv().await.map_err(Error::from), - } - } - fn start_background_receiving(self: &Arc<Self>) { let selfc = self.clone(); - let on_failure = |result: TaskResult<Result<()>>| async move { + let on_complete = |result: TaskResult<Result<()>>| async move { if let TaskResult::Completed(Err(err)) = result { error!("background receiving stopped: {err}"); } - // drop all subscription channels - selfc.subscriptions.lock().await.clear(); + // Drop all subscription + selfc.subscriber.drop_all().await; }; let selfc = self.clone(); + // Spawn a new task for listing to new coming messages. self.task_group.spawn( async move { loop { @@ -157,48 +153,23 @@ impl Client { } } }, - on_failure, + on_complete, ); } async fn handle_msg(&self, msg: serde_json::Value) -> Result<()> { + // Check if the received message is of type Response if let Ok(res) = serde_json::from_value::<message::Response>(msg.clone()) { debug!("<-- {res}"); - if res.id.is_none() { - return Err(Error::InvalidMsg("Response id is none")); - } - - let id: u32 = serde_json::from_value(res.id.clone().unwrap())?; - match self.chans.lock().await.remove(&id) { - Some(tx) => tx.send(res).await?, - None => return Err(Error::InvalidMsg("Receive unkown message")), - } - + self.message_dispatcher.dispatch(res).await?; return Ok(()); } + // Check if the received message is of type Notification if let Ok(nt) = serde_json::from_value::<message::Notification>(msg.clone()) { debug!("<-- {nt}"); - let sub_result: message::NotificationResult = match nt.params { - Some(ref p) => serde_json::from_value(p.clone())?, - None => return Err(Error::InvalidMsg("Invalid notification msg")), - }; - - match self - .subscriptions - .lock() - .await - .get(&sub_result.subscription) - { - Some(s) => { - s.send(sub_result.result.unwrap_or(json!(""))).await?; - return Ok(()); - } - None => { - warn!("Receive unknown notification {}", sub_result.subscription); - return Ok(()); - } - } + self.subscriber.notify(nt).await?; + return Ok(()); } error!("Receive unexpected msg: {msg}"); diff --git a/jsonrpc/src/client/subscriber.rs b/jsonrpc/src/client/subscriber.rs new file mode 100644 index 0000000..d47cc2a --- /dev/null +++ b/jsonrpc/src/client/subscriber.rs @@ -0,0 +1,64 @@ +use std::collections::HashMap; + +use async_channel::{Receiver, Sender}; +use log::warn; +use serde_json::json; + +use karyon_core::async_runtime::lock::Mutex; + +use crate::{ + message::{Notification, NotificationResult, SubscriptionID}, + Error, Result, +}; + +const CHANNEL_CAP: usize = 10; + +/// Manages subscriptions for the client. +pub(super) struct Subscriber { + subs: Mutex<HashMap<SubscriptionID, Sender<serde_json::Value>>>, +} + +/// Type alias for a subscription to receive notifications. +/// +/// The receiver channel is returned by the `subscribe` +pub type Subscription = Receiver<serde_json::Value>; + +impl Subscriber { + pub(super) fn new() -> Self { + Self { + subs: Mutex::new(HashMap::new()), + } + } + + pub(super) async fn subscribe(&self, id: SubscriptionID) -> Receiver<serde_json::Value> { + let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_CAP); + self.subs.lock().await.insert(id, ch_tx); + ch_rx + } + + pub(super) async fn drop_all(&self) { + self.subs.lock().await.clear(); + } + + /// Unsubscribe + pub(super) async fn unsubscribe(&self, id: &SubscriptionID) { + self.subs.lock().await.remove(id); + } + + pub(super) async fn notify(&self, nt: Notification) -> Result<()> { + let nt_res: NotificationResult = match nt.params { + Some(ref p) => serde_json::from_value(p.clone())?, + None => return Err(Error::InvalidMsg("Invalid notification msg")), + }; + match self.subs.lock().await.get(&nt_res.subscription) { + Some(s) => { + s.send(nt_res.result.unwrap_or(json!(""))).await?; + Ok(()) + } + None => { + warn!("Receive unknown notification {}", nt_res.subscription); + Ok(()) + } + } + } +} |