Procházet zdrojové kódy

Merge branch 'snapshot': first version of snapshot taking.

Jing Yang před 5 roky
rodič
revize
5760aa3932
7 změnil soubory, kde provedl 342 přidání a 19 odebrání
  1. 107 0
      src/install_snapshot.rs
  2. 66 14
      src/lib.rs
  3. 3 0
      src/persister.rs
  4. 55 1
      src/rpcs.rs
  5. 86 0
      src/snapshot.rs
  6. 13 4
      tests/config/mod.rs
  7. 12 0
      tests/config/persister/mod.rs

+ 107 - 0
src/install_snapshot.rs

@@ -0,0 +1,107 @@
+use crate::utils::retry_rpc;
+use crate::{
+    Index, Peer, Raft, RaftState, RpcClient, State, Term, RPC_DEADLINE,
+};
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub(crate) struct InstallSnapshotArgs {
+    pub(crate) term: Term,
+    leader_id: Peer,
+    pub(crate) last_included_index: Index,
+    last_included_term: Term,
+    // TODO(ditsing): Serde cannot handle Vec<u8> as efficient as expected.
+    data: Vec<u8>,
+    offset: usize,
+    done: bool,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub(crate) struct InstallSnapshotReply {
+    term: Term,
+}
+
+impl<C: Clone + Default + serde::Serialize> Raft<C> {
+    pub(crate) fn process_install_snapshot(
+        &self,
+        args: InstallSnapshotArgs,
+    ) -> InstallSnapshotReply {
+        if args.offset != 0 || !args.done {
+            panic!("Current implementation cannot handle segmented snapshots.")
+        }
+
+        let mut rf = self.inner_state.lock();
+        if rf.current_term > args.term {
+            return InstallSnapshotReply {
+                term: rf.current_term,
+            };
+        }
+
+        if rf.current_term < args.term {
+            rf.current_term = args.term;
+            rf.voted_for = None;
+            self.persister.save_state(rf.persisted_state().into());
+        }
+
+        rf.state = State::Follower;
+        rf.leader_id = args.leader_id;
+
+        self.election.reset_election_timer();
+
+        // The above code is exactly the same as AppendEntries.
+
+        if args.last_included_index < rf.log.end()
+            && args.last_included_index >= rf.log.start()
+            && args.last_included_term == rf.log[args.last_included_index].term
+        {
+            rf.log.shift(args.last_included_index, args.data);
+        } else {
+            rf.log.reset(
+                args.last_included_index,
+                args.last_included_term,
+                args.data,
+            );
+        }
+        // The length of the log might shrink.
+        let last_log_index = rf.log.last_index_term().index;
+        if rf.commit_index > last_log_index {
+            rf.commit_index = last_log_index;
+        }
+        self.persister.save_snapshot_and_state(
+            rf.persisted_state().into(),
+            rf.log.snapshot().1,
+        );
+
+        self.apply_command_signal.notify_one();
+        InstallSnapshotReply { term: args.term }
+    }
+
+    pub(crate) fn build_install_snapshot(
+        rf: &RaftState<C>,
+    ) -> InstallSnapshotArgs {
+        let (last, snapshot) = rf.log.snapshot();
+        InstallSnapshotArgs {
+            term: rf.current_term,
+            leader_id: rf.leader_id,
+            last_included_index: last.index,
+            last_included_term: last.term,
+            data: snapshot.to_owned(),
+            offset: 0,
+            done: true,
+        }
+    }
+
+    const INSTALL_SNAPSHOT_RETRY: usize = 1;
+    pub(crate) async fn send_install_snapshot(
+        rpc_client: &RpcClient,
+        args: InstallSnapshotArgs,
+    ) -> std::io::Result<Option<bool>> {
+        let term = args.term;
+        let reply = retry_rpc(
+            Self::INSTALL_SNAPSHOT_RETRY,
+            RPC_DEADLINE,
+            move |_round| rpc_client.call_install_snapshot(args.clone()),
+        )
+        .await?;
+        Ok(if reply.term == term { Some(true) } else { None })
+    }
+}

+ 66 - 14
src/lib.rs

@@ -16,18 +16,23 @@ use crossbeam_utils::sync::WaitGroup;
 use parking_lot::{Condvar, Mutex};
 use rand::{thread_rng, Rng};
 
+use crate::install_snapshot::InstallSnapshotArgs;
 use crate::persister::PersistedRaftState;
 pub use crate::persister::Persister;
 pub(crate) use crate::raft_state::RaftState;
 pub(crate) use crate::raft_state::State;
 pub use crate::rpcs::RpcClient;
+pub use crate::snapshot::Snapshot;
+use crate::snapshot::SnapshotDaemon;
 use crate::utils::retry_rpc;
 
 mod index_term;
+mod install_snapshot;
 mod log_array;
 mod persister;
 mod raft_state;
 pub mod rpcs;
+mod snapshot;
 pub mod utils;
 
 #[derive(
@@ -66,6 +71,7 @@ pub struct Raft<Command> {
     apply_command_signal: Arc<Condvar>,
     keep_running: Arc<AtomicBool>,
     election: Arc<ElectionState>,
+    snapshot_daemon: SnapshotDaemon,
 
     thread_pool: Arc<tokio::runtime::Runtime>,
 
@@ -125,14 +131,17 @@ where
     ///
     /// Each instance will create at least 3 + (number of peers) threads. The
     /// extensive usage of threads is to minimize latency.
-    pub fn new<Func>(
+    pub fn new<ApplyCommandFunc, RequestSnapshotFunc>(
         peers: Vec<RpcClient>,
         me: usize,
         persister: Arc<dyn Persister>,
-        apply_command: Func,
+        apply_command: ApplyCommandFunc,
+        max_state_size_bytes: Option<usize>,
+        request_snapshot: RequestSnapshotFunc,
     ) -> Self
     where
-        Func: 'static + Send + FnMut(Index, Command),
+        ApplyCommandFunc: 'static + Send + FnMut(Index, Command),
+        RequestSnapshotFunc: 'static + Send + FnMut(Index) -> Snapshot,
     {
         let peer_size = peers.len();
         let mut state = RaftState {
@@ -179,6 +188,7 @@ where
             apply_command_signal: Arc::new(Default::default()),
             keep_running: Arc::new(Default::default()),
             election: Arc::new(election),
+            snapshot_daemon: Default::default(),
             thread_pool: Arc::new(thread_pool),
             stop_wait_group: WaitGroup::new(),
         };
@@ -195,6 +205,7 @@ where
         ));
         // The last step is to start running election timer.
         this.run_election_timer();
+        this.run_snapshot_daemon(max_state_size_bytes, request_snapshot);
         this
     }
 }
