snapshot_holder.rs 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. use std::marker::PhantomData;
  2. use parking_lot::Mutex;
  3. use serde::Serialize;
  4. use ruaft::Snapshot;
  5. use serde::de::DeserializeOwned;
  6. #[derive(Default)]
  7. pub(crate) struct SnapshotHolder<T> {
  8. snapshot_requests: Mutex<Vec<usize>>,
  9. phantom: PhantomData<T>,
  10. }
  11. impl<T> SnapshotHolder<T> {
  12. pub fn request_snapshot(&self, min_index: usize) {
  13. let mut requests = self.snapshot_requests.lock();
  14. let pos = requests.binary_search(&min_index);
  15. if let Err(pos) = pos {
  16. requests.insert(pos, min_index);
  17. }
  18. }
  19. }
  20. impl<T: Serialize> SnapshotHolder<T> {
  21. pub fn take_snapshot(&self, state: &T, curr: usize) -> Option<Snapshot> {
  22. let requested = self
  23. .snapshot_requests
  24. .lock()
  25. .first()
  26. .map_or(false, |&min_index| min_index <= curr);
  27. if requested {
  28. let data = bincode::serialize(state)
  29. .expect("Serialization should never fail.");
  30. return Some(Snapshot {
  31. data,
  32. last_included_index: curr,
  33. });
  34. }
  35. None
  36. }
  37. pub fn unblock_response(&self, curr: usize) {
  38. let mut requests = self.snapshot_requests.lock();
  39. let mut processed = 0;
  40. for &index in requests.iter() {
  41. if index <= curr {
  42. processed += 1;
  43. } else {
  44. break;
  45. }
  46. }
  47. requests.drain(0..processed);
  48. }
  49. }
  50. impl<T: DeserializeOwned> SnapshotHolder<T> {
  51. pub fn load_snapshot(&self, snapshot: Snapshot) -> T {
  52. let state = bincode::deserialize(&snapshot.data).expect(&*format!(
  53. "Deserialization should never fail, {:?}",
  54. &snapshot.data
  55. ));
  56. state
  57. }
  58. }