aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhozan23 <hozan23@karyontech.net>2024-06-17 16:17:17 +0200
committerhozan23 <hozan23@karyontech.net>2024-06-17 16:17:17 +0200
commit72accd61fad0eea312d868b283c6b26da4802ff8 (patch)
treea1b8e0df25df3ea6bc6be5b1fe6ebe1f282150c4
parent2d2925c3e21af8ee8f745aa00c0a59dcd9c95df9 (diff)
jsonrpc/client: use serde untagged enum for decoding Notifications and Responses
-rw-r--r--jsonrpc/src/client/builder.rs4
-rw-r--r--jsonrpc/src/client/mod.rs68
-rw-r--r--jsonrpc/src/client/subscriptions.rs (renamed from jsonrpc/src/client/subscriber.rs)37
3 files changed, 58 insertions, 51 deletions
diff --git a/jsonrpc/src/client/builder.rs b/jsonrpc/src/client/builder.rs
index 2263498..510ce56 100644
--- a/jsonrpc/src/client/builder.rs
+++ b/jsonrpc/src/client/builder.rs
@@ -16,7 +16,7 @@ use crate::codec::WsJsonCodec;
use crate::{codec::JsonCodec, Error, Result, TcpConfig};
-use super::{Client, MessageDispatcher, Subscriber};
+use super::{Client, MessageDispatcher, Subscriptions};
const DEFAULT_TIMEOUT: u64 = 3000; // 3s
@@ -172,7 +172,7 @@ impl ClientBuilder {
timeout: self.timeout,
conn,
message_dispatcher: MessageDispatcher::new(),
- subscriber: Subscriber::new(),
+ subscriptions: Subscriptions::new(),
task_group: TaskGroup::new(),
});
client.start_background_receiving();
diff --git a/jsonrpc/src/client/mod.rs b/jsonrpc/src/client/mod.rs
index 9c07509..eddba19 100644
--- a/jsonrpc/src/client/mod.rs
+++ b/jsonrpc/src/client/mod.rs
@@ -1,11 +1,12 @@
pub mod builder;
mod message_dispatcher;
-mod subscriber;
+mod subscriptions;
+
+use std::{sync::Arc, time::Duration};
use log::{debug, error};
-use serde::{de::DeserializeOwned, Serialize};
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::json;
-use std::{sync::Arc, time::Duration};
use karyon_core::{
async_util::{timeout, TaskGroup, TaskResult},
@@ -19,8 +20,8 @@ use crate::{
};
use message_dispatcher::MessageDispatcher;
-use subscriber::Subscriber;
-pub use subscriber::Subscription;
+pub use subscriptions::Subscription;
+use subscriptions::Subscriptions;
type RequestID = u32;
@@ -30,7 +31,14 @@ pub struct Client {
timeout: Option<u64>,
message_dispatcher: MessageDispatcher,
task_group: TaskGroup,
- subscriber: Subscriber,
+ subscriptions: Arc<Subscriptions>,
+}
+
+#[derive(Serialize, Deserialize)]
+#[serde(untagged)]
+enum NewMsg {
+ Notification(message::Notification),
+ Response(message::Response),
}
impl Client {
@@ -65,9 +73,9 @@ impl Client {
None => return Err(Error::InvalidMsg("Invalid subscription id")),
};
- let rx = self.subscriber.subscribe(sub_id).await;
+ let sub = self.subscriptions.subscribe(sub_id).await;
- Ok((sub_id, rx))
+ Ok((sub_id, sub))
}
/// Unsubscribes from the provided method, waits for the response, and returns the result.
@@ -76,7 +84,7 @@ impl Client {
/// and subscription ID. It waits for the response to confirm the unsubscription.
pub async fn unsubscribe(&self, method: &str, sub_id: SubscriptionID) -> Result<()> {
let _ = self.send_request(method, sub_id).await?;
- self.subscriber.unsubscribe(&sub_id).await;
+ self.subscriptions.unsubscribe(&sub_id).await;
Ok(())
}
@@ -134,11 +142,12 @@ impl Client {
let selfc = self.clone();
let on_complete = |result: TaskResult<Result<()>>| async move {
if let TaskResult::Completed(Err(err)) = result {
- error!("background receiving stopped: {err}");
+ error!("Background receiving loop stopped: {err}");
}
// Drop all subscription
- selfc.subscriber.drop_all().await;
+ selfc.subscriptions.drop_all().await;
};
+
let selfc = self.clone();
// Spawn a new task for listing to new coming messages.
self.task_group.spawn(
@@ -146,10 +155,8 @@ impl Client {
loop {
let msg = selfc.conn.recv().await?;
if let Err(err) = selfc.handle_msg(msg).await {
- error!(
- "Handle a msg from the endpoint {} : {err}",
- selfc.conn.peer_endpoint()?
- );
+ let endpoint = selfc.conn.peer_endpoint()?;
+ error!("Handle a new msg from the endpoint {endpoint} : {err}",);
}
}
},
@@ -158,21 +165,22 @@ impl Client {
}
async fn handle_msg(&self, msg: serde_json::Value) -> Result<()> {
- // Check if the received message is of type Response
- if let Ok(res) = serde_json::from_value::<message::Response>(msg.clone()) {
- debug!("<-- {res}");
- self.message_dispatcher.dispatch(res).await?;
- return Ok(());
- }
-
- // Check if the received message is of type Notification
- if let Ok(nt) = serde_json::from_value::<message::Notification>(msg.clone()) {
- debug!("<-- {nt}");
- self.subscriber.notify(nt).await?;
- return Ok(());
+ match serde_json::from_value::<NewMsg>(msg.clone()) {
+ Ok(msg) => match msg {
+ NewMsg::Response(res) => {
+ debug!("<-- {res}");
+ self.message_dispatcher.dispatch(res).await
+ }
+ NewMsg::Notification(nt) => {
+ debug!("<-- {nt}");
+ self.subscriptions.notify(nt).await?;
+ Ok(())
+ }
+ },
+ Err(_) => {
+ error!("Receive unexpected msg: {msg}");
+ Err(Error::InvalidMsg("Unexpected msg"))
+ }
}
-
- error!("Receive unexpected msg: {msg}");
- Err(Error::InvalidMsg("Unexpected msg"))
}
}
diff --git a/jsonrpc/src/client/subscriber.rs b/jsonrpc/src/client/subscriptions.rs
index 168f16e..9c8a9f4 100644
--- a/jsonrpc/src/client/subscriber.rs
+++ b/jsonrpc/src/client/subscriptions.rs
@@ -1,8 +1,9 @@
-use std::collections::HashMap;
+use std::{collections::HashMap, sync::Arc};
use async_channel::{Receiver, Sender};
use log::warn;
use serde_json::json;
+use serde_json::Value;
use karyon_core::async_runtime::lock::Mutex;
@@ -11,24 +12,24 @@ use crate::{
Error, Result,
};
-/// Manages subscriptions for the client.
-pub(super) struct Subscriber {
- subs: Mutex<HashMap<SubscriptionID, Sender<serde_json::Value>>>,
-}
-
/// Type alias for a subscription to receive notifications.
///
/// The receiver channel is returned by the `subscribe`
-pub type Subscription = Receiver<serde_json::Value>;
+pub type Subscription = Receiver<Value>;
+
+/// Manages subscriptions for the client.
+pub(super) struct Subscriptions {
+ subs: Mutex<HashMap<SubscriptionID, Sender<Value>>>,
+}
-impl Subscriber {
- pub(super) fn new() -> Self {
- Self {
+impl Subscriptions {
+ pub(super) fn new() -> Arc<Self> {
+ Arc::new(Self {
subs: Mutex::new(HashMap::new()),
- }
+ })
}
- pub(super) async fn subscribe(&self, id: SubscriptionID) -> Receiver<serde_json::Value> {
+ pub(super) async fn subscribe(self: &Arc<Self>, id: SubscriptionID) -> Subscription {
let (ch_tx, ch_rx) = async_channel::unbounded();
self.subs.lock().await.insert(id, ch_tx);
ch_rx
@@ -38,7 +39,6 @@ impl Subscriber {
self.subs.lock().await.clear();
}
- /// Unsubscribe
pub(super) async fn unsubscribe(&self, id: &SubscriptionID) {
self.subs.lock().await.remove(id);
}
@@ -48,15 +48,14 @@ impl Subscriber {
Some(ref p) => serde_json::from_value(p.clone())?,
None => return Err(Error::InvalidMsg("Invalid notification msg")),
};
+
match self.subs.lock().await.get(&nt_res.subscription) {
- Some(s) => {
- s.send(nt_res.result.unwrap_or(json!(""))).await?;
- Ok(())
- }
+ Some(s) => s.send(nt_res.result.unwrap_or(json!(""))).await?,
None => {
- warn!("Receive unknown notification {}", nt_res.subscription);
- Ok(())
+ warn!("Receive unknown notification {}", nt_res.subscription)
}
}
+
+ Ok(())
}
}