Переглянути джерело

Add functions that create clients and servers.

Both KV service and Raft service are supported.
Jing Yang 4 роки тому
батько
коміт
0cb5982d91
5 змінених файлів з 124 додано та 19 видалено
  1. 3 1
      durio/Cargo.toml
  2. 47 2
      durio/src/kv_service.rs
  3. 1 0
      durio/src/main.rs
  4. 26 16
      durio/src/raft_service.rs
  5. 47 0
      durio/src/utils.rs

+ 3 - 1
durio/Cargo.toml

@@ -14,13 +14,15 @@ categories = ["raft"]
 
 [dependencies]
 async-trait = "0.1"
+futures-util = "0.3.15"
 kvraft = { path = "../kvraft" }
 ruaft = { path = "..", features = ["integration-test"] }
 serde = "1.0"
 serde_derive = "1.0"
 serde_json = "1.0"
-tarpc = "0.27"
+tarpc = { version = "0.27", features = ["serde-transport", "tcp"] }
 tokio = { version = "1.7", features = ["macros", "rt-multi-thread", "time", "parking_lot"] }
+tokio-serde = { version = "0.8", features = ["json"] }
 warp = "0.3"
 
 [dev-dependencies]

+ 47 - 2
durio/src/kv_service.rs

@@ -1,11 +1,16 @@
+use std::future::Future;
+use std::net::SocketAddr;
 use std::sync::Arc;
 
+use async_trait::async_trait;
 use tarpc::context::Context;
 
-use kvraft::{GetArgs, GetReply, KVServer, PutAppendArgs, PutAppendReply};
+use kvraft::{
+    GetArgs, GetReply, KVServer, PutAppendArgs, PutAppendReply, RemoteKvraft,
+};
 
 #[tarpc::service]
-trait KVService {
+pub(crate) trait KVService {
     async fn get(args: GetArgs) -> GetReply;
     async fn put_append(args: PutAppendArgs) -> PutAppendReply;
 }
@@ -27,3 +32,43 @@ impl KVService for KVRpcServer {
         self.0.put_append(args).await
     }
 }
+
+#[async_trait]
+impl RemoteKvraft for KVServiceClient {
+    async fn get(&self, args: GetArgs) -> std::io::Result<GetReply> {
+        self.get(Context::current(), args)
+            .await
+            .map_err(crate::utils::translate_rpc_error)
+    }
+
+    async fn put_append(
+        &self,
+        args: PutAppendArgs,
+    ) -> std::io::Result<PutAppendReply> {
+        self.put_append(Context::current(), args)
+            .await
+            .map_err(crate::utils::translate_rpc_error)
+    }
+}
+
+#[allow(dead_code)]
+pub(crate) async fn connect_to_kv_service(
+    addr: SocketAddr,
+) -> std::io::Result<KVServiceClient> {
+    let conn = tarpc::serde_transport::tcp::connect(
+        addr,
+        tokio_serde::formats::Json::default,
+    )
+    .await?;
+    let client =
+        KVServiceClient::new(tarpc::client::Config::default(), conn).spawn();
+    Ok(client)
+}
+
+pub(crate) fn start_kv_service_server(
+    addr: SocketAddr,
+    kv_server: Arc<KVServer>,
+) -> impl Future<Output = std::io::Result<()>> {
+    let server = KVRpcServer(kv_server);
+    crate::utils::start_tarpc_server(addr, server.serve())
+}

+ 1 - 0
durio/src/main.rs

@@ -1,5 +1,6 @@
 mod kv_service;
 mod raft_service;
+mod utils;
 
 use std::sync::atomic::{AtomicUsize, Ordering};
 use std::sync::Arc;

+ 26 - 16
durio/src/raft_service.rs

@@ -1,3 +1,5 @@
+use std::future::Future;
+use std::net::SocketAddr;
 use std::sync::Arc;
 
 use async_trait::async_trait;
@@ -8,11 +10,9 @@ use ruaft::{
     AppendEntriesArgs, AppendEntriesReply, InstallSnapshotArgs,
     InstallSnapshotReply, Raft, RemoteRaft, RequestVoteArgs, RequestVoteReply,
 };
-use std::io::ErrorKind;
-use tarpc::client::RpcError;
 
 #[tarpc::service]
