Selaa lähdekoodia

Create a term marker that uses RPC responses to update term: #7.

Jing Yang 4 vuotta sitten
vanhempi
commit
89238d8801
5 muutettua tiedostoa jossa 77 lisäystä ja 9 poistoa
  1. 9 4
      src/election.rs
  2. 11 4
      src/heartbeats.rs
  3. 1 0
      src/lib.rs
  4. 6 1
      src/sync_log_entries.rs
  5. 50 0
      src/term_marker.rs

+ 9 - 4
src/election.rs

@@ -5,6 +5,7 @@ use std::time::{Duration, Instant};
 use parking_lot::{Condvar, Mutex};
 use rand::{thread_rng, Rng};
 
+use crate::term_marker::TermMarker;
 use crate::utils::{retry_rpc, RPC_DEADLINE};
 use crate::{Peer, Raft, RaftState, RequestVoteArgs, RpcClient, State, Term};
 
@@ -245,15 +246,17 @@ where
         };
 
         let mut votes = vec![];
+        let term_marker = self.term_marker();
         for (index, rpc_client) in self.peers.iter().enumerate() {
             if index != self.me.0 {
                 // RpcClient must be cloned so that it lives long enough for
                 // spawn(), which requires static life time.
-                let rpc_client = rpc_client.clone();
                 // RPCs are started right away.
-                let one_vote = self
-                    .thread_pool
-                    .spawn(Self::request_vote(rpc_client, args.clone()));
+                let one_vote = self.thread_pool.spawn(Self::request_vote(
+                    rpc_client.clone(),
+                    args.clone(),
+                    term_marker.clone(),
+                ));
                 votes.push(one_vote);
             }
         }
@@ -275,6 +278,7 @@ where
     async fn request_vote(
         rpc_client: Arc<RpcClient>,
         args: RequestVoteArgs,
+        term_marker: TermMarker<Command>,
     ) -> Option<bool> {
         let term = args.term;
         // See the comment in send_heartbeat() for this override.
@@ -285,6 +289,7 @@ where
             })
             .await;
         if let Ok(reply) = reply {
+            term_marker.mark(reply.term);
             return Some(reply.vote_granted && reply.term == term);
         }
         None

+ 11 - 4
src/heartbeats.rs

@@ -4,6 +4,7 @@ use std::time::Duration;
 
 use parking_lot::Mutex;
 
+use crate::term_marker::TermMarker;
 use crate::utils::{retry_rpc, RPC_DEADLINE};
 use crate::{AppendEntriesArgs, Raft, RaftState, RpcClient};
 
@@ -32,6 +33,8 @@ where
             if peer_index != self.me.0 {
                 // rf is now owned by the outer async function.
                 let rf = self.inner_state.clone();
+                // A function that updates term with responses to heartbeats.
+                let term_marker = self.term_marker();
                 // RPC client must be cloned into the outer async function.
                 let rpc_client = rpc_client.clone();
                 // Shutdown signal.
@@ -44,6 +47,7 @@ where
                             tokio::spawn(Self::send_heartbeat(
                                 rpc_client.clone(),
                                 args,
+                                term_marker.clone(),
                             ));
                         }
                     }
@@ -77,6 +81,7 @@ where
     async fn send_heartbeat(
         rpc_client: Arc<RpcClient>,
         args: AppendEntriesArgs<Command>,
+        term_watermark: TermMarker<Command>,
     ) -> std::io::Result<()> {
         // Passing a reference that is moved to the following closure.
         //
@@ -93,10 +98,12 @@ where
         // of type Arc can be passed-in directly. However that requires args to
         // be sync because they can be shared by more than one futures.
         let rpc_client = rpc_client.as_ref();
-        retry_rpc(Self::HEARTBEAT_RETRY, RPC_DEADLINE, move |_round| {
-            rpc_client.call_append_entries(args.clone())
-        })
-        .await?;
+        let response =
+            retry_rpc(Self::HEARTBEAT_RETRY, RPC_DEADLINE, move |_round| {
+                rpc_client.call_append_entries(args.clone())
+            })
+            .await?;
+        term_watermark.mark(response.term);
         Ok(())
     }
 }

