浏览代码

Implement install snapshot.

Jing Yang 4 年之前
父节点
当前提交
eb08693e11
共有 3 个文件被更改,包括 43 次插入11 次删除
  1. 31 10
      kvraft/src/server.rs
  2. 11 0
      kvraft/src/snapshot_holder.rs
  3. 1 1
      tests/snapshot_tests.rs

+ 31 - 10
kvraft/src/server.rs

@@ -7,7 +7,7 @@ use std::time::Duration;
 
 use parking_lot::{Condvar, Mutex};
 
-use ruaft::{Persister, Raft, RpcClient, Term};
+use ruaft::{ApplyCommandMessage, Persister, Raft, RpcClient, Term};
 
 use crate::common::{
     ClerkId, GetArgs, GetReply, KVError, PutAppendArgs, PutAppendEnum,
@@ -22,8 +22,6 @@ pub struct KVServer {
     // snapshot
 }
 
-type IndexedCommand = (usize, UniqueKVOp);
-
 #[derive(Clone, Default, Serialize, Deserialize)]
 pub struct UniqueKVOp {
     op: KVOp,
@@ -98,8 +96,8 @@ impl KVServer {
         max_state_size_bytes: Option<usize>,
     ) -> Arc<Self> {
         let (tx, rx) = channel();
-        let apply_command = move |index, command| {
-            tx.send((index, command))
+        let apply_command = move |message| {
+            tx.send(message)
                 .expect("The receiving end of apply command channel should have not been dropped");
         };
         let snapshot_holder = Arc::new(SnapshotHolder::default());
@@ -182,18 +180,41 @@ impl KVServer {
         };
     }
 
+    fn restore_state(&self, mut new_state: KVServerState) {
+        let mut state = self.state.lock();
+        std::mem::swap(&mut new_state, &mut *state);
+
+        for result_holder in new_state.queries.values() {
+            *result_holder.result.lock() = Err(CommitError::NotLeader);
+            result_holder.condvar.notify_all();
+        }
+    }
+
     fn process_command(
         self: &Arc<Self>,
         snapshot_holder: Arc<SnapshotHolder<KVServerState>>,
-        command_channel: Receiver<IndexedCommand>,
+        command_channel: Receiver<ApplyCommandMessage<UniqueKVOp>>,
     ) {
         let this = Arc::downgrade(self);
         std::thread::spawn(move || {
-            while let Ok((index, command)) = command_channel.recv() {
+            while let Ok(message) = command_channel.recv() {
                 if let Some(this) = this.upgrade() {
-                    this.apply_op(command.unique_id, command.me, command.op);
-                    snapshot_holder.take_snapshot(&this.state.lock(), index);
-                    snapshot_holder.unblock_response();
+                    match message {
+                        ApplyCommandMessage::Snapshot(snapshot) => {
+                            let state = snapshot_holder.load_snapshot(snapshot);
+                            this.restore_state(state);
+                        }
+                        ApplyCommandMessage::Command(index, command) => {
+                            this.apply_op(
+                                command.unique_id,
+                                command.me,
+                                command.op,
+                            );
+                            snapshot_holder
+                                .take_snapshot(&this.state.lock(), index);
+                            snapshot_holder.unblock_response();
+                        }
+                    }
                 } else {
                     break;
                 }

+ 11 - 0
kvraft/src/snapshot_holder.rs

@@ -5,6 +5,7 @@ use parking_lot::{Condvar, Mutex};
 use serde::Serialize;
 
 use ruaft::Snapshot;
+use serde::de::DeserializeOwned;
 
 #[derive(Default)]
 pub(crate) struct SnapshotHolder<T> {
@@ -83,3 +84,13 @@ impl<T: Serialize> SnapshotHolder<T> {
         requests.drain(0..processed);
     }
 }
+
+impl<T: DeserializeOwned> SnapshotHolder<T> {
+    pub fn load_snapshot(&self, snapshot: Snapshot) -> T {
+        let state = bincode::deserialize(&snapshot.data)
+            .expect("Deserialization should never fail");
+        *self.current_snapshot.lock() = snapshot;
+
+        state
+    }
+}

+ 1 - 1
tests/snapshot_tests.rs

@@ -10,7 +10,7 @@ fn install_snapshot_rpc() {
     const SERVERS: usize = 3;
     const MAX_RAFT_STATE: usize = 1000;
     const KEY: &str = "a";
-    let cfg = Arc::new(make_config(SERVERS, true, MAX_RAFT_STATE));
+    let cfg = Arc::new(make_config(SERVERS, false, MAX_RAFT_STATE));
     defer!(cfg.clean_up());
 
     let mut clerk = cfg.make_clerk();