Przeglądaj źródła

Migrate log array to generics.

Jing Yang 5 lat temu
rodzic
commit
b8e2904c31
4 zmienionych plików z 42 dodań i 35 usunięć
  1. 9 4
      src/lib.rs
  2. 27 27
      src/log_array.rs
  3. 1 1
      src/raft_state.rs
  4. 5 3
      src/rpcs.rs

+ 9 - 4
src/lib.rs

@@ -153,7 +153,8 @@ where
         {
             state.current_term = persisted_state.current_term;
             state.voted_for = persisted_state.voted_for;
-            state.log = log_array::LogArray::restore(persisted_state.log).unwrap();
+            state.log =
+                log_array::LogArray::restore(persisted_state.log).unwrap();
         }
 
         let election = ElectionState {
@@ -201,9 +202,10 @@ where
 // Command must be
 // 1. clone: they are copied to the persister.
 // 2. serialize: they are converted to bytes to persist.
+// 3. default: a default value is used as the first element of the log.
 impl<Command> Raft<Command>
 where
-    Command: Clone + serde::Serialize,
+    Command: Clone + serde::Serialize + Default,
 {
     pub(crate) fn process_request_vote(
         &self,
@@ -321,9 +323,10 @@ where
 // 1. clone: they are copied to the persister.
 // 2. send: Arc<Mutex<Vec<LogEntry<Command>>>> must be send, it is moved to another thread.
 // 3. serialize: they are converted to bytes to persist.
+// 4. default: a default value is used as the first element of log.
 impl<Command> Raft<Command>
 where
-    Command: 'static + Clone + Send + serde::Serialize,
+    Command: 'static + Clone + Send + serde::Serialize + Default,
 {
     fn run_election_timer(&self) -> std::thread::JoinHandle<()> {
         let this = self.clone();
@@ -807,7 +810,9 @@ where
                     if rf.last_applied < rf.commit_index {
                         let index = rf.last_applied + 1;
                         let last_one = rf.commit_index + 1;
-                        let commands: Vec<Command> = rf.log.between(index, last_one)
+                        let commands: Vec<Command> = rf
+                            .log
+                            .between(index, last_one)
                             .iter()
                             .map(|entry| entry.command.clone())
                             .collect();

+ 27 - 27
src/log_array.rs

@@ -1,13 +1,13 @@
-use crate::{Command, Index, LogEntry, Term};
+use crate::{Index, LogEntry, Term};
 use std::mem::swap;
 
-pub(crate) struct LogArray {
-    inner: Vec<LogEntry>,
+pub(crate) struct LogArray<C> {
+    inner: Vec<LogEntry<C>>,
     snapshot: bytes::Bytes,
 }
 
-impl LogArray {
-    pub fn create() -> LogArray {
+impl<C: Default> LogArray<C> {
+    pub fn create() -> LogArray<C> {
         let ret = LogArray {
             inner: vec![Self::build_first_entry(0, Term(0))],
             snapshot: bytes::Bytes::new(),
@@ -16,7 +16,7 @@ impl LogArray {
         ret
     }
 
-    pub fn restore(inner: Vec<LogEntry>) -> std::io::Result<Self> {
+    pub fn restore(inner: Vec<LogEntry<C>>) -> std::io::Result<Self> {
         Ok(LogArray {
             inner,
             snapshot: bytes::Bytes::new(),
@@ -25,7 +25,7 @@ impl LogArray {
 }
 
 // Log accessors
-impl LogArray {
+impl<C> LogArray<C> {
     pub fn start_offset(&self) -> Index {
         self.first_entry().index
     }
@@ -45,23 +45,23 @@ impl LogArray {
         (last_entry.index, last_entry.term)
     }
 
-    pub fn at(&self, index: Index) -> &LogEntry {
+    pub fn at(&self, index: Index) -> &LogEntry<C> {
         let index = self.check_start_index(index);
         &self.inner[index]
     }
 
-    pub fn after(&self, index: Index) -> &[LogEntry] {
+    pub fn after(&self, index: Index) -> &[LogEntry<C>] {
         let index = self.check_start_index(index);
         &self.inner[index..]
     }
 
-    pub fn between(&self, start: Index, end: Index) -> &[LogEntry] {
+    pub fn between(&self, start: Index, end: Index) -> &[LogEntry<C>] {
         let start = self.check_start_index(start);
         let end = self.check_end_index(end);
         &self.inner[start..end]
     }
 
-    pub fn all(&self) -> &[LogEntry] {
+    pub fn all(&self) -> &[LogEntry<C>] {
         &self.inner[..]
     }
 
@@ -71,8 +71,8 @@ impl LogArray {
     }
 }
 
-impl std::ops::Index<usize> for LogArray {
-    type Output = LogEntry;
+impl<C> std::ops::Index<usize> for LogArray<C> {
+    type Output = LogEntry<C>;
 
     fn index(&self, index: usize) -> &Self::Output {
         self.at(index)
@@ -80,8 +80,8 @@ impl std::ops::Index<usize> for LogArray {
 }
 
 // Mutations
-impl LogArray {
-    pub fn add(&mut self, term: Term, command: Command) -> Index {
+impl<C> LogArray<C> {
+    pub fn add(&mut self, term: Term, command: C) -> Index {
         let index = self.len();
         self.push(LogEntry {
             index,
@@ -91,13 +91,9 @@ impl LogArray {
         index
     }
 
-    pub fn push(&mut self, log_entry: LogEntry) {
+    pub fn push(&mut self, log_entry: LogEntry<C>) {
         let index = log_entry.index;
-        assert_eq!(
-            index,
-            self.len(),
-            "Expecting new index to be exact at len",
-        );
+        assert_eq!(index, self.len(), "Expecting new index to be exact at len");
         self.inner.push(log_entry);
         assert_eq!(
             index + 1,
@@ -117,7 +113,9 @@ impl LogArray {
         self.inner.truncate(index);
         self.check_one_element()
     }
+}
 
+impl<C: Default> LogArray<C> {
     #[allow(dead_code)]
     pub fn shift(&mut self, index: Index, snapshot: bytes::Bytes) {
         // Discard everything before index and store the snapshot.
@@ -145,7 +143,7 @@ impl LogArray {
         index: Index,
         term: Term,
         snapshot: bytes::Bytes,
-    ) -> Vec<LogEntry> {
+    ) -> Vec<LogEntry<C>> {
         let mut inner = vec![Self::build_first_entry(index, term)];
         swap(&mut inner, &mut self.inner);
         self.snapshot = snapshot;
@@ -156,14 +154,14 @@ impl LogArray {
     }
 }
 
-impl LogArray {
-    fn first_entry(&self) -> &LogEntry {
+impl<C> LogArray<C> {
+    fn first_entry(&self) -> &LogEntry<C> {
         self.inner
             .first()
             .expect("There must be at least one element in log")
     }
 
-    fn last_entry(&self) -> &LogEntry {
+    fn last_entry(&self) -> &LogEntry<C> {
         &self
             .inner
             .last()
@@ -216,12 +214,14 @@ impl LogArray {
             "There must be at least one element in log"
         )
     }
+}
 
-    fn build_first_entry(index: Index, term: Term) -> LogEntry {
+impl<C: Default> LogArray<C> {
+    fn build_first_entry(index: Index, term: Term) -> LogEntry<C> {
         LogEntry {
             index,
             term,
-            command: Command(0),
+            command: C::default(),
         }
     }
 }

+ 1 - 1
src/raft_state.rs

@@ -13,7 +13,7 @@ pub(crate) enum State {
 pub(crate) struct RaftState<Command> {
     pub current_term: Term,
     pub voted_for: Option<Peer>,
-    pub log: LogArray,
+    pub log: LogArray<Command>,
 
     pub commit_index: Index,
     pub last_applied: Index,

+ 5 - 3
src/rpcs.rs

@@ -10,7 +10,7 @@ use crate::{
 use serde::de::DeserializeOwned;
 use serde::Serialize;
 
-fn proxy_request_vote<Command: Clone + Serialize>(
+fn proxy_request_vote<Command: Clone + Serialize + Default>(
     raft: &Raft<Command>,
     data: RequestMessage,
 ) -> ReplyMessage {
@@ -25,7 +25,9 @@ fn proxy_request_vote<Command: Clone + Serialize>(
     )
 }
 
-fn proxy_append_entries<Command: Clone + Serialize + DeserializeOwned>(
+fn proxy_append_entries<
+    Command: Clone + Serialize + DeserializeOwned + Default,
+>(
     raft: &Raft<Command>,
     data: RequestMessage,
 ) -> ReplyMessage {
@@ -82,7 +84,7 @@ impl RpcClient {
 }
 
 pub fn register_server<
-    Command: 'static + Clone + Serialize + DeserializeOwned,
+    Command: 'static + Clone + Serialize + DeserializeOwned + Default,
     S: AsRef<str>,
 >(
     raft: Arc<Raft<Command>>,