Browse Source

Adapt to the new snapshot framework.

Jing Yang 4 năm trước cách đây
mục cha
commit
967dae925a
2 tập tin đã thay đổi với 27 bổ sung69 xóa
  1. 6 4
      kvraft/src/server.rs
  2. 21 65
      kvraft/src/snapshot_holder.rs

+ 6 - 4
kvraft/src/server.rs

@@ -211,16 +211,18 @@ impl KVServer {
                                 command.me,
                                 command.op,
                             );
-                            snapshot_holder
-                                .take_snapshot(&this.state.lock(), index);
-                            snapshot_holder.unblock_response();
+                            if let Some(snapshot) = snapshot_holder
+                                .take_snapshot(&this.state.lock(), index)
+                            {
+                                this.rf.lock().save_snapshot(snapshot);
+                                snapshot_holder.unblock_response(index);
+                            }
                         }
                     }
                 } else {
                     break;
                 }
             }
-            snapshot_holder.shutdown();
         });
     }
 

+ 21 - 65
kvraft/src/snapshot_holder.rs

@@ -1,110 +1,66 @@
 use std::marker::PhantomData;
-use std::sync::Arc;
 
-use parking_lot::{Condvar, Mutex};
+use parking_lot::Mutex;
 use serde::Serialize;
 
 use ruaft::Snapshot;
 use serde::de::DeserializeOwned;
-use std::sync::atomic::{AtomicBool, Ordering};
 
 #[derive(Default)]
 pub(crate) struct SnapshotHolder<T> {
-    snapshot_requests: Mutex<Vec<(usize, Arc<Condvar>)>>,
-    current_snapshot: Mutex<Snapshot>,
-    shutdown: AtomicBool,
+    snapshot_requests: Mutex<Vec<usize>>,
     phantom: PhantomData<T>,
 }
 
 impl<T> SnapshotHolder<T> {
-    pub fn request_snapshot(&self, min_index: usize) -> Snapshot {
-        if self.shutdown.load(Ordering::SeqCst) {
-            return self.current_snapshot.lock().clone();
-        }
-
-        let condvar = {
-            let mut requests = self.snapshot_requests.lock();
-            let pos =
-                requests.binary_search_by_key(&min_index, |&(index, _)| index);
-            match pos {
-                Ok(pos) => requests[pos].1.clone(),
-                Err(pos) => {
-                    assert!(pos == 0 || requests[pos - 1].0 < min_index);
-                    assert!(
-                        pos + 1 >= requests.len()
-                            || requests[pos + 1].0 > min_index
-                    );
-                    let condvar = Arc::new(Condvar::new());
-                    requests.insert(pos, (min_index, condvar.clone()));
-                    condvar
-                }
-            }
-        };
-
-        // Now wait for the snapshot
-        let mut current_snapshot = self.current_snapshot.lock();
-        while current_snapshot.last_included_index < min_index
-            && !self.shutdown.load(Ordering::SeqCst)
-        {
-            condvar.wait(&mut current_snapshot);
+    pub fn request_snapshot(&self, min_index: usize) {
+        let mut requests = self.snapshot_requests.lock();
+        let pos = requests.binary_search(&min_index);
+        if let Err(pos) = pos {
+            requests.insert(pos, min_index);
         }
-
-        current_snapshot.clone()
     }
 }
 
 impl<T: Serialize> SnapshotHolder<T> {
-    const MIN_SNAPSHOT_INDEX_INTERVAL: usize = 100;
-    pub fn take_snapshot(&self, state: &T, curr: usize) {
-        let expired = self.current_snapshot.lock().last_included_index
-            + Self::MIN_SNAPSHOT_INDEX_INTERVAL
-            <= curr;
+    pub fn take_snapshot(&self, state: &T, curr: usize) -> Option<Snapshot> {
         let requested = self
             .snapshot_requests
             .lock()
             .first()
-            .map_or(false, |&(min_index, _)| min_index <= curr);
+            .map_or(false, |&min_index| min_index <= curr);
 
-        if expired || requested {
+        if requested {
             let data = bincode::serialize(state)
                 .expect("Serialization should never fail.");
-            let mut current_snapshot = self.current_snapshot.lock();
-            assert!(current_snapshot.last_included_index < curr);
-            *current_snapshot = Snapshot {
-                last_included_index: curr,
+            return Some(Snapshot {
                 data,
-            }
+                last_included_index: curr,
+            });
         }
+        None
     }
 
-    pub fn unblock_response(&self) {
-        let curr = self.current_snapshot.lock().last_included_index;
+    pub fn unblock_response(&self, curr: usize) {
         let mut requests = self.snapshot_requests.lock();
         let mut processed = 0;
-        for (index, condvar) in requests.iter() {
-            if *index <= curr {
+        for &index in requests.iter() {
+            if index <= curr {
                 processed += 1;
-                condvar.notify_all();
             } else {
                 break;
             }
         }
         requests.drain(0..processed);
     }
-
-    pub fn shutdown(&self) {
-        self.shutdown.store(true, Ordering::SeqCst);
-        for (_, condvar) in self.snapshot_requests.lock().iter() {
-            condvar.notify_all();
-        }
-    }
 }
 
 impl<T: DeserializeOwned> SnapshotHolder<T> {
     pub fn load_snapshot(&self, snapshot: Snapshot) -> T {
-        let state = bincode::deserialize(&snapshot.data)
-            .expect("Deserialization should never fail");
-        *self.current_snapshot.lock() = snapshot;
+        let state = bincode::deserialize(&snapshot.data).expect(&*format!(
+            "Deserialization should never fail, {:?}",
+            &snapshot.data
+        ));
 
         state
     }