+ 1 - 0
src/lib.rs

@@ -42,6 +42,7 @@ mod raft_state;
 pub mod rpcs;
 mod snapshot;
 mod sync_log_entries;
+mod term_marker;
 pub mod utils;
 
 #[derive(

+ 6 - 1
src/sync_log_entries.rs

@@ -7,6 +7,7 @@ use parking_lot::{Condvar, Mutex};
 use crate::check_or_record;
 use crate::daemon_env::ErrorKind;
 use crate::index_term::IndexTerm;
+use crate::term_marker::TermMarker;
 use crate::utils::{retry_rpc, RPC_DEADLINE};
 use crate::{
     AppendEntriesArgs, InstallSnapshotArgs, Peer, Raft, RaftState, RpcClient,
@@ -91,6 +92,7 @@ where
                                 this.new_log_entry.clone().unwrap(),
                                 openings[i].0.clone(),
                                 this.apply_command_signal.clone(),
+                                this.term_marker(),
                             ));
                         }
                     }
@@ -152,6 +154,7 @@ where
         rerun: std::sync::mpsc::Sender<Option<Peer>>,
         opening: Arc<AtomicUsize>,
         apply_command_signal: Arc<Condvar>,
+        term_marker: TermMarker<Command>,
     ) {
         if opening.swap(0, Ordering::SeqCst) == 0 {
             return;
@@ -274,7 +277,9 @@ where
                 let _ = rerun.send(Some(Peer(peer_index)));
             }
             // Do nothing, not our term anymore.
-            Ok(SyncLogEntriesResult::TermElapsed(_)) => {}
+            Ok(SyncLogEntriesResult::TermElapsed(term)) => {
+                term_marker.mark(term);
+            }
             Err(_) => {
                 tokio::time::sleep(Duration::from_millis(
                     HEARTBEAT_INTERVAL_MILLIS,

+ 50 - 0
src/term_marker.rs

@@ -0,0 +1,50 @@
+use std::sync::Arc;
+
+use parking_lot::Mutex;
+
+use crate::election::ElectionState;
+use crate::{Persister, Raft, RaftState, State, Term};
+use serde::Serialize;
+
+/// A closure that updates the `Term` of the `RaftState`.
+#[derive(Clone)]
+pub(crate) struct TermMarker<Command> {
+    rf: Arc<Mutex<RaftState<Command>>>,
+    election: Arc<ElectionState>,
+    persister: Arc<dyn Persister>,
+}
+
+impl<Command: Clone + Serialize> TermMarker<Command> {
+    pub fn mark(&self, term: Term) {
+        let mut rf = self.rf.lock();
+        mark_term(&mut rf, &self.election, self.persister.as_ref(), term)
+    }
+}
+
+impl<Command: Clone + Serialize> Raft<Command> {
+    /// Create a `TermMarker` that can be passed to tasks.
+    pub(crate) fn term_marker(&self) -> TermMarker<Command> {
+        TermMarker {
+            rf: self.inner_state.clone(),
+            election: self.election.clone(),
+            persister: self.persister.clone(),
+        }
+    }
+}
+
+/// Update the term of the `RaftState`.
+pub(crate) fn mark_term<Command: Clone + Serialize>(
+    rf: &mut RaftState<Command>,
+    election: &ElectionState,
+    persister: &dyn Persister,
+    term: Term,
+) {
+    if term > rf.current_term {
+        rf.current_term = term;
+        rf.voted_for = None;
+        rf.state = State::Follower;
+
+        election.reset_election_timer();
+        persister.save_state(rf.persisted_state().into());
+    }
+}