aboutsummaryrefslogtreecommitdiff
path: root/jsonrpc
diff options
context:
space:
mode:
Diffstat (limited to 'jsonrpc')
-rw-r--r--jsonrpc/Cargo.toml1
-rw-r--r--jsonrpc/README.md62
-rw-r--r--jsonrpc/examples/pubsub_client.rs47
-rw-r--r--jsonrpc/examples/pubsub_server.rs69
-rw-r--r--jsonrpc/examples/server.rs4
-rw-r--r--jsonrpc/examples/tokio_server/Cargo.lock1
-rw-r--r--jsonrpc/examples/tokio_server/src/main.rs4
-rw-r--r--jsonrpc/jsonrpc_macro/src/lib.rs38
-rw-r--r--jsonrpc/src/client.rs158
-rw-r--r--jsonrpc/src/client/mod.rs374
-rw-r--r--jsonrpc/src/error.rs18
-rw-r--r--jsonrpc/src/lib.rs83
-rw-r--r--jsonrpc/src/message.rs36
-rw-r--r--jsonrpc/src/server.rs282
-rw-r--r--jsonrpc/src/server/channel.rs69
-rw-r--r--jsonrpc/src/server/mod.rs454
-rw-r--r--jsonrpc/src/server/pubsub_service.rs67
-rw-r--r--jsonrpc/src/server/service.rs (renamed from jsonrpc/src/service.rs)0
18 files changed, 1236 insertions, 531 deletions
diff --git a/jsonrpc/Cargo.toml b/jsonrpc/Cargo.toml
index be3176b..40779fe 100644
--- a/jsonrpc/Cargo.toml
+++ b/jsonrpc/Cargo.toml
@@ -43,6 +43,7 @@ serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.114"
thiserror = "1.0.58"
async-trait = "0.1.77"
+async-channel = "2.3.1"
async-tungstenite = { version = "0.25.0", default-features = false, optional = true }
diff --git a/jsonrpc/README.md b/jsonrpc/README.md
index 0883bec..3322671 100644
--- a/jsonrpc/README.md
+++ b/jsonrpc/README.md
@@ -5,10 +5,13 @@ A fast and lightweight async implementation of [JSON-RPC
features:
- Supports TCP, TLS, WebSocket, and Unix protocols.
-- Uses smol(async-std) as the async runtime, but also supports tokio via the
+- Uses `smol`(async-std) as the async runtime, but also supports `tokio` via the
`tokio` feature.
- Allows registration of multiple services (structs) of different types on a
single server.
+- Supports pub/sub
+- Allows passing an `async_executors::Executor` or tokio's `Runtime` when building
+ the server.
## Example
@@ -16,8 +19,11 @@ features:
use std::sync::Arc;
use serde_json::Value;
+use smol::stream::StreamExt;
-use karyon_jsonrpc::{Error, Server, Client, rpc_impl};
+use karyon_jsonrpc::{
+ Error, Server, Client, rpc_impl, rpc_pubsub_impl, SubscriptionID, ArcChannel
+};
struct HelloWorld {}
@@ -37,12 +43,42 @@ impl HelloWorld {
}
}
+#[rpc_pubsub_impl]
+impl HelloWorld {
+ async fn log_subscribe(&self, chan: ArcChannel, _params: Value) -> Result<Value, Error> {
+ let sub = chan.new_subscription().await;
+ let sub_id = sub.id.clone();
+ smol::spawn(async move {
+ loop {
+ smol::Timer::after(std::time::Duration::from_secs(1)).await;
+ if let Err(err) = sub.notify(serde_json::json!("Hello")).await {
+ println!("Error send notification {err}");
+ break;
+ }
+ }
+ })
+ .detach();
+
+ Ok(serde_json::json!(sub_id))
+ }
+
+ async fn log_unsubscribe(&self, chan: ArcChannel, params: Value) -> Result<Value, Error> {
+ let sub_id: SubscriptionID = serde_json::from_value(params)?;
+ chan.remove_subscription(&sub_id).await;
+ Ok(serde_json::json!(true))
+ }
+}
+
+
// Server
async {
+ let service = Arc::new(HelloWorld {});
// Creates a new server
+
let server = Server::builder("tcp://127.0.0.1:60000")
.expect("create new server builder")
- .service(HelloWorld{})
+ .service(service.clone())
+ .pubsub_service(service)
.build()
.await
.expect("build the server");
@@ -63,6 +99,26 @@ async {
let result: String = client.call("HelloWorld.say_hello", "world".to_string())
.await
.expect("send a request");
+
+ let (sub_id, sub) = client
+ .subscribe("Calc.log_subscribe", ())
+ .await
+ .expect("Subscribe to log_subscribe method");
+
+ smol::spawn(async move {
+ sub.for_each(|m| {
+ println!("Receive new notification: {m}");
+ })
+ .await
+ })
+ .detach();
+
+ smol::Timer::after(std::time::Duration::from_secs(5)).await;
+
+ client
+ .unsubscribe("Calc.log_unsubscribe", sub_id)
+ .await
+ .expect("Unsubscribe from log_unsubscirbe method");
};
```
diff --git a/jsonrpc/examples/pubsub_client.rs b/jsonrpc/examples/pubsub_client.rs
new file mode 100644
index 0000000..fee2a26
--- /dev/null
+++ b/jsonrpc/examples/pubsub_client.rs
@@ -0,0 +1,47 @@
+use serde::{Deserialize, Serialize};
+use smol::stream::StreamExt;
+
+use karyon_jsonrpc::Client;
+
+#[derive(Deserialize, Serialize, Debug)]
+struct Pong {}
+
+fn main() {
+ env_logger::init();
+ smol::future::block_on(async {
+ let client = Client::builder("tcp://127.0.0.1:6000")
+ .expect("Create client builder")
+ .build()
+ .await
+ .expect("Build a client");
+
+ let result: Pong = client
+ .call("Calc.ping", ())
+ .await
+ .expect("Send ping request");
+
+ println!("receive pong msg: {:?}", result);
+
+ let (sub_id, sub) = client
+ .subscribe("Calc.log_subscribe", ())
+ .await
+ .expect("Subscribe to log_subscribe method");
+
+ smol::spawn(async move {
+ sub.for_each(|m| {
+ println!("Receive new notification: {m}");
+ })
+ .await
+ })
+ .detach();
+
+ smol::Timer::after(std::time::Duration::from_secs(5)).await;
+
+ client
+ .unsubscribe("Calc.log_unsubscribe", sub_id)
+ .await
+ .expect("Unsubscribe from log_unsubscirbe method");
+
+ smol::Timer::after(std::time::Duration::from_secs(2)).await;
+ });
+}
diff --git a/jsonrpc/examples/pubsub_server.rs b/jsonrpc/examples/pubsub_server.rs
new file mode 100644
index 0000000..739e6d5
--- /dev/null
+++ b/jsonrpc/examples/pubsub_server.rs
@@ -0,0 +1,69 @@
+use std::sync::Arc;
+
+use serde::{Deserialize, Serialize};
+use serde_json::Value;
+
+use karyon_jsonrpc::{rpc_impl, rpc_pubsub_impl, ArcChannel, Error, Server, SubscriptionID};
+
+struct Calc {}
+
+#[derive(Deserialize, Serialize)]
+struct Req {
+ x: u32,
+ y: u32,
+}
+
+#[derive(Deserialize, Serialize)]
+struct Pong {}
+
+#[rpc_impl]
+impl Calc {
+ async fn ping(&self, _params: Value) -> Result<Value, Error> {
+ Ok(serde_json::json!(Pong {}))
+ }
+}
+
+#[rpc_pubsub_impl]
+impl Calc {
+ async fn log_subscribe(&self, chan: ArcChannel, _params: Value) -> Result<Value, Error> {
+ let sub = chan.new_subscription().await;
+ let sub_id = sub.id.clone();
+ smol::spawn(async move {
+ loop {
+ smol::Timer::after(std::time::Duration::from_secs(1)).await;
+ if let Err(err) = sub.notify(serde_json::json!("Hello")).await {
+ println!("Error send notification {err}");
+ break;
+ }
+ }
+ })
+ .detach();
+
+ Ok(serde_json::json!(sub_id))
+ }
+
+ async fn log_unsubscribe(&self, chan: ArcChannel, params: Value) -> Result<Value, Error> {
+ let sub_id: SubscriptionID = serde_json::from_value(params)?;
+ chan.remove_subscription(&sub_id).await;
+ Ok(serde_json::json!(true))
+ }
+}
+
+fn main() {
+ env_logger::init();
+ smol::block_on(async {
+ let calc = Arc::new(Calc {});
+
+ // Creates a new server
+ let server = Server::builder("tcp://127.0.0.1:6000")
+ .expect("Create a new server builder")
+ .service(calc.clone())
+ .pubsub_service(calc)
+ .build()
+ .await
+ .expect("Build a new server");
+
+ // Start the server
+ server.start().await.expect("Start the server");
+ });
+}
diff --git a/jsonrpc/examples/server.rs b/jsonrpc/examples/server.rs
index 841e276..5b951cd 100644
--- a/jsonrpc/examples/server.rs
+++ b/jsonrpc/examples/server.rs
@@ -1,3 +1,5 @@
+use std::sync::Arc;
+
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -48,7 +50,7 @@ fn main() {
// Creates a new server
let server = Server::builder("tcp://127.0.0.1:6000")
.expect("Create a new server builder")
- .service(calc)
+ .service(Arc::new(calc))
.build()
.await
.expect("start a new server");
diff --git a/jsonrpc/examples/tokio_server/Cargo.lock b/jsonrpc/examples/tokio_server/Cargo.lock
index a7fdb0b..ab39fcd 100644
--- a/jsonrpc/examples/tokio_server/Cargo.lock
+++ b/jsonrpc/examples/tokio_server/Cargo.lock
@@ -681,6 +681,7 @@ dependencies = [
name = "karyon_jsonrpc"
version = "0.1.0"
dependencies = [
+ "async-channel",
"async-trait",
"async-tungstenite",
"karyon_core",
diff --git a/jsonrpc/examples/tokio_server/src/main.rs b/jsonrpc/examples/tokio_server/src/main.rs
index 978c90a..ce77cd3 100644
--- a/jsonrpc/examples/tokio_server/src/main.rs
+++ b/jsonrpc/examples/tokio_server/src/main.rs
@@ -1,3 +1,5 @@
+use std::sync::Arc;
+
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -48,7 +50,7 @@ async fn main() {
// Creates a new server
let server = Server::builder("tcp://127.0.0.1:6000")
.expect("Create a new server builder")
- .service(calc)
+ .service(Arc::new(calc))
.build()
.await
.expect("start a new server");
diff --git a/jsonrpc/jsonrpc_macro/src/lib.rs b/jsonrpc/jsonrpc_macro/src/lib.rs
index c3d51e8..5acfa5e 100644
--- a/jsonrpc/jsonrpc_macro/src/lib.rs
+++ b/jsonrpc/jsonrpc_macro/src/lib.rs
@@ -45,3 +45,41 @@ pub fn rpc_impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
quoted.into()
}
+
+// TODO remove duplicate code
+#[proc_macro_attribute]
+pub fn rpc_pubsub_impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
+ let mut methods: Vec<Ident> = vec![];
+
+ let item2 = item.clone();
+ let parsed_input = parse_macro_input!(item2 as ItemImpl);
+
+ let self_ty = match *parsed_input.self_ty {
+ Type::Path(p) => p,
+ _ => err!(
+ parsed_input.span(),
+ "implementing the trait `RPCService` on this type is unsupported"
+ ),
+ };
+
+ if parsed_input.items.is_empty() {
+ err!(self_ty.span(), "At least one method should be implemented");
+ }
+
+ for item in parsed_input.items {
+ match item {
+ ImplItem::Method(method) => {
+ methods.push(method.sig.ident);
+ }
+ _ => err!(item.span(), "unexpected item"),
+ }
+ }
+
+ let item2: TokenStream2 = item.into();
+ let quoted = quote! {
+ karyon_jsonrpc::impl_pubsub_rpc_service!(#self_ty, #(#methods),*);
+ #item2
+ };
+
+ quoted.into()
+}
diff --git a/jsonrpc/src/client.rs b/jsonrpc/src/client.rs
deleted file mode 100644
index b55943e..0000000
--- a/jsonrpc/src/client.rs
+++ /dev/null
@@ -1,158 +0,0 @@
-use std::time::Duration;
-
-use log::debug;
-use serde::{de::DeserializeOwned, Serialize};
-
-#[cfg(feature = "smol")]
-use futures_rustls::rustls;
-#[cfg(feature = "tokio")]
-use tokio_rustls::rustls;
-
-use karyon_core::{async_util::timeout, util::random_32};
-use karyon_net::{tls::ClientTlsConfig, Conn, Endpoint, ToEndpoint};
-
-#[cfg(feature = "ws")]
-use karyon_net::ws::{ClientWsConfig, ClientWssConfig};
-
-#[cfg(feature = "ws")]
-use crate::codec::WsJsonCodec;
-
-use crate::{codec::JsonCodec, message, Error, Result};
-
-/// Represents an RPC client
-pub struct Client {
- conn: Conn<serde_json::Value>,
- timeout: Option<u64>,
-}
-
-impl Client {
- /// Calls the provided method, waits for the response, and returns the result.
- pub async fn call<T: Serialize + DeserializeOwned, V: DeserializeOwned>(
- &self,
- method: &str,
- params: T,
- ) -> Result<V> {
- let id = serde_json::json!(random_32());
-
- let request = message::Request {
- jsonrpc: message::JSONRPC_VERSION.to_string(),
- id,
- method: method.to_string(),
- params: serde_json::json!(params),
- };
-
- let req_json = serde_json::to_value(&request)?;
- match self.timeout {
- Some(s) => {
- let dur = Duration::from_secs(s);
- timeout(dur, self.conn.send(req_json)).await??;
- }
- None => {
- self.conn.send(req_json).await?;
- }
- }
- debug!("--> {request}");
-
- let msg = self.conn.recv().await?;
- let response = serde_json::from_value::<message::Response>(msg)?;
- debug!("<-- {response}");
-
- if response.id.is_none() || response.id.unwrap() != request.id {
- return Err(Error::InvalidMsg("Invalid response id"));
- }
-
- if let Some(error) = response.error {
- return Err(Error::CallError(error.code, error.message));
- }
-
- match response.result {
- Some(result) => Ok(serde_json::from_value::<V>(result)?),
- None => Err(Error::InvalidMsg("Invalid response result")),
- }
- }
-}
-
-pub struct ClientBuilder {
- endpoint: Endpoint,
- tls_config: Option<(rustls::ClientConfig, String)>,
- timeout: Option<u64>,
-}
-
-impl ClientBuilder {
- pub fn with_timeout(mut self, timeout: u64) -> Self {
- self.timeout = Some(timeout);
- self
- }
-
- pub fn tls_config(mut self, config: rustls::ClientConfig, dns_name: &str) -> Result<Self> {
- match self.endpoint {
- Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => {
- self.tls_config = Some((config, dns_name.to_string()));
- Ok(self)
- }
- _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
- }
- }
-
- pub async fn build(self) -> Result<Client> {
- let conn: Conn<serde_json::Value> = match self.endpoint {
- Endpoint::Tcp(..) | Endpoint::Tls(..) => match self.tls_config {
- Some((conf, dns_name)) => Box::new(
- karyon_net::tls::dial(
- &self.endpoint,
- ClientTlsConfig {
- dns_name,
- client_config: conf,
- tcp_config: Default::default(),
- },
- JsonCodec {},
- )
- .await?,
- ),
- None => Box::new(
- karyon_net::tcp::dial(&self.endpoint, Default::default(), JsonCodec {}).await?,
- ),
- },
- #[cfg(feature = "ws")]
- Endpoint::Ws(..) | Endpoint::Wss(..) => match self.tls_config {
- Some((conf, dns_name)) => Box::new(
- karyon_net::ws::dial(
- &self.endpoint,
- ClientWsConfig {
- tcp_config: Default::default(),
- wss_config: Some(ClientWssConfig {
- dns_name,
- client_config: conf,
- }),
- },
- WsJsonCodec {},
- )
- .await?,
- ),
- None => Box::new(
- karyon_net::ws::dial(&self.endpoint, Default::default(), WsJsonCodec {})
- .await?,
- ),
- },
- #[cfg(all(feature = "unix", target_family = "unix"))]
- Endpoint::Unix(..) => Box::new(
- karyon_net::unix::dial(&self.endpoint, Default::default(), JsonCodec {}).await?,
- ),
- _ => return Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
- };
- Ok(Client {
- timeout: self.timeout,
- conn,
- })
- }
-}
-impl Client {
- pub fn builder(endpoint: impl ToEndpoint) -> Result<ClientBuilder> {
- let endpoint = endpoint.to_endpoint()?;
- Ok(ClientBuilder {
- endpoint,
- timeout: None,
- tls_config: None,
- })
- }
-}
diff --git a/jsonrpc/src/client/mod.rs b/jsonrpc/src/client/mod.rs
new file mode 100644
index 0000000..c9253fc
--- /dev/null
+++ b/jsonrpc/src/client/mod.rs
@@ -0,0 +1,374 @@
+use std::{collections::HashMap, sync::Arc, time::Duration};
+
+use log::{debug, error, warn};
+use serde::{de::DeserializeOwned, Serialize};
+use serde_json::json;
+
+#[cfg(feature = "smol")]
+use futures_rustls::rustls;
+#[cfg(feature = "tokio")]
+use tokio_rustls::rustls;
+
+use karyon_core::{
+ async_runtime::lock::Mutex,
+ async_util::{timeout, TaskGroup, TaskResult},
+ util::random_32,
+};
+use karyon_net::{tls::ClientTlsConfig, Conn, Endpoint, ToEndpoint};
+
+#[cfg(feature = "ws")]
+use karyon_net::ws::{ClientWsConfig, ClientWssConfig};
+
+#[cfg(feature = "ws")]
+use crate::codec::WsJsonCodec;
+
+use crate::{codec::JsonCodec, message, Error, Result, SubscriptionID, TcpConfig};
+
+const CHANNEL_CAP: usize = 10;
+
+const DEFAULT_TIMEOUT: u64 = 1000; // 1s
+
+/// Type alias for a subscription to receive notifications.
+///
+/// The receiver channel is returned by the `subscribe` method to receive
+/// notifications from the server.
+pub type Subscription = async_channel::Receiver<serde_json::Value>;
+
+/// Represents an RPC client
+pub struct Client {
+ conn: Conn<serde_json::Value>,
+ chan_tx: async_channel::Sender<message::Response>,
+ chan_rx: async_channel::Receiver<message::Response>,
+ timeout: Option<u64>,
+ subscriptions: Mutex<HashMap<SubscriptionID, async_channel::Sender<serde_json::Value>>>,
+ task_group: TaskGroup,
+}
+
+impl Client {
+ /// Calls the provided method, waits for the response, and returns the result.
+ pub async fn call<T: Serialize + DeserializeOwned, V: DeserializeOwned>(
+ &self,
+ method: &str,
+ params: T,
+ ) -> Result<V> {
+ let request = self.send_request(method, params, None).await?;
+ debug!("--> {request}");
+
+ let response = self.chan_rx.recv().await?;
+ debug!("<-- {response}");
+
+ if let Some(error) = response.error {
+ return Err(Error::CallError(error.code, error.message));
+ }
+
+ if response.id.is_none() || response.id.unwrap() != request.id {
+ return Err(Error::InvalidMsg("Invalid response id"));
+ }
+
+ match response.result {
+ Some(result) => Ok(serde_json::from_value::<V>(result)?),
+ None => Err(Error::InvalidMsg("Invalid response result")),
+ }
+ }
+
+ /// Subscribes to the provided method, waits for the response, and returns the result.
+ ///
+ /// This function sends a subscription request to the specified method
+ /// with the given parameters. It waits for the response and returns a
+ /// tuple containing a `SubscriptionID` and a `Subscription` (channel receiver).
+ pub async fn subscribe<T: Serialize + DeserializeOwned>(
+ &self,
+ method: &str,
+ params: T,
+ ) -> Result<(SubscriptionID, Subscription)> {
+ let request = self.send_request(method, params, Some(json!(true))).await?;
+ debug!("--> {request}");
+
+ let response = self.chan_rx.recv().await?;
+ debug!("<-- {response}");
+
+ if let Some(error) = response.error {
+ return Err(Error::SubscribeError(error.code, error.message));
+ }
+
+ if response.id.is_none() || response.id.unwrap() != request.id {
+ return Err(Error::InvalidMsg("Invalid response id"));
+ }
+
+ let sub_id = match response.subscription {
+ Some(result) => serde_json::from_value::<SubscriptionID>(result)?,
+ None => return Err(Error::InvalidMsg("Invalid subscription id")),
+ };
+
+ let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_CAP);
+ self.subscriptions.lock().await.insert(sub_id, ch_tx);
+
+ Ok((sub_id, ch_rx))
+ }
+
+ /// Unsubscribes from the provided method, waits for the response, and returns the result.
+ ///
+ /// This function sends an unsubscription request for the specified method
+ /// and subscription ID. It waits for the response to confirm the unsubscription.
+ pub async fn unsubscribe(&self, method: &str, sub_id: SubscriptionID) -> Result<()> {
+ let request = self
+ .send_request(method, json!(sub_id), Some(json!(true)))
+ .await?;
+ debug!("--> {request}");
+
+ let response = self.chan_rx.recv().await?;
+ debug!("<-- {response}");
+
+ if let Some(error) = response.error {
+ return Err(Error::SubscribeError(error.code, error.message));
+ }
+
+ if response.id.is_none() || response.id.unwrap() != request.id {
+ return Err(Error::InvalidMsg("Invalid response id"));
+ }
+
+ self.subscriptions.lock().await.remove(&sub_id);
+ Ok(())
+ }
+
+ async fn send_request<T: Serialize + DeserializeOwned>(
+ &self,
+ method: &str,
+ params: T,
+ subscriber: Option<serde_json::Value>,
+ ) -> Result<message::Request> {
+ let id = json!(random_32());
+
+ let request = message::Request {
+ jsonrpc: message::JSONRPC_VERSION.to_string(),
+ id,
+ method: method.to_string(),
+ params: json!(params),
+ subscriber,
+ };
+
+ let req_json = serde_json::to_value(&request)?;
+
+ match self.timeout {
+ Some(s) => {
+ let dur = Duration::from_millis(s);
+ timeout(dur, self.conn.send(req_json)).await??;
+ }
+ None => {
+ self.conn.send(req_json).await?;
+ }
+ }
+
+ Ok(request)
+ }
+
+ fn start_background_receiving(self: &Arc<Self>) {
+ let selfc = self.clone();
+ let on_failure = |result: TaskResult<Result<()>>| async move {
+ if let TaskResult::Completed(Err(err)) = result {
+ error!("background receiving stopped: {err}");
+ }
+ // drop all subscription channels
+ selfc.subscriptions.lock().await.clear();
+ };
+ let selfc = self.clone();
+ self.task_group.spawn(
+ async move {
+ loop {
+ let msg = selfc.conn.recv().await?;
+
+ if let Ok(res) = serde_json::from_value::<message::Response>(msg.clone()) {
+ selfc.chan_tx.send(res).await?;
+ continue;
+ }
+
+ if let Ok(nt) = serde_json::from_value::<message::Notification>(msg.clone()) {
+ let sub_id = match nt.subscription.clone() {
+ Some(id) => serde_json::from_value::<SubscriptionID>(id)?,
+ None => {
+ return Err(Error::InvalidMsg(
+ "Invalid notification msg: subscription id not found",
+ ))
+ }
+ };
+
+ match selfc.subscriptions.lock().await.get(&sub_id) {
+ Some(s) => {
+ s.send(nt.params.unwrap_or(json!(""))).await?;
+ continue;
+ }
+ None => {
+ warn!("Receive unknown notification {sub_id}");
+ continue;
+ }
+ }
+ }
+
+ error!("Receive unexpected msg: {msg}");
+ return Err(Error::InvalidMsg("Unexpected msg"));
+ }
+ },
+ on_failure,
+ );
+ }
+}
+
+/// Builder for constructing an RPC [`Client`].
+pub struct ClientBuilder {
+ endpoint: Endpoint,
+ tls_config: Option<(rustls::ClientConfig, String)>,
+ tcp_config: TcpConfig,
+ timeout: Option<u64>,
+}
+
+impl ClientBuilder {
+ /// Set timeout for requests, in milliseconds.
+ ///
+ /// # Examples
+ ///
+ /// ```ignore
+ /// let client = Client::builder()?.set_timeout(5000).build()?;
+ /// ```
+ pub fn set_timeout(mut self, timeout: u64) -> Self {
+ self.timeout = Some(timeout);
+ self
+ }
+
+ /// Configure TCP settings for the client.
+ ///
+ /// # Example
+ ///
+ /// ```ignore
+ /// let tcp_config = TcpConfig::default();
+ /// let client = Client::builder()?.tcp_config(tcp_config)?.build()?;
+ /// ```
+ ///
+ /// This function will return an error if the endpoint does not support TCP protocols.
+ pub fn tcp_config(mut self, config: TcpConfig) -> Result<Self> {
+ match self.endpoint {
+ Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => {
+ self.tcp_config = config;
+ Ok(self)
+ }
+ _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
+ }
+ }
+
+ /// Configure TLS settings for the client.
+ ///
+ /// # Example
+ ///
+ /// ```ignore
+ /// let tls_config = rustls::ClientConfig::new(...);
+ /// let client = Client::builder()?.tls_config(tls_config, "example.com")?.build()?;
+ /// ```
+ ///
+ /// This function will return an error if the endpoint does not support TLS protocols.
+ pub fn tls_config(mut self, config: rustls::ClientConfig, dns_name: &str) -> Result<Self> {
+ match self.endpoint {
+ Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => {
+ self.tls_config = Some((config, dns_name.to_string()));
+ Ok(self)
+ }
+ _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
+ }
+ }
+
+ /// Build RPC client from [`ClientBuilder`].
+ ///
+ /// This function creates a new RPC client using the configurations
+ /// specified in the `ClientBuilder`. It returns a `Arc<Client>` on success.
+ ///
+ /// # Example
+ ///
+ /// ```ignore
+ /// let client = Client::builder(endpoint)?
+ /// .set_timeout(5000)
+ /// .tcp_config(tcp_config)?
+ /// .tls_config(tls_config, "example.com")?
+ /// .build()
+ /// .await?;
+ /// ```
+ pub async fn build(self) -> Result<Arc<Client>> {
+ let conn: Conn<serde_json::Value> = match self.endpoint {
+ Endpoint::Tcp(..) | Endpoint::Tls(..) => match self.tls_config {
+ Some((conf, dns_name)) => Box::new(
+ karyon_net::tls::dial(
+ &self.endpoint,
+ ClientTlsConfig {
+ dns_name,
+ client_config: conf,
+ tcp_config: self.tcp_config,
+ },
+ JsonCodec {},
+ )
+ .await?,
+ ),
+ None => Box::new(
+ karyon_net::tcp::dial(&self.endpoint, self.tcp_config, JsonCodec {}).await?,
+ ),
+ },
+ #[cfg(feature = "ws")]
+ Endpoint::Ws(..) | Endpoint::Wss(..) => match self.tls_config {
+ Some((conf, dns_name)) => Box::new(
+ karyon_net::ws::dial(
+ &self.endpoint,
+ ClientWsConfig {
+ tcp_config: self.tcp_config,
+ wss_config: Some(ClientWssConfig {
+ dns_name,
+ client_config: conf,
+ }),
+ },
+ WsJsonCodec {},
+ )
+ .await?,
+ ),
+ None => {
+ let config = ClientWsConfig {
+ tcp_config: self.tcp_config,
+ wss_config: None,
+ };
+ Box::new(karyon_net::ws::dial(&self.endpoint, config, WsJsonCodec {}).await?)
+ }
+ },
+ #[cfg(all(feature = "unix", target_family = "unix"))]
+ Endpoint::Unix(..) => Box::new(
+ karyon_net::unix::dial(&self.endpoint, Default::default(), JsonCodec {}).await?,
+ ),
+ _ => return Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
+ };
+
+ let (tx, rx) = async_channel::bounded(CHANNEL_CAP);
+ let client = Arc::new(Client {
+ timeout: self.timeout,
+ conn,
+ chan_tx: tx,
+ chan_rx: rx,
+ subscriptions: Mutex::new(HashMap::new()),
+ task_group: TaskGroup::new(),
+ });
+ client.start_background_receiving();
+ Ok(client)
+ }
+}
+impl Client {
+ /// Creates a new [`ClientBuilder`]
+ ///
+ /// This function initializes a `ClientBuilder` with the specified endpoint.
+ ///
+ /// # Example
+ ///
+ /// ```ignore
+ /// let builder = Client::builder("ws://127.0.0.1:3000")?.build()?;
+ /// ```
+ pub fn builder(endpoint: impl ToEndpoint) -> Result<ClientBuilder> {
+ let endpoint = endpoint.to_endpoint()?;
+ Ok(ClientBuilder {
+ endpoint,
+ timeout: Some(DEFAULT_TIMEOUT),
+ tls_config: None,
+ tcp_config: Default::default(),
+ })
+ }
+}
diff --git a/jsonrpc/src/error.rs b/jsonrpc/src/error.rs
index 7f89729..d68e169 100644
--- a/jsonrpc/src/error.rs
+++ b/jsonrpc/src/error.rs
@@ -11,6 +11,9 @@ pub enum Error {
#[error("Call Error: code: {0} msg: {1}")]
CallError(i32, String),
+ #[error("Subscribe Error: code: {0} msg: {1}")]
+ SubscribeError(i32, String),
+
#[error("RPC Method Error: code: {0} msg: {1}")]
RPCMethodError(i32, &'static str),
@@ -29,6 +32,15 @@ pub enum Error {
#[error("Unsupported protocol: {0}")]
UnsupportedProtocol(String),
+ #[error("Subscription not found: {0}")]
+ SubscriptionNotFound(String),
+
+ #[error(transparent)]
+ ChannelRecv(#[from] async_channel::RecvError),
+
+ #[error("Channel broadcast Error: {0}")]
+ ChannelBroadcast(String),
+
#[error("Unexpected Error: {0}")]
General(&'static str),
@@ -38,3 +50,9 @@ pub enum Error {
#[error(transparent)]
KaryonNet(#[from] karyon_net::Error),
}
+
+impl<T> From<async_channel::SendError<T>> for Error {
+ fn from(error: async_channel::SendError<T>) -> Self {
+ Error::ChannelBroadcast(error.to_string())
+ }
+}
diff --git a/jsonrpc/src/lib.rs b/jsonrpc/src/lib.rs
index 187b1ad..7573c4d 100644
--- a/jsonrpc/src/lib.rs
+++ b/jsonrpc/src/lib.rs
@@ -1,81 +1,20 @@
-//! A fast and lightweight async implementation of [JSON-RPC
-//! 2.0](https://www.jsonrpc.org/specification).
-//!
-//! features:
-//! - Supports TCP, TLS, WebSocket, and Unix protocols.
-//! - Uses smol(async-std) as the async runtime, but also supports tokio via
-//! the `tokio` feature.
-//! - Allows registration of multiple services (structs) of different types on a
-//! single server.
-//!
-//! # Example
-//!
-//! ```
-//! use std::sync::Arc;
-//!
-//! use serde_json::Value;
-//!
-//! use karyon_jsonrpc::{Error, Server, Client, rpc_impl};
-//!
-//! struct HelloWorld {}
-//!
-//! #[rpc_impl]
-//! impl HelloWorld {
-//! async fn say_hello(&self, params: Value) -> Result<Value, Error> {
-//! let msg: String = serde_json::from_value(params)?;
-//! Ok(serde_json::json!(format!("Hello {msg}!")))
-//! }
-//!
-//! async fn foo(&self, params: Value) -> Result<Value, Error> {
-//! Ok(serde_json::json!("foo!"))
-//! }
-//!
-//! async fn bar(&self, params: Value) -> Result<Value, Error> {
-//! Ok(serde_json::json!("bar!"))
-//! }
-//! }
-//!
-//! // Server
-//! async {
-//! // Creates a new server
-//! let server = Server::builder("tcp://127.0.0.1:60000")
-//! .expect("create new server builder")
-//! .service(HelloWorld{})
-//! .build()
-//! .await
-//! .expect("build the server");
-//!
-//! // Starts the server
-//! server.start().await.expect("start the server");
-//! };
-//!
-//! // Client
-//! async {
-//! // Creates a new client
-//! let client = Client::builder("tcp://127.0.0.1:60000")
-//! .expect("create new client builder")
-//! .build()
-//! .await
-//! .expect("build the client");
-//!
-//! let result: String = client.call("HelloWorld.say_hello", "world".to_string())
-//! .await
-//! .expect("send a request");
-//! };
-//!
-//! ```
+#![doc = include_str!("../README.md")]
mod client;
mod codec;
mod error;
pub mod message;
mod server;
-mod service;
pub use client::Client;
-pub use server::Server;
-
pub use error::{Error, Result};
-pub use karyon_jsonrpc_macro::rpc_impl;
-pub use karyon_net::Endpoint;
-pub use service::{RPCMethod, RPCService};
+pub use server::{
+ channel::{ArcChannel, Channel, Subscription, SubscriptionID},
+ pubsub_service::{PubSubRPCMethod, PubSubRPCService},
+ service::{RPCMethod, RPCService},
+ Server,
+};
+
+pub use karyon_jsonrpc_macro::{rpc_impl, rpc_pubsub_impl};
+
+pub use karyon_net::{tcp::TcpConfig, Endpoint};
diff --git a/jsonrpc/src/message.rs b/jsonrpc/src/message.rs
index f4bf490..9c89362 100644
--- a/jsonrpc/src/message.rs
+++ b/jsonrpc/src/message.rs
@@ -23,9 +23,12 @@ pub struct Request {
pub method: String,
pub params: serde_json::Value,
pub id: serde_json::Value,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub subscriber: Option<serde_json::Value>,
}
#[derive(Debug, Serialize, Deserialize)]
+#[serde(deny_unknown_fields)]
pub struct Response {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
@@ -34,30 +37,35 @@ pub struct Response {
pub error: Option<Error>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<serde_json::Value>,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-pub struct Error {
- pub code: i32,
- pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
- pub data: Option<serde_json::Value>,
+ pub subscription: Option<serde_json::Value>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Notification {
pub jsonrpc: String,
- pub method: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub method: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<serde_json::Value>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub subscription: Option<serde_json::Value>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Error {
+ pub code: i32,
+ pub message: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub data: Option<serde_json::Value>,
}
impl std::fmt::Display for Request {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
- "{{jsonrpc: {}, method: {}, params: {:?}, id: {:?}}}",
- self.jsonrpc, self.method, self.params, self.id
+ "{{jsonrpc: {}, method: {}, params: {:?}, id: {:?}, subscribe: {:?}}}",
+ self.jsonrpc, self.method, self.params, self.id, self.subscriber
)
}
}
@@ -66,8 +74,8 @@ impl std::fmt::Display for Response {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
- "{{jsonrpc: {}, result': {:?}, error: {:?} , id: {:?}}}",
- self.jsonrpc, self.result, self.error, self.id
+ "{{jsonrpc: {}, result': {:?}, error: {:?} , id: {:?}, subscription: {:?}}}",
+ self.jsonrpc, self.result, self.error, self.id, self.subscription
)
}
}
@@ -86,8 +94,8 @@ impl std::fmt::Display for Notification {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
- "{{jsonrpc: {}, method: {}, params: {:?}}}",
- self.jsonrpc, self.method, self.params
+ "{{jsonrpc: {}, method: {:?}, params: {:?}, subscription: {:?}}}",
+ self.jsonrpc, self.method, self.params, self.subscription
)
}
}
diff --git a/jsonrpc/src/server.rs b/jsonrpc/src/server.rs
deleted file mode 100644
index 2155295..0000000
--- a/jsonrpc/src/server.rs
+++ /dev/null
@@ -1,282 +0,0 @@
-use std::{collections::HashMap, sync::Arc};
-
-use log::{debug, error, warn};
-
-#[cfg(feature = "smol")]
-use futures_rustls::rustls;
-#[cfg(feature = "tokio")]
-use tokio_rustls::rustls;
-
-use karyon_core::async_runtime::Executor;
-use karyon_core::async_util::{TaskGroup, TaskResult};
-
-use karyon_net::{Conn, Endpoint, Listener, ToEndpoint};
-
-#[cfg(feature = "ws")]
-use crate::codec::WsJsonCodec;
-
-use crate::{codec::JsonCodec, message, Error, RPCService, Result};
-
-pub const INVALID_REQUEST_ERROR_MSG: &str = "Invalid request";
-pub const FAILED_TO_PARSE_ERROR_MSG: &str = "Failed to parse";
-pub const METHOD_NOT_FOUND_ERROR_MSG: &str = "Method not found";
-pub const INTERNAL_ERROR_MSG: &str = "Internal error";
-
-fn pack_err_res(code: i32, msg: &str, id: Option<serde_json::Value>) -> message::Response {
- let err = message::Error {
- code,
- message: msg.to_string(),
- data: None,
- };
-
- message::Response {
- jsonrpc: message::JSONRPC_VERSION.to_string(),
- error: Some(err),
- result: None,
- id,
- }
-}
-
-/// Represents an RPC server
-pub struct Server {
- listener: Listener<serde_json::Value>,
- task_group: TaskGroup,
- services: HashMap<String, Box<dyn RPCService + 'static>>,
-}
-
-impl Server {
- /// Returns the local endpoint.
- pub fn local_endpoint(&self) -> Result<Endpoint> {
- self.listener.local_endpoint().map_err(Error::from)
- }
-
- /// Starts the RPC server
- pub async fn start(self: Arc<Self>) -> Result<()> {
- loop {
- match self.listener.accept().await {
- Ok(conn) => {
- if let Err(err) = self.handle_conn(conn).await {
- error!("Failed to handle a new conn: {err}")
- }
- }
- Err(err) => {
- error!("Failed to accept a new conn: {err}")
- }
- }
- }
- }
-
- /// Shuts down the RPC server
- pub async fn shutdown(&self) {
- self.task_group.cancel().await;
- }
-
- /// Handles a new connection
- async fn handle_conn(self: &Arc<Self>, conn: Conn<serde_json::Value>) -> Result<()> {
- let endpoint = conn.peer_endpoint().expect("get peer endpoint");
- debug!("Handle a new connection {endpoint}");
-
- let on_failure = |result: TaskResult<Result<()>>| async move {
- if let TaskResult::Completed(Err(err)) = result {
- error!("Connection {} dropped: {}", endpoint, err);
- } else {
- warn!("Connection {} dropped", endpoint);
- }
- };
-
- let selfc = self.clone();
- self.task_group.spawn(
- async move {
- loop {
- let msg = conn.recv().await?;
- let response = selfc.handle_request(msg).await;
- let response = serde_json::to_value(response)?;
- debug!("--> {response}");
- conn.send(response).await?;
- }
- },
- on_failure,
- );
-
- Ok(())
- }
-
- /// Handles a request
- async fn handle_request(&self, msg: serde_json::Value) -> message::Response {
- let rpc_msg = match serde_json::from_value::<message::Request>(msg) {
- Ok(m) => m,
- Err(_) => {
- return pack_err_res(message::PARSE_ERROR_CODE, FAILED_TO_PARSE_ERROR_MSG, None);
- }
- };
- debug!("<-- {rpc_msg}");
-
- let srvc_method: Vec<&str> = rpc_msg.method.split('.').collect();
- if srvc_method.len() != 2 {
- return pack_err_res(
- message::INVALID_REQUEST_ERROR_CODE,
- INVALID_REQUEST_ERROR_MSG,
- Some(rpc_msg.id),
- );
- }
-
- let srvc_name = srvc_method[0];
- let method_name = srvc_method[1];
-
- let service = match self.services.get(srvc_name) {
- Some(s) => s,
- None => {
- return pack_err_res(
- message::METHOD_NOT_FOUND_ERROR_CODE,
- METHOD_NOT_FOUND_ERROR_MSG,
- Some(rpc_msg.id),
- );
- }
- };
-
- let method = match service.get_method(method_name) {
- Some(m) => m,
- None => {
- return pack_err_res(
- message::METHOD_NOT_FOUND_ERROR_CODE,
- METHOD_NOT_FOUND_ERROR_MSG,
- Some(rpc_msg.id),
- );
- }
- };
-
- let result = match method(rpc_msg.params.clone()).await {
- Ok(res) => res,
- Err(Error::ParseJSON(_)) => {
- return pack_err_res(
- message::PARSE_ERROR_CODE,
- FAILED_TO_PARSE_ERROR_MSG,
- Some(rpc_msg.id),
- );
- }
- Err(Error::InvalidParams(msg)) => {
- return pack_err_res(message::INVALID_PARAMS_ERROR_CODE, msg, Some(rpc_msg.id));
- }
- Err(Error::InvalidRequest(msg)) => {
- return pack_err_res(message::INVALID_REQUEST_ERROR_CODE, msg, Some(rpc_msg.id));
- }
- Err(Error::RPCMethodError(code, msg)) => {
- return pack_err_res(code, msg, Some(rpc_msg.id));
- }
- Err(_) => {
- return pack_err_res(
- message::INTERNAL_ERROR_CODE,
- INTERNAL_ERROR_MSG,
- Some(rpc_msg.id),
- );
- }
- };
-
- message::Response {
- jsonrpc: message::JSONRPC_VERSION.to_string(),
- error: None,
- result: Some(result),
- id: Some(rpc_msg.id),
- }
- }
-}
-
-pub struct ServerBuilder {
- endpoint: Endpoint,
- tls_config: Option<rustls::ServerConfig>,
- services: HashMap<String, Box<dyn RPCService + 'static>>,
-}
-
-impl ServerBuilder {
- pub fn service(mut self, service: impl RPCService + 'static) -> Self {
- self.services.insert(service.name(), Box::new(service));
- self
- }
-
- pub fn tls_config(mut self, config: rustls::ServerConfig) -> Result<ServerBuilder> {
- match self.endpoint {
- Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => {
- self.tls_config = Some(config);
- Ok(self)
- }
- _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
- }
- }
-
- pub async fn build(self) -> Result<Arc<Server>> {
- self._build(TaskGroup::new()).await
- }
-
- pub async fn build_with_executor(self, ex: Executor) -> Result<Arc<Server>> {
- self._build(TaskGroup::with_executor(ex)).await
- }
-
- async fn _build(self, task_group: TaskGroup) -> Result<Arc<Server>> {
- let listener: Listener<serde_json::Value> = match self.endpoint {
- Endpoint::Tcp(..) | Endpoint::Tls(..) => match &self.tls_config {
- Some(conf) => Box::new(
- karyon_net::tls::listen(
- &self.endpoint,
- karyon_net::tls::ServerTlsConfig {
- server_config: conf.clone(),
- tcp_config: Default::default(),
- },
- JsonCodec {},
- )
- .await?,
- ),
- None => Box::new(
- karyon_net::tcp::listen(&self.endpoint, Default::default(), JsonCodec {})
- .await?,
- ),
- },
- #[cfg(feature = "ws")]
- Endpoint::Ws(..) | Endpoint::Wss(..) => match &self.tls_config {
- Some(conf) => Box::new(
- karyon_net::ws::listen(
- &self.endpoint,
- karyon_net::ws::ServerWsConfig {
- tcp_config: Default::default(),
- wss_config: Some(karyon_net::ws::ServerWssConfig {
- server_config: conf.clone(),
- }),
- },
- WsJsonCodec {},
- )
- .await?,
- ),
- None => Box::new(
- karyon_net::ws::listen(&self.endpoint, Default::default(), WsJsonCodec {})
- .await?,
- ),
- },
- #[cfg(all(feature = "unix", target_family = "unix"))]
- Endpoint::Unix(..) => Box::new(karyon_net::unix::listen(
- &self.endpoint,
- Default::default(),
- JsonCodec {},
- )?),
-
- _ => return Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
- };
-
- Ok(Arc::new(Server {
- listener,
- task_group,
- services: self.services,
- }))
- }
-}
-
-impl ServerBuilder {}
-
-impl Server {
- pub fn builder(endpoint: impl ToEndpoint) -> Result<ServerBuilder> {
- let endpoint = endpoint.to_endpoint()?;
- Ok(ServerBuilder {
- endpoint,
- services: HashMap::new(),
- tls_config: None,
- })
- }
-}
diff --git a/jsonrpc/src/server/channel.rs b/jsonrpc/src/server/channel.rs
new file mode 100644
index 0000000..1498825
--- /dev/null
+++ b/jsonrpc/src/server/channel.rs
@@ -0,0 +1,69 @@
+use std::sync::Arc;
+
+use karyon_core::{async_runtime::lock::Mutex, util::random_32};
+
+use crate::{Error, Result};
+
+pub type SubscriptionID = u32;
+pub type ArcChannel = Arc<Channel>;
+
+/// Represents a new subscription
+pub struct Subscription {
+ pub id: SubscriptionID,
+ parent: Arc<Channel>,
+ chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>,
+}
+
+impl Subscription {
+ /// Creates a new `Subscription`
+ fn new(
+ parent: Arc<Channel>,
+ id: SubscriptionID,
+ chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>,
+ ) -> Self {
+ Self { parent, id, chan }
+ }
+
+ /// Sends a notification to the subscriber
+ pub async fn notify(&self, res: serde_json::Value) -> Result<()> {
+ if self.parent.subs.lock().await.contains(&self.id) {
+ self.chan.send((self.id, res)).await?;
+ Ok(())
+ } else {
+ Err(Error::SubscriptionNotFound(self.id.to_string()))
+ }
+ }
+}
+
+/// Represents a channel for creating/removing subscriptions
+pub struct Channel {
+ chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>,
+ subs: Mutex<Vec<SubscriptionID>>,
+}
+
+impl Channel {
+ /// Creates a new `Channel`
+ pub fn new(chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>) -> ArcChannel {
+ Arc::new(Self {
+ chan,
+ subs: Mutex::new(Vec::new()),
+ })
+ }
+
+ /// Creates a new subscription
+ pub async fn new_subscription(self: &Arc<Self>) -> Subscription {
+ let sub_id = random_32();
+ let sub = Subscription::new(self.clone(), sub_id, self.chan.clone());
+ self.subs.lock().await.push(sub_id);
+ sub
+ }
+
+ /// Removes a subscription
+ pub async fn remove_subscription(self: &Arc<Self>, id: &SubscriptionID) {
+ let i = match self.subs.lock().await.iter().position(|i| i == id) {
+ Some(i) => i,
+ None => return,
+ };
+ self.subs.lock().await.remove(i);
+ }
+}
diff --git a/jsonrpc/src/server/mod.rs b/jsonrpc/src/server/mod.rs
new file mode 100644
index 0000000..4ebab10
--- /dev/null
+++ b/jsonrpc/src/server/mod.rs
@@ -0,0 +1,454 @@
+pub mod channel;
+pub mod pubsub_service;
+pub mod service;
+
+use std::{collections::HashMap, sync::Arc};
+
+use log::{debug, error, warn};
+
+#[cfg(feature = "smol")]
+use futures_rustls::rustls;
+#[cfg(feature = "tokio")]
+use tokio_rustls::rustls;
+
+use karyon_core::{
+ async_runtime::Executor,
+ async_util::{select, Either, TaskGroup, TaskResult},
+};
+
+use karyon_net::{Conn, Endpoint, Listener, ToEndpoint};
+
+#[cfg(feature = "ws")]
+use crate::codec::WsJsonCodec;
+
+#[cfg(feature = "ws")]
+use karyon_net::ws::ServerWsConfig;
+
+use crate::{codec::JsonCodec, message, Error, PubSubRPCService, RPCService, Result, TcpConfig};
+
+use channel::{ArcChannel, Channel};
+
+const CHANNEL_CAP: usize = 10;
+
+pub const INVALID_REQUEST_ERROR_MSG: &str = "Invalid request";
+pub const FAILED_TO_PARSE_ERROR_MSG: &str = "Failed to parse";
+pub const METHOD_NOT_FOUND_ERROR_MSG: &str = "Method not found";
+pub const INTERNAL_ERROR_MSG: &str = "Internal error";
+
+fn pack_err_res(code: i32, msg: &str, id: Option<serde_json::Value>) -> message::Response {
+ let err = message::Error {
+ code,
+ message: msg.to_string(),
+ data: None,
+ };
+
+ message::Response {
+ jsonrpc: message::JSONRPC_VERSION.to_string(),
+ error: Some(err),
+ result: None,
+ id,
+ subscription: None,
+ }
+}
+
+struct NewRequest {
+ srvc_name: String,
+ method_name: String,
+ msg: message::Request,
+}
+
+enum SanityCheckResult {
+ NewReq(NewRequest),
+ ErrRes(message::Response),
+}
+
+/// Represents an RPC server
+pub struct Server {
+ listener: Listener<serde_json::Value>,
+ task_group: TaskGroup,
+ services: HashMap<String, Arc<dyn RPCService + 'static>>,
+ pubsub_services: HashMap<String, Arc<dyn PubSubRPCService + 'static>>,
+}
+
+impl Server {
+ /// Returns the local endpoint.
+ pub fn local_endpoint(&self) -> Result<Endpoint> {
+ self.listener.local_endpoint().map_err(Error::from)
+ }
+
+ /// Starts the RPC server
+ pub async fn start(self: Arc<Self>) -> Result<()> {
+ loop {
+ match self.listener.accept().await {
+ Ok(conn) => {
+ if let Err(err) = self.handle_conn(conn).await {
+ error!("Failed to handle a new conn: {err}")
+ }
+ }
+ Err(err) => {
+ error!("Failed to accept a new conn: {err}")
+ }
+ }
+ }
+ }
+
+ /// Shuts down the RPC server
+ pub async fn shutdown(&self) {
+ self.task_group.cancel().await;
+ }
+
+ /// Handles a new connection
+ async fn handle_conn(self: &Arc<Self>, conn: Conn<serde_json::Value>) -> Result<()> {
+ let endpoint = conn.peer_endpoint().expect("get peer endpoint");
+ debug!("Handle a new connection {endpoint}");
+
+ let on_failure = |result: TaskResult<Result<()>>| async move {
+ if let TaskResult::Completed(Err(err)) = result {
+ error!("Connection {} dropped: {}", endpoint, err);
+ } else {
+ warn!("Connection {} dropped", endpoint);
+ }
+ };
+
+ let selfc = self.clone();
+ let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_CAP);
+ let channel = Channel::new(ch_tx);
+ self.task_group.spawn(
+ async move {
+ loop {
+ match select(conn.recv(), ch_rx.recv()).await {
+ Either::Left(msg) => {
+ // TODO spawn a task
+ let response = selfc.handle_request(channel.clone(), msg?).await;
+ debug!("--> {response}");
+ conn.send(serde_json::to_value(response)?).await?;
+ }
+ Either::Right(msg) => {
+ let (sub_id, result) = msg?;
+ let response = message::Notification {
+ jsonrpc: message::JSONRPC_VERSION.to_string(),
+ method: None,
+ params: Some(result),
+ subscription: Some(sub_id.into()),
+ };
+ debug!("--> {response}");
+ conn.send(serde_json::to_value(response)?).await?;
+ }
+ }
+ }
+ },
+ on_failure,
+ );
+
+ Ok(())
+ }
+
+ fn sanity_check(&self, request: serde_json::Value) -> SanityCheckResult {
+ let rpc_msg = match serde_json::from_value::<message::Request>(request) {
+ Ok(m) => m,
+ Err(_) => {
+ return SanityCheckResult::ErrRes(pack_err_res(
+ message::PARSE_ERROR_CODE,
+ FAILED_TO_PARSE_ERROR_MSG,
+ None,
+ ));
+ }
+ };
+ debug!("<-- {rpc_msg}");
+
+ let srvc_method_str = rpc_msg.method.clone();
+ let srvc_method: Vec<&str> = srvc_method_str.split('.').collect();
+ if srvc_method.len() < 2 {
+ return SanityCheckResult::ErrRes(pack_err_res(
+ message::INVALID_REQUEST_ERROR_CODE,
+ INVALID_REQUEST_ERROR_MSG,
+ Some(rpc_msg.id),
+ ));
+ }
+
+ let srvc_name = srvc_method[0].to_string();
+ let method_name = srvc_method[1].to_string();
+
+ SanityCheckResult::NewReq(NewRequest {
+ srvc_name,
+ method_name,
+ msg: rpc_msg,
+ })
+ }
+
+ /// Handles a new request
+ async fn handle_request(
+ &self,
+ channel: ArcChannel,
+ msg: serde_json::Value,
+ ) -> message::Response {
+ let req = match self.sanity_check(msg) {
+ SanityCheckResult::NewReq(req) => req,
+ SanityCheckResult::ErrRes(res) => return res,
+ };
+
+ if req.msg.subscriber.is_some() {
+ match self.pubsub_services.get(&req.srvc_name) {
+ Some(s) => {
+ self.handle_pubsub_request(channel, s, &req.method_name, req.msg)
+ .await
+ }
+ None => pack_err_res(
+ message::METHOD_NOT_FOUND_ERROR_CODE,
+ METHOD_NOT_FOUND_ERROR_MSG,
+ Some(req.msg.id),
+ ),
+ }
+ } else {
+ match self.services.get(&req.srvc_name) {
+ Some(s) => self.handle_call_request(s, &req.method_name, req.msg).await,
+ None => pack_err_res(
+ message::METHOD_NOT_FOUND_ERROR_CODE,
+ METHOD_NOT_FOUND_ERROR_MSG,
+ Some(req.msg.id),
+ ),
+ }
+ }
+ }
+
+ /// Handles a call request
+ async fn handle_call_request(
+ &self,
+ service: &Arc<dyn RPCService + 'static>,
+ method_name: &str,
+ rpc_msg: message::Request,
+ ) -> message::Response {
+ let method = match service.get_method(method_name) {
+ Some(m) => m,
+ None => {
+ return pack_err_res(
+ message::METHOD_NOT_FOUND_ERROR_CODE,
+ METHOD_NOT_FOUND_ERROR_MSG,
+ Some(rpc_msg.id),
+ );
+ }
+ };
+
+ let result = match method(rpc_msg.params.clone()).await {
+ Ok(res) => res,
+ Err(err) => return self.handle_error(err, rpc_msg.id),
+ };
+
+ message::Response {
+ jsonrpc: message::JSONRPC_VERSION.to_string(),
+ error: None,
+ result: Some(result),
+ id: Some(rpc_msg.id),
+ subscription: None,
+ }
+ }
+
+ /// Handles a pubsub request
+ async fn handle_pubsub_request(
+ &self,
+ channel: ArcChannel,
+ service: &Arc<dyn PubSubRPCService + 'static>,
+ method_name: &str,
+ rpc_msg: message::Request,
+ ) -> message::Response {
+ let method = match service.get_pubsub_method(method_name) {
+ Some(m) => m,
+ None => {
+ return pack_err_res(
+ message::METHOD_NOT_FOUND_ERROR_CODE,
+ METHOD_NOT_FOUND_ERROR_MSG,
+ Some(rpc_msg.id),
+ );
+ }
+ };
+
+ let result = match method(channel, rpc_msg.params.clone()).await {
+ Ok(res) => res,
+ Err(err) => return self.handle_error(err, rpc_msg.id),
+ };
+
+ message::Response {
+ jsonrpc: message::JSONRPC_VERSION.to_string(),
+ error: None,
+ result: None,
+ id: Some(rpc_msg.id),
+ subscription: Some(result),
+ }
+ }
+
+ fn handle_error(&self, err: Error, msg_id: serde_json::Value) -> message::Response {
+ match err {
+ Error::ParseJSON(_) => pack_err_res(
+ message::PARSE_ERROR_CODE,
+ FAILED_TO_PARSE_ERROR_MSG,
+ Some(msg_id),
+ ),
+ Error::InvalidParams(msg) => {
+ pack_err_res(message::INVALID_PARAMS_ERROR_CODE, msg, Some(msg_id))
+ }
+ Error::InvalidRequest(msg) => {
+ pack_err_res(message::INVALID_REQUEST_ERROR_CODE, msg, Some(msg_id))
+ }
+ Error::RPCMethodError(code, msg) => pack_err_res(code, msg, Some(msg_id)),
+ _ => pack_err_res(
+ message::INTERNAL_ERROR_CODE,
+ INTERNAL_ERROR_MSG,
+ Some(msg_id),
+ ),
+ }
+ }
+}
+
+/// Builder for constructing an RPC [`Server`].
+pub struct ServerBuilder {
+ endpoint: Endpoint,
+ tcp_config: TcpConfig,
+ tls_config: Option<rustls::ServerConfig>,
+ services: HashMap<String, Arc<dyn RPCService + 'static>>,
+ pubsub_services: HashMap<String, Arc<dyn PubSubRPCService + 'static>>,
+}
+
+impl ServerBuilder {
+ /// Adds a new RPC service to the server.
+ pub fn service(mut self, service: Arc<dyn RPCService>) -> Self {
+ self.services.insert(service.name(), service);
+ self
+ }
+
+ /// Adds a new PubSub RPC service to the server.
+ pub fn pubsub_service(mut self, service: Arc<dyn PubSubRPCService>) -> Self {
+ self.pubsub_services.insert(service.name(), service);
+ self
+ }
+
+ /// Configure TCP settings for the server.
+ ///
+ /// # Example
+ ///
+ /// ```ignore
+ /// let tcp_config = TcpConfig::default();
+ /// let server = Server::builder()?.tcp_config(tcp_config)?.build()?;
+ /// ```
+ ///
+ /// This function will return an error if the endpoint does not support TCP protocols.
+ pub fn tcp_config(mut self, config: TcpConfig) -> Result<ServerBuilder> {
+ match self.endpoint {
+ Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => {
+ self.tcp_config = config;
+ Ok(self)
+ }
+ _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
+ }
+ }
+
+ /// Configure TLS settings for the server.
+ ///
+ /// # Example
+ ///
+ /// ```ignore
+ /// let tls_config = rustls::ServerConfig::new(...);
+ /// let server = Server::builder()?.tls_config(tls_config)?.build()?;
+ /// ```
+ ///
+ /// This function will return an error if the endpoint does not support TLS protocols.
+ pub fn tls_config(mut self, config: rustls::ServerConfig) -> Result<ServerBuilder> {
+ match self.endpoint {
+ Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => {
+ self.tls_config = Some(config);
+ Ok(self)
+ }
+ _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
+ }
+ }
+
+ /// Builds the server with the configured options.
+ pub async fn build(self) -> Result<Arc<Server>> {
+ self._build(TaskGroup::new()).await
+ }
+
+ /// Builds the server with the configured options and an executor.
+ pub async fn build_with_executor(self, ex: Executor) -> Result<Arc<Server>> {
+ self._build(TaskGroup::with_executor(ex)).await
+ }
+
+ async fn _build(self, task_group: TaskGroup) -> Result<Arc<Server>> {
+ let listener: Listener<serde_json::Value> = match self.endpoint {
+ Endpoint::Tcp(..) | Endpoint::Tls(..) => match &self.tls_config {
+ Some(conf) => Box::new(
+ karyon_net::tls::listen(
+ &self.endpoint,
+ karyon_net::tls::ServerTlsConfig {
+ server_config: conf.clone(),
+ tcp_config: self.tcp_config,
+ },
+ JsonCodec {},
+ )
+ .await?,
+ ),
+ None => Box::new(
+ karyon_net::tcp::listen(&self.endpoint, self.tcp_config, JsonCodec {}).await?,
+ ),
+ },
+ #[cfg(feature = "ws")]
+ Endpoint::Ws(..) | Endpoint::Wss(..) => match &self.tls_config {
+ Some(conf) => Box::new(
+ karyon_net::ws::listen(
+ &self.endpoint,
+ ServerWsConfig {
+ tcp_config: self.tcp_config,
+ wss_config: Some(karyon_net::ws::ServerWssConfig {
+ server_config: conf.clone(),
+ }),
+ },
+ WsJsonCodec {},
+ )
+ .await?,
+ ),
+ None => {
+ let config = ServerWsConfig {
+ tcp_config: self.tcp_config,
+ wss_config: None,
+ };
+ Box::new(karyon_net::ws::listen(&self.endpoint, config, WsJsonCodec {}).await?)
+ }
+ },
+ #[cfg(all(feature = "unix", target_family = "unix"))]
+ Endpoint::Unix(..) => Box::new(karyon_net::unix::listen(
+ &self.endpoint,
+ Default::default(),
+ JsonCodec {},
+ )?),
+
+ _ => return Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
+ };
+
+ Ok(Arc::new(Server {
+ listener,
+ task_group,
+ services: self.services,
+ pubsub_services: self.pubsub_services,
+ }))
+ }
+}
+
+impl Server {
+ /// Creates a new [`ServerBuilder`]
+ ///
+ /// This function initializes a `ServerBuilder` with the specified endpoint.
+ ///
+ /// # Example
+ ///
+ /// ```ignore
+ /// let builder = Server::builder("ws://127.0.0.1:3000")?.build()?;
+ /// ```
+ pub fn builder(endpoint: impl ToEndpoint) -> Result<ServerBuilder> {
+ let endpoint = endpoint.to_endpoint()?;
+ Ok(ServerBuilder {
+ endpoint,
+ services: HashMap::new(),
+ pubsub_services: HashMap::new(),
+ tcp_config: Default::default(),
+ tls_config: None,
+ })
+ }
+}
diff --git a/jsonrpc/src/server/pubsub_service.rs b/jsonrpc/src/server/pubsub_service.rs
new file mode 100644
index 0000000..5b4bf9a
--- /dev/null
+++ b/jsonrpc/src/server/pubsub_service.rs
@@ -0,0 +1,67 @@
+use std::{future::Future, pin::Pin};
+
+use crate::Result;
+
+use super::channel::ArcChannel;
+
+/// Represents the RPC method
+pub type PubSubRPCMethod<'a> =
+ Box<dyn Fn(ArcChannel, serde_json::Value) -> PubSubRPCMethodOutput<'a> + Send + 'a>;
+type PubSubRPCMethodOutput<'a> =
+ Pin<Box<dyn Future<Output = Result<serde_json::Value>> + Send + Sync + 'a>>;
+
+/// Defines the interface for an RPC service.
+pub trait PubSubRPCService: Sync + Send {
+ fn get_pubsub_method<'a>(&'a self, name: &'a str) -> Option<PubSubRPCMethod>;
+ fn name(&self) -> String;
+}
+
+/// Implements the [`PubSubRPCService`] trait for a provided type.
+///
+/// # Example
+///
+/// ```
+/// use serde_json::Value;
+///
+/// use karyon_jsonrpc::{Error, impl_rpc_service};
+///
+/// struct Hello {}
+///
+/// impl Hello {
+/// async fn foo(&self, params: Value) -> Result<Value, Error> {
+/// Ok(serde_json::json!("foo!"))
+/// }
+///
+/// async fn bar(&self, params: Value) -> Result<Value, Error> {
+/// Ok(serde_json::json!("bar!"))
+/// }
+/// }
+///
+/// impl_rpc_service!(Hello, foo, bar);
+///
+/// ```
+#[macro_export]
+macro_rules! impl_pubsub_rpc_service {
+ ($t:ty, $($m:ident),*) => {
+ impl karyon_jsonrpc::PubSubRPCService for $t {
+ fn get_pubsub_method<'a>(
+ &'a self,
+ name: &'a str
+ ) -> Option<karyon_jsonrpc::PubSubRPCMethod> {
+ match name {
+ $(
+ stringify!($m) => {
+ Some(Box::new(move |chan: karyon_jsonrpc::ArcChannel, params: serde_json::Value| Box::pin(self.$m(chan, params))))
+ }
+ )*
+ _ => None,
+ }
+
+
+ }
+ fn name(&self) -> String{
+ stringify!($t).to_string()
+ }
+ }
+ };
+}
diff --git a/jsonrpc/src/service.rs b/jsonrpc/src/server/service.rs
index 4c8c4b8..4c8c4b8 100644
--- a/jsonrpc/src/service.rs
+++ b/jsonrpc/src/server/service.rs