Pārlūkot izejas kodu

Merge branch 'merge_tests': Separate test configs from production code.

Test configs are moved out of ruaft and kvraft. Those two crates no longer
depend on labrpc.

A RPC interface has been created to allow integration of other RPC frameworks.

No significant performance regressions discovered.
Jing Yang 4 gadi atpakaļ
vecāks
revīzija
48584f190e

+ 3 - 5
Cargo.toml

@@ -19,7 +19,6 @@ bytes = "1.0"
 crossbeam-utils = "0.8"
 futures-channel = "0.3.15"
 futures-util = "0.3.15"
-labrpc = "0.1.12"
 log = "0.4"
 parking_lot = "0.11.1"
 rand = "0.8"
@@ -33,17 +32,16 @@ default = []
 integration-test = ["test_utils"]
 
 [dev-dependencies]
-anyhow = "1.0"
-futures = { version = "0.3.15", features = ["thread-pool"] }
-ruaft = { path = ".", features = ["integration-test"] }
+kvraft = { path = "kvraft" }
 scopeguard = "1.1.0"
 stdext = "0.3"
+test_configs = { path = "test_configs" }
 test_utils = { path = "test_utils" }
-kvraft = { path = "kvraft" }
 
 [workspace]
 members = [
     "kvraft",
     "linearizability",
+    "test_configs",
     "test_utils",
 ]

+ 2 - 2
README.md

