|
|
@@ -1,17 +1,55 @@
|
|
|
use std::collections::hash_map::Entry::Vacant;
|
|
|
+use std::future::Future;
|
|
|
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
|
|
|
+use std::pin::Pin;
|
|
|
use std::sync::Arc;
|
|
|
|
|
|
+use futures::FutureExt;
|
|
|
use parking_lot::Mutex;
|
|
|
|
|
|
#[cfg(feature = "tracing")]
|
|
|
use crate::tracing::TraceHolder;
|
|
|
use crate::{ReplyMessage, RequestMessage, Result, ServerIdentifier};
|
|
|
|
|
|
-pub type RpcHandler = dyn Fn(RequestMessage) -> ReplyMessage;
|
|
|
+pub trait RpcHandler:
|
|
|
+ (Fn(RequestMessage) -> ReplyMessage) + Send + Sync + 'static
|
|
|
+{
|
|
|
+}
|
|
|
+
|
|
|
+impl<T> RpcHandler for T where
|
|
|
+ T: (Fn(RequestMessage) -> ReplyMessage) + Send + Sync + 'static
|
|
|
+{
|
|
|
+}
|
|
|
+
|
|
|
+pub trait AsyncRpcHandler:
|
|
|
+ (Fn(
|
|
|
+ RequestMessage,
|
|
|
+ ) -> Pin<Box<dyn Future<Output = ReplyMessage> + Send + 'static>>)
|
|
|
+ + Send
|
|
|
+ + Sync
|
|
|
+ + 'static
|
|
|
+{
|
|
|
+}
|
|
|
+
|
|
|
+impl<T> AsyncRpcHandler for T where
|
|
|
+ T: (Fn(
|
|
|
+ RequestMessage,
|
|
|
+ )
|
|
|
+ -> Pin<Box<dyn Future<Output = ReplyMessage> + Send + 'static>>)
|
|
|
+ + Send
|
|
|
+ + Sync
|
|
|
+ + 'static
|
|
|
+{
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Clone)]
|
|
|
+enum RpcHandlerType {
|
|
|
+ RpcHandler(Arc<dyn RpcHandler>),
|
|
|
+ AsyncRpcHandler(Arc<dyn AsyncRpcHandler>),
|
|
|
+}
|
|
|
|
|
|
struct ServerState {
|
|
|
- rpc_handlers: std::collections::HashMap<String, Arc<RpcHandler>>,
|
|
|
+ rpc_handlers: std::collections::HashMap<String, RpcHandlerType>,
|
|
|
rpc_count: usize,
|
|
|
}
|
|
|
|
|
|
@@ -22,12 +60,6 @@ pub struct Server {
|
|
|
interrupt: tokio::sync::Notify,
|
|
|
}
|
|
|
|
|
|
-impl Unpin for Server {}
|
|
|
-// Server contains a immutable name, a mutex-protected state, and a thread pool.
|
|
|
-// All of those 3 are `Send` and `Sync`.
|
|
|
-unsafe impl Send for Server {}
|
|
|
-unsafe impl Sync for Server {}
|
|
|
-
|
|
|
impl Server {
|
|
|
pub(crate) async fn dispatch(
|
|
|
self: Arc<Self>,
|
|
|
@@ -39,7 +71,7 @@ impl Server {
|
|
|
mark_trace!(trace, before_server_scheduling);
|
|
|
#[cfg(feature = "tracing")]
|
|
|
let trace_clone = trace.clone();
|
|
|
- let runner = move || {
|
|
|
+ let runner = async move {
|
|
|
let rpc_handler = {
|
|
|
// Blocking on a mutex in a thread pool. Sounds horrible, but
|
|
|
// in fact quite safe, given that the critical section is short.
|
|
|
@@ -49,9 +81,16 @@ impl Server {
|
|
|
};
|
|
|
mark_trace!(trace_clone, before_handling);
|
|
|
let response = match rpc_handler {
|
|
|
- Some(rpc_handler) => {
|
|
|
- Ok(catch_unwind(AssertUnwindSafe(|| rpc_handler(data))))
|
|
|
- }
|
|
|
+ Some(rpc_handler) => match rpc_handler {
|
|
|
+ RpcHandlerType::RpcHandler(rpc_handler) => {
|
|
|
+ Ok(catch_unwind(AssertUnwindSafe(|| rpc_handler(data))))
|
|
|
+ }
|
|
|
+ RpcHandlerType::AsyncRpcHandler(rpc_handler) => {
|
|
|
+ Ok(AssertUnwindSafe(rpc_handler(data))
|
|
|
+ .catch_unwind()
|
|
|
+ .await)
|
|
|
+ }
|
|
|
+ },
|
|
|
None => Err(std::io::Error::new(
|
|
|
std::io::ErrorKind::InvalidInput,
|
|
|
format!(
|
|
|
@@ -61,18 +100,18 @@ impl Server {
|
|
|
)),
|
|
|
};
|
|
|
mark_trace!(trace_clone, after_handling);
|
|
|
- return match response {
|
|
|
+ match response {
|
|
|
Ok(Ok(response)) => Ok(response),
|
|
|
Ok(Err(e)) => resume_unwind(e),
|
|
|
Err(e) => Err(e),
|
|
|
- };
|
|
|
+ }
|
|
|
};
|
|
|
let thread_pool = this.thread_pool.as_ref().unwrap();
|
|
|
// Using spawn() instead of spawn_blocking(), because the spawn() is
|
|
|
// better at handling a large number of small workloads. Running
|
|
|
// blocking code on async runner is fine, since all of the tasks we run
|
|
|
// on this pool are blocking (for a limited time).
|
|
|
- let result = thread_pool.spawn(async { runner() });
|
|
|
+ let result = thread_pool.spawn(runner);
|
|
|
mark_trace!(trace, after_server_scheduling);
|
|
|
let result = tokio::select! {
|
|
|
result = result => Some(result),
|
|
|
@@ -99,12 +138,34 @@ impl Server {
|
|
|
pub fn register_rpc_handler(
|
|
|
&mut self,
|
|
|
service_method: String,
|
|
|
- rpc_handler: Box<RpcHandler>,
|
|
|
+ rpc_handler: impl RpcHandler,
|
|
|
+ ) -> Result<()> {
|
|
|
+ self.register_rpc_handler_type(
|
|
|
+ service_method,
|
|
|
+ RpcHandlerType::RpcHandler(Arc::new(rpc_handler)),
|
|
|
+ )
|
|
|
+ }
|
|
|
+
|
|
|
+ pub fn register_async_rpc_handler(
|
|
|
+ &mut self,
|
|
|
+ service_method: String,
|
|
|
+ rpc_handler: impl AsyncRpcHandler,
|
|
|
+ ) -> Result<()> {
|
|
|
+ self.register_rpc_handler_type(
|
|
|
+ service_method,
|
|
|
+ RpcHandlerType::AsyncRpcHandler(Arc::new(rpc_handler)),
|
|
|
+ )
|
|
|
+ }
|
|
|
+
|
|
|
+ fn register_rpc_handler_type(
|
|
|
+ &mut self,
|
|
|
+ service_method: String,
|
|
|
+ rpc_handler: RpcHandlerType,
|
|
|
) -> Result<()> {
|
|
|
let mut state = self.state.lock();
|
|
|
let debug_service_method = service_method.clone();
|
|
|
if let Vacant(vacant) = state.rpc_handlers.entry(service_method) {
|
|
|
- vacant.insert(Arc::new(rpc_handler));
|
|
|
+ vacant.insert(rpc_handler);
|
|
|
Ok(())
|
|
|
} else {
|
|
|
Err(std::io::Error::new(
|