Преглед изворни кода

Create an RPC interceptor and a sample test.

Jing Yang пре 3 година
родитељ
комит
d173e212bb

+ 1 - 1
kvraft/src/lib.rs

@@ -2,7 +2,7 @@ pub use async_client::{AsyncClerk, AsyncClient};
 pub use client::Clerk;
 pub use common::{
     CommitSentinelArgs, CommitSentinelReply, GetArgs, GetReply, PutAppendArgs,
-    PutAppendReply,
+    PutAppendEnum, PutAppendReply, UniqueId,
 };
 pub use remote_kvraft::RemoteKvraft;
 pub use server::KVServer;

+ 7 - 1
src/utils/integration_test.rs

@@ -1,7 +1,7 @@
 #![cfg(feature = "integration-test")]
 
 use crate::{
-    AppendEntriesArgs, AppendEntriesReply, Peer, RequestVoteArgs,
+    AppendEntriesArgs, AppendEntriesReply, IndexTerm, Peer, RequestVoteArgs,
     RequestVoteReply, Term,
 };
 
@@ -40,6 +40,12 @@ pub fn unpack_request_vote_reply(reply: RequestVoteReply) -> (Term, bool) {
     (reply.term, reply.vote_granted)
 }
 
+pub fn unpack_append_entries_args<T>(
+    request: AppendEntriesArgs<T>,
+) -> Option<IndexTerm> {
+    request.entries.last().map(|e| e.into())
+}
+
 pub fn unpack_append_entries_reply(reply: AppendEntriesReply) -> (Term, bool) {
     (reply.term, reply.success)
 }

+ 3 - 0
test_configs/Cargo.toml

@@ -8,11 +8,14 @@ anyhow = "1.0"
 async-trait = "0.1"
 bincode = "1.3.3"
 bytes = "1.1"
+crossbeam-channel = "0.5.5"
+futures-channel = "0.3.21"
 futures-util = "0.3.21"
 kvraft = { path = "../kvraft", features = ["integration-test"] }
 labrpc = "0.2.2"
 linearizability = { path = "../linearizability" }
 log = "0.4"
+once_cell = "1.12.0"
 parking_lot = "0.12"
 rand = "0.8"
 ruaft = { path = "..", features = ["integration-test"] }

+ 273 - 0
test_configs/src/interceptor/mod.rs

