pub mod builder; pub mod channel; pub mod pubsub_service; pub mod service; use std::{collections::HashMap, sync::Arc}; use log::{debug, error, trace, warn}; use karyon_core::async_util::{select, Either, TaskGroup, TaskResult}; use karyon_net::{Conn, Endpoint, Listener}; use crate::{message, Error, PubSubRPCService, RPCService, Result}; use channel::{ArcChannel, Channel}; const CHANNEL_CAP: usize = 10; pub const INVALID_REQUEST_ERROR_MSG: &str = "Invalid request"; pub const FAILED_TO_PARSE_ERROR_MSG: &str = "Failed to parse"; pub const METHOD_NOT_FOUND_ERROR_MSG: &str = "Method not found"; pub const INTERNAL_ERROR_MSG: &str = "Internal error"; fn pack_err_res(code: i32, msg: &str, id: Option) -> message::Response { let err = message::Error { code, message: msg.to_string(), data: None, }; message::Response { jsonrpc: message::JSONRPC_VERSION.to_string(), error: Some(err), result: None, id, } } struct NewRequest { srvc_name: String, method_name: String, msg: message::Request, } enum SanityCheckResult { NewReq(NewRequest), ErrRes(message::Response), } /// Represents an RPC server pub struct Server { listener: Listener, task_group: TaskGroup, services: HashMap>, pubsub_services: HashMap>, } impl Server { /// Returns the local endpoint. pub fn local_endpoint(&self) -> Result { self.listener.local_endpoint().map_err(Error::from) } /// Starts the RPC server pub async fn start(self: &Arc) { let on_failure = |result: TaskResult>| async move { if let TaskResult::Completed(Err(err)) = result { error!("Accept loop stopped: {err}"); } }; let selfc = self.clone(); self.task_group.spawn( async move { loop { match selfc.listener.accept().await { Ok(conn) => { if let Err(err) = selfc.handle_conn(conn).await { error!("Failed to handle a new conn: {err}") } } Err(err) => { error!("Failed to accept a new conn: {err}") } } } }, on_failure, ); } /// Shuts down the RPC server pub async fn shutdown(&self) { self.task_group.cancel().await; } /// Handles a new connection async fn handle_conn(self: &Arc, conn: Conn) -> Result<()> { let endpoint = conn.peer_endpoint().expect("get peer endpoint"); debug!("Handle a new connection {endpoint}"); // TODO Avoid depending on channels let (tx, rx) = async_channel::bounded::(CHANNEL_CAP); let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_CAP); let channel = Channel::new(ch_tx); let on_failure = |result: TaskResult>| async move { if let TaskResult::Completed(Err(err)) = result { debug!("Notification loop stopped: {err}"); } }; let selfc = self.clone(); let txc = tx.clone(); self.task_group.spawn( async move { loop { let nt = ch_rx.recv().await?; let params = Some(serde_json::json!(message::NotificationResult { subscription: nt.sub_id, result: Some(nt.result), })); let response = message::Notification { jsonrpc: message::JSONRPC_VERSION.to_string(), method: nt.method, params, }; debug!("--> {response}"); txc.send(serde_json::to_value(response)?).await?; } }, on_failure, ); let on_failure = |result: TaskResult>| async move { if let TaskResult::Completed(Err(err)) = result { error!("Connection {} dropped: {}", endpoint, err); } else { warn!("Connection {} dropped", endpoint); } }; self.task_group.spawn( async move { loop { match select(conn.recv(), rx.recv()).await { Either::Left(msg) => { selfc.new_request(tx.clone(), channel.clone(), msg?).await; } Either::Right(msg) => conn.send(msg?).await?, } } }, on_failure, ); Ok(()) } fn sanity_check(&self, request: serde_json::Value) -> SanityCheckResult { let rpc_msg = match serde_json::from_value::(request) { Ok(m) => m, Err(_) => { return SanityCheckResult::ErrRes(pack_err_res( message::PARSE_ERROR_CODE, FAILED_TO_PARSE_ERROR_MSG, None, )); } }; debug!("<-- {rpc_msg}"); let srvc_method_str = rpc_msg.method.clone(); let srvc_method: Vec<&str> = srvc_method_str.split('.').collect(); if srvc_method.len() < 2 { return SanityCheckResult::ErrRes(pack_err_res( message::INVALID_REQUEST_ERROR_CODE, INVALID_REQUEST_ERROR_MSG, Some(rpc_msg.id), )); } let srvc_name = srvc_method[0].to_string(); let method_name = srvc_method[1].to_string(); SanityCheckResult::NewReq(NewRequest { srvc_name, method_name, msg: rpc_msg, }) } /// Spawns a new task for handling a new request async fn new_request( self: &Arc, sender: async_channel::Sender, channel: ArcChannel, msg: serde_json::Value, ) { trace!("--> new request {msg}"); let on_failure = |result: TaskResult>| async move { if let TaskResult::Completed(Err(err)) = result { error!("Failed to handle a request: {err}"); } }; let selfc = self.clone(); self.task_group.spawn( async move { let response = selfc.handle_request(channel, msg).await; debug!("--> {response}"); sender.send(serde_json::json!(response)).await?; Ok(()) }, on_failure, ); } /// Handles a new request async fn handle_request( &self, channel: ArcChannel, msg: serde_json::Value, ) -> message::Response { let req = match self.sanity_check(msg) { SanityCheckResult::NewReq(req) => req, SanityCheckResult::ErrRes(res) => return res, }; let mut response = message::Response { jsonrpc: message::JSONRPC_VERSION.to_string(), error: None, result: None, id: Some(req.msg.id.clone()), }; if let Some(service) = self.pubsub_services.get(&req.srvc_name) { if let Some(method) = service.get_pubsub_method(&req.method_name) { let name = format!("{}.{}", service.name(), req.method_name); let params = req.msg.params.unwrap_or(serde_json::json!(())); response.result = match method(channel, name, params).await { Ok(res) => Some(res), Err(err) => return self.handle_error(err, req.msg.id), }; return response; } } if let Some(service) = self.services.get(&req.srvc_name) { if let Some(method) = service.get_method(&req.method_name) { let params = req.msg.params.unwrap_or(serde_json::json!(())); response.result = match method(params).await { Ok(res) => Some(res), Err(err) => return self.handle_error(err, req.msg.id), }; return response; } } pack_err_res( message::METHOD_NOT_FOUND_ERROR_CODE, METHOD_NOT_FOUND_ERROR_MSG, Some(req.msg.id), ) } fn handle_error(&self, err: Error, msg_id: serde_json::Value) -> message::Response { match err { Error::ParseJSON(_) => pack_err_res( message::PARSE_ERROR_CODE, FAILED_TO_PARSE_ERROR_MSG, Some(msg_id), ), Error::InvalidParams(msg) => { pack_err_res(message::INVALID_PARAMS_ERROR_CODE, msg, Some(msg_id)) } Error::InvalidRequest(msg) => { pack_err_res(message::INVALID_REQUEST_ERROR_CODE, msg, Some(msg_id)) } Error::RPCMethodError(code, msg) => pack_err_res(code, msg, Some(msg_id)), _ => pack_err_res( message::INTERNAL_ERROR_CODE, INTERNAL_ERROR_MSG, Some(msg_id), ), } } }