@@ -81,7 +81,7 @@ Things would be better after I implement an RPC interface and improve the `persi
 
 - [x] Split into multiple files
 - [x] Add public documentation
-- [ ] Add a proper RPC interface to all public methods
-- [ ] Benchmarks
+- [x] Add a proper RPC interface to all public methods
 - [x] Allow storing arbitrary information
 - [x] Add more logging.
+- [ ] Benchmarks

+ 3 - 4
kvraft/Cargo.toml

@@ -4,19 +4,18 @@ version = "0.1.0"
 edition = "2018"
 
 [dependencies]
+async-trait = "0.1"
 bincode = "1.3.3"
-bytes = "1.0"
-labrpc = "0.1.12"
+log = "0.4.14"
 parking_lot = "0.11.1"
 rand = "0.8"
 ruaft = { path = "..", features = ["integration-test"] }
-linearizability = { path = "../linearizability" }
 serde = "1.0.116"
 serde_derive = "1.0.116"
 test_utils = { path = "../test_utils" }
 tokio = { version = "1.7", features = ["time", "parking_lot"] }
-log = "0.4.14"
 
 [dev-dependencies]
 scopeguard = "1.1.0"
 stdext = "0.3.0"
+test_configs = { path = "../test_configs" }

+ 34 - 33
kvraft/src/client.rs

@@ -1,16 +1,13 @@
+use std::future::Future;
 use std::sync::atomic::{AtomicUsize, Ordering};
 use std::sync::Once;
 use std::time::Duration;
 
-use labrpc::{Client, RequestMessage};
-use serde::de::DeserializeOwned;
-use serde::Serialize;
-
 use crate::common::{
-    GetArgs, GetEnum, GetReply, KVRaftOptions, PutAppendArgs, PutAppendEnum,
-    PutAppendReply, UniqueIdSequence, GET, PUT_APPEND,
+    GetArgs, GetEnum, GetReply, KVError, KVRaftOptions, PutAppendArgs,
+    PutAppendEnum, PutAppendReply, UniqueIdSequence, ValidReply,
 };
-use crate::common::{KVError, ValidReply};
+use crate::RemoteKvraft;
 
 pub struct Clerk {
     init: Once,
@@ -18,7 +15,7 @@ pub struct Clerk {
 }
 
 impl Clerk {
-    pub fn new(servers: Vec<Client>) -> Self {
+    pub fn new(servers: Vec<impl RemoteKvraft>) -> Self {
         Self {
             init: Once::new(),
             inner: ClerkInner::new(servers),
@@ -64,7 +61,7 @@ impl Clerk {
 }
 
 pub struct ClerkInner {
-    servers: Vec<Client>,
+    servers: Vec<Box<dyn RemoteKvraft>>,
 
     last_server_index: AtomicUsize,
     unique_id: UniqueIdSequence,
@@ -73,7 +70,11 @@ pub struct ClerkInner {
 }
 
 impl ClerkInner {
-    pub fn new(servers: Vec<Client>) -> Self {
+    pub fn new(servers: Vec<impl RemoteKvraft>) -> Self {
+        let servers = servers
+            .into_iter()
+            .map(|s| Box::new(s) as Box<dyn RemoteKvraft>)
+            .collect();
         Self {
             servers,
 
@@ -94,7 +95,8 @@ impl ClerkInner {
                 op: GetEnum::NoDuplicate,
                 unique_id: self.unique_id.zero(),
             };
-            let reply: Option<GetReply> = self.call_rpc(GET, args, Some(1));
+            let reply: Option<GetReply> =
+                self.retry_rpc(|remote, args| remote.get(args), args, Some(1));
             if let Some(reply) = reply {
                 match reply.result {
                     Ok(_) => {
@@ -115,24 +117,18 @@ impl ClerkInner {
     }
 
     const DEFAULT_TIMEOUT: Duration = Duration::from_secs(1);
-
-    fn call_rpc<M, A, R>(
-        &mut self,
-        method: M,
-        args: A,
+    pub fn retry_rpc<'a, Func, Fut, Args, Reply>(
+        &'a mut self,
+        mut future_func: Func,
+        args: Args,
         max_retry: Option<usize>,
-    ) -> Option<R>
+    ) -> Option<Reply>
     where
-        M: AsRef<str>,
-        A: Serialize,
-        R: DeserializeOwned + ValidReply,
+        Args: Clone,
+        Reply: ValidReply,
+        Fut: Future<Output = std::io::Result<Reply>> + Send + 'a,
+        Func: FnMut(&'a dyn RemoteKvraft, Args) -> Fut,
     {
-        let method = method.as_ref().to_owned();
-        let data = RequestMessage::from(
-            bincode::serialize(&args)
-                .expect("Serialization of requests should not fail"),
-        );
-
         let max_retry =
             std::cmp::max(max_retry.unwrap_or(usize::MAX), self.servers.len());
 
@@ -142,7 +138,7 @@ impl ClerkInner {
             let rpc_response = self.executor.block_on(async {
                 tokio::time::timeout(
                     Self::DEFAULT_TIMEOUT,
-                    client.call_rpc(method.clone(), data.clone()),
+                    future_func(client.as_ref(), args.clone()),
                 )
                 .await
             });
@@ -150,9 +146,7 @@ impl ClerkInner {
                 Ok(reply) => reply,
                 Err(e) => Err(e.into()),
             };
-            if let Ok(reply) = reply {
-                let ret: R = bincode::deserialize(reply.as_ref())
-                    .expect("Deserialization of reply should not fail");
+            if let Ok(ret) = reply {
                 if ret.is_reply_valid() {
                     self.last_server_index.store(index, Ordering::Relaxed);
                     return Some(ret);
@@ -183,7 +177,11 @@ impl ClerkInner {
             op: GetEnum::AllowDuplicate,
             unique_id: self.unique_id.inc(),
         };
-        let reply: GetReply = self.call_rpc(GET, args, options.max_retry)?;
+        let reply: GetReply = self.retry_rpc(
+            |remote, args| remote.get(args),
+            args,
+            options.max_retry,
+        )?;
         match reply.result {
             Ok(val) => Some(val),
             Err(KVError::Conflict) => panic!("We should never see a conflict."),
@@ -213,8 +211,11 @@ impl ClerkInner {
             op,
             unique_id: self.unique_id.inc(),
         };
-        let reply: PutAppendReply =
-            self.call_rpc(PUT_APPEND, args, options.max_retry)?;
+        let reply: PutAppendReply = self.retry_rpc(
+            |remote, args| remote.put_append(args),
+            args,
+            options.max_retry,
+        )?;
         match reply.result {
             Ok(val) => Some(val),
             Err(KVError::Expired) => Some(()),

+ 0 - 3
kvraft/src/common.rs

@@ -52,9 +52,6 @@ impl UniqueIdSequence {
     }
 }
 
-pub(crate) const GET: &str = "KVServer.Get";
-pub(crate) const PUT_APPEND: &str = "KVServer.PutAppend";
-
 #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 pub enum PutAppendEnum {
     Put,

+ 3 - 1
kvraft/src/lib.rs

@@ -1,9 +1,11 @@
 pub use client::Clerk;
+pub use common::{GetArgs, GetReply, PutAppendArgs, PutAppendReply};
+pub use remote_kvraft::RemoteKvraft;
 pub use server::KVServer;
 
 mod client;
 mod common;
 mod server;
 
+mod remote_kvraft;
 mod snapshot_holder;
-pub mod testing_utils;

+ 13 - 0
kvraft/src/remote_kvraft.rs

@@ -0,0 +1,13 @@
+use async_trait::async_trait;
+
+use crate::common::{GetArgs, GetReply, PutAppendArgs, PutAppendReply};
+
+#[async_trait]
+pub trait RemoteKvraft: Send + Sync + 'static {
+    async fn get(&self, args: GetArgs) -> std::io::Result<GetReply>;
+
+    async fn put_append(
+        &self,
+        args: PutAppendArgs,
+    ) -> std::io::Result<PutAppendReply>;
+}

+ 0 - 85
kvraft/src/testing_utils/memory_persister.rs

@@ -1,85 +0,0 @@
-use std::sync::Arc;
-
-use parking_lot::Mutex;
-
-#[derive(Clone)]
-pub struct State {
-    bytes: bytes::Bytes,
-    snapshot: Vec<u8>,
-}
-
-pub struct MemoryPersister {
-    state: Mutex<State>,
-}
-
-impl MemoryPersister {
-    pub fn new() -> Self {
-        Self {
-            state: Mutex::new(State {
-                bytes: bytes::Bytes::new(),
-                snapshot: vec![],
-            }),
-        }
-    }
-}
-
-impl ruaft::Persister for MemoryPersister {
-    fn read_state(&self) -> bytes::Bytes {
-        self.state.lock().bytes.clone()
-    }
-
-    fn save_state(&self, data: bytes::Bytes) {
-        self.state.lock().bytes = data;
-    }
-
-    fn state_size(&self) -> usize {
-        self.state.lock().bytes.len()
-    }
-
-    fn save_snapshot_and_state(&self, state: bytes::Bytes, snapshot: &[u8]) {
-        let mut this = self.state.lock();
-        this.bytes = state;
-        this.snapshot = snapshot.to_vec();
-    }
-}
-
-impl MemoryPersister {
-    pub fn read(&self) -> State {
-        self.state.lock().clone()
-    }
-
-    pub fn restore(&self, state: State) {
-        *self.state.lock() = state;
-    }
-
-    pub fn snapshot_size(&self) -> usize {
-        self.state.lock().snapshot.len()
-    }
-}
-
-#[derive(Default)]
-pub struct MemoryStorage {
-    state_vec: Vec<Arc<MemoryPersister>>,
-}
-
-impl MemoryStorage {
-    pub fn make(&mut self) -> Arc<MemoryPersister> {
-        let persister = Arc::new(MemoryPersister::new());
-        self.state_vec.push(persister.clone());
-        persister
-    }
-
-    pub fn at(&self, index: usize) -> Arc<MemoryPersister> {
-        self.state_vec[index].clone()
-    }
-
-    pub fn replace(&mut self, index: usize) -> Arc<MemoryPersister> {
-        let persister = Arc::new(MemoryPersister::new());
-        self.state_vec[index] = persister.clone();
-        persister
-    }
-
-    pub fn all(&self) -> &Vec<Arc<MemoryPersister>> {
-        &self.state_vec
-    }
-}

+ 0 - 35
kvraft/src/testing_utils/rpcs.rs

@@ -1,35 +0,0 @@
-use labrpc::{Network, Server};
-use parking_lot::Mutex;
-
-use ruaft::rpcs::make_rpc_handler;
-
-use crate::common::{GET, PUT_APPEND};
-use crate::server::KVServer;
-
-pub fn register_kv_server<
-    KV: 'static + AsRef<KVServer> + Clone,
-    S: AsRef<str>,
->(
-    kv: KV,
-    name: S,
-    network: &Mutex<Network>,
-) -> std::io::Result<()> {
-    let mut network = network.lock();
-    let server_name = name.as_ref();
-    let mut server = Server::make_server(server_name);
-
-    let kv_clone = kv.clone();
-    server.register_rpc_handler(
-        GET.to_owned(),
-        make_rpc_handler(move |args| kv_clone.as_ref().get(args)),
-    )?;
-
-    server.register_rpc_handler(
-        PUT_APPEND.to_owned(),
-        make_rpc_handler(move |args| kv.as_ref().put_append(args)),
-    )?;
-
-    network.add_server(server_name, server);
-
-    Ok(())
-}

+ 4 - 5
kvraft/tests/service_test.rs

@@ -1,13 +1,12 @@
 use std::sync::atomic::{AtomicUsize, Ordering};
 use std::sync::Arc;
 
-use kvraft::testing_utils::config::{
-    make_config, sleep_election_timeouts, sleep_millis,
-};
-use kvraft::testing_utils::generic_test::{
+use scopeguard::defer;
+use test_configs::kvraft::config::make_config;
+use test_configs::kvraft::generic_test::{
     generic_test, spawn_clients, GenericTestParams,
 };
-use scopeguard::defer;
+use test_configs::utils::{sleep_election_timeouts, sleep_millis};
 use test_utils::init_test_log;
 use test_utils::thread_local_logger::LocalLogger;
 

+ 0 - 1
src/lib.rs

@@ -32,7 +32,6 @@ mod process_install_snapshot;
 mod process_request_vote;
 mod raft_state;
 mod remote_raft;
-pub mod rpcs;
 mod snapshot;
 mod sync_log_entries;
 mod term_marker;

+ 1 - 1
src/process_append_entries.rs

@@ -12,7 +12,7 @@ impl<Command> Raft<Command>
 where
     Command: Clone + serde::Serialize + Default,
 {
-    pub(crate) fn process_append_entries(
+    pub fn process_append_entries(
         &self,
         args: AppendEntriesArgs<Command>,
     ) -> AppendEntriesReply {

+ 1 - 1
src/process_install_snapshot.rs

@@ -3,7 +3,7 @@ use crate::daemon_env::ErrorKind;
 use crate::{InstallSnapshotArgs, InstallSnapshotReply, Raft, State};
 
 impl<C: Clone + Default + serde::Serialize> Raft<C> {
-    pub(crate) fn process_install_snapshot(
+    pub fn process_install_snapshot(
         &self,
         args: InstallSnapshotArgs,
     ) -> InstallSnapshotReply {

+ 1 - 1
src/process_request_vote.rs

@@ -8,7 +8,7 @@ impl<Command> Raft<Command>
 where
     Command: Clone + serde::Serialize + Default,
 {
-    pub(crate) fn process_request_vote(
+    pub fn process_request_vote(
         &self,
         args: RequestVoteArgs,
     ) -> RequestVoteReply {

+ 49 - 0
src/utils.rs

@@ -30,3 +30,52 @@ where
 }
 
 pub const RPC_DEADLINE: Duration = Duration::from_secs(2);
+
+#[cfg(feature = "integration-test")]
+pub mod integration_test {
+    use crate::{
+        AppendEntriesArgs, AppendEntriesReply, Peer, RequestVoteArgs,
+        RequestVoteReply, Term,
+    };
+
+    pub fn make_request_vote_args(
+        term: Term,
+        peer_id: usize,
+        last_log_index: usize,
+        last_log_term: Term,
+    ) -> RequestVoteArgs {
+        RequestVoteArgs {
+            term,
+            candidate_id: Peer(peer_id),
+            last_log_index,
+            last_log_term,
+        }
+    }
+
+    pub fn make_append_entries_args<Command>(
+        term: Term,
+        leader_id: usize,
+        prev_log_index: usize,
+        prev_log_term: Term,
+        leader_commit: usize,
+    ) -> AppendEntriesArgs<Command> {
+        AppendEntriesArgs {
+            term,
+            leader_id: Peer(leader_id),
+            prev_log_index,
+            prev_log_term,
+            entries: vec![],
+            leader_commit,
+        }
+    }
+
+    pub fn unpack_request_vote_reply(reply: RequestVoteReply) -> (Term, bool) {
+        (reply.term, reply.vote_granted)
+    }
+
+    pub fn unpack_append_entries_reply(
+        reply: AppendEntriesReply,
+    ) -> (Term, bool) {
+        (reply.term, reply.success)
+    }
+}

+ 23 - 0
test_configs/Cargo.toml

@@ -0,0 +1,23 @@
+[package]
+name = "test_configs"
+version = "0.1.0"
+edition = "2018"
+
+[dependencies]
+anyhow = "1.0"
+async-trait = "0.1"
+bincode = "1.3.3"
+bytes = "1.0"
+kvraft = { path = "../kvraft" }
+labrpc = "0.1.12"
+linearizability = { path = "../linearizability" }
+log = "0.4"
+parking_lot = "0.11.1"
+rand = "0.8"
+ruaft = { path = "..", features = ["integration-test"] }
+serde = "1.0.116"
+test_utils = { path = "../test_utils" }
+
+[dev-dependencies]
+futures = { version = "0.3.15", features = ["thread-pool"] }
+stdext = "0.3"

+ 19 - 31
kvraft/src/testing_utils/config.rs → test_configs/src/kvraft/config.rs

@@ -5,13 +5,10 @@ use parking_lot::Mutex;
 use rand::seq::SliceRandom;
 use rand::thread_rng;
 
-use ruaft::rpcs::register_server;
-use ruaft::Persister;
+use crate::{register_kv_server, register_server, Persister, RpcClient};
 
-use crate::client::Clerk;
-use crate::server::KVServer;
-use crate::testing_utils::memory_persister::{MemoryPersister, MemoryStorage};
-use crate::testing_utils::rpcs::register_kv_server;
+use kvraft::Clerk;
+use kvraft::KVServer;
 
 struct ConfigState {
     kv_servers: Vec<Option<Arc<KVServer>>>,
@@ -22,7 +19,7 @@ pub struct Config {
     network: Arc<Mutex<labrpc::Network>>,
     server_count: usize,
     state: Mutex<ConfigState>,
-    storage: Mutex<MemoryStorage>,
+    storage: Mutex<Vec<Arc<Persister>>>,
     maxraftstate: usize,
 }
 
@@ -48,14 +45,14 @@ impl Config {
         {
             let mut network = self.network.lock();
             for j in 0..self.server_count {
-                clients.push(ruaft::rpcs::RpcClient::new(network.make_client(
+                clients.push(crate::RpcClient::new(network.make_client(
                     Self::client_name(index, j),
                     Self::server_name(j),
                 )))
             }
         }
 
-        let persister = self.storage.lock().at(index);
+        let persister = self.storage.lock()[index].clone();
 
         let kv =
             KVServer::new(clients, index, persister, Some(self.maxraftstate));
@@ -153,9 +150,10 @@ impl Config {
             network.remove_server(Self::kv_server_name(index));
         }
 
-        let data = self.storage.lock().at(index).read();
+        let data = self.storage.lock()[index].read();
 
-        let persister = self.storage.lock().replace(index);
+        let persister = Arc::new(Persister::new());
+        self.storage.lock()[index] = persister.clone();
         persister.restore(data);
 
         if let Some(kv_server) = self.state.lock().kv_servers[index].take() {
@@ -198,10 +196,10 @@ impl Config {
         {
             let mut network = self.network.lock();
             for j in 0..self.server_count {
-                clients.push(network.make_client(
+                clients.push(RpcClient::new(network.make_client(
                     Self::kv_clerk_name(clerk_index, j),
                     Self::kv_server_name(j),
-                ));
+                )));
             }
             // Disable clerk connection to all kv servers.
             Self::set_clerk_connect(
@@ -253,10 +251,10 @@ impl Config {
     fn check_size(
         &self,
         upper: usize,
-        size_fn: impl Fn(&MemoryPersister) -> usize,
+        size_fn: impl Fn(&Persister) -> usize,
     ) -> Result<(), String> {
         let mut over_limits = String::new();
-        for (index, p) in self.storage.lock().all().iter().enumerate() {
+        for (index, p) in self.storage.lock().iter().enumerate() {
             let size = size_fn(p);
             if size > upper {
                 let str = format!(" (index {}, size {})", index, size);
@@ -273,11 +271,11 @@ impl Config {
     }
 
     pub fn check_log_size(&self, upper: usize) -> Result<(), String> {
-        self.check_size(upper, MemoryPersister::state_size)
+        self.check_size(upper, ruaft::Persister::state_size)
     }
 
     pub fn check_snapshot_size(&self, upper: usize) -> Result<(), String> {
-        self.check_size(upper, MemoryPersister::snapshot_size)
+        self.check_size(upper, Persister::snapshot_size)
     }
 }
 
@@ -298,11 +296,10 @@ pub fn make_config(
         next_clerk: 0,
     });
 
-    let mut storage = MemoryStorage::default();
-    for _ in 0..server_count {
-        storage.make();
-    }
-    let storage = Mutex::new(storage);
+    let storage = Mutex::new(vec![]);
+    storage
+        .lock()
+        .resize_with(server_count, || Arc::new(Persister::new()));
 
     let cfg = Config {
         network,
@@ -319,12 +316,3 @@ pub fn make_config(
 
     cfg
 }
-
-pub fn sleep_millis(mills: u64) {
-    std::thread::sleep(std::time::Duration::from_millis(mills))
-}
-
-pub const LONG_ELECTION_TIMEOUT_MILLIS: u64 = 1000;
-pub fn sleep_election_timeouts(count: u64) {
-    sleep_millis(LONG_ELECTION_TIMEOUT_MILLIS * count)
-}

+ 5 - 5
kvraft/src/testing_utils/generic_test.rs → test_configs/src/kvraft/generic_test.rs

@@ -9,11 +9,11 @@ use test_utils::thread_local_logger::LocalLogger;
 
 use linearizability::{KvInput, KvModel, KvOp, KvOutput, Operation};
 
-use crate::testing_utils::config::{
-    make_config, sleep_election_timeouts, sleep_millis, Config,
-    LONG_ELECTION_TIMEOUT_MILLIS,
+use super::config::{make_config, Config};
+use crate::utils::{
+    sleep_election_timeouts, sleep_millis, LONG_ELECTION_TIMEOUT_MILLIS,
 };
-use crate::Clerk;
+use kvraft::Clerk;
 
 pub fn spawn_clients<T, Func>(
     config: Arc<Config>,
@@ -129,7 +129,7 @@ fn run_partition(cfg: Arc<Config>, stop: Arc<AtomicBool>) {
             LONG_ELECTION_TIMEOUT_MILLIS
                 ..LONG_ELECTION_TIMEOUT_MILLIS + PARTITION_MAX_DELAY_MILLIS,
         );
-        std::thread::sleep(Duration::from_millis(delay));
+        sleep_millis(delay);
     }
 }
 

+ 0 - 2
kvraft/src/testing_utils/mod.rs → test_configs/src/kvraft/mod.rs

@@ -1,4 +1,2 @@
 pub mod config;
 pub mod generic_test;
-mod memory_persister;
-mod rpcs;

+ 8 - 0
test_configs/src/lib.rs

@@ -0,0 +1,8 @@
+pub mod kvraft;
+mod persister;
+pub mod raft;
+mod rpcs;
+pub mod utils;
+
+pub use persister::Persister;
+use rpcs::{register_kv_server, register_server, RpcClient};

+ 18 - 3
tests/config/persister/mod.rs → test_configs/src/persister.rs

@@ -1,8 +1,9 @@
 use parking_lot::Mutex;
 
-struct State {
-    bytes: bytes::Bytes,
-    snapshot: Vec<u8>,
+#[derive(Clone)]
+pub struct State {
+    pub bytes: bytes::Bytes,
+    pub snapshot: Vec<u8>,
 }
 
 pub struct Persister {
@@ -45,3 +46,17 @@ impl ruaft::Persister for Persister {
         this.snapshot = snapshot.to_vec();
     }
 }
+
+impl Persister {
+    pub fn read(&self) -> State {
+        self.state.lock().clone()
+    }
+
+    pub fn restore(&self, state: State) {
+        *self.state.lock() = state;
+    }
+
+    pub fn snapshot_size(&self) -> usize {
+        self.state.lock().snapshot.len()
+    }
+}

+ 10 - 21
tests/config/mod.rs → test_configs/src/raft/config.rs

@@ -2,19 +2,17 @@ use std::collections::HashMap;
 use std::path::PathBuf;
 use std::rc::Rc;
 use std::sync::Arc;
-use std::time::Instant;
+use std::time::{Duration, Instant};
 
 pub use anyhow::Result;
 use anyhow::{anyhow, bail};
 use parking_lot::Mutex;
 use rand::{thread_rng, Rng};
-use tokio::time::Duration;
 
-use ruaft::rpcs::register_server;
+use crate::register_server;
+use crate::utils::sleep_millis;
 use ruaft::{ApplyCommandMessage, Persister, Raft, Term};
 
-pub mod persister;
-
 struct ConfigState {
     rafts: Vec<Option<Raft<i32>>>,
     connected: Vec<bool>,
@@ -24,7 +22,7 @@ struct LogState {
     committed_logs: Vec<Vec<i32>>,
     results: Vec<Result<()>>,
     max_index: usize,
-    saved: Vec<Arc<persister::Persister>>,
+    saved: Vec<Arc<crate::Persister>>,
 }
 
 pub struct Config {
@@ -291,7 +289,7 @@ impl Config {
             raft.kill();
         }
         let mut log = self.log.lock();
-        log.saved[index] = Arc::new(persister::Persister::new());
+        log.saved[index] = Arc::new(crate::Persister::new());
         log.saved[index].save_state(data);
     }
 
@@ -304,7 +302,7 @@ impl Config {
         {
             let mut network = self.network.lock();
             for j in 0..self.server_count {
-                clients.push(ruaft::rpcs::RpcClient::new(network.make_client(
+                clients.push(crate::RpcClient::new(network.make_client(
                     Self::client_name(index, j),
                     Self::server_name(j),
                 )))
@@ -365,7 +363,7 @@ impl Config {
     pub fn end(&self) {}
 
     pub fn cleanup(&self) {
-        log::trace!("Cleaning up test config ...");
+        log::trace!("Cleaning up test raft.config ...");
         let mut network = self.network.lock();
         for i in 0..self.server_count {
             network.remove_server(Self::server_name(i));
@@ -377,7 +375,7 @@ impl Config {
                 raft.kill();
             }
         }
-        log::trace!("Cleaning up test config done.");
+        log::trace!("Cleaning up test raft.config done.");
         eprintln!(
             "Ruaft log file for {}: {:?}",
             self.test_path,
@@ -446,7 +444,7 @@ impl Config {
 #[macro_export]
 macro_rules! make_config {
     ($server_count:expr, $unreliable:expr) => {
-        $crate::config::make_config(
+        $crate::raft::config::make_config(
             $server_count,
             $unreliable,
             stdext::function_name!(),
@@ -476,7 +474,7 @@ pub fn make_config(
     });
 
     let mut saved = vec![];
-    saved.resize_with(server_count, || Arc::new(persister::Persister::new()));
+    saved.resize_with(server_count, || Arc::new(crate::Persister::new()));
     let log = Arc::new(Mutex::new(LogState {
         committed_logs: vec![vec![]; server_count],
         results: vec![],
@@ -500,12 +498,3 @@ pub fn make_config(
 
     cfg
 }
-
-pub fn sleep_millis(mills: u64) {
-    std::thread::sleep(std::time::Duration::from_millis(mills))
-}
-
-pub const LONG_ELECTION_TIMEOUT_MILLIS: u64 = 1000;
-pub fn sleep_election_timeouts(count: u64) {
-    sleep_millis(LONG_ELECTION_TIMEOUT_MILLIS * count)
-}

+ 1 - 0
test_configs/src/raft/mod.rs

@@ -0,0 +1 @@
+pub mod config;

+ 71 - 28
src/rpcs.rs → test_configs/src/rpcs.rs

@@ -1,13 +1,16 @@
 use async_trait::async_trait;
 use labrpc::{Client, Network, ReplyMessage, RequestMessage, Server};
 use parking_lot::Mutex;
+use serde::de::DeserializeOwned;
+use serde::Serialize;
 
-use crate::{
+use kvraft::{
+    GetArgs, GetReply, KVServer, PutAppendArgs, PutAppendReply, RemoteKvraft,
+};
+use ruaft::{
     AppendEntriesArgs, AppendEntriesReply, InstallSnapshotArgs,
     InstallSnapshotReply, Raft, RequestVoteArgs, RequestVoteReply,
 };
-use serde::de::DeserializeOwned;
-use serde::Serialize;
 
 const REQUEST_VOTE_RPC: &str = "Raft.RequestVote";
 const APPEND_ENTRIES_RPC: &str = "Raft.AppendEntries";
@@ -43,8 +46,8 @@ impl RpcClient {
 }
 
 #[async_trait]
-impl<Command: 'static + Send + Serialize>
-    crate::remote_raft::RemoteRaft<Command> for RpcClient
+impl<Command: 'static + Send + Serialize> ruaft::RemoteRaft<Command>
+    for RpcClient
 {
     async fn request_vote(
         &self,
@@ -68,6 +71,23 @@ impl<Command: 'static + Send + Serialize>
     }
 }
 
+const GET: &str = "KVServer.Get";
+const PUT_APPEND: &str = "KVServer.PutAppend";
+
+#[async_trait]
+impl RemoteKvraft for RpcClient {
+    async fn get(&self, args: GetArgs) -> std::io::Result<GetReply> {
+        self.call_rpc(GET, args).await
+    }
+
+    async fn put_append(
+        &self,
+        args: PutAppendArgs,
+    ) -> std::io::Result<PutAppendReply> {
+        self.call_rpc(PUT_APPEND, args).await
+    }
+}
+
 pub fn make_rpc_handler<Request, Reply, F>(
     func: F,
 ) -> Box<dyn Fn(RequestMessage) -> ReplyMessage>
@@ -129,6 +149,33 @@ pub fn register_server<
 
     Ok(())
 }
+pub fn register_kv_server<
+    KV: 'static + AsRef<KVServer> + Clone,
+    S: AsRef<str>,
+>(
+    kv: KV,
+    name: S,
+    network: &Mutex<Network>,
+) -> std::io::Result<()> {
+    let mut network = network.lock();
+    let server_name = name.as_ref();
+    let mut server = Server::make_server(server_name);
+
+    let kv_clone = kv.clone();
+    server.register_rpc_handler(
+        GET.to_owned(),
+        make_rpc_handler(move |args| kv_clone.as_ref().get(args)),
+    )?;
+
+    server.register_rpc_handler(
+        PUT_APPEND.to_owned(),
+        make_rpc_handler(move |args| kv.as_ref().put_append(args)),
+    )?;
+
+    network.add_server(server_name, server);
+
+    Ok(())
+}
 
 #[cfg(test)]
 mod tests {
@@ -136,12 +183,16 @@ mod tests {
 
     use bytes::Bytes;
 
-    use crate::{ApplyCommandMessage, Peer, RemoteRaft, Term};
+    use ruaft::{ApplyCommandMessage, RemoteRaft, Term};
 
     use super::*;
+    use ruaft::utils::integration_test::{
+        make_append_entries_args, make_request_vote_args,
+        unpack_append_entries_reply, unpack_request_vote_reply,
+    };
 
-    type DoNothingPersister = ();
-    impl crate::Persister for DoNothingPersister {
+    struct DoNothingPersister;
+    impl ruaft::Persister for DoNothingPersister {
         fn read_state(&self) -> Bytes {
             Bytes::new()
         }
@@ -157,6 +208,8 @@ mod tests {
 
     #[test]
     fn test_basic_message() -> std::io::Result<()> {
+        test_utils::init_test_log!();
+
         let client = {
             let network = Network::run_daemon();
             let name = "test-basic-message";
@@ -168,7 +221,7 @@ mod tests {
             let raft = Arc::new(Raft::new(
                 vec![RpcClient(client)],
                 0,
-                Arc::new(()),
+                Arc::new(DoNothingPersister),
                 |_: ApplyCommandMessage<i32>| {},
                 None,
                 Raft::<i32>::NO_SNAPSHOT,
@@ -182,30 +235,20 @@ mod tests {
         };
 
         let rpc_client = RpcClient(client);
-        let request = RequestVoteArgs {
-            term: Term(2021),
-
-            candidate_id: Peer(0),
-            last_log_index: 0,
-            last_log_term: Term(0),
-        };
+        let request = make_request_vote_args(Term(2021), 0, 0, Term(0));
         let response = futures::executor::block_on(
             (&rpc_client as &dyn RemoteRaft<i32>).request_vote(request),
         )?;
-        assert!(response.vote_granted);
-
-        let request = AppendEntriesArgs::<i32> {
-            term: Term(2021),
-            leader_id: Peer(0),
-            prev_log_index: 0,
-            prev_log_term: Term(0),
-            entries: vec![],
-            leader_commit: 0,
-        };
+        let (_, vote_granted) = unpack_request_vote_reply(response);
+        assert!(vote_granted);
+
+        let request =
+            make_append_entries_args::<i32>(Term(2021), 0, 0, Term(0), 0);
         let response =
             futures::executor::block_on(rpc_client.append_entries(request))?;
-        assert_eq!(2021, response.term.0);
-        assert!(response.success);
+        let (Term(term), success) = unpack_append_entries_reply(response);
+        assert_eq!(2021, term);
+        assert!(success);
 
         Ok(())
     }

+ 10 - 0
test_configs/src/utils.rs

@@ -0,0 +1,10 @@
+use std::time::Duration;
+
+pub fn sleep_millis(mills: u64) {
+    std::thread::sleep(Duration::from_millis(mills))
+}
+
+pub const LONG_ELECTION_TIMEOUT_MILLIS: u64 = 1000;
+pub fn sleep_election_timeouts(count: u64) {
+    sleep_millis(LONG_ELECTION_TIMEOUT_MILLIS * count)
+}

+ 9 - 10
tests/agreement_tests.rs

@@ -1,9 +1,8 @@
 #![allow(clippy::identity_op)]
 use rand::{thread_rng, Rng};
 use scopeguard::defer;
-
-// This is to remove the annoying "unused code in config" warnings.
-pub mod config;
+use test_configs::utils::{sleep_election_timeouts, sleep_millis};
+use test_configs::{make_config, raft::config};
 
 #[test]
 fn basic_agreement() -> config::Result<()> {
@@ -45,7 +44,7 @@ fn fail_agree() -> config::Result<()> {
     // agree despite one disconnected server?
     cfg.one(102, SERVERS - 1, false)?;
     cfg.one(103, SERVERS - 1, false)?;
-    config::sleep_election_timeouts(1);
+    sleep_election_timeouts(1);
     cfg.one(104, SERVERS - 1, false)?;
     cfg.one(105, SERVERS - 1, false)?;
 
@@ -54,7 +53,7 @@ fn fail_agree() -> config::Result<()> {
 
     // agree with full set of servers?
     cfg.one(106, SERVERS, true)?;
-    config::sleep_election_timeouts(1);
+    sleep_election_timeouts(1);
     cfg.one(107, SERVERS, true)?;
 
     cfg.end();
@@ -82,7 +81,7 @@ fn fail_no_agree() -> config::Result<()> {
     let index = result.unwrap().1;
     assert_eq!(2, index, "expected index 2, got {}", index);
 
-    config::sleep_election_timeouts(2);
+    sleep_election_timeouts(2);
 
     let (commit_count, _) = cfg.committed_count(index)?;
     assert_eq!(
@@ -173,7 +172,7 @@ fn backup() -> config::Result<()> {
         cfg.leader_start(leader1, thread_rng().gen());
     }
 
-    config::sleep_election_timeouts(2);
+    sleep_election_timeouts(2);
 
     cfg.disconnect((leader1 + 0) % SERVERS);
     cfg.disconnect((leader1 + 1) % SERVERS);
@@ -201,7 +200,7 @@ fn backup() -> config::Result<()> {
         cfg.leader_start(leader2, thread_rng().gen());
     }
 
-    config::sleep_election_timeouts(2);
+    sleep_election_timeouts(2);
 
     // bring original leader back to life,
     for i in 0..SERVERS {
@@ -249,7 +248,7 @@ fn count() -> config::Result<()> {
             break (false, 0);
         }
         if retries != 0 {
-            config::sleep_millis(3000);
+            sleep_millis(3000);
         }
         retries += 1;
 
@@ -320,7 +319,7 @@ fn count() -> config::Result<()> {
 
     assert!(success, "term change too often");
 
-    config::sleep_election_timeouts(1);
+    sleep_election_timeouts(1);
 
     let diff = cfg.total_rpcs() - total;
     assert!(

+ 4 - 5
tests/election_tests.rs

@@ -1,7 +1,6 @@
 use scopeguard::defer;
-
-// This is to remove the annoying "unused code in config" warnings.
-pub mod config;
+use test_configs::utils::{sleep_election_timeouts, sleep_millis};
+use test_configs::{make_config, raft::config};
 
 #[test]
 fn initial_election() -> config::Result<()> {
@@ -13,10 +12,10 @@ fn initial_election() -> config::Result<()> {
 
     cfg.check_one_leader()?;
 
-    config::sleep_millis(50);
+    sleep_millis(50);
 
     let first_term = cfg.check_terms()?;
-    config::sleep_election_timeouts(2);
+    sleep_election_timeouts(2);
 
     let second_term = cfg.check_terms()?;
 

+ 14 - 13
tests/persist_tests.rs

@@ -6,9 +6,10 @@ use std::sync::Arc;
 
 use rand::{thread_rng, Rng};
 use scopeguard::defer;
-
-// This is to remove the annoying "unused code in config" warnings.
-pub mod config;
+use test_configs::utils::{
+    sleep_election_timeouts, sleep_millis, LONG_ELECTION_TIMEOUT_MILLIS,
+};
+use test_configs::{make_config, raft::config};
 
 #[test]
 fn persist1() -> config::Result<()> {
@@ -90,7 +91,7 @@ fn persist2() -> config::Result<()> {
         cfg.connect((leader1 + 1) % SERVERS);
         cfg.connect((leader1 + 2) % SERVERS);
 
-        config::sleep_election_timeouts(1);
+        sleep_election_timeouts(1);
 
         cfg.start1((leader1 + 3) % SERVERS)?;
         cfg.connect((leader1 + 3) % SERVERS);
@@ -165,13 +166,13 @@ fn figure8() -> config::Result<()> {
         }
 
         let millis_upper = if thread_rng().gen_ratio(100, 1000) {
-            config::LONG_ELECTION_TIMEOUT_MILLIS >> 1
+            LONG_ELECTION_TIMEOUT_MILLIS >> 1
         } else {
             // Magic number 13?
             13
         };
         let millis = thread_rng().gen_range(0..millis_upper);
-        config::sleep_millis(millis);
+        sleep_millis(millis);
 
         if let Some(leader) = leader {
             cfg.crash1(leader);
@@ -263,13 +264,13 @@ fn figure8_unreliable() -> config::Result<()> {
         }
 
         let millis_upper = if thread_rng().gen_ratio(100, 1000) {
-            config::LONG_ELECTION_TIMEOUT_MILLIS >> 1
+            LONG_ELECTION_TIMEOUT_MILLIS >> 1
         } else {
             // Magic number 13?
             13
         };
         let millis = thread_rng().gen_range(0..millis_upper);
-        config::sleep_millis(millis);
+        sleep_millis(millis);
 
         if let Some(leader) = leader {
             if thread_rng().gen_ratio(1, 2) {
@@ -346,10 +347,10 @@ fn internal_churn(unreliable: bool) -> config::Result<()> {
                             }
                             // The contract we started might not get
                         }
-                        config::sleep_millis(*millis);
+                        sleep_millis(*millis);
                     }
                 } else {
-                    config::sleep_millis(79 + client_index * 17);
+                    sleep_millis(79 + client_index * 17);
                 }
             }
 
@@ -376,10 +377,10 @@ fn internal_churn(unreliable: bool) -> config::Result<()> {
                 cfg.crash1(server);
             }
         }
-        config::sleep_millis(config::LONG_ELECTION_TIMEOUT_MILLIS / 10 * 7);
+        sleep_millis(LONG_ELECTION_TIMEOUT_MILLIS / 10 * 7);
     }
 
-    config::sleep_election_timeouts(1);
+    sleep_election_timeouts(1);
     cfg.set_unreliable(false);
     for i in 0..SERVERS {
         if !cfg.is_server_alive(i) {
@@ -395,7 +396,7 @@ fn internal_churn(unreliable: bool) -> config::Result<()> {
         all_cmds.append(&mut cmds);
     }
 
-    config::sleep_election_timeouts(1);
+    sleep_election_timeouts(1);
 
     let last_cmd_index = cfg.one(thread_rng().gen(), SERVERS, true)?;
     let mut consented = vec![];

+ 3 - 2
tests/snapshot_tests.rs

@@ -2,8 +2,9 @@ use std::sync::Arc;
 
 use scopeguard::defer;
 
-use kvraft::testing_utils::config::{make_config, sleep_election_timeouts};
-use kvraft::testing_utils::generic_test::{generic_test, GenericTestParams};
+use test_configs::kvraft::config::make_config;
+use test_configs::kvraft::generic_test::{generic_test, GenericTestParams};
+use test_configs::utils::sleep_election_timeouts;
 use test_utils::init_test_log;
 
 #[test]