-trait RaftService {
+pub(crate) trait RaftService {
     async fn append_entries(
         args: AppendEntriesArgs<UniqueKVOp>,
     ) -> AppendEntriesReply;
@@ -22,6 +22,7 @@ trait RaftService {
     async fn request_vote(args: RequestVoteArgs) -> RequestVoteReply;
 }
 
+#[derive(Clone)]
 struct RaftRpcServer(Arc<Raft<UniqueKVOp>>);
 
 #[tarpc::server]
@@ -59,7 +60,7 @@ impl RemoteRaft<UniqueKVOp> for RaftServiceClient {
     ) -> std::io::Result<RequestVoteReply> {
         self.request_vote(Context::current(), args)
             .await
-            .map_err(translate_rpc_error)
+            .map_err(crate::utils::translate_rpc_error)
     }
 
     async fn append_entries(
@@ -68,7 +69,7 @@ impl RemoteRaft<UniqueKVOp> for RaftServiceClient {
     ) -> std::io::Result<AppendEntriesReply> {
         self.append_entries(Context::current(), args)
             .await
-            .map_err(translate_rpc_error)
+            .map_err(crate::utils::translate_rpc_error)
     }
 
     async fn install_snapshot(
@@ -77,18 +78,27 @@ impl RemoteRaft<UniqueKVOp> for RaftServiceClient {
     ) -> std::io::Result<InstallSnapshotReply> {
         self.install_snapshot(Context::current(), args)
             .await
-            .map_err(translate_rpc_error)
+            .map_err(crate::utils::translate_rpc_error)
     }
 }
 
-fn translate_rpc_error(e: RpcError) -> std::io::Error {
-    match e {
-        RpcError::Disconnected => std::io::Error::new(ErrorKind::BrokenPipe, e),
-        RpcError::DeadlineExceeded => {
-            std::io::Error::new(ErrorKind::TimedOut, e)
-        }
-        RpcError::Server(server_error) => {
-            std::io::Error::new(ErrorKind::Other, server_error)
-        }
-    }
+pub(crate) async fn connect_to_raft_service(
+    addr: SocketAddr,
+) -> std::io::Result<impl RemoteRaft<UniqueKVOp>> {
+    let conn = tarpc::serde_transport::tcp::connect(
+        addr,
+        tokio_serde::formats::Json::default,
+    )
+    .await?;
+    let client =
+        RaftServiceClient::new(tarpc::client::Config::default(), conn).spawn();
+    Ok(client)
+}
+
+pub(crate) fn start_raft_service_server(
+    addr: SocketAddr,
+    raft: Arc<Raft<UniqueKVOp>>,
+) -> impl Future<Output = std::io::Result<()>> {
+    let server = RaftRpcServer(raft);
+    crate::utils::start_tarpc_server(addr, server.serve())
 }

+ 47 - 0
durio/src/utils.rs

@@ -0,0 +1,47 @@
+use std::io::ErrorKind;
+use std::net::SocketAddr;
+
+use futures_util::StreamExt;
+use tarpc::client::RpcError;
+use tarpc::server::{Channel, Serve};
+
+pub(crate) fn translate_rpc_error(e: RpcError) -> std::io::Error {
+    match e {
+        RpcError::Disconnected => std::io::Error::new(ErrorKind::BrokenPipe, e),
+        RpcError::DeadlineExceeded => {
+            std::io::Error::new(ErrorKind::TimedOut, e)
+        }
+        RpcError::Server(server_error) => {
+            std::io::Error::new(ErrorKind::Other, server_error)
+        }
+    }
+}
+
+pub(crate) async fn start_tarpc_server<Request, Reply, ServeFn>(
+    addr: SocketAddr,
+    serve: ServeFn,
+) -> std::io::Result<()>
+where
+    Request: Send + 'static + serde::de::DeserializeOwned,
+    Reply: Send + 'static + serde::ser::Serialize,
+    ServeFn:
+        tarpc::server::Serve<Request, Resp = Reply> + Send + 'static + Clone,
+    <ServeFn as Serve<Request>>::Fut: Send,
+{
+    let mut listener = tarpc::serde_transport::tcp::listen(
+        addr,
+        tokio_serde::formats::Json::default,
+    )
+    .await?;
+
+    tokio::spawn(async move {
+        while let Some(conn) = listener.next().await {
+            if let Ok(conn) = conn {
+                let channel = tarpc::server::BaseChannel::with_defaults(conn)
+                    .max_concurrent_requests(1);
+                tokio::spawn(channel.execute(serve.clone()));
+            }
+        }
+    });
+    Ok(())
+}