@@ -318,6 +329,12 @@ where
     }
 }
 
+enum SyncLogEntryOperation<Command> {
+    AppendEntries(AppendEntriesArgs<Command>),
+    InstallSnapshot(InstallSnapshotArgs),
+    None,
+}
+
 // Command must be
 // 0. 'static: Raft<Command> must be 'static, it is moved to another thread.
 // 1. clone: they are copied to the persister.
@@ -683,13 +700,26 @@ where
             return;
         }
 
-        let args = match Self::build_append_entries(&rf, peer_index) {
-            Some(args) => args,
-            None => return,
+        let operation = Self::build_sync_log_entry(&rf, peer_index);
+        let (term, match_index, succeeded) = match operation {
+            SyncLogEntryOperation::AppendEntries(args) => {
+                let term = args.term;
+                let match_index = args.prev_log_index + args.entries.len();
+                let succeeded = Self::append_entries(&rpc_client, args).await;
+
+                (term, match_index, succeeded)
+            }
+            SyncLogEntryOperation::InstallSnapshot(args) => {
+                let term = args.term;
+                let match_index = args.last_included_index;
+                let succeeded =
+                    Self::send_install_snapshot(&rpc_client, args).await;
+
+                (term, match_index, succeeded)
+            }
+            SyncLogEntryOperation::None => return,
         };
-        let term = args.term;
-        let match_index = args.prev_log_index + args.entries.len();
-        let succeeded = Self::append_entries(&rpc_client, args).await;
+
         match succeeded {
             Ok(Some(true)) => {
                 let mut rf = rf.lock();
@@ -748,24 +778,43 @@ where
         };
     }
 
-    fn build_append_entries(
+    fn build_sync_log_entry(
         rf: &Mutex<RaftState<Command>>,
         peer_index: usize,
-    ) -> Option<AppendEntriesArgs<Command>> {
+    ) -> SyncLogEntryOperation<Command> {
         let rf = rf.lock();
         if !rf.is_leader() {
-            return None;
+            return SyncLogEntryOperation::None;
         }
+
+        // To send AppendEntries request, next_index must be strictly larger
+        // than start(). Otherwise we won't be able to know the log term of the
+        // entry right before next_index.
+        return if rf.next_index[peer_index] > rf.log.start() {
+            SyncLogEntryOperation::AppendEntries(Self::build_append_entries(
+                &rf, peer_index,
+            ))
+        } else {
+            SyncLogEntryOperation::InstallSnapshot(
+                Self::build_install_snapshot(&rf),
+            )
+        };
+    }
+
+    fn build_append_entries(
+        rf: &RaftState<Command>,
+        peer_index: usize,
+    ) -> AppendEntriesArgs<Command> {
         let prev_log_index = rf.next_index[peer_index] - 1;
         let prev_log_term = rf.log[prev_log_index].term;
-        Some(AppendEntriesArgs {
+        AppendEntriesArgs {
             term: rf.current_term,
             leader_id: rf.leader_id,
             prev_log_index,
             prev_log_term,
             entries: rf.log.after(rf.next_index[peer_index]).to_vec(),
             leader_commit: rf.commit_index,
-        })
+        }
     }
 
     const APPEND_ENTRIES_RETRY: usize = 1;
@@ -797,6 +846,7 @@ where
         let keep_running = self.keep_running.clone();
         let rf = self.inner_state.clone();
         let condvar = self.apply_command_signal.clone();
+        let snapshot_daemon = self.snapshot_daemon.clone();
         let stop_wait_group = self.stop_wait_group.clone();
         std::thread::spawn(move || {
             while keep_running.load(Ordering::SeqCst) {
@@ -827,6 +877,7 @@ where
                 // Release the lock while calling external functions.
                 for command in commands {
                     apply_command(index, command);
+                    snapshot_daemon.trigger();
                     index += 1;
                 }
             }
@@ -855,6 +906,7 @@ where
         self.election.stop_election_timer();
         self.new_log_entry.take().map(|n| n.send(None));
         self.apply_command_signal.notify_all();
+        self.snapshot_daemon.trigger();
         self.stop_wait_group.wait();
         std::sync::Arc::try_unwrap(self.thread_pool)
             .expect(

+ 3 - 0
src/persister.rs

@@ -9,6 +9,9 @@ use serde::Serialize;
 pub trait Persister: Send + Sync {
     fn read_state(&self) -> Bytes;
     fn save_state(&self, bytes: Bytes);
+    fn state_size(&self) -> usize;
+
+    fn save_snapshot_and_state(&self, state: Bytes, snapshot: &[u8]);
 }
 
 #[derive(Serialize, Deserialize)]

+ 55 - 1
src/rpcs.rs

@@ -3,6 +3,7 @@ use std::sync::Arc;
 use labrpc::{Client, Network, ReplyMessage, RequestMessage, Server};
 use parking_lot::Mutex;
 
+use crate::install_snapshot::{InstallSnapshotArgs, InstallSnapshotReply};
 use crate::{
     AppendEntriesArgs, AppendEntriesReply, Raft, RequestVoteArgs,
     RequestVoteReply,
@@ -41,8 +42,23 @@ fn proxy_append_entries<
     )
 }
 
+fn proxy_install_snapshot<Command: Clone + Serialize + Default>(
+    raft: &Raft<Command>,
+    data: RequestMessage,
+) -> ReplyMessage {
+    let reply = raft.process_install_snapshot(
+        bincode::deserialize(data.as_ref())
+            .expect("Deserialization should not fail"),
+    );
+
+    ReplyMessage::from(
+        bincode::serialize(&reply).expect("Serialization should not fail"),
+    )
+}
+
 pub(crate) const REQUEST_VOTE_RPC: &str = "Raft.RequestVote";
 pub(crate) const APPEND_ENTRIES_RPC: &str = "Raft.AppendEntries";
+pub(crate) const INSTALL_SNAPSHOT_RPC: &str = "Raft.InstallSnapshot";
 
 pub struct RpcClient(Client);
 
@@ -81,6 +97,24 @@ impl RpcClient {
         Ok(bincode::deserialize(reply.as_ref())
             .expect("Deserialization of reply should not fail"))
     }
+
+    pub(crate) async fn call_install_snapshot(
+        &self,
+        request: InstallSnapshotArgs,
+    ) -> std::io::Result<InstallSnapshotReply> {
+        let data = RequestMessage::from(
+            bincode::serialize(&request)
+                .expect("Serialization of requests should not fail"),
+        );
+
+        let reply = self
+            .0
+            .call_rpc(INSTALL_SNAPSHOT_RPC.to_owned(), data)
+            .await?;
+
+        Ok(bincode::deserialize(reply.as_ref())
+            .expect("Deserialization of reply should not fail"))
+    }
 }
 
 pub fn register_server<
@@ -103,7 +137,7 @@ pub fn register_server<
         }),
     )?;
 
