aboutsummaryrefslogtreecommitdiff
path: root/jsonrpc/src
diff options
context:
space:
mode:
Diffstat (limited to 'jsonrpc/src')
-rw-r--r--jsonrpc/src/client.rs49
-rw-r--r--jsonrpc/src/codec.rs100
-rw-r--r--jsonrpc/src/lib.rs15
-rw-r--r--jsonrpc/src/server.rs25
-rw-r--r--jsonrpc/src/utils.rs63
5 files changed, 159 insertions, 93 deletions
diff --git a/jsonrpc/src/client.rs b/jsonrpc/src/client.rs
index 2863204..d5caebe 100644
--- a/jsonrpc/src/client.rs
+++ b/jsonrpc/src/client.rs
@@ -1,31 +1,46 @@
use log::debug;
use serde::{de::DeserializeOwned, Serialize};
-use karyons_core::{async_utils::timeout, utils::random_32};
+use karyons_core::utils::random_32;
use karyons_net::{dial, Conn, Endpoint};
use crate::{
- message,
- utils::{read_until, write_all},
- Error, Result, JSONRPC_VERSION,
+ codec::{Codec, CodecConfig},
+ message, Error, Result, JSONRPC_VERSION,
};
+/// Represents client config
+#[derive(Default)]
+pub struct ClientConfig {
+ pub timeout: Option<u64>,
+}
+
/// Represents an RPC client
pub struct Client {
- conn: Conn,
- timeout: Option<u64>,
+ codec: Codec,
+ config: ClientConfig,
}
impl Client {
/// Creates a new RPC client.
- pub fn new(conn: Conn, timeout: Option<u64>) -> Self {
- Self { conn, timeout }
+ pub fn new(conn: Conn, config: ClientConfig) -> Self {
+ let codec_config = CodecConfig {
+ max_allowed_msg_size: 0,
+ ..Default::default()
+ };
+ let codec = Codec::new(conn, codec_config);
+ Self { codec, config }
}
/// Creates a new RPC client using the provided endpoint.
- pub async fn new_with_endpoint(endpoint: &Endpoint, timeout: Option<u64>) -> Result<Self> {
+ pub async fn new_with_endpoint(endpoint: &Endpoint, config: ClientConfig) -> Result<Self> {
let conn = dial(endpoint).await?;
- Ok(Self { conn, timeout })
+ let codec_config = CodecConfig {
+ max_allowed_msg_size: 0,
+ ..Default::default()
+ };
+ let codec = Codec::new(conn, codec_config);
+ Ok(Self { codec, config })
}
/// Calls the named method, waits for the response, and returns the result.
@@ -44,19 +59,15 @@ impl Client {
};
let payload = serde_json::to_vec(&request)?;
- write_all(&self.conn, &payload).await?;
+ self.codec.write_all(&payload).await?;
debug!("--> {request}");
let mut buffer = vec![];
- if let Some(t) = self.timeout {
- timeout(
- std::time::Duration::from_secs(t),
- read_until(&self.conn, &mut buffer),
- )
- .await?
+ if let Some(t) = self.config.timeout {
+ self.codec.read_until_timeout(&mut buffer, t).await?;
} else {
- read_until(&self.conn, &mut buffer).await
- }?;
+ self.codec.read_until(&mut buffer).await?;
+ };
let response = serde_json::from_slice::<message::Response>(&buffer)?;
debug!("<-- {response}");
diff --git a/jsonrpc/src/codec.rs b/jsonrpc/src/codec.rs
new file mode 100644
index 0000000..ea97e54
--- /dev/null
+++ b/jsonrpc/src/codec.rs
@@ -0,0 +1,100 @@
+use memchr::memchr;
+
+use karyons_core::async_utils::timeout;
+use karyons_net::Conn;
+
+use crate::{Error, Result};
+
+const DEFAULT_BUFFER_SIZE: usize = 1024;
+const DEFAULT_MAX_ALLOWED_MSG_SIZE: usize = 1024 * 1024; // 1MB
+
+// TODO: Add unit tests for Codec's functions.
+
+/// Represents Codec config
+#[derive(Clone)]
+pub struct CodecConfig {
+ pub default_buffer_size: usize,
+ /// The maximum allowed size to receive a message. If set to zero, there
+ /// will be no size limit.
+ pub max_allowed_msg_size: usize,
+}
+
+impl Default for CodecConfig {
+ fn default() -> Self {
+ Self {
+ default_buffer_size: DEFAULT_BUFFER_SIZE,
+ max_allowed_msg_size: DEFAULT_MAX_ALLOWED_MSG_SIZE,
+ }
+ }
+}
+
+pub struct Codec {
+ conn: Conn,
+ config: CodecConfig,
+}
+
+impl Codec {
+ /// Creates a new Codec
+ pub fn new(conn: Conn, config: CodecConfig) -> Self {
+ Self { conn, config }
+ }
+
+ /// Read all bytes into `buffer` until the `0x0` byte or EOF is
+ /// reached.
+ ///
+ /// If successful, this function will return the total number of bytes read.
+ pub async fn read_until(&self, buffer: &mut Vec<u8>) -> Result<usize> {
+ let delim = b'\0';
+
+ let mut read = 0;
+
+ loop {
+ let mut tmp_buf = vec![0; self.config.default_buffer_size];
+ let n = self.conn.read(&mut tmp_buf).await?;
+ if n == 0 {
+ return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
+ }
+
+ match memchr(delim, &tmp_buf) {
+ Some(i) => {
+ buffer.extend_from_slice(&tmp_buf[..i]);
+ read += i;
+ break;
+ }
+ None => {
+ buffer.extend_from_slice(&tmp_buf);
+ read += tmp_buf.len();
+ }
+ }
+
+ if self.config.max_allowed_msg_size != 0
+ && buffer.len() == self.config.max_allowed_msg_size
+ {
+ return Err(Error::InvalidMsg(
+ "Message exceeds the maximum allowed size",
+ ));
+ }
+ }
+
+ Ok(read)
+ }
+
+ /// Writes an entire buffer into the given connection.
+ pub async fn write_all(&self, mut buf: &[u8]) -> Result<()> {
+ while !buf.is_empty() {
+ let n = self.conn.write(buf).await?;
+ let (_, rest) = std::mem::take(&mut buf).split_at(n);
+ buf = rest;
+
+ if n == 0 {
+ return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
+ }
+ }
+
+ Ok(())
+ }
+
+ pub async fn read_until_timeout(&self, buffer: &mut Vec<u8>, t: u64) -> Result<usize> {
+ timeout(std::time::Duration::from_secs(t), self.read_until(buffer)).await?
+ }
+}
diff --git a/jsonrpc/src/lib.rs b/jsonrpc/src/lib.rs
index 2ac89c9..f73b5e6 100644
--- a/jsonrpc/src/lib.rs
+++ b/jsonrpc/src/lib.rs
@@ -7,7 +7,7 @@
//!
//! use serde_json::Value;
//!
-//! use karyons_jsonrpc::{JsonRPCError, Server, Client, register_service};
+//! use karyons_jsonrpc::{JsonRPCError, Server, Client, register_service, ServerConfig, ClientConfig};
//!
//! struct HelloWorld {}
//!
@@ -24,7 +24,8 @@
//!
//! // Creates a new server
//! let endpoint = "tcp://127.0.0.1:60000".parse().unwrap();
-//! let server = Server::new_with_endpoint(&endpoint, ex.clone()).await.unwrap();
+//! let config = ServerConfig::default();
+//! let server = Server::new_with_endpoint(&endpoint, config, ex.clone()).await.unwrap();
//!
//! // Register the HelloWorld service
//! register_service!(HelloWorld, say_hello);
@@ -39,7 +40,8 @@
//!
//! // Creates a new client
//! let endpoint = "tcp://127.0.0.1:60000".parse().unwrap();
-//! let client = Client::new_with_endpoint(&endpoint, None).await.unwrap();
+//! let config = ClientConfig::default();
+//! let client = Client::new_with_endpoint(&endpoint, config).await.unwrap();
//!
//! let result: String = client.call("HelloWorld.say_hello", "world".to_string()).await.unwrap();
//! };
@@ -47,17 +49,18 @@
//! ```
mod client;
+mod codec;
mod error;
pub mod message;
mod server;
mod service;
-mod utils;
pub const JSONRPC_VERSION: &str = "2.0";
use error::{Error, Result};
-pub use client::Client;
+pub use client::{Client, ClientConfig};
+pub use codec::CodecConfig;
pub use error::Error as JsonRPCError;
-pub use server::Server;
+pub use server::{Server, ServerConfig};
pub use service::{RPCMethod, RPCService};
diff --git a/jsonrpc/src/server.rs b/jsonrpc/src/server.rs
index 6c01a96..133f261 100644
--- a/jsonrpc/src/server.rs
+++ b/jsonrpc/src/server.rs
@@ -10,36 +10,49 @@ use karyons_core::{
use karyons_net::{listen, Conn, Endpoint, Listener};
use crate::{
+ codec::{Codec, CodecConfig},
message,
service::RPCService,
- utils::{read_until, write_all},
Error, Result, JSONRPC_VERSION,
};
/// Represents an RPC server
+#[derive(Default)]
+pub struct ServerConfig {
+ codec_config: CodecConfig,
+}
+
+/// Represents an RPC server
pub struct Server<'a> {
listener: Box<dyn Listener>,
services: RwLock<HashMap<String, Box<dyn RPCService + 'a>>>,
task_group: TaskGroup<'a>,
+ config: ServerConfig,
}
impl<'a> Server<'a> {
/// Creates a new RPC server.
- pub fn new(listener: Box<dyn Listener>, ex: Executor<'a>) -> Arc<Self> {
+ pub fn new(listener: Box<dyn Listener>, config: ServerConfig, ex: Executor<'a>) -> Arc<Self> {
Arc::new(Self {
listener,
services: RwLock::new(HashMap::new()),
task_group: TaskGroup::new(ex),
+ config,
})
}
/// Creates a new RPC server using the provided endpoint.
- pub async fn new_with_endpoint(endpoint: &Endpoint, ex: Executor<'a>) -> Result<Arc<Self>> {
+ pub async fn new_with_endpoint(
+ endpoint: &Endpoint,
+ config: ServerConfig,
+ ex: Executor<'a>,
+ ) -> Result<Arc<Self>> {
let listener = listen(endpoint).await?;
Ok(Arc::new(Self {
listener,
services: RwLock::new(HashMap::new()),
task_group: TaskGroup::new(ex),
+ config,
}))
}
@@ -77,15 +90,17 @@ impl<'a> Server<'a> {
}
};
+ let codec = Codec::new(conn, self.config.codec_config.clone());
+
let selfc = self.clone();
self.task_group.spawn(
async move {
loop {
let mut buffer = vec![];
- read_until(&conn, &mut buffer).await?;
+ codec.read_until(&mut buffer).await?;
let response = selfc.handle_request(&buffer).await;
let payload = serde_json::to_vec(&response)?;
- write_all(&conn, &payload).await?;
+ codec.write_all(&payload).await?;
debug!("--> {response}");
}
},
diff --git a/jsonrpc/src/utils.rs b/jsonrpc/src/utils.rs
deleted file mode 100644
index 1f21b7a..0000000
--- a/jsonrpc/src/utils.rs
+++ /dev/null
@@ -1,63 +0,0 @@
-use memchr::memchr;
-
-use karyons_net::Conn;
-
-use crate::{Error, Result};
-
-const DEFAULT_MSG_SIZE: usize = 1024;
-const MAX_ALLOWED_MSG_SIZE: usize = 1024 * 1024; // 1MB
-
-// TODO: Add unit tests for these functions.
-
-/// Read all bytes into `buffer` until the `0x0` byte or EOF is
-/// reached.
-///
-/// If successful, this function will return the total number of bytes read.
-pub async fn read_until(conn: &Conn, buffer: &mut Vec<u8>) -> Result<usize> {
- let delim = b'\0';
-
- let mut read = 0;
-
- loop {
- let mut tmp_buf = [0; DEFAULT_MSG_SIZE];
- let n = conn.read(&mut tmp_buf).await?;
- if n == 0 {
- return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
- }
-
- match memchr(delim, &tmp_buf) {
- Some(i) => {
- buffer.extend_from_slice(&tmp_buf[..i]);
- read += i;
- break;
- }
- None => {
- buffer.extend_from_slice(&tmp_buf);
- read += tmp_buf.len();
- }
- }
-
- if buffer.len() == MAX_ALLOWED_MSG_SIZE {
- return Err(Error::InvalidMsg(
- "Message exceeds the maximum allowed size",
- ));
- }
- }
-
- Ok(read)
-}
-
-/// Writes an entire buffer into the given connection.
-pub async fn write_all(conn: &Conn, mut buf: &[u8]) -> Result<()> {
- while !buf.is_empty() {
- let n = conn.write(buf).await?;
- let (_, rest) = std::mem::take(&mut buf).split_at(n);
- buf = rest;
-
- if n == 0 {
- return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into()));
- }
- }
-
- Ok(())
-}