Parcourir la source

Add basic structure of raft states.

Jing Yang il y a 5 ans
Parent
commit
3178f92b72
3 fichiers modifiés avec 111 ajouts et 36 suppressions
  1. 2 2
      Cargo.toml
  2. 57 2
      src/lib.rs
  3. 52 32
      src/rpcs.rs

+ 2 - 2
Cargo.toml

@@ -7,12 +7,12 @@ edition = "2018"
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 
 [dependencies]
+bincode = "1.3.1"
 labrpc = { path = "../labrpc" }
-tokio = "0.2.22"
 parking_lot = "0.11.0"
 serde = "1.0.116"
 serde_derive = "1.0.116"
-bincode = "1.3.1"
+tokio = { version = "0.2.22", features = ["rt-threaded", "time"] }
 
 [dev-dependencies]
 futures = { version = "0.3.5", features = ["thread-pool"] }

+ 57 - 2
src/lib.rs

@@ -6,9 +6,64 @@ extern crate labrpc;
 extern crate serde_derive;
 extern crate tokio;
 
-mod rpcs;
+use crate::rpcs::RpcClient;
+use parking_lot::{Condvar, Mutex};
+use std::sync::atomic::AtomicBool;
 
-struct Raft {}
+pub mod rpcs;
+
+enum State {
+    Follower,
+    Candidate,
+    // TODO: add PreVote
+    Leader,
+}
+
+// TODO: remove all of the defaults.
+impl Default for State {
+    fn default() -> Self {
+        Self::Leader
+    }
+}
+
+#[derive(Default)]
+struct RaftState {
+    current_term: usize,
+    voted_for: i64,
+    // TODO: Allow sending of arbitrary information.
+    log: Vec<usize>,
+
+    commit_index: usize,
+    last_applied: usize,
+
+    next_index: Vec<usize>,
+    match_index: Vec<usize>,
+    current_step: Vec<i64>,
+
+    state: State,
+
+    leader_id: usize,
+    // election_timer: timer,
+}
+
+#[derive(Default)]
+struct Raft {
+    inner_state: Mutex<RaftState>,
+    peers: RpcClient,
+
+    me: usize,
+
+    vote_mutex: Mutex<()>,
+    vote_cond: Condvar,
+
+    // new_log_entry: Sender<usize>,
+    // new_log_entry: Receiver<usize>,
+    // apply_command_cond: Condvar
+
+    keep_running: AtomicBool,
+
+    // applyCh: Sender<ApplyMsg>
+}
 
 #[derive(Serialize, Deserialize)]
 struct RequestVoteArgs {

+ 52 - 32
src/rpcs.rs

@@ -25,21 +25,6 @@ impl RpcHandler for RequestVoteRpcHandler {
     }
 }
 
-pub(crate) const REQUEST_VOTE_RPC: &'static str = "Raft.RequestVote";
-pub(crate) async fn call_request_vote(
-    client: &Client,
-    request: RequestVoteArgs,
-) -> std::io::Result<RequestVoteReply> {
-    let data = RequestMessage::from(
-        bincode::serialize(&request)
-            .expect("Serialization of requests should not fail"),
-    );
-    let reply = client.call_rpc(REQUEST_VOTE_RPC.to_owned(), data).await?;
-
-    Ok(bincode::deserialize(reply.as_ref())
-        .expect("Deserialization of reply should not fail"))
-}
-
 struct AppendEntriesRpcHandler(Arc<Raft>);
 
 impl RpcHandler for AppendEntriesRpcHandler {
@@ -55,21 +40,51 @@ impl RpcHandler for AppendEntriesRpcHandler {
     }
 }
 
+pub(crate) const REQUEST_VOTE_RPC: &'static str = "Raft.RequestVote";
 pub(crate) const APPEND_ENTRIES_RPC: &'static str = "Raft.AppendEntries";
-pub(crate) async fn call_append_entries(
-    client: &Client,
-    request: AppendEntriesArgs,
-) -> std::io::Result<AppendEntriesReply> {
-    let data = RequestMessage::from(
-        bincode::serialize(&request)
-            .expect("Serialization of requests should not fail"),
-    );
-    let reply = client.call_rpc(APPEND_ENTRIES_RPC.to_owned(), data).await?;
-
-    Ok(bincode::deserialize(reply.as_ref())
-        .expect("Deserialization of reply should not fail"))
+
+#[derive(Default)]
+pub(crate) struct RpcClient(Vec<Client>);
+
+impl RpcClient {
+    pub(crate) async fn call_request_vote(
+        &self,
+        client_index: usize,
+        request: RequestVoteArgs,
+    ) -> std::io::Result<RequestVoteReply> {
+        let data = RequestMessage::from(
+            bincode::serialize(&request)
+                .expect("Serialization of requests should not fail"),
+        );
+        let reply = self.0[client_index]
+            .call_rpc(REQUEST_VOTE_RPC.to_owned(), data)
+            .await?;
+
+        Ok(bincode::deserialize(reply.as_ref())
+            .expect("Deserialization of reply should not fail"))
+    }
+
+    pub(crate) async fn call_append_entries(
+        &self,
+        client_index: usize,
+        request: AppendEntriesArgs,
+    ) -> std::io::Result<AppendEntriesReply> {
+        let data = RequestMessage::from(
+            bincode::serialize(&request)
+                .expect("Serialization of requests should not fail"),
+        );
+        let reply = self.0[client_index]
+            .call_rpc(APPEND_ENTRIES_RPC.to_owned(), data)
+            .await?;
+
+        Ok(bincode::deserialize(reply.as_ref())
+            .expect("Deserialization of reply should not fail"))
+    }
 }
 
+unsafe impl Send for RpcClient {}
+unsafe impl Sync for RpcClient {}
+
 pub(crate) fn register_server<S: AsRef<str>>(
     raft: Arc<Raft>,
     name: S,
@@ -104,7 +119,9 @@ mod tests {
     fn test_basic_message() -> std::io::Result<()> {
         let client = {
             let network = Network::run_daemon();
-            let raft = Arc::new(Raft {});
+            let raft = Arc::new(Raft {
+                ..Default::default()
+            });
             let name = "test-basic-message";
 
             register_server(raft, name, network.clone())?;
@@ -115,14 +132,17 @@ mod tests {
             client
         };
 
+        let rpc_client = RpcClient(vec![client]);
         let request = RequestVoteArgs { term: 2021 };
-        let response =
-            futures::executor::block_on(call_request_vote(&client, request))?;
+        let response = futures::executor::block_on(
+            rpc_client.call_request_vote(0, request),
+        )?;
         assert_eq!(2022, response.term);
 
         let request = AppendEntriesArgs { term: 2021 };
-        let response =
-            futures::executor::block_on(call_append_entries(&client, request))?;
+        let response = futures::executor::block_on(
+            rpc_client.call_append_entries(0, request),
+        )?;
         assert_eq!(2020, response.term);
 
         Ok(())