From 72accd61fad0eea312d868b283c6b26da4802ff8 Mon Sep 17 00:00:00 2001
From: hozan23 <hozan23@karyontech.net>
Date: Mon, 17 Jun 2024 16:17:17 +0200
Subject: jsonrpc/client: use serde untagged enum for decoding Notifications
 and Responses

---
 jsonrpc/src/client/builder.rs       |  4 +--
 jsonrpc/src/client/mod.rs           | 68 +++++++++++++++++++++----------------
 jsonrpc/src/client/subscriber.rs    | 62 ---------------------------------
 jsonrpc/src/client/subscriptions.rs | 61 +++++++++++++++++++++++++++++++++
 4 files changed, 101 insertions(+), 94 deletions(-)
 delete mode 100644 jsonrpc/src/client/subscriber.rs
 create mode 100644 jsonrpc/src/client/subscriptions.rs

(limited to 'jsonrpc/src')

diff --git a/jsonrpc/src/client/builder.rs b/jsonrpc/src/client/builder.rs
index 2263498..510ce56 100644
--- a/jsonrpc/src/client/builder.rs
+++ b/jsonrpc/src/client/builder.rs
@@ -16,7 +16,7 @@ use crate::codec::WsJsonCodec;
 
 use crate::{codec::JsonCodec, Error, Result, TcpConfig};
 
-use super::{Client, MessageDispatcher, Subscriber};
+use super::{Client, MessageDispatcher, Subscriptions};
 
 const DEFAULT_TIMEOUT: u64 = 3000; // 3s
 
@@ -172,7 +172,7 @@ impl ClientBuilder {
             timeout: self.timeout,
             conn,
             message_dispatcher: MessageDispatcher::new(),
-            subscriber: Subscriber::new(),
+            subscriptions: Subscriptions::new(),
             task_group: TaskGroup::new(),
         });
         client.start_background_receiving();
diff --git a/jsonrpc/src/client/mod.rs b/jsonrpc/src/client/mod.rs
index 9c07509..eddba19 100644
--- a/jsonrpc/src/client/mod.rs
+++ b/jsonrpc/src/client/mod.rs
@@ -1,11 +1,12 @@
 pub mod builder;
 mod message_dispatcher;
-mod subscriber;
+mod subscriptions;
+
+use std::{sync::Arc, time::Duration};
 
 use log::{debug, error};
-use serde::{de::DeserializeOwned, Serialize};
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
 use serde_json::json;
