diff options
Diffstat (limited to 'jsonrpc/src/server')
-rw-r--r-- | jsonrpc/src/server/channel.rs | 69 | ||||
-rw-r--r-- | jsonrpc/src/server/mod.rs | 454 | ||||
-rw-r--r-- | jsonrpc/src/server/pubsub_service.rs | 67 | ||||
-rw-r--r-- | jsonrpc/src/server/service.rs | 64 |
4 files changed, 654 insertions, 0 deletions
diff --git a/jsonrpc/src/server/channel.rs b/jsonrpc/src/server/channel.rs new file mode 100644 index 0000000..1498825 --- /dev/null +++ b/jsonrpc/src/server/channel.rs @@ -0,0 +1,69 @@ +use std::sync::Arc; + +use karyon_core::{async_runtime::lock::Mutex, util::random_32}; + +use crate::{Error, Result}; + +pub type SubscriptionID = u32; +pub type ArcChannel = Arc<Channel>; + +/// Represents a new subscription +pub struct Subscription { + pub id: SubscriptionID, + parent: Arc<Channel>, + chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>, +} + +impl Subscription { + /// Creates a new `Subscription` + fn new( + parent: Arc<Channel>, + id: SubscriptionID, + chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>, + ) -> Self { + Self { parent, id, chan } + } + + /// Sends a notification to the subscriber + pub async fn notify(&self, res: serde_json::Value) -> Result<()> { + if self.parent.subs.lock().await.contains(&self.id) { + self.chan.send((self.id, res)).await?; + Ok(()) + } else { + Err(Error::SubscriptionNotFound(self.id.to_string())) + } + } +} + +/// Represents a channel for creating/removing subscriptions +pub struct Channel { + chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>, + subs: Mutex<Vec<SubscriptionID>>, +} + +impl Channel { + /// Creates a new `Channel` + pub fn new(chan: async_channel::Sender<(SubscriptionID, serde_json::Value)>) -> ArcChannel { + Arc::new(Self { + chan, + subs: Mutex::new(Vec::new()), + }) + } + + /// Creates a new subscription + pub async fn new_subscription(self: &Arc<Self>) -> Subscription { + let sub_id = random_32(); + let sub = Subscription::new(self.clone(), sub_id, self.chan.clone()); + self.subs.lock().await.push(sub_id); + sub + } + + /// Removes a subscription + pub async fn remove_subscription(self: &Arc<Self>, id: &SubscriptionID) { + let i = match self.subs.lock().await.iter().position(|i| i == id) { + Some(i) => i, + None => return, + }; + self.subs.lock().await.remove(i); + } +} diff --git a/jsonrpc/src/server/mod.rs b/jsonrpc/src/server/mod.rs new file mode 100644 index 0000000..4ebab10 --- /dev/null +++ b/jsonrpc/src/server/mod.rs @@ -0,0 +1,454 @@ +pub mod channel; +pub mod pubsub_service; +pub mod service; + +use std::{collections::HashMap, sync::Arc}; + +use log::{debug, error, warn}; + +#[cfg(feature = "smol")] +use futures_rustls::rustls; +#[cfg(feature = "tokio")] +use tokio_rustls::rustls; + +use karyon_core::{ + async_runtime::Executor, + async_util::{select, Either, TaskGroup, TaskResult}, +}; + +use karyon_net::{Conn, Endpoint, Listener, ToEndpoint}; + +#[cfg(feature = "ws")] +use crate::codec::WsJsonCodec; + +#[cfg(feature = "ws")] +use karyon_net::ws::ServerWsConfig; + +use crate::{codec::JsonCodec, message, Error, PubSubRPCService, RPCService, Result, TcpConfig}; + +use channel::{ArcChannel, Channel}; + +const CHANNEL_CAP: usize = 10; + +pub const INVALID_REQUEST_ERROR_MSG: &str = "Invalid request"; +pub const FAILED_TO_PARSE_ERROR_MSG: &str = "Failed to parse"; +pub const METHOD_NOT_FOUND_ERROR_MSG: &str = "Method not found"; +pub const INTERNAL_ERROR_MSG: &str = "Internal error"; + +fn pack_err_res(code: i32, msg: &str, id: Option<serde_json::Value>) -> message::Response { + let err = message::Error { + code, + message: msg.to_string(), + data: None, + }; + + message::Response { + jsonrpc: message::JSONRPC_VERSION.to_string(), + error: Some(err), + result: None, + id, + subscription: None, + } +} + +struct NewRequest { + srvc_name: String, + method_name: String, + msg: message::Request, +} + +enum SanityCheckResult { + NewReq(NewRequest), + ErrRes(message::Response), +} + +/// Represents an RPC server +pub struct Server { + listener: Listener<serde_json::Value>, + task_group: TaskGroup, + services: HashMap<String, Arc<dyn RPCService + 'static>>, + pubsub_services: HashMap<String, Arc<dyn PubSubRPCService + 'static>>, +} + +impl Server { + /// Returns the local endpoint. + pub fn local_endpoint(&self) -> Result<Endpoint> { + self.listener.local_endpoint().map_err(Error::from) + } + + /// Starts the RPC server + pub async fn start(self: Arc<Self>) -> Result<()> { + loop { + match self.listener.accept().await { + Ok(conn) => { + if let Err(err) = self.handle_conn(conn).await { + error!("Failed to handle a new conn: {err}") + } + } + Err(err) => { + error!("Failed to accept a new conn: {err}") + } + } + } + } + + /// Shuts down the RPC server + pub async fn shutdown(&self) { + self.task_group.cancel().await; + } + + /// Handles a new connection + async fn handle_conn(self: &Arc<Self>, conn: Conn<serde_json::Value>) -> Result<()> { + let endpoint = conn.peer_endpoint().expect("get peer endpoint"); + debug!("Handle a new connection {endpoint}"); + + let on_failure = |result: TaskResult<Result<()>>| async move { + if let TaskResult::Completed(Err(err)) = result { + error!("Connection {} dropped: {}", endpoint, err); + } else { + warn!("Connection {} dropped", endpoint); + } + }; + + let selfc = self.clone(); + let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_CAP); + let channel = Channel::new(ch_tx); + self.task_group.spawn( + async move { + loop { + match select(conn.recv(), ch_rx.recv()).await { + Either::Left(msg) => { + // TODO spawn a task + let response = selfc.handle_request(channel.clone(), msg?).await; + debug!("--> {response}"); + conn.send(serde_json::to_value(response)?).await?; + } + Either::Right(msg) => { + let (sub_id, result) = msg?; + let response = message::Notification { + jsonrpc: message::JSONRPC_VERSION.to_string(), + method: None, + params: Some(result), + subscription: Some(sub_id.into()), + }; + debug!("--> {response}"); + conn.send(serde_json::to_value(response)?).await?; + } + } + } + }, + on_failure, + ); + + Ok(()) + } + + fn sanity_check(&self, request: serde_json::Value) -> SanityCheckResult { + let rpc_msg = match serde_json::from_value::<message::Request>(request) { + Ok(m) => m, + Err(_) => { + return SanityCheckResult::ErrRes(pack_err_res( + message::PARSE_ERROR_CODE, + FAILED_TO_PARSE_ERROR_MSG, + None, + )); + } + }; + debug!("<-- {rpc_msg}"); + + let srvc_method_str = rpc_msg.method.clone(); + let srvc_method: Vec<&str> = srvc_method_str.split('.').collect(); + if srvc_method.len() < 2 { + return SanityCheckResult::ErrRes(pack_err_res( + message::INVALID_REQUEST_ERROR_CODE, + INVALID_REQUEST_ERROR_MSG, + Some(rpc_msg.id), + )); + } + + let srvc_name = srvc_method[0].to_string(); + let method_name = srvc_method[1].to_string(); + + SanityCheckResult::NewReq(NewRequest { + srvc_name, + method_name, + msg: rpc_msg, + }) + } + + /// Handles a new request + async fn handle_request( + &self, + channel: ArcChannel, + msg: serde_json::Value, + ) -> message::Response { + let req = match self.sanity_check(msg) { + SanityCheckResult::NewReq(req) => req, + SanityCheckResult::ErrRes(res) => return res, + }; + + if req.msg.subscriber.is_some() { + match self.pubsub_services.get(&req.srvc_name) { + Some(s) => { + self.handle_pubsub_request(channel, s, &req.method_name, req.msg) + .await + } + None => pack_err_res( + message::METHOD_NOT_FOUND_ERROR_CODE, + METHOD_NOT_FOUND_ERROR_MSG, + Some(req.msg.id), + ), + } + } else { + match self.services.get(&req.srvc_name) { + Some(s) => self.handle_call_request(s, &req.method_name, req.msg).await, + None => pack_err_res( + message::METHOD_NOT_FOUND_ERROR_CODE, + METHOD_NOT_FOUND_ERROR_MSG, + Some(req.msg.id), + ), + } + } + } + + /// Handles a call request + async fn handle_call_request( + &self, + service: &Arc<dyn RPCService + 'static>, + method_name: &str, + rpc_msg: message::Request, + ) -> message::Response { + let method = match service.get_method(method_name) { + Some(m) => m, + None => { + return pack_err_res( + message::METHOD_NOT_FOUND_ERROR_CODE, + METHOD_NOT_FOUND_ERROR_MSG, + Some(rpc_msg.id), + ); + } + }; + + let result = match method(rpc_msg.params.clone()).await { + Ok(res) => res, + Err(err) => return self.handle_error(err, rpc_msg.id), + }; + + message::Response { + jsonrpc: message::JSONRPC_VERSION.to_string(), + error: None, + result: Some(result), + id: Some(rpc_msg.id), + subscription: None, + } + } + + /// Handles a pubsub request + async fn handle_pubsub_request( + &self, + channel: ArcChannel, + service: &Arc<dyn PubSubRPCService + 'static>, + method_name: &str, + rpc_msg: message::Request, + ) -> message::Response { + let method = match service.get_pubsub_method(method_name) { + Some(m) => m, + None => { + return pack_err_res( + message::METHOD_NOT_FOUND_ERROR_CODE, + METHOD_NOT_FOUND_ERROR_MSG, + Some(rpc_msg.id), + ); + } + }; + + let result = match method(channel, rpc_msg.params.clone()).await { + Ok(res) => res, + Err(err) => return self.handle_error(err, rpc_msg.id), + }; + + message::Response { + jsonrpc: message::JSONRPC_VERSION.to_string(), + error: None, + result: None, + id: Some(rpc_msg.id), + subscription: Some(result), + } + } + + fn handle_error(&self, err: Error, msg_id: serde_json::Value) -> message::Response { + match err { + Error::ParseJSON(_) => pack_err_res( + message::PARSE_ERROR_CODE, + FAILED_TO_PARSE_ERROR_MSG, + Some(msg_id), + ), + Error::InvalidParams(msg) => { + pack_err_res(message::INVALID_PARAMS_ERROR_CODE, msg, Some(msg_id)) + } + Error::InvalidRequest(msg) => { + pack_err_res(message::INVALID_REQUEST_ERROR_CODE, msg, Some(msg_id)) + } + Error::RPCMethodError(code, msg) => pack_err_res(code, msg, Some(msg_id)), + _ => pack_err_res( + message::INTERNAL_ERROR_CODE, + INTERNAL_ERROR_MSG, + Some(msg_id), + ), + } + } +} + +/// Builder for constructing an RPC [`Server`]. +pub struct ServerBuilder { + endpoint: Endpoint, + tcp_config: TcpConfig, + tls_config: Option<rustls::ServerConfig>, + services: HashMap<String, Arc<dyn RPCService + 'static>>, + pubsub_services: HashMap<String, Arc<dyn PubSubRPCService + 'static>>, +} + +impl ServerBuilder { + /// Adds a new RPC service to the server. + pub fn service(mut self, service: Arc<dyn RPCService>) -> Self { + self.services.insert(service.name(), service); + self + } + + /// Adds a new PubSub RPC service to the server. + pub fn pubsub_service(mut self, service: Arc<dyn PubSubRPCService>) -> Self { + self.pubsub_services.insert(service.name(), service); + self + } + + /// Configure TCP settings for the server. + /// + /// # Example + /// + /// ```ignore + /// let tcp_config = TcpConfig::default(); + /// let server = Server::builder()?.tcp_config(tcp_config)?.build()?; + /// ``` + /// + /// This function will return an error if the endpoint does not support TCP protocols. + pub fn tcp_config(mut self, config: TcpConfig) -> Result<ServerBuilder> { + match self.endpoint { + Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => { + self.tcp_config = config; + Ok(self) + } + _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())), + } + } + + /// Configure TLS settings for the server. + /// + /// # Example + /// + /// ```ignore + /// let tls_config = rustls::ServerConfig::new(...); + /// let server = Server::builder()?.tls_config(tls_config)?.build()?; + /// ``` + /// + /// This function will return an error if the endpoint does not support TLS protocols. + pub fn tls_config(mut self, config: rustls::ServerConfig) -> Result<ServerBuilder> { + match self.endpoint { + Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => { + self.tls_config = Some(config); + Ok(self) + } + _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())), + } + } + + /// Builds the server with the configured options. + pub async fn build(self) -> Result<Arc<Server>> { + self._build(TaskGroup::new()).await + } + + /// Builds the server with the configured options and an executor. + pub async fn build_with_executor(self, ex: Executor) -> Result<Arc<Server>> { + self._build(TaskGroup::with_executor(ex)).await + } + + async fn _build(self, task_group: TaskGroup) -> Result<Arc<Server>> { + let listener: Listener<serde_json::Value> = match self.endpoint { + Endpoint::Tcp(..) | Endpoint::Tls(..) => match &self.tls_config { + Some(conf) => Box::new( + karyon_net::tls::listen( + &self.endpoint, + karyon_net::tls::ServerTlsConfig { + server_config: conf.clone(), + tcp_config: self.tcp_config, + }, + JsonCodec {}, + ) + .await?, + ), + None => Box::new( + karyon_net::tcp::listen(&self.endpoint, self.tcp_config, JsonCodec {}).await?, + ), + }, + #[cfg(feature = "ws")] + Endpoint::Ws(..) | Endpoint::Wss(..) => match &self.tls_config { + Some(conf) => Box::new( + karyon_net::ws::listen( + &self.endpoint, + ServerWsConfig { + tcp_config: self.tcp_config, + wss_config: Some(karyon_net::ws::ServerWssConfig { + server_config: conf.clone(), + }), + }, + WsJsonCodec {}, + ) + .await?, + ), + None => { + let config = ServerWsConfig { + tcp_config: self.tcp_config, + wss_config: None, + }; + Box::new(karyon_net::ws::listen(&self.endpoint, config, WsJsonCodec {}).await?) + } + }, + #[cfg(all(feature = "unix", target_family = "unix"))] + Endpoint::Unix(..) => Box::new(karyon_net::unix::listen( + &self.endpoint, + Default::default(), + JsonCodec {}, + )?), + + _ => return Err(Error::UnsupportedProtocol(self.endpoint.to_string())), + }; + + Ok(Arc::new(Server { + listener, + task_group, + services: self.services, + pubsub_services: self.pubsub_services, + })) + } +} + +impl Server { + /// Creates a new [`ServerBuilder`] + /// + /// This function initializes a `ServerBuilder` with the specified endpoint. + /// + /// # Example + /// + /// ```ignore + /// let builder = Server::builder("ws://127.0.0.1:3000")?.build()?; + /// ``` + pub fn builder(endpoint: impl ToEndpoint) -> Result<ServerBuilder> { + let endpoint = endpoint.to_endpoint()?; + Ok(ServerBuilder { + endpoint, + services: HashMap::new(), + pubsub_services: HashMap::new(), + tcp_config: Default::default(), + tls_config: None, + }) + } +} diff --git a/jsonrpc/src/server/pubsub_service.rs b/jsonrpc/src/server/pubsub_service.rs new file mode 100644 index 0000000..5b4bf9a --- /dev/null +++ b/jsonrpc/src/server/pubsub_service.rs @@ -0,0 +1,67 @@ +use std::{future::Future, pin::Pin}; + +use crate::Result; + +use super::channel::ArcChannel; + +/// Represents the RPC method +pub type PubSubRPCMethod<'a> = + Box<dyn Fn(ArcChannel, serde_json::Value) -> PubSubRPCMethodOutput<'a> + Send + 'a>; +type PubSubRPCMethodOutput<'a> = + Pin<Box<dyn Future<Output = Result<serde_json::Value>> + Send + Sync + 'a>>; + +/// Defines the interface for an RPC service. +pub trait PubSubRPCService: Sync + Send { + fn get_pubsub_method<'a>(&'a self, name: &'a str) -> Option<PubSubRPCMethod>; + fn name(&self) -> String; +} + +/// Implements the [`PubSubRPCService`] trait for a provided type. +/// +/// # Example +/// +/// ``` +/// use serde_json::Value; +/// +/// use karyon_jsonrpc::{Error, impl_rpc_service}; +/// +/// struct Hello {} +/// +/// impl Hello { +/// async fn foo(&self, params: Value) -> Result<Value, Error> { +/// Ok(serde_json::json!("foo!")) +/// } +/// +/// async fn bar(&self, params: Value) -> Result<Value, Error> { +/// Ok(serde_json::json!("bar!")) +/// } +/// } +/// +/// impl_rpc_service!(Hello, foo, bar); +/// +/// ``` +#[macro_export] +macro_rules! impl_pubsub_rpc_service { + ($t:ty, $($m:ident),*) => { + impl karyon_jsonrpc::PubSubRPCService for $t { + fn get_pubsub_method<'a>( + &'a self, + name: &'a str + ) -> Option<karyon_jsonrpc::PubSubRPCMethod> { + match name { + $( + stringify!($m) => { + Some(Box::new(move |chan: karyon_jsonrpc::ArcChannel, params: serde_json::Value| Box::pin(self.$m(chan, params)))) + } + )* + _ => None, + } + + + } + fn name(&self) -> String{ + stringify!($t).to_string() + } + } + }; +} diff --git a/jsonrpc/src/server/service.rs b/jsonrpc/src/server/service.rs new file mode 100644 index 0000000..4c8c4b8 --- /dev/null +++ b/jsonrpc/src/server/service.rs @@ -0,0 +1,64 @@ +use std::{future::Future, pin::Pin}; + +use crate::Result; + +/// Represents the RPC method +pub type RPCMethod<'a> = Box<dyn Fn(serde_json::Value) -> RPCMethodOutput<'a> + Send + 'a>; +type RPCMethodOutput<'a> = + Pin<Box<dyn Future<Output = Result<serde_json::Value>> + Send + Sync + 'a>>; + +/// Defines the interface for an RPC service. +pub trait RPCService: Sync + Send { + fn get_method<'a>(&'a self, name: &'a str) -> Option<RPCMethod>; + fn name(&self) -> String; +} + +/// Implements the [`RPCService`] trait for a provided type. +/// +/// # Example +/// +/// ``` +/// use serde_json::Value; +/// +/// use karyon_jsonrpc::{Error, impl_rpc_service}; +/// +/// struct Hello {} +/// +/// impl Hello { +/// async fn foo(&self, params: Value) -> Result<Value, Error> { +/// Ok(serde_json::json!("foo!")) +/// } +/// +/// async fn bar(&self, params: Value) -> Result<Value, Error> { +/// Ok(serde_json::json!("bar!")) +/// } +/// } +/// +/// impl_rpc_service!(Hello, foo, bar); +/// +/// ``` +#[macro_export] +macro_rules! impl_rpc_service { + ($t:ty, $($m:ident),*) => { + impl karyon_jsonrpc::RPCService for $t { + fn get_method<'a>( + &'a self, + name: &'a str + ) -> Option<karyon_jsonrpc::RPCMethod> { + match name { + $( + stringify!($m) => { + Some(Box::new(move |params: serde_json::Value| Box::pin(self.$m(params)))) + } + )* + _ => None, + } + + + } + fn name(&self) -> String{ + stringify!($t).to_string() + } + } + }; +} |