aboutsummaryrefslogtreecommitdiff
path: root/jsonrpc/src/client/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'jsonrpc/src/client/mod.rs')
-rw-r--r--jsonrpc/src/client/mod.rs324
1 files changed, 78 insertions, 246 deletions
diff --git a/jsonrpc/src/client/mod.rs b/jsonrpc/src/client/mod.rs
index b614c95..0666ee0 100644
--- a/jsonrpc/src/client/mod.rs
+++ b/jsonrpc/src/client/mod.rs
@@ -1,33 +1,22 @@
+pub mod builder;
+
use std::{collections::HashMap, sync::Arc, time::Duration};
use log::{debug, error, warn};
use serde::{de::DeserializeOwned, Serialize};
use serde_json::json;
-#[cfg(feature = "smol")]
-use futures_rustls::rustls;
-#[cfg(feature = "tokio")]
-use tokio_rustls::rustls;
-
use karyon_core::{
async_runtime::lock::Mutex,
async_util::{timeout, TaskGroup, TaskResult},
- util::random_64,
+ util::random_32,
};
-use karyon_net::{tls::ClientTlsConfig, Conn, Endpoint, ToEndpoint};
-
-#[cfg(feature = "ws")]
-use karyon_net::ws::{ClientWsConfig, ClientWssConfig};
+use karyon_net::Conn;
-#[cfg(feature = "ws")]
-use crate::codec::WsJsonCodec;
-
-use crate::{codec::JsonCodec, message, Error, Result, SubscriptionID, TcpConfig};
+use crate::{message, Error, Result, SubscriptionID};
const CHANNEL_CAP: usize = 10;
-const DEFAULT_TIMEOUT: u64 = 3000; // 3s
-
/// Type alias for a subscription to receive notifications.
///
/// The receiver channel is returned by the `subscribe` method to receive
@@ -37,9 +26,8 @@ pub type Subscription = async_channel::Receiver<serde_json::Value>;
/// Represents an RPC client
pub struct Client {
conn: Conn<serde_json::Value>,
- chan_tx: async_channel::Sender<message::Response>,
- chan_rx: async_channel::Receiver<message::Response>,
timeout: Option<u64>,
+ chans: Mutex<HashMap<u32, async_channel::Sender<message::Response>>>,
subscriptions: Mutex<HashMap<SubscriptionID, async_channel::Sender<serde_json::Value>>>,
task_group: TaskGroup,
}
@@ -51,20 +39,7 @@ impl Client {
method: &str,
params: T,
) -> Result<V> {
- let request = self.send_request(method, params).await?;
-
- let response = match self.timeout {
- Some(t) => timeout(Duration::from_millis(t), self.chan_rx.recv()).await??,
- None => self.chan_rx.recv().await?,
- };
-
- if let Some(error) = response.error {
- return Err(Error::CallError(error.code, error.message));
- }
-
- if response.id.is_none() || response.id.unwrap() != request.id {
- return Err(Error::InvalidMsg("Invalid response id"));
- }
+ let response = self.send_request(method, params).await?;
match response.result {
Some(result) => Ok(serde_json::from_value::<V>(result)?),
@@ -82,20 +57,7 @@ impl Client {
method: &str,
params: T,
) -> Result<(SubscriptionID, Subscription)> {
- let request = self.send_request(method, params).await?;
-
- let response = match self.timeout {
- Some(t) => timeout(Duration::from_millis(t), self.chan_rx.recv()).await??,
- None => self.chan_rx.recv().await?,
- };
-
- if let Some(error) = response.error {
- return Err(Error::SubscribeError(error.code, error.message));
- }
-
- if response.id.is_none() || response.id.unwrap() != request.id {
- return Err(Error::InvalidMsg("Invalid response id"));
- }
+ let response = self.send_request(method, params).await?;
let sub_id = match response.result {
Some(result) => serde_json::from_value::<SubscriptionID>(result)?,
@@ -113,21 +75,7 @@ impl Client {
/// This function sends an unsubscription request for the specified method
/// and subscription ID. It waits for the response to confirm the unsubscription.
pub async fn unsubscribe(&self, method: &str, sub_id: SubscriptionID) -> Result<()> {
- let request = self.send_request(method, sub_id).await?;
-
- let response = match self.timeout {
- Some(t) => timeout(Duration::from_millis(t), self.chan_rx.recv()).await??,
- None => self.chan_rx.recv().await?,
- };
-
- if let Some(error) = response.error {
- return Err(Error::SubscribeError(error.code, error.message));
- }
-
- if response.id.is_none() || response.id.unwrap() != request.id {
- return Err(Error::InvalidMsg("Invalid response id"));
- }
-
+ let _ = self.send_request(method, sub_id).await?;
self.subscriptions.lock().await.remove(&sub_id);
Ok(())
}
@@ -136,8 +84,8 @@ impl Client {
&self,
method: &str,
params: T,
- ) -> Result<message::Request> {
- let id = random_64();
+ ) -> Result<message::Response> {
+ let id = random_32();
let request = message::Request {
jsonrpc: message::JSONRPC_VERSION.to_string(),
id: json!(id),
@@ -157,8 +105,39 @@ impl Client {
}
}
+ let (tx, rx) = async_channel::bounded(CHANNEL_CAP);
+ self.chans.lock().await.insert(id, tx);
+
+ let response = match self.wait_for_response(rx).await {
+ Ok(r) => r,
+ Err(err) => {
+ self.chans.lock().await.remove(&id);
+ return Err(err);
+ }
+ };
+
+ if let Some(error) = response.error {
+ return Err(Error::SubscribeError(error.code, error.message));
+ }
+
+ if *response.id.as_ref().unwrap() != request.id {
+ return Err(Error::InvalidMsg("Invalid response id"));
+ }
+
debug!("--> {request}");
- Ok(request)
+ Ok(response)
+ }
+
+ async fn wait_for_response(
+ &self,
+ rx: async_channel::Receiver<message::Response>,
+ ) -> Result<message::Response> {
+ match self.timeout {
+ Some(t) => timeout(Duration::from_millis(t), rx.recv())
+ .await?
+ .map_err(Error::from),
+ None => rx.recv().await.map_err(Error::from),
+ }
}
fn start_background_receiving(self: &Arc<Self>) {
@@ -175,201 +154,54 @@ impl Client {
async move {
loop {
let msg = selfc.conn.recv().await?;
- if let Ok(res) = serde_json::from_value::<message::Response>(msg.clone()) {
- debug!("<-- {res}");
- selfc.chan_tx.send(res).await?;
- continue;
- }
-
- if let Ok(nt) = serde_json::from_value::<message::Notification>(msg.clone()) {
- debug!("<-- {nt}");
- let sub_result: message::NotificationResult = match nt.params {
- Some(ref p) => serde_json::from_value(p.clone())?,
- None => return Err(Error::InvalidMsg("Invalid notification msg")),
- };
-
- match selfc
- .subscriptions
- .lock()
- .await
- .get(&sub_result.subscription)
- {
- Some(s) => {
- s.send(sub_result.result.unwrap_or(json!(""))).await?;
- continue;
- }
- None => {
- warn!("Receive unknown notification {}", sub_result.subscription);
- continue;
- }
- }
- }
-
- error!("Receive unexpected msg: {msg}");
- return Err(Error::InvalidMsg("Unexpected msg"));
+ selfc.handle_msg(msg).await?;
}
},
on_failure,
);
}
-}
-
-/// Builder for constructing an RPC [`Client`].
-pub struct ClientBuilder {
- endpoint: Endpoint,
- tls_config: Option<(rustls::ClientConfig, String)>,
- tcp_config: TcpConfig,
- timeout: Option<u64>,
-}
-
-impl ClientBuilder {
- /// Set timeout for sending and receiving messages, in milliseconds.
- ///
- /// # Examples
- ///
- /// ```ignore
- /// let client = Client::builder()?.set_timeout(5000).build()?;
- /// ```
- pub fn set_timeout(mut self, timeout: u64) -> Self {
- self.timeout = Some(timeout);
- self
- }
- /// Configure TCP settings for the client.
- ///
- /// # Example
- ///
- /// ```ignore
- /// let tcp_config = TcpConfig::default();
- /// let client = Client::builder()?.tcp_config(tcp_config)?.build()?;
- /// ```
- ///
- /// This function will return an error if the endpoint does not support TCP protocols.
- pub fn tcp_config(mut self, config: TcpConfig) -> Result<Self> {
- match self.endpoint {
- Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => {
- self.tcp_config = config;
- Ok(self)
+ async fn handle_msg(&self, msg: serde_json::Value) -> Result<()> {
+ if let Ok(res) = serde_json::from_value::<message::Response>(msg.clone()) {
+ debug!("<-- {res}");
+ if res.id.is_none() {
+ return Err(Error::InvalidMsg("Response id is none"));
}
- _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
- }
- }
- /// Configure TLS settings for the client.
- ///
- /// # Example
- ///
- /// ```ignore
- /// let tls_config = rustls::ClientConfig::new(...);
- /// let client = Client::builder()?.tls_config(tls_config, "example.com")?.build()?;
- /// ```
- ///
- /// This function will return an error if the endpoint does not support TLS protocols.
- pub fn tls_config(mut self, config: rustls::ClientConfig, dns_name: &str) -> Result<Self> {
- match self.endpoint {
- Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => {
- self.tls_config = Some((config, dns_name.to_string()));
- Ok(self)
+ let id: u32 = serde_json::from_value(res.id.clone().unwrap())?;
+ match self.chans.lock().await.remove(&id) {
+ Some(tx) => tx.send(res).await?,
+ None => return Err(Error::InvalidMsg("Receive unkown message")),
}
- _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
+
+ return Ok(());
}
- }
- /// Build RPC client from [`ClientBuilder`].
- ///
- /// This function creates a new RPC client using the configurations
- /// specified in the `ClientBuilder`. It returns a `Arc<Client>` on success.
- ///
- /// # Example
- ///
- /// ```ignore
- /// let client = Client::builder(endpoint)?
- /// .set_timeout(5000)
- /// .tcp_config(tcp_config)?
- /// .tls_config(tls_config, "example.com")?
- /// .build()
- /// .await?;
- /// ```
- pub async fn build(self) -> Result<Arc<Client>> {
- let conn: Conn<serde_json::Value> = match self.endpoint {
- Endpoint::Tcp(..) | Endpoint::Tls(..) => match self.tls_config {
- Some((conf, dns_name)) => Box::new(
- karyon_net::tls::dial(
- &self.endpoint,
- ClientTlsConfig {
- dns_name,
- client_config: conf,
- tcp_config: self.tcp_config,
- },
- JsonCodec {},
- )
- .await?,
- ),
- None => Box::new(
- karyon_net::tcp::dial(&self.endpoint, self.tcp_config, JsonCodec {}).await?,
- ),
- },
- #[cfg(feature = "ws")]
- Endpoint::Ws(..) | Endpoint::Wss(..) => match self.tls_config {
- Some((conf, dns_name)) => Box::new(
- karyon_net::ws::dial(
- &self.endpoint,
- ClientWsConfig {
- tcp_config: self.tcp_config,
- wss_config: Some(ClientWssConfig {
- dns_name,
- client_config: conf,
- }),
- },
- WsJsonCodec {},
- )
- .await?,
- ),
+ if let Ok(nt) = serde_json::from_value::<message::Notification>(msg.clone()) {
+ debug!("<-- {nt}");
+ let sub_result: message::NotificationResult = match nt.params {
+ Some(ref p) => serde_json::from_value(p.clone())?,
+ None => return Err(Error::InvalidMsg("Invalid notification msg")),
+ };
+
+ match self
+ .subscriptions
+ .lock()
+ .await
+ .get(&sub_result.subscription)
+ {
+ Some(s) => {
+ s.send(sub_result.result.unwrap_or(json!(""))).await?;
+ return Ok(());
+ }
None => {
- let config = ClientWsConfig {
- tcp_config: self.tcp_config,
- wss_config: None,
- };
- Box::new(karyon_net::ws::dial(&self.endpoint, config, WsJsonCodec {}).await?)
+ warn!("Receive unknown notification {}", sub_result.subscription);
+ return Ok(());
}
- },
- #[cfg(all(feature = "unix", target_family = "unix"))]
- Endpoint::Unix(..) => Box::new(
- karyon_net::unix::dial(&self.endpoint, Default::default(), JsonCodec {}).await?,
- ),
- _ => return Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
- };
+ }
+ }
- let (tx, rx) = async_channel::bounded(CHANNEL_CAP);
- let client = Arc::new(Client {
- timeout: self.timeout,
- conn,
- chan_tx: tx,
- chan_rx: rx,
- subscriptions: Mutex::new(HashMap::new()),
- task_group: TaskGroup::new(),
- });
- client.start_background_receiving();
- Ok(client)
- }
-}
-impl Client {
- /// Creates a new [`ClientBuilder`]
- ///
- /// This function initializes a `ClientBuilder` with the specified endpoint.
- ///
- /// # Example
- ///
- /// ```ignore
- /// let builder = Client::builder("ws://127.0.0.1:3000")?.build()?;
- /// ```
- pub fn builder(endpoint: impl ToEndpoint) -> Result<ClientBuilder> {
- let endpoint = endpoint.to_endpoint()?;
- Ok(ClientBuilder {
- endpoint,
- timeout: Some(DEFAULT_TIMEOUT),
- tls_config: None,
- tcp_config: Default::default(),
- })
+ error!("Receive unexpected msg: {msg}");
+ Err(Error::InvalidMsg("Unexpected msg"))
}
}