Просмотр исходного кода

Replace term marker with a static remote context.

Jing Yang 3 лет назад
Родитель
Сommit
ff45298c67
8 измененных файлов с 95 добавлено и 32 удалено
  1. 2 5
      src/election.rs
  2. 2 6
      src/heartbeats.rs
  3. 1 0
      src/lib.rs
  4. 21 5
      src/raft.rs
  5. 53 0
      src/remote_context.rs
  6. 0 0
      src/remote_peer.rs
  7. 2 4
      src/sync_log_entries.rs
  8. 14 12
      src/term_marker.rs

+ 2 - 5
src/election.rs

@@ -5,8 +5,8 @@ use std::time::{Duration, Instant};
 use parking_lot::{Condvar, Mutex};
 use rand::{thread_rng, Rng};
 
+use crate::remote_context::RemoteContext;
 use crate::sync_log_entries::SyncLogEntriesComms;
-use crate::term_marker::TermMarker;
 use crate::utils::{retry_rpc, RPC_DEADLINE};
 use crate::verify_authority::VerifyAuthorityDaemon;
 use crate::{
@@ -253,7 +253,6 @@ impl<Command: ReplicableCommand> Raft<Command> {
         };
 
         let mut votes = vec![];
-        let term_marker = self.term_marker();
         for (index, rpc_client) in self.peers.iter().enumerate() {
             if index != self.me.0 {
                 // RpcClient must be cloned so that it lives long enough for
@@ -262,7 +261,6 @@ impl<Command: ReplicableCommand> Raft<Command> {
                 let one_vote = self.thread_pool.spawn(Self::request_vote(
                     rpc_client.clone(),
                     args.clone(),
-                    term_marker.clone(),
                 ));
                 votes.push(one_vote);
             }
@@ -287,7 +285,6 @@ impl<Command: ReplicableCommand> Raft<Command> {
     async fn request_vote(
         rpc_client: impl RemoteRaft<Command>,
         args: RequestVoteArgs,
-        term_marker: TermMarker<Command>,
     ) -> Option<bool> {
         let term = args.term;
         // See the comment in send_heartbeat() for this override.
@@ -298,7 +295,7 @@ impl<Command: ReplicableCommand> Raft<Command> {
             })
             .await;
         if let Ok(reply) = reply {
-            term_marker.mark(reply.term);
+            RemoteContext::<Command>::term_marker().mark(reply.term);
             return Some(reply.vote_granted && reply.term == term);
         }
         None

+ 2 - 6
src/heartbeats.rs

@@ -4,7 +4,7 @@ use std::time::{Duration, Instant};
 
 use parking_lot::Mutex;
 
-use crate::term_marker::TermMarker;
+use crate::remote_context::RemoteContext;
 use crate::utils::{retry_rpc, RPC_DEADLINE};
 use crate::verify_authority::DaemonBeatTicker;
 use crate::{
@@ -74,8 +74,6 @@ impl<Command: ReplicableCommand> Raft<Command> {
             if peer_index != self.me.0 {
                 // rf is now owned by the outer async function.
                 let rf = self.inner_state.clone();
-                // A function that updates term with responses to heartbeats.
-                let term_marker = self.term_marker();
                 // A function that casts an "authoritative" vote with Ok()
                 // responses to heartbeats.
                 let beat_ticker = self.beat_ticker(peer_index);
@@ -97,7 +95,6 @@ impl<Command: ReplicableCommand> Raft<Command> {
                             tokio::spawn(Self::send_heartbeat(
                                 rpc_client.clone(),
                                 args,
-                                term_marker.clone(),
                                 beat_ticker.clone(),
                             ));
                         }
@@ -138,7 +135,6 @@ impl<Command: ReplicableCommand> Raft<Command> {
         // the RPC client.
         rpc_client: impl RemoteRaft<Command>,
         args: AppendEntriesArgs<Command>,
-        term_watermark: TermMarker<Command>,
         beat_ticker: DaemonBeatTicker,
     ) -> std::io::Result<()> {
         let term = args.term;
@@ -166,7 +162,7 @@ impl<Command: ReplicableCommand> Raft<Command> {
         if term == response.term {
             beat_ticker.tick(beat);
         } else {
-            term_watermark.mark(response.term);
+            RemoteContext::<Command>::term_marker().mark(response.term);
         }
         Ok(())
     }

+ 1 - 0
src/lib.rs

@@ -29,6 +29,7 @@ mod process_install_snapshot;
 mod process_request_vote;
 mod raft;
 mod raft_state;
+mod remote_context;
 mod remote_raft;
 mod replicable_command;
 mod snapshot;

+ 21 - 5
src/raft.rs

@@ -12,8 +12,10 @@ use crate::daemon_watch::{Daemon, DaemonWatch};
 use crate::election::ElectionState;
 use crate::heartbeats::{HeartbeatsDaemon, HEARTBEAT_INTERVAL};
 use crate::persister::PersistedRaftState;
+use crate::remote_context::RemoteContext;
 use crate::snapshot::{RequestSnapshotFnMut, SnapshotDaemon};
 use crate::sync_log_entries::SyncLogEntriesComms;
+use crate::term_marker::TermMarker;
 use crate::verify_authority::VerifyAuthorityDaemon;
 use crate::{IndexTerm, Persister, RaftState, RemoteRaft, ReplicableCommand};
 
@@ -88,9 +90,17 @@ impl<Command: ReplicableCommand> Raft<Command> {
                 .expect("Persisted log should not contain error");
         }
 
-        let election = ElectionState::create();
+        let inner_state = Arc::new(Mutex::new(state));
+        let election = Arc::new(ElectionState::create());
         election.reset_election_timer();
 
+        let term_marker = TermMarker::create(
+            inner_state.clone(),
+            election.clone(),
+            persister.clone(),
+        );
+        let context = RemoteContext::create(term_marker);
+
         let daemon_env = DaemonEnv::create();
         let thread_env = daemon_env.for_thread();
         let thread_pool = tokio::runtime::Builder::new_multi_thread()
@@ -98,8 +108,14 @@ impl<Command: ReplicableCommand> Raft<Command> {
             .enable_io()
             .thread_name(format!("raft-instance-{}", me))
             .worker_threads(peer_size)
-            .on_thread_start(move || thread_env.clone().attach())
-            .on_thread_stop(ThreadEnv::detach)
+            .on_thread_start(move || {
+                context.clone().attach();
+                thread_env.clone().attach();
+            })
+            .on_thread_stop(move || {
+                RemoteContext::<Command>::detach();
+                ThreadEnv::detach();
+            })
             .build()
             .expect("Creating thread pool should not fail");
         let peers = peers
@@ -110,14 +126,14 @@ impl<Command: ReplicableCommand> Raft<Command> {
             crate::sync_log_entries::create(peer_size);
 
         let mut this = Raft {
-            inner_state: Arc::new(Mutex::new(state)),
+            inner_state,
             peers,
             me: Peer(me),
             persister,
             sync_log_entries_comms,
             apply_command_signal: Arc::new(Condvar::new()),
             keep_running: Arc::new(AtomicBool::new(true)),
-            election: Arc::new(election),
+            election,
             snapshot_daemon: SnapshotDaemon::create(),
             verify_authority_daemon: VerifyAuthorityDaemon::create(peer_size),
             heartbeats_daemon: HeartbeatsDaemon::create(),

+ 53 - 0
src/remote_context.rs

@@ -0,0 +1,53 @@
+use std::any::Any;
+use std::cell::RefCell;
+
+use crate::term_marker::TermMarker;
+
+#[derive(Clone)]
+pub(crate) struct RemoteContext<Command> {
+    term_marker: TermMarker<Command>,
+}
+
+impl<Command: 'static> RemoteContext<Command> {
+    pub fn create(term_marker: TermMarker<Command>) -> Self {
+        Self { term_marker }
+    }
+
+    pub fn term_marker() -> &'static TermMarker<Command> {
+        &Self::fetch_context().term_marker
+    }
+
+    thread_local! {
+        // Using a pointer to expose a static reference.
+        // Using Any to mask the fact that we are storing a generic struct.
+        static REMOTE_CONTEXT: RefCell<*mut dyn Any> = RefCell::new(
+            std::ptr::null_mut::<()>() as *mut dyn Any);
+    }
+
+    pub fn attach(self) {
+        Self::set_context(Box::new(self))
+    }
+
+    pub fn detach() -> Box<Self> {
+        let static_context = Self::fetch_context();
+        unsafe { Box::from_raw(static_context) }
+    }
+
+    fn set_context(context: Box<Self>) {
+        let context_ptr = Box::into_raw(context);
+        let any_ptr: *mut dyn Any = context_ptr;
+        Self::REMOTE_CONTEXT.with(|context| *context.borrow_mut() = any_ptr);
+    }
+
+    fn fetch_context() -> &'static mut Self {
+        let any_ptr = Self::REMOTE_CONTEXT.with(|context| *context.borrow());
+        if any_ptr.is_null() {
+            panic!("Context is not set");
+        }
+        unsafe {
+            (*any_ptr)
+                .downcast_mut::<Self>()
+                .expect("Context is set to the wrong type.")
+        }
+    }
+}

+ 0 - 0
src/remote_peer.rs


+ 2 - 4
src/sync_log_entries.rs

@@ -6,7 +6,7 @@ use parking_lot::{Condvar, Mutex};
 use crate::daemon_env::ErrorKind;
 use crate::heartbeats::HEARTBEAT_INTERVAL;
 use crate::peer_progress::PeerProgress;
-use crate::term_marker::TermMarker;
+use crate::remote_context::RemoteContext;
 use crate::utils::{retry_rpc, SharedSender, RPC_DEADLINE};
 use crate::verify_authority::DaemonBeatTicker;
 use crate::{
@@ -148,7 +148,6 @@ impl<Command: ReplicableCommand> Raft<Command> {
                                 this.sync_log_entries_comms.clone(),
                                 progress.clone(),
                                 this.apply_command_signal.clone(),
-                                this.term_marker(),
                                 this.beat_ticker(i),
                                 TaskNumber(task_number),
                             ));
@@ -208,7 +207,6 @@ impl<Command: ReplicableCommand> Raft<Command> {
         comms: SyncLogEntriesComms,
         progress: PeerProgress,
         apply_command_signal: Arc<Condvar>,
-        term_marker: TermMarker<Command>,
         beat_ticker: DaemonBeatTicker,
         task_number: TaskNumber,
     ) {
@@ -367,7 +365,7 @@ impl<Command: ReplicableCommand> Raft<Command> {
             }
             // Do nothing, not our term anymore.
             Ok(SyncLogEntriesResult::TermElapsed(term)) => {
-                term_marker.mark(term);
+                RemoteContext::<Command>::term_marker().mark(term);
             }
             Err(_) => {
                 tokio::time::sleep(HEARTBEAT_INTERVAL).await;

+ 14 - 12
src/term_marker.rs

@@ -4,7 +4,7 @@ use parking_lot::Mutex;
 use serde::Serialize;
 
 use crate::election::ElectionState;
-use crate::{Persister, Raft, RaftState, State, Term};
+use crate::{Persister, RaftState, State, Term};
 
 /// A closure that updates the `Term` of the `RaftState`.
 #[derive(Clone)]
@@ -15,23 +15,25 @@ pub(crate) struct TermMarker<Command> {
 }
 
 impl<Command: Clone + Serialize> TermMarker<Command> {
+    /// Create a `TermMarker` that can be passed to async tasks.
+    pub fn create(
+        rf: Arc<Mutex<RaftState<Command>>>,
+        election: Arc<ElectionState>,
+        persister: Arc<dyn Persister>,
+    ) -> Self {
+        Self {
+            rf,
+            election,
+            persister,
+        }
+    }
+
     pub fn mark(&self, term: Term) {
         let mut rf = self.rf.lock();
         mark_term(&mut rf, &self.election, self.persister.as_ref(), term)
     }
 }
 
-impl<Command: Clone + Serialize> Raft<Command> {
-    /// Create a `TermMarker` that can be passed to tasks.
-    pub(crate) fn term_marker(&self) -> TermMarker<Command> {
-        TermMarker {
-            rf: self.inner_state.clone(),
-            election: self.election.clone(),
-            persister: self.persister.clone(),
-        }
-    }
-}
-
 /// Update the term of the `RaftState`.
 pub(crate) fn mark_term<Command: Clone + Serialize>(
     rf: &mut RaftState<Command>,