@@ -0,0 +1,273 @@
+use std::future::Future;
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::Arc;
+use std::time::{Duration, Instant};
+
+use async_trait::async_trait;
+use crossbeam_channel::{Receiver, Sender};
+use once_cell::sync::OnceCell;
+
+use kvraft::{
+    GetArgs, KVServer, PutAppendArgs, PutAppendEnum, UniqueId, UniqueKVOp,
+};
+use ruaft::{
+    AppendEntriesArgs, AppendEntriesReply, InstallSnapshotArgs,
+    InstallSnapshotReply, Raft, RemoteRaft, ReplicableCommand, RequestVoteArgs,
+    RequestVoteReply,
+};
+
+use crate::Persister;
+
+type RaftId = usize;
+
+pub struct EventHandle {
+    pub from: RaftId,
+    pub to: RaftId,
+    sender: futures_channel::oneshot::Sender<std::io::Result<()>>,
+}
+
+struct EventStub {
+    receiver: futures_channel::oneshot::Receiver<std::io::Result<()>>,
+}
+
+fn create_event_pair(from: RaftId, to: RaftId) -> (EventHandle, EventStub) {
+    let (sender, receiver) = futures_channel::oneshot::channel();
+    (EventHandle { from, to, sender }, EventStub { receiver })
+}
+
+impl EventHandle {
+    pub fn unblock(self) {
+        self.sender.send(Ok(())).unwrap();
+    }
+
+    pub fn reply_error(self, e: std::io::Error) {
+        self.sender.send(Err(e)).unwrap();
+    }
+
+    pub fn reply_interrupted_error(self) {
+        self.reply_error(std::io::Error::from(std::io::ErrorKind::Interrupted))
+    }
+}
+
+impl EventStub {
+    pub async fn wait(self) -> std::io::Result<()> {
+        self.receiver.await.unwrap_or(Ok(()))
+    }
+}
+
+pub enum RaftRpcEvent<T> {
+    RequestVoteRequest(RequestVoteArgs),
+    RequestVoteResponse(RequestVoteArgs, RequestVoteReply),
+    AppendEntriesRequest(AppendEntriesArgs<T>),
+    AppendEntriesResponse(AppendEntriesArgs<T>, AppendEntriesReply),
+    InstallSnapshotRequest(InstallSnapshotArgs),
+    InstallSnapshotResponse(InstallSnapshotArgs, InstallSnapshotReply),
+}
+
+struct InterceptingRpcClient<T> {
+    from: RaftId,
+    to: RaftId,
+    target: OnceCell<Raft<T>>,
+    event_queue: Sender<(RaftRpcEvent<T>, EventHandle)>,
+}
+
+impl<T> InterceptingRpcClient<T> {
+    async fn intercept(&self, event: RaftRpcEvent<T>) -> std::io::Result<()> {
+        let (handle, stub) = create_event_pair(self.from, self.to);
+        let _ = self.event_queue.send((event, handle));
+        stub.wait().await
+    }
+
+    pub fn set_raft(&self, target: Raft<T>) {
+        self.target
+            .set(target)
+            .map_err(|_| ())
+            .expect("Raft should only be set once");
+    }
+}
+
+#[async_trait]
+impl<T: ReplicableCommand> RemoteRaft<T> for &InterceptingRpcClient<T> {
+    async fn request_vote(
+        &self,
+        args: RequestVoteArgs,
+    ) -> std::io::Result<RequestVoteReply> {
+        let event_result = self
+            .intercept(RaftRpcEvent::RequestVoteRequest(args.clone()))
+            .await;
+        if let Err(e) = event_result {
+            return Err(e);
+        };
+
+        let reply = self.target.wait().process_request_vote(args.clone());
+
+        self.intercept(RaftRpcEvent::RequestVoteResponse(args, reply.clone()))
+            .await
+            .map(|_| reply)
+    }
+
+    async fn append_entries(
+        &self,
+        args: AppendEntriesArgs<T>,
+    ) -> std::io::Result<AppendEntriesReply> {
+        let args_clone = args.clone();
+        let event_result = self
+            .intercept(RaftRpcEvent::AppendEntriesRequest(args_clone))
+            .await;
+        if let Err(e) = event_result {
+            return Err(e);
+        };
+
+        let reply = self.target.wait().process_append_entries(args.clone());
+
+        self.intercept(RaftRpcEvent::AppendEntriesResponse(args, reply.clone()))
+            .await
+            .map(|_| reply)
+    }
+
+    async fn install_snapshot(
+        &self,
+        args: InstallSnapshotArgs,
+    ) -> std::io::Result<InstallSnapshotReply> {
+        let event_result = self
+            .intercept(RaftRpcEvent::InstallSnapshotRequest(args.clone()))
+            .await;
+        if let Err(e) = event_result {
+            return Err(e);
+        };
+
+        let reply = self.target.wait().process_install_snapshot(args.clone());
+
+        self.intercept(RaftRpcEvent::InstallSnapshotResponse(
+            args,
+            reply.clone(),
+        ))
+        .await
+        .map(|_| reply)
+    }
+}
+
+pub struct EventQueue<T> {
+    pub receiver: Receiver<(RaftRpcEvent<T>, EventHandle)>,
+}
+
+fn make_grid_clients<T>(
+    server_count: usize,
+) -> (EventQueue<T>, Vec<Vec<InterceptingRpcClient<T>>>) {
+    let (sender, receiver) = crossbeam_channel::unbounded();
+    let mut all_clients = vec![];
+    for from in 0..server_count {
+        let mut clients = vec![];
+        for to in 0..server_count {
+            let interceptor = InterceptingRpcClient {
+                from,
+                to,
+                target: Default::default(),
+                event_queue: sender.clone(),
+            };
+            clients.push(interceptor);
+        }
+        all_clients.push(clients);
+    }
+    (EventQueue { receiver }, all_clients)
+}
+
+pub struct Config {
+    pub event_queue: EventQueue<UniqueKVOp>,
+    pub kv_servers: Vec<Arc<KVServer>>,
+    seq: AtomicUsize,
+}
+
+impl Config {
+    pub fn find_leader(&self) -> Option<&KVServer> {
+        let start = Instant::now();
+        while start.elapsed() < Duration::from_secs(1) {
+            if let Some(kv_server) = self
+                .kv_servers
+                .iter()
+                .find(|kv_server| kv_server.raft().get_state().1)
+            {
+                return Some(kv_server.as_ref());
+            }
+        }
+        None
+    }
+
+    pub async fn put(&self, key: String, value: String) -> Result<(), ()> {
+        let kv_server = self.find_leader().unwrap();
+        let result = kv_server
+            .put_append(PutAppendArgs {
+                key,
+                value,
+                op: PutAppendEnum::Put,
+                unique_id: UniqueId {
+                    clerk_id: 1,
+                    sequence_id: self.seq.fetch_add(1, Ordering::Relaxed)
+                        as u64,
+                },
+            })
+            .await;
+        result.result.map_err(|_| ())
+    }
+
+    pub fn spawn_put(
+        self: &Arc<Self>,
+        key: String,
+        value: String,
+    ) -> impl Future<Output = Result<(), ()>> {
+        let this = self.clone();
+        async move { this.put(key, value).await }
+    }
+
+    pub async fn get(&self, key: String) -> Result<String, ()> {
+        let kv_server = self.find_leader().unwrap();
+        let result = kv_server.get(GetArgs { key }).await;
+        result.result.map(|v| v.unwrap_or_default()).map_err(|_| ())
+    }
+
+    pub fn spawn_get(
+        self: &Arc<Self>,
+        key: String,
+    ) -> impl Future<Output = Result<String, ()>> {
+        let this = self.clone();
+        async move { this.get(key).await }
+    }
+}
+
+pub fn make_config(server_count: usize, max_state: Option<usize>) -> Config {
+    let (event_queue, clients) = make_grid_clients(server_count);
+    let persister = Arc::new(Persister::new());
+    let mut kv_servers = vec![];
+    let clients: Vec<Vec<&'static InterceptingRpcClient<UniqueKVOp>>> = clients
+        .into_iter()
+        .map(|v| {
+            v.into_iter()
+                .map(|c| {
+                    let c = Box::leak(Box::new(c));
+                    &*c
+                })
+                .collect()
+        })
+        .collect();
+    for (index, client_vec) in clients.iter().enumerate() {
+        let kv_server = KVServer::new(
+            client_vec.to_vec(),
+            index,
+            persister.clone(),
+            max_state,
+        );
+        kv_servers.push(kv_server);
+    }
+
+    for clients in clients.iter() {
+        for j in 0..server_count {
+            clients[j].set_raft(kv_servers[j].raft().clone());
+        }
+    }
+
+    Config {
+        event_queue,
+        kv_servers,
+        seq: AtomicUsize::new(0),
+    }
+}