-    let raft_clone = raft;
+    let raft_clone = raft.clone();
     server.register_rpc_handler(
         APPEND_ENTRIES_RPC.to_owned(),
         Box::new(move |request| {
@@ -111,6 +145,14 @@ pub fn register_server<
         }),
     )?;
 
+    let raft_clone = raft;
+    server.register_rpc_handler(
+        INSTALL_SNAPSHOT_RPC.to_owned(),
+        Box::new(move |request| {
+            proxy_install_snapshot(raft_clone.as_ref(), request)
+        }),
+    )?;
+
     network.add_server(server_name, server);
 
     Ok(())
@@ -123,6 +165,7 @@ mod tests {
     use crate::{Peer, Term};
 
     use super::*;
+    use crate::snapshot::Snapshot;
 
     type DoNothingPersister = ();
     impl crate::Persister for DoNothingPersister {
@@ -131,6 +174,12 @@ mod tests {
         }
 
         fn save_state(&self, _bytes: Bytes) {}
+
+        fn state_size(&self) -> usize {
+            0
+        }
+
+        fn save_snapshot_and_state(&self, _: Bytes, _: &[u8]) {}
     }
 
     #[test]
@@ -148,6 +197,11 @@ mod tests {
                 0,
                 Arc::new(()),
                 |_, _: i32| {},
+                None,
+                |index| Snapshot {
+                    last_included_index: index,
+                    data: vec![],
+                },
             ));
             register_server(raft, name, network.as_ref())?;
 

