diff options
Diffstat (limited to 'jsonrpc/src')
-rw-r--r-- | jsonrpc/src/client.rs | 77 | ||||
-rw-r--r-- | jsonrpc/src/error.rs | 34 | ||||
-rw-r--r-- | jsonrpc/src/lib.rs | 61 | ||||
-rw-r--r-- | jsonrpc/src/message.rs | 91 | ||||
-rw-r--r-- | jsonrpc/src/server.rs | 263 | ||||
-rw-r--r-- | jsonrpc/src/utils.rs | 63 |
6 files changed, 589 insertions, 0 deletions
diff --git a/jsonrpc/src/client.rs b/jsonrpc/src/client.rs new file mode 100644 index 0000000..2863204 --- /dev/null +++ b/jsonrpc/src/client.rs @@ -0,0 +1,77 @@ +use log::debug; +use serde::{de::DeserializeOwned, Serialize}; + +use karyons_core::{async_utils::timeout, utils::random_32}; +use karyons_net::{dial, Conn, Endpoint}; + +use crate::{ + message, + utils::{read_until, write_all}, + Error, Result, JSONRPC_VERSION, +}; + +/// Represents an RPC client +pub struct Client { + conn: Conn, + timeout: Option<u64>, +} + +impl Client { + /// Creates a new RPC client. + pub fn new(conn: Conn, timeout: Option<u64>) -> Self { + Self { conn, timeout } + } + + /// Creates a new RPC client using the provided endpoint. + pub async fn new_with_endpoint(endpoint: &Endpoint, timeout: Option<u64>) -> Result<Self> { + let conn = dial(endpoint).await?; + Ok(Self { conn, timeout }) + } + + /// Calls the named method, waits for the response, and returns the result. + pub async fn call<T: Serialize + DeserializeOwned, V: DeserializeOwned>( + &self, + method: &str, + params: T, + ) -> Result<V> { + let id = serde_json::json!(random_32()); + + let request = message::Request { + jsonrpc: JSONRPC_VERSION.to_string(), + id, + method: method.to_string(), + params: serde_json::json!(params), + }; + + let payload = serde_json::to_vec(&request)?; + write_all(&self.conn, &payload).await?; + debug!("--> {request}"); + + let mut buffer = vec![]; + if let Some(t) = self.timeout { + timeout( + std::time::Duration::from_secs(t), + read_until(&self.conn, &mut buffer), + ) + .await? + } else { + read_until(&self.conn, &mut buffer).await + }?; + + let response = serde_json::from_slice::<message::Response>(&buffer)?; + debug!("<-- {response}"); + + if let Some(error) = response.error { + return Err(Error::CallError(error.code, error.message)); + } + + if response.id.is_none() || response.id.unwrap() != request.id { + return Err(Error::InvalidMsg("Invalid response id")); + } + + match response.result { + Some(result) => Ok(serde_json::from_value::<V>(result)?), + None => Err(Error::InvalidMsg("Invalid response result")), + } + } +} diff --git a/jsonrpc/src/error.rs b/jsonrpc/src/error.rs new file mode 100644 index 0000000..5437c6d --- /dev/null +++ b/jsonrpc/src/error.rs @@ -0,0 +1,34 @@ +use thiserror::Error as ThisError; + +pub type Result<T> = std::result::Result<T, Error>; + +/// Represents karyons's jsonrpc Error. +#[derive(ThisError, Debug)] +pub enum Error { + #[error(transparent)] + IO(#[from] std::io::Error), + + #[error("Call Error: code: {0} msg: {1}")] + CallError(i32, String), + + #[error("RPC Method Error: code: {0} msg: {1}")] + RPCMethodError(i32, &'static str), + + #[error("Invalid Params: {0}")] + InvalidParams(&'static str), + + #[error("Invalid Request: {0}")] + InvalidRequest(&'static str), + + #[error(transparent)] + ParseJSON(#[from] serde_json::Error), + + #[error("Invalid Message Error: {0}")] + InvalidMsg(&'static str), + + #[error(transparent)] + KaryonsCore(#[from] karyons_core::error::Error), + + #[error(transparent)] + KaryonsNet(#[from] karyons_net::NetError), +} diff --git a/jsonrpc/src/lib.rs b/jsonrpc/src/lib.rs new file mode 100644 index 0000000..8a547d9 --- /dev/null +++ b/jsonrpc/src/lib.rs @@ -0,0 +1,61 @@ +//! A fast and lightweight async [JSONRPC 2.0](https://www.jsonrpc.org/specification) implementation. +//! +//! # Example +//! +//! ``` +//! use std::sync::Arc; +//! +//! use serde_json::Value; +//! +//! use karyons_jsonrpc::{JsonRPCError, Server, Client, register_service}; +//! +//! struct HelloWorld {} +//! +//! impl HelloWorld { +//! async fn say_hello(&self, params: Value) -> Result<Value, JsonRPCError> { +//! let msg: String = serde_json::from_value(params)?; +//! Ok(serde_json::json!(format!("Hello {msg}!"))) +//! } +//! } +//! +//! // Server +//! async { +//! let ex = Arc::new(smol::Executor::new()); +//! +//! // Creates a new server +//! let endpoint = "tcp://127.0.0.1:60000".parse().unwrap(); +//! let server = Server::new_with_endpoint(&endpoint, ex.clone()).await.unwrap(); +//! +//! // Register the HelloWorld service +//! register_service!(HelloWorld, say_hello); +//! server.attach_service(HelloWorld{}); +//! +//! // Starts the server +//! ex.run(server.start()); +//! }; +//! +//! // Client +//! async { +//! +//! // Creates a new client +//! let endpoint = "tcp://127.0.0.1:60000".parse().unwrap(); +//! let client = Client::new_with_endpoint(&endpoint, None).await.unwrap(); +//! +//! let result: String = client.call("HelloWorld.say_hello", "world".to_string()).await.unwrap(); +//! }; +//! +//! ``` + +mod client; +mod error; +pub mod message; +mod server; +mod utils; + +pub const JSONRPC_VERSION: &str = "2.0"; + +use error::{Error, Result}; + +pub use client::Client; +pub use error::Error as JsonRPCError; +pub use server::{RPCMethod, RPCService, Server}; diff --git a/jsonrpc/src/message.rs b/jsonrpc/src/message.rs new file mode 100644 index 0000000..89ef613 --- /dev/null +++ b/jsonrpc/src/message.rs @@ -0,0 +1,91 @@ +use serde::{Deserialize, Serialize}; + +/// Parse error: Invalid JSON was received by the server. +pub const PARSE_ERROR_CODE: i32 = -32700; + +/// Invalid request: The JSON sent is not a valid Request object. +pub const INVALID_REQUEST_ERROR_CODE: i32 = -32600; + +/// Method not found: The method does not exist / is not available. +pub const METHOD_NOT_FOUND_ERROR_CODE: i32 = -32601; + +/// Invalid params: Invalid method parameter(s). +pub const INVALID_PARAMS_ERROR_CODE: i32 = -32602; + +/// Internal error: Internal JSON-RPC error. +pub const INTERNAL_ERROR_CODE: i32 = -32603; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Request { + pub jsonrpc: String, + pub method: String, + pub params: serde_json::Value, + pub id: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Response { + pub jsonrpc: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option<serde_json::Value>, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option<Error>, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option<serde_json::Value>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Error { + pub code: i32, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option<serde_json::Value>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Notification { + pub jsonrpc: String, + pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option<serde_json::Value>, +} + +impl std::fmt::Display for Request { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{{jsonrpc: {}, method: {}, params: {:?}, id: {:?}}}", + self.jsonrpc, self.method, self.params, self.id + ) + } +} + +impl std::fmt::Display for Response { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{{jsonrpc: {}, result': {:?}, error: {:?} , id: {:?}}}", + self.jsonrpc, self.result, self.error, self.id + ) + } +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "RpcError {{ code: {}, message: {}, data: {:?} }} ", + self.code, self.message, self.data + ) + } +} + +impl std::fmt::Display for Notification { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{{jsonrpc: {}, method: {}, params: {:?}}}", + self.jsonrpc, self.method, self.params + ) + } +} diff --git a/jsonrpc/src/server.rs b/jsonrpc/src/server.rs new file mode 100644 index 0000000..9642381 --- /dev/null +++ b/jsonrpc/src/server.rs @@ -0,0 +1,263 @@ +use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc}; + +use log::{debug, error, warn}; +use smol::lock::RwLock; + +use karyons_core::{ + async_utils::{TaskGroup, TaskResult}, + Executor, +}; +use karyons_net::{listen, Conn, Endpoint, Listener}; + +use crate::{ + message, + utils::{read_until, write_all}, + Error, Result, JSONRPC_VERSION, +}; + +/// Represents an RPC server +pub struct Server<'a> { + listener: Box<dyn Listener>, + services: RwLock<HashMap<String, Box<dyn RPCService + 'a>>>, + task_group: TaskGroup<'a>, +} + +impl<'a> Server<'a> { + /// Creates a new RPC server. + pub fn new(listener: Box<dyn Listener>, ex: Executor<'a>) -> Arc<Self> { + Arc::new(Self { + listener, + services: RwLock::new(HashMap::new()), + task_group: TaskGroup::new(ex), + }) + } + + /// Creates a new RPC server using the provided endpoint. + pub async fn new_with_endpoint(endpoint: &Endpoint, ex: Executor<'a>) -> Result<Arc<Self>> { + let listener = listen(endpoint).await?; + Ok(Arc::new(Self { + listener, + services: RwLock::new(HashMap::new()), + task_group: TaskGroup::new(ex), + })) + } + + /// Starts the RPC server + pub async fn start(self: Arc<Self>) -> Result<()> { + loop { + let conn = self.listener.accept().await?; + self.handle_conn(conn).await?; + } + } + + /// Attach a new service to the RPC server + pub async fn attach_service(&self, service: impl RPCService + 'a) { + self.services + .write() + .await + .insert(service.name(), Box::new(service)); + } + + /// Shuts down the RPC server + pub async fn shutdown(&self) { + self.task_group.cancel().await; + } + + /// Handles a new connection + async fn handle_conn(self: &Arc<Self>, conn: Conn) -> Result<()> { + let endpoint = conn.peer_endpoint()?; + debug!("Handle a new connection {endpoint}"); + + let on_failure = |result: TaskResult<Result<()>>| async move { + if let TaskResult::Completed(Err(err)) = result { + error!("Connection {} dropped: {}", endpoint, err); + } else { + warn!("Connection {} dropped", endpoint); + } + }; + + let selfc = self.clone(); + self.task_group.spawn( + async move { + loop { + let mut buffer = vec![]; + read_until(&conn, &mut buffer).await?; + let response = selfc.handle_request(&buffer).await; + let payload = serde_json::to_vec(&response)?; + write_all(&conn, &payload).await?; + debug!("--> {response}"); + } + }, + on_failure, + ); + + Ok(()) + } + + /// Handles a request + async fn handle_request(&self, buffer: &[u8]) -> message::Response { + let rpc_msg = match serde_json::from_slice::<message::Request>(buffer) { + Ok(m) => m, + Err(_) => { + return self.pack_err_res(message::PARSE_ERROR_CODE, "Failed to parse", None); + } + }; + + debug!("<-- {rpc_msg}"); + + let srvc_method: Vec<&str> = rpc_msg.method.split('.').collect(); + if srvc_method.len() != 2 { + return self.pack_err_res( + message::INVALID_REQUEST_ERROR_CODE, + "Invalid request", + Some(rpc_msg.id), + ); + } + + let srvc_name = srvc_method[0]; + let method_name = srvc_method[1]; + + let services = self.services.read().await; + + let service = match services.get(srvc_name) { + Some(s) => s, + None => { + return self.pack_err_res( + message::METHOD_NOT_FOUND_ERROR_CODE, + "Method not found", + Some(rpc_msg.id), + ); + } + }; + + let method = match service.get_method(method_name) { + Some(m) => m, + None => { + return self.pack_err_res( + message::METHOD_NOT_FOUND_ERROR_CODE, + "Method not found", + Some(rpc_msg.id), + ); + } + }; + + let result = match method(rpc_msg.params.clone()).await { + Ok(res) => res, + Err(Error::ParseJSON(_)) => { + return self.pack_err_res( + message::PARSE_ERROR_CODE, + "Failed to parse", + Some(rpc_msg.id), + ); + } + Err(Error::InvalidParams(msg)) => { + return self.pack_err_res( + message::INVALID_PARAMS_ERROR_CODE, + msg, + Some(rpc_msg.id), + ); + } + Err(Error::InvalidRequest(msg)) => { + return self.pack_err_res( + message::INVALID_REQUEST_ERROR_CODE, + msg, + Some(rpc_msg.id), + ); + } + Err(Error::RPCMethodError(code, msg)) => { + return self.pack_err_res(code, msg, Some(rpc_msg.id)); + } + Err(_) => { + return self.pack_err_res( + message::INTERNAL_ERROR_CODE, + "Internal error", + Some(rpc_msg.id), + ); + } + }; + + message::Response { + jsonrpc: JSONRPC_VERSION.to_string(), + error: None, + result: Some(result), + id: Some(rpc_msg.id), + } + } + + fn pack_err_res( + &self, + code: i32, + msg: &str, + id: Option<serde_json::Value>, + ) -> message::Response { + let err = message::Error { + code, + message: msg.to_string(), + data: None, + }; + + message::Response { + jsonrpc: JSONRPC_VERSION.to_string(), + error: Some(err), + result: None, + id, + } + } +} + +/// Represents the RPC method +pub type RPCMethod<'a> = Box<dyn Fn(serde_json::Value) -> RPCMethodOutput<'a> + Send + 'a>; +type RPCMethodOutput<'a> = + Pin<Box<dyn Future<Output = Result<serde_json::Value>> + Send + Sync + 'a>>; + +/// Defines the interface for an RPC service. +pub trait RPCService: Sync + Send { + fn get_method<'a>(&'a self, name: &'a str) -> Option<RPCMethod>; + fn name(&self) -> String; +} + +/// Implements the `RPCService` trait for a provided type. +/// +/// # Example +/// +/// ``` +/// use serde_json::Value; +/// +/// use karyons_jsonrpc::{JsonRPCError, register_service}; +/// +/// struct Hello {} +/// +/// impl Hello { +/// async fn say_hello(&self, params: Value) -> Result<Value, JsonRPCError> { +/// Ok(serde_json::json!("hello!")) +/// } +/// } +/// +/// register_service!(Hello, say_hello); +/// +/// ``` +#[macro_export] +macro_rules! register_service { + ($t:ty, $($m:ident),*) => { + impl karyons_jsonrpc::RPCService for $t { + fn get_method<'a>( + &'a self, + name: &'a str + ) -> Option<karyons_jsonrpc::RPCMethod> { + match name { + $( + stringify!($m) => { + Some(Box::new(move |params: serde_json::Value| Box::pin(self.$m(params)))) + } + )* + _ => None, + } + + + } + fn name(&self) -> String{ + stringify!($t).to_string() + } + } + }; +} diff --git a/jsonrpc/src/utils.rs b/jsonrpc/src/utils.rs new file mode 100644 index 0000000..1f21b7a --- /dev/null +++ b/jsonrpc/src/utils.rs @@ -0,0 +1,63 @@ +use memchr::memchr; + +use karyons_net::Conn; + +use crate::{Error, Result}; + +const DEFAULT_MSG_SIZE: usize = 1024; +const MAX_ALLOWED_MSG_SIZE: usize = 1024 * 1024; // 1MB + +// TODO: Add unit tests for these functions. + +/// Read all bytes into `buffer` until the `0x0` byte or EOF is +/// reached. +/// +/// If successful, this function will return the total number of bytes read. +pub async fn read_until(conn: &Conn, buffer: &mut Vec<u8>) -> Result<usize> { + let delim = b'\0'; + + let mut read = 0; + + loop { + let mut tmp_buf = [0; DEFAULT_MSG_SIZE]; + let n = conn.read(&mut tmp_buf).await?; + if n == 0 { + return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); + } + + match memchr(delim, &tmp_buf) { + Some(i) => { + buffer.extend_from_slice(&tmp_buf[..i]); + read += i; + break; + } + None => { + buffer.extend_from_slice(&tmp_buf); + read += tmp_buf.len(); + } + } + + if buffer.len() == MAX_ALLOWED_MSG_SIZE { + return Err(Error::InvalidMsg( + "Message exceeds the maximum allowed size", + )); + } + } + + Ok(read) +} + +/// Writes an entire buffer into the given connection. +pub async fn write_all(conn: &Conn, mut buf: &[u8]) -> Result<()> { + while !buf.is_empty() { + let n = conn.write(buf).await?; + let (_, rest) = std::mem::take(&mut buf).split_at(n); + buf = rest; + + if n == 0 { + return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); + } + } + + Ok(()) +} |