|
|
@@ -1,110 +1,66 @@
|
|
|
use std::marker::PhantomData;
|
|
|
-use std::sync::Arc;
|
|
|
|
|
|
-use parking_lot::{Condvar, Mutex};
|
|
|
+use parking_lot::Mutex;
|
|
|
use serde::Serialize;
|
|
|
|
|
|
use ruaft::Snapshot;
|
|
|
use serde::de::DeserializeOwned;
|
|
|
-use std::sync::atomic::{AtomicBool, Ordering};
|
|
|
|
|
|
#[derive(Default)]
|
|
|
pub(crate) struct SnapshotHolder<T> {
|
|
|
- snapshot_requests: Mutex<Vec<(usize, Arc<Condvar>)>>,
|
|
|
- current_snapshot: Mutex<Snapshot>,
|
|
|
- shutdown: AtomicBool,
|
|
|
+ snapshot_requests: Mutex<Vec<usize>>,
|
|
|
phantom: PhantomData<T>,
|
|
|
}
|
|
|
|
|
|
impl<T> SnapshotHolder<T> {
|
|
|
- pub fn request_snapshot(&self, min_index: usize) -> Snapshot {
|
|
|
- if self.shutdown.load(Ordering::SeqCst) {
|
|
|
- return self.current_snapshot.lock().clone();
|
|
|
- }
|
|
|
-
|
|
|
- let condvar = {
|
|
|
- let mut requests = self.snapshot_requests.lock();
|
|
|
- let pos =
|
|
|
- requests.binary_search_by_key(&min_index, |&(index, _)| index);
|
|
|
- match pos {
|
|
|
- Ok(pos) => requests[pos].1.clone(),
|
|
|
- Err(pos) => {
|
|
|
- assert!(pos == 0 || requests[pos - 1].0 < min_index);
|
|
|
- assert!(
|
|
|
- pos + 1 >= requests.len()
|
|
|
- || requests[pos + 1].0 > min_index
|
|
|
- );
|
|
|
- let condvar = Arc::new(Condvar::new());
|
|
|
- requests.insert(pos, (min_index, condvar.clone()));
|
|
|
- condvar
|
|
|
- }
|
|
|
- }
|
|
|
- };
|
|
|
-
|
|
|
- // Now wait for the snapshot
|
|
|
- let mut current_snapshot = self.current_snapshot.lock();
|
|
|
- while current_snapshot.last_included_index < min_index
|
|
|
- && !self.shutdown.load(Ordering::SeqCst)
|
|
|
- {
|
|
|
- condvar.wait(&mut current_snapshot);
|
|
|
+ pub fn request_snapshot(&self, min_index: usize) {
|
|
|
+ let mut requests = self.snapshot_requests.lock();
|
|
|
+ let pos = requests.binary_search(&min_index);
|
|
|
+ if let Err(pos) = pos {
|
|
|
+ requests.insert(pos, min_index);
|
|
|
}
|
|
|
-
|
|
|
- current_snapshot.clone()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
impl<T: Serialize> SnapshotHolder<T> {
|
|
|
- const MIN_SNAPSHOT_INDEX_INTERVAL: usize = 100;
|
|
|
- pub fn take_snapshot(&self, state: &T, curr: usize) {
|
|
|
- let expired = self.current_snapshot.lock().last_included_index
|
|
|
- + Self::MIN_SNAPSHOT_INDEX_INTERVAL
|
|
|
- <= curr;
|
|
|
+ pub fn take_snapshot(&self, state: &T, curr: usize) -> Option<Snapshot> {
|
|
|
let requested = self
|
|
|
.snapshot_requests
|
|
|
.lock()
|
|
|
.first()
|
|
|
- .map_or(false, |&(min_index, _)| min_index <= curr);
|
|
|
+ .map_or(false, |&min_index| min_index <= curr);
|
|
|
|
|
|
- if expired || requested {
|
|
|
+ if requested {
|
|
|
let data = bincode::serialize(state)
|
|
|
.expect("Serialization should never fail.");
|
|
|
- let mut current_snapshot = self.current_snapshot.lock();
|
|
|
- assert!(current_snapshot.last_included_index < curr);
|
|
|
- *current_snapshot = Snapshot {
|
|
|
- last_included_index: curr,
|
|
|
+ return Some(Snapshot {
|
|
|
data,
|
|
|
- }
|
|
|
+ last_included_index: curr,
|
|
|
+ });
|
|
|
}
|
|
|
+ None
|
|
|
}
|
|
|
|
|
|
- pub fn unblock_response(&self) {
|
|
|
- let curr = self.current_snapshot.lock().last_included_index;
|
|
|
+ pub fn unblock_response(&self, curr: usize) {
|
|
|
let mut requests = self.snapshot_requests.lock();
|
|
|
let mut processed = 0;
|
|
|
- for (index, condvar) in requests.iter() {
|
|
|
- if *index <= curr {
|
|
|
+ for &index in requests.iter() {
|
|
|
+ if index <= curr {
|
|
|
processed += 1;
|
|
|
- condvar.notify_all();
|
|
|
} else {
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
requests.drain(0..processed);
|
|
|
}
|
|
|
-
|
|
|
- pub fn shutdown(&self) {
|
|
|
- self.shutdown.store(true, Ordering::SeqCst);
|
|
|
- for (_, condvar) in self.snapshot_requests.lock().iter() {
|
|
|
- condvar.notify_all();
|
|
|
- }
|
|
|
- }
|
|
|
}
|
|
|
|
|
|
impl<T: DeserializeOwned> SnapshotHolder<T> {
|
|
|
pub fn load_snapshot(&self, snapshot: Snapshot) -> T {
|
|
|
- let state = bincode::deserialize(&snapshot.data)
|
|
|
- .expect("Deserialization should never fail");
|
|
|
- *self.current_snapshot.lock() = snapshot;
|
|
|
+ let state = bincode::deserialize(&snapshot.data).expect(&*format!(
|
|
|
+ "Deserialization should never fail, {:?}",
|
|
|
+ &snapshot.data
|
|
|
+ ));
|
|
|
|
|
|
state
|
|
|
}
|