Browse Source

Implement run election.

Jing Yang 5 years ago
parent
commit
7ad4bd0a21
3 changed files with 180 additions and 30 deletions
  1. 2 1
      Cargo.toml
  2. 166 11
      src/lib.rs
  3. 12 18
      src/rpcs.rs

+ 2 - 1
Cargo.toml

@@ -8,12 +8,13 @@ edition = "2018"
 
 [dependencies]
 bincode = "1.3.1"
+futures = { version = "0.3.5" }
 labrpc = { path = "../labrpc" }
 parking_lot = "0.11.0"
 rand = "0.7.3"
 serde = "1.0.116"
 serde_derive = "1.0.116"
-tokio = { version = "0.2.22", features = ["rt-threaded", "time"] }
+tokio = { version = "0.2.22", features = ["rt-threaded", "sync", "time"] }
 
 [dev-dependencies]
 futures = { version = "0.3.5", features = ["thread-pool"] }

+ 166 - 11
src/lib.rs

@@ -1,20 +1,28 @@
 #![allow(unused)]
 
 extern crate bincode;
+extern crate futures;
 extern crate labrpc;
 extern crate rand;
 #[macro_use]
 extern crate serde_derive;
 extern crate tokio;
 
-use crate::rpcs::RpcClient;
+use std::future::Future;
+use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
+use std::sync::Arc;
+use std::time::Duration;
+
+use futures::FutureExt;
 use parking_lot::{Condvar, Mutex};
 use rand::{thread_rng, Rng};
-use std::sync::atomic::AtomicBool;
-use tokio::time::Duration;
+
+use crate::rpcs::RpcClient;
+use std::cell::RefCell;
 
 pub mod rpcs;
 