-use std::{sync::Arc, time::Duration};
 
 use karyon_core::{
     async_util::{timeout, TaskGroup, TaskResult},
@@ -19,8 +20,8 @@ use crate::{
 };
 
 use message_dispatcher::MessageDispatcher;
-use subscriber::Subscriber;
-pub use subscriber::Subscription;
+pub use subscriptions::Subscription;
+use subscriptions::Subscriptions;
 
 type RequestID = u32;
 
@@ -30,7 +31,14 @@ pub struct Client {
     timeout: Option<u64>,
     message_dispatcher: MessageDispatcher,
     task_group: TaskGroup,
-    subscriber: Subscriber,
+    subscriptions: Arc<Subscriptions>,
+}
+
+#[derive(Serialize, Deserialize)]
+#[serde(untagged)]
+enum NewMsg {
+    Notification(message::Notification),
+    Response(message::Response),
 }
 
 impl Client {
@@ -65,9 +73,9 @@ impl Client {
             None => return Err(Error::InvalidMsg("Invalid subscription id")),
         };
 
-        let rx = self.subscriber.subscribe(sub_id).await;
+        let sub = self.subscriptions.subscribe(sub_id).await;
 
-        Ok((sub_id, rx))
+        Ok((sub_id, sub))
     }
 
     /// Unsubscribes from the provided method, waits for the response, and returns the result.
@@ -76,7 +84,7 @@ impl Client {
     /// and subscription ID. It waits for the response to confirm the unsubscription.
     pub async fn unsubscribe(&self, method: &str, sub_id: SubscriptionID) -> Result<()> {
         let _ = self.send_request(method, sub_id).await?;
-        self.subscriber.unsubscribe(&sub_id).await;
+        self.subscriptions.unsubscribe(&sub_id).await;
         Ok(())
     }
 
@@ -134,11 +142,12 @@ impl Client {
         let selfc = self.clone();
         let on_complete = |result: TaskResult<Result<()>>| async move {
             if let TaskResult::Completed(Err(err)) = result {
-                error!("background receiving stopped: {err}");
+                error!("Background receiving loop stopped: {err}");
             }
             // Drop all subscription
-            selfc.subscriber.drop_all().await;
+            selfc.subscriptions.drop_all().await;
         };
+
         let selfc = self.clone();
         // Spawn a new task for listing to new coming messages.
         self.task_group.spawn(
@@ -146,10 +155,8 @@ impl Client {
                 loop {
                     let msg = selfc.conn.recv().await?;
                     if let Err(err) = selfc.handle_msg(msg).await {
-                        error!(
-                            "Handle a msg from the endpoint {} : {err}",
-                            selfc.conn.peer_endpoint()?
-                        );
+                        let endpoint = selfc.conn.peer_endpoint()?;
+                        error!("Handle a new msg from the endpoint {endpoint} : {err}",);
                     }
                 }
             },
@@ -158,21 +165,22 @@ impl Client {
     }
 
     async fn handle_msg(&self, msg: serde_json::Value) -> Result<()> {
-        // Check if the received message is of type Response
-        if let Ok(res) = serde_json::from_value::<message::Response>(msg.clone()) {
-            debug!("<-- {res}");
-            self.message_dispatcher.dispatch(res).await?;
-            return Ok(());
-        }
-
-        // Check if the received message is of type Notification
-        if let Ok(nt) = serde_json::from_value::<message::Notification>(msg.clone()) {
-            debug!("<-- {nt}");
-            self.subscriber.notify(nt).await?;
-            return Ok(());
+        match serde_json::from_value::<NewMsg>(msg.clone()) {
+            Ok(msg) => match msg {
+                NewMsg::Response(res) => {
+                    debug!("<-- {res}");
+                    self.message_dispatcher.dispatch(res).await
+                }
+                NewMsg::Notification(nt) => {
+                    debug!("<-- {nt}");
+                    self.subscriptions.notify(nt).await?;
+                    Ok(())
+                }
+            },
+            Err(_) => {
+                error!("Receive unexpected msg: {msg}");
+                Err(Error::InvalidMsg("Unexpected msg"))
+            }
         }
-
-        error!("Receive unexpected msg: {msg}");
-        Err(Error::InvalidMsg("Unexpected msg"))
     }
 }
diff --git a/jsonrpc/src/client/subscriber.rs b/jsonrpc/src/client/subscriber.rs
deleted file mode 100644
index 168f16e..0000000
--- a/jsonrpc/src/client/subscriber.rs
+++ /dev/null
@@ -1,62 +0,0 @@
-use std::collections::HashMap;
-
-use async_channel::{Receiver, Sender};
-use log::warn;
-use serde_json::json;
-
-use karyon_core::async_runtime::lock::Mutex;
-
-use crate::{
-    message::{Notification, NotificationResult, SubscriptionID},
-    Error, Result,
-};
-
-/// Manages subscriptions for the client.
-pub(super) struct Subscriber {
-    subs: Mutex<HashMap<SubscriptionID, Sender<serde_json::Value>>>,
-}
-
-/// Type alias for a subscription to receive notifications.
-///
-/// The receiver channel is returned by the `subscribe`
-pub type Subscription = Receiver<serde_json::Value>;
-
-impl Subscriber {
-    pub(super) fn new() -> Self {
-        Self {
-            subs: Mutex::new(HashMap::new()),
-        }
-    }
-
-    pub(super) async fn subscribe(&self, id: SubscriptionID) -> Receiver<serde_json::Value> {
-        let (ch_tx, ch_rx) = async_channel::unbounded();
-        self.subs.lock().await.insert(id, ch_tx);
-        ch_rx
-    }
-
-    pub(super) async fn drop_all(&self) {
-        self.subs.lock().await.clear();
-    }
-
-    /// Unsubscribe
-    pub(super) async fn unsubscribe(&self, id: &SubscriptionID) {
-        self.subs.lock().await.remove(id);
-    }
-
-    pub(super) async fn notify(&self, nt: Notification) -> Result<()> {
-        let nt_res: NotificationResult = match nt.params {
-            Some(ref p) => serde_json::from_value(p.clone())?,
-            None => return Err(Error::InvalidMsg("Invalid notification msg")),
-        };
-        match self.subs.lock().await.get(&nt_res.subscription) {
-            Some(s) => {
-                s.send(nt_res.result.unwrap_or(json!(""))).await?;
-                Ok(())
-            }
-            None => {
-                warn!("Receive unknown notification {}", nt_res.subscription);
-                Ok(())
-            }
-        }
-    }
-}
diff --git a/jsonrpc/src/client/subscriptions.rs b/jsonrpc/src/client/subscriptions.rs
new file mode 100644
index 0000000..9c8a9f4
--- /dev/null
+++ b/jsonrpc/src/client/subscriptions.rs
@@ -0,0 +1,61 @@
+use std::{collections::HashMap, sync::Arc};
+
+use async_channel::{Receiver, Sender};
+use log::warn;
+use serde_json::json;
+use serde_json::Value;
+
+use karyon_core::async_runtime::lock::Mutex;
+
+use crate::{
+    message::{Notification, NotificationResult, SubscriptionID},
+    Error, Result,
+};
+
+/// Type alias for a subscription to receive notifications.
+///
+/// The receiver channel is returned by the `subscribe`
+pub type Subscription = Receiver<Value>;
+
+/// Manages subscriptions for the client.
+pub(super) struct Subscriptions {
+    subs: Mutex<HashMap<SubscriptionID, Sender<Value>>>,
+}
+
+impl Subscriptions {
+    pub(super) fn new() -> Arc<Self> {
+        Arc::new(Self {
+            subs: Mutex::new(HashMap::new()),
+        })
+    }
+
+    pub(super) async fn subscribe(self: &Arc<Self>, id: SubscriptionID) -> Subscription {
+        let (ch_tx, ch_rx) = async_channel::unbounded();
+        self.subs.lock().await.insert(id, ch_tx);
+        ch_rx
+    }
+
+    pub(super) async fn drop_all(&self) {
+        self.subs.lock().await.clear();
+    }
+
+    pub(super) async fn unsubscribe(&self, id: &SubscriptionID) {
+        self.subs.lock().await.remove(id);
+    }
+
+    pub(super) async fn notify(&self, nt: Notification) -> Result<()> {
+        let nt_res: NotificationResult = match nt.params {
+            Some(ref p) => serde_json::from_value(p.clone())?,
+            None => return Err(Error::InvalidMsg("Invalid notification msg")),
+        };
+
+        match self.subs.lock().await.get(&nt_res.subscription) {
+            Some(s) => s.send(nt_res.result.unwrap_or(json!(""))).await?,
+            None => {
+                warn!("Receive unknown notification {}", nt_res.subscription)
+            }
+        }
+
+        Ok(())
+    }
+}
-- 
cgit v1.2.3