diff options
Diffstat (limited to 'jsonrpc/src/client')
-rw-r--r-- | jsonrpc/src/client/builder.rs | 131 | ||||
-rw-r--r-- | jsonrpc/src/client/mod.rs | 119 | ||||
-rw-r--r-- | jsonrpc/src/client/subscriptions.rs | 2 |
3 files changed, 139 insertions, 113 deletions
diff --git a/jsonrpc/src/client/builder.rs b/jsonrpc/src/client/builder.rs index 5a7936c..d1e3b67 100644 --- a/jsonrpc/src/client/builder.rs +++ b/jsonrpc/src/client/builder.rs @@ -1,26 +1,17 @@ -use std::sync::{atomic::AtomicBool, Arc}; +use std::sync::Arc; -use karyon_core::async_util::TaskGroup; -use karyon_net::{Conn, Endpoint, ToEndpoint}; +#[cfg(feature = "tcp")] +use karyon_net::Endpoint; +use karyon_net::ToEndpoint; #[cfg(feature = "tls")] -use karyon_net::{async_rustls::rustls, tls::ClientTlsConfig}; - -#[cfg(feature = "ws")] -use karyon_net::ws::ClientWsConfig; - -#[cfg(all(feature = "ws", feature = "tls"))] -use karyon_net::ws::ClientWssConfig; - -#[cfg(feature = "ws")] -use crate::codec::WsJsonCodec; +use karyon_net::async_rustls::rustls; +use crate::Result; #[cfg(feature = "tcp")] -use crate::TcpConfig; +use crate::{Error, TcpConfig}; -use crate::{codec::JsonCodec, Error, Result}; - -use super::{Client, MessageDispatcher, Subscriptions}; +use super::{Client, ClientConfig}; const DEFAULT_TIMEOUT: u64 = 3000; // 3s @@ -44,26 +35,22 @@ impl Client { pub fn builder(endpoint: impl ToEndpoint) -> Result<ClientBuilder> { let endpoint = endpoint.to_endpoint()?; Ok(ClientBuilder { - endpoint, - timeout: Some(DEFAULT_TIMEOUT), - #[cfg(feature = "tcp")] - tcp_config: Default::default(), - #[cfg(feature = "tls")] - tls_config: None, - subscription_buffer_size: DEFAULT_MAX_SUBSCRIPTION_BUFFER_SIZE, + inner: ClientConfig { + endpoint, + timeout: Some(DEFAULT_TIMEOUT), + #[cfg(feature = "tcp")] + tcp_config: Default::default(), + #[cfg(feature = "tls")] + tls_config: None, + subscription_buffer_size: DEFAULT_MAX_SUBSCRIPTION_BUFFER_SIZE, + }, }) } } /// Builder for constructing an RPC [`Client`]. pub struct ClientBuilder { - endpoint: Endpoint, - #[cfg(feature = "tcp")] - tcp_config: TcpConfig, - #[cfg(feature = "tls")] - tls_config: Option<(rustls::ClientConfig, String)>, - timeout: Option<u64>, - subscription_buffer_size: usize, + inner: ClientConfig, } impl ClientBuilder { @@ -82,7 +69,7 @@ impl ClientBuilder { /// }; /// ``` pub fn set_timeout(mut self, timeout: u64) -> Self { - self.timeout = Some(timeout); + self.inner.timeout = Some(timeout); self } @@ -106,7 +93,7 @@ impl ClientBuilder { /// }; /// ``` pub fn set_max_subscription_buffer_size(mut self, size: usize) -> Self { - self.subscription_buffer_size = size; + self.inner.subscription_buffer_size = size; self } @@ -128,12 +115,12 @@ impl ClientBuilder { /// This function will return an error if the endpoint does not support TCP protocols. #[cfg(feature = "tcp")] pub fn tcp_config(mut self, config: TcpConfig) -> Result<Self> { - match self.endpoint { + match self.inner.endpoint { Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => { - self.tcp_config = config; + self.inner.tcp_config = config; Ok(self) } - _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())), + _ => Err(Error::UnsupportedProtocol(self.inner.endpoint.to_string())), } } @@ -157,14 +144,14 @@ impl ClientBuilder { /// This function will return an error if the endpoint does not support TLS protocols. #[cfg(feature = "tls")] pub fn tls_config(mut self, config: rustls::ClientConfig, dns_name: &str) -> Result<Self> { - match self.endpoint { + match self.inner.endpoint { Endpoint::Tls(..) | Endpoint::Wss(..) => { - self.tls_config = Some((config, dns_name.to_string())); + self.inner.tls_config = Some((config, dns_name.to_string())); Ok(self) } _ => Err(Error::UnsupportedProtocol(format!( "Invalid tls config for endpoint: {}", - self.endpoint + self.inner.endpoint ))), } } @@ -189,71 +176,7 @@ impl ClientBuilder { /// /// ``` pub async fn build(self) -> Result<Arc<Client>> { - let conn: Conn<serde_json::Value> = match self.endpoint { - #[cfg(feature = "tcp")] - Endpoint::Tcp(..) => Box::new( - karyon_net::tcp::dial(&self.endpoint, self.tcp_config, JsonCodec {}).await?, - ), - #[cfg(feature = "tls")] - Endpoint::Tls(..) => match self.tls_config { - Some((conf, dns_name)) => Box::new( - karyon_net::tls::dial( - &self.endpoint, - ClientTlsConfig { - dns_name, - client_config: conf, - tcp_config: self.tcp_config, - }, - JsonCodec {}, - ) - .await?, - ), - None => return Err(Error::TLSConfigRequired), - }, - #[cfg(feature = "ws")] - Endpoint::Ws(..) => { - let config = ClientWsConfig { - tcp_config: self.tcp_config, - wss_config: None, - }; - Box::new(karyon_net::ws::dial(&self.endpoint, config, WsJsonCodec {}).await?) - } - #[cfg(all(feature = "ws", feature = "tls"))] - Endpoint::Wss(..) => match self.tls_config { - Some((conf, dns_name)) => Box::new( - karyon_net::ws::dial( - &self.endpoint, - ClientWsConfig { - tcp_config: self.tcp_config, - wss_config: Some(ClientWssConfig { - dns_name, - client_config: conf, - }), - }, - WsJsonCodec {}, - ) - .await?, - ), - None => return Err(Error::TLSConfigRequired), - }, - #[cfg(all(feature = "unix", target_family = "unix"))] - Endpoint::Unix(..) => Box::new( - karyon_net::unix::dial(&self.endpoint, Default::default(), JsonCodec {}).await?, - ), - _ => return Err(Error::UnsupportedProtocol(self.endpoint.to_string())), - }; - - let send_chan = async_channel::bounded(10); - - let client = Arc::new(Client { - timeout: self.timeout, - disconnect: AtomicBool::new(false), - send_chan, - message_dispatcher: MessageDispatcher::new(), - subscriptions: Subscriptions::new(self.subscription_buffer_size), - task_group: TaskGroup::new(), - }); - client.start_background_loop(conn); + let client = Client::init(self.inner).await?; Ok(client) } } diff --git a/jsonrpc/src/client/mod.rs b/jsonrpc/src/client/mod.rs index 80125b1..51f0233 100644 --- a/jsonrpc/src/client/mod.rs +++ b/jsonrpc/src/client/mod.rs @@ -15,13 +15,26 @@ use log::{debug, error}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::json; +#[cfg(feature = "tcp")] +use karyon_net::tcp::TcpConfig; +#[cfg(feature = "ws")] +use karyon_net::ws::ClientWsConfig; +#[cfg(all(feature = "ws", feature = "tls"))] +use karyon_net::ws::ClientWssConfig; +#[cfg(feature = "tls")] +use karyon_net::{async_rustls::rustls, tls::ClientTlsConfig}; +use karyon_net::{Conn, Endpoint}; + use karyon_core::{ async_util::{select, timeout, Either, TaskGroup, TaskResult}, util::random_32, }; -use karyon_net::Conn; + +#[cfg(feature = "ws")] +use crate::codec::WsJsonCodec; use crate::{ + codec::JsonCodec, message::{self, SubscriptionID}, Error, Result, }; @@ -32,14 +45,24 @@ use subscriptions::Subscriptions; type RequestID = u32; +struct ClientConfig { + endpoint: Endpoint, + #[cfg(feature = "tcp")] + tcp_config: TcpConfig, + #[cfg(feature = "tls")] + tls_config: Option<(rustls::ClientConfig, String)>, + timeout: Option<u64>, + subscription_buffer_size: usize, +} + /// Represents an RPC client pub struct Client { - timeout: Option<u64>, disconnect: AtomicBool, message_dispatcher: MessageDispatcher, - task_group: TaskGroup, - send_chan: (Sender<serde_json::Value>, Receiver<serde_json::Value>), subscriptions: Arc<Subscriptions>, + send_chan: (Sender<serde_json::Value>, Receiver<serde_json::Value>), + task_group: TaskGroup, + config: ClientConfig, } #[derive(Serialize, Deserialize)] @@ -96,6 +119,11 @@ impl Client { Ok(()) } + /// Disconnect the client + pub async fn stop(&self) { + self.task_group.cancel().await; + } + async fn send_request<T: Serialize + DeserializeOwned>( &self, method: &str, @@ -116,7 +144,7 @@ impl Client { let rx = self.message_dispatcher.register(id).await; // Wait for the message dispatcher to send the response - let result = match self.timeout { + let result = match self.config.timeout { Some(t) => timeout(Duration::from_millis(t), rx.recv()).await?, None => rx.recv().await, }; @@ -152,11 +180,86 @@ impl Client { Ok(()) } + async fn init(config: ClientConfig) -> Result<Arc<Self>> { + let client = Arc::new(Client { + disconnect: AtomicBool::new(false), + subscriptions: Subscriptions::new(config.subscription_buffer_size), + send_chan: async_channel::bounded(10), + message_dispatcher: MessageDispatcher::new(), + task_group: TaskGroup::new(), + config, + }); + + let conn = client.connect().await?; + client.start_background_loop(conn); + Ok(client) + } + + async fn connect(self: &Arc<Self>) -> Result<Conn<serde_json::Value>> { + let endpoint = self.config.endpoint.clone(); + let conn: Conn<serde_json::Value> = match endpoint { + #[cfg(feature = "tcp")] + Endpoint::Tcp(..) => Box::new( + karyon_net::tcp::dial(&endpoint, self.config.tcp_config.clone(), JsonCodec {}) + .await?, + ), + #[cfg(feature = "tls")] + Endpoint::Tls(..) => match &self.config.tls_config { + Some((conf, dns_name)) => Box::new( + karyon_net::tls::dial( + &self.config.endpoint, + ClientTlsConfig { + dns_name: dns_name.to_string(), + client_config: conf.clone(), + tcp_config: self.config.tcp_config.clone(), + }, + JsonCodec {}, + ) + .await?, + ), + None => return Err(Error::TLSConfigRequired), + }, + #[cfg(feature = "ws")] + Endpoint::Ws(..) => { + let config = ClientWsConfig { + tcp_config: self.config.tcp_config.clone(), + wss_config: None, + }; + Box::new(karyon_net::ws::dial(&endpoint, config, WsJsonCodec {}).await?) + } + #[cfg(all(feature = "ws", feature = "tls"))] + Endpoint::Wss(..) => match &self.config.tls_config { + Some((conf, dns_name)) => Box::new( + karyon_net::ws::dial( + &endpoint, + ClientWsConfig { + tcp_config: self.config.tcp_config.clone(), + wss_config: Some(ClientWssConfig { + dns_name: dns_name.clone(), + client_config: conf.clone(), + }), + }, + WsJsonCodec {}, + ) + .await?, + ), + None => return Err(Error::TLSConfigRequired), + }, + #[cfg(all(feature = "unix", target_family = "unix"))] + Endpoint::Unix(..) => { + Box::new(karyon_net::unix::dial(&endpoint, Default::default(), JsonCodec {}).await?) + } + _ => return Err(Error::UnsupportedProtocol(endpoint.to_string())), + }; + + Ok(conn) + } + fn start_background_loop(self: &Arc<Self>, conn: Conn<serde_json::Value>) { let selfc = self.clone(); let on_complete = |result: TaskResult<Result<()>>| async move { if let TaskResult::Completed(Err(err)) = result { - error!("Background loop stopped: {err}"); + error!("Client stopped: {err}"); } selfc.disconnect.store(true, Ordering::Relaxed); selfc.subscriptions.clear().await; @@ -201,8 +304,8 @@ impl Client { self.subscriptions.notify(nt).await } }, - Err(_) => { - error!("Receive unexpected msg: {msg}"); + Err(err) => { + error!("Receive unexpected msg {msg}: {err}"); Err(Error::InvalidMsg("Unexpected msg")) } } diff --git a/jsonrpc/src/client/subscriptions.rs b/jsonrpc/src/client/subscriptions.rs index f3d8cb2..fe66f96 100644 --- a/jsonrpc/src/client/subscriptions.rs +++ b/jsonrpc/src/client/subscriptions.rs @@ -25,7 +25,7 @@ impl Subscription { } pub async fn recv(&self) -> Result<Value> { - self.rx.recv().await.map_err(Error::from) + self.rx.recv().await.map_err(|_| Error::SubscriptionClosed) } pub fn id(&self) -> SubscriptionID { |