+ 86 - 0
src/snapshot.rs

@@ -0,0 +1,86 @@
+use crate::{Index, Raft};
+use crossbeam_utils::sync::{Parker, Unparker};
+use std::sync::atomic::Ordering;
+
+pub struct Snapshot {
+    pub last_included_index: Index,
+    pub data: Vec<u8>,
+}
+
+#[derive(Clone, Debug, Default)]
+pub(crate) struct SnapshotDaemon {
+    unparker: Option<Unparker>,
+}
+
+impl SnapshotDaemon {
+    pub(crate) fn trigger(&self) {
+        match &self.unparker {
+            Some(unparker) => unparker.unpark(),
+            None => {}
+        }
+    }
+}
+
+impl<C: 'static + Clone + Default + Send + serde::Serialize> Raft<C> {
+    pub(crate) fn run_snapshot_daemon<Func>(
+        &mut self,
+        max_state_size: Option<usize>,
+        mut request_snapshot: Func,
+    ) where
+        Func: 'static + Send + FnMut(Index) -> Snapshot,
+    {
+        let max_state_size = match max_state_size {
+            Some(max_state_size) => max_state_size,
+            None => return,
+        };
+
+        let parker = Parker::new();
+        let unparker = parker.unparker().clone();
+        self.snapshot_daemon.unparker.replace(unparker.clone());
+
+        let keep_running = self.keep_running.clone();
+        let rf = self.inner_state.clone();
+        let persister = self.persister.clone();
+        let stop_wait_group = self.stop_wait_group.clone();
+
+        std::thread::spawn(move || loop {
+            parker.park();
+            if !keep_running.load(Ordering::SeqCst) {
+                drop(stop_wait_group);
+                break;
+            }
+            if persister.state_size() >= max_state_size {
+                let log_start = rf.lock().log.first_index_term();
+                let snapshot = request_snapshot(log_start.index + 1);
+
+                let mut rf = rf.lock();
+                if rf.log.first_index_term() != log_start {
+                    // Another snapshot was installed, let's try again.
+                    unparker.unpark();
+                    continue;
+                }
+                if snapshot.last_included_index <= rf.log.start() {
+                    // It seems the request_snapshot callback is misbehaving,
+                    // let's try again.
+                    unparker.unpark();
+                    continue;
+                }
+
+                if snapshot.last_included_index >= rf.log.end() {
+                    // We recently rolled back some of the committed logs. This
+                    // can happen but usually the same exact log entries will be
+                    // installed in the next AppendEntries request.
+                    // There is no need to retry, because when the log entries
+                    // are re-committed, we will be notified again.
+                    continue;
+                }
+
+                rf.log.shift(snapshot.last_included_index, snapshot.data);
+                persister.save_snapshot_and_state(
+                    rf.persisted_state().into(),
+                    rf.log.snapshot().1,
+                );
+            }
+        });
+    }
+}