+ 1 - 0
test_configs/src/lib.rs

@@ -1,3 +1,4 @@
+pub mod interceptor;
 pub mod kvraft;
 mod persister;
 pub mod raft;

+ 55 - 0
tests/regression_tests.rs

@@ -0,0 +1,55 @@
+use ruaft::utils::integration_test::{
+    unpack_append_entries_args, unpack_append_entries_reply,
+};
+use std::sync::Arc;
+use std::time::{Duration, Instant};
+use test_configs::interceptor::{make_config, RaftRpcEvent};
+use test_utils::init_test_log;
+
+#[test]
+fn smoke_test() {
+    init_test_log!();
+    let server_count = 3;
+    let config = make_config(server_count, None);
+    let config = Arc::new(config);
+    let thread_pool = tokio::runtime::Runtime::new().unwrap();
+    let put = thread_pool.spawn(
+        config.spawn_put("commit".to_string(), "consistency".to_string()),
+    );
+
+    let mut responded = false;
+    while let Ok((event, handle)) = config.event_queue.receiver.recv() {
+        if let RaftRpcEvent::AppendEntriesResponse(args, reply) = event {
+            if let Some(index_term) = unpack_append_entries_args(args) {
+                let (term, success) = unpack_append_entries_reply(reply);
+                if term == index_term.term && success && index_term.index >= 1 {
+                    responded = true;
+                    break;
+                }
+            }
+        }
+        handle.unblock();
+    }
+    assert!(responded, "At least one peer must have responded OK");
+    let result = thread_pool.block_on(put).unwrap();
+    assert!(result.is_ok());
+
+    let get = thread_pool.spawn(config.spawn_get("commit".to_string()));
+    let start = Instant::now();
+    while let Ok((_event, handle)) = config
+        .event_queue
+        .receiver
+        .recv_timeout(Duration::from_secs(1))
+    {
+        if get.is_finished() {
+            break;
+        }
+        if start.elapsed() >= Duration::from_secs(1) {
+            break;
+        }
+        handle.unblock();
+    }
+    assert!(get.is_finished());
+    let value = thread_pool.block_on(get).unwrap().unwrap();
+    assert_eq!("consistency", value);
+}