From 8fc494d2d508f0e0beefccda31d15a5e387a9791 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Mon, 17 Jun 2024 16:18:34 +0200 Subject: jsonrpc/server: use weak pointer for Channel in subscriptions --- jsonrpc/src/server/channel.rs | 18 +++++++++++++----- jsonrpc/src/server/mod.rs | 3 +-- 2 files changed, 14 insertions(+), 7 deletions(-) (limited to 'jsonrpc') diff --git a/jsonrpc/src/server/channel.rs b/jsonrpc/src/server/channel.rs index b5c9184..36896b4 100644 --- a/jsonrpc/src/server/channel.rs +++ b/jsonrpc/src/server/channel.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::sync::{Arc, Weak}; use karyon_core::{async_runtime::lock::Mutex, util::random_32}; @@ -15,7 +15,7 @@ pub(crate) struct NewNotification { #[derive(Clone)] pub struct Subscription { pub id: SubscriptionID, - parent: Arc, + parent: Weak, chan: async_channel::Sender, method: String, } @@ -23,7 +23,7 @@ pub struct Subscription { impl Subscription { /// Creates a new [`Subscription`] fn new( - parent: Arc, + parent: Weak, id: SubscriptionID, chan: async_channel::Sender, method: &str, @@ -38,7 +38,7 @@ impl Subscription { /// Sends a notification to the subscriber pub async fn notify(&self, res: serde_json::Value) -> Result<()> { - if self.parent.subs.lock().await.contains(&self.id) { + if self.still_subscribed().await { let nt = NewNotification { sub_id: self.id, result: res, @@ -50,6 +50,14 @@ impl Subscription { Err(Error::SubscriptionNotFound(self.id.to_string())) } } + + /// Checks from the partent if this subscription is still subscribed + pub async fn still_subscribed(&self) -> bool { + match self.parent.upgrade() { + Some(parent) => parent.subs.lock().await.contains(&self.id), + None => false, + } + } } /// Represents a connection channel for creating/removing subscriptions @@ -70,7 +78,7 @@ impl Channel { /// Creates a new [`Subscription`] pub async fn new_subscription(self: &Arc, method: &str) -> Subscription { let sub_id = random_32(); - let sub = Subscription::new(self.clone(), sub_id, self.chan.clone(), method); + let sub = Subscription::new(Arc::downgrade(self), sub_id, self.chan.clone(), method); self.subs.lock().await.push(sub_id); sub } diff --git a/jsonrpc/src/server/mod.rs b/jsonrpc/src/server/mod.rs index 7ff1a8c..6f539be 100644 --- a/jsonrpc/src/server/mod.rs +++ b/jsonrpc/src/server/mod.rs @@ -268,9 +268,8 @@ impl Server { if let Some(service) = self.pubsub_services.get(&req.srvc_name) { // Check if the method exists within the service if let Some(method) = service.get_pubsub_method(&req.method_name) { - let name = format!("{}.{}", service.name(), req.method_name); let params = req.msg.params.unwrap_or(serde_json::json!(())); - response.result = match method(channel, name, params).await { + response.result = match method(channel, req.msg.method, params).await { Ok(res) => Some(res), Err(err) => return err.to_response(Some(req.msg.id), None), }; -- cgit v1.2.3