Explorar o código

Create an RPC client wrapper for Ruaft.

Jing Yang %!s(int64=4) %!d(string=hai) anos
pai
achega
ba0270ed2d
Modificáronse 2 ficheiros con 49 adicións e 3 borrados
  1. 1 0
      durio/Cargo.toml
  2. 48 3
      durio/src/ruaft_service.rs

+ 1 - 0
durio/Cargo.toml

@@ -13,6 +13,7 @@ keywords = ["raft"]
 categories = ["raft"]
 
 [dependencies]
+async-trait = "0.1"
 kvraft = { path = "../kvraft" }
 ruaft = { path = "..", features = ["integration-test"] }
 serde = "1.0"

+ 48 - 3
durio/src/ruaft_service.rs

@@ -1,15 +1,18 @@
 use std::sync::Arc;
 
+use async_trait::async_trait;
 use tarpc::context::Context;
 
 use kvraft::UniqueKVOp;
 use ruaft::{
     AppendEntriesArgs, AppendEntriesReply, InstallSnapshotArgs,
-    InstallSnapshotReply, Raft, RequestVoteArgs, RequestVoteReply,
+    InstallSnapshotReply, Raft, RemoteRaft, RequestVoteArgs, RequestVoteReply,
 };
+use std::io::ErrorKind;
+use tarpc::client::RpcError;
 
 #[tarpc::service]
-trait RuaftSerivce {
+trait RuaftService {
     async fn append_entries(
         args: AppendEntriesArgs<UniqueKVOp>,
     ) -> AppendEntriesReply;
@@ -22,7 +25,7 @@ trait RuaftSerivce {
 struct RuaftRpcServer(Arc<Raft<UniqueKVOp>>);
 
 #[tarpc::server]
-impl RuaftSerivce for RuaftRpcServer {
+impl RuaftService for RuaftRpcServer {
     async fn append_entries(
         self,
         _context: Context,
@@ -47,3 +50,45 @@ impl RuaftSerivce for RuaftRpcServer {
         self.0.process_request_vote(args)
     }
 }
+
+#[async_trait]
+impl RemoteRaft<UniqueKVOp> for RuaftServiceClient {
+    async fn request_vote(
+        &self,
+        args: RequestVoteArgs,
+    ) -> std::io::Result<RequestVoteReply> {
+        self.request_vote(Context::current(), args)
+            .await
+            .map_err(translate_rpc_error)
+    }
+
+    async fn append_entries(
+        &self,
+        args: AppendEntriesArgs<UniqueKVOp>,
+    ) -> std::io::Result<AppendEntriesReply> {
+        self.append_entries(Context::current(), args)
+            .await
+            .map_err(translate_rpc_error)
+    }
+
+    async fn install_snapshot(
+        &self,
+        args: InstallSnapshotArgs,
+    ) -> std::io::Result<InstallSnapshotReply> {
+        self.install_snapshot(Context::current(), args)
+            .await
+            .map_err(translate_rpc_error)
+    }
+}
+
+fn translate_rpc_error(e: RpcError) -> std::io::Error {
+    match e {
+        RpcError::Disconnected => std::io::Error::new(ErrorKind::BrokenPipe, e),
+        RpcError::DeadlineExceeded => {
+            std::io::Error::new(ErrorKind::TimedOut, e)
+        }
+        RpcError::Server(server_error) => {
+            std::io::Error::new(ErrorKind::Other, server_error)
+        }
+    }
+}