Prechádzať zdrojové kódy

Simplify RPC handler and client wrapper using generics.

Jing Yang 4 rokov pred
rodič
commit
1dec92e51a
1 zmenil súbory, kde vykonal 44 pridanie a 76 odobranie
  1. 44 76
      src/rpcs.rs

+ 44 - 76
src/rpcs.rs

@@ -9,51 +9,6 @@ use crate::{
 use serde::de::DeserializeOwned;
 use serde::Serialize;
 
-fn proxy_request_vote<Command: Clone + Serialize + Default>(
-    raft: &Raft<Command>,
-    data: RequestMessage,
-) -> ReplyMessage {
-    let reply = raft.process_request_vote(
-        bincode::deserialize(data.as_ref())
-            .expect("Deserialization of requests should not fail"),
-    );
-
-    ReplyMessage::from(
-        bincode::serialize(&reply)
-            .expect("Serialization of reply should not fail"),
-    )
-}
-
-fn proxy_append_entries<
-    Command: Clone + Serialize + DeserializeOwned + Default,
->(
-    raft: &Raft<Command>,
-    data: RequestMessage,
-) -> ReplyMessage {
-    let reply = raft.process_append_entries(
-        bincode::deserialize(data.as_ref())
-            .expect("Deserialization should not fail"),
-    );
-
-    ReplyMessage::from(
-        bincode::serialize(&reply).expect("Serialization should not fail"),
-    )
-}
-
-fn proxy_install_snapshot<Command: Clone + Serialize + Default>(
-    raft: &Raft<Command>,
-    data: RequestMessage,
-) -> ReplyMessage {
-    let reply = raft.process_install_snapshot(
-        bincode::deserialize(data.as_ref())
-            .expect("Deserialization should not fail"),
-    );
-
-    ReplyMessage::from(
-        bincode::serialize(&reply).expect("Serialization should not fail"),
-    )
-}
-
 pub(crate) const REQUEST_VOTE_RPC: &str = "Raft.RequestVote";
 pub(crate) const APPEND_ENTRIES_RPC: &str = "Raft.AppendEntries";
 pub(crate) const INSTALL_SNAPSHOT_RPC: &str = "Raft.InstallSnapshot";
@@ -65,54 +20,67 @@ impl RpcClient {
         Self(client)
     }
 
-    pub(crate) async fn call_request_vote(
+    async fn call_rpc<Method, Request, Reply>(
         &self,
-        request: RequestVoteArgs,
-    ) -> std::io::Result<RequestVoteReply> {
+        method: Method,
+        request: Request,
+    ) -> std::io::Result<Reply>
+    where
+        Method: AsRef<str>,
+        Request: Serialize,
+        Reply: DeserializeOwned,
+    {
         let data = RequestMessage::from(
             bincode::serialize(&request)
                 .expect("Serialization of requests should not fail"),
         );
 
-        let reply = self.0.call_rpc(REQUEST_VOTE_RPC.to_owned(), data).await?;
+        let reply = self.0.call_rpc(method.as_ref().to_owned(), data).await?;
 
         Ok(bincode::deserialize(reply.as_ref())
             .expect("Deserialization of reply should not fail"))
     }
 
+    pub(crate) async fn call_request_vote(
+        &self,
+        request: RequestVoteArgs,
+    ) -> std::io::Result<RequestVoteReply> {
+        self.call_rpc(REQUEST_VOTE_RPC, request).await
+    }
+
     pub(crate) async fn call_append_entries<Command: Serialize>(
         &self,
         request: AppendEntriesArgs<Command>,
     ) -> std::io::Result<AppendEntriesReply> {
-        let data = RequestMessage::from(
-            bincode::serialize(&request)
-                .expect("Serialization of requests should not fail"),
-        );
-
-        let reply =
-            self.0.call_rpc(APPEND_ENTRIES_RPC.to_owned(), data).await?;
-
-        Ok(bincode::deserialize(reply.as_ref())
-            .expect("Deserialization of reply should not fail"))
+        self.call_rpc(APPEND_ENTRIES_RPC, request).await
     }
 
     pub(crate) async fn call_install_snapshot(
         &self,
         request: InstallSnapshotArgs,
     ) -> std::io::Result<InstallSnapshotReply> {
-        let data = RequestMessage::from(
-            bincode::serialize(&request)
-                .expect("Serialization of requests should not fail"),
-        );
+        self.call_rpc(INSTALL_SNAPSHOT_RPC, request).await
+    }
+}
 
-        let reply = self
-            .0
-            .call_rpc(INSTALL_SNAPSHOT_RPC.to_owned(), data)
-            .await?;
+fn make_rpc_handler<Request, Reply, F>(
+    func: F,
+) -> Box<dyn Fn(RequestMessage) -> ReplyMessage>
+where
+    Request: DeserializeOwned,
+    Reply: Serialize,
+    F: 'static + Fn(Request) -> Reply,
+{
+    Box::new(move |request| {
+        let reply = func(
+            bincode::deserialize(&request)
+                .expect("Deserialization should not fail"),
+        );
 
-        Ok(bincode::deserialize(reply.as_ref())
-            .expect("Deserialization of reply should not fail"))
-    }
+        ReplyMessage::from(
+            bincode::serialize(&reply).expect("Serialization should not fail"),
+        )
+    })
 }
 
 pub fn register_server<
@@ -131,24 +99,24 @@ pub fn register_server<
     let raft_clone = raft.clone();
     server.register_rpc_handler(
         REQUEST_VOTE_RPC.to_owned(),
-        Box::new(move |request| {
-            proxy_request_vote(raft_clone.as_ref(), request)
+        make_rpc_handler(move |args| {
+            raft_clone.as_ref().process_request_vote(args)
         }),
     )?;
 
     let raft_clone = raft.clone();
     server.register_rpc_handler(
         APPEND_ENTRIES_RPC.to_owned(),
-        Box::new(move |request| {
-            proxy_append_entries(raft_clone.as_ref(), request)
+        make_rpc_handler(move |args| {
+            raft_clone.as_ref().process_append_entries(args)
         }),
     )?;
 
     let raft_clone = raft;
     server.register_rpc_handler(
         INSTALL_SNAPSHOT_RPC.to_owned(),
-        Box::new(move |request| {
-            proxy_install_snapshot(raft_clone.as_ref(), request)
+        make_rpc_handler(move |args| {
+            raft_clone.as_ref().process_install_snapshot(args)
         }),
     )?;