Prechádzať zdrojové kódy

Update kvraft to be async.

Notable problems solved

1. Logging. Now KVServer also runs in an environment we do not control.
2. Blocking threads. Migrated away from blocking mutexes and started using
async-friendly channels and atomic counters.
Jing Yang 4 rokov pred
rodič
commit
0e2bfbd2c4

+ 1 - 0
kvraft/Cargo.toml

@@ -6,6 +6,7 @@ edition = "2018"
 [dependencies]
 async-trait = "0.1"
 bincode = "1.3.3"
+futures = "0.3.15"
 log = "0.4.14"
 parking_lot = "0.11.1"
 rand = "0.8"

+ 57 - 45
kvraft/src/server.rs

@@ -5,10 +5,12 @@ use std::sync::mpsc::{channel, Receiver};
 use std::sync::Arc;
 use std::time::Duration;
 
-use parking_lot::{Condvar, Mutex};
+use futures::FutureExt;
+use parking_lot::Mutex;
 use serde_derive::{Deserialize, Serialize};
 
 use ruaft::{ApplyCommandMessage, Persister, Raft, RemoteRaft, Term};
+use test_utils::log_with;
 use test_utils::thread_local_logger::LocalLogger;
 
 use crate::common::{
@@ -38,7 +40,15 @@ struct KVServerState {
     debug_kv: HashMap<String, String>,
     applied_op: HashMap<ClerkId, (UniqueId, CommitResult)>,
     #[serde(skip)]
-    queries: HashMap<UniqueId, Arc<ResultHolder>>,
+    queries: HashMap<
+        UniqueId,
+        (
+            Arc<ResultHolder>,
+            futures::channel::oneshot::Sender<
+                Result<CommitResult, CommitError>,
+            >,
+        ),
+    >,
 }
 
 #[derive(Clone, Serialize, Deserialize)]
@@ -57,8 +67,10 @@ impl Default for KVOp {
 
 struct ResultHolder {
     term: AtomicUsize,
-    result: Mutex<Result<CommitResult, CommitError>>,
-    condvar: Condvar,
+    peeks: AtomicUsize,
+    result: futures::future::Shared<
+        futures::channel::oneshot::Receiver<Result<CommitResult, CommitError>>,
+    >,
 }
 
 #[derive(Clone, Debug, Serialize, Deserialize)]
@@ -173,25 +185,23 @@ impl KVServer {
             }
         }
 
-        if let Some(result_holder) = state.queries.remove(&unique_id) {
+        if let Some((_, sender)) = state.queries.remove(&unique_id) {
             // This KV server might not be the same leader that committed the
             // query. We are not sure if it is a duplicate or a conflict. To
             // tell the difference, terms of all queries must be stored.
-            *result_holder.result.lock() = if leader == self.me {
+            let _ = sender.send(if leader == self.me {
                 Ok(result)
             } else {
                 Err(CommitError::NotMe(result))
-            };
-            result_holder.condvar.notify_all();
+            });
         };
     }
 
     fn restore_state(&self, mut new_state: KVServerState) {
         let mut state = self.state.lock();
         // Cleanup all existing queries.
-        for result_holder in state.queries.values() {
-            *result_holder.result.lock() = Err(CommitError::NotLeader);
-            result_holder.condvar.notify_all();
+        for (_, (_, sender)) in state.queries.drain() {
+            let _ = sender.send(Err(CommitError::NotLeader));
         }
 
         std::mem::swap(&mut new_state, &mut *state);
@@ -238,7 +248,7 @@ impl KVServer {
 
     const UNSEEN_TERM: usize = 0;
     const ATTEMPTING_TERM: usize = usize::MAX;
-    fn block_for_commit(
+    async fn block_for_commit(
         &self,
         unique_id: UniqueId,
         op: KVOp,
@@ -259,18 +269,21 @@ impl KVServer {
                 }
             };
             let entry = state.queries.entry(unique_id).or_insert_with(|| {
-                Arc::new(ResultHolder {
-                    term: AtomicUsize::new(Self::UNSEEN_TERM),
-                    result: Mutex::new(Err(CommitError::TimedOut)),
-                    condvar: Condvar::new(),
-                })
+                let (tx, rx) = futures::channel::oneshot::channel();
+                (
+                    Arc::new(ResultHolder {
+                        term: AtomicUsize::new(Self::UNSEEN_TERM),
+                        peeks: AtomicUsize::new(0),
+                        result: rx.shared(),
+                    }),
+                    tx,
+                )
             });
-            entry.clone()
+            entry.0.clone()
         };
 
         let (Term(hold_term), is_leader) = self.rf.get_state();
         if !is_leader {
-            result_holder.condvar.notify_all();
             return Err(CommitError::NotLeader);
         }
         Self::validate_term(hold_term);
@@ -307,7 +320,7 @@ impl KVServer {
                 me: self.me,
                 unique_id,
             };
-            let start = self.rf.start(op);
+            let start = log_with!(self.logger, self.rf.start(op));
             let start_term =
                 start.map_or(Self::UNSEEN_TERM, |(Term(term), _)| {
                     Self::validate_term(term);
@@ -324,23 +337,26 @@ impl KVServer {
             assert_eq!(set, Ok(Self::ATTEMPTING_TERM));
 
             if start_term == Self::UNSEEN_TERM {
-                result_holder.condvar.notify_all();
                 return Err(CommitError::NotLeader);
             }
         }
 
-        let mut guard = result_holder.result.lock();
+        let result = result_holder.result.clone();
         // Wait for the op to be committed.
-        result_holder.condvar.wait_for(&mut guard, timeout);
-
-        // Copy the result out.
-        let result = guard.clone();
-        // If the result is OK, all other requests should see "Duplicate".
-        if let Ok(result) = guard.clone() {
-            *guard = Err(CommitError::Duplicate(result))
+        let result = tokio::time::timeout(timeout, result).await;
+        match result {
+            Ok(Ok(Ok(result))) => {
+                // If the result is OK, all other requests should see "Duplicate".
+                if result_holder.peeks.fetch_add(1, Ordering::Relaxed) == 0 {
+                    Ok(result)
+                } else {
+                    Err(CommitError::Duplicate(result))
+                }
+            }
+            Ok(Ok(Err(e))) => Err(e),
+            Ok(Err(_)) => Err(CommitError::NotLeader),
+            Err(_) => Err(CommitError::TimedOut),
         }
-
-        result
     }
 
     fn validate_term(term: usize) {
@@ -353,17 +369,17 @@ impl KVServer {
 
     const DEFAULT_TIMEOUT: Duration = Duration::from_secs(1);
 
-    pub fn get(&self, args: GetArgs) -> GetReply {
-        self.logger.clone().attach();
+    pub async fn get(&self, args: GetArgs) -> GetReply {
         let map_dup = match args.op {
             GetEnum::AllowDuplicate => |r| Ok(r),
             GetEnum::NoDuplicate => |_| Err(KVError::Conflict),
         };
-        let result = match self.block_for_commit(
+        let result_fut = self.block_for_commit(
             args.unique_id,
             KVOp::Get(args.key),
             Self::DEFAULT_TIMEOUT,
-        ) {
+        );
+        let result = match result_fut.await {
             Ok(result) => Ok(result),
             Err(CommitError::Duplicate(result)) => map_dup(result),
             Err(CommitError::NotMe(result)) => map_dup(result),
@@ -381,17 +397,14 @@ impl KVServer {
         GetReply { result }
     }
 
-    pub fn put_append(&self, args: PutAppendArgs) -> PutAppendReply {
-        self.logger.clone().attach();
+    pub async fn put_append(&self, args: PutAppendArgs) -> PutAppendReply {
         let op = match args.op {
             PutAppendEnum::Put => KVOp::Put(args.key, args.value),
             PutAppendEnum::Append => KVOp::Append(args.key, args.value),
         };
-        let result = match self.block_for_commit(
-            args.unique_id,
-            op,
-            Self::DEFAULT_TIMEOUT,
-        ) {
+        let result_fut =
+            self.block_for_commit(args.unique_id, op, Self::DEFAULT_TIMEOUT);
+        let result = match result_fut.await {
             Ok(result) => result,
             Err(CommitError::Duplicate(result)) => result,
             Err(CommitError::NotMe(result)) => result,
@@ -430,9 +443,8 @@ impl KVServer {
         // Return error to new queries.
         self.keep_running.store(false, Ordering::SeqCst);
         // Cancel all in-flight queries.
-        for result_holder in self.state.lock().queries.values() {
-            *result_holder.result.lock() = Err(CommitError::NotLeader);
-            result_holder.condvar.notify_all();
+        for (_, (_, sender)) in self.state.lock().queries.drain() {
+            let _ = sender.send(Err(CommitError::NotLeader));
         }
 
         let rf = self.raft().clone();

+ 2 - 1
test_configs/Cargo.toml

@@ -8,8 +8,9 @@ anyhow = "1.0"
 async-trait = "0.1"
 bincode = "1.3.3"
 bytes = "1.0"
+futures-util = "0.3.15"
 kvraft = { path = "../kvraft" }
-labrpc = "0.1.12"
+labrpc = "0.2.1"
 linearizability = { path = "../linearizability" }
 log = "0.4"
 parking_lot = "0.11.1"

+ 1 - 1
test_configs/src/kvraft/config.rs

@@ -58,7 +58,7 @@ impl Config {
             KVServer::new(clients, index, persister, Some(self.maxraftstate));
         self.state.lock().kv_servers[index].replace(kv.clone());
 
-        let raft = std::rc::Rc::new(kv.raft().clone());
+        let raft = Arc::new(kv.raft().clone());
 
         register_server(raft, Self::server_name(index), self.network.as_ref())?;
 

+ 1 - 2
test_configs/src/raft/config.rs

@@ -1,6 +1,5 @@
 use std::collections::HashMap;
 use std::path::PathBuf;
-use std::rc::Rc;
 use std::sync::Arc;
 use std::time::{Duration, Instant};
 
@@ -323,7 +322,7 @@ impl Config {
         );
         self.state.lock().rafts[index].replace(raft.clone());
 
-        let raft = Rc::new(raft);
+        let raft = Arc::new(raft);
         register_server(raft, Self::server_name(index), self.network.as_ref())?;
         Ok(())
     }

+ 41 - 9
test_configs/src/rpcs.rs

@@ -1,9 +1,13 @@
+use std::future::Future;
+
 use async_trait::async_trait;
 use labrpc::{Client, Network, ReplyMessage, RequestMessage, Server};
 use parking_lot::Mutex;
 use serde::de::DeserializeOwned;
 use serde::Serialize;
 
+use futures_util::future::BoxFuture;
+use futures_util::FutureExt;
 use kvraft::{
     GetArgs, GetReply, KVServer, PutAppendArgs, PutAppendReply, RemoteKvraft,
 };
@@ -90,13 +94,13 @@ impl RemoteKvraft for RpcClient {
 
 pub fn make_rpc_handler<Request, Reply, F>(
     func: F,
-) -> Box<dyn Fn(RequestMessage) -> ReplyMessage>
+) -> impl Fn(RequestMessage) -> ReplyMessage
 where
     Request: DeserializeOwned,
     Reply: Serialize,
     F: 'static + Fn(Request) -> Reply,
 {
-    Box::new(move |request| {
+    move |request| {
         let reply = func(
             bincode::deserialize(&request)
                 .expect("Deserialization should not fail"),
@@ -105,12 +109,36 @@ where
         ReplyMessage::from(
             bincode::serialize(&reply).expect("Serialization should not fail"),
         )
-    })
+    }
+}
+
+pub fn make_async_rpc_handler<'a, Request, Reply, F, Fut>(
+    func: F,
+) -> impl Fn(RequestMessage) -> BoxFuture<'a, ReplyMessage>
+where
+    Request: DeserializeOwned + Send,
+    Reply: Serialize,
+    Fut: Future<Output = Reply> + Send + 'a,
+    F: 'a + Send + Clone + FnOnce(Request) -> Fut,
+{
+    move |request| {
+        let func = func.clone();
+        let fut = async move {
+            let request = bincode::deserialize(&request)
+                .expect("Deserialization should not fail");
+            let reply = func(request).await;
+            ReplyMessage::from(
+                bincode::serialize(&reply)
+                    .expect("Serialization should not fail"),
+            )
+        };
+        fut.boxed()
+    }
 }
 
 pub fn register_server<
     Command: 'static + Clone + Serialize + DeserializeOwned + Default,
-    R: 'static + AsRef<Raft<Command>> + Clone,
+    R: 'static + AsRef<Raft<Command>> + Send + Sync + Clone,
     S: AsRef<str>,
 >(
     raft: R,
@@ -150,7 +178,7 @@ pub fn register_server<
     Ok(())
 }
 pub fn register_kv_server<
-    KV: 'static + AsRef<KVServer> + Clone,
+    KV: 'static + AsRef<KVServer> + Send + Sync + Clone,
     S: AsRef<str>,
 >(
     kv: KV,
@@ -162,14 +190,18 @@ pub fn register_kv_server<
     let mut server = Server::make_server(server_name);
 
     let kv_clone = kv.clone();
-    server.register_rpc_handler(
+    server.register_async_rpc_handler(
         GET.to_owned(),
-        make_rpc_handler(move |args| kv_clone.as_ref().get(args)),
+        make_async_rpc_handler(move |args| async move {
+            kv_clone.as_ref().get(args).await
+        }),
     )?;
 
-    server.register_rpc_handler(
+    server.register_async_rpc_handler(
         PUT_APPEND.to_owned(),
-        make_rpc_handler(move |args| kv.as_ref().put_append(args)),
+        make_async_rpc_handler(move |args| async move {
+            kv.as_ref().put_append(args).await
+        }),
     )?;
 
     network.add_server(server_name, server);