Преглед изворни кода

Use a shared pool for all servers in a network.

The server panic will not kill a tokio pool thread, so they are safe
to run in the same pool.
Jing Yang пре 4 година
родитељ
комит
e6e351a577
6 измењених фајлова са 42 додато и 33 уклоњено
  1. 5 1
      Cargo.toml
  2. 0 1
      src/lib.rs
  3. 10 1
      src/network.rs
  4. 17 28
      src/server.rs
  5. 10 0
      src/test_utils/junk_server.rs
  6. 0 2
      src/tracing.rs

+ 5 - 1
Cargo.toml

@@ -12,11 +12,15 @@ homepage = "https://github.com/ditsing/labrpc"
 [dependencies]
 bytes = "1.0"
 crossbeam-channel = "0.5.1"
-futures = { version = "0.3.15", default-features = false, features = ["thread-pool"] }
+futures = { version = "0.3.15", default-features = false }
 parking_lot = "0.11.1"
 rand = "0.8.0"
 tokio = { version = "1.7", features = ["macros", "rt-multi-thread", "sync", "time", "parking_lot"] }
 
+[dev-dependencies]
+futures = { version = "0.3.15", default-features = false, features = ["thread-pool"] }
+lazy_static = "1.4.0"
+
 [features]
 default = []
 tracing = []

+ 0 - 1
src/lib.rs

@@ -1,6 +1,5 @@
 extern crate bytes;
 extern crate crossbeam_channel;
-extern crate futures;
 extern crate rand;
 extern crate tokio;
 

+ 10 - 1
src/network.rs

@@ -25,6 +25,8 @@ pub struct Network {
 
     // Network bus
     request_bus: Sender<Option<RpcOnWire>>,
+    // Server thread pool,
+    server_pool: tokio::runtime::Runtime,
 
     // Closing signal.
     keep_running: bool,
@@ -82,8 +84,9 @@ impl Network {
     pub fn add_server<S: Into<ServerIdentifier>>(
         &mut self,
         server_name: S,
-        server: Server,
+        mut server: Server,
     ) {
+        server.use_pool(self.server_pool.handle().clone());
         self.servers.insert(server_name.into(), Arc::new(server));
     }
 
@@ -325,6 +328,11 @@ impl Network {
     }
 
     fn new() -> (Self, Receiver<Option<RpcOnWire>>) {
+        // Server thread pool
+        let server_pool = tokio::runtime::Builder::new_multi_thread()
+            .thread_name("server-pool")
+            .build()
+            .expect("Creating server thread pool should not fail");
         // The channel has infinite buffer, could OOM the server if there are
         // too many pending RPCs to be served.
         let (tx, rx) = crossbeam_channel::unbounded();
@@ -335,6 +343,7 @@ impl Network {
             clients: Default::default(),
             servers: Default::default(),
             request_bus: tx,
+            server_pool,
             keep_running: true,
             stopped: Default::default(),
             rpc_count: std::cell::Cell::new(0),

+ 17 - 28
src/server.rs

@@ -1,5 +1,5 @@
 use std::collections::hash_map::Entry::Vacant;
-use std::panic::{catch_unwind, AssertUnwindSafe};
+use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
 use std::sync::Arc;
 
 use parking_lot::Mutex;
@@ -18,7 +18,7 @@ struct ServerState {
 pub struct Server {
     name: String,
     state: Mutex<ServerState>,
-    thread_pool: futures::executor::ThreadPool,
+    thread_pool: Option<tokio::runtime::Handle>,
     interrupt: tokio::sync::Notify,
 }
 
@@ -29,19 +29,17 @@ unsafe impl Send for Server {}
 unsafe impl Sync for Server {}
 
 impl Server {
-    const THREAD_POOL_SIZE: usize = 4;
-    pub async fn dispatch(
+    pub(crate) async fn dispatch(
         self: Arc<Self>,
         service_method: String,
         data: RequestMessage,
         #[cfg(feature = "tracing")] trace: TraceHolder,
     ) -> Result<ReplyMessage> {
-        let (tx, rx) = futures::channel::oneshot::channel();
         let this = self.clone();
         mark_trace!(trace, before_server_scheduling);
         #[cfg(feature = "tracing")]
         let trace_clone = trace.clone();
-        this.thread_pool.spawn_ok(async move {
+        let result = this.thread_pool.as_ref().unwrap().spawn(async move {
             let rpc_handler = {
                 // Blocking on a mutex in a thread pool. Sounds horrible, but
                 // in fact quite safe, given that the critical section is short.
@@ -52,13 +50,7 @@ impl Server {
             mark_trace!(trace_clone, before_handling);
             let response = match rpc_handler {
                 Some(rpc_handler) => {
-                    match catch_unwind(AssertUnwindSafe(|| rpc_handler(data))) {
-                        Ok(result) => Ok(result),
-                        Err(_) => {
-                            drop(tx);
-                            return;
-                        }
-                    }
+                    Ok(catch_unwind(AssertUnwindSafe(|| rpc_handler(data))))
                 }
                 None => Err(std::io::Error::new(
                     std::io::ErrorKind::InvalidInput,
@@ -69,16 +61,15 @@ impl Server {
                 )),
             };
             mark_trace!(trace_clone, after_handling);
-            #[allow(clippy::redundant_pattern_matching)]
-            if let Err(_) = tx.send(response) {
-                // Receiving end is dropped. Never mind.
-                // Do nothing.
-            }
-            mark_trace!(trace_clone, handler_response);
+            return match response {
+                Ok(Ok(response)) => Ok(response),
+                Ok(Err(e)) => resume_unwind(e),
+                Err(e) => Err(e),
+            };
         });
         mark_trace!(trace, after_server_scheduling);
         let result = tokio::select! {
-            result = rx => Some(result),
+            result = result => Some(result),
             _ = this.interrupt.notified() => None,
         };
         let ret = match result {
@@ -133,19 +124,17 @@ impl Server {
             rpc_handlers: std::collections::HashMap::new(),
             rpc_count: 0,
         });
-        let name = name.into();
-        let thread_pool = futures::executor::ThreadPool::builder()
-            .name_prefix(name.clone())
-            .pool_size(Self::THREAD_POOL_SIZE)
-            .create()
-            .expect("Creating thread pools should not fail");
         Self {
-            name,
+            name: name.into(),
             state,
-            thread_pool,
+            thread_pool: None,
             interrupt: tokio::sync::Notify::new(),
         }
     }
+
+    pub(crate) fn use_pool(&mut self, thread_pool: tokio::runtime::Handle) {
+        self.thread_pool = Some(thread_pool);
+    }
 }
 
 #[cfg(test)]

+ 10 - 0
src/test_utils/junk_server.rs

@@ -1,3 +1,5 @@
+use lazy_static::lazy_static;
+
 use crate::Server;
 
 pub const TEST_SERVER: &str = &"test-server";
@@ -42,5 +44,13 @@ pub fn make_test_server() -> Server {
             Box::new(move |_| panic!("Aborting rpc...")),
         )
         .expect("Registering the second RPC handler should not fail");
+    lazy_static! {
+        static ref DEFAULT_RUNTIME: tokio::runtime::Runtime =
+            tokio::runtime::Builder::new_multi_thread()
+                .build()
+                .expect("Build server default runtime should not fail");
+    }
+    server.use_pool(DEFAULT_RUNTIME.handle().clone());
+
     server
 }

+ 0 - 2
src/tracing.rs

@@ -23,7 +23,6 @@ pub struct Trace {
     pub after_server_scheduling: Duration,
     pub before_handling: Duration,
     pub after_handling: Duration,
-    pub handler_response: Duration,
     pub server_response: Duration,
     pub after_server_response: Duration,
     /// The delay of when the request is served by the server.
@@ -55,7 +54,6 @@ impl Trace {
             after_server_scheduling: Default::default(),
             before_handling: Default::default(),
             after_handling: Default::default(),
-            handler_response: Default::default(),
             server_response: Default::default(),
             after_server_response: Default::default(),
             after_serving: Default::default(),