Kaynağa Gözat

Add shutdown to snapshot holder and a new test.

Jing Yang 4 yıl önce
ebeveyn
işleme
cd5818f558
5 değiştirilmiş dosya ile 37 ekleme ve 4 silme
  1. 1 0
      kvraft/src/server.rs
  2. 16 1
      kvraft/src/snapshot_holder.rs
  3. 1 1
      src/lib.rs
  4. 8 2
      src/snapshot.rs
  5. 11 0
      tests/snapshot_tests.rs

+ 1 - 0
kvraft/src/server.rs

@@ -220,6 +220,7 @@ impl KVServer {
                     break;
                 }
             }
+            snapshot_holder.shutdown();
         });
     }
 

+ 16 - 1
kvraft/src/snapshot_holder.rs

@@ -6,16 +6,22 @@ use serde::Serialize;
 
 use ruaft::Snapshot;
 use serde::de::DeserializeOwned;
+use std::sync::atomic::{AtomicBool, Ordering};
 
 #[derive(Default)]
 pub(crate) struct SnapshotHolder<T> {
     snapshot_requests: Mutex<Vec<(usize, Arc<Condvar>)>>,
     current_snapshot: Mutex<Snapshot>,
+    shutdown: AtomicBool,
     phantom: PhantomData<T>,
 }
 
 impl<T> SnapshotHolder<T> {
     pub fn request_snapshot(&self, min_index: usize) -> Snapshot {
+        if self.shutdown.load(Ordering::SeqCst) {
+            return self.current_snapshot.lock().clone();
+        }
+
         let condvar = {
             let mut requests = self.snapshot_requests.lock();
             let pos =
@@ -37,7 +43,9 @@ impl<T> SnapshotHolder<T> {
 
         // Now wait for the snapshot
         let mut current_snapshot = self.current_snapshot.lock();
-        while current_snapshot.last_included_index < min_index {
+        while current_snapshot.last_included_index < min_index
+            && !self.shutdown.load(Ordering::SeqCst)
+        {
             condvar.wait(&mut current_snapshot);
         }
 
@@ -83,6 +91,13 @@ impl<T: Serialize> SnapshotHolder<T> {
         }
         requests.drain(0..processed);
     }
+
+    pub fn shutdown(&self) {
+        self.shutdown.store(true, Ordering::SeqCst);
+        for (_, condvar) in self.snapshot_requests.lock().iter() {
+            condvar.notify_all();
+        }
+    }
 }
 
 impl<T: DeserializeOwned> SnapshotHolder<T> {

+ 1 - 1
src/lib.rs

@@ -929,7 +929,7 @@ where
         self.election.stop_election_timer();
         self.new_log_entry.take().map(|n| n.send(None));
         self.apply_command_signal.notify_all();
-        self.snapshot_daemon.trigger();
+        self.snapshot_daemon.kill();
         self.stop_wait_group.wait();
         std::sync::Arc::try_unwrap(self.thread_pool)
             .expect(

+ 8 - 2
src/snapshot.rs

@@ -1,7 +1,9 @@
-use crate::{Index, Raft};
-use crossbeam_utils::sync::{Parker, Unparker};
 use std::sync::atomic::Ordering;
 
+use crossbeam_utils::sync::{Parker, Unparker};
+
+use crate::{Index, Raft};
+
 #[derive(Clone, Debug, Default)]
 pub struct Snapshot {
     pub last_included_index: Index,
@@ -27,6 +29,10 @@ impl SnapshotDaemon {
             None => {}
         }
     }
+
+    pub(crate) fn kill(&self) {
+        self.trigger();
+    }
 }
 
 impl<C: 'static + Clone + Default + Send + serde::Serialize> Raft<C> {

+ 11 - 0
tests/snapshot_tests.rs

@@ -3,6 +3,7 @@ extern crate kvraft;
 extern crate scopeguard;
 
 use kvraft::testing_utils::config::{make_config, sleep_election_timeouts};
+use kvraft::testing_utils::generic_test::{generic_test, GenericTestParams};
 use std::sync::Arc;
 
 #[test]
@@ -90,3 +91,13 @@ fn snapshot_size() {
 
     cfg.end();
 }
+
+#[test]
+fn snapshot_recover() {
+    generic_test(GenericTestParams {
+        clients: 1,
+        crash: true,
+        maxraftstate: Some(1000),
+        ..Default::default()
+    })
+}