From 2d1a8aea0b9330cd2eaad26eb187644adad6bed9 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 23 May 2024 00:19:58 +0200 Subject: jsonrpc: spawn task when handle new request --- core/src/util/mod.rs | 5 ++ jsonrpc/README.md | 12 +-- jsonrpc/examples/pubsub_server.rs | 23 ++++-- jsonrpc/examples/server.rs | 7 +- jsonrpc/examples/tokio_server/src/main.rs | 6 +- jsonrpc/src/client/mod.rs | 45 +++++++---- jsonrpc/src/codec.rs | 14 +++- jsonrpc/src/error.rs | 3 + jsonrpc/src/message.rs | 29 ++++--- jsonrpc/src/server/channel.rs | 39 +++++++--- jsonrpc/src/server/mod.rs | 124 +++++++++++++++++++++--------- jsonrpc/src/server/pubsub_service.rs | 6 +- 12 files changed, 220 insertions(+), 93 deletions(-) diff --git a/core/src/util/mod.rs b/core/src/util/mod.rs index a3c3f50..ec59e1c 100644 --- a/core/src/util/mod.rs +++ b/core/src/util/mod.rs @@ -13,6 +13,11 @@ pub fn random_32() -> u32 { OsRng.gen() } +/// Generates and returns a random u64 using `rand::rngs::OsRng`. +pub fn random_64() -> u64 { + OsRng.gen() +} + /// Generates and returns a random u16 using `rand::rngs::OsRng`. pub fn random_16() -> u16 { OsRng.gen() diff --git a/jsonrpc/README.md b/jsonrpc/README.md index 091af99..03f5ace 100644 --- a/jsonrpc/README.md +++ b/jsonrpc/README.md @@ -16,7 +16,7 @@ features: ## Example ```rust -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use serde_json::Value; use smol::stream::StreamExt; @@ -45,8 +45,8 @@ impl HelloWorld { #[rpc_pubsub_impl] impl HelloWorld { - async fn log_subscribe(&self, chan: ArcChannel, _params: Value) -> Result { - let sub = chan.new_subscription().await; + async fn log_subscribe(&self, chan: ArcChannel, method: String, _params: Value) -> Result { + let sub = chan.new_subscription(&method).await; let sub_id = sub.id.clone(); smol::spawn(async move { loop { @@ -62,7 +62,7 @@ impl HelloWorld { Ok(serde_json::json!(sub_id)) } - async fn log_unsubscribe(&self, chan: ArcChannel, params: Value) -> Result { + async fn log_unsubscribe(&self, chan: ArcChannel, method: String, params: Value) -> Result { let sub_id: SubscriptionID = serde_json::from_value(params)?; chan.remove_subscription(&sub_id).await; Ok(serde_json::json!(true)) @@ -84,7 +84,9 @@ async { .expect("build the server"); // Starts the server - server.start().await.expect("start the server"); + server.start().await; + + smol::Timer::after(Duration::MAX).await; }; // Client diff --git a/jsonrpc/examples/pubsub_server.rs b/jsonrpc/examples/pubsub_server.rs index 739e6d5..4b77c45 100644 --- a/jsonrpc/examples/pubsub_server.rs +++ b/jsonrpc/examples/pubsub_server.rs @@ -1,8 +1,9 @@ -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use serde::{Deserialize, Serialize}; use serde_json::Value; +use karyon_core::async_util::sleep; use karyon_jsonrpc::{rpc_impl, rpc_pubsub_impl, ArcChannel, Error, Server, SubscriptionID}; struct Calc {} @@ -25,8 +26,13 @@ impl Calc { #[rpc_pubsub_impl] impl Calc { - async fn log_subscribe(&self, chan: ArcChannel, _params: Value) -> Result { - let sub = chan.new_subscription().await; + async fn log_subscribe( + &self, + chan: ArcChannel, + method: String, + _params: Value, + ) -> Result { + let sub = chan.new_subscription(&method).await; let sub_id = sub.id.clone(); smol::spawn(async move { loop { @@ -42,7 +48,12 @@ impl Calc { Ok(serde_json::json!(sub_id)) } - async fn log_unsubscribe(&self, chan: ArcChannel, params: Value) -> Result { + async fn log_unsubscribe( + &self, + chan: ArcChannel, + _method: String, + params: Value, + ) -> Result { let sub_id: SubscriptionID = serde_json::from_value(params)?; chan.remove_subscription(&sub_id).await; Ok(serde_json::json!(true)) @@ -64,6 +75,8 @@ fn main() { .expect("Build a new server"); // Start the server - server.start().await.expect("Start the server"); + server.start().await; + + sleep(Duration::MAX).await; }); } diff --git a/jsonrpc/examples/server.rs b/jsonrpc/examples/server.rs index 5b951cd..470bd02 100644 --- a/jsonrpc/examples/server.rs +++ b/jsonrpc/examples/server.rs @@ -1,8 +1,9 @@ -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use serde::{Deserialize, Serialize}; use serde_json::Value; +use karyon_core::async_util::sleep; use karyon_jsonrpc::{rpc_impl, Error, Server}; struct Calc { @@ -56,6 +57,8 @@ fn main() { .expect("start a new server"); // Start the server - server.start().await.unwrap(); + server.start().await; + + sleep(Duration::MAX).await; }); } diff --git a/jsonrpc/examples/tokio_server/src/main.rs b/jsonrpc/examples/tokio_server/src/main.rs index ce77cd3..d70a46a 100644 --- a/jsonrpc/examples/tokio_server/src/main.rs +++ b/jsonrpc/examples/tokio_server/src/main.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -56,5 +56,7 @@ async fn main() { .expect("start a new server"); // Start the server - server.start().await.unwrap(); + server.start().await; + + tokio::time::sleep(Duration::MAX).await; } diff --git a/jsonrpc/src/client/mod.rs b/jsonrpc/src/client/mod.rs index c9253fc..0d8ccb8 100644 --- a/jsonrpc/src/client/mod.rs +++ b/jsonrpc/src/client/mod.rs @@ -12,7 +12,7 @@ use tokio_rustls::rustls; use karyon_core::{ async_runtime::lock::Mutex, async_util::{timeout, TaskGroup, TaskResult}, - util::random_32, + util::random_64, }; use karyon_net::{tls::ClientTlsConfig, Conn, Endpoint, ToEndpoint}; @@ -54,7 +54,10 @@ impl Client { let request = self.send_request(method, params, None).await?; debug!("--> {request}"); - let response = self.chan_rx.recv().await?; + let response = match self.timeout { + Some(t) => timeout(Duration::from_millis(t), self.chan_rx.recv()).await??, + None => self.chan_rx.recv().await?, + }; debug!("<-- {response}"); if let Some(error) = response.error { @@ -84,7 +87,10 @@ impl Client { let request = self.send_request(method, params, Some(json!(true))).await?; debug!("--> {request}"); - let response = self.chan_rx.recv().await?; + let response = match self.timeout { + Some(t) => timeout(Duration::from_millis(t), self.chan_rx.recv()).await??, + None => self.chan_rx.recv().await?, + }; debug!("<-- {response}"); if let Some(error) = response.error { @@ -95,7 +101,7 @@ impl Client { return Err(Error::InvalidMsg("Invalid response id")); } - let sub_id = match response.subscription { + let sub_id = match response.result { Some(result) => serde_json::from_value::(result)?, None => return Err(Error::InvalidMsg("Invalid subscription id")), }; @@ -116,7 +122,10 @@ impl Client { .await?; debug!("--> {request}"); - let response = self.chan_rx.recv().await?; + let response = match self.timeout { + Some(t) => timeout(Duration::from_millis(t), self.chan_rx.recv()).await??, + None => self.chan_rx.recv().await?, + }; debug!("<-- {response}"); if let Some(error) = response.error { @@ -137,11 +146,11 @@ impl Client { params: T, subscriber: Option, ) -> Result { - let id = json!(random_32()); + let id = random_64(); let request = message::Request { jsonrpc: message::JSONRPC_VERSION.to_string(), - id, + id: json!(id), method: method.to_string(), params: json!(params), subscriber, @@ -150,9 +159,9 @@ impl Client { let req_json = serde_json::to_value(&request)?; match self.timeout { - Some(s) => { - let dur = Duration::from_millis(s); - timeout(dur, self.conn.send(req_json)).await??; + Some(ms) => { + let t = Duration::from_millis(ms); + timeout(t, self.conn.send(req_json)).await??; } None => { self.conn.send(req_json).await?; @@ -176,15 +185,14 @@ impl Client { async move { loop { let msg = selfc.conn.recv().await?; - if let Ok(res) = serde_json::from_value::(msg.clone()) { selfc.chan_tx.send(res).await?; continue; } if let Ok(nt) = serde_json::from_value::(msg.clone()) { - let sub_id = match nt.subscription.clone() { - Some(id) => serde_json::from_value::(id)?, + let sub_result: message::NotificationResult = match nt.params { + Some(p) => serde_json::from_value(p)?, None => { return Err(Error::InvalidMsg( "Invalid notification msg: subscription id not found", @@ -192,13 +200,18 @@ impl Client { } }; - match selfc.subscriptions.lock().await.get(&sub_id) { + match selfc + .subscriptions + .lock() + .await + .get(&sub_result.subscription) + { Some(s) => { - s.send(nt.params.unwrap_or(json!(""))).await?; + s.send(sub_result.result.unwrap_or(json!(""))).await?; continue; } None => { - warn!("Receive unknown notification {sub_id}"); + warn!("Receive unknown notification {}", sub_result.subscription); continue; } } diff --git a/jsonrpc/src/codec.rs b/jsonrpc/src/codec.rs index 29c6f13..cc11602 100644 --- a/jsonrpc/src/codec.rs +++ b/jsonrpc/src/codec.rs @@ -38,7 +38,7 @@ impl Decoder for JsonCodec { let item = match iter.next() { Some(Ok(item)) => item, Some(Err(ref e)) if e.is_eof() => return Ok(None), - Some(Err(e)) => return Err(Error::Encode(e.to_string())), + Some(Err(e)) => return Err(Error::Decode(e.to_string())), None => return Ok(None), }; @@ -70,13 +70,21 @@ impl WebSocketEncoder for WsJsonCodec { #[cfg(feature = "ws")] impl WebSocketDecoder for WsJsonCodec { type DeItem = serde_json::Value; - fn decode(&self, src: &Message) -> Result { + fn decode(&self, src: &Message) -> Result> { match src { Message::Text(s) => match serde_json::from_str(s) { + Ok(m) => Ok(Some(m)), + Err(err) => Err(Error::Decode(err.to_string())), + }, + Message::Binary(s) => match serde_json::from_slice(s) { Ok(m) => Ok(m), Err(err) => Err(Error::Decode(err.to_string())), }, - _ => Err(Error::Decode("Receive wrong message".to_string())), + Message::Close(_) => Err(Error::IO(std::io::ErrorKind::ConnectionAborted.into())), + m => Err(Error::Decode(format!( + "Receive unexpected message: {:?}", + m + ))), } } } diff --git a/jsonrpc/src/error.rs b/jsonrpc/src/error.rs index d68e169..e1cb071 100644 --- a/jsonrpc/src/error.rs +++ b/jsonrpc/src/error.rs @@ -32,6 +32,9 @@ pub enum Error { #[error("Unsupported protocol: {0}")] UnsupportedProtocol(String), + #[error("Receive close message from connection: {0}")] + CloseConnection(String), + #[error("Subscription not found: {0}")] SubscriptionNotFound(String), diff --git a/jsonrpc/src/message.rs b/jsonrpc/src/message.rs index 9c89362..2cf28b1 100644 --- a/jsonrpc/src/message.rs +++ b/jsonrpc/src/message.rs @@ -1,5 +1,9 @@ use serde::{Deserialize, Serialize}; +use crate::SubscriptionID; + +pub type ID = u64; + pub const JSONRPC_VERSION: &str = "2.0"; /// Parse error: Invalid JSON was received by the server. @@ -32,24 +36,25 @@ pub struct Request { pub struct Response { pub jsonrpc: String, #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub result: Option, #[serde(skip_serializing_if = "Option::is_none")] pub error: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub subscription: Option, } #[derive(Debug, Serialize, Deserialize)] pub struct Notification { pub jsonrpc: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub method: Option, + pub method: String, #[serde(skip_serializing_if = "Option::is_none")] pub params: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub subscription: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct NotificationResult { + pub result: Option, + pub subscription: SubscriptionID, } #[derive(Debug, Serialize, Deserialize)] @@ -74,8 +79,8 @@ impl std::fmt::Display for Response { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "{{jsonrpc: {}, result': {:?}, error: {:?} , id: {:?}, subscription: {:?}}}", - self.jsonrpc, self.result, self.error, self.id, self.subscription + "{{jsonrpc: {}, result': {:?}, error: {:?} , id: {:?}}}", + self.jsonrpc, self.result, self.error, self.id, ) } } @@ -94,8 +99,8 @@ impl std::fmt::Display for Notification { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "{{jsonrpc: {}, method: {:?}, params: {:?}, subscription: {:?}}}", - self.jsonrpc, self.method, self.params, self.subscription + "{{jsonrpc: {}, method: {:?}, params: {:?}}}", + self.jsonrpc, self.method, self.params ) } } diff --git a/jsonrpc/src/server/channel.rs b/jsonrpc/src/server/channel.rs index 1498825..f14c1dd 100644 --- a/jsonrpc/src/server/channel.rs +++ b/jsonrpc/src/server/channel.rs @@ -7,11 +7,18 @@ use crate::{Error, Result}; pub type SubscriptionID = u32; pub type ArcChannel = Arc; +pub(crate) struct NewNotification { + pub sub_id: SubscriptionID, + pub result: serde_json::Value, + pub method: String, +} + /// Represents a new subscription pub struct Subscription { pub id: SubscriptionID, parent: Arc, - chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>, + chan: async_channel::Sender, + method: String, } impl Subscription { @@ -19,15 +26,26 @@ impl Subscription { fn new( parent: Arc, id: SubscriptionID, - chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>, + chan: async_channel::Sender, + method: &str, ) -> Self { - Self { parent, id, chan } + Self { + parent, + id, + chan, + method: method.to_string(), + } } /// Sends a notification to the subscriber pub async fn notify(&self, res: serde_json::Value) -> Result<()> { if self.parent.subs.lock().await.contains(&self.id) { - self.chan.send((self.id, res)).await?; + let nt = NewNotification { + sub_id: self.id, + result: res, + method: self.method.clone(), + }; + self.chan.send(nt).await?; Ok(()) } else { Err(Error::SubscriptionNotFound(self.id.to_string())) @@ -37,13 +55,13 @@ impl Subscription { /// Represents a channel for creating/removing subscriptions pub struct Channel { - chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>, + chan: async_channel::Sender, subs: Mutex>, } impl Channel { /// Creates a new `Channel` - pub fn new(chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>) -> ArcChannel { + pub(crate) fn new(chan: async_channel::Sender) -> ArcChannel { Arc::new(Self { chan, subs: Mutex::new(Vec::new()), @@ -51,19 +69,20 @@ impl Channel { } /// Creates a new subscription - pub async fn new_subscription(self: &Arc) -> Subscription { + pub async fn new_subscription(self: &Arc, method: &str) -> Subscription { let sub_id = random_32(); - let sub = Subscription::new(self.clone(), sub_id, self.chan.clone()); + let sub = Subscription::new(self.clone(), sub_id, self.chan.clone(), method); self.subs.lock().await.push(sub_id); sub } /// Removes a subscription pub async fn remove_subscription(self: &Arc, id: &SubscriptionID) { - let i = match self.subs.lock().await.iter().position(|i| i == id) { + let mut subs = self.subs.lock().await; + let i = match subs.iter().position(|i| i == id) { Some(i) => i, None => return, }; - self.subs.lock().await.remove(i); + subs.remove(i); } } diff --git a/jsonrpc/src/server/mod.rs b/jsonrpc/src/server/mod.rs index 4ebab10..29b1a10 100644 --- a/jsonrpc/src/server/mod.rs +++ b/jsonrpc/src/server/mod.rs @@ -47,7 +47,6 @@ fn pack_err_res(code: i32, msg: &str, id: Option) -> message: error: Some(err), result: None, id, - subscription: None, } } @@ -77,19 +76,31 @@ impl Server { } /// Starts the RPC server - pub async fn start(self: Arc) -> Result<()> { - loop { - match self.listener.accept().await { - Ok(conn) => { - if let Err(err) = self.handle_conn(conn).await { - error!("Failed to handle a new conn: {err}") + pub async fn start(self: &Arc) { + let on_failure = |result: TaskResult>| async move { + if let TaskResult::Completed(Err(err)) = result { + error!("Accept loop stopped: {err}"); + } + }; + + let selfc = self.clone(); + self.task_group.spawn( + async move { + loop { + match selfc.listener.accept().await { + Ok(conn) => { + if let Err(err) = selfc.handle_conn(conn).await { + error!("Failed to handle a new conn: {err}") + } + } + Err(err) => { + error!("Failed to accept a new conn: {err}") + } } } - Err(err) => { - error!("Failed to accept a new conn: {err}") - } - } - } + }, + on_failure, + ); } /// Shuts down the RPC server @@ -102,6 +113,40 @@ impl Server { let endpoint = conn.peer_endpoint().expect("get peer endpoint"); debug!("Handle a new connection {endpoint}"); + // TODO Avoid depending on channels + let (tx, rx) = async_channel::bounded::(CHANNEL_CAP); + + let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_CAP); + let channel = Channel::new(ch_tx); + + let on_failure = |result: TaskResult>| async move { + if let TaskResult::Completed(Err(err)) = result { + debug!("Notification loop stopped: {err}"); + } + }; + + let selfc = self.clone(); + let txc = tx.clone(); + self.task_group.spawn( + async move { + loop { + let nt = ch_rx.recv().await?; + let params = Some(serde_json::json!(message::NotificationResult { + subscription: nt.sub_id, + result: Some(nt.result), + })); + let response = message::Notification { + jsonrpc: message::JSONRPC_VERSION.to_string(), + method: nt.method, + params, + }; + debug!("--> {response}"); + txc.send(serde_json::to_value(response)?).await?; + } + }, + on_failure, + ); + let on_failure = |result: TaskResult>| async move { if let TaskResult::Completed(Err(err)) = result { error!("Connection {} dropped: {}", endpoint, err); @@ -110,30 +155,14 @@ impl Server { } }; - let selfc = self.clone(); - let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_CAP); - let channel = Channel::new(ch_tx); self.task_group.spawn( async move { loop { - match select(conn.recv(), ch_rx.recv()).await { + match select(conn.recv(), rx.recv()).await { Either::Left(msg) => { - // TODO spawn a task - let response = selfc.handle_request(channel.clone(), msg?).await; - debug!("--> {response}"); - conn.send(serde_json::to_value(response)?).await?; - } - Either::Right(msg) => { - let (sub_id, result) = msg?; - let response = message::Notification { - jsonrpc: message::JSONRPC_VERSION.to_string(), - method: None, - params: Some(result), - subscription: Some(sub_id.into()), - }; - debug!("--> {response}"); - conn.send(serde_json::to_value(response)?).await?; + selfc.new_request(tx.clone(), channel.clone(), msg?).await; } + Either::Right(msg) => conn.send(msg?).await?, } } }, @@ -176,8 +205,32 @@ impl Server { }) } + /// Spawns a new task for handling a new request + async fn new_request( + self: &Arc, + sender: async_channel::Sender, + channel: ArcChannel, + msg: serde_json::Value, + ) { + let on_failure = |result: TaskResult>| async move { + if let TaskResult::Completed(Err(err)) = result { + error!("Failed to handle a request: {err}"); + } + }; + let selfc = self.clone(); + self.task_group.spawn( + async move { + let response = selfc._handle_request(channel, msg).await; + debug!("--> {response}"); + sender.send(serde_json::json!(response)).await?; + Ok(()) + }, + on_failure, + ); + } + /// Handles a new request - async fn handle_request( + async fn _handle_request( &self, channel: ArcChannel, msg: serde_json::Value, @@ -239,7 +292,6 @@ impl Server { error: None, result: Some(result), id: Some(rpc_msg.id), - subscription: None, } } @@ -262,7 +314,8 @@ impl Server { } }; - let result = match method(channel, rpc_msg.params.clone()).await { + let name = format!("{}.{}", service.name(), method_name); + let result = match method(channel, name, rpc_msg.params.clone()).await { Ok(res) => res, Err(err) => return self.handle_error(err, rpc_msg.id), }; @@ -270,9 +323,8 @@ impl Server { message::Response { jsonrpc: message::JSONRPC_VERSION.to_string(), error: None, - result: None, + result: Some(result), id: Some(rpc_msg.id), - subscription: Some(result), } } diff --git a/jsonrpc/src/server/pubsub_service.rs b/jsonrpc/src/server/pubsub_service.rs index 5b4bf9a..5b3b50b 100644 --- a/jsonrpc/src/server/pubsub_service.rs +++ b/jsonrpc/src/server/pubsub_service.rs @@ -6,7 +6,7 @@ use super::channel::ArcChannel; /// Represents the RPC method pub type PubSubRPCMethod<'a> = - Box PubSubRPCMethodOutput<'a> + Send + 'a>; + Box PubSubRPCMethodOutput<'a> + Send + 'a>; type PubSubRPCMethodOutput<'a> = Pin> + Send + Sync + 'a>>; @@ -51,7 +51,9 @@ macro_rules! impl_pubsub_rpc_service { match name { $( stringify!($m) => { - Some(Box::new(move |chan: karyon_jsonrpc::ArcChannel, params: serde_json::Value| Box::pin(self.$m(chan, params)))) + Some(Box::new(move |chan: karyon_jsonrpc::ArcChannel, method: String, params: serde_json::Value| { + Box::pin(self.$m(chan, method, params)) + })) } )* _ => None, -- cgit v1.2.3