+#[derive(Eq, PartialEq)]
 enum State {
     Follower,
     Candidate,
@@ -73,20 +81,20 @@ struct RaftState {
     state: State,
 
     leader_id: Peer,
-    // Timer will be removed upon shutdown.
+
+    // Current election cancel token, might be None if no election is running.
+    election_cancel_token: Option<tokio::sync::oneshot::Sender<Term>>,
+    // Timer will be removed upon shutdown or elected.
     election_timer: Option<tokio::time::Delay>,
 }
 
 #[derive(Default)]
 struct Raft {
-    inner_state: Mutex<RaftState>,
-    peers: RpcClient,
+    inner_state: Arc<Mutex<RaftState>>,
+    peers: Vec<RpcClient>,
 
     me: Peer,
 
-    vote_mutex: Mutex<()>,
-    vote_cond: Condvar,
-
     // new_log_entry: Sender<usize>,
     // new_log_entry: Receiver<usize>,
     // apply_command_cond: Condvar
@@ -153,8 +161,8 @@ impl Raft {
             rf.current_term = args.term;
             rf.voted_for = None;
             rf.state = State::Follower;
-            // TODO: quit current election
             rf.reset_election_timer();
+            rf.stop_current_election();
             rf.persist();
         }
 
@@ -168,6 +176,7 @@ impl Raft {
         {
             rf.voted_for = Some(args.candidate_id);
             rf.reset_election_timer();
+            // No need to stop the election. We are not a candidate.
             rf.persist();
 
             RequestVoteReply {
@@ -201,7 +210,7 @@ impl Raft {
 
         rf.state = State::Follower;
         rf.reset_election_timer();
-        // TODO: stop previous election
+        rf.stop_current_election();
         rf.leader_id = args.leader_id;
 
         if rf.log.len() <= args.prev_log_index
@@ -239,6 +248,146 @@ impl Raft {
             success: true,
         }
     }
+
+    async fn retry_rpc<Func, Fut, T>(
+        max_retry: usize,
+        mut task_gen: Func,
+    ) -> std::io::Result<T>
+    where
+        Fut: Future<Output = std::io::Result<T>> + Send + 'static,
+        Func: FnMut(usize) -> Fut,
+    {
+        for i in 0..max_retry {
+            if let Ok(reply) = task_gen(i).await {
+                return Ok(reply);
+            }
+            tokio::time::delay_for(tokio::time::Duration::from_millis(
+                (1 << i) * 10,
+            ))
+            .await;
+        }
+        Err(std::io::Error::new(
+            std::io::ErrorKind::TimedOut,
+            format!("Timed out after {} retries", max_retry),
+        ))
+    }
+
+    fn run_election(&self) {
+        let (term, last_log_index, last_log_term, cancel_token) = {
+            let mut rf = self.inner_state.lock();
+
+            let (tx, rx) = tokio::sync::oneshot::channel();
+            rf.current_term.0 += 1;
+
+            rf.voted_for = Some(self.me);
+            rf.state = State::Candidate;
+            rf.reset_election_timer();
+            rf.stop_current_election();
+
+            rf.election_cancel_token.replace(tx);
+
+            rf.persist();
+
+            (
+                rf.current_term,
+                rf.log.len() - 1,
+                rf.log.last().unwrap().term,
+                rx,
+            )
+        };
+
+        let me = self.me;
+
+        let mut votes = vec![];
+        for i in 0..self.peers.len() {
+            if i != self.me.0 {
+                // Make a clone now so that self will not be passed across await
+                // boundary.
+                let rpc_client = self.peers[i].clone();
+                let one_vote = async move {
+                    let reply_future = Self::retry_rpc(4, move |_round| {
+                        rpc_client.clone().call_request_vote(RequestVoteArgs {
+                            term,
+                            candidate_id: me,
+                            last_log_index,
+                            last_log_term,
+                        })
+                    });
+                    if let Ok(reply) = reply_future.await {
+                        return Some(reply.vote_granted && reply.term == term);
+                    }
+                    return None;
+                };
+                // Futures must be pinned so that they have Unpin, as required
+                // by futures::future::select.
+                votes.push(Box::pin(one_vote));
+            }
+        }
+
+        tokio::spawn(Self::count_vote_util_cancelled(
+            term,
+            self.inner_state.clone(),
+            votes,
+            self.peers.len() / 2,
+            cancel_token,
+        ));
+    }
+
+    async fn count_vote_util_cancelled(
+        term: Term,
+        rf: Arc<Mutex<RaftState>>,
+        votes: Vec<impl Future<Output = Option<bool>> + Unpin>,
+        majority: usize,
+        cancel_token: tokio::sync::oneshot::Receiver<Term>,
+    ) {
+        let mut vote_count = 0;
+        let mut against_count = 0;
+        let mut cancel_token = cancel_token;
+        let mut futures_vec = votes;
+        while vote_count < majority && against_count <= majority {
+            // Mixing tokio futures with futures-rs ones. Fingers crossed.
+            let selected = futures::future::select(
+                cancel_token,
+                futures::future::select_all(futures_vec),
+            )
+            .await;
+            let ((one_vote, index, rest), new_token) = match selected {
+                futures::future::Either::Left(_) => break,
+                futures::future::Either::Right(tuple) => tuple,
+            };
+
+            futures_vec = rest;
+            cancel_token = new_token;
+
+            if let Some(vote) = one_vote {
+                if vote {
+                    vote_count += 1
+                } else {
+                    against_count += 1
+                }
+            }
+        }
+
+        if vote_count < majority {
+            return;
+        }
+        let mut rf = rf.lock();
+        if rf.current_term == term && rf.state == State::Candidate {
+            rf.state = State::Leader;
+        }
+        let log_len = rf.log.len();
+        for item in rf.next_index.iter_mut() {
+            *item = log_len;
+        }
+        for item in rf.match_index.iter_mut() {
+            *item = 0;
+        }
+        // TODO: send heartbeats.
+        // Drop the timer and cancel token.
+        rf.election_cancel_token.take();
+        rf.election_timer.take();
+        rf.persist();
+    }
 }
 
 const HEARTBEAT_INTERVAL_MILLIS: u64 = 150;
@@ -259,6 +408,12 @@ impl RaftState {
         )
     }
 
+    fn stop_current_election(&mut self) {
+        self.election_cancel_token
+            .take()
+            .map(|sender| sender.send(self.current_term));
+    }
+
     fn persist(&self) {
         // TODO: implement
     }

+ 12 - 18
src/rpcs.rs

@@ -43,48 +43,42 @@ impl RpcHandler for AppendEntriesRpcHandler {
 pub(crate) const REQUEST_VOTE_RPC: &'static str = "Raft.RequestVote";
 pub(crate) const APPEND_ENTRIES_RPC: &'static str = "Raft.AppendEntries";
 
-#[derive(Default)]
-pub(crate) struct RpcClient(Vec<Client>);
+#[derive(Clone)]
+pub(crate) struct RpcClient(Client);
 
 impl RpcClient {
     pub(crate) async fn call_request_vote(
-        &self,
-        client_index: usize,
+        self: Self,
         request: RequestVoteArgs,
     ) -> std::io::Result<RequestVoteReply> {
         let data = RequestMessage::from(
             bincode::serialize(&request)
                 .expect("Serialization of requests should not fail"),
         );
-        let reply = self.0[client_index]
-            .call_rpc(REQUEST_VOTE_RPC.to_owned(), data)
-            .await?;
+
+        let reply = self.0.call_rpc(REQUEST_VOTE_RPC.to_owned(), data).await?;
 
         Ok(bincode::deserialize(reply.as_ref())
             .expect("Deserialization of reply should not fail"))
     }
 
     pub(crate) async fn call_append_entries(
-        &self,
-        client_index: usize,
+        self: Self,
         request: AppendEntriesArgs,
     ) -> std::io::Result<AppendEntriesReply> {
         let data = RequestMessage::from(
             bincode::serialize(&request)
                 .expect("Serialization of requests should not fail"),
         );
-        let reply = self.0[client_index]
-            .call_rpc(APPEND_ENTRIES_RPC.to_owned(), data)
-            .await?;
+
+        let reply =
+            self.0.call_rpc(APPEND_ENTRIES_RPC.to_owned(), data).await?;
 
         Ok(bincode::deserialize(reply.as_ref())
             .expect("Deserialization of reply should not fail"))
     }
 }
 
-unsafe impl Send for RpcClient {}
-unsafe impl Sync for RpcClient {}
-
 pub(crate) fn register_server<S: AsRef<str>>(
     raft: Arc<Raft>,
     name: S,
@@ -132,7 +126,7 @@ mod tests {
             client
         };
 
-        let rpc_client = RpcClient(vec![client]);
+        let rpc_client = RpcClient(client);
         let request = RequestVoteArgs {
             term: Term(2021),
 
@@ -141,7 +135,7 @@ mod tests {
             last_log_term: Default::default(),
         };
         let response = futures::executor::block_on(
-            rpc_client.call_request_vote(0, request),
+            rpc_client.clone().call_request_vote(request),
         )?;
         assert_eq!(true, response.vote_granted);
 
@@ -154,7 +148,7 @@ mod tests {
             leader_commit: 0,
         };
         let response = futures::executor::block_on(
-            rpc_client.call_append_entries(0, request),
+            rpc_client.clone().call_append_entries(request),
         )?;
         assert_eq!(2021, response.term.0);
         assert_eq!(false, response.success);