diff options
author | hozan23 <hozan23@karyontech.net> | 2024-05-21 02:20:45 +0200 |
---|---|---|
committer | hozan23 <hozan23@karyontech.net> | 2024-05-22 15:02:06 +0200 |
commit | 028940fe3e0a87cdc421a6d07f1ecfb6c208b9d0 (patch) | |
tree | 3272d5c71cafb098e548cb9811e8f9ddc260ef2f /jsonrpc | |
parent | 0f0cefb62ee8b641dcabcc0a2a1cf019c1de4843 (diff) |
jsonrpc: support pubsub
Diffstat (limited to 'jsonrpc')
-rw-r--r-- | jsonrpc/Cargo.toml | 1 | ||||
-rw-r--r-- | jsonrpc/README.md | 62 | ||||
-rw-r--r-- | jsonrpc/examples/pubsub_client.rs | 47 | ||||
-rw-r--r-- | jsonrpc/examples/pubsub_server.rs | 69 | ||||
-rw-r--r-- | jsonrpc/examples/server.rs | 4 | ||||
-rw-r--r-- | jsonrpc/examples/tokio_server/Cargo.lock | 1 | ||||
-rw-r--r-- | jsonrpc/examples/tokio_server/src/main.rs | 4 | ||||
-rw-r--r-- | jsonrpc/jsonrpc_macro/src/lib.rs | 38 | ||||
-rw-r--r-- | jsonrpc/src/client.rs | 158 | ||||
-rw-r--r-- | jsonrpc/src/client/mod.rs | 374 | ||||
-rw-r--r-- | jsonrpc/src/error.rs | 18 | ||||
-rw-r--r-- | jsonrpc/src/lib.rs | 83 | ||||
-rw-r--r-- | jsonrpc/src/message.rs | 36 | ||||
-rw-r--r-- | jsonrpc/src/server.rs | 282 | ||||
-rw-r--r-- | jsonrpc/src/server/channel.rs | 69 | ||||
-rw-r--r-- | jsonrpc/src/server/mod.rs | 454 | ||||
-rw-r--r-- | jsonrpc/src/server/pubsub_service.rs | 67 | ||||
-rw-r--r-- | jsonrpc/src/server/service.rs (renamed from jsonrpc/src/service.rs) | 0 |
18 files changed, 1236 insertions, 531 deletions
diff --git a/jsonrpc/Cargo.toml b/jsonrpc/Cargo.toml index be3176b..40779fe 100644 --- a/jsonrpc/Cargo.toml +++ b/jsonrpc/Cargo.toml @@ -43,6 +43,7 @@ serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.114" thiserror = "1.0.58" async-trait = "0.1.77" +async-channel = "2.3.1" async-tungstenite = { version = "0.25.0", default-features = false, optional = true } diff --git a/jsonrpc/README.md b/jsonrpc/README.md index 0883bec..3322671 100644 --- a/jsonrpc/README.md +++ b/jsonrpc/README.md @@ -5,10 +5,13 @@ A fast and lightweight async implementation of [JSON-RPC features: - Supports TCP, TLS, WebSocket, and Unix protocols. -- Uses smol(async-std) as the async runtime, but also supports tokio via the +- Uses `smol`(async-std) as the async runtime, but also supports `tokio` via the `tokio` feature. - Allows registration of multiple services (structs) of different types on a single server. +- Supports pub/sub +- Allows passing an `async_executors::Executor` or tokio's `Runtime` when building + the server. ## Example @@ -16,8 +19,11 @@ features: use std::sync::Arc; use serde_json::Value; +use smol::stream::StreamExt; -use karyon_jsonrpc::{Error, Server, Client, rpc_impl}; +use karyon_jsonrpc::{ + Error, Server, Client, rpc_impl, rpc_pubsub_impl, SubscriptionID, ArcChannel +}; struct HelloWorld {} @@ -37,12 +43,42 @@ impl HelloWorld { } } +#[rpc_pubsub_impl] +impl HelloWorld { + async fn log_subscribe(&self, chan: ArcChannel, _params: Value) -> Result<Value, Error> { + let sub = chan.new_subscription().await; + let sub_id = sub.id.clone(); + smol::spawn(async move { + loop { + smol::Timer::after(std::time::Duration::from_secs(1)).await; + if let Err(err) = sub.notify(serde_json::json!("Hello")).await { + println!("Error send notification {err}"); + break; + } + } + }) + .detach(); + + Ok(serde_json::json!(sub_id)) + } + + async fn log_unsubscribe(&self, chan: ArcChannel, params: Value) -> Result<Value, Error> { + let sub_id: SubscriptionID = serde_json::from_value(params)?; + chan.remove_subscription(&sub_id).await; + Ok(serde_json::json!(true)) + } +} + + // Server async { + let service = Arc::new(HelloWorld {}); // Creates a new server + let server = Server::builder("tcp://127.0.0.1:60000") .expect("create new server builder") - .service(HelloWorld{}) + .service(service.clone()) + .pubsub_service(service) .build() .await .expect("build the server"); @@ -63,6 +99,26 @@ async { let result: String = client.call("HelloWorld.say_hello", "world".to_string()) .await .expect("send a request"); + + let (sub_id, sub) = client + .subscribe("Calc.log_subscribe", ()) + .await + .expect("Subscribe to log_subscribe method"); + + smol::spawn(async move { + sub.for_each(|m| { + println!("Receive new notification: {m}"); + }) + .await + }) + .detach(); + + smol::Timer::after(std::time::Duration::from_secs(5)).await; + + client + .unsubscribe("Calc.log_unsubscribe", sub_id) + .await + .expect("Unsubscribe from log_unsubscirbe method"); }; ``` diff --git a/jsonrpc/examples/pubsub_client.rs b/jsonrpc/examples/pubsub_client.rs new file mode 100644 index 0000000..fee2a26 --- /dev/null +++ b/jsonrpc/examples/pubsub_client.rs @@ -0,0 +1,47 @@ +use serde::{Deserialize, Serialize}; +use smol::stream::StreamExt; + +use karyon_jsonrpc::Client; + +#[derive(Deserialize, Serialize, Debug)] +struct Pong {} + +fn main() { + env_logger::init(); + smol::future::block_on(async { + let client = Client::builder("tcp://127.0.0.1:6000") + .expect("Create client builder") + .build() + .await + .expect("Build a client"); + + let result: Pong = client + .call("Calc.ping", ()) + .await + .expect("Send ping request"); + + println!("receive pong msg: {:?}", result); + + let (sub_id, sub) = client + .subscribe("Calc.log_subscribe", ()) + .await + .expect("Subscribe to log_subscribe method"); + + smol::spawn(async move { + sub.for_each(|m| { + println!("Receive new notification: {m}"); + }) + .await + }) + .detach(); + + smol::Timer::after(std::time::Duration::from_secs(5)).await; + + client + .unsubscribe("Calc.log_unsubscribe", sub_id) + .await + .expect("Unsubscribe from log_unsubscirbe method"); + + smol::Timer::after(std::time::Duration::from_secs(2)).await; + }); +} diff --git a/jsonrpc/examples/pubsub_server.rs b/jsonrpc/examples/pubsub_server.rs new file mode 100644 index 0000000..739e6d5 --- /dev/null +++ b/jsonrpc/examples/pubsub_server.rs @@ -0,0 +1,69 @@ +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use karyon_jsonrpc::{rpc_impl, rpc_pubsub_impl, ArcChannel, Error, Server, SubscriptionID}; + +struct Calc {} + +#[derive(Deserialize, Serialize)] +struct Req { + x: u32, + y: u32, +} + +#[derive(Deserialize, Serialize)] +struct Pong {} + +#[rpc_impl] +impl Calc { + async fn ping(&self, _params: Value) -> Result<Value, Error> { + Ok(serde_json::json!(Pong {})) + } +} + +#[rpc_pubsub_impl] +impl Calc { + async fn log_subscribe(&self, chan: ArcChannel, _params: Value) -> Result<Value, Error> { + let sub = chan.new_subscription().await; + let sub_id = sub.id.clone(); + smol::spawn(async move { + loop { + smol::Timer::after(std::time::Duration::from_secs(1)).await; + if let Err(err) = sub.notify(serde_json::json!("Hello")).await { + println!("Error send notification {err}"); + break; + } + } + }) + .detach(); + + Ok(serde_json::json!(sub_id)) + } + + async fn log_unsubscribe(&self, chan: ArcChannel, params: Value) -> Result<Value, Error> { + let sub_id: SubscriptionID = serde_json::from_value(params)?; + chan.remove_subscription(&sub_id).await; + Ok(serde_json::json!(true)) + } +} + +fn main() { + env_logger::init(); + smol::block_on(async { + let calc = Arc::new(Calc {}); + + // Creates a new server + let server = Server::builder("tcp://127.0.0.1:6000") + .expect("Create a new server builder") + .service(calc.clone()) + .pubsub_service(calc) + .build() + .await + .expect("Build a new server"); + + // Start the server + server.start().await.expect("Start the server"); + }); +} diff --git a/jsonrpc/examples/server.rs b/jsonrpc/examples/server.rs index 841e276..5b951cd 100644 --- a/jsonrpc/examples/server.rs +++ b/jsonrpc/examples/server.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -48,7 +50,7 @@ fn main() { // Creates a new server let server = Server::builder("tcp://127.0.0.1:6000") .expect("Create a new server builder") - .service(calc) + .service(Arc::new(calc)) .build() .await .expect("start a new server"); diff --git a/jsonrpc/examples/tokio_server/Cargo.lock b/jsonrpc/examples/tokio_server/Cargo.lock index a7fdb0b..ab39fcd 100644 --- a/jsonrpc/examples/tokio_server/Cargo.lock +++ b/jsonrpc/examples/tokio_server/Cargo.lock @@ -681,6 +681,7 @@ dependencies = [ name = "karyon_jsonrpc" version = "0.1.0" dependencies = [ + "async-channel", "async-trait", "async-tungstenite", "karyon_core", diff --git a/jsonrpc/examples/tokio_server/src/main.rs b/jsonrpc/examples/tokio_server/src/main.rs index 978c90a..ce77cd3 100644 --- a/jsonrpc/examples/tokio_server/src/main.rs +++ b/jsonrpc/examples/tokio_server/src/main.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -48,7 +50,7 @@ async fn main() { // Creates a new server let server = Server::builder("tcp://127.0.0.1:6000") .expect("Create a new server builder") - .service(calc) + .service(Arc::new(calc)) .build() .await .expect("start a new server"); diff --git a/jsonrpc/jsonrpc_macro/src/lib.rs b/jsonrpc/jsonrpc_macro/src/lib.rs index c3d51e8..5acfa5e 100644 --- a/jsonrpc/jsonrpc_macro/src/lib.rs +++ b/jsonrpc/jsonrpc_macro/src/lib.rs @@ -45,3 +45,41 @@ pub fn rpc_impl(_attr: TokenStream, item: TokenStream) -> TokenStream { quoted.into() } + +// TODO remove duplicate code +#[proc_macro_attribute] +pub fn rpc_pubsub_impl(_attr: TokenStream, item: TokenStream) -> TokenStream { + let mut methods: Vec<Ident> = vec![]; + + let item2 = item.clone(); + let parsed_input = parse_macro_input!(item2 as ItemImpl); + + let self_ty = match *parsed_input.self_ty { + Type::Path(p) => p, + _ => err!( + parsed_input.span(), + "implementing the trait `RPCService` on this type is unsupported" + ), + }; + + if parsed_input.items.is_empty() { + err!(self_ty.span(), "At least one method should be implemented"); + } + + for item in parsed_input.items { + match item { + ImplItem::Method(method) => { + methods.push(method.sig.ident); + } + _ => err!(item.span(), "unexpected item"), + } + } + + let item2: TokenStream2 = item.into(); + let quoted = quote! { + karyon_jsonrpc::impl_pubsub_rpc_service!(#self_ty, #(#methods),*); + #item2 + }; + + quoted.into() +} diff --git a/jsonrpc/src/client.rs b/jsonrpc/src/client.rs deleted file mode 100644 index b55943e..0000000 --- a/jsonrpc/src/client.rs +++ /dev/null @@ -1,158 +0,0 @@ -use std::time::Duration; - -use log::debug; -use serde::{de::DeserializeOwned, Serialize}; - -#[cfg(feature = "smol")] -use futures_rustls::rustls; -#[cfg(feature = "tokio")] -use tokio_rustls::rustls; - -use karyon_core::{async_util::timeout, util::random_32}; -use karyon_net::{tls::ClientTlsConfig, Conn, Endpoint, ToEndpoint}; - -#[cfg(feature = "ws")] -use karyon_net::ws::{ClientWsConfig, ClientWssConfig}; - -#[cfg(feature = "ws")] -use crate::codec::WsJsonCodec; - -use crate::{codec::JsonCodec, message, Error, Result}; - -/// Represents an RPC client -pub struct Client { - conn: Conn<serde_json::Value>, - timeout: Option<u64>, -} - -impl Client { - /// Calls the provided method, waits for the response, and returns the result. - pub async fn call<T: Serialize + DeserializeOwned, V: DeserializeOwned>( - &self, - method: &str, - params: T, - ) -> Result<V> { - let id = serde_json::json!(random_32()); - - let request = message::Request { - jsonrpc: message::JSONRPC_VERSION.to_string(), - id, - method: method.to_string(), - params: serde_json::json!(params), - }; - - let req_json = serde_json::to_value(&request)?; - match self.timeout { - Some(s) => { - let dur = Duration::from_secs(s); - timeout(dur, self.conn.send(req_json)).await??; - } - None => { - self.conn.send(req_json).await?; - } - } - debug!("--> {request}"); - - let msg = self.conn.recv().await?; - let response = serde_json::from_value::<message::Response>(msg)?; - debug!("<-- {response}"); - - if response.id.is_none() || response.id.unwrap() != request.id { - return Err(Error::InvalidMsg("Invalid response id")); - } - - if let Some(error) = response.error { - return Err(Error::CallError(error.code, error.message)); - } - - match response.result { - Some(result) => Ok(serde_json::from_value::<V>(result)?), - None => Err(Error::InvalidMsg("Invalid response result")), - } - } -} - -pub struct ClientBuilder { - endpoint: Endpoint, - tls_config: Option<(rustls::ClientConfig, String)>, - timeout: Option<u64>, -} - -impl ClientBuilder { - pub fn with_timeout(mut self, timeout: u64) -> Self { - self.timeout = Some(timeout); - self - } - - pub fn tls_config(mut self, config: rustls::ClientConfig, dns_name: &str) -> Result<Self> { - match self.endpoint { - Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => { - self.tls_config = Some((config, dns_name.to_string())); - Ok(self) - } - _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())), - } - } - - pub async fn build(self) -> Result<Client> { - let conn: Conn<serde_json::Value> = match self.endpoint { - Endpoint::Tcp(..) | Endpoint::Tls(..) => match self.tls_config { - Some((conf, dns_name)) => Box::new( - karyon_net::tls::dial( - &self.endpoint, - ClientTlsConfig { - dns_name, - client_config: conf, - tcp_config: Default::default(), - }, - JsonCodec {}, - ) - .await?, - ), - None => Box::new( - karyon_net::tcp::dial(&self.endpoint, Default::default(), JsonCodec {}).await?, - ), - }, - #[cfg(feature = "ws")] - Endpoint::Ws(..) | Endpoint::Wss(..) => match self.tls_config { - Some((conf, dns_name)) => Box::new( - karyon_net::ws::dial( - &self.endpoint, - ClientWsConfig { - tcp_config: Default::default(), - wss_config: Some(ClientWssConfig { - dns_name, - client_config: conf, - }), - }, - WsJsonCodec {}, - ) - .await?, - ), - None => Box::new( - karyon_net::ws::dial(&self.endpoint, Default::default(), WsJsonCodec {}) - .await?, - ), - }, - #[cfg(all(feature = "unix", target_family = "unix"))] - Endpoint::Unix(..) => Box::new( - karyon_net::unix::dial(&self.endpoint, Default::default(), JsonCodec {}).await?, - ), - _ => return Err(Error::UnsupportedProtocol(self.endpoint.to_string())), - }; - Ok(Client { - timeout: self.timeout, - conn, - }) - } -} -impl Client { - pub fn builder(endpoint: impl ToEndpoint) -> Result<ClientBuilder> { - let endpoint = endpoint.to_endpoint()?; - Ok(ClientBuilder { - endpoint, - timeout: None, - tls_config: None, - }) - } -} diff --git a/jsonrpc/src/client/mod.rs b/jsonrpc/src/client/mod.rs new file mode 100644 index 0000000..c9253fc --- /dev/null +++ b/jsonrpc/src/client/mod.rs @@ -0,0 +1,374 @@ +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use log::{debug, error, warn}; +use serde::{de::DeserializeOwned, Serialize}; +use serde_json::json; + +#[cfg(feature = "smol")] +use futures_rustls::rustls; +#[cfg(feature = "tokio")] +use tokio_rustls::rustls; + +use karyon_core::{ + async_runtime::lock::Mutex, + async_util::{timeout, TaskGroup, TaskResult}, + util::random_32, +}; +use karyon_net::{tls::ClientTlsConfig, Conn, Endpoint, ToEndpoint}; + +#[cfg(feature = "ws")] +use karyon_net::ws::{ClientWsConfig, ClientWssConfig}; + +#[cfg(feature = "ws")] +use crate::codec::WsJsonCodec; + +use crate::{codec::JsonCodec, message, Error, Result, SubscriptionID, TcpConfig}; + +const CHANNEL_CAP: usize = 10; + +const DEFAULT_TIMEOUT: u64 = 1000; // 1s + +/// Type alias for a subscription to receive notifications. +/// +/// The receiver channel is returned by the `subscribe` method to receive +/// notifications from the server. +pub type Subscription = async_channel::Receiver<serde_json::Value>; + +/// Represents an RPC client +pub struct Client { + conn: Conn<serde_json::Value>, + chan_tx: async_channel::Sender<message::Response>, + chan_rx: async_channel::Receiver<message::Response>, + timeout: Option<u64>, + subscriptions: Mutex<HashMap<SubscriptionID, async_channel::Sender<serde_json::Value>>>, + task_group: TaskGroup, +} + +impl Client { + /// Calls the provided method, waits for the response, and returns the result. + pub async fn call<T: Serialize + DeserializeOwned, V: DeserializeOwned>( + &self, + method: &str, + params: T, + ) -> Result<V> { + let request = self.send_request(method, params, None).await?; + debug!("--> {request}"); + + let response = self.chan_rx.recv().await?; + debug!("<-- {response}"); + + if let Some(error) = response.error { + return Err(Error::CallError(error.code, error.message)); + } + + if response.id.is_none() || response.id.unwrap() != request.id { + return Err(Error::InvalidMsg("Invalid response id")); + } + + match response.result { + Some(result) => Ok(serde_json::from_value::<V>(result)?), + None => Err(Error::InvalidMsg("Invalid response result")), + } + } + + /// Subscribes to the provided method, waits for the response, and returns the result. + /// + /// This function sends a subscription request to the specified method + /// with the given parameters. It waits for the response and returns a + /// tuple containing a `SubscriptionID` and a `Subscription` (channel receiver). + pub async fn subscribe<T: Serialize + DeserializeOwned>( + &self, + method: &str, + params: T, + ) -> Result<(SubscriptionID, Subscription)> { + let request = self.send_request(method, params, Some(json!(true))).await?; + debug!("--> {request}"); + + let response = self.chan_rx.recv().await?; + debug!("<-- {response}"); + + if let Some(error) = response.error { + return Err(Error::SubscribeError(error.code, error.message)); + } + + if response.id.is_none() || response.id.unwrap() != request.id { + return Err(Error::InvalidMsg("Invalid response id")); + } + + let sub_id = match response.subscription { + Some(result) => serde_json::from_value::<SubscriptionID>(result)?, + None => return Err(Error::InvalidMsg("Invalid subscription id")), + }; + + let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_CAP); + self.subscriptions.lock().await.insert(sub_id, ch_tx); + + Ok((sub_id, ch_rx)) + } + + /// Unsubscribes from the provided method, waits for the response, and returns the result. + /// + /// This function sends an unsubscription request for the specified method + /// and subscription ID. It waits for the response to confirm the unsubscription. + pub async fn unsubscribe(&self, method: &str, sub_id: SubscriptionID) -> Result<()> { + let request = self + .send_request(method, json!(sub_id), Some(json!(true))) + .await?; + debug!("--> {request}"); + + let response = self.chan_rx.recv().await?; + debug!("<-- {response}"); + + if let Some(error) = response.error { + return Err(Error::SubscribeError(error.code, error.message)); + } + + if response.id.is_none() || response.id.unwrap() != request.id { + return Err(Error::InvalidMsg("Invalid response id")); + } + + self.subscriptions.lock().await.remove(&sub_id); + Ok(()) + } + + async fn send_request<T: Serialize + DeserializeOwned>( + &self, + method: &str, + params: T, + subscriber: Option<serde_json::Value>, + ) -> Result<message::Request> { + let id = json!(random_32()); + + let request = message::Request { + jsonrpc: message::JSONRPC_VERSION.to_string(), + id, + method: method.to_string(), + params: json!(params), + subscriber, + }; + + let req_json = serde_json::to_value(&request)?; + + match self.timeout { + Some(s) => { + let dur = Duration::from_millis(s); + timeout(dur, self.conn.send(req_json)).await??; + } + None => { + self.conn.send(req_json).await?; + } + } + + Ok(request) + } + + fn start_background_receiving(self: &Arc<Self>) { + let selfc = self.clone(); + let on_failure = |result: TaskResult<Result<()>>| async move { + if let TaskResult::Completed(Err(err)) = result { + error!("background receiving stopped: {err}"); + } + // drop all subscription channels + selfc.subscriptions.lock().await.clear(); + }; + let selfc = self.clone(); + self.task_group.spawn( + async move { + loop { + let msg = selfc.conn.recv().await?; + + if let Ok(res) = serde_json::from_value::<message::Response>(msg.clone()) { + selfc.chan_tx.send(res).await?; + continue; + } + + if let Ok(nt) = serde_json::from_value::<message::Notification>(msg.clone()) { + let sub_id = match nt.subscription.clone() { + Some(id) => serde_json::from_value::<SubscriptionID>(id)?, + None => { + return Err(Error::InvalidMsg( + "Invalid notification msg: subscription id not found", + )) + } + }; + + match selfc.subscriptions.lock().await.get(&sub_id) { + Some(s) => { + s.send(nt.params.unwrap_or(json!(""))).await?; + continue; + } + None => { + warn!("Receive unknown notification {sub_id}"); + continue; + } + } + } + + error!("Receive unexpected msg: {msg}"); + return Err(Error::InvalidMsg("Unexpected msg")); + } + }, + on_failure, + ); + } +} + +/// Builder for constructing an RPC [`Client`]. +pub struct ClientBuilder { + endpoint: Endpoint, + tls_config: Option<(rustls::ClientConfig, String)>, + tcp_config: TcpConfig, + timeout: Option<u64>, +} + +impl ClientBuilder { + /// Set timeout for requests, in milliseconds. + /// + /// # Examples + /// + /// ```ignore + /// let client = Client::builder()?.set_timeout(5000).build()?; + /// ``` + pub fn set_timeout(mut self, timeout: u64) -> Self { + self.timeout = Some(timeout); + self + } + + /// Configure TCP settings for the client. + /// + /// # Example + /// + /// ```ignore + /// let tcp_config = TcpConfig::default(); + /// let client = Client::builder()?.tcp_config(tcp_config)?.build()?; + /// ``` + /// + /// This function will return an error if the endpoint does not support TCP protocols. + pub fn tcp_config(mut self, config: TcpConfig) -> Result<Self> { + match self.endpoint { + Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => { + self.tcp_config = config; + Ok(self) + } + _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())), + } + } + + /// Configure TLS settings for the client. + /// + /// # Example + /// + /// ```ignore + /// let tls_config = rustls::ClientConfig::new(...); + /// let client = Client::builder()?.tls_config(tls_config, "example.com")?.build()?; + /// ``` + /// + /// This function will return an error if the endpoint does not support TLS protocols. + pub fn tls_config(mut self, config: rustls::ClientConfig, dns_name: &str) -> Result<Self> { + match self.endpoint { + Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => { + self.tls_config = Some((config, dns_name.to_string())); + Ok(self) + } + _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())), + } + } + + /// Build RPC client from [`ClientBuilder`]. + /// + /// This function creates a new RPC client using the configurations + /// specified in the `ClientBuilder`. It returns a `Arc<Client>` on success. + /// + /// # Example + /// + /// ```ignore + /// let client = Client::builder(endpoint)? + /// .set_timeout(5000) + /// .tcp_config(tcp_config)? + /// .tls_config(tls_config, "example.com")? + /// .build() + /// .await?; + /// ``` + pub async fn build(self) -> Result<Arc<Client>> { + let conn: Conn<serde_json::Value> = match self.endpoint { + Endpoint::Tcp(..) | Endpoint::Tls(..) => match self.tls_config { + Some((conf, dns_name)) => Box::new( + karyon_net::tls::dial( + &self.endpoint, + ClientTlsConfig { + dns_name, + client_config: conf, + tcp_config: self.tcp_config, + }, + JsonCodec {}, + ) + .await?, + ), + None => Box::new( + karyon_net::tcp::dial(&self.endpoint, self.tcp_config, JsonCodec {}).await?, + ), + }, + #[cfg(feature = "ws")] + Endpoint::Ws(..) | Endpoint::Wss(..) => match self.tls_config { + Some((conf, dns_name)) => Box::new( + karyon_net::ws::dial( + &self.endpoint, + ClientWsConfig { + tcp_config: self.tcp_config, + wss_config: Some(ClientWssConfig { + dns_name, + client_config: conf, + }), + }, + WsJsonCodec {}, + ) + .await?, + ), + None => { + let config = ClientWsConfig { + tcp_config: self.tcp_config, + wss_config: None, + }; + Box::new(karyon_net::ws::dial(&self.endpoint, config, WsJsonCodec {}).await?) + } + }, + #[cfg(all(feature = "unix", target_family = "unix"))] + Endpoint::Unix(..) => Box::new( + karyon_net::unix::dial(&self.endpoint, Default::default(), JsonCodec {}).await?, + ), + _ => return Err(Error::UnsupportedProtocol(self.endpoint.to_string())), + }; + + let (tx, rx) = async_channel::bounded(CHANNEL_CAP); + let client = Arc::new(Client { + timeout: self.timeout, + conn, + chan_tx: tx, + chan_rx: rx, + subscriptions: Mutex::new(HashMap::new()), + task_group: TaskGroup::new(), + }); + client.start_background_receiving(); + Ok(client) + } +} +impl Client { + /// Creates a new [`ClientBuilder`] + /// + /// This function initializes a `ClientBuilder` with the specified endpoint. + /// + /// # Example + /// + /// ```ignore + /// let builder = Client::builder("ws://127.0.0.1:3000")?.build()?; + /// ``` + pub fn builder(endpoint: impl ToEndpoint) -> Result<ClientBuilder> { + let endpoint = endpoint.to_endpoint()?; + Ok(ClientBuilder { + endpoint, + timeout: Some(DEFAULT_TIMEOUT), + tls_config: None, + tcp_config: Default::default(), + }) + } +} diff --git a/jsonrpc/src/error.rs b/jsonrpc/src/error.rs index 7f89729..d68e169 100644 --- a/jsonrpc/src/error.rs +++ b/jsonrpc/src/error.rs @@ -11,6 +11,9 @@ pub enum Error { #[error("Call Error: code: {0} msg: {1}")] CallError(i32, String), + #[error("Subscribe Error: code: {0} msg: {1}")] + SubscribeError(i32, String), + #[error("RPC Method Error: code: {0} msg: {1}")] RPCMethodError(i32, &'static str), @@ -29,6 +32,15 @@ pub enum Error { #[error("Unsupported protocol: {0}")] UnsupportedProtocol(String), + #[error("Subscription not found: {0}")] + SubscriptionNotFound(String), + + #[error(transparent)] + ChannelRecv(#[from] async_channel::RecvError), + + #[error("Channel broadcast Error: {0}")] + ChannelBroadcast(String), + #[error("Unexpected Error: {0}")] General(&'static str), @@ -38,3 +50,9 @@ pub enum Error { #[error(transparent)] KaryonNet(#[from] karyon_net::Error), } + +impl<T> From<async_channel::SendError<T>> for Error { + fn from(error: async_channel::SendError<T>) -> Self { + Error::ChannelBroadcast(error.to_string()) + } +} diff --git a/jsonrpc/src/lib.rs b/jsonrpc/src/lib.rs index 187b1ad..7573c4d 100644 --- a/jsonrpc/src/lib.rs +++ b/jsonrpc/src/lib.rs @@ -1,81 +1,20 @@ -//! A fast and lightweight async implementation of [JSON-RPC -//! 2.0](https://www.jsonrpc.org/specification). -//! -//! features: -//! - Supports TCP, TLS, WebSocket, and Unix protocols. -//! - Uses smol(async-std) as the async runtime, but also supports tokio via -//! the `tokio` feature. -//! - Allows registration of multiple services (structs) of different types on a -//! single server. -//! -//! # Example -//! -//! ``` -//! use std::sync::Arc; -//! -//! use serde_json::Value; -//! -//! use karyon_jsonrpc::{Error, Server, Client, rpc_impl}; -//! -//! struct HelloWorld {} -//! -//! #[rpc_impl] -//! impl HelloWorld { -//! async fn say_hello(&self, params: Value) -> Result<Value, Error> { -//! let msg: String = serde_json::from_value(params)?; -//! Ok(serde_json::json!(format!("Hello {msg}!"))) -//! } -//! -//! async fn foo(&self, params: Value) -> Result<Value, Error> { -//! Ok(serde_json::json!("foo!")) -//! } -//! -//! async fn bar(&self, params: Value) -> Result<Value, Error> { -//! Ok(serde_json::json!("bar!")) -//! } -//! } -//! -//! // Server -//! async { -//! // Creates a new server -//! let server = Server::builder("tcp://127.0.0.1:60000") -//! .expect("create new server builder") -//! .service(HelloWorld{}) -//! .build() -//! .await -//! .expect("build the server"); -//! -//! // Starts the server -//! server.start().await.expect("start the server"); -//! }; -//! -//! // Client -//! async { -//! // Creates a new client -//! let client = Client::builder("tcp://127.0.0.1:60000") -//! .expect("create new client builder") -//! .build() -//! .await -//! .expect("build the client"); -//! -//! let result: String = client.call("HelloWorld.say_hello", "world".to_string()) -//! .await -//! .expect("send a request"); -//! }; -//! -//! ``` +#![doc = include_str!("../README.md")] mod client; mod codec; mod error; pub mod message; mod server; -mod service; pub use client::Client; -pub use server::Server; - pub use error::{Error, Result}; -pub use karyon_jsonrpc_macro::rpc_impl; -pub use karyon_net::Endpoint; -pub use service::{RPCMethod, RPCService}; +pub use server::{ + channel::{ArcChannel, Channel, Subscription, SubscriptionID}, + pubsub_service::{PubSubRPCMethod, PubSubRPCService}, + service::{RPCMethod, RPCService}, + Server, +}; + +pub use karyon_jsonrpc_macro::{rpc_impl, rpc_pubsub_impl}; + +pub use karyon_net::{tcp::TcpConfig, Endpoint}; diff --git a/jsonrpc/src/message.rs b/jsonrpc/src/message.rs index f4bf490..9c89362 100644 --- a/jsonrpc/src/message.rs +++ b/jsonrpc/src/message.rs @@ -23,9 +23,12 @@ pub struct Request { pub method: String, pub params: serde_json::Value, pub id: serde_json::Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub subscriber: Option<serde_json::Value>, } #[derive(Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] pub struct Response { pub jsonrpc: String, #[serde(skip_serializing_if = "Option::is_none")] @@ -34,30 +37,35 @@ pub struct Response { pub error: Option<Error>, #[serde(skip_serializing_if = "Option::is_none")] pub id: Option<serde_json::Value>, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Error { - pub code: i32, - pub message: String, #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option<serde_json::Value>, + pub subscription: Option<serde_json::Value>, } #[derive(Debug, Serialize, Deserialize)] pub struct Notification { pub jsonrpc: String, - pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub method: Option<String>, #[serde(skip_serializing_if = "Option::is_none")] pub params: Option<serde_json::Value>, + #[serde(skip_serializing_if = "Option::is_none")] + pub subscription: Option<serde_json::Value>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Error { + pub code: i32, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option<serde_json::Value>, } impl std::fmt::Display for Request { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "{{jsonrpc: {}, method: {}, params: {:?}, id: {:?}}}", - self.jsonrpc, self.method, self.params, self.id + "{{jsonrpc: {}, method: {}, params: {:?}, id: {:?}, subscribe: {:?}}}", + self.jsonrpc, self.method, self.params, self.id, self.subscriber ) } } @@ -66,8 +74,8 @@ impl std::fmt::Display for Response { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "{{jsonrpc: {}, result': {:?}, error: {:?} , id: {:?}}}", - self.jsonrpc, self.result, self.error, self.id + "{{jsonrpc: {}, result': {:?}, error: {:?} , id: {:?}, subscription: {:?}}}", + self.jsonrpc, self.result, self.error, self.id, self.subscription ) } } @@ -86,8 +94,8 @@ impl std::fmt::Display for Notification { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "{{jsonrpc: {}, method: {}, params: {:?}}}", - self.jsonrpc, self.method, self.params + "{{jsonrpc: {}, method: {:?}, params: {:?}, subscription: {:?}}}", + self.jsonrpc, self.method, self.params, self.subscription ) } } diff --git a/jsonrpc/src/server.rs b/jsonrpc/src/server.rs deleted file mode 100644 index 2155295..0000000 --- a/jsonrpc/src/server.rs +++ /dev/null @@ -1,282 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use log::{debug, error, warn}; - -#[cfg(feature = "smol")] -use futures_rustls::rustls; -#[cfg(feature = "tokio")] -use tokio_rustls::rustls; - -use karyon_core::async_runtime::Executor; -use karyon_core::async_util::{TaskGroup, TaskResult}; - -use karyon_net::{Conn, Endpoint, Listener, ToEndpoint}; - -#[cfg(feature = "ws")] -use crate::codec::WsJsonCodec; - -use crate::{codec::JsonCodec, message, Error, RPCService, Result}; - -pub const INVALID_REQUEST_ERROR_MSG: &str = "Invalid request"; -pub const FAILED_TO_PARSE_ERROR_MSG: &str = "Failed to parse"; -pub const METHOD_NOT_FOUND_ERROR_MSG: &str = "Method not found"; -pub const INTERNAL_ERROR_MSG: &str = "Internal error"; - -fn pack_err_res(code: i32, msg: &str, id: Option<serde_json::Value>) -> message::Response { - let err = message::Error { - code, - message: msg.to_string(), - data: None, - }; - - message::Response { - jsonrpc: message::JSONRPC_VERSION.to_string(), - error: Some(err), - result: None, - id, - } -} - -/// Represents an RPC server -pub struct Server { - listener: Listener<serde_json::Value>, - task_group: TaskGroup, - services: HashMap<String, Box<dyn RPCService + 'static>>, -} - -impl Server { - /// Returns the local endpoint. - pub fn local_endpoint(&self) -> Result<Endpoint> { - self.listener.local_endpoint().map_err(Error::from) - } - - /// Starts the RPC server - pub async fn start(self: Arc<Self>) -> Result<()> { - loop { - match self.listener.accept().await { - Ok(conn) => { - if let Err(err) = self.handle_conn(conn).await { - error!("Failed to handle a new conn: {err}") - } - } - Err(err) => { - error!("Failed to accept a new conn: {err}") - } - } - } - } - - /// Shuts down the RPC server - pub async fn shutdown(&self) { - self.task_group.cancel().await; - } - - /// Handles a new connection - async fn handle_conn(self: &Arc<Self>, conn: Conn<serde_json::Value>) -> Result<()> { - let endpoint = conn.peer_endpoint().expect("get peer endpoint"); - debug!("Handle a new connection {endpoint}"); - - let on_failure = |result: TaskResult<Result<()>>| async move { - if let TaskResult::Completed(Err(err)) = result { - error!("Connection {} dropped: {}", endpoint, err); - } else { - warn!("Connection {} dropped", endpoint); - } - }; - - let selfc = self.clone(); - self.task_group.spawn( - async move { - loop { - let msg = conn.recv().await?; - let response = selfc.handle_request(msg).await; - let response = serde_json::to_value(response)?; - debug!("--> {response}"); - conn.send(response).await?; - } - }, - on_failure, - ); - - Ok(()) - } - - /// Handles a request - async fn handle_request(&self, msg: serde_json::Value) -> message::Response { - let rpc_msg = match serde_json::from_value::<message::Request>(msg) { - Ok(m) => m, - Err(_) => { - return pack_err_res(message::PARSE_ERROR_CODE, FAILED_TO_PARSE_ERROR_MSG, None); - } - }; - debug!("<-- {rpc_msg}"); - - let srvc_method: Vec<&str> = rpc_msg.method.split('.').collect(); - if srvc_method.len() != 2 { - return pack_err_res( - message::INVALID_REQUEST_ERROR_CODE, - INVALID_REQUEST_ERROR_MSG, - Some(rpc_msg.id), - ); - } - - let srvc_name = srvc_method[0]; - let method_name = srvc_method[1]; - - let service = match self.services.get(srvc_name) { - Some(s) => s, - None => { - return pack_err_res( - message::METHOD_NOT_FOUND_ERROR_CODE, - METHOD_NOT_FOUND_ERROR_MSG, - Some(rpc_msg.id), - ); - } - }; - - let method = match service.get_method(method_name) { - Some(m) => m, - None => { - return pack_err_res( - message::METHOD_NOT_FOUND_ERROR_CODE, - METHOD_NOT_FOUND_ERROR_MSG, - Some(rpc_msg.id), - ); - } - }; - - let result = match method(rpc_msg.params.clone()).await { - Ok(res) => res, - Err(Error::ParseJSON(_)) => { - return pack_err_res( - message::PARSE_ERROR_CODE, - FAILED_TO_PARSE_ERROR_MSG, - Some(rpc_msg.id), - ); - } - Err(Error::InvalidParams(msg)) => { - return pack_err_res(message::INVALID_PARAMS_ERROR_CODE, msg, Some(rpc_msg.id)); - } - Err(Error::InvalidRequest(msg)) => { - return pack_err_res(message::INVALID_REQUEST_ERROR_CODE, msg, Some(rpc_msg.id)); - } - Err(Error::RPCMethodError(code, msg)) => { - return pack_err_res(code, msg, Some(rpc_msg.id)); - } - Err(_) => { - return pack_err_res( - message::INTERNAL_ERROR_CODE, - INTERNAL_ERROR_MSG, - Some(rpc_msg.id), - ); - } - }; - - message::Response { - jsonrpc: message::JSONRPC_VERSION.to_string(), - error: None, - result: Some(result), - id: Some(rpc_msg.id), - } - } -} - -pub struct ServerBuilder { - endpoint: Endpoint, - tls_config: Option<rustls::ServerConfig>, - services: HashMap<String, Box<dyn RPCService + 'static>>, -} - -impl ServerBuilder { - pub fn service(mut self, service: impl RPCService + 'static) -> Self { - self.services.insert(service.name(), Box::new(service)); - self - } - - pub fn tls_config(mut self, config: rustls::ServerConfig) -> Result<ServerBuilder> { - match self.endpoint { - Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => { - self.tls_config = Some(config); - Ok(self) - } - _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())), - } - } - - pub async fn build(self) -> Result<Arc<Server>> { - self._build(TaskGroup::new()).await - } - - pub async fn build_with_executor(self, ex: Executor) -> Result<Arc<Server>> { - self._build(TaskGroup::with_executor(ex)).await - } - - async fn _build(self, task_group: TaskGroup) -> Result<Arc<Server>> { - let listener: Listener<serde_json::Value> = match self.endpoint { - Endpoint::Tcp(..) | Endpoint::Tls(..) => match &self.tls_config { - Some(conf) => Box::new( - karyon_net::tls::listen( - &self.endpoint, - karyon_net::tls::ServerTlsConfig { - server_config: conf.clone(), - tcp_config: Default::default(), - }, - JsonCodec {}, - ) - .await?, - ), - None => Box::new( - karyon_net::tcp::listen(&self.endpoint, Default::default(), JsonCodec {}) - .await?, - ), - }, - #[cfg(feature = "ws")] - Endpoint::Ws(..) | Endpoint::Wss(..) => match &self.tls_config { - Some(conf) => Box::new( - karyon_net::ws::listen( - &self.endpoint, - karyon_net::ws::ServerWsConfig { - tcp_config: Default::default(), - wss_config: Some(karyon_net::ws::ServerWssConfig { - server_config: conf.clone(), - }), - }, - WsJsonCodec {}, - ) - .await?, - ), - None => Box::new( - karyon_net::ws::listen(&self.endpoint, Default::default(), WsJsonCodec {}) - .await?, - ), - }, - #[cfg(all(feature = "unix", target_family = "unix"))] - Endpoint::Unix(..) => Box::new(karyon_net::unix::listen( - &self.endpoint, - Default::default(), - JsonCodec {}, - )?), - - _ => return Err(Error::UnsupportedProtocol(self.endpoint.to_string())), - }; - - Ok(Arc::new(Server { - listener, - task_group, - services: self.services, - })) - } -} - -impl ServerBuilder {} - -impl Server { - pub fn builder(endpoint: impl ToEndpoint) -> Result<ServerBuilder> { - let endpoint = endpoint.to_endpoint()?; - Ok(ServerBuilder { - endpoint, - services: HashMap::new(), - tls_config: None, - }) - } -} diff --git a/jsonrpc/src/server/channel.rs b/jsonrpc/src/server/channel.rs new file mode 100644 index 0000000..1498825 --- /dev/null +++ b/jsonrpc/src/server/channel.rs @@ -0,0 +1,69 @@ +use std::sync::Arc; + +use karyon_core::{async_runtime::lock::Mutex, util::random_32}; + +use crate::{Error, Result}; + +pub type SubscriptionID = u32; +pub type ArcChannel = Arc<Channel>; + +/// Represents a new subscription +pub struct Subscription { + pub id: SubscriptionID, + parent: Arc<Channel>, + chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>, +} + +impl Subscription { + /// Creates a new `Subscription` + fn new( + parent: Arc<Channel>, + id: SubscriptionID, + chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>, + ) -> Self { + Self { parent, id, chan } + } + + /// Sends a notification to the subscriber + pub async fn notify(&self, res: serde_json::Value) -> Result<()> { + if self.parent.subs.lock().await.contains(&self.id) { + self.chan.send((self.id, res)).await?; + Ok(()) + } else { + Err(Error::SubscriptionNotFound(self.id.to_string())) + } + } +} + +/// Represents a channel for creating/removing subscriptions +pub struct Channel { + chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>, + subs: Mutex<Vec<SubscriptionID>>, +} + +impl Channel { + /// Creates a new `Channel` + pub fn new(chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>) -> ArcChannel { + Arc::new(Self { + chan, + subs: Mutex::new(Vec::new()), + }) + } + + /// Creates a new subscription + pub async fn new_subscription(self: &Arc<Self>) -> Subscription { + let sub_id = random_32(); + let sub = Subscription::new(self.clone(), sub_id, self.chan.clone()); + self.subs.lock().await.push(sub_id); + sub + } + + /// Removes a subscription + pub async fn remove_subscription(self: &Arc<Self>, id: &SubscriptionID) { + let i = match self.subs.lock().await.iter().position(|i| i == id) { + Some(i) => i, + None => return, + }; + self.subs.lock().await.remove(i); + } +} diff --git a/jsonrpc/src/server/mod.rs b/jsonrpc/src/server/mod.rs new file mode 100644 index 0000000..4ebab10 --- /dev/null +++ b/jsonrpc/src/server/mod.rs @@ -0,0 +1,454 @@ +pub mod channel; +pub mod pubsub_service; +pub mod service; + +use std::{collections::HashMap, sync::Arc}; + +use log::{debug, error, warn}; + +#[cfg(feature = "smol")] +use futures_rustls::rustls; +#[cfg(feature = "tokio")] +use tokio_rustls::rustls; + +use karyon_core::{ + async_runtime::Executor, + async_util::{select, Either, TaskGroup, TaskResult}, +}; + +use karyon_net::{Conn, Endpoint, Listener, ToEndpoint}; + +#[cfg(feature = "ws")] +use crate::codec::WsJsonCodec; + +#[cfg(feature = "ws")] +use karyon_net::ws::ServerWsConfig; + +use crate::{codec::JsonCodec, message, Error, PubSubRPCService, RPCService, Result, TcpConfig}; + +use channel::{ArcChannel, Channel}; + +const CHANNEL_CAP: usize = 10; + +pub const INVALID_REQUEST_ERROR_MSG: &str = "Invalid request"; +pub const FAILED_TO_PARSE_ERROR_MSG: &str = "Failed to parse"; +pub const METHOD_NOT_FOUND_ERROR_MSG: &str = "Method not found"; +pub const INTERNAL_ERROR_MSG: &str = "Internal error"; + +fn pack_err_res(code: i32, msg: &str, id: Option<serde_json::Value>) -> message::Response { + let err = message::Error { + code, + message: msg.to_string(), + data: None, + }; + + message::Response { + jsonrpc: message::JSONRPC_VERSION.to_string(), + error: Some(err), + result: None, + id, + subscription: None, + } +} + +struct NewRequest { + srvc_name: String, + method_name: String, + msg: message::Request, +} + +enum SanityCheckResult { + NewReq(NewRequest), + ErrRes(message::Response), +} + +/// Represents an RPC server +pub struct Server { + listener: Listener<serde_json::Value>, + task_group: TaskGroup, + services: HashMap<String, Arc<dyn RPCService + 'static>>, + pubsub_services: HashMap<String, Arc<dyn PubSubRPCService + 'static>>, +} + +impl Server { + /// Returns the local endpoint. + pub fn local_endpoint(&self) -> Result<Endpoint> { + self.listener.local_endpoint().map_err(Error::from) + } + + /// Starts the RPC server + pub async fn start(self: Arc<Self>) -> Result<()> { + loop { + match self.listener.accept().await { + Ok(conn) => { + if let Err(err) = self.handle_conn(conn).await { + error!("Failed to handle a new conn: {err}") + } + } + Err(err) => { + error!("Failed to accept a new conn: {err}") + } + } + } + } + + /// Shuts down the RPC server + pub async fn shutdown(&self) { + self.task_group.cancel().await; + } + + /// Handles a new connection + async fn handle_conn(self: &Arc<Self>, conn: Conn<serde_json::Value>) -> Result<()> { + let endpoint = conn.peer_endpoint().expect("get peer endpoint"); + debug!("Handle a new connection {endpoint}"); + + let on_failure = |result: TaskResult<Result<()>>| async move { + if let TaskResult::Completed(Err(err)) = result { + error!("Connection {} dropped: {}", endpoint, err); + } else { + warn!("Connection {} dropped", endpoint); + } + }; + + let selfc = self.clone(); + let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_CAP); + let channel = Channel::new(ch_tx); + self.task_group.spawn( + async move { + loop { + match select(conn.recv(), ch_rx.recv()).await { + Either::Left(msg) => { + // TODO spawn a task + let response = selfc.handle_request(channel.clone(), msg?).await; + debug!("--> {response}"); + conn.send(serde_json::to_value(response)?).await?; + } + Either::Right(msg) => { + let (sub_id, result) = msg?; + let response = message::Notification { + jsonrpc: message::JSONRPC_VERSION.to_string(), + method: None, + params: Some(result), + subscription: Some(sub_id.into()), + }; + debug!("--> {response}"); + conn.send(serde_json::to_value(response)?).await?; + } + } + } + }, + on_failure, + ); + + Ok(()) + } + + fn sanity_check(&self, request: serde_json::Value) -> SanityCheckResult { + let rpc_msg = match serde_json::from_value::<message::Request>(request) { + Ok(m) => m, + Err(_) => { + return SanityCheckResult::ErrRes(pack_err_res( + message::PARSE_ERROR_CODE, + FAILED_TO_PARSE_ERROR_MSG, + None, + )); + } + }; + debug!("<-- {rpc_msg}"); + + let srvc_method_str = rpc_msg.method.clone(); + let srvc_method: Vec<&str> = srvc_method_str.split('.').collect(); + if srvc_method.len() < 2 { + return SanityCheckResult::ErrRes(pack_err_res( + message::INVALID_REQUEST_ERROR_CODE, + INVALID_REQUEST_ERROR_MSG, + Some(rpc_msg.id), + )); + } + + let srvc_name = srvc_method[0].to_string(); + let method_name = srvc_method[1].to_string(); + + SanityCheckResult::NewReq(NewRequest { + srvc_name, + method_name, + msg: rpc_msg, + }) + } + + /// Handles a new request + async fn handle_request( + &self, + channel: ArcChannel, + msg: serde_json::Value, + ) -> message::Response { + let req = match self.sanity_check(msg) { + SanityCheckResult::NewReq(req) => req, + SanityCheckResult::ErrRes(res) => return res, + }; + + if req.msg.subscriber.is_some() { + match self.pubsub_services.get(&req.srvc_name) { + Some(s) => { + self.handle_pubsub_request(channel, s, &req.method_name, req.msg) + .await + } + None => pack_err_res( + message::METHOD_NOT_FOUND_ERROR_CODE, + METHOD_NOT_FOUND_ERROR_MSG, + Some(req.msg.id), + ), + } + } else { + match self.services.get(&req.srvc_name) { + Some(s) => self.handle_call_request(s, &req.method_name, req.msg).await, + None => pack_err_res( + message::METHOD_NOT_FOUND_ERROR_CODE, + METHOD_NOT_FOUND_ERROR_MSG, + Some(req.msg.id), + ), + } + } + } + + /// Handles a call request + async fn handle_call_request( + &self, + service: &Arc<dyn RPCService + 'static>, + method_name: &str, + rpc_msg: message::Request, + ) -> message::Response { + let method = match service.get_method(method_name) { + Some(m) => m, + None => { + return pack_err_res( + message::METHOD_NOT_FOUND_ERROR_CODE, + METHOD_NOT_FOUND_ERROR_MSG, + Some(rpc_msg.id), + ); + } + }; + + let result = match method(rpc_msg.params.clone()).await { + Ok(res) => res, + Err(err) => return self.handle_error(err, rpc_msg.id), + }; + + message::Response { + jsonrpc: message::JSONRPC_VERSION.to_string(), + error: None, + result: Some(result), + id: Some(rpc_msg.id), + subscription: None, + } + } + + /// Handles a pubsub request + async fn handle_pubsub_request( + &self, + channel: ArcChannel, + service: &Arc<dyn PubSubRPCService + 'static>, + method_name: &str, + rpc_msg: message::Request, + ) -> message::Response { + let method = match service.get_pubsub_method(method_name) { + Some(m) => m, + None => { + return pack_err_res( + message::METHOD_NOT_FOUND_ERROR_CODE, + METHOD_NOT_FOUND_ERROR_MSG, + Some(rpc_msg.id), + ); + } + }; + + let result = match method(channel, rpc_msg.params.clone()).await { + Ok(res) => res, + Err(err) => return self.handle_error(err, rpc_msg.id), + }; + + message::Response { + jsonrpc: message::JSONRPC_VERSION.to_string(), + error: None, + result: None, + id: Some(rpc_msg.id), + subscription: Some(result), + } + } + + fn handle_error(&self, err: Error, msg_id: serde_json::Value) -> message::Response { + match err { + Error::ParseJSON(_) => pack_err_res( + message::PARSE_ERROR_CODE, + FAILED_TO_PARSE_ERROR_MSG, + Some(msg_id), + ), + Error::InvalidParams(msg) => { + pack_err_res(message::INVALID_PARAMS_ERROR_CODE, msg, Some(msg_id)) + } + Error::InvalidRequest(msg) => { + pack_err_res(message::INVALID_REQUEST_ERROR_CODE, msg, Some(msg_id)) + } + Error::RPCMethodError(code, msg) => pack_err_res(code, msg, Some(msg_id)), + _ => pack_err_res( + message::INTERNAL_ERROR_CODE, + INTERNAL_ERROR_MSG, + Some(msg_id), + ), + } + } +} + +/// Builder for constructing an RPC [`Server`]. +pub struct ServerBuilder { + endpoint: Endpoint, + tcp_config: TcpConfig, + tls_config: Option<rustls::ServerConfig>, + services: HashMap<String, Arc<dyn RPCService + 'static>>, + pubsub_services: HashMap<String, Arc<dyn PubSubRPCService + 'static>>, +} + +impl ServerBuilder { + /// Adds a new RPC service to the server. + pub fn service(mut self, service: Arc<dyn RPCService>) -> Self { + self.services.insert(service.name(), service); + self + } + + /// Adds a new PubSub RPC service to the server. + pub fn pubsub_service(mut self, service: Arc<dyn PubSubRPCService>) -> Self { + self.pubsub_services.insert(service.name(), service); + self + } + + /// Configure TCP settings for the server. + /// + /// # Example + /// + /// ```ignore + /// let tcp_config = TcpConfig::default(); + /// let server = Server::builder()?.tcp_config(tcp_config)?.build()?; + /// ``` + /// + /// This function will return an error if the endpoint does not support TCP protocols. + pub fn tcp_config(mut self, config: TcpConfig) -> Result<ServerBuilder> { + match self.endpoint { + Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => { + self.tcp_config = config; + Ok(self) + } + _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())), + } + } + + /// Configure TLS settings for the server. + /// + /// # Example + /// + /// ```ignore + /// let tls_config = rustls::ServerConfig::new(...); + /// let server = Server::builder()?.tls_config(tls_config)?.build()?; + /// ``` + /// + /// This function will return an error if the endpoint does not support TLS protocols. + pub fn tls_config(mut self, config: rustls::ServerConfig) -> Result<ServerBuilder> { + match self.endpoint { + Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => { + self.tls_config = Some(config); + Ok(self) + } + _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())), + } + } + + /// Builds the server with the configured options. + pub async fn build(self) -> Result<Arc<Server>> { + self._build(TaskGroup::new()).await + } + + /// Builds the server with the configured options and an executor. + pub async fn build_with_executor(self, ex: Executor) -> Result<Arc<Server>> { + self._build(TaskGroup::with_executor(ex)).await + } + + async fn _build(self, task_group: TaskGroup) -> Result<Arc<Server>> { + let listener: Listener<serde_json::Value> = match self.endpoint { + Endpoint::Tcp(..) | Endpoint::Tls(..) => match &self.tls_config { + Some(conf) => Box::new( + karyon_net::tls::listen( + &self.endpoint, + karyon_net::tls::ServerTlsConfig { + server_config: conf.clone(), + tcp_config: self.tcp_config, + }, + JsonCodec {}, + ) + .await?, + ), + None => Box::new( + karyon_net::tcp::listen(&self.endpoint, self.tcp_config, JsonCodec {}).await?, + ), + }, + #[cfg(feature = "ws")] + Endpoint::Ws(..) | Endpoint::Wss(..) => match &self.tls_config { + Some(conf) => Box::new( + karyon_net::ws::listen( + &self.endpoint, + ServerWsConfig { + tcp_config: self.tcp_config, + wss_config: Some(karyon_net::ws::ServerWssConfig { + server_config: conf.clone(), + }), + }, + WsJsonCodec {}, + ) + .await?, + ), + None => { + let config = ServerWsConfig { + tcp_config: self.tcp_config, + wss_config: None, + }; + Box::new(karyon_net::ws::listen(&self.endpoint, config, WsJsonCodec {}).await?) + } + }, + #[cfg(all(feature = "unix", target_family = "unix"))] + Endpoint::Unix(..) => Box::new(karyon_net::unix::listen( + &self.endpoint, + Default::default(), + JsonCodec {}, + )?), + + _ => return Err(Error::UnsupportedProtocol(self.endpoint.to_string())), + }; + + Ok(Arc::new(Server { + listener, + task_group, + services: self.services, + pubsub_services: self.pubsub_services, + })) + } +} + +impl Server { + /// Creates a new [`ServerBuilder`] + /// + /// This function initializes a `ServerBuilder` with the specified endpoint. + /// + /// # Example + /// + /// ```ignore + /// let builder = Server::builder("ws://127.0.0.1:3000")?.build()?; + /// ``` + pub fn builder(endpoint: impl ToEndpoint) -> Result<ServerBuilder> { + let endpoint = endpoint.to_endpoint()?; + Ok(ServerBuilder { + endpoint, + services: HashMap::new(), + pubsub_services: HashMap::new(), + tcp_config: Default::default(), + tls_config: None, + }) + } +} diff --git a/jsonrpc/src/server/pubsub_service.rs b/jsonrpc/src/server/pubsub_service.rs new file mode 100644 index 0000000..5b4bf9a --- /dev/null +++ b/jsonrpc/src/server/pubsub_service.rs @@ -0,0 +1,67 @@ +use std::{future::Future, pin::Pin}; + +use crate::Result; + +use super::channel::ArcChannel; + +/// Represents the RPC method +pub type PubSubRPCMethod<'a> = + Box<dyn Fn(ArcChannel, serde_json::Value) -> PubSubRPCMethodOutput<'a> + Send + 'a>; +type PubSubRPCMethodOutput<'a> = + Pin<Box<dyn Future<Output = Result<serde_json::Value>> + Send + Sync + 'a>>; + +/// Defines the interface for an RPC service. +pub trait PubSubRPCService: Sync + Send { + fn get_pubsub_method<'a>(&'a self, name: &'a str) -> Option<PubSubRPCMethod>; + fn name(&self) -> String; +} + +/// Implements the [`PubSubRPCService`] trait for a provided type. +/// +/// # Example +/// +/// ``` +/// use serde_json::Value; +/// +/// use karyon_jsonrpc::{Error, impl_rpc_service}; +/// +/// struct Hello {} +/// +/// impl Hello { +/// async fn foo(&self, params: Value) -> Result<Value, Error> { +/// Ok(serde_json::json!("foo!")) +/// } +/// +/// async fn bar(&self, params: Value) -> Result<Value, Error> { +/// Ok(serde_json::json!("bar!")) +/// } +/// } +/// +/// impl_rpc_service!(Hello, foo, bar); +/// +/// ``` +#[macro_export] +macro_rules! impl_pubsub_rpc_service { + ($t:ty, $($m:ident),*) => { + impl karyon_jsonrpc::PubSubRPCService for $t { + fn get_pubsub_method<'a>( + &'a self, + name: &'a str + ) -> Option<karyon_jsonrpc::PubSubRPCMethod> { + match name { + $( + stringify!($m) => { + Some(Box::new(move |chan: karyon_jsonrpc::ArcChannel, params: serde_json::Value| Box::pin(self.$m(chan, params)))) + } + )* + _ => None, + } + + + } + fn name(&self) -> String{ + stringify!($t).to_string() + } + } + }; +} diff --git a/jsonrpc/src/service.rs b/jsonrpc/src/server/service.rs index 4c8c4b8..4c8c4b8 100644 --- a/jsonrpc/src/service.rs +++ b/jsonrpc/src/server/service.rs |