+ 13 - 4
tests/config/mod.rs

@@ -8,7 +8,7 @@ use rand::{thread_rng, Rng};
 use tokio::time::Duration;
 
 use ruaft::rpcs::register_server;
-use ruaft::{Persister, Raft, RpcClient};
+use ruaft::{Persister, Raft, RpcClient, Snapshot};
 
 pub mod persister;
 
@@ -307,10 +307,19 @@ impl Config {
         let persister = self.log.lock().saved[index].clone();
 
         let log_clone = self.log.clone();
-        let raft =
-            Raft::new(clients, index, persister, move |cmd_index, cmd| {
+        let raft = Raft::new(
+            clients,
+            index,
+            persister,
+            move |cmd_index, cmd| {
                 Self::apply_command(log_clone.clone(), index, cmd_index, cmd)
-            });
+            },
+            None,
+            |index| Snapshot {
+                last_included_index: index,
+                data: vec![],
+            },
+        );
         self.state.lock().rafts[index].replace(raft.clone());
 
         let raft = Arc::new(raft);

+ 12 - 0
tests/config/persister/mod.rs

@@ -2,6 +2,7 @@ use parking_lot::Mutex;
 
 struct State {
     bytes: bytes::Bytes,
+    snapshot: Vec<u8>,
 }
 
 pub struct Persister {
@@ -13,6 +14,7 @@ impl Persister {
         Self {
             state: Mutex::new(State {
                 bytes: bytes::Bytes::new(),
+                snapshot: vec![],
             }),
         }
     }
@@ -26,4 +28,14 @@ impl ruaft::Persister for Persister {
     fn save_state(&self, data: bytes::Bytes) {
         self.state.lock().bytes = data;
     }
+
+    fn state_size(&self) -> usize {
+        self.state.lock().bytes.len()
+    }
+
+    fn save_snapshot_and_state(&self, state: bytes::Bytes, snapshot: &[u8]) {
+        let mut this = self.state.lock();
+        this.bytes = state;
+        this.snapshot = snapshot.to_vec();
+    }
 }