Sfoglia il codice sorgente

Allow server to return early when it is removed.

Jing Yang 4 anni fa
parent
commit
99f258f104
3 ha cambiato i file con 37 aggiunte e 28 eliminazioni
  1. 1 1
      Cargo.toml
  2. 14 23
      src/network.rs
  3. 22 4
      src/server.rs

+ 1 - 1
Cargo.toml

@@ -12,7 +12,7 @@ crossbeam-channel = "0.5.0"
 futures = { version = "0.3.8", default-features = false, features = ["thread-pool"] }
 parking_lot = "0.11.1"
 rand = "0.8.0"
-tokio = { version = "1.0", features = ["rt-multi-thread", "time", "parking_lot"] }
+tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "sync", "time", "parking_lot"] }
 
 [features]
 default = []

+ 14 - 23
src/network.rs

@@ -89,7 +89,10 @@ impl Network {
     }
 
     pub fn remove_server<S: AsRef<str>>(&mut self, server_name: S) {
-        self.servers.remove(server_name.as_ref());
+        let server = self.servers.remove(server_name.as_ref());
+        if let Some(server) = server {
+            server.interrupt();
+        }
     }
 
     pub fn get_rpc_count<S: AsRef<str>>(
@@ -186,13 +189,11 @@ impl Network {
             }
         }
 
-        let mut lookup_result = None;
         let reply = match server_result {
             // Call the server.
             Ok(server) => {
                 // Simulates the copy from network to server.
                 let data = rpc.request.clone();
-                lookup_result.replace(server.clone());
                 // No need to set timeout. The RPCs are not supposed to block.
                 mark_trace!(rpc.trace, before_serving);
                 #[cfg(not(feature = "tracing"))]
@@ -223,26 +224,13 @@ impl Network {
         };
         mark_trace!(rpc.trace, after_serving);
 
+        let client = &rpc.client;
+        let reply = reply.and_then(|reply| {
+            // Fail the RPC if the client has been disconnected.
+            network.lock().dispatch(client).map(|_| reply)
+        });
+
         if reply.is_ok() {
-            // The lookup must have succeeded.
-            let lookup_result = lookup_result.unwrap();
-            // The server's address must have been changed, given that we take
-            // ownership of the server when it is registered.
-            let unchanged = match network.lock().dispatch(&rpc.client) {
-                Ok(server) => Arc::ptr_eq(&server, &lookup_result),
-                Err(_) => false,
-            };
-
-            // Fail the RPC if the client has been disconnected, or the server
-            // has been updated.
-            if !unchanged {
-                let _ = rpc.reply_channel.send(Err(std::io::Error::new(
-                    std::io::ErrorKind::ConnectionReset,
-                    "Network connection has been reset.".to_owned(),
-                )));
-                mark_trace!(rpc.trace, served);
-                return;
-            }
             // Random drop again.
             if !reliable
                 && thread_rng().gen_ratio(Self::DROP_RATE.0, Self::DROP_RATE.1)
@@ -534,7 +522,10 @@ mod tests {
 
         let err = reply
             .expect_err("Client should receive error after server is killed");
-        assert_eq!(std::io::ErrorKind::ConnectionReset, err.kind());
+        assert!(
+            std::io::ErrorKind::ConnectionReset == err.kind()
+                || std::io::ErrorKind::NotFound == err.kind()
+        );
 
         Ok(())
     }

+ 22 - 4
src/server.rs

@@ -18,6 +18,7 @@ pub struct Server {
     name: String,
     state: Mutex<ServerState>,
     thread_pool: futures::executor::ThreadPool,
+    interrupt: tokio::sync::Notify,
 }
 
 impl Unpin for Server {}
@@ -67,12 +68,24 @@ impl Server {
             mark_trace!(trace_clone, handler_response);
         });
         mark_trace!(trace, after_server_scheduling);
-        let ret = rx.await.map_err(|_e| {
-            std::io::Error::new(
+        let result = tokio::select! {
+            result = rx => Some(result),
+            _ = this.interrupt.notified() => None,
+        };
+        let ret = match result {
+            Some(Ok(ret)) => ret,
+            Some(Err(_)) => Err(std::io::Error::new(
                 std::io::ErrorKind::ConnectionReset,
                 format!("Remote server {} cancelled the RPC.", this.name),
-            )
-        })?;
+            )),
+            None => {
+                // Fail the RPC if the server has been terminated.
+                Err(std::io::Error::new(
+                    std::io::ErrorKind::ConnectionReset,
+                    "Network connection has been reset.".to_owned(),
+                ))
+            }
+        };
         mark_trace!(trace, server_response);
         ret
     }
@@ -102,6 +115,10 @@ impl Server {
         self.state.lock().rpc_count.get()
     }
 
+    pub fn interrupt(&self) {
+        self.interrupt.notify_waiters();
+    }
+
     pub fn make_server<S: Into<ServerIdentifier>>(name: S) -> Self {
         let state = Mutex::new(ServerState {
             rpc_handlers: std::collections::HashMap::new(),
@@ -117,6 +134,7 @@ impl Server {
             name,
             state,
             thread_pool,
+            interrupt: tokio::sync::Notify::new(),
         }
     }
 }