Procházet zdrojové kódy

Add support to async RPC handler.

A new AsyncRpcHandler type is added.

The network is different from other RPC systems, because

1. It must support adding/removing servers on the fly. The service dispatch
table is not static, and thus has to be guarded by a mutex.
2. It must simultaneously support more than one concrete types that handle
RPC. All those types are stored in the internal dispatch table, which must
use type erasure instead of generic.

There are two ways to store an RPC handler in the dispatch table
1. Arc<dyn Fn() + Send + Sync>
2. Box<dyn Factory<dyn FnOnce() + Send>>

Note 'dyn FnOnce()' cannot be clone. We chose the first one because the
second approach needs too many 'dyn's.
Jing Yang před 4 roky
rodič
revize
e925410606
3 změnil soubory, kde provedl 80 přidání a 20 odebrání
  1. 0 1
      src/lib.rs
  2. 78 17
      src/server.rs
  3. 2 2
      src/test_utils/junk_server.rs

+ 0 - 1
src/lib.rs

@@ -25,7 +25,6 @@ mod server;
 pub type Result<T> = std::io::Result<T>;
 pub use client::Client;
 pub use network::Network;
-pub use server::RpcHandler;
 pub use server::Server;
 #[cfg(feature = "tracing")]
 pub use tracing::Trace;

+ 78 - 17
src/server.rs

@@ -1,17 +1,55 @@
 use std::collections::hash_map::Entry::Vacant;
+use std::future::Future;
 use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
+use std::pin::Pin;
 use std::sync::Arc;
 
+use futures::FutureExt;
 use parking_lot::Mutex;
 
 #[cfg(feature = "tracing")]
 use crate::tracing::TraceHolder;
 use crate::{ReplyMessage, RequestMessage, Result, ServerIdentifier};
 
-pub type RpcHandler = dyn Fn(RequestMessage) -> ReplyMessage;
+pub trait RpcHandler:
+    (Fn(RequestMessage) -> ReplyMessage) + Send + Sync + 'static
+{
+}
+
+impl<T> RpcHandler for T where
+    T: (Fn(RequestMessage) -> ReplyMessage) + Send + Sync + 'static
+{
+}
+
+pub trait AsyncRpcHandler:
+    (Fn(
+        RequestMessage,
+    ) -> Pin<Box<dyn Future<Output = ReplyMessage> + Send + 'static>>)
+    + Send
+    + Sync
+    + 'static
+{
+}
+
+impl<T> AsyncRpcHandler for T where
+    T: (Fn(
+            RequestMessage,
+        )
+            -> Pin<Box<dyn Future<Output = ReplyMessage> + Send + 'static>>)
+        + Send
+        + Sync
+        + 'static
+{
+}
+
+#[derive(Clone)]
+enum RpcHandlerType {
+    RpcHandler(Arc<dyn RpcHandler>),
+    AsyncRpcHandler(Arc<dyn AsyncRpcHandler>),
+}
 
 struct ServerState {
-    rpc_handlers: std::collections::HashMap<String, Arc<RpcHandler>>,
+    rpc_handlers: std::collections::HashMap<String, RpcHandlerType>,
     rpc_count: usize,
 }
 
@@ -22,12 +60,6 @@ pub struct Server {
     interrupt: tokio::sync::Notify,
 }
 
-impl Unpin for Server {}
-// Server contains a immutable name, a mutex-protected state, and a thread pool.
-// All of those 3 are `Send` and `Sync`.
-unsafe impl Send for Server {}
-unsafe impl Sync for Server {}
-
 impl Server {
     pub(crate) async fn dispatch(
         self: Arc<Self>,
@@ -39,7 +71,7 @@ impl Server {
         mark_trace!(trace, before_server_scheduling);
         #[cfg(feature = "tracing")]
         let trace_clone = trace.clone();
-        let runner = move || {
+        let runner = async move {
             let rpc_handler = {
                 // Blocking on a mutex in a thread pool. Sounds horrible, but
                 // in fact quite safe, given that the critical section is short.
@@ -49,9 +81,16 @@ impl Server {
             };
             mark_trace!(trace_clone, before_handling);
             let response = match rpc_handler {
-                Some(rpc_handler) => {
-                    Ok(catch_unwind(AssertUnwindSafe(|| rpc_handler(data))))
-                }
+                Some(rpc_handler) => match rpc_handler {
+                    RpcHandlerType::RpcHandler(rpc_handler) => {
+                        Ok(catch_unwind(AssertUnwindSafe(|| rpc_handler(data))))
+                    }
+                    RpcHandlerType::AsyncRpcHandler(rpc_handler) => {
+                        Ok(AssertUnwindSafe(rpc_handler(data))
+                            .catch_unwind()
+                            .await)
+                    }
+                },
                 None => Err(std::io::Error::new(
                     std::io::ErrorKind::InvalidInput,
                     format!(
@@ -61,18 +100,18 @@ impl Server {
                 )),
             };
             mark_trace!(trace_clone, after_handling);
-            return match response {
+            match response {
                 Ok(Ok(response)) => Ok(response),
                 Ok(Err(e)) => resume_unwind(e),
                 Err(e) => Err(e),
-            };
+            }
         };
         let thread_pool = this.thread_pool.as_ref().unwrap();
         // Using spawn() instead of spawn_blocking(), because the spawn() is
         // better at handling a large number of small workloads. Running
         // blocking code on async runner is fine, since all of the tasks we run
         // on this pool are blocking (for a limited time).
-        let result = thread_pool.spawn(async { runner() });
+        let result = thread_pool.spawn(runner);
         mark_trace!(trace, after_server_scheduling);
         let result = tokio::select! {
             result = result => Some(result),
@@ -99,12 +138,34 @@ impl Server {
     pub fn register_rpc_handler(
         &mut self,
         service_method: String,
-        rpc_handler: Box<RpcHandler>,
+        rpc_handler: impl RpcHandler,
+    ) -> Result<()> {
+        self.register_rpc_handler_type(
+            service_method,
+            RpcHandlerType::RpcHandler(Arc::new(rpc_handler)),
+        )
+    }
+
+    pub fn register_async_rpc_handler(
+        &mut self,
+        service_method: String,
+        rpc_handler: impl AsyncRpcHandler,
+    ) -> Result<()> {
+        self.register_rpc_handler_type(
+            service_method,
+            RpcHandlerType::AsyncRpcHandler(Arc::new(rpc_handler)),
+        )
+    }
+
+    fn register_rpc_handler_type(
+        &mut self,
+        service_method: String,
+        rpc_handler: RpcHandlerType,
     ) -> Result<()> {
         let mut state = self.state.lock();
         let debug_service_method = service_method.clone();
         if let Vacant(vacant) = state.rpc_handlers.entry(service_method) {
-            vacant.insert(Arc::new(rpc_handler));
+            vacant.insert(rpc_handler);
             Ok(())
         } else {
             Err(std::io::Error::new(

+ 2 - 2
src/test_utils/junk_server.rs

@@ -31,11 +31,11 @@ pub fn make_test_server() -> Server {
     server
         .register_rpc_handler(
             JunkRpcs::Echo.name(),
-            Box::new(move |request| {
+            move |request: bytes::Bytes| {
                 let mut reply = bytes::BytesMut::from(request.as_ref());
                 reply.reverse();
                 reply.freeze()
-            }),
+            },
         )
         .expect("Registering the first RPC handler should not fail");
     server