Kaynağa Gözat

Move the join logic to a standalone handler.

Jing Yang 3 yıl önce
ebeveyn
işleme
c71fd0fa03

+ 1 - 1
kvraft/src/server.rs

@@ -520,7 +520,7 @@ impl KVServer {
         // We must drop self to remove the only clone of raft, so that
         // `rf.kill()` does not block.
         drop(self);
-        rf.kill();
+        rf.kill().join();
         // The process_command thread will exit, after Raft drops the reference
         // to the sender.
     }

+ 42 - 20
src/raft.rs

@@ -14,9 +14,7 @@ use crate::persister::PersistedRaftState;
 use crate::snapshot::{RequestSnapshotFnMut, SnapshotDaemon};
 use crate::sync_log_entries::SyncLogEntriesComms;
 use crate::verify_authority::VerifyAuthorityDaemon;
-use crate::{
-    utils, IndexTerm, Persister, RaftState, RemoteRaft, ReplicableCommand,
-};
+use crate::{IndexTerm, Persister, RaftState, RemoteRaft, ReplicableCommand};
 
 #[derive(
     Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize,
@@ -43,9 +41,10 @@ pub struct Raft<Command> {
     pub(crate) verify_authority_daemon: VerifyAuthorityDaemon,
     pub(crate) heartbeats_daemon: HeartbeatsDaemon,
 
-    pub(crate) thread_pool: utils::ThreadPoolHolder,
-    pub(crate) daemon_watch: DaemonWatch,
+    pub(crate) thread_pool: tokio::runtime::Handle,
 
+    pub(crate) runtime: Arc<Mutex<Option<tokio::runtime::Runtime>>>,
+    pub(crate) daemon_watch: DaemonWatch,
     pub(crate) daemon_env: DaemonEnv,
 }
 
@@ -122,7 +121,8 @@ impl<Command: ReplicableCommand> Raft<Command> {
             snapshot_daemon: SnapshotDaemon::create(),
             verify_authority_daemon: VerifyAuthorityDaemon::create(peer_size),
             heartbeats_daemon: HeartbeatsDaemon::create(),
-            thread_pool: utils::ThreadPoolHolder::new(thread_pool),
+            thread_pool: thread_pool.handle().clone(),
+            runtime: Arc::new(Mutex::new(Some(thread_pool))),
             daemon_watch,
             daemon_env,
         };
@@ -171,12 +171,9 @@ impl<Command: ReplicableCommand> Raft<Command> {
         Some(IndexTerm::pack(index, term))
     }
 
-    const SHUTDOWN_TIMEOUT: Duration =
-        Duration::from_millis(HEARTBEAT_INTERVAL.as_millis() as u64 * 2);
-
     /// Cleanly shutdown this instance. This function never blocks forever. It
     /// either panics or returns eventually.
-    pub fn kill(self) {
+    pub fn kill(self) -> RaftJoinHandle {
         self.keep_running.store(false, Ordering::Release);
         self.election.stop_election_timer();
         self.sync_log_entries_comms.kill();
@@ -184,16 +181,12 @@ impl<Command: ReplicableCommand> Raft<Command> {
         self.snapshot_daemon.kill();
         self.verify_authority_daemon.kill();
 
-        self.daemon_watch.wait_for_daemons();
-        self.thread_pool
-            .take()
-            .expect(
-                "All references to the thread pool should have been dropped.",
-            )
-            .shutdown_timeout(Self::SHUTDOWN_TIMEOUT);
-        // DaemonEnv must be shutdown after the thread pool, since there might
-        // be tasks logging errors in the pool.
-        self.daemon_env.shutdown();
+        let runtime = self.runtime.lock().take().unwrap();
+        RaftJoinHandle {
+            runtime,
+            daemon_watch: self.daemon_watch,
+            daemon_env: self.daemon_env,
+        }
     }
 
     /// Returns the current term and whether we are the leader.
@@ -206,6 +199,35 @@ impl<Command: ReplicableCommand> Raft<Command> {
     }
 }
 
+/// A join handle returned by `Raft::kill()`. Join this handle to cleanly
+/// shutdown a Raft instance.
+///
+/// All clones of the same Raft instance created by `Raft::clone()` must be
+/// dropped before `RaftJoinHandle::join()` can return.
+///
+/// After `RaftJoinHandle::join()` returns, all threads and thread pools created
+/// by this Raft instance will have stopped. No callbacks will be called. No new
+/// commits will be created by this Raft instance.
+#[must_use]
+pub struct RaftJoinHandle {
+    thread_pool: tokio::runtime::Runtime,
+    daemon_watch: DaemonWatch,
+    daemon_env: DaemonEnv,
+}
+
+impl RaftJoinHandle {
+    const SHUTDOWN_TIMEOUT: Duration =
+        Duration::from_millis(HEARTBEAT_INTERVAL.as_millis() as u64 * 2);
+
+    pub fn join(self) {
+        self.daemon_watch.wait_for_daemons();
+        self.runtime.shutdown_timeout(Self::SHUTDOWN_TIMEOUT);
+        // DaemonEnv must be shutdown after the thread pool, since there might
+        // be tasks logging errors in the pool.
+        self.daemon_env.shutdown();
+    }
+}
+
 #[cfg(test)]
 mod tests {
     #[test]

+ 0 - 2
src/utils/mod.rs

@@ -1,8 +1,6 @@
 pub use rpcs::{retry_rpc, RPC_DEADLINE};
 pub use shared_sender::SharedSender;
-pub use thread_pool_holder::ThreadPoolHolder;
 
 pub mod integration_test;
 mod rpcs;
 mod shared_sender;
-mod thread_pool_holder;

+ 0 - 37
src/utils/thread_pool_holder.rs

@@ -1,37 +0,0 @@
-lazy_static::lazy_static! {
-    static ref THREAD_POOLS: parking_lot::Mutex<std::collections::HashMap<u64, tokio::runtime::Runtime>> =
-        parking_lot::Mutex::new(std::collections::HashMap::new());
-}
-
-#[derive(Clone)]
-pub struct ThreadPoolHolder {
-    id: u64,
-    handle: tokio::runtime::Handle,
-}
-
-impl ThreadPoolHolder {
-    pub fn new(runtime: tokio::runtime::Runtime) -> Self {
-        let handle = runtime.handle().clone();
-        loop {
-            let id: u64 = rand::random();
-            if let std::collections::hash_map::Entry::Vacant(v) =
-                THREAD_POOLS.lock().entry(id)
-            {
-                v.insert(runtime);
-                break Self { id, handle };
-            }
-        }
-    }
-
-    pub fn take(self) -> Option<tokio::runtime::Runtime> {
-        THREAD_POOLS.lock().remove(&self.id)
-    }
-}
-
-impl std::ops::Deref for ThreadPoolHolder {
-    type Target = tokio::runtime::Handle;
-
-    fn deref(&self) -> &Self::Target {
-        &self.handle
-    }
-}

+ 2 - 2
test_configs/src/raft/config.rs

@@ -286,7 +286,7 @@ impl Config {
         // might directly or indirectly block on the log lock, e.g. through
         // the apply command function.
         if let Some(raft) = raft {
-            raft.kill();
+            raft.kill().join();
         }
         let mut log = self.log.lock();
         log.saved[index] = Arc::new(crate::Persister::new());
@@ -371,7 +371,7 @@ impl Config {
         drop(network);
         for raft in &mut self.state.lock().rafts {
             if let Some(raft) = raft.take() {
-                raft.kill();
+                raft.kill().join();
             }
         }
         log::trace!("Cleaning up test raft.config done.");