From 4d51e3211740764764a6423f8ead4944e1790341 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Sun, 19 Nov 2023 22:19:06 +0300 Subject: karyons jsonrpc implementation --- Cargo.toml | 2 + jsonrpc/Cargo.toml | 27 +++++ jsonrpc/README.md | 50 +++++++++ jsonrpc/examples/client.py | 59 ++++++++++ jsonrpc/examples/client.rs | 34 ++++++ jsonrpc/examples/server.rs | 59 ++++++++++ jsonrpc/src/client.rs | 77 +++++++++++++ jsonrpc/src/error.rs | 34 ++++++ jsonrpc/src/lib.rs | 61 +++++++++++ jsonrpc/src/message.rs | 91 ++++++++++++++++ jsonrpc/src/server.rs | 263 +++++++++++++++++++++++++++++++++++++++++++++ jsonrpc/src/utils.rs | 63 +++++++++++ 12 files changed, 820 insertions(+) create mode 100644 jsonrpc/Cargo.toml create mode 100644 jsonrpc/README.md create mode 100644 jsonrpc/examples/client.py create mode 100644 jsonrpc/examples/client.rs create mode 100644 jsonrpc/examples/server.rs create mode 100644 jsonrpc/src/client.rs create mode 100644 jsonrpc/src/error.rs create mode 100644 jsonrpc/src/lib.rs create mode 100644 jsonrpc/src/message.rs create mode 100644 jsonrpc/src/server.rs create mode 100644 jsonrpc/src/utils.rs diff --git a/Cargo.toml b/Cargo.toml index aa5ba29..4d85f3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ "core", "net", "p2p", + "jsonrpc", ] resolver = "2" @@ -15,3 +16,4 @@ edition = "2021" karyons_core = { path = "core" } karyons_net = { path = "net" } karyons_p2p = { path = "p2p" } +karyons_jsonrpc = { path = "jsonrpc" } diff --git a/jsonrpc/Cargo.toml b/jsonrpc/Cargo.toml new file mode 100644 index 0000000..12bea1a --- /dev/null +++ b/jsonrpc/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "karyons_jsonrpc" +version.workspace = true +edition.workspace = true + +[dependencies] +karyons_core.workspace = true +karyons_net.workspace = true + +smol = "1.3.0" +log = "0.4.20" +rand = "0.8.5" +serde = { version = "1.0.192", features = ["derive"] } +serde_json = "1.0.108" +thiserror = "1.0.50" +memchr = "2.6.4" + +[[example]] +name = "server" +path = "examples/server.rs" + +[[example]] +name = "client" +path = "examples/client.rs" + +[dev-dependencies] +env_logger = "0.10.0" diff --git a/jsonrpc/README.md b/jsonrpc/README.md new file mode 100644 index 0000000..82ab2e9 --- /dev/null +++ b/jsonrpc/README.md @@ -0,0 +1,50 @@ +# karyons jsonrpc + +A fast and lightweight async [JSONRPC2.0](https://www.jsonrpc.org/specification) implementation. + +## Example + +```rust +use std::sync::Arc; + +use serde_json::Value; + +use karyons_jsonrpc::{JsonRPCError, Server, Client, register_service}; + +struct HelloWorld {} + +impl HelloWorld { + async fn say_hello(&self, params: Value) -> Result { + let msg: String = serde_json::from_value(params)?; + Ok(serde_json::json!(format!("Hello {msg}!"))) + } +} + +////////////////// +// Server +////////////////// +let ex = Arc::new(smol::Executor::new()); + +// Creates a new server +let endpoint = "tcp://127.0.0.1:60000".parse().unwrap(); +let server = Server::new_with_endpoint(&endpoint, ex.clone()).await.unwrap(); + +// Register the HelloWorld service +register_service!(HelloWorld, say_hello); +server.attach_service(HelloWorld{}); + +// Starts the server +ex.run(server.start()); + + +////////////////// +// Client +////////////////// + +// Creates a new client +let endpoint = "tcp://127.0.0.1:60000".parse().unwrap(); +let client = Client::new_with_endpoint(&endpoint, None).await.unwrap(); + +let result: String = client.call("HelloWorld.say_hello", "world".to_string()).await.unwrap(); + +``` diff --git a/jsonrpc/examples/client.py b/jsonrpc/examples/client.py new file mode 100644 index 0000000..14b3cf9 --- /dev/null +++ b/jsonrpc/examples/client.py @@ -0,0 +1,59 @@ +import socket +import random +import json + +HOST = "127.0.0.1" +PORT = 60000 + +s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +s.connect((HOST, PORT)) + +req = { + "jsonrpc": "2.0", + "id": str(random.randint(0, 1000)), + "method": "Calc.add", + "params": {"x": 4, "y": 3}, +} +print("Send: ", req) +s.sendall(json.dumps(req).encode()) +res = s.recv(1024) +res = json.loads(res) +print("Received: ", res) + +req = { + "jsonrpc": "2.0", + "id": str(random.randint(0, 1000)), + "method": "Calc.sub", + "params": {"x": 4, "y": 3}, +} +print("Send: ", req) +s.sendall(json.dumps(req).encode()) +res = s.recv(1024) +res = json.loads(res) +print("Received: ", res) + +req = { + "jsonrpc": "2.0", + "id": str(random.randint(0, 1000)), + "method": "Calc.ping", + "params": None, +} +print("Send: ", req) +s.sendall(json.dumps(req).encode()) +res = s.recv(1024) +res = json.loads(res) +print("Received: ", res) + +req = { + "jsonrpc": "2.0", + "id": str(random.randint(0, 1000)), + "method": "Calc.version", + "params": None, +} +print("Send: ", req) +s.sendall(json.dumps(req).encode()) +res = s.recv(1024) +res = json.loads(res) +print("Received: ", res) + +s.close() diff --git a/jsonrpc/examples/client.rs b/jsonrpc/examples/client.rs new file mode 100644 index 0000000..0063098 --- /dev/null +++ b/jsonrpc/examples/client.rs @@ -0,0 +1,34 @@ +use serde::{Deserialize, Serialize}; + +use karyons_jsonrpc::Client; + +#[derive(Deserialize, Serialize)] +struct Req { + x: u32, + y: u32, +} + +#[derive(Deserialize, Serialize, Debug)] +struct Pong {} + +fn main() { + env_logger::init(); + smol::future::block_on(async { + let endpoint = "tcp://127.0.0.1:60000".parse().unwrap(); + let client = Client::new_with_endpoint(&endpoint, None).await.unwrap(); + + let params = Req { x: 10, y: 7 }; + let result: u32 = client.call("Calc.add", params).await.unwrap(); + println!("result {result}"); + + let params = Req { x: 10, y: 7 }; + let result: u32 = client.call("Calc.sub", params).await.unwrap(); + println!("result {result}"); + + let result: Pong = client.call("Calc.ping", ()).await.unwrap(); + println!("result {:?}", result); + + let result: String = client.call("Calc.version", ()).await.unwrap(); + println!("result {result}"); + }); +} diff --git a/jsonrpc/examples/server.rs b/jsonrpc/examples/server.rs new file mode 100644 index 0000000..367bfe9 --- /dev/null +++ b/jsonrpc/examples/server.rs @@ -0,0 +1,59 @@ +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use karyons_jsonrpc::{register_service, JsonRPCError, Server}; + +struct Calc { + version: String, +} + +#[derive(Deserialize, Serialize)] +struct Req { + x: u32, + y: u32, +} + +#[derive(Deserialize, Serialize)] +struct Pong {} + +impl Calc { + async fn ping(&self, _params: Value) -> Result { + Ok(serde_json::json!(Pong {})) + } + + async fn add(&self, params: Value) -> Result { + let params: Req = serde_json::from_value(params)?; + Ok(serde_json::json!(params.x + params.y)) + } + + async fn sub(&self, params: Value) -> Result { + let params: Req = serde_json::from_value(params)?; + Ok(serde_json::json!(params.x - params.y)) + } + + async fn version(&self, _params: Value) -> Result { + Ok(serde_json::json!(self.version)) + } +} + +fn main() { + env_logger::init(); + let ex = Arc::new(smol::Executor::new()); + smol::block_on(ex.clone().run(async { + // Creates a new server + let endpoint = "tcp://127.0.0.1:60000".parse().unwrap(); + let server = Server::new_with_endpoint(&endpoint, ex).await.unwrap(); + + // Register the Calc service + register_service!(Calc, ping, add, sub, version); + let calc = Calc { + version: String::from("0.1"), + }; + server.attach_service(calc).await; + + // Start the server + server.start().await.unwrap(); + })); +} diff --git a/jsonrpc/src/client.rs b/jsonrpc/src/client.rs new file mode 100644 index 0000000..2863204 --- /dev/null +++ b/jsonrpc/src/client.rs @@ -0,0 +1,77 @@ +use log::debug; +use serde::{de::DeserializeOwned, Serialize}; + +use karyons_core::{async_utils::timeout, utils::random_32}; +use karyons_net::{dial, Conn, Endpoint}; + +use crate::{ + message, + utils::{read_until, write_all}, + Error, Result, JSONRPC_VERSION, +}; + +/// Represents an RPC client +pub struct Client { + conn: Conn, + timeout: Option, +} + +impl Client { + /// Creates a new RPC client. + pub fn new(conn: Conn, timeout: Option) -> Self { + Self { conn, timeout } + } + + /// Creates a new RPC client using the provided endpoint. + pub async fn new_with_endpoint(endpoint: &Endpoint, timeout: Option) -> Result { + let conn = dial(endpoint).await?; + Ok(Self { conn, timeout }) + } + + /// Calls the named method, waits for the response, and returns the result. + pub async fn call( + &self, + method: &str, + params: T, + ) -> Result { + let id = serde_json::json!(random_32()); + + let request = message::Request { + jsonrpc: JSONRPC_VERSION.to_string(), + id, + method: method.to_string(), + params: serde_json::json!(params), + }; + + let payload = serde_json::to_vec(&request)?; + write_all(&self.conn, &payload).await?; + debug!("--> {request}"); + + let mut buffer = vec![]; + if let Some(t) = self.timeout { + timeout( + std::time::Duration::from_secs(t), + read_until(&self.conn, &mut buffer), + ) + .await? + } else { + read_until(&self.conn, &mut buffer).await + }?; + + let response = serde_json::from_slice::(&buffer)?; + debug!("<-- {response}"); + + if let Some(error) = response.error { + return Err(Error::CallError(error.code, error.message)); + } + + if response.id.is_none() || response.id.unwrap() != request.id { + return Err(Error::InvalidMsg("Invalid response id")); + } + + match response.result { + Some(result) => Ok(serde_json::from_value::(result)?), + None => Err(Error::InvalidMsg("Invalid response result")), + } + } +} diff --git a/jsonrpc/src/error.rs b/jsonrpc/src/error.rs new file mode 100644 index 0000000..5437c6d --- /dev/null +++ b/jsonrpc/src/error.rs @@ -0,0 +1,34 @@ +use thiserror::Error as ThisError; + +pub type Result = std::result::Result; + +/// Represents karyons's jsonrpc Error. +#[derive(ThisError, Debug)] +pub enum Error { + #[error(transparent)] + IO(#[from] std::io::Error), + + #[error("Call Error: code: {0} msg: {1}")] + CallError(i32, String), + + #[error("RPC Method Error: code: {0} msg: {1}")] + RPCMethodError(i32, &'static str), + + #[error("Invalid Params: {0}")] + InvalidParams(&'static str), + + #[error("Invalid Request: {0}")] + InvalidRequest(&'static str), + + #[error(transparent)] + ParseJSON(#[from] serde_json::Error), + + #[error("Invalid Message Error: {0}")] + InvalidMsg(&'static str), + + #[error(transparent)] + KaryonsCore(#[from] karyons_core::error::Error), + + #[error(transparent)] + KaryonsNet(#[from] karyons_net::NetError), +} diff --git a/jsonrpc/src/lib.rs b/jsonrpc/src/lib.rs new file mode 100644 index 0000000..8a547d9 --- /dev/null +++ b/jsonrpc/src/lib.rs @@ -0,0 +1,61 @@ +//! A fast and lightweight async [JSONRPC 2.0](https://www.jsonrpc.org/specification) implementation. +//! +//! # Example +//! +//! ``` +//! use std::sync::Arc; +//! +//! use serde_json::Value; +//! +//! use karyons_jsonrpc::{JsonRPCError, Server, Client, register_service}; +//! +//! struct HelloWorld {} +//! +//! impl HelloWorld { +//! async fn say_hello(&self, params: Value) -> Result { +//! let msg: String = serde_json::from_value(params)?; +//! Ok(serde_json::json!(format!("Hello {msg}!"))) +//! } +//! } +//! +//! // Server +//! async { +//! let ex = Arc::new(smol::Executor::new()); +//! +//! // Creates a new server +//! let endpoint = "tcp://127.0.0.1:60000".parse().unwrap(); +//! let server = Server::new_with_endpoint(&endpoint, ex.clone()).await.unwrap(); +//! +//! // Register the HelloWorld service +//! register_service!(HelloWorld, say_hello); +//! server.attach_service(HelloWorld{}); +//! +//! // Starts the server +//! ex.run(server.start()); +//! }; +//! +//! // Client +//! async { +//! +//! // Creates a new client +//! let endpoint = "tcp://127.0.0.1:60000".parse().unwrap(); +//! let client = Client::new_with_endpoint(&endpoint, None).await.unwrap(); +//! +//! let result: String = client.call("HelloWorld.say_hello", "world".to_string()).await.unwrap(); +//! }; +//! +//! ``` + +mod client; +mod error; +pub mod message; +mod server; +mod utils; + +pub const JSONRPC_VERSION: &str = "2.0"; + +use error::{Error, Result}; + +pub use client::Client; +pub use error::Error as JsonRPCError; +pub use server::{RPCMethod, RPCService, Server}; diff --git a/jsonrpc/src/message.rs b/jsonrpc/src/message.rs new file mode 100644 index 0000000..89ef613 --- /dev/null +++ b/jsonrpc/src/message.rs @@ -0,0 +1,91 @@ +use serde::{Deserialize, Serialize}; + +/// Parse error: Invalid JSON was received by the server. +pub const PARSE_ERROR_CODE: i32 = -32700; + +/// Invalid request: The JSON sent is not a valid Request object. +pub const INVALID_REQUEST_ERROR_CODE: i32 = -32600; + +/// Method not found: The method does not exist / is not available. +pub const METHOD_NOT_FOUND_ERROR_CODE: i32 = -32601; + +/// Invalid params: Invalid method parameter(s). +pub const INVALID_PARAMS_ERROR_CODE: i32 = -32602; + +/// Internal error: Internal JSON-RPC error. +pub const INTERNAL_ERROR_CODE: i32 = -32603; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Request { + pub jsonrpc: String, + pub method: String, + pub params: serde_json::Value, + pub id: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Response { + pub jsonrpc: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Error { + pub code: i32, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Notification { + pub jsonrpc: String, + pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +impl std::fmt::Display for Request { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{{jsonrpc: {}, method: {}, params: {:?}, id: {:?}}}", + self.jsonrpc, self.method, self.params, self.id + ) + } +} + +impl std::fmt::Display for Response { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{{jsonrpc: {}, result': {:?}, error: {:?} , id: {:?}}}", + self.jsonrpc, self.result, self.error, self.id + ) + } +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "RpcError {{ code: {}, message: {}, data: {:?} }} ", + self.code, self.message, self.data + ) + } +} + +impl std::fmt::Display for Notification { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{{jsonrpc: {}, method: {}, params: {:?}}}", + self.jsonrpc, self.method, self.params + ) + } +} diff --git a/jsonrpc/src/server.rs b/jsonrpc/src/server.rs new file mode 100644 index 0000000..9642381 --- /dev/null +++ b/jsonrpc/src/server.rs @@ -0,0 +1,263 @@ +use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc}; + +use log::{debug, error, warn}; +use smol::lock::RwLock; + +use karyons_core::{ + async_utils::{TaskGroup, TaskResult}, + Executor, +}; +use karyons_net::{listen, Conn, Endpoint, Listener}; + +use crate::{ + message, + utils::{read_until, write_all}, + Error, Result, JSONRPC_VERSION, +}; + +/// Represents an RPC server +pub struct Server<'a> { + listener: Box, + services: RwLock>>, + task_group: TaskGroup<'a>, +} + +impl<'a> Server<'a> { + /// Creates a new RPC server. + pub fn new(listener: Box, ex: Executor<'a>) -> Arc { + Arc::new(Self { + listener, + services: RwLock::new(HashMap::new()), + task_group: TaskGroup::new(ex), + }) + } + + /// Creates a new RPC server using the provided endpoint. + pub async fn new_with_endpoint(endpoint: &Endpoint, ex: Executor<'a>) -> Result> { + let listener = listen(endpoint).await?; + Ok(Arc::new(Self { + listener, + services: RwLock::new(HashMap::new()), + task_group: TaskGroup::new(ex), + })) + } + + /// Starts the RPC server + pub async fn start(self: Arc) -> Result<()> { + loop { + let conn = self.listener.accept().await?; + self.handle_conn(conn).await?; + } + } + + /// Attach a new service to the RPC server + pub async fn attach_service(&self, service: impl RPCService + 'a) { + self.services + .write() + .await + .insert(service.name(), Box::new(service)); + } + + /// Shuts down the RPC server + pub async fn shutdown(&self) { + self.task_group.cancel().await; + } + + /// Handles a new connection + async fn handle_conn(self: &Arc, conn: Conn) -> Result<()> { + let endpoint = conn.peer_endpoint()?; + debug!("Handle a new connection {endpoint}"); + + let on_failure = |result: TaskResult>| async move { + if let TaskResult::Completed(Err(err)) = result { + error!("Connection {} dropped: {}", endpoint, err); + } else { + warn!("Connection {} dropped", endpoint); + } + }; + + let selfc = self.clone(); + self.task_group.spawn( + async move { + loop { + let mut buffer = vec![]; + read_until(&conn, &mut buffer).await?; + let response = selfc.handle_request(&buffer).await; + let payload = serde_json::to_vec(&response)?; + write_all(&conn, &payload).await?; + debug!("--> {response}"); + } + }, + on_failure, + ); + + Ok(()) + } + + /// Handles a request + async fn handle_request(&self, buffer: &[u8]) -> message::Response { + let rpc_msg = match serde_json::from_slice::(buffer) { + Ok(m) => m, + Err(_) => { + return self.pack_err_res(message::PARSE_ERROR_CODE, "Failed to parse", None); + } + }; + + debug!("<-- {rpc_msg}"); + + let srvc_method: Vec<&str> = rpc_msg.method.split('.').collect(); + if srvc_method.len() != 2 { + return self.pack_err_res( + message::INVALID_REQUEST_ERROR_CODE, + "Invalid request", + Some(rpc_msg.id), + ); + } + + let srvc_name = srvc_method[0]; + let method_name = srvc_method[1]; + + let services = self.services.read().await; + + let service = match services.get(srvc_name) { + Some(s) => s, + None => { + return self.pack_err_res( + message::METHOD_NOT_FOUND_ERROR_CODE, + "Method not found", + Some(rpc_msg.id), + ); + } + }; + + let method = match service.get_method(method_name) { + Some(m) => m, + None => { + return self.pack_err_res( + message::METHOD_NOT_FOUND_ERROR_CODE, + "Method not found", + Some(rpc_msg.id), + ); + } + }; + + let result = match method(rpc_msg.params.clone()).await { + Ok(res) => res, + Err(Error::ParseJSON(_)) => { + return self.pack_err_res( + message::PARSE_ERROR_CODE, + "Failed to parse", + Some(rpc_msg.id), + ); + } + Err(Error::InvalidParams(msg)) => { + return self.pack_err_res( + message::INVALID_PARAMS_ERROR_CODE, + msg, + Some(rpc_msg.id), + ); + } + Err(Error::InvalidRequest(msg)) => { + return self.pack_err_res( + message::INVALID_REQUEST_ERROR_CODE, + msg, + Some(rpc_msg.id), + ); + } + Err(Error::RPCMethodError(code, msg)) => { + return self.pack_err_res(code, msg, Some(rpc_msg.id)); + } + Err(_) => { + return self.pack_err_res( + message::INTERNAL_ERROR_CODE, + "Internal error", + Some(rpc_msg.id), + ); + } + }; + + message::Response { + jsonrpc: JSONRPC_VERSION.to_string(), + error: None, + result: Some(result), + id: Some(rpc_msg.id), + } + } + + fn pack_err_res( + &self, + code: i32, + msg: &str, + id: Option, + ) -> message::Response { + let err = message::Error { + code, + message: msg.to_string(), + data: None, + }; + + message::Response { + jsonrpc: JSONRPC_VERSION.to_string(), + error: Some(err), + result: None, + id, + } + } +} + +/// Represents the RPC method +pub type RPCMethod<'a> = Box RPCMethodOutput<'a> + Send + 'a>; +type RPCMethodOutput<'a> = + Pin> + Send + Sync + 'a>>; + +/// Defines the interface for an RPC service. +pub trait RPCService: Sync + Send { + fn get_method<'a>(&'a self, name: &'a str) -> Option; + fn name(&self) -> String; +} + +/// Implements the `RPCService` trait for a provided type. +/// +/// # Example +/// +/// ``` +/// use serde_json::Value; +/// +/// use karyons_jsonrpc::{JsonRPCError, register_service}; +/// +/// struct Hello {} +/// +/// impl Hello { +/// async fn say_hello(&self, params: Value) -> Result { +/// Ok(serde_json::json!("hello!")) +/// } +/// } +/// +/// register_service!(Hello, say_hello); +/// +/// ``` +#[macro_export] +macro_rules! register_service { + ($t:ty, $($m:ident),*) => { + impl karyons_jsonrpc::RPCService for $t { + fn get_method<'a>( + &'a self, + name: &'a str + ) -> Option { + match name { + $( + stringify!($m) => { + Some(Box::new(move |params: serde_json::Value| Box::pin(self.$m(params)))) + } + )* + _ => None, + } + + + } + fn name(&self) -> String{ + stringify!($t).to_string() + } + } + }; +} diff --git a/jsonrpc/src/utils.rs b/jsonrpc/src/utils.rs new file mode 100644 index 0000000..1f21b7a --- /dev/null +++ b/jsonrpc/src/utils.rs @@ -0,0 +1,63 @@ +use memchr::memchr; + +use karyons_net::Conn; + +use crate::{Error, Result}; + +const DEFAULT_MSG_SIZE: usize = 1024; +const MAX_ALLOWED_MSG_SIZE: usize = 1024 * 1024; // 1MB + +// TODO: Add unit tests for these functions. + +/// Read all bytes into `buffer` until the `0x0` byte or EOF is +/// reached. +/// +/// If successful, this function will return the total number of bytes read. +pub async fn read_until(conn: &Conn, buffer: &mut Vec) -> Result { + let delim = b'\0'; + + let mut read = 0; + + loop { + let mut tmp_buf = [0; DEFAULT_MSG_SIZE]; + let n = conn.read(&mut tmp_buf).await?; + if n == 0 { + return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); + } + + match memchr(delim, &tmp_buf) { + Some(i) => { + buffer.extend_from_slice(&tmp_buf[..i]); + read += i; + break; + } + None => { + buffer.extend_from_slice(&tmp_buf); + read += tmp_buf.len(); + } + } + + if buffer.len() == MAX_ALLOWED_MSG_SIZE { + return Err(Error::InvalidMsg( + "Message exceeds the maximum allowed size", + )); + } + } + + Ok(read) +} + +/// Writes an entire buffer into the given connection. +pub async fn write_all(conn: &Conn, mut buf: &[u8]) -> Result<()> { + while !buf.is_empty() { + let n = conn.write(buf).await?; + let (_, rest) = std::mem::take(&mut buf).split_at(n); + buf = rest; + + if n == 0 { + return Err(Error::IO(std::io::ErrorKind::UnexpectedEof.into())); + } + } + + Ok(()) +} -- cgit v1.2.3