snapshot_holder.rs 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. use std::marker::PhantomData;
  2. use parking_lot::Mutex;
  3. use serde::de::DeserializeOwned;
  4. use serde::Serialize;
  5. use ruaft::Snapshot;
  6. #[derive(Default)]
  7. pub(crate) struct SnapshotHolder<T> {
  8. snapshot_requests: Mutex<Vec<usize>>,
  9. phantom: PhantomData<T>,
  10. }
  11. impl<T> SnapshotHolder<T> {
  12. pub fn request_snapshot(&self, min_index: usize) {
  13. let mut requests = self.snapshot_requests.lock();
  14. let pos = requests.binary_search(&min_index);
  15. if let Err(pos) = pos {
  16. requests.insert(pos, min_index);
  17. }
  18. }
  19. }
  20. impl<T: Serialize> SnapshotHolder<T> {
  21. pub fn take_snapshot(&self, state: &T, curr: usize) -> Option<Snapshot> {
  22. let mut requests = self.snapshot_requests.lock();
  23. let processed = requests.partition_point(|index| *index <= curr);
  24. if processed == 0 {
  25. return None;
  26. }
  27. requests.drain(0..processed);
  28. drop(requests);
  29. let data = bincode::serialize(state)
  30. .expect("Serialization should never fail.");
  31. Some(Snapshot {
  32. data,
  33. last_included_index: curr,
  34. })
  35. }
  36. }
  37. impl<T: DeserializeOwned> SnapshotHolder<T> {
  38. pub fn load_snapshot(&self, snapshot: Snapshot) -> T {
  39. if let Ok(result) = bincode::deserialize(&snapshot.data) {
  40. result
  41. } else {
  42. panic!("Deserialization should never fail, {:?}", snapshot.data)
  43. }
  44. }
  45. }