aboutsummaryrefslogtreecommitdiff
path: root/jsonrpc/src/server.rs
diff options
context:
space:
mode:
Diffstat (limited to 'jsonrpc/src/server.rs')
-rw-r--r--jsonrpc/src/server.rs193
1 files changed, 134 insertions, 59 deletions
diff --git a/jsonrpc/src/server.rs b/jsonrpc/src/server.rs
index 26d632a..1cc7e1f 100644
--- a/jsonrpc/src/server.rs
+++ b/jsonrpc/src/server.rs
@@ -1,17 +1,20 @@
use std::{collections::HashMap, sync::Arc};
use log::{debug, error, warn};
-use smol::lock::RwLock;
-use karyon_core::async_util::{Executor, TaskGroup, TaskResult};
+#[cfg(feature = "smol")]
+use futures_rustls::rustls;
+#[cfg(feature = "tokio")]
+use tokio_rustls::rustls;
-use karyon_net::{Conn, Listener, ToListener};
+use karyon_core::async_runtime::Executor;
+use karyon_core::async_util::{TaskGroup, TaskResult};
+
+use karyon_net::{Conn, Endpoint, Listener, ToEndpoint};
use crate::{
- codec::{Codec, CodecConfig},
- message,
- service::RPCService,
- Endpoint, Error, Result, JSONRPC_VERSION,
+ codec::{JsonCodec, WsJsonCodec},
+ message, Error, RPCService, Result,
};
pub const INVALID_REQUEST_ERROR_MSG: &str = "Invalid request";
@@ -27,69 +30,50 @@ fn pack_err_res(code: i32, msg: &str, id: Option<serde_json::Value>) -> message:
};
message::Response {
- jsonrpc: JSONRPC_VERSION.to_string(),
+ jsonrpc: message::JSONRPC_VERSION.to_string(),
error: Some(err),
result: None,
id,
}
}
-/// RPC server config
-#[derive(Default)]
-pub struct ServerConfig {
- codec_config: CodecConfig,
-}
-
/// Represents an RPC server
-pub struct Server<'a> {
- listener: Listener,
- services: RwLock<HashMap<String, Box<dyn RPCService + 'a>>>,
- task_group: TaskGroup<'a>,
- config: ServerConfig,
+pub struct Server {
+ listener: Listener<serde_json::Value>,
+ task_group: TaskGroup,
+ services: HashMap<String, Box<dyn RPCService + 'static>>,
}
-impl<'a> Server<'a> {
- /// Creates a new RPC server by passing a listener. It supports Tcp, Unix, and Tls.
- pub fn new<T: ToListener>(listener: T, config: ServerConfig, ex: Executor<'a>) -> Arc<Self> {
- Arc::new(Self {
- listener: listener.to_listener(),
- services: RwLock::new(HashMap::new()),
- task_group: TaskGroup::with_executor(ex),
- config,
- })
- }
-
+impl Server {
/// Returns the local endpoint.
pub fn local_endpoint(&self) -> Result<Endpoint> {
- self.listener.local_endpoint().map_err(Error::KaryonNet)
+ self.listener.local_endpoint().map_err(Error::from)
}
/// Starts the RPC server
pub async fn start(self: Arc<Self>) -> Result<()> {
loop {
- let conn = self.listener.accept().await?;
- if let Err(err) = self.handle_conn(conn).await {
- error!("Failed to handle a new conn: {err}")
+ match self.listener.accept().await {
+ Ok(conn) => {
+ if let Err(err) = self.handle_conn(conn).await {
+ error!("Failed to handle a new conn: {err}")
+ }
+ }
+ Err(err) => {
+ error!("Failed to accept a new conn: {err}")
+ }
}
}
}
- /// Attach a new service to the RPC server
- pub async fn attach_service(&self, service: impl RPCService + 'a) {
- self.services
- .write()
- .await
- .insert(service.name(), Box::new(service));
- }
-
/// Shuts down the RPC server
pub async fn shutdown(&self) {
self.task_group.cancel().await;
}
/// Handles a new connection
- async fn handle_conn(self: &Arc<Self>, conn: Conn) -> Result<()> {
- let endpoint = conn.peer_endpoint()?;
+ async fn handle_conn(self: &Arc<Self>, conn: Conn<serde_json::Value>) -> Result<()> {
+ let endpoint = conn.peer_endpoint().expect("get peer endpoint");
debug!("Handle a new connection {endpoint}");
let on_failure = |result: TaskResult<Result<()>>| async move {
@@ -100,19 +84,15 @@ impl<'a> Server<'a> {
}
};
- let codec = Codec::new(conn, self.config.codec_config.clone());
-
let selfc = self.clone();
self.task_group.spawn(
async move {
loop {
- let mut buffer = vec![];
- codec.read_until(&mut buffer).await?;
- let response = selfc.handle_request(&buffer).await;
- let mut payload = serde_json::to_vec(&response)?;
- payload.push(b'\n');
- codec.write_all(&payload).await?;
+ let msg = conn.recv().await?;
+ let response = selfc.handle_request(msg).await;
+ let response = serde_json::to_value(response)?;
debug!("--> {response}");
+ conn.send(response).await?;
}
},
on_failure,
@@ -122,14 +102,13 @@ impl<'a> Server<'a> {
}
/// Handles a request
- async fn handle_request(&self, buffer: &[u8]) -> message::Response {
- let rpc_msg = match serde_json::from_slice::<message::Request>(buffer) {
+ async fn handle_request(&self, msg: serde_json::Value) -> message::Response {
+ let rpc_msg = match serde_json::from_value::<message::Request>(msg) {
Ok(m) => m,
Err(_) => {
return pack_err_res(message::PARSE_ERROR_CODE, FAILED_TO_PARSE_ERROR_MSG, None);
}
};
-
debug!("<-- {rpc_msg}");
let srvc_method: Vec<&str> = rpc_msg.method.split('.').collect();
@@ -144,9 +123,7 @@ impl<'a> Server<'a> {
let srvc_name = srvc_method[0];
let method_name = srvc_method[1];
- let services = self.services.read().await;
-
- let service = match services.get(srvc_name) {
+ let service = match self.services.get(srvc_name) {
Some(s) => s,
None => {
return pack_err_res(
@@ -196,10 +173,108 @@ impl<'a> Server<'a> {
};
message::Response {
- jsonrpc: JSONRPC_VERSION.to_string(),
+ jsonrpc: message::JSONRPC_VERSION.to_string(),
error: None,
result: Some(result),
id: Some(rpc_msg.id),
}
}
}
+
+pub struct ServerBuilder {
+ endpoint: Endpoint,
+ tls_config: Option<rustls::ServerConfig>,
+ services: HashMap<String, Box<dyn RPCService + 'static>>,
+}
+
+impl ServerBuilder {
+ pub fn service(mut self, service: impl RPCService + 'static) -> Self {
+ self.services.insert(service.name(), Box::new(service));
+ self
+ }
+
+ pub fn tls_config(mut self, config: rustls::ServerConfig) -> Result<ServerBuilder> {
+ match self.endpoint {
+ Endpoint::Tcp(..) | Endpoint::Tls(..) | Endpoint::Ws(..) | Endpoint::Wss(..) => {
+ self.tls_config = Some(config);
+ Ok(self)
+ }
+ _ => Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
+ }
+ }
+
+ pub async fn build(self) -> Result<Arc<Server>> {
+ self._build(TaskGroup::new()).await
+ }
+
+ pub async fn build_with_executor(self, ex: Executor) -> Result<Arc<Server>> {
+ self._build(TaskGroup::with_executor(ex)).await
+ }
+
+ async fn _build(self, task_group: TaskGroup) -> Result<Arc<Server>> {
+ let listener: Listener<serde_json::Value> = match self.endpoint {
+ Endpoint::Tcp(..) | Endpoint::Tls(..) => match &self.tls_config {
+ Some(conf) => Box::new(
+ karyon_net::tls::listen(
+ &self.endpoint,
+ karyon_net::tls::ServerTlsConfig {
+ server_config: conf.clone(),
+ tcp_config: Default::default(),
+ },
+ JsonCodec {},
+ )
+ .await?,
+ ),
+ None => Box::new(
+ karyon_net::tcp::listen(&self.endpoint, Default::default(), JsonCodec {})
+ .await?,
+ ),
+ },
+ Endpoint::Ws(..) | Endpoint::Wss(..) => match &self.tls_config {
+ Some(conf) => Box::new(
+ karyon_net::ws::listen(
+ &self.endpoint,
+ karyon_net::ws::ServerWsConfig {
+ tcp_config: Default::default(),
+ wss_config: Some(karyon_net::ws::ServerWssConfig {
+ server_config: conf.clone(),
+ }),
+ },
+ WsJsonCodec {},
+ )
+ .await?,
+ ),
+ None => Box::new(
+ karyon_net::ws::listen(&self.endpoint, Default::default(), WsJsonCodec {})
+ .await?,
+ ),
+ },
+ Endpoint::Unix(..) => Box::new(karyon_net::unix::listen(
+ &self.endpoint,
+ Default::default(),
+ JsonCodec {},
+ )?),
+
+ _ => return Err(Error::UnsupportedProtocol(self.endpoint.to_string())),
+ };
+
+ Ok(Arc::new(Server {
+ listener,
+ task_group,
+ services: self.services,
+ }))
+ }
+}
+
+impl ServerBuilder {}
+
+impl Server {
+ pub fn builder(endpoint: impl ToEndpoint) -> Result<ServerBuilder> {
+ let endpoint = endpoint.to_endpoint()?;
+ Ok(ServerBuilder {
+ endpoint,
+ services: HashMap::new(),
+ tls_config: None,
+ })
+ }
+}