Bläddra i källkod

Merge branch 'join_handle'

Jing Yang 3 år sedan
förälder
incheckning
ba700950cc

+ 2 - 3
durio/src/raft_service.rs

@@ -1,6 +1,5 @@
 use std::future::Future;
 use std::net::SocketAddr;
-use std::sync::Arc;
 
 use async_trait::async_trait;
 use tarpc::context::Context;
@@ -23,7 +22,7 @@ pub(crate) trait RaftService {
 }
 
 #[derive(Clone)]
-struct RaftRpcServer(Arc<Raft<UniqueKVOp>>);
+struct RaftRpcServer(Raft<UniqueKVOp>);
 
 #[tarpc::server]
 impl RaftService for RaftRpcServer {
@@ -125,7 +124,7 @@ pub(crate) async fn connect_to_raft_service(
 
 pub(crate) fn start_raft_service_server(
     addr: SocketAddr,
-    raft: Arc<Raft<UniqueKVOp>>,
+    raft: Raft<UniqueKVOp>,
 ) -> impl Future<Output = std::io::Result<()>> {
     let server = RaftRpcServer(raft);
     crate::utils::start_tarpc_server(addr, server.serve())

+ 1 - 1
durio/src/run.rs

@@ -22,7 +22,7 @@ pub(crate) async fn run_kv_instance(
     let persister = Arc::new(DoNothingPersister::default());
 
     let kv_server = KVServer::new(remote_rafts, me, persister, None);
-    let raft = Arc::new(kv_server.raft().clone());
+    let raft = kv_server.raft().clone();
 
     start_raft_service_server(local_raft_peer, raft).await?;
     start_kv_service_server(addr, kv_server.clone()).await?;

+ 1 - 1
kvraft/src/server.rs

@@ -520,7 +520,7 @@ impl KVServer {
         // We must drop self to remove the only clone of raft, so that
         // `rf.kill()` does not block.
         drop(self);
-        rf.kill();
+        rf.kill().join();
         // The process_command thread will exit, after Raft drops the reference
         // to the sender.
     }

+ 3 - 6
src/apply_command.rs

@@ -1,6 +1,5 @@
 use std::sync::atomic::Ordering;
 
-use crate::daemon_watch::Daemon;
 use crate::heartbeats::HEARTBEAT_INTERVAL;
 use crate::{Index, Raft, ReplicableCommand, Snapshot};
 
@@ -49,13 +48,13 @@ impl<Command: ReplicableCommand> Raft<Command> {
     pub(crate) fn run_apply_command_daemon(
         &self,
         mut apply_command: impl ApplyCommandFnMut<Command>,
-    ) {
+    ) -> impl FnOnce() {
         let keep_running = self.keep_running.clone();
         let me = self.me;
         let rf = self.inner_state.clone();
         let condvar = self.apply_command_signal.clone();
         let snapshot_daemon = self.snapshot_daemon.clone();
-        let apply_command_daemon = move || {
+        move || {
             log::info!("{:?} apply command daemon running ...", me);
 
             while keep_running.load(Ordering::Relaxed) {
@@ -110,8 +109,6 @@ impl<Command: ReplicableCommand> Raft<Command> {
                 }
             }
             log::info!("{:?} apply command daemon done.", me);
-        };
-        self.daemon_watch
-            .create_daemon(Daemon::ApplyCommand, apply_command_daemon);
+        }
     }
 }

+ 6 - 10
src/daemon_watch.rs

@@ -1,7 +1,5 @@
 use crate::daemon_env::ThreadEnv;
 use crossbeam_utils::sync::WaitGroup;
-use parking_lot::Mutex;
-use std::sync::Arc;
 
 #[derive(Debug)]
 pub(crate) enum Daemon {
@@ -17,10 +15,8 @@ pub(crate) enum Daemon {
 /// [`DaemonWatch`] manages daemon threads and makes sure that panics are
 /// recorded during shutdown. It collects daemon panics and send them to
 /// [`crate::DaemonEnv`].
-#[derive(Clone)]
 pub(crate) struct DaemonWatch {
-    #[allow(clippy::type_complexity)]
-    daemons: Arc<Mutex<Vec<(Daemon, std::thread::JoinHandle<()>)>>>,
+    daemons: Vec<(Daemon, std::thread::JoinHandle<()>)>,
     thread_env: ThreadEnv,
     stop_wait_group: WaitGroup,
 }
@@ -28,7 +24,7 @@ pub(crate) struct DaemonWatch {
 impl DaemonWatch {
     pub fn create(thread_env: ThreadEnv) -> Self {
         Self {
-            daemons: Arc::new(Mutex::new(vec![])),
+            daemons: vec![],
             thread_env,
             stop_wait_group: WaitGroup::new(),
         }
@@ -36,7 +32,7 @@ impl DaemonWatch {
 
     /// Register a daemon thread to make sure it is correctly shutdown when the
     /// Raft instance is killed.
-    pub fn create_daemon<F, T>(&self, daemon: Daemon, func: F)
+    pub fn create_daemon<F, T>(&mut self, daemon: Daemon, func: F)
     where
         F: FnOnce() -> T,
         F: Send + 'static,
@@ -53,13 +49,13 @@ impl DaemonWatch {
                 drop(stop_wait_group);
             })
             .expect("Creating daemon thread should never fail");
-        self.daemons.lock().push((daemon, thread));
+        self.daemons.push((daemon, thread));
     }
 
     pub fn wait_for_daemons(self) {
         self.stop_wait_group.wait();
         self.thread_env.attach();
-        for (daemon, join_handle) in self.daemons.lock().drain(..) {
+        for (daemon, join_handle) in self.daemons.into_iter() {
             if let Some(err) = join_handle.join().err() {
                 let err_str = err.downcast_ref::<&str>().map(|s| s.to_owned());
                 let err_string =
@@ -81,7 +77,7 @@ mod tests {
     #[test]
     fn test_watch_daemon_shutdown() {
         let daemon_env = DaemonEnv::create();
-        let daemon_watch = DaemonWatch::create(daemon_env.for_thread());
+        let mut daemon_watch = DaemonWatch::create(daemon_env.for_thread());
         daemon_watch.create_daemon(Daemon::ApplyCommand, || {
             panic!("message with type &str");
         });

+ 3 - 6
src/election.rs

@@ -5,7 +5,6 @@ use std::time::{Duration, Instant};
 use parking_lot::{Condvar, Mutex};
 use rand::{thread_rng, Rng};
 
-use crate::daemon_watch::Daemon;
 use crate::sync_log_entries::SyncLogEntriesComms;
 use crate::term_marker::TermMarker;
 use crate::utils::{retry_rpc, RPC_DEADLINE};
@@ -127,10 +126,10 @@ impl<Command: ReplicableCommand> Raft<Command> {
     /// election timer. There could be more than one vote-counting tasks running
     /// at the same time, but all earlier tasks except the newest one will
     /// eventually realize the term they were competing for has passed and quit.
-    pub(crate) fn run_election_timer(&self) {
+    pub(crate) fn run_election_timer(&self) -> impl FnOnce() {
         let this = self.clone();
 
-        let election_daemon = move || {
+        move || {
             log::info!("{:?} election timer daemon running ...", this.me);
 
             let election = this.election.clone();
@@ -213,9 +212,7 @@ impl<Command: ReplicableCommand> Raft<Command> {
             }
 
             log::info!("{:?} election timer daemon done.", this.me);
-        };
-        self.daemon_watch
-            .create_daemon(Daemon::ElectionTimer, election_daemon);
+        }
     }
 
     fn run_election(

+ 77 - 30
src/raft.rs

@@ -1,3 +1,4 @@
+use crossbeam_utils::sync::WaitGroup;
 use std::sync::atomic::{AtomicBool, Ordering};
 use std::sync::Arc;
 use std::time::Duration;
@@ -7,16 +8,14 @@ use serde_derive::{Deserialize, Serialize};
 
 use crate::apply_command::ApplyCommandFnMut;
 use crate::daemon_env::{DaemonEnv, ThreadEnv};
-use crate::daemon_watch::DaemonWatch;
+use crate::daemon_watch::{Daemon, DaemonWatch};
 use crate::election::ElectionState;
 use crate::heartbeats::{HeartbeatsDaemon, HEARTBEAT_INTERVAL};
 use crate::persister::PersistedRaftState;
 use crate::snapshot::{RequestSnapshotFnMut, SnapshotDaemon};
 use crate::sync_log_entries::SyncLogEntriesComms;
 use crate::verify_authority::VerifyAuthorityDaemon;
-use crate::{
-    utils, IndexTerm, Persister, RaftState, RemoteRaft, ReplicableCommand,
-};
+use crate::{IndexTerm, Persister, RaftState, RemoteRaft, ReplicableCommand};
 
 #[derive(
     Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize,
@@ -43,10 +42,11 @@ pub struct Raft<Command> {
     pub(crate) verify_authority_daemon: VerifyAuthorityDaemon,
     pub(crate) heartbeats_daemon: HeartbeatsDaemon,
 
-    pub(crate) thread_pool: utils::ThreadPoolHolder,
-    pub(crate) daemon_watch: DaemonWatch,
-
+    pub(crate) thread_pool: tokio::runtime::Handle,
     pub(crate) daemon_env: DaemonEnv,
+
+    stop_wait_group: WaitGroup,
+    join_handle: Arc<Mutex<Option<RaftJoinHandle>>>,
 }
 
 impl<Command: ReplicableCommand> Raft<Command> {
@@ -102,7 +102,6 @@ impl<Command: ReplicableCommand> Raft<Command> {
             .on_thread_stop(ThreadEnv::detach)
             .build()
             .expect("Creating thread pool should not fail");
-        let daemon_watch = DaemonWatch::create(daemon_env.for_thread());
         let peers = peers
             .into_iter()
             .map(|r| Arc::new(r) as Arc<dyn RemoteRaft<Command>>)
@@ -122,24 +121,49 @@ impl<Command: ReplicableCommand> Raft<Command> {
             snapshot_daemon: SnapshotDaemon::create(),
             verify_authority_daemon: VerifyAuthorityDaemon::create(peer_size),
             heartbeats_daemon: HeartbeatsDaemon::create(),
-            thread_pool: utils::ThreadPoolHolder::new(thread_pool),
-            daemon_watch,
-            daemon_env,
+            thread_pool: thread_pool.handle().clone(),
+            stop_wait_group: WaitGroup::new(),
+            daemon_env: daemon_env.clone(),
+            // The join handle will be created later.
+            join_handle: Arc::new(Mutex::new(None)),
         };
 
+        let mut daemon_watch = DaemonWatch::create(daemon_env.for_thread());
         // Running in a standalone thread.
-        this.run_verify_authority_daemon();
+        daemon_watch.create_daemon(
+            Daemon::VerifyAuthority,
+            this.run_verify_authority_daemon(),
+        );
         // Running in a standalone thread.
-        this.run_snapshot_daemon(max_state_size_bytes, request_snapshot);
+        daemon_watch.create_daemon(
+            Daemon::Snapshot,
+            this.run_snapshot_daemon(max_state_size_bytes, request_snapshot),
+        );
         // Running in a standalone thread.
-        this.run_log_entry_daemon(sync_log_entries_daemon);
+        daemon_watch.create_daemon(
+            Daemon::SyncLogEntries,
+            this.run_log_entry_daemon(sync_log_entries_daemon),
+        );
         // Running in a standalone thread.
-        this.run_apply_command_daemon(apply_command);
+        daemon_watch.create_daemon(
+            Daemon::ApplyCommand,
+            this.run_apply_command_daemon(apply_command),
+        );
         // One off function that schedules many little tasks, running on the
         // internal thread pool.
         this.schedule_heartbeats(HEARTBEAT_INTERVAL);
         // The last step is to start running election timer.
-        this.run_election_timer();
+        daemon_watch
+            .create_daemon(Daemon::ElectionTimer, this.run_election_timer());
+
+        // Create the join handle
+        this.join_handle.lock().replace(RaftJoinHandle {
+            stop_wait_group: this.stop_wait_group.clone(),
+            thread_pool,
+            daemon_watch,
+            daemon_env,
+        });
+
         this
     }
 }
@@ -171,12 +195,9 @@ impl<Command: ReplicableCommand> Raft<Command> {
         Some(IndexTerm::pack(index, term))
     }
 
-    const SHUTDOWN_TIMEOUT: Duration =
-        Duration::from_millis(HEARTBEAT_INTERVAL.as_millis() as u64 * 2);
-
     /// Cleanly shutdown this instance. This function never blocks forever. It
     /// either panics or returns eventually.
-    pub fn kill(self) {
+    pub fn kill(self) -> RaftJoinHandle {
         self.keep_running.store(false, Ordering::Release);
         self.election.stop_election_timer();
         self.sync_log_entries_comms.kill();
@@ -184,16 +205,7 @@ impl<Command: ReplicableCommand> Raft<Command> {
         self.snapshot_daemon.kill();
         self.verify_authority_daemon.kill();
 
-        self.daemon_watch.wait_for_daemons();
-        self.thread_pool
-            .take()
-            .expect(
-                "All references to the thread pool should have been dropped.",
-            )
-            .shutdown_timeout(Self::SHUTDOWN_TIMEOUT);
-        // DaemonEnv must be shutdown after the thread pool, since there might
-        // be tasks logging errors in the pool.
-        self.daemon_env.shutdown();
+        self.join_handle.lock().take().unwrap()
     }
 
     /// Returns the current term and whether we are the leader.
@@ -206,6 +218,41 @@ impl<Command: ReplicableCommand> Raft<Command> {
     }
 }
 
+/// A join handle returned by `Raft::kill()`. Join this handle to cleanly
+/// shutdown a Raft instance.
+///
+/// All clones of the same Raft instance created by `Raft::clone()` must be
+/// dropped before `RaftJoinHandle::join()` can return.
+///
+/// After `RaftJoinHandle::join()` returns, all threads and thread pools created
+/// by this Raft instance will have stopped. No callbacks will be called. No new
+/// commits will be created by this Raft instance.
+#[must_use]
+pub struct RaftJoinHandle {
+    stop_wait_group: WaitGroup,
+    thread_pool: tokio::runtime::Runtime,
+    daemon_watch: DaemonWatch,
+    daemon_env: DaemonEnv,
+}
+
+impl RaftJoinHandle {
+    const SHUTDOWN_TIMEOUT: Duration =
+        Duration::from_millis(HEARTBEAT_INTERVAL.as_millis() as u64 * 2);
+
+    /// Waits for the Raft instance to shutdown.
+    ///
+    /// See the struct documentation for more details.
+    pub fn join(self) {
+        // Wait for all Raft instances to be dropped.
+        self.stop_wait_group.wait();
+        self.daemon_watch.wait_for_daemons();
+        self.thread_pool.shutdown_timeout(Self::SHUTDOWN_TIMEOUT);
+        // DaemonEnv must be shutdown after the thread pool, since there might
+        // be tasks logging errors in the pool.
+        self.daemon_env.shutdown();
+    }
+}
+
 #[cfg(test)]
 mod tests {
     #[test]

+ 7 - 5
src/snapshot.rs

@@ -6,7 +6,6 @@ use parking_lot::{Condvar, Mutex};
 
 use crate::check_or_record;
 use crate::daemon_env::ErrorKind;
-use crate::daemon_watch::Daemon;
 use crate::{Index, Raft};
 
 #[derive(Clone, Debug, Default)]
@@ -128,10 +127,10 @@ impl<C: 'static + Clone + Send + serde::Serialize> Raft<C> {
         &mut self,
         max_state_size: Option<usize>,
         mut request_snapshot: impl RequestSnapshotFnMut,
-    ) {
+    ) -> impl FnOnce() {
         let max_state_size = match max_state_size {
             Some(max_state_size) => max_state_size,
-            None => return,
+            None => usize::MAX,
         };
 
         let parker = Parker::new();
@@ -218,7 +217,10 @@ impl<C: 'static + Clone + Send + serde::Serialize> Raft<C> {
                 );
             }
         };
-        self.daemon_watch
-            .create_daemon(Daemon::Snapshot, snapshot_daemon);
+        move || {
+            if max_state_size != usize::MAX {
+                snapshot_daemon()
+            }
+        }
     }
 }

+ 3 - 6
src/sync_log_entries.rs

@@ -4,7 +4,6 @@ use std::sync::Arc;
 use parking_lot::{Condvar, Mutex};
 
 use crate::daemon_env::ErrorKind;
-use crate::daemon_watch::Daemon;
 use crate::heartbeats::HEARTBEAT_INTERVAL;
 use crate::peer_progress::PeerProgress;
 use crate::term_marker::TermMarker;
@@ -120,10 +119,10 @@ impl<Command: ReplicableCommand> Raft<Command> {
     pub(crate) fn run_log_entry_daemon(
         &self,
         SyncLogEntriesDaemon { rx, peer_progress }: SyncLogEntriesDaemon,
-    ) {
+    ) -> impl FnOnce() {
         // Clone everything that the thread needs.
         let this = self.clone();
-        let sync_log_entry_daemon = move || {
+        move || {
             log::info!("{:?} sync log entries daemon running ...", this.me);
 
             let mut task_number = 0;
@@ -160,9 +159,7 @@ impl<Command: ReplicableCommand> Raft<Command> {
             }
 
             log::info!("{:?} sync log entries daemon done.", this.me);
-        };
-        self.daemon_watch
-            .create_daemon(Daemon::SyncLogEntries, sync_log_entry_daemon);
+        }
     }
 
     /// Syncs log entries to a peer once, requests a new sync if that fails.

+ 0 - 2
src/utils/mod.rs

@@ -1,8 +1,6 @@
 pub use rpcs::{retry_rpc, RPC_DEADLINE};
 pub use shared_sender::SharedSender;
-pub use thread_pool_holder::ThreadPoolHolder;
 
 pub mod integration_test;
 mod rpcs;
 mod shared_sender;
-mod thread_pool_holder;

+ 0 - 37
src/utils/thread_pool_holder.rs

@@ -1,37 +0,0 @@
-lazy_static::lazy_static! {
-    static ref THREAD_POOLS: parking_lot::Mutex<std::collections::HashMap<u64, tokio::runtime::Runtime>> =
-        parking_lot::Mutex::new(std::collections::HashMap::new());
-}
-
-#[derive(Clone)]
-pub struct ThreadPoolHolder {
-    id: u64,
-    handle: tokio::runtime::Handle,
-}
-
-impl ThreadPoolHolder {
-    pub fn new(runtime: tokio::runtime::Runtime) -> Self {
-        let handle = runtime.handle().clone();
-        loop {
-            let id: u64 = rand::random();
-            if let std::collections::hash_map::Entry::Vacant(v) =
-                THREAD_POOLS.lock().entry(id)
-            {
-                v.insert(runtime);
-                break Self { id, handle };
-            }
-        }
-    }
-
-    pub fn take(self) -> Option<tokio::runtime::Runtime> {
-        THREAD_POOLS.lock().remove(&self.id)
-    }
-}
-
-impl std::ops::Deref for ThreadPoolHolder {
-    type Target = tokio::runtime::Handle;
-
-    fn deref(&self) -> &Self::Target {
-        &self.handle
-    }
-}

+ 3 - 6
src/verify_authority.rs

@@ -7,7 +7,6 @@ use std::time::{Duration, Instant};
 use parking_lot::{Condvar, Mutex};
 
 use crate::beat_ticker::{Beat, SharedBeatTicker};
-use crate::daemon_watch::Daemon;
 use crate::heartbeats::HEARTBEAT_INTERVAL;
 use crate::{Index, Raft, Term};
 
@@ -332,13 +331,13 @@ impl<Command: 'static + Send> Raft<Command> {
     const BEAT_RECORDING_MAX_PAUSE: Duration = Duration::from_millis(20);
 
     /// Create a thread and runs the verify authority daemon.
-    pub(crate) fn run_verify_authority_daemon(&self) {
+    pub(crate) fn run_verify_authority_daemon(&self) -> impl FnOnce() {
         let me = self.me;
         let keep_running = self.keep_running.clone();
         let this_daemon = self.verify_authority_daemon.clone();
         let rf = self.inner_state.clone();
 
-        let verify_authority_daemon = move || {
+        move || {
             log::info!("{:?} verify authority daemon running ...", me);
             while keep_running.load(Ordering::Relaxed) {
                 this_daemon.wait_for(Self::BEAT_RECORDING_MAX_PAUSE);
@@ -350,9 +349,7 @@ impl<Command: 'static + Send> Raft<Command> {
                     .run_verify_authority_iteration(current_term, commit_index);
             }
             log::info!("{:?} verify authority daemon done.", me);
-        };
-        self.daemon_watch
-            .create_daemon(Daemon::VerifyAuthority, verify_authority_daemon);
+        }
     }
 
     /// Create a verify authority request. Returns None if we are not the

+ 1 - 1
test_configs/src/kvraft/config.rs

@@ -58,7 +58,7 @@ impl Config {
             KVServer::new(clients, index, persister, Some(self.maxraftstate));
         self.state.lock().kv_servers[index].replace(kv.clone());
 
-        let raft = Arc::new(kv.raft().clone());
+        let raft = kv.raft().clone();
 
         register_server(raft, Self::server_name(index), self.network.as_ref())?;
 

+ 2 - 3
test_configs/src/raft/config.rs

@@ -286,7 +286,7 @@ impl Config {
         // might directly or indirectly block on the log lock, e.g. through
         // the apply command function.
         if let Some(raft) = raft {
-            raft.kill();
+            raft.kill().join();
         }
         let mut log = self.log.lock();
         log.saved[index] = Arc::new(crate::Persister::new());
@@ -321,7 +321,6 @@ impl Config {
         );
         self.state.lock().rafts[index].replace(raft.clone());
 
-        let raft = Arc::new(raft);
         register_server(raft, Self::server_name(index), self.network.as_ref())?;
         Ok(())
     }
@@ -371,7 +370,7 @@ impl Config {
         drop(network);
         for raft in &mut self.state.lock().rafts {
             if let Some(raft) = raft.take() {
-                raft.kill();
+                raft.kill().join();
             }
         }
         log::trace!("Cleaning up test raft.config done.");

+ 7 - 10
test_configs/src/rpcs.rs

@@ -145,11 +145,10 @@ where
 }
 
 pub fn register_server<
-    Command: 'static + Clone + Serialize + DeserializeOwned,
-    R: 'static + AsRef<Raft<Command>> + Send + Sync + Clone,
+    Command: 'static + Clone + Send + Serialize + DeserializeOwned,
     S: AsRef<str>,
 >(
-    raft: R,
+    raft: Raft<Command>,
     name: S,
     network: &Mutex<Network>,
 ) -> std::io::Result<()> {
@@ -159,19 +158,17 @@ pub fn register_server<
 
     server.register_rpc_handler(REQUEST_VOTE_RPC.to_owned(), {
         let raft = raft.clone();
-        make_rpc_handler(move |args| raft.as_ref().process_request_vote(args))
+        make_rpc_handler(move |args| raft.process_request_vote(args))
     })?;
 
     server.register_rpc_handler(APPEND_ENTRIES_RPC.to_owned(), {
         let raft = raft.clone();
-        make_rpc_handler(move |args| raft.as_ref().process_append_entries(args))
+        make_rpc_handler(move |args| raft.process_append_entries(args))
     })?;
 
     server.register_rpc_handler(
         INSTALL_SNAPSHOT_RPC.to_owned(),
-        make_rpc_handler(move |args| {
-            raft.as_ref().process_install_snapshot(args)
-        }),
+        make_rpc_handler(move |args| raft.process_install_snapshot(args)),
     )?;
 
     network.add_server(server_name, server);
@@ -258,14 +255,14 @@ mod tests {
                 .lock()
                 .make_client("test-basic-message", name.to_owned());
 
-            let raft = Arc::new(Raft::new(
+            let raft = Raft::new(
                 vec![RpcClient(client)],
                 0,
                 Arc::new(DoNothingPersister),
                 |_: ApplyCommandMessage<i32>| {},
                 None,
                 crate::utils::NO_SNAPSHOT,
-            ));
+            );
             register_server(raft, name, network.as_ref())?;
 
             let client = network