瀏覽代碼

Store DaemonEnv in the thread context to enable global logging.

The next step is to replace all assertions and error printers with
DaemonEnv.
Jing Yang 4 年之前
父節點
當前提交
bc61f08723
共有 5 個文件被更改,包括 109 次插入10 次删除
  1. 4 0
      src/apply_command.rs
  2. 78 4
      src/daemon_env.rs
  3. 3 1
      src/install_snapshot.rs
  4. 21 4
      src/lib.rs
  5. 3 1
      src/snapshot.rs

+ 4 - 0
src/apply_command.rs

@@ -30,8 +30,12 @@ where
         let rf = self.inner_state.clone();
         let condvar = self.apply_command_signal.clone();
         let snapshot_daemon = self.snapshot_daemon.clone();
+        let daemon_env = self.daemon_env.clone();
         let stop_wait_group = self.stop_wait_group.clone();
         let join_handle = std::thread::spawn(move || {
+            // Note: do not change this to `let _ = ...`.
+            let _guard = daemon_env.for_scope();
+
             while keep_running.load(Ordering::SeqCst) {
                 let messages = {
                     let mut rf = rf.lock();

+ 78 - 4
src/daemon_env.rs

@@ -1,4 +1,5 @@
-use std::sync::Arc;
+use std::cell::RefCell;
+use std::sync::{Arc, Weak};
 
 use parking_lot::Mutex;
 
@@ -7,9 +8,9 @@ use crate::{Peer, RaftState, State, Term};
 
 #[macro_export]
 macro_rules! check_or_record {
-    ($daemon_env:expr, $condition:expr, $error_kind:expr, $message:expr, $rf:expr) => {
+    ($condition:expr, $error_kind:expr, $message:expr, $rf:expr) => {
         if !$condition {
-            $daemon_env.record_error(
+            crate::daemon_env::ThreadEnv::upgrade().record_error(
                 $error_kind,
                 $message,
                 $rf,
@@ -19,9 +20,10 @@ macro_rules! check_or_record {
     };
 }
 
-#[derive(Clone, Debug, Default)]
+#[derive(Clone, Debug)]
 pub(crate) struct DaemonEnv {
     data: Arc<Mutex<DaemonEnvData>>,
+    thread_env: ThreadEnv,
 }
 
 #[derive(Debug, Default)]
@@ -132,3 +134,75 @@ struct StrippedRaftState {
     state: State,
     leader_id: Peer,
 }
+
+impl DaemonEnv {
+    pub(crate) fn create() -> Self {
+        let data = Default::default();
+        // Pre-create a template thread_env, so that we can clone the weak
+        // pointer instead of downgrading frequently.
+        let thread_env = ThreadEnv {
+            data: Arc::downgrade(&data),
+        };
+        Self { data, thread_env }
+    }
+
+    pub(crate) fn for_thread(&self) -> ThreadEnv {
+        self.thread_env.clone()
+    }
+
+    pub(crate) fn for_scope(&self) -> ThreadEnvGuard {
+        self.for_thread().attach();
+        ThreadEnvGuard {}
+    }
+}
+
+#[derive(Clone, Debug, Default)]
+pub(crate) struct ThreadEnv {
+    data: Weak<Mutex<DaemonEnvData>>,
+}
+
+impl ThreadEnv {
+    thread_local! {static ENV: RefCell<ThreadEnv> = Default::default()}
+
+    // The dance between Arc<> and Weak<> is complex, but useful:
+    // 1) We do not have to worry about slow RPC threads causing
+    // DaemonEnv::shutdown() to fail. They only hold a Weak<> pointer after all;
+    // 2) We have one system that works both in the environments that we control
+    // (daemon threads and our own thread pools), and in those we don't (RPC
+    // handling methods);
+    // 3) Utils (log_array, persister) can log errors without access to Raft;
+    // 4) Because of 2), we do not need to expose DaemonEnv externally outside
+    // this crate, even though there is a public macro referencing it.
+    //
+    // On the other hand, the cost is fairly small, because:
+    // 1) Clone of weak is cheap: one branch plus one relaxed atomic load;
+    // downgrade is more expensive, but we only do it once;
+    // 2) Upgrade of weak is expensive, but that only happens when there is
+    // an error, which should be (knock wood) rare;
+    // 3) Set and unset a thread_local value is cheap, too.
+    pub fn upgrade() -> DaemonEnv {
+        let env = Self::ENV.with(|env| env.borrow().clone());
+        DaemonEnv {
+            data: env.data.upgrade().unwrap(),
+            thread_env: env,
+        }
+    }
+
+    pub fn attach(self) {
+        Self::ENV.with(|env| env.replace(self));
+    }
+
+    pub fn detach() {
+        Self::ENV.with(|env| env.replace(Default::default()));
+    }
+}
+
+pub(crate) struct ThreadEnvGuard {}
+
+impl Drop for ThreadEnvGuard {
+    fn drop(&mut self) {
+        ThreadEnv::detach()
+    }
+}
+
+// TODO(ditsing): add tests.

+ 3 - 1
src/install_snapshot.rs

@@ -30,6 +30,9 @@ impl<C: Clone + Default + serde::Serialize> Raft<C> {
         &self,
         args: InstallSnapshotArgs,
     ) -> InstallSnapshotReply {
+        // Note: do not change this to `let _ = ...`.
+        let _guard = self.daemon_env.for_scope();
+
         if args.offset != 0 || !args.done {
             panic!("Current implementation cannot handle segmented snapshots.")
         }
@@ -84,7 +87,6 @@ impl<C: Clone + Default + serde::Serialize> Raft<C> {
             }
         } else {
             check_or_record!(
-                self.daemon_env,
                 args.last_included_index > rf.commit_index,
                 ErrorKind::SnapshotBeforeCommitted(
                     args.last_included_index,

+ 21 - 4
src/lib.rs

@@ -18,7 +18,7 @@ use rand::{thread_rng, Rng};
 
 use crate::apply_command::ApplyCommandFnMut;
 pub use crate::apply_command::ApplyCommandMessage;
-use crate::daemon_env::{DaemonEnv, ErrorKind};
+use crate::daemon_env::{DaemonEnv, ErrorKind, ThreadEnv};
 use crate::index_term::IndexTerm;
 use crate::install_snapshot::InstallSnapshotArgs;
 use crate::persister::PersistedRaftState;
@@ -176,10 +176,14 @@ where
         };
         election.reset_election_timer();
 
+        let daemon_env = DaemonEnv::create();
+        let thread_env = daemon_env.for_thread();
         let thread_pool = tokio::runtime::Builder::new_multi_thread()
             .enable_time()
             .thread_name(format!("raft-instance-{}", me))
             .worker_threads(peer_size)
+            .on_thread_start(move || thread_env.clone().attach())
+            .on_thread_stop(ThreadEnv::detach)
             .build()
             .expect("Creating thread pool should not fail");
         let peers = peers.into_iter().map(Arc::new).collect();
@@ -194,7 +198,7 @@ where
             election: Arc::new(election),
             snapshot_daemon: Default::default(),
             thread_pool: Arc::new(thread_pool),
-            daemon_env: Default::default(),
+            daemon_env,
             stop_wait_group: WaitGroup::new(),
         };
 
@@ -228,6 +232,9 @@ where
         &self,
         args: RequestVoteArgs,
     ) -> RequestVoteReply {
+        // Note: do not change this to `let _ = ...`.
+        let _guard = self.daemon_env.for_scope();
+
         let mut rf = self.inner_state.lock();
 
         let term = rf.current_term;
@@ -277,6 +284,9 @@ where
         &self,
         args: AppendEntriesArgs<Command>,
     ) -> AppendEntriesReply {
+        // Note: do not change this to `let _ = ...`.
+        let _guard = self.daemon_env.for_scope();
+
         let mut rf = self.inner_state.lock();
         if rf.current_term > args.term {
             return AppendEntriesReply {
@@ -313,7 +323,6 @@ where
             if rf.log.end() > index {
                 if rf.log[index].term != entry.term {
                     check_or_record!(
-                        self.daemon_env,
                         index > rf.commit_index,
                         ErrorKind::RollbackCommitted(index),
                         "Entries before commit index should never be rolled back",
@@ -373,6 +382,9 @@ where
     fn run_election_timer(&self) {
         let this = self.clone();
         let join_handle = std::thread::spawn(move || {
+            // Note: do not change this to `let _ = ...`.
+            let _guard = this.daemon_env.for_scope();
+
             let election = this.election.clone();
 
             let mut should_run = None;
@@ -675,6 +687,9 @@ where
         // Clone everything that the thread needs.
         let this = self.clone();
         let join_handle = std::thread::spawn(move || {
+            // Note: do not change this to `let _ = ...`.
+            let _guard = this.daemon_env.for_scope();
+
             let mut openings = vec![];
             openings.resize_with(this.peers.len(), || {
                 Opening(Arc::new(AtomicUsize::new(0)))
@@ -941,7 +956,6 @@ where
         self.apply_command_signal.notify_all();
         self.snapshot_daemon.kill();
         self.stop_wait_group.wait();
-        self.daemon_env.shutdown();
         std::sync::Arc::try_unwrap(self.thread_pool)
             .expect(
                 "All references to the thread pool should have been dropped.",
@@ -949,6 +963,9 @@ where
             .shutdown_timeout(Duration::from_millis(
                 HEARTBEAT_INTERVAL_MILLIS * 2,
             ));
+        // DaemonEnv must be shutdown after the thread pool, since there might
+        // be tasks logging errors in the pool.
+        self.daemon_env.shutdown();
     }
 
     pub fn get_state(&self) -> (Term, bool) {

+ 3 - 1
src/snapshot.rs

@@ -84,6 +84,9 @@ impl<C: 'static + Clone + Default + Send + serde::Serialize> Raft<C> {
         let stop_wait_group = self.stop_wait_group.clone();
 
         let join_handle = std::thread::spawn(move || loop {
+            // Note: do not change this to `let _ = ...`.
+            let _guard = daemon_env.for_scope();
+
             parker.park();
             if !keep_running.load(Ordering::SeqCst) {
                 // Explicitly drop every thing.
@@ -123,7 +126,6 @@ impl<C: 'static + Clone + Default + Send + serde::Serialize> Raft<C> {
                 }
 
                 check_or_record!(
-                    daemon_env,
                     snapshot.last_included_index < rf.log.end(),
                     ErrorKind::SnapshotAfterLogEnd(
                         snapshot.last_included_index,