From 598f9e2d47da80f2bec2ead9c2fe215eff157936 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Mon, 20 Nov 2023 23:15:10 +0300 Subject: jsonrpc: add Codec struct for reading from and writing to the connection --- jsonrpc/README.md | 8 ++-- jsonrpc/examples/client.rs | 5 ++- jsonrpc/examples/server.rs | 7 +++- jsonrpc/src/client.rs | 49 +++++++++++++--------- jsonrpc/src/codec.rs | 100 +++++++++++++++++++++++++++++++++++++++++++++ jsonrpc/src/lib.rs | 15 ++++--- jsonrpc/src/server.rs | 25 +++++++++--- jsonrpc/src/utils.rs | 63 ---------------------------- 8 files changed, 172 insertions(+), 100 deletions(-) create mode 100644 jsonrpc/src/codec.rs delete mode 100644 jsonrpc/src/utils.rs diff --git a/jsonrpc/README.md b/jsonrpc/README.md index 52a9146..d937071 100644 --- a/jsonrpc/README.md +++ b/jsonrpc/README.md @@ -9,7 +9,7 @@ use std::sync::Arc; use serde_json::Value; -use karyons_jsonrpc::{JsonRPCError, Server, Client, register_service}; +use karyons_jsonrpc::{JsonRPCError, Server, Client, register_service, ServerConfig, ClientConfig}; struct HelloWorld {} @@ -27,7 +27,8 @@ let ex = Arc::new(smol::Executor::new()); // Creates a new server let endpoint = "tcp://127.0.0.1:60000".parse().unwrap(); -let server = Server::new_with_endpoint(&endpoint, ex.clone()).await.unwrap(); +let config = ServerConfig::default(); +let server = Server::new_with_endpoint(&endpoint, config, ex.clone()).await.unwrap(); // Register the HelloWorld service register_service!(HelloWorld, say_hello); @@ -41,7 +42,8 @@ ex.run(server.start()); ////////////////// // Creates a new client let endpoint = "tcp://127.0.0.1:60000".parse().unwrap(); -let client = Client::new_with_endpoint(&endpoint, None).await.unwrap(); +let config = ClientConfig::default(); +let client = Client::new_with_endpoint(&endpoint, config).await.unwrap(); let result: String = client.call("HelloWorld.say_hello", "world".to_string()).await.unwrap(); diff --git a/jsonrpc/examples/client.rs b/jsonrpc/examples/client.rs index 0063098..6b60233 100644 --- a/jsonrpc/examples/client.rs +++ b/jsonrpc/examples/client.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -use karyons_jsonrpc::Client; +use karyons_jsonrpc::{Client, ClientConfig}; #[derive(Deserialize, Serialize)] struct Req { @@ -15,7 +15,8 @@ fn main() { env_logger::init(); smol::future::block_on(async { let endpoint = "tcp://127.0.0.1:60000".parse().unwrap(); - let client = Client::new_with_endpoint(&endpoint, None).await.unwrap(); + let config = ClientConfig::default(); + let client = Client::new_with_endpoint(&endpoint, config).await.unwrap(); let params = Req { x: 10, y: 7 }; let result: u32 = client.call("Calc.add", params).await.unwrap(); diff --git a/jsonrpc/examples/server.rs b/jsonrpc/examples/server.rs index 367bfe9..512913a 100644 --- a/jsonrpc/examples/server.rs +++ b/jsonrpc/examples/server.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use serde::{Deserialize, Serialize}; use serde_json::Value; -use karyons_jsonrpc::{register_service, JsonRPCError, Server}; +use karyons_jsonrpc::{register_service, JsonRPCError, Server, ServerConfig}; struct Calc { version: String, @@ -44,7 +44,10 @@ fn main() { smol::block_on(ex.clone().run(async { // Creates a new server let endpoint = "tcp://127.0.0.1:60000".parse().unwrap(); - let server = Server::new_with_endpoint(&endpoint, ex).await.unwrap(); + let config = ServerConfig::default(); + let server = Server::new_with_endpoint(&endpoint, config, ex) + .await + .unwrap(); // Register the Calc service register_service!(Calc, ping, add, sub, version); diff --git a/jsonrpc/src/client.rs b/jsonrpc/src/client.rs index 2863204..d5caebe 100644 --- a/jsonrpc/src/client.rs +++ b/jsonrpc/src/client.rs @@ -1,31 +1,46 @@ use log::debug; use serde::{de::DeserializeOwned, Serialize}; -use karyons_core::{async_utils::timeout, utils::random_32}; +use karyons_core::utils::random_32; use karyons_net::{dial, Conn, Endpoint}; use crate::{ - message, - utils::{read_until, write_all}, - Error, Result, JSONRPC_VERSION, + codec::{Codec, CodecConfig}, + message, Error, Result, JSONRPC_VERSION, }; +/// Represents client config +#[derive(Default)] +pub struct ClientConfig { + pub timeout: Option, +} + /// Represents an RPC client pub struct Client { - conn: Conn, - timeout: Option, + codec: Codec, + config: ClientConfig, } impl Client { /// Creates a new RPC client. - pub fn new(conn: Conn, timeout: Option) -> Self { - Self { conn, timeout } + pub fn new(conn: Conn, config: ClientConfig) -> Self { + let codec_config = CodecConfig { + max_allowed_msg_size: 0, + ..Default::default() + }; + let codec = Codec::new(conn, codec_config); + Self { codec, config } } /// Creates a new RPC client using the provided endpoint. - pub async fn new_with_endpoint(endpoint: &Endpoint, timeout: Option) -> Result { + pub async fn new_with_endpoint(endpoint: &Endpoint, config: ClientConfig) -> Result { let conn = dial(endpoint).await?; - Ok(Self { conn, timeout }) + let codec_config = CodecConfig { + max_allowed_msg_size: 0, + ..Default::default() + }; + let codec = Codec::new(conn, codec_config); + Ok(Self { codec, config }) } /// Calls the named method, waits for the response, and returns the result. @@ -44,19 +59,15 @@ impl Client { }; let payload = serde_json::to_vec(&request)?; - write_all(&self.conn, &payload).await?; + self.codec.write_all(&payload).await?; debug!("--> {request}"); let mut buffer = vec![]; - if let Some(t) = self.timeout { - timeout( - std::time::Duration::from_secs(t), - read_until(&self.conn, &mut buffer), - ) - .await? + if let Some(t) = self.config.timeout { + self.codec.read_until_timeout(&mut buffer, t).await?; } else { - read_until(&self.conn, &mut buffer).await - }?; + self.codec.read_until(&mut buffer).await?; + }; let response = serde_json::from_slice::(&buffer)?; debug!("<-- {response}"); diff --git a/jsonrpc/src/codec.rs b/jsonrpc/src/codec.rs new file mode 100644 index 0000000..ea97e54 --- /dev/null +++ b/jsonrpc/src/codec.rs @@ -0,0 +1,100 @@ +use memchr::memchr; + +use karyons_core::async_utils::timeout; +use karyons_net::Conn; + +use crate::{Error, Result}; + +const DEFAULT_BUFFER_SIZE: usize = 1024; +const DEFAULT_MAX_ALLOWED_MSG_SIZE: usize = 1024 * 1024; // 1MB + +// TODO: Add unit tests for Codec's functions. + +/// Represents Codec config +#[derive(Clone)] +pub struct CodecConfig { + pub default_buffer_size: usize, + /// The maximum allowed size to receive a message. If set to zero, there + /// will be no size limit. + pub max_allowed_msg_size: usize, +} + +impl Default for CodecConfig { + fn default() -> Self { + Self { + default_buffer_size: DEFAULT_BUFFER_SIZE, + max_allowed_msg_size: DEFAULT_MAX_ALLOWED_MSG_SIZE, + } + } +} + +pub struct Codec { + conn: Conn, + config: CodecConfig, +} + +impl Codec { + /// Creates a new Codec + pub fn new(conn: Conn, config: CodecConfig) -> Self { + Self { conn, config } + } + + /// Read all bytes into `buffer` until the `0x0` byte or EOF is + /// reached. + /// + /// If successful, this function will return the total number of bytes read. + pub async fn read_until(&self, buffer: &mut Vec) -> Result { + let delim = b'\0'; + + let mut read = 0; + + loop { + let mut tmp_buf = vec![0; self.config.default_buffer_size]; + let n = self.conn.read(&mut tmp_buf).await?; + if n == 0 { + return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); + } + + match memchr(delim, &tmp_buf) { + Some(i) => { + buffer.extend_from_slice(&tmp_buf[..i]); + read += i; + break; + } + None => { + buffer.extend_from_slice(&tmp_buf); + read += tmp_buf.len(); + } + } + + if self.config.max_allowed_msg_size != 0 + && buffer.len() == self.config.max_allowed_msg_size + { + return Err(Error::InvalidMsg( + "Message exceeds the maximum allowed size", + )); + } + } + + Ok(read) + } + + /// Writes an entire buffer into the given connection. + pub async fn write_all(&self, mut buf: &[u8]) -> Result<()> { + while !buf.is_empty() { + let n = self.conn.write(buf).await?; + let (_, rest) = std::mem::take(&mut buf).split_at(n); + buf = rest; + + if n == 0 { + return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); + } + } + + Ok(()) + } + + pub async fn read_until_timeout(&self, buffer: &mut Vec, t: u64) -> Result { + timeout(std::time::Duration::from_secs(t), self.read_until(buffer)).await? + } +} diff --git a/jsonrpc/src/lib.rs b/jsonrpc/src/lib.rs index 2ac89c9..f73b5e6 100644 --- a/jsonrpc/src/lib.rs +++ b/jsonrpc/src/lib.rs @@ -7,7 +7,7 @@ //! //! use serde_json::Value; //! -//! use karyons_jsonrpc::{JsonRPCError, Server, Client, register_service}; +//! use karyons_jsonrpc::{JsonRPCError, Server, Client, register_service, ServerConfig, ClientConfig}; //! //! struct HelloWorld {} //! @@ -24,7 +24,8 @@ //! //! // Creates a new server //! let endpoint = "tcp://127.0.0.1:60000".parse().unwrap(); -//! let server = Server::new_with_endpoint(&endpoint, ex.clone()).await.unwrap(); +//! let config = ServerConfig::default(); +//! let server = Server::new_with_endpoint(&endpoint, config, ex.clone()).await.unwrap(); //! //! // Register the HelloWorld service //! register_service!(HelloWorld, say_hello); @@ -39,7 +40,8 @@ //! //! // Creates a new client //! let endpoint = "tcp://127.0.0.1:60000".parse().unwrap(); -//! let client = Client::new_with_endpoint(&endpoint, None).await.unwrap(); +//! let config = ClientConfig::default(); +//! let client = Client::new_with_endpoint(&endpoint, config).await.unwrap(); //! //! let result: String = client.call("HelloWorld.say_hello", "world".to_string()).await.unwrap(); //! }; @@ -47,17 +49,18 @@ //! ``` mod client; +mod codec; mod error; pub mod message; mod server; mod service; -mod utils; pub const JSONRPC_VERSION: &str = "2.0"; use error::{Error, Result}; -pub use client::Client; +pub use client::{Client, ClientConfig}; +pub use codec::CodecConfig; pub use error::Error as JsonRPCError; -pub use server::Server; +pub use server::{Server, ServerConfig}; pub use service::{RPCMethod, RPCService}; diff --git a/jsonrpc/src/server.rs b/jsonrpc/src/server.rs index 6c01a96..133f261 100644 --- a/jsonrpc/src/server.rs +++ b/jsonrpc/src/server.rs @@ -10,36 +10,49 @@ use karyons_core::{ use karyons_net::{listen, Conn, Endpoint, Listener}; use crate::{ + codec::{Codec, CodecConfig}, message, service::RPCService, - utils::{read_until, write_all}, Error, Result, JSONRPC_VERSION, }; +/// Represents an RPC server +#[derive(Default)] +pub struct ServerConfig { + codec_config: CodecConfig, +} + /// Represents an RPC server pub struct Server<'a> { listener: Box, services: RwLock>>, task_group: TaskGroup<'a>, + config: ServerConfig, } impl<'a> Server<'a> { /// Creates a new RPC server. - pub fn new(listener: Box, ex: Executor<'a>) -> Arc { + pub fn new(listener: Box, config: ServerConfig, ex: Executor<'a>) -> Arc { Arc::new(Self { listener, services: RwLock::new(HashMap::new()), task_group: TaskGroup::new(ex), + config, }) } /// Creates a new RPC server using the provided endpoint. - pub async fn new_with_endpoint(endpoint: &Endpoint, ex: Executor<'a>) -> Result> { + pub async fn new_with_endpoint( + endpoint: &Endpoint, + config: ServerConfig, + ex: Executor<'a>, + ) -> Result> { let listener = listen(endpoint).await?; Ok(Arc::new(Self { listener, services: RwLock::new(HashMap::new()), task_group: TaskGroup::new(ex), + config, })) } @@ -77,15 +90,17 @@ impl<'a> Server<'a> { } }; + let codec = Codec::new(conn, self.config.codec_config.clone()); + let selfc = self.clone(); self.task_group.spawn( async move { loop { let mut buffer = vec![]; - read_until(&conn, &mut buffer).await?; + codec.read_until(&mut buffer).await?; let response = selfc.handle_request(&buffer).await; let payload = serde_json::to_vec(&response)?; - write_all(&conn, &payload).await?; + codec.write_all(&payload).await?; debug!("--> {response}"); } }, diff --git a/jsonrpc/src/utils.rs b/jsonrpc/src/utils.rs deleted file mode 100644 index 1f21b7a..0000000 --- a/jsonrpc/src/utils.rs +++ /dev/null @@ -1,63 +0,0 @@ -use memchr::memchr; - -use karyons_net::Conn; - -use crate::{Error, Result}; - -const DEFAULT_MSG_SIZE: usize = 1024; -const MAX_ALLOWED_MSG_SIZE: usize = 1024 * 1024; // 1MB - -// TODO: Add unit tests for these functions. - -/// Read all bytes into `buffer` until the `0x0` byte or EOF is -/// reached. -/// -/// If successful, this function will return the total number of bytes read. -pub async fn read_until(conn: &Conn, buffer: &mut Vec) -> Result { - let delim = b'\0'; - - let mut read = 0; - - loop { - let mut tmp_buf = [0; DEFAULT_MSG_SIZE]; - let n = conn.read(&mut tmp_buf).await?; - if n == 0 { - return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); - } - - match memchr(delim, &tmp_buf) { - Some(i) => { - buffer.extend_from_slice(&tmp_buf[..i]); - read += i; - break; - } - None => { - buffer.extend_from_slice(&tmp_buf); - read += tmp_buf.len(); - } - } - - if buffer.len() == MAX_ALLOWED_MSG_SIZE { - return Err(Error::InvalidMsg( - "Message exceeds the maximum allowed size", - )); - } - } - - Ok(read) -} - -/// Writes an entire buffer into the given connection. -pub async fn write_all(conn: &Conn, mut buf: &[u8]) -> Result<()> { - while !buf.is_empty() { - let n = conn.write(buf).await?; - let (_, rest) = std::mem::take(&mut buf).split_at(n); - buf = rest; - - if n == 0 { - return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); - } - } - - Ok(()) -} -- cgit v1.2.3