Selaa lähdekoodia

Merge branch 'snapshot_test'

1. Redesigned the snapshot thread system to use explicit blocking for better
shutdown.
2. Added many integration tests related to snapshots.
Jing Yang 4 vuotta sitten
vanhempi
commit
dd77cb06b3

+ 7 - 0
Cargo.toml

@@ -26,3 +26,10 @@ tokio = { version = "1.0", features = ["rt-multi-thread", "time", "parking_lot"]
 anyhow = "1.0"
 futures = { version = "0.3.8", features = ["thread-pool"] }
 scopeguard = "1.1.0"
+kvraft = { path = "kvraft" }
+
+[workspace]
+members = [
+    "kvraft",
+    "linearizability",
+]

+ 19 - 0
kvraft/Cargo.toml

@@ -0,0 +1,19 @@
+[package]
+name = "kvraft"
+version = "0.1.0"
+edition = "2018"
+
+[dependencies]
+bincode = "1.3.1"
+bytes = "1.0"
+labrpc = { path = "../../labrpc" }
+parking_lot = "0.11.1"
+rand = "0.8"
+ruaft = { path = "../" }
+linearizability = { path = "../linearizability" }
+serde = "1.0.116"
+serde_derive = "1.0.116"
+tokio = { version = "1.0", features = ["rt-multi-thread", "time", "parking_lot"] }
+
+[dev-dependencies]
+scopeguard = "1.1.0"

+ 243 - 0
kvraft/src/client.rs

@@ -0,0 +1,243 @@
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::Once;
+use std::time::Duration;
+
+use labrpc::{Client, RequestMessage};
+use serde::de::DeserializeOwned;
+use serde::Serialize;
+
+use crate::common::{
+    GetArgs, GetEnum, GetReply, KVRaftOptions, PutAppendArgs, PutAppendEnum,
+    PutAppendReply, UniqueIdSequence, GET, PUT_APPEND,
+};
+use crate::common::{KVError, ValidReply};
+
+pub struct Clerk {
+    init: Once,
+    inner: ClerkInner,
+}
+
+impl Clerk {
+    pub fn new(servers: Vec<Client>) -> Self {
+        Self {
+            init: Once::new(),
+            inner: ClerkInner::new(servers),
+        }
+    }
+
+    pub fn get<K: AsRef<str>>(&mut self, key: K) -> Option<String> {
+        let inner = self.init_once();
+
+        let key = key.as_ref();
+        loop {
+            if let Some(val) = inner.get(key.to_owned(), Default::default()) {
+                return val;
+            }
+        }
+    }
+
+    pub fn put<K: AsRef<str>, V: AsRef<str>>(&mut self, key: K, value: V) {
+        let inner = self.init_once();
+
+        let key = key.as_ref();
+        let value = value.as_ref();
+        inner
+            .put(key.to_owned(), value.to_owned(), Default::default())
+            .expect("Put should never return error with unlimited retry.")
+    }
+
+    pub fn append<K: AsRef<str>, V: AsRef<str>>(&mut self, key: K, value: V) {
+        let inner = self.init_once();
+
+        let key = key.as_ref();
+        let value = value.as_ref();
+        inner
+            .append(key.to_owned(), value.to_owned(), Default::default())
+            .expect("Append should never return error with unlimited retry.")
+    }
+
+    pub fn init_once(&mut self) -> &mut ClerkInner {
+        let (init, inner) = (&self.init, &mut self.inner);
+        init.call_once(|| inner.commit_sentinel());
+        &mut self.inner
+    }
+}
+
+pub struct ClerkInner {
+    servers: Vec<Client>,
+
+    last_server_index: AtomicUsize,
+    unique_id: UniqueIdSequence,
+
+    executor: tokio::runtime::Runtime,
+}
+
+impl ClerkInner {
+    pub fn new(servers: Vec<Client>) -> Self {
+        Self {
+            servers,
+
+            last_server_index: AtomicUsize::new(0),
+            unique_id: UniqueIdSequence::new(),
+
+            executor: tokio::runtime::Builder::new_current_thread()
+                .enable_time()
+                .build()
+                .expect("Creating thread pool should not fail"),
+        }
+    }
+
+    fn commit_sentinel(&mut self) {
+        loop {
+            let args = GetArgs {
+                key: "".to_string(),
+                op: GetEnum::NoDuplicate,
+                unique_id: self.unique_id.zero(),
+            };
+            let reply: Option<GetReply> = self.call_rpc(GET, args, Some(1));
+            if let Some(reply) = reply {
+                match reply.result {
+                    Ok(_) => {
+                        // Discard the used unique_id.
+                        self.unique_id.inc();
+                        break;
+                    }
+                    Err(KVError::Expired) | Err(KVError::Conflict) => {
+                        // The client ID happens to be re-used. The request does
+                        // not fail as "Duplicate", because another client has
+                        // committed more than just the sentinel.
+                        self.unique_id = UniqueIdSequence::new();
+                    }
+                    Err(_) => {}
+                };
+            };
+        }
+    }
+
+    const DEFAULT_TIMEOUT: Duration = Duration::from_secs(1);
+
+    fn call_rpc<M, A, R>(
+        &mut self,
+        method: M,
+        args: A,
+        max_retry: Option<usize>,
+    ) -> Option<R>
+    where
+        M: AsRef<str>,
+        A: Serialize,
+        R: DeserializeOwned + ValidReply,
+    {
+        let method = method.as_ref().to_owned();
+        let data = RequestMessage::from(
+            bincode::serialize(&args)
+                .expect("Serialization of requests should not fail"),
+        );
+
+        let max_retry =
+            std::cmp::max(max_retry.unwrap_or(usize::MAX), self.servers.len());
+
+        let mut index = self.last_server_index.load(Ordering::Relaxed);
+        for _ in 0..max_retry {
+            let client = &self.servers[index];
+            let rpc_response = self.executor.block_on(async {
+                tokio::time::timeout(
+                    Self::DEFAULT_TIMEOUT,
+                    client.call_rpc(method.clone(), data.clone()),
+                )
+                .await
+            });
+            let reply = match rpc_response {
+                Ok(reply) => reply,
+                Err(e) => Err(e.into()),
+            };
+            if let Ok(reply) = reply {
+                let ret: R = bincode::deserialize(reply.as_ref())
+                    .expect("Deserialization of reply should not fail");
+                if ret.is_reply_valid() {
+                    self.last_server_index.store(index, Ordering::Relaxed);
+                    return Some(ret);
+                }
+            }
+            index += 1;
+            index %= self.servers.len();
+        }
+        None
+    }
+
+    /// This function returns None when
+    /// 1. No KVServer can be reached, or
+    /// 2. No KVServer claimed to be the leader, or
+    /// 3. When the KVServer committed the request but it was not passed
+    /// back to the clerk. We must retry with a new unique_id.
+    ///
+    /// In all 3 cases the request can be retried.
+    ///
+    /// This function do not expect a Conflict request with the same unique_id.
+    pub fn get(
+        &mut self,
+        key: String,
+        options: KVRaftOptions,
+    ) -> Option<Option<String>> {
+        let args = GetArgs {
+            key,
+            op: GetEnum::AllowDuplicate,
+            unique_id: self.unique_id.inc(),
+        };
+        let reply: GetReply = self.call_rpc(GET, args, options.max_retry)?;
+        match reply.result {
+            Ok(val) => Some(val),
+            Err(KVError::Conflict) => panic!("We should never see a conflict."),
+            _ => None,
+        }
+    }
+
+    /// This function returns None when
+    /// 1. No KVServer can be reached, or
+    /// 2. No KVServer claimed to be the leader.
+    ///
+    /// Some(()) is returned if the request has been committed previously, under
+    /// the assumption is that two different requests with the same unique_id
+    /// must be identical.
+    ///
+    /// This function do not expect a Conflict request with the same unique_id.
+    fn put_append(
+        &mut self,
+        key: String,
+        value: String,
+        op: PutAppendEnum,
+        options: KVRaftOptions,
+    ) -> Option<()> {
+        let args = PutAppendArgs {
+            key,
+            value,
+            op,
+            unique_id: self.unique_id.inc(),
+        };
+        let reply: PutAppendReply =
+            self.call_rpc(PUT_APPEND, args, options.max_retry)?;
+        match reply.result {
+            Ok(val) => Some(val),
+            Err(KVError::Expired) => Some(()),
+            Err(KVError::Conflict) => panic!("We should never see a conflict."),
+            _ => None,
+        }
+    }
+
+    pub fn put(
+        &mut self,
+        key: String,
+        value: String,
+        options: KVRaftOptions,
+    ) -> Option<()> {
+        self.put_append(key, value, PutAppendEnum::Put, options)
+    }
+
+    pub fn append(
+        &mut self,
+        key: String,
+        value: String,
+        options: KVRaftOptions,
+    ) -> Option<()> {
+        self.put_append(key, value, PutAppendEnum::Append, options)
+    }
+}

+ 132 - 0
kvraft/src/common.rs

@@ -0,0 +1,132 @@
+use std::sync::atomic::{AtomicU64, Ordering};
+
+use rand::{thread_rng, RngCore};
+
+pub type ClerkId = u64;
+#[derive(
+    Clone,
+    Copy,
+    Debug,
+    Default,
+    Hash,
+    Ord,
+    PartialOrd,
+    Eq,
+    PartialEq,
+    Serialize,
+    Deserialize,
+)]
+pub struct UniqueId {
+    pub clerk_id: ClerkId,
+    pub sequence_id: u64,
+}
+
+#[derive(Debug)]
+pub struct UniqueIdSequence {
+    clerk_id: u64,
+    sequence_id: AtomicU64,
+}
+
+impl UniqueIdSequence {
+    pub fn new() -> Self {
+        Self {
+            clerk_id: thread_rng().next_u64(),
+            sequence_id: AtomicU64::new(0),
+        }
+    }
+
+    pub fn zero(&self) -> UniqueId {
+        UniqueId {
+            clerk_id: self.clerk_id,
+            sequence_id: 0,
+        }
+    }
+
+    pub fn inc(&mut self) -> UniqueId {
+        let seq = self.sequence_id.fetch_add(1, Ordering::Relaxed);
+        UniqueId {
+            clerk_id: self.clerk_id,
+            sequence_id: seq,
+        }
+    }
+}
+
+pub(crate) const GET: &str = "KVServer.Get";
+pub(crate) const PUT_APPEND: &str = "KVServer.PutAppend";
+
+#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
+pub enum PutAppendEnum {
+    Put,
+    Append,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct PutAppendArgs {
+    pub key: String,
+    pub value: String,
+    pub op: PutAppendEnum,
+
+    pub unique_id: UniqueId,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct PutAppendReply {
+    pub result: Result<(), KVError>,
+}
+
+#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
+pub enum GetEnum {
+    AllowDuplicate,
+    NoDuplicate,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct GetArgs {
+    pub key: String,
+    pub op: GetEnum,
+
+    pub unique_id: UniqueId,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct GetReply {
+    pub result: Result<Option<String>, KVError>,
+}
+
+#[derive(Clone, Debug, Default)]
+pub struct KVRaftOptions {
+    pub max_retry: Option<usize>,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub enum KVError {
+    NotLeader,
+    Expired,
+    TimedOut,
+    Conflict,
+}
+
+pub trait ValidReply {
+    fn is_reply_valid(&self) -> bool;
+}
+
+impl<T> ValidReply for Result<T, KVError> {
+    fn is_reply_valid(&self) -> bool {
+        !matches!(
+            self.as_ref().err(),
+            Some(KVError::TimedOut) | Some(KVError::NotLeader)
+        )
+    }
+}
+
+impl ValidReply for PutAppendReply {
+    fn is_reply_valid(&self) -> bool {
+        self.result.is_reply_valid()
+    }
+}
+
+impl ValidReply for GetReply {
+    fn is_reply_valid(&self) -> bool {
+        self.result.is_reply_valid()
+    }
+}

+ 18 - 0
kvraft/src/lib.rs

@@ -0,0 +1,18 @@
+extern crate labrpc;
+extern crate parking_lot;
+extern crate rand;
+extern crate ruaft;
+extern crate serde;
+#[macro_use]
+extern crate serde_derive;
+extern crate tokio;
+
+pub use client::Clerk;
+pub use server::KVServer;
+
+mod client;
+mod common;
+mod server;
+
+mod snapshot_holder;
+pub mod testing_utils;

+ 439 - 0
kvraft/src/server.rs

@@ -0,0 +1,439 @@
+use std::collections::hash_map::Entry;
+use std::collections::HashMap;
+use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
+use std::sync::mpsc::{channel, Receiver};
+use std::sync::Arc;
+use std::time::Duration;
+
+use parking_lot::{Condvar, Mutex};
+
+use ruaft::{ApplyCommandMessage, Persister, Raft, RpcClient, Term};
+
+use crate::common::{
+    ClerkId, GetArgs, GetEnum, GetReply, KVError, PutAppendArgs, PutAppendEnum,
+    PutAppendReply, UniqueId,
+};
+use crate::snapshot_holder::SnapshotHolder;
+
+pub struct KVServer {
+    me: AtomicUsize,
+    state: Mutex<KVServerState>,
+    rf: Mutex<Raft<UniqueKVOp>>,
+    keep_running: AtomicBool,
+}
+
+#[derive(Clone, Default, Serialize, Deserialize)]
+pub struct UniqueKVOp {
+    op: KVOp,
+    me: usize,
+    unique_id: UniqueId,
+}
+
+#[derive(Default, Serialize, Deserialize)]
+struct KVServerState {
+    kv: HashMap<String, String>,
+    debug_kv: HashMap<String, String>,
+    applied_op: HashMap<ClerkId, (UniqueId, CommitResult)>,
+    #[serde(skip)]
+    queries: HashMap<UniqueId, Arc<ResultHolder>>,
+}
+
+#[derive(Clone, Serialize, Deserialize)]
+enum KVOp {
+    NoOp,
+    Get(String),
+    Put(String, String),
+    Append(String, String),
+}
+
+impl Default for KVOp {
+    fn default() -> Self {
+        KVOp::NoOp
+    }
+}
+
+struct ResultHolder {
+    term: AtomicUsize,
+    result: Mutex<Result<CommitResult, CommitError>>,
+    condvar: Condvar,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+enum CommitResult {
+    Get(Option<String>),
+    Put,
+    Append,
+}
+
+#[derive(Clone, Debug)]
+enum CommitError {
+    NotLeader,
+    Expired(UniqueId),
+    TimedOut,
+    #[allow(dead_code)]
+    Conflict,
+    NotMe(CommitResult),
+    Duplicate(CommitResult),
+}
+
+impl From<CommitError> for KVError {
+    fn from(err: CommitError) -> Self {
+        match err {
+            CommitError::NotLeader => KVError::NotLeader,
+            CommitError::Expired(_) => KVError::Expired,
+            CommitError::TimedOut => KVError::TimedOut,
+            CommitError::Conflict => KVError::Conflict,
+            CommitError::NotMe(_) => panic!("NotMe is not a KVError"),
+            CommitError::Duplicate(_) => panic!("Duplicate is not a KVError"),
+        }
+    }
+}
+
+impl KVServer {
+    pub fn new(
+        servers: Vec<RpcClient>,
+        me: usize,
+        persister: Arc<dyn Persister>,
+        max_state_size_bytes: Option<usize>,
+    ) -> Arc<Self> {
+        let (tx, rx) = channel();
+        let apply_command = move |message| {
+            tx.send(message)
+                .expect("The receiving end of apply command channel should have not been dropped");
+        };
+        let snapshot_holder = Arc::new(SnapshotHolder::default());
+        let snapshot_holder_clone = snapshot_holder.clone();
+        let ret = Arc::new(Self {
+            me: AtomicUsize::new(me),
+            state: Default::default(),
+            rf: Mutex::new(Raft::new(
+                servers,
+                me,
+                persister,
+                apply_command,
+                max_state_size_bytes,
+                move |index| snapshot_holder_clone.request_snapshot(index),
+            )),
+            keep_running: AtomicBool::new(true),
+        });
+        ret.process_command(snapshot_holder, rx);
+        ret
+    }
+
+    fn apply_op(&self, unique_id: UniqueId, leader: usize, op: KVOp) {
+        // The borrow checker does not allow borrowing two fields of an instance
+        // inside a MutexGuard. But it does allow borrowing two fields of the
+        // instance itself. Calling deref_mut() on the MutexGuard works, too!
+        let state = &mut *self.state.lock();
+        let (applied_op, kv) = (&mut state.applied_op, &mut state.kv);
+        let entry = applied_op.entry(unique_id.clerk_id);
+        if let Entry::Occupied(curr) = &entry {
+            let (applied_unique_id, _) = curr.get();
+            if *applied_unique_id >= unique_id {
+                // Redelivered.
+                // It is guaranteed that we have no pending queries with the
+                // same unique_id, because
+                // 1. When inserting into queries, we first check the unique_id
+                // is strictly larger than the one in applied_op.
+                // 2. When modifying entries in applied_op, the unique_id can
+                // only grow larger. And we make sure there is no entries with
+                // the same unique_id in queries.
+                // TODO(ditsing): in case 2), make sure there is no entries in
+                // queries that have a smaller unique_id.
+                assert!(!state.queries.contains_key(&unique_id));
+                return;
+            }
+        }
+
+        let result = match op {
+            KVOp::NoOp => return,
+            KVOp::Get(key) => CommitResult::Get(kv.get(&key).cloned()),
+            KVOp::Put(key, value) => {
+                kv.insert(key, value);
+                CommitResult::Put
+            }
+            KVOp::Append(key, value) => {
+                kv.entry(key)
+                    .and_modify(|str| str.push_str(&value))
+                    .or_insert(value);
+                CommitResult::Append
+            }
+        };
+
+        match entry {
+            Entry::Occupied(mut curr) => {
+                curr.insert((unique_id, result.clone()));
+            }
+            Entry::Vacant(vacant) => {
+                vacant.insert((unique_id, result.clone()));
+            }
+        }
+
+        if let Some(result_holder) = state.queries.remove(&unique_id) {
+            // This KV server might not be the same leader that committed the
+            // query. We are not sure if it is a duplicate or a conflict. To
+            // tell the difference, terms of all queries must be stored.
+            *result_holder.result.lock() = if leader == self.me() {
+                Ok(result)
+            } else {
+                Err(CommitError::NotMe(result))
+            };
+            result_holder.condvar.notify_all();
+        };
+    }
+
+    fn restore_state(&self, mut new_state: KVServerState) {
+        let mut state = self.state.lock();
+        // Cleanup all existing queries.
+        for result_holder in state.queries.values() {
+            *result_holder.result.lock() = Err(CommitError::NotLeader);
+            result_holder.condvar.notify_all();
+        }
+
+        std::mem::swap(&mut new_state, &mut *state);
+    }
+
+    fn process_command(
+        self: &Arc<Self>,
+        snapshot_holder: Arc<SnapshotHolder<KVServerState>>,
+        command_channel: Receiver<ApplyCommandMessage<UniqueKVOp>>,
+    ) {
+        let this = Arc::downgrade(self);
+        std::thread::spawn(move || {
+            while let Ok(message) = command_channel.recv() {
+                if let Some(this) = this.upgrade() {
+                    match message {
+                        ApplyCommandMessage::Snapshot(snapshot) => {
+                            let state = snapshot_holder.load_snapshot(snapshot);
+                            this.restore_state(state);
+                        }
+                        ApplyCommandMessage::Command(index, command) => {
+                            this.apply_op(
+                                command.unique_id,
+                                command.me,
+                                command.op,
+                            );
+                            if let Some(snapshot) = snapshot_holder
+                                .take_snapshot(&this.state.lock(), index)
+                            {
+                                this.rf.lock().save_snapshot(snapshot);
+                            }
+                        }
+                    }
+                } else {
+                    break;
+                }
+            }
+        });
+    }
+
+    const UNSEEN_TERM: usize = 0;
+    const ATTEMPTING_TERM: usize = usize::MAX;
+    fn block_for_commit(
+        &self,
+        unique_id: UniqueId,
+        op: KVOp,
+        timeout: Duration,
+    ) -> Result<CommitResult, CommitError> {
+        if !self.keep_running.load(Ordering::SeqCst) {
+            return Err(CommitError::NotLeader);
+        }
+        let result_holder = {
+            let mut state = self.state.lock();
+            let applied = state.applied_op.get(&unique_id.clerk_id);
+            if let Some((applied_unique_id, result)) = applied {
+                #[allow(clippy::comparison_chain)]
+                if unique_id < *applied_unique_id {
+                    return Err(CommitError::Expired(unique_id));
+                } else if unique_id == *applied_unique_id {
+                    return Err(CommitError::Duplicate(result.clone()));
+                }
+            };
+            let entry = state.queries.entry(unique_id).or_insert_with(|| {
+                Arc::new(ResultHolder {
+                    term: AtomicUsize::new(Self::UNSEEN_TERM),
+                    result: Mutex::new(Err(CommitError::TimedOut)),
+                    condvar: Condvar::new(),
+                })
+            });
+            entry.clone()
+        };
+
+        let (Term(hold_term), is_leader) = self.rf.lock().get_state();
+        if !is_leader {
+            result_holder.condvar.notify_all();
+            return Err(CommitError::NotLeader);
+        }
+        Self::validate_term(hold_term);
+
+        let set = result_holder.term.compare_exchange(
+            Self::UNSEEN_TERM,
+            Self::ATTEMPTING_TERM,
+            Ordering::SeqCst,
+            Ordering::SeqCst,
+        );
+        let start = match set {
+            // Nobody has attempted start() yet.
+            Ok(Self::UNSEEN_TERM) => true,
+            Ok(_) => panic!(
+                "compare_exchange should always return the current value 0"
+            ),
+            // Somebody is attempting start().
+            Err(Self::ATTEMPTING_TERM) => false,
+            // Somebody has attempted start().
+            Err(prev_term) if prev_term < hold_term => {
+                let set = result_holder.term.compare_exchange(
+                    prev_term,
+                    Self::ATTEMPTING_TERM,
+                    Ordering::SeqCst,
+                    Ordering::SeqCst,
+                );
+                set.is_ok()
+            }
+            _ => false,
+        };
+        if start {
+            let op = UniqueKVOp {
+                op,
+                me: self.me(),
+                unique_id,
+            };
+            let start = self.rf.lock().start(op);
+            let start_term =
+                start.map_or(Self::UNSEEN_TERM, |(Term(term), _)| {
+                    Self::validate_term(term);
+                    term
+                });
+            let set = result_holder.term.compare_exchange(
+                Self::ATTEMPTING_TERM,
+                start_term,
+                Ordering::SeqCst,
+                Ordering::SeqCst,
+            );
+            // Setting term must have been successful, and must return the
+            // value previously set by this attempt.
+            assert_eq!(set, Ok(Self::ATTEMPTING_TERM));
+
+            if start_term == Self::UNSEEN_TERM {
+                result_holder.condvar.notify_all();
+                return Err(CommitError::NotLeader);
+            }
+        }
+
+        let mut guard = result_holder.result.lock();
+        // Wait for the op to be committed.
+        result_holder.condvar.wait_for(&mut guard, timeout);
+
+        // Copy the result out.
+        let result = guard.clone();
+        // If the result is OK, all other requests should see "Duplicate".
+        if let Ok(result) = guard.clone() {
+            *guard = Err(CommitError::Duplicate(result))
+        }
+
+        result
+    }
+
+    fn validate_term(term: usize) {
+        assert!(term > Self::UNSEEN_TERM, "Term must be larger than 0.");
+        assert!(
+            term < Self::ATTEMPTING_TERM,
+            "Term must be smaller than usize::MAX."
+        );
+    }
+
+    const DEFAULT_TIMEOUT: Duration = Duration::from_secs(1);
+
+    pub fn get(&self, args: GetArgs) -> GetReply {
+        let map_dup = match args.op {
+            GetEnum::AllowDuplicate => |r| Ok(r),
+            GetEnum::NoDuplicate => |_| Err(KVError::Conflict),
+        };
+        let result = match self.block_for_commit(
+            args.unique_id,
+            KVOp::Get(args.key),
+            Self::DEFAULT_TIMEOUT,
+        ) {
+            Ok(result) => Ok(result),
+            Err(CommitError::Duplicate(result)) => map_dup(result),
+            Err(CommitError::NotMe(result)) => map_dup(result),
+            Err(e) => Err(e.into()),
+        };
+        let result = match result {
+            Ok(result) => result,
+            Err(e) => return GetReply { result: Err(e) },
+        };
+        let result = match result {
+            CommitResult::Get(result) => Ok(result),
+            CommitResult::Put => Err(KVError::Conflict),
+            CommitResult::Append => Err(KVError::Conflict),
+        };
+        GetReply { result }
+    }
+
+    pub fn put_append(&self, args: PutAppendArgs) -> PutAppendReply {
+        let op = match args.op {
+            PutAppendEnum::Put => KVOp::Put(args.key, args.value),
+            PutAppendEnum::Append => KVOp::Append(args.key, args.value),
+        };
+        let result = match self.block_for_commit(
+            args.unique_id,
+            op,
+            Self::DEFAULT_TIMEOUT,
+        ) {
+            Ok(result) => result,
+            Err(CommitError::Duplicate(result)) => result,
+            Err(CommitError::NotMe(result)) => result,
+            Err(e) => {
+                return PutAppendReply {
+                    result: Err(e.into()),
+                }
+            }
+        };
+        let result = match result {
+            CommitResult::Put => {
+                if args.op == PutAppendEnum::Put {
+                    Ok(())
+                } else {
+                    Err(KVError::Conflict)
+                }
+            }
+            CommitResult::Append => {
+                if args.op == PutAppendEnum::Append {
+                    Ok(())
+                } else {
+                    Err(KVError::Conflict)
+                }
+            }
+            CommitResult::Get(_) => Err(KVError::Conflict),
+        };
+
+        PutAppendReply { result }
+    }
+
+    pub fn me(&self) -> usize {
+        self.me.load(Ordering::Relaxed)
+    }
+
+    pub fn raft(&self) -> Raft<UniqueKVOp> {
+        self.rf.lock().clone()
+    }
+
+    pub fn kill(self: Arc<Self>) {
+        // Return error to new queries.
+        self.keep_running.store(false, Ordering::SeqCst);
+        // Cancel all in-flight queries.
+        for result_holder in self.state.lock().queries.values() {
+            *result_holder.result.lock() = Err(CommitError::NotLeader);
+            result_holder.condvar.notify_all();
+        }
+
+        let rf = self.raft();
+        // We must drop self to remove the only clone of raft, so that
+        // `rf.kill()` does not block.
+        drop(self);
+        rf.kill();
+        // The process_command thread will exit, after Raft drops the reference
+        // to the sender.
+    }
+}

+ 55 - 0
kvraft/src/snapshot_holder.rs

@@ -0,0 +1,55 @@
+use std::marker::PhantomData;
+
+use parking_lot::Mutex;
+use serde::de::DeserializeOwned;
+use serde::Serialize;
+
+use ruaft::Snapshot;
+
+#[derive(Default)]
+pub(crate) struct SnapshotHolder<T> {
+    snapshot_requests: Mutex<Vec<usize>>,
+    phantom: PhantomData<T>,
+}
+
+impl<T> SnapshotHolder<T> {
+    pub fn request_snapshot(&self, min_index: usize) {
+        let mut requests = self.snapshot_requests.lock();
+        let pos = requests.binary_search(&min_index);
+        if let Err(pos) = pos {
+            requests.insert(pos, min_index);
+        }
+    }
+}
+
+impl<T: Serialize> SnapshotHolder<T> {
+    pub fn take_snapshot(&self, state: &T, curr: usize) -> Option<Snapshot> {
+        let mut requests = self.snapshot_requests.lock();
+
+        let processed = requests.partition_point(|index| *index <= curr);
+        if processed == 0 {
+            return None;
+        }
+
+        requests.drain(0..processed);
+        drop(requests);
+
+        let data = bincode::serialize(state)
+            .expect("Serialization should never fail.");
+        Some(Snapshot {
+            data,
+            last_included_index: curr,
+        })
+    }
+}
+
+impl<T: DeserializeOwned> SnapshotHolder<T> {
+    pub fn load_snapshot(&self, snapshot: Snapshot) -> T {
+        let state = bincode::deserialize(&snapshot.data).expect(&*format!(
+            "Deserialization should never fail, {:?}",
+            &snapshot.data
+        ));
+
+        state
+    }
+}

+ 330 - 0
kvraft/src/testing_utils/config.rs

@@ -0,0 +1,330 @@
+use std::sync::Arc;
+
+use labrpc::Network;
+use parking_lot::Mutex;
+use rand::seq::SliceRandom;
+use rand::thread_rng;
+
+use ruaft::rpcs::register_server;
+use ruaft::{Persister, RpcClient};
+
+use crate::client::Clerk;
+use crate::server::KVServer;
+use crate::testing_utils::memory_persister::{MemoryPersister, MemoryStorage};
+use crate::testing_utils::rpcs::register_kv_server;
+
+struct ConfigState {
+    kv_servers: Vec<Option<Arc<KVServer>>>,
+    next_clerk: usize,
+}
+
+pub struct Config {
+    network: Arc<Mutex<labrpc::Network>>,
+    server_count: usize,
+    state: Mutex<ConfigState>,
+    storage: Mutex<MemoryStorage>,
+    maxraftstate: usize,
+}
+
+impl Config {
+    fn kv_clerk_name(i: usize, server: usize) -> String {
+        format!("kvraft-clerk-client-{}-to-{}", i, server)
+    }
+
+    fn kv_server_name(i: usize) -> String {
+        format!("kv-server-{}", i)
+    }
+
+    fn server_name(i: usize) -> String {
+        format!("kvraft-server-{}", i)
+    }
+
+    fn client_name(client: usize, server: usize) -> String {
+        format!("kvraft-client-{}-to-{}", client, server)
+    }
+
+    fn start_server(&self, index: usize) -> std::io::Result<()> {
+        let mut clients = vec![];
+        {
+            let mut network = self.network.lock();
+            for j in 0..self.server_count {
+                clients.push(RpcClient::new(network.make_client(
+                    Self::client_name(index, j),
+                    Self::server_name(j),
+                )))
+            }
+        }
+
+        let persister = self.storage.lock().at(index);
+
+        let kv =
+            KVServer::new(clients, index, persister, Some(self.maxraftstate));
+        self.state.lock().kv_servers[index].replace(kv.clone());
+
+        let raft = std::rc::Rc::new(kv.raft());
+
+        register_server(raft, Self::server_name(index), self.network.as_ref())?;
+
+        register_kv_server(
+            kv,
+            Self::kv_server_name(index),
+            self.network.as_ref(),
+        )?;
+        Ok(())
+    }
+
+    pub fn begin<S: std::fmt::Display>(&self, msg: S) {
+        eprintln!("{}", msg);
+    }
+
+    fn shuffled_indexes(&self) -> Vec<usize> {
+        let mut indexes: Vec<usize> = (0..self.server_count).collect();
+        indexes.shuffle(&mut thread_rng());
+        indexes
+    }
+
+    pub fn partition(&self) -> (Vec<usize>, Vec<usize>) {
+        let state = self.state.lock();
+        let mut indexes = self.shuffled_indexes();
+
+        // Swap leader to position 0.
+        let leader_position = indexes
+            .iter()
+            .position(|index| {
+                state.kv_servers[*index]
+                    .as_ref()
+                    .map_or(false, |kv| kv.raft().get_state().1)
+            })
+            .unwrap_or(0);
+        indexes.swap(0, leader_position);
+
+        let part_one = indexes.split_off(indexes.len() / 2);
+        let part_two = indexes;
+        self.network_partition(&part_one, &part_two);
+
+        (part_one, part_two)
+    }
+
+    pub fn random_partition(&self) -> (Vec<usize>, Vec<usize>) {
+        let mut indexes = self.shuffled_indexes();
+        let part_one = indexes.split_off(indexes.len() / 2);
+        let part_two = indexes;
+        self.network_partition(&part_one, &part_two);
+
+        (part_one, part_two)
+    }
+
+    fn set_connect(
+        network: &mut Network,
+        from: &[usize],
+        to: &[usize],
+        yes: bool,
+    ) {
+        for i in from {
+            for j in to {
+                network.set_enable_client(Self::client_name(*i, *j), yes)
+            }
+        }
+    }
+
+    pub fn network_partition(&self, part_one: &[usize], part_two: &[usize]) {
+        let mut network = self.network.lock();
+        Self::set_connect(&mut network, part_one, part_two, false);
+        Self::set_connect(&mut network, part_two, part_one, false);
+        Self::set_connect(&mut network, part_one, part_one, true);
+        Self::set_connect(&mut network, part_two, part_two, true);
+    }
+
+    pub fn connect_all(&self) {
+        let all: Vec<usize> = (0..self.state.lock().kv_servers.len()).collect();
+        let mut network = self.network.lock();
+        Self::set_connect(&mut network, &all, &all, true);
+    }
+
+    fn crash_server(&self, index: usize) {
+        {
+            let all: Vec<usize> = (0..self.server_count).collect();
+
+            let mut network = self.network.lock();
+            Self::set_connect(&mut network, &all, &[index], false);
+            Self::set_connect(&mut network, &[index], &all, false);
+
+            network.remove_server(Self::server_name(index));
+            network.remove_server(Self::kv_server_name(index));
+        }
+
+        let data = self.storage.lock().at(index).read();
+
+        let persister = self.storage.lock().replace(index);
+        persister.restore(data);
+
+        if let Some(kv_server) = self.state.lock().kv_servers[index].take() {
+            kv_server.kill();
+        }
+    }
+
+    pub fn crash_all(&self) {
+        for i in 0..self.server_count {
+            self.crash_server(i);
+        }
+    }
+
+    pub fn restart_all(&self) {
+        for index in 0..self.server_count {
+            self.start_server(index)
+                .expect("Start server should never fail");
+        }
+    }
+
+    fn set_clerk_connect(
+        network: &mut Network,
+        clerk_index: usize,
+        to: &[usize],
+        yes: bool,
+    ) {
+        for j in to {
+            network.set_enable_client(Self::kv_clerk_name(clerk_index, *j), yes)
+        }
+    }
+
+    pub fn make_limited_clerk(&self, to: &[usize]) -> Clerk {
+        let mut clients = vec![];
+        let clerk_index = {
+            let mut state = self.state.lock();
+            state.next_clerk += 1;
+            state.next_clerk
+        };
+
+        {
+            let mut network = self.network.lock();
+            for j in 0..self.server_count {
+                clients.push(network.make_client(
+                    Self::kv_clerk_name(clerk_index, j),
+                    Self::kv_server_name(j),
+                ));
+            }
+            // Disable clerk connection to all kv servers.
+            Self::set_clerk_connect(
+                &mut network,
+                clerk_index,
+                &(0..self.server_count).collect::<Vec<usize>>(),
+                false,
+            );
+            // Enable clerk connection to some servers.
+            Self::set_clerk_connect(&mut network, clerk_index, to, true);
+        }
+
+        clients.shuffle(&mut thread_rng());
+        Clerk::new(clients)
+    }
+
+    pub fn make_clerk(&self) -> Clerk {
+        self.make_limited_clerk(&(0..self.server_count).collect::<Vec<usize>>())
+    }
+
+    pub fn connect_all_clerks(&self) {
+        let mut network = self.network.lock();
+        let all = &(0..self.server_count).collect::<Vec<usize>>();
+        for clerk_index in 0..self.state.lock().next_clerk {
+            Self::set_clerk_connect(&mut network, clerk_index + 1, all, true);
+        }
+    }
+
+    pub fn end(&self) {}
+
+    pub fn clean_up(&self) {
+        let mut network = self.network.lock();
+        for i in 0..self.server_count {
+            network.remove_server(Self::server_name(i));
+            network.remove_server(Self::kv_server_name(i));
+        }
+        network.stop();
+        drop(network);
+
+        for kv_server in &mut self.state.lock().kv_servers {
+            if let Some(kv_server) = kv_server.take() {
+                kv_server.kill();
+            }
+        }
+    }
+}
+
+impl Config {
+    fn check_size(
+        &self,
+        upper: usize,
+        size_fn: impl Fn(&MemoryPersister) -> usize,
+    ) -> Result<(), String> {
+        let mut over_limits = String::new();
+        for (index, p) in self.storage.lock().all().iter().enumerate() {
+            let size = size_fn(p);
+            if size > upper {
+                let str = format!(" (index {}, size {})", index, size);
+                over_limits.push_str(&str);
+            }
+        }
+        if !over_limits.is_empty() {
+            return Err(format!(
+                "logs were not trimmed to {}:{}",
+                upper, over_limits
+            ));
+        }
+        Ok(())
+    }
+
+    pub fn check_log_size(&self, upper: usize) -> Result<(), String> {
+        self.check_size(upper, MemoryPersister::state_size)
+    }
+
+    pub fn check_snapshot_size(&self, upper: usize) -> Result<(), String> {
+        self.check_size(upper, MemoryPersister::snapshot_size)
+    }
+}
+
+pub fn make_config(
+    server_count: usize,
+    unreliable: bool,
+    maxraftstate: usize,
+) -> Config {
+    let network = labrpc::Network::run_daemon();
+    {
+        let mut unlocked_network = network.lock();
+        unlocked_network.set_reliable(!unreliable);
+        unlocked_network.set_long_delays(true);
+    }
+
+    let state = Mutex::new(ConfigState {
+        kv_servers: vec![None; server_count],
+        next_clerk: 0,
+    });
+
+    let mut storage = MemoryStorage::default();
+    for _ in 0..server_count {
+        storage.make();
+    }
+    let storage = Mutex::new(storage);
+
+    let cfg = Config {
+        network,
+        server_count,
+        state,
+        storage,
+        maxraftstate,
+    };
+
+    for i in 0..server_count {
+        cfg.start_server(i)
+            .expect("Starting server should not fail");
+    }
+
+    cfg
+}
+
+pub fn sleep_millis(mills: u64) {
+    std::thread::sleep(std::time::Duration::from_millis(mills))
+}
+
+pub const LONG_ELECTION_TIMEOUT_MILLIS: u64 = 1000;
+pub fn sleep_election_timeouts(count: u64) {
+    sleep_millis(LONG_ELECTION_TIMEOUT_MILLIS * count)
+}

+ 300 - 0
kvraft/src/testing_utils/generic_test.rs

@@ -0,0 +1,300 @@
+use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::Arc;
+use std::thread::JoinHandle;
+use std::time::{Duration, Instant};
+
+use parking_lot::Mutex;
+use rand::{thread_rng, Rng};
+
+use linearizability::{KvInput, KvModel, KvOp, KvOutput, Operation};
+
+use crate::testing_utils::config::{
+    make_config, sleep_election_timeouts, sleep_millis, Config,
+    LONG_ELECTION_TIMEOUT_MILLIS,
+};
+use crate::Clerk;
+
+pub fn spawn_clients<T, Func>(
+    config: Arc<Config>,
+    clients: usize,
+    func: Func,
+) -> Vec<JoinHandle<T>>
+where
+    T: 'static + Send,
+    Func: 'static + Clone + Send + Sync + Fn(usize, Clerk) -> T,
+{
+    let mut client_threads = vec![];
+    for i in 0..clients {
+        let clerk = config.make_clerk();
+        let func = func.clone();
+        client_threads.push(std::thread::spawn(move || func(i, clerk)))
+    }
+    eprintln!("spawning clients done.");
+    client_threads
+}
+
+fn appending_client(
+    index: usize,
+    mut clerk: Clerk,
+    stop: Arc<AtomicBool>,
+) -> (usize, String) {
+    eprintln!("client {} running.", index);
+    let mut op_count = 0usize;
+    let key = index.to_string();
+    let mut last = String::new();
+    let mut rng = thread_rng();
+
+    clerk.put(&key, &last);
+
+    while !stop.load(Ordering::Acquire) {
+        eprintln!("client {} starting {}.", index, op_count);
+        if rng.gen_ratio(1, 2) {
+            let value = format!("({}, {}), ", index, op_count);
+
+            last.push_str(&value);
+            clerk.append(&key, &value);
+
+            op_count += 1;
+        } else {
+            let value = clerk
+                .get(&key)
+                .unwrap_or_else(|| panic!("Key {} should exist.", index));
+            assert_eq!(value, last);
+        }
+        eprintln!("client {} done {}.", index, op_count);
+    }
+    eprintln!("client {} done.", index);
+    (op_count, last)
+}
+
+fn linearizability_client(
+    index: usize,
+    client_count: usize,
+    mut clerk: Clerk,
+    stop: Arc<AtomicBool>,
+    ops: Arc<Mutex<Vec<Operation<KvInput, KvOutput>>>>,
+) -> (usize, String) {
+    let mut op_count = 0usize;
+    while !stop.load(Ordering::Acquire) {
+        let key = thread_rng().gen_range(0..client_count).to_string();
+        let value = format!("({}, {}), ", index, op_count);
+        let call_time = Instant::now();
+        let call_op;
+        let return_op;
+        if thread_rng().gen_ratio(500, 1000) {
+            clerk.append(&key, &value);
+            call_op = KvInput {
+                op: KvOp::Append,
+                key,
+                value,
+            };
+            return_op = KvOutput::default();
+        } else if thread_rng().gen_ratio(100, 1000) {
+            clerk.put(&key, &value);
+            call_op = KvInput {
+                op: KvOp::Put,
+                key,
+                value,
+            };
+            return_op = KvOutput::default();
+        } else {
+            let result = clerk.get(&key).unwrap_or_default();
+            call_op = KvInput {
+                op: KvOp::Get,
+                key,
+                value: Default::default(),
+            };
+            return_op = result;
+        }
+        let return_time = Instant::now();
+        ops.lock().push(Operation {
+            call_op,
+            call_time,
+            return_op,
+            return_time,
+        });
+
+        op_count += 1;
+    }
+    (op_count, String::new())
+}
+
+const PARTITION_MAX_DELAY_MILLIS: u64 = 200;
+
+fn run_partition(cfg: Arc<Config>, stop: Arc<AtomicBool>) {
+    while !stop.load(Ordering::Acquire) {
+        cfg.random_partition();
+        let delay = thread_rng().gen_range(
+            LONG_ELECTION_TIMEOUT_MILLIS
+                ..LONG_ELECTION_TIMEOUT_MILLIS + PARTITION_MAX_DELAY_MILLIS,
+        );
+        std::thread::sleep(Duration::from_millis(delay));
+    }
+}
+
+#[derive(Debug)]
+struct Laps {
+    clients_started: Duration,
+    partition_done: Duration,
+    crash_done: Duration,
+    running_time: Duration,
+    partition_stopped: Duration,
+    client_spawn: Duration,
+    client_waits: Duration,
+}
+
+#[derive(Default)]
+pub struct GenericTestParams {
+    pub clients: usize,
+    pub unreliable: bool,
+    pub partition: bool,
+    pub crash: bool,
+    pub maxraftstate: Option<usize>,
+    pub min_ops: Option<usize>,
+    pub test_linearizability: bool,
+}
+
+pub fn generic_test(test_params: GenericTestParams) {
+    let GenericTestParams {
+        clients,
+        unreliable,
+        partition,
+        crash,
+        maxraftstate,
+        min_ops,
+        test_linearizability,
+    } = test_params;
+    let maxraftstate = maxraftstate.unwrap_or(usize::MAX);
+    let min_ops = min_ops.unwrap_or(10);
+    let servers: usize = if test_linearizability { 7 } else { 5 };
+    let cfg = Arc::new(make_config(servers, unreliable, maxraftstate));
+
+    cfg.begin("");
+    let mut clerk = cfg.make_clerk();
+    let ops = Arc::new(Mutex::new(vec![]));
+
+    let mut laps = vec![];
+    const ROUNDS: usize = 3;
+    for _ in 0..ROUNDS {
+        let start = Instant::now();
+        // Network partition thread.
+        let partition_stop = Arc::new(AtomicBool::new(false));
+        // KV server clients.
+        let clients_stop = Arc::new(AtomicBool::new(false));
+
+        let config = cfg.clone();
+        let clients_stop_clone = clients_stop.clone();
+        let ops_clone = ops.clone();
+        let spawn_client_results = std::thread::spawn(move || {
+            spawn_clients(config, clients, move |index: usize, clerk: Clerk| {
+                if !test_linearizability {
+                    appending_client(index, clerk, clients_stop_clone.clone())
+                } else {
+                    linearizability_client(
+                        index,
+                        clients,
+                        clerk,
+                        clients_stop_clone.clone(),
+                        ops_clone.clone(),
+                    )
+                }
+            })
+        });
+        let clients_started = start.elapsed();
+
+        let partition_result = if partition {
+            // Let the clients perform some operations without interruption.
+            sleep_millis(1000);
+            let config = cfg.clone();
+            let partition_stop_clone = partition_stop.clone();
+            Some(std::thread::spawn(|| {
+                run_partition(config, partition_stop_clone)
+            }))
+        } else {
+            None
+        };
+        let partition_done = start.elapsed();
+
+        if crash {
+            cfg.crash_all();
+            sleep_election_timeouts(1);
+            cfg.restart_all();
+        }
+        let crash_done = start.elapsed();
+
+        std::thread::sleep(Duration::from_secs(5));
+        let running_time = start.elapsed();
+
+        // Stop partitions.
+        partition_stop.store(true, Ordering::Release);
+        if let Some(result) = partition_result {
+            result.join().expect("Partition thread should never fail");
+            cfg.connect_all();
+            sleep_election_timeouts(1);
+        }
+        let partition_stopped = start.elapsed();
+
+        // Tell all clients to stop.
+        clients_stop.store(true, Ordering::Release);
+
+        let client_results = spawn_client_results
+            .join()
+            .expect("Spawning clients should never fail.");
+        let client_spawn = start.elapsed();
+        for (index, client_result) in client_results.into_iter().enumerate() {
+            let (op_count, last_result) =
+                client_result.join().expect("Client should never fail");
+            if !last_result.is_empty() {
+                let real_result = clerk
+                    .get(index.to_string())
+                    .unwrap_or_else(|| panic!("Key {} should exist.", index));
+                assert_eq!(real_result, last_result);
+            }
+            eprintln!("Client {} committed {} operations", index, op_count);
+            assert!(
+                op_count >= min_ops,
+                "Client {} committed {} operations, less than {}",
+                index,
+                op_count,
+                min_ops
+            );
+        }
+        let client_waits = start.elapsed();
+        laps.push(Laps {
+            clients_started,
+            partition_done,
+            crash_done,
+            running_time,
+            partition_stopped,
+            client_spawn,
+            client_waits,
+        });
+    }
+
+    cfg.end();
+    cfg.clean_up();
+
+    for (index, laps) in laps.iter().enumerate() {
+        eprintln!("Round {} diagnostics: {:?}", index, laps);
+    }
+
+    if test_linearizability {
+        let ops: &'static Vec<Operation<KvInput, KvOutput>> =
+            Box::leak(Box::new(
+                Arc::try_unwrap(ops)
+                    .expect("No one should be holding ops")
+                    .into_inner(),
+            ));
+        let start = Instant::now();
+        eprintln!("Searching for linearization arrangements ...");
+        assert!(
+            linearizability::check_operations_timeout::<KvModel>(&ops, None),
+            "History {:?} is not linearizable,",
+            ops,
+        );
+        eprintln!(
+            "Searching for linearization arrangements done after {:?}.",
+            start.elapsed()
+        );
+    }
+}

+ 85 - 0
kvraft/src/testing_utils/memory_persister.rs

@@ -0,0 +1,85 @@
+use std::sync::Arc;
+
+use parking_lot::Mutex;
+
+#[derive(Clone)]
+pub struct State {
+    bytes: bytes::Bytes,
+    snapshot: Vec<u8>,
+}
+
+pub struct MemoryPersister {
+    state: Mutex<State>,
+}
+
+impl MemoryPersister {
+    pub fn new() -> Self {
+        Self {
+            state: Mutex::new(State {
+                bytes: bytes::Bytes::new(),
+                snapshot: vec![],
+            }),
+        }
+    }
+}
+
+impl ruaft::Persister for MemoryPersister {
+    fn read_state(&self) -> bytes::Bytes {
+        self.state.lock().bytes.clone()
+    }
+
+    fn save_state(&self, data: bytes::Bytes) {
+        self.state.lock().bytes = data;
+    }
+
+    fn state_size(&self) -> usize {
+        self.state.lock().bytes.len()
+    }
+
+    fn save_snapshot_and_state(&self, state: bytes::Bytes, snapshot: &[u8]) {
+        let mut this = self.state.lock();
+        this.bytes = state;
+        this.snapshot = snapshot.to_vec();
+    }
+}
+
+impl MemoryPersister {
+    pub fn read(&self) -> State {
+        self.state.lock().clone()
+    }
+
+    pub fn restore(&self, state: State) {
+        *self.state.lock() = state;
+    }
+
+    pub fn snapshot_size(&self) -> usize {
+        self.state.lock().snapshot.len()
+    }
+}
+
+#[derive(Default)]
+pub struct MemoryStorage {
+    state_vec: Vec<Arc<MemoryPersister>>,
+}
+
+impl MemoryStorage {
+    pub fn make(&mut self) -> Arc<MemoryPersister> {
+        let persister = Arc::new(MemoryPersister::new());
+        self.state_vec.push(persister.clone());
+        persister
+    }
+
+    pub fn at(&self, index: usize) -> Arc<MemoryPersister> {
+        self.state_vec[index].clone()
+    }
+
+    pub fn replace(&mut self, index: usize) -> Arc<MemoryPersister> {
+        let persister = Arc::new(MemoryPersister::new());
+        self.state_vec[index] = persister.clone();
+        persister
+    }
+
+    pub fn all(&self) -> &Vec<Arc<MemoryPersister>> {
+        &self.state_vec
+    }
+}

+ 4 - 0
kvraft/src/testing_utils/mod.rs

@@ -0,0 +1,4 @@
+pub mod config;
+pub mod generic_test;
+mod memory_persister;
+mod rpcs;

+ 35 - 0
kvraft/src/testing_utils/rpcs.rs

@@ -0,0 +1,35 @@
+use labrpc::{Network, Server};
+use parking_lot::Mutex;
+
+use ruaft::rpcs::make_rpc_handler;
+
+use crate::common::{GET, PUT_APPEND};
+use crate::server::KVServer;
+
+pub fn register_kv_server<
+    KV: 'static + AsRef<KVServer> + Clone,
+    S: AsRef<str>,
+>(
+    kv: KV,
+    name: S,
+    network: &Mutex<Network>,
+) -> std::io::Result<()> {
+    let mut network = network.lock();
+    let server_name = name.as_ref();
+    let mut server = Server::make_server(server_name);
+
+    let kv_clone = kv.clone();
+    server.register_rpc_handler(
+        GET.to_owned(),
+        make_rpc_handler(move |args| kv_clone.as_ref().get(args)),
+    )?;
+
+    server.register_rpc_handler(
+        PUT_APPEND.to_owned(),
+        make_rpc_handler(move |args| kv.as_ref().put_append(args)),
+    )?;
+
+    network.add_server(server_name, server);
+
+    Ok(())
+}

+ 254 - 0
kvraft/tests/service_test.rs

@@ -0,0 +1,254 @@
+extern crate kvraft;
+#[macro_use]
+extern crate scopeguard;
+
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::Arc;
+
+use kvraft::testing_utils::config::{
+    make_config, sleep_election_timeouts, sleep_millis,
+};
+use kvraft::testing_utils::generic_test::{
+    generic_test, spawn_clients, GenericTestParams,
+};
+
+type Result = std::result::Result<(), String>;
+
+fn check_concurrent_results(
+    value: String,
+    clients: usize,
+    expected: Vec<usize>,
+) -> Result {
+    if !value.starts_with('(') || !value.ends_with(')') {
+        return Err(format!("Malformed value string {}", value));
+    }
+    let inner_value = &value[1..value.len() - 1];
+    let mut progress = vec![0; clients];
+    for pair_str in inner_value.split(")(") {
+        let mut nums = vec![];
+        for num_str in pair_str.split(", ") {
+            let num: usize = num_str.parse().map_err(|_e| {
+                format!("Parsing '{:?}' failed within '{:?}'", num_str, value)
+            })?;
+            nums.push(num);
+        }
+        if nums.len() != 2 {
+            return Err(format!(
+                concat!(
+                    "More than two numbers in the same group when",
+                    " parsing '{:?}' failed within '{:?}'",
+                ),
+                pair_str, value,
+            ));
+        }
+        let (client, curr) = (nums[0], nums[1]);
+        if progress[client] != curr {
+            return Err(format!(
+                "Client {} failed, expecting {}, got {}, others are {:?} in {}",
+                client, progress[client], curr, progress, value,
+            ));
+        }
+        progress[client] = curr + 1;
+    }
+    assert_eq!(progress, expected, "Expecting progress in {}", value);
+    Ok(())
+}
+
+#[test]
+fn basic_service() {
+    generic_test(GenericTestParams {
+        clients: 1,
+        ..Default::default()
+    });
+}
+
+#[test]
+fn concurrent_client() {
+    generic_test(GenericTestParams {
+        clients: 5,
+        ..Default::default()
+    });
+}
+
+#[test]
+fn unreliable_many_clients() {
+    generic_test(GenericTestParams {
+        clients: 5,
+        unreliable: true,
+        ..Default::default()
+    });
+}
+
+#[test]
+fn unreliable_one_key_many_clients() -> Result {
+    const SERVERS: usize = 5;
+    let cfg = Arc::new(make_config(SERVERS, true, 0));
+    defer!(cfg.clean_up());
+
+    let mut clerk = cfg.make_clerk();
+
+    cfg.begin("Test: concurrent append to same key, unreliable (3A)");
+
+    clerk.put("k", "");
+
+    const CLIENTS: usize = 5;
+    const ATTEMPTS: usize = 10;
+    let client_results =
+        spawn_clients(cfg.clone(), CLIENTS, |index, mut clerk| {
+            for i in 0..ATTEMPTS {
+                clerk.append("k", format!("({}, {})", index, i));
+            }
+        });
+    for client_result in client_results {
+        client_result.join().expect("Client should never fail");
+    }
+
+    let value = clerk.get("k").expect("Key should exist");
+
+    check_concurrent_results(value, CLIENTS, vec![ATTEMPTS; CLIENTS])
+}
+
+#[test]
+fn one_partition() -> Result {
+    const SERVERS: usize = 5;
+    let cfg = Arc::new(make_config(SERVERS, false, 0));
+    defer!(cfg.clean_up());
+
+    cfg.begin("Test: progress in majority (3A)");
+
+    const KEY: &str = "1";
+    let mut clerk = cfg.make_clerk();
+    clerk.put(KEY, "13");
+
+    let (majority, minority) = cfg.partition();
+
+    assert!(minority.len() < majority.len());
+    assert_eq!(minority.len() + majority.len(), SERVERS);
+
+    let mut clerk_majority = cfg.make_limited_clerk(&majority);
+    let mut clerk_minority1 = cfg.make_limited_clerk(&minority);
+    let mut clerk_minority2 = cfg.make_limited_clerk(&minority);
+
+    clerk_majority.put(KEY, "14");
+    assert_eq!(clerk_majority.get(KEY), Some("14".to_owned()));
+
+    cfg.begin("Test: no progress in minority (3A)");
+    let counter = Arc::new(AtomicUsize::new(0));
+    let counter1 = counter.clone();
+    std::thread::spawn(move || {
+        clerk_minority1.put(KEY, "15");
+        counter1.fetch_or(1, Ordering::SeqCst);
+    });
+    let counter2 = counter.clone();
+    std::thread::spawn(move || {
+        clerk_minority2.get(KEY);
+        counter2.fetch_or(2, Ordering::SeqCst);
+    });
+
+    sleep_millis(1000);
+
+    assert_eq!(counter.load(Ordering::SeqCst), 0);
+
+    assert_eq!(clerk_majority.get(KEY), Some("14".to_owned()));
+    clerk_majority.put(KEY, "16");
+    assert_eq!(clerk_majority.get(KEY), Some("16".to_owned()));
+
+    cfg.begin("Test: completion after heal (3A)");
+
+    cfg.connect_all();
+    cfg.connect_all_clerks();
+
+    sleep_election_timeouts(1);
+    for _ in 0..100 {
+        sleep_millis(60);
+        if counter.load(Ordering::SeqCst) == 3 {
+            break;
+        }
+    }
+
+    assert_eq!(counter.load(Ordering::SeqCst), 3);
+    assert_eq!(clerk.get(KEY), Some("15".to_owned()));
+
+    Ok(())
+}
+
+#[test]
+fn many_partitions_one_client() {
+    generic_test(GenericTestParams {
+        clients: 1,
+        partition: true,
+        ..Default::default()
+    });
+}
+
+#[test]
+fn many_partitions_many_client() {
+    generic_test(GenericTestParams {
+        clients: 5,
+        partition: true,
+        ..Default::default()
+    });
+}
+
+#[test]
+fn persist_one_client() {
+    generic_test(GenericTestParams {
+        clients: 1,
+        crash: true,
+        ..Default::default()
+    });
+}
+
+#[test]
+fn persist_concurrent() {
+    generic_test(GenericTestParams {
+        clients: 5,
+        crash: true,
+        ..Default::default()
+    });
+}
+
+#[test]
+fn persist_concurrent_unreliable() {
+    generic_test(GenericTestParams {
+        clients: 5,
+        unreliable: true,
+        crash: true,
+        ..Default::default()
+    });
+}
+
+#[test]
+fn persist_partition() {
+    generic_test(GenericTestParams {
+        clients: 5,
+        partition: true,
+        crash: true,
+        ..Default::default()
+    });
+}
+
+#[test]
+fn persist_partition_unreliable() {
+    generic_test(GenericTestParams {
+        clients: 5,
+        unreliable: true,
+        partition: true,
+        crash: true,
+        min_ops: Some(5),
+        ..Default::default()
+    });
+}
+
+#[test]
+fn linearizability() {
+    generic_test(GenericTestParams {
+        clients: 15,
+        unreliable: true,
+        partition: true,
+        crash: true,
+        maxraftstate: None,
+        min_ops: Some(0),
+        test_linearizability: true,
+    });
+}

+ 7 - 0
linearizability/Cargo.toml

@@ -0,0 +1,7 @@
+[package]
+name = "linearizability"
+version = "0.1.0"
+edition = "2018"
+
+[dependencies]
+bit-set = "0.5"

+ 225 - 0
linearizability/src/lib.rs

@@ -0,0 +1,225 @@
+use std::collections::HashSet;
+use std::fmt::Debug;
+use std::time::{Duration, Instant};
+
+use bit_set::BitSet;
+
+pub use model::KvInput;
+pub use model::KvModel;
+pub use model::KvOp;
+pub use model::KvOutput;
+pub use model::Model;
+
+use crate::offset_linked_list::{NodeRef, OffsetLinkedList};
+
+mod model;
+mod offset_linked_list;
+
+#[derive(Debug)]
+pub struct Operation<C: Debug, R: Debug> {
+    pub call_op: C,
+    pub call_time: Instant,
+    pub return_op: R,
+    pub return_time: Instant,
+}
+
+enum EntryKind<'a, C: Debug, R: Debug> {
+    Call(&'a Operation<C, R>),
+    Return,
+}
+
+struct Entry<'a, C: Debug, R: Debug> {
+    kind: EntryKind<'a, C, R>,
+    id: usize,
+    time: Instant,
+    other: usize,
+}
+
+fn operation_to_entries<'a, C: Debug, R: Debug>(
+    ops: &[&'a Operation<C, R>],
+) -> Vec<Entry<'a, C, R>> {
+    let mut result = vec![];
+    for op in ops {
+        let id = result.len() >> 1;
+        result.push(Entry {
+            kind: EntryKind::Return,
+            id,
+            time: op.return_time,
+            other: 0,
+        });
+        result.push(Entry {
+            kind: EntryKind::Call(op),
+            id,
+            time: op.call_time,
+            other: 0,
+        });
+    }
+    result.sort_by_cached_key(|e| e.time);
+    let mut this = vec![0; ops.len()];
+    let mut that = vec![0; ops.len()];
+    for (index, entry) in result.iter().enumerate() {
+        match entry.kind {
+            EntryKind::Call(_) => this[entry.id] = index,
+            EntryKind::Return => that[entry.id] = index,
+        }
+    }
+    for i in 0..ops.len() {
+        result[this[i]].other = that[i];
+        result[that[i]].other = this[i];
+    }
+    result
+}
+
+fn check_history<T: Model>(
+    ops: &[&Operation<<T as Model>::Input, <T as Model>::Output>],
+) -> bool {
+    let entries = operation_to_entries(ops);
+    let mut list = OffsetLinkedList::create(entries);
+
+    let mut all = HashSet::new();
+    let mut stack = vec![];
+
+    let mut flag = BitSet::new();
+    let mut leg = list.first().expect("Linked list should not be empty");
+    let mut curr = T::create();
+    while !list.is_empty() {
+        let entry = list.get(leg);
+        let other = NodeRef(entry.other);
+        match entry.kind {
+            EntryKind::Call(ops) => {
+                let mut next = curr.clone();
+                if next.step(&ops.call_op, &ops.return_op) {
+                    let mut next_flag = flag.clone();
+                    next_flag.insert(entry.id);
+                    if all.insert((next_flag.clone(), next.clone())) {
+                        std::mem::swap(&mut curr, &mut next);
+                        std::mem::swap(&mut flag, &mut next_flag);
+                        stack.push((leg, next, next_flag));
+
+                        list.lift(leg);
+                        list.lift(other);
+
+                        if let Some(first) = list.first() {
+                            leg = first;
+                        } else {
+                            break;
+                        }
+                    } else {
+                        leg = list
+                            .succ(leg)
+                            .expect("There should be another element");
+                    }
+                } else {
+                    leg = list
+                        .succ(leg)
+                        .expect("There should be another element");
+                }
+            }
+            EntryKind::Return => {
+                if stack.is_empty() {
+                    return false;
+                }
+                let (prev_leg, prev, prev_flag) = stack.pop().unwrap();
+                leg = prev_leg;
+                curr = prev;
+                flag = prev_flag;
+
+                list.unlift(NodeRef(list.get(leg).other));
+                list.unlift(leg);
+                leg = list.succ(leg).expect("There should be another element");
+            }
+        }
+    }
+    true
+}
+
+pub fn check_operations_timeout<T: Model>(
+    history: &'static [Operation<<T as Model>::Input, <T as Model>::Output>],
+    _: Option<Duration>,
+) -> bool
+where
+    <T as Model>::Input: Sync,
+    <T as Model>::Output: Sync,
+{
+    let mut results = vec![];
+    let mut partitions = vec![];
+    for sub_history in T::partition(history) {
+        // Making a copy and pass the original value to the thread below.
+        partitions.push(sub_history.clone());
+        results
+            .push(std::thread::spawn(move || check_history::<T>(&sub_history)));
+    }
+    let mut failed = vec![];
+    for (index, result) in results.into_iter().enumerate() {
+        let result = result.join().expect("Search thread should never panic");
+        if !result {
+            eprintln!("Partition {} failed: {:?}.", index, partitions[index]);
+            failed.push(index);
+        }
+    }
+    failed.is_empty()
+}
+
+#[cfg(test)]
+mod tests {
+    use std::time::{Duration, Instant};
+
+    use crate::{check_operations_timeout, Model, Operation};
+
+    #[derive(Clone, Debug, Eq, PartialEq, Hash)]
+    struct CountingModel {
+        base: usize,
+        cnt: usize,
+    }
+
+    impl Model for CountingModel {
+        type Input = usize;
+        type Output = usize;
+
+        fn create() -> Self {
+            Self { base: 0, cnt: 0 }
+        }
+
+        fn step(&mut self, input: &Self::Input, output: &Self::Output) -> bool {
+            if self.base == 0 && *input != 0 && *output == 1 {
+                self.base = *input;
+                self.cnt = 1;
+                true
+            } else if self.base == *input && self.cnt + 1 == *output {
+                self.cnt += 1;
+                true
+            } else {
+                false
+            }
+        }
+    }
+    #[test]
+    fn no_accept() {
+        let ops = Box::leak(Box::new(vec![]));
+        let start = Instant::now();
+        for i in 0..4 {
+            ops.push(Operation {
+                call_op: 0usize,
+                call_time: start,
+                return_op: i as usize,
+                return_time: start + Duration::from_secs(i),
+            });
+        }
+        assert!(!check_operations_timeout::<CountingModel>(ops, None));
+    }
+
+    #[test]
+    fn accept() {
+        let mut ops = Box::leak(Box::new(vec![]));
+        let start = Instant::now();
+        for i in 0..4 {
+            ops.push(Operation {
+                call_op: 1usize,
+                call_time: start + Duration::from_secs(i * 2),
+                return_op: (i + 1) as usize,
+                return_time: start + Duration::from_secs(i + 4),
+            });
+        }
+        assert!(check_operations_timeout::<CountingModel>(ops, None));
+    }
+}

+ 82 - 0
linearizability/src/model.rs

@@ -0,0 +1,82 @@
+use std::collections::HashMap;
+
+use crate::Operation;
+
+pub trait Model:
+    std::cmp::Eq + std::clone::Clone + std::hash::Hash + std::fmt::Debug
+{
+    type Input: std::fmt::Debug;
+    type Output: std::fmt::Debug;
+
+    fn create() -> Self;
+    fn partition(
+        history: &[Operation<Self::Input, Self::Output>],
+    ) -> Vec<Vec<&Operation<Self::Input, Self::Output>>> {
+        let history: Vec<&Operation<Self::Input, Self::Output>> =
+            history.iter().collect();
+        return vec![history];
+    }
+    fn step(&mut self, input: &Self::Input, output: &Self::Output) -> bool;
+}
+
+#[derive(Clone, Debug)]
+pub enum KvOp {
+    Get,
+    Put,
+    Append,
+}
+
+#[derive(Clone, Debug)]
+pub struct KvInput {
+    pub op: KvOp,
+    pub key: String,
+    pub value: String,
+}
+pub type KvOutput = String;
+
+unsafe impl Sync for KvInput {}
+
+#[derive(Clone, Debug, Eq, PartialEq, Hash)]
+pub struct KvModel {
+    expected_output: String,
+}
+
+impl Model for KvModel {
+    type Input = KvInput;
+    type Output = KvOutput;
+
+    fn create() -> Self {
+        KvModel {
+            expected_output: String::new(),
+        }
+    }
+
+    fn partition(
+        history: &[Operation<KvInput, KvOutput>],
+    ) -> Vec<Vec<&Operation<KvInput, KvOutput>>> {
+        let mut by_key =
+            HashMap::<String, Vec<&Operation<KvInput, KvOutput>>>::new();
+        for op in history {
+            by_key.entry(op.call_op.key.clone()).or_default().push(op);
+        }
+        let mut result = vec![];
+        for (_, values) in by_key {
+            result.push(values);
+        }
+        result
+    }
+
+    fn step(&mut self, input: &KvInput, output: &KvOutput) -> bool {
+        match input.op {
+            KvOp::Get => self.expected_output == *output,
+            KvOp::Put => {
+                self.expected_output = input.value.clone();
+                true
+            }
+            KvOp::Append => {
+                self.expected_output += &input.value;
+                true
+            }
+        }
+    }
+}

+ 198 - 0
linearizability/src/offset_linked_list.rs

@@ -0,0 +1,198 @@
+use std::mem::MaybeUninit;
+
+struct Node<T> {
+    prev: usize,
+    succ: usize,
+    data: MaybeUninit<T>,
+}
+
+impl<T> Default for Node<T> {
+    fn default() -> Self {
+        Self {
+            prev: 0,
+            succ: 0,
+            data: MaybeUninit::uninit(),
+        }
+    }
+}
+
+pub struct OffsetLinkedList<T> {
+    nodes: Vec<Node<T>>,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub struct NodeRef(pub usize);
+
+impl NodeRef {
+    fn from_index(index: usize) -> Option<Self> {
+        if index == OffsetLinkedList::<()>::HEAD {
+            None
+        } else {
+            Some(Self(index - 1))
+        }
+    }
+}
+
+pub struct Iter<'a, T> {
+    list: &'a OffsetLinkedList<T>,
+    index: usize,
+}
+
+impl<'a, T> Iterator for Iter<'a, T> {
+    type Item = &'a T;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        if self.index == OffsetLinkedList::<()>::HEAD {
+            None
+        } else {
+            let node = self.list.at(self.index);
+            self.index = node.succ;
+            Some(unsafe { &*node.data.as_ptr() })
+        }
+    }
+}
+
+impl<T> OffsetLinkedList<T> {
+    const HEAD: usize = 0;
+
+    pub fn create(data: Vec<T>) -> Self {
+        let len = data.len();
+        let mut nodes = Vec::with_capacity(len + 1);
+        for _ in 0..len + 1 {
+            nodes.push(Node::default());
+        }
+        for (i, data) in data.into_iter().enumerate() {
+            nodes[i].succ = i + 1;
+            nodes[i + 1].prev = i;
+            nodes[i + 1].data = MaybeUninit::new(data);
+        }
+        nodes[Self::HEAD].prev = len;
+        nodes[len].succ = Self::HEAD;
+        Self { nodes }
+    }
+
+    fn offset_index(&self, index: NodeRef) -> usize {
+        assert!(index.0 + 1 < self.nodes.len());
+        index.0 + 1
+    }
+
+    pub fn lift(&mut self, index: NodeRef) {
+        let index = self.offset_index(index);
+        let prev = self.nodes[index].prev;
+        let succ = self.nodes[index].succ;
+        self.nodes[prev].succ = succ;
+        self.nodes[succ].prev = prev;
+    }
+
+    pub fn unlift(&mut self, index: NodeRef) {
+        let index = self.offset_index(index);
+        let prev = self.nodes[index].prev;
+        let succ = self.nodes[index].succ;
+        self.nodes[prev].succ = index;
+        self.nodes[succ].prev = index;
+    }
+
+    pub fn get(&self, index: NodeRef) -> &T {
+        let index = self.offset_index(index);
+        unsafe { &*self.nodes[index].data.as_ptr() }
+    }
+
+    #[allow(dead_code)]
+    pub fn prev(&self, index: NodeRef) -> Option<NodeRef> {
+        let index = self.offset_index(index);
+        let succ = self.nodes[index].prev;
+        NodeRef::from_index(succ)
+    }
+
+    pub fn succ(&self, index: NodeRef) -> Option<NodeRef> {
+        let index = self.offset_index(index);
+        NodeRef::from_index(self.nodes[index].succ)
+    }
+
+    pub fn first(&self) -> Option<NodeRef> {
+        NodeRef::from_index(self.nodes[Self::HEAD].succ)
+    }
+
+    #[allow(dead_code)]
+    pub fn last(&self) -> Option<NodeRef> {
+        NodeRef::from_index(self.nodes[Self::HEAD].prev)
+    }
+
+    pub fn is_empty(&self) -> bool {
+        self.nodes[Self::HEAD].succ == Self::HEAD
+    }
+
+    fn at(&self, index: usize) -> &Node<T> {
+        &self.nodes[index]
+    }
+
+    #[allow(dead_code)]
+    pub fn iter(&self) -> Iter<'_, T> {
+        Iter {
+            list: self,
+            index: 1,
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::offset_linked_list::{NodeRef, OffsetLinkedList};
+
+    fn make_list() -> OffsetLinkedList<char> {
+        let data: Vec<char> = ('a'..='z').collect();
+        OffsetLinkedList::create(data)
+    }
+
+    fn assert_char_list_eq(list: &OffsetLinkedList<char>, ans: &str) {
+        let mut list_str = String::new();
+        let mut leg = list.first();
+        while let Some(curr) = leg {
+            list_str.push(*list.get(curr));
+            leg = list.succ(curr);
+        }
+        assert_eq!(&list_str, ans);
+    }
+
+    #[test]
+    fn linked_list() {
+        let mut list = make_list();
+        let data_str: String = ('a'..='z').collect();
+        assert_char_list_eq(&list, &data_str);
+
+        let mut leg = list.first().unwrap();
+        for i in 0..10 {
+            if i % 3 == 0 {
+                list.lift(leg);
+            }
+            leg = list.succ(leg).unwrap();
+        }
+        list.lift(leg);
+        assert_char_list_eq(&list, &"bcefhilmnopqrstuvwxyz");
+
+        list.unlift(NodeRef(0));
+        list.unlift(NodeRef(3));
+        list.unlift(NodeRef(10));
+        assert_char_list_eq(&list, &"abcdefhiklmnopqrstuvwxyz");
+    }
+
+    #[test]
+    fn empty_linked_list() {
+        let mut list = make_list();
+        assert!(!list.is_empty());
+
+        let mut leg = list.first();
+        while let Some(curr) = leg {
+            leg = list.succ(curr);
+            list.lift(curr);
+        }
+        assert!(list.is_empty())
+    }
+
+    #[test]
+    fn iterate_linked_list() {
+        let list_str: String = make_list().iter().collect();
+        let data_str: String = ('a'..='z').collect();
+        assert_eq!(data_str, list_str);
+    }
+}

+ 3 - 5
src/lib.rs

@@ -330,6 +330,7 @@ where
             };
             self.apply_command_signal.notify_one();
         }
+        self.snapshot_daemon.log_grow(rf.log.start(), rf.log.end());
 
         AppendEntriesReply {
             term: args.term,
@@ -929,7 +930,7 @@ where
         self.election.stop_election_timer();
         self.new_log_entry.take().map(|n| n.send(None));
         self.apply_command_signal.notify_all();
-        self.snapshot_daemon.trigger();
+        self.snapshot_daemon.kill();
         self.stop_wait_group.wait();
         std::sync::Arc::try_unwrap(self.thread_pool)
             .expect(
@@ -987,8 +988,5 @@ impl ElectionState {
 }
 
 impl<C> Raft<C> {
-    pub const NO_SNAPSHOT: fn(Index) -> Snapshot = |index| Snapshot {
-        last_included_index: index,
-        data: vec![],
-    };
+    pub const NO_SNAPSHOT: fn(Index) = |_| {};
 }

+ 1 - 1
src/rpcs.rs

@@ -63,7 +63,7 @@ impl RpcClient {
     }
 }
 
-fn make_rpc_handler<Request, Reply, F>(
+pub fn make_rpc_handler<Request, Reply, F>(
     func: F,
 ) -> Box<dyn Fn(RequestMessage) -> ReplyMessage>
 where

+ 51 - 22
src/snapshot.rs

@@ -1,6 +1,10 @@
-use crate::{Index, Raft};
-use crossbeam_utils::sync::{Parker, Unparker};
 use std::sync::atomic::Ordering;
+use std::sync::Arc;
+
+use crossbeam_utils::sync::{Parker, Unparker};
+use parking_lot::{Condvar, Mutex};
+
+use crate::{Index, Raft};
 
 #[derive(Clone, Debug, Default)]
 pub struct Snapshot {
@@ -11,25 +15,51 @@ pub struct Snapshot {
 #[derive(Clone, Debug, Default)]
 pub(crate) struct SnapshotDaemon {
     unparker: Option<Unparker>,
+    current_snapshot: Arc<(Mutex<Snapshot>, Condvar)>,
 }
 
-pub trait RequestSnapshotFnMut:
-    'static + Send + FnMut(Index) -> Snapshot
-{
-}
+pub trait RequestSnapshotFnMut: 'static + Send + FnMut(Index) {}
 
-impl<T: 'static + Send + FnMut(Index) -> Snapshot> RequestSnapshotFnMut for T {}
+impl<T: 'static + Send + FnMut(Index)> RequestSnapshotFnMut for T {}
 
 impl SnapshotDaemon {
+    pub(crate) fn save_snapshot(&self, snapshot: Snapshot) {
+        let mut curr = self.current_snapshot.0.lock();
+        if curr.last_included_index < snapshot.last_included_index {
+            *curr = snapshot;
+        }
+        self.current_snapshot.1.notify_one();
+    }
+
     pub(crate) fn trigger(&self) {
         match &self.unparker {
             Some(unparker) => unparker.unpark(),
             None => {}
         }
     }
+
+    const MIN_SNAPSHOT_INDEX_INTERVAL: usize = 100;
+
+    pub(crate) fn log_grow(&self, first_index: Index, last_index: Index) {
+        if last_index - first_index > Self::MIN_SNAPSHOT_INDEX_INTERVAL {
+            self.trigger();
+        }
+    }
+
+    pub(crate) fn kill(&self) {
+        self.trigger();
+        // Acquire the lock to make sure the daemon thread either has been
+        // waiting on the condition, or has not checked `keep_running` yet.
+        let _ = self.current_snapshot.0.lock();
+        self.current_snapshot.1.notify_all();
+    }
 }
 
 impl<C: 'static + Clone + Default + Send + serde::Serialize> Raft<C> {
+    pub fn save_snapshot(&self, snapshot: Snapshot) {
+        self.snapshot_daemon.save_snapshot(snapshot)
+    }
+
     pub(crate) fn run_snapshot_daemon(
         &mut self,
         max_state_size: Option<usize>,
@@ -47,6 +77,7 @@ impl<C: 'static + Clone + Default + Send + serde::Serialize> Raft<C> {
         let keep_running = self.keep_running.clone();
         let rf = self.inner_state.clone();
         let persister = self.persister.clone();
+        let snapshot_daemon = self.snapshot_daemon.clone();
         let stop_wait_group = self.stop_wait_group.clone();
 
         std::thread::spawn(move || loop {
@@ -56,12 +87,23 @@ impl<C: 'static + Clone + Default + Send + serde::Serialize> Raft<C> {
                 drop(keep_running);
                 drop(rf);
                 drop(persister);
+                drop(snapshot_daemon);
                 drop(stop_wait_group);
                 break;
             }
             if persister.state_size() >= max_state_size {
                 let log_start = rf.lock().log.first_index_term();
-                let snapshot = request_snapshot(log_start.index + 1);
+                let snapshot = {
+                    let mut snapshot =
+                        snapshot_daemon.current_snapshot.0.lock();
+                    if keep_running.load(Ordering::SeqCst)
+                        && snapshot.last_included_index <= log_start.index
+                    {
+                        request_snapshot(log_start.index + 1);
+                        snapshot_daemon.current_snapshot.1.wait(&mut snapshot);
+                    }
+                    snapshot.clone()
+                };
 
                 let mut rf = rf.lock();
                 if rf.log.first_index_term() != log_start {
@@ -76,20 +118,7 @@ impl<C: 'static + Clone + Default + Send + serde::Serialize> Raft<C> {
                     continue;
                 }
 
-                if snapshot.last_included_index >= rf.log.end() {
-                    // We recently rolled back some of the committed logs. This
-                    // can happen but usually the same exact log entries will be
-                    // installed in the next AppendEntries request.
-                    // There is no need to retry, because when the log entries
-                    // are re-committed, we will be notified again.
-
-                    // We will not be notified when the log length changes. Thus
-                    // when the log length grows to passing last_included_index
-                    // the first time, no snapshot will be taken, although
-                    // nothing is preventing it to be done. We will wait until
-                    // at least one more entry is committed.
-                    continue;
-                }
+                assert!(snapshot.last_included_index < rf.log.end());
 
                 rf.log.shift(snapshot.last_included_index, snapshot.data);
                 persister.save_snapshot_and_state(

+ 160 - 0
tests/snapshot_tests.rs

@@ -0,0 +1,160 @@
+extern crate kvraft;
+#[macro_use]
+extern crate scopeguard;
+
+use kvraft::testing_utils::config::{make_config, sleep_election_timeouts};
+use kvraft::testing_utils::generic_test::{generic_test, GenericTestParams};
+use std::sync::Arc;
+
+#[test]
+fn install_snapshot_rpc() {
+    const SERVERS: usize = 3;
+    const MAX_RAFT_STATE: usize = 1000;
+    const KEY: &str = "a";
+    let cfg = Arc::new(make_config(SERVERS, false, MAX_RAFT_STATE));
+    defer!(cfg.clean_up());
+
+    let mut clerk = cfg.make_clerk();
+
+    cfg.begin("Test: InstallSnapshot RPC (3B)");
+
+    clerk.put("a", "A");
+    assert_eq!(clerk.get(KEY), Some("A".to_owned()));
+    let (majority, minority) = cfg.partition();
+    {
+        let mut clerk = cfg.make_limited_clerk(&majority);
+        for i in 0..50 {
+            let i_str = i.to_string();
+            clerk.put(&i_str, &i_str);
+        }
+        sleep_election_timeouts(1);
+        clerk.put("b", "B");
+    }
+
+    cfg.check_log_size(MAX_RAFT_STATE * 2)
+        .expect("Log does not seem to be trimmed:");
+
+    // Swap majority and minority.
+    let (mut majority, mut minority) = (minority, majority);
+    majority.push(
+        minority
+            .pop()
+            .expect("There should be at least one server in the majority."),
+    );
+    cfg.network_partition(&majority, &minority);
+
+    {
+        let mut clerk = cfg.make_limited_clerk(&majority);
+        clerk.put("c", "C");
+        clerk.put("d", "D");
+        assert_eq!(clerk.get(KEY), Some("A".to_owned()));
+        assert_eq!(clerk.get("b"), Some("B".to_owned()));
+        assert_eq!(clerk.get("c"), Some("C".to_owned()));
+        assert_eq!(clerk.get("d"), Some("D".to_owned()));
+        assert_eq!(clerk.get("1"), Some("1".to_owned()));
+        assert_eq!(clerk.get("49"), Some("49".to_owned()));
+    }
+
+    cfg.connect_all();
+    clerk.put("e", "E");
+    assert_eq!(clerk.get("c"), Some("C".to_owned()));
+    assert_eq!(clerk.get("e"), Some("E".to_owned()));
+    assert_eq!(clerk.get("1"), Some("1".to_owned()));
+    assert_eq!(clerk.get("49"), Some("49".to_owned()));
+
+    cfg.end();
+}
+
+#[test]
+fn snapshot_size() {
+    const SERVERS: usize = 3;
+    const MAX_RAFT_STATE: usize = 1000;
+    const MAX_SNAPSHOT_STATE: usize = 500;
+    let cfg = Arc::new(make_config(SERVERS, false, MAX_RAFT_STATE));
+    defer!(cfg.clean_up());
+
+    let mut clerk = cfg.make_clerk();
+
+    cfg.begin("Test: snapshot size is reasonable (3B)");
+
+    for _ in 0..200 {
+        clerk.put("x", "0");
+        assert_eq!(clerk.get("x"), Some("0".to_owned()));
+        clerk.put("x", "1");
+        assert_eq!(clerk.get("x"), Some("1".to_owned()));
+    }
+
+    cfg.check_log_size(MAX_RAFT_STATE * 2)
+        .expect("Log does not seem to be trimmed:");
+    cfg.check_snapshot_size(MAX_SNAPSHOT_STATE)
+        .expect("Snapshot size is too big:");
+
+    cfg.end();
+}
+
+#[test]
+fn snapshot_recover_test() {
+    generic_test(GenericTestParams {
+        clients: 1,
+        crash: true,
+        maxraftstate: Some(1000),
+        ..Default::default()
+    })
+}
+
+#[test]
+fn snapshot_recover_many_clients() {
+    generic_test(GenericTestParams {
+        clients: 20,
+        crash: true,
+        maxraftstate: Some(1000),
+        min_ops: Some(0),
+        ..Default::default()
+    })
+}
+
+#[test]
+fn snapshot_unreliable_test() {
+    generic_test(GenericTestParams {
+        clients: 5,
+        unreliable: true,
+        maxraftstate: Some(1000),
+        ..Default::default()
+    })
+}
+
+#[test]
+fn snapshot_unreliable_recover_test() {
+    generic_test(GenericTestParams {
+        clients: 5,
+        unreliable: true,
+        crash: true,
+        maxraftstate: Some(1000),
+        ..Default::default()
+    })
+}
+
+#[test]
+fn snapshot_unreliable_recover_partition() {
+    generic_test(GenericTestParams {
+        clients: 5,
+        unreliable: true,
+        crash: true,
+        partition: true,
+        maxraftstate: Some(1000),
+        min_ops: Some(0),
+        ..Default::default()
+    })
+}
+#[test]
+fn linearizability() {
+    generic_test(GenericTestParams {
+        clients: 15,
+        unreliable: true,
+        partition: true,
+        crash: true,
+        maxraftstate: Some(1000),
+        min_ops: Some(0),
+        test_linearizability: true,
+    });
+}