rpcs.rs 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. use std::future::Future;
  2. use async_trait::async_trait;
  3. use futures_util::future::BoxFuture;
  4. use futures_util::FutureExt;
  5. use labrpc::{Client, Network, ReplyMessage, RequestMessage, Server};
  6. use parking_lot::Mutex;
  7. use serde::de::DeserializeOwned;
  8. use serde::Serialize;
  9. use kvraft::{
  10. CommitSentinelArgs, CommitSentinelReply, GetArgs, GetReply, KVServer,
  11. PutAppendArgs, PutAppendReply, RemoteKvraft,
  12. };
  13. use ruaft::{
  14. AppendEntriesArgs, AppendEntriesReply, InstallSnapshotArgs,
  15. InstallSnapshotReply, Raft, ReplicableCommand, RequestVoteArgs,
  16. RequestVoteReply,
  17. };
  18. const REQUEST_VOTE_RPC: &str = "Raft.RequestVote";
  19. const APPEND_ENTRIES_RPC: &str = "Raft.AppendEntries";
  20. const INSTALL_SNAPSHOT_RPC: &str = "Raft.InstallSnapshot";
  21. pub struct RpcClient(Client);
  22. impl RpcClient {
  23. pub fn new(client: Client) -> Self {
  24. Self(client)
  25. }
  26. async fn call_rpc<Method, Request, Reply>(
  27. &self,
  28. method: Method,
  29. request: Request,
  30. ) -> std::io::Result<Reply>
  31. where
  32. Method: AsRef<str>,
  33. Request: Serialize,
  34. Reply: DeserializeOwned,
  35. {
  36. let data = RequestMessage::from(
  37. bincode::serialize(&request)
  38. .expect("Serialization of requests should not fail"),
  39. );
  40. let reply = self.0.call_rpc(method.as_ref().to_owned(), data).await?;
  41. Ok(bincode::deserialize(reply.as_ref())
  42. .expect("Deserialization of reply should not fail"))
  43. }
  44. }
  45. #[async_trait]
  46. impl<Command: ReplicableCommand> ruaft::RemoteRaft<Command> for RpcClient {
  47. async fn request_vote(
  48. &self,
  49. args: RequestVoteArgs,
  50. ) -> std::io::Result<RequestVoteReply> {
  51. self.call_rpc(REQUEST_VOTE_RPC, args).await
  52. }
  53. async fn append_entries(
  54. &self,
  55. args: AppendEntriesArgs<Command>,
  56. ) -> std::io::Result<AppendEntriesReply> {
  57. self.call_rpc(APPEND_ENTRIES_RPC, args).await
  58. }
  59. async fn install_snapshot(
  60. &self,
  61. args: InstallSnapshotArgs,
  62. ) -> std::io::Result<InstallSnapshotReply> {
  63. self.call_rpc(INSTALL_SNAPSHOT_RPC, args).await
  64. }
  65. }
  66. const GET: &str = "KVServer.Get";
  67. const PUT_APPEND: &str = "KVServer.PutAppend";
  68. const COMMIT_SENTINEL: &str = "KVServer.CommitSentinel";
  69. #[async_trait]
  70. impl RemoteKvraft for RpcClient {
  71. async fn get(&self, args: GetArgs) -> std::io::Result<GetReply> {
  72. self.call_rpc(GET, args).await
  73. }
  74. async fn put_append(
  75. &self,
  76. args: PutAppendArgs,
  77. ) -> std::io::Result<PutAppendReply> {
  78. self.call_rpc(PUT_APPEND, args).await
  79. }
  80. async fn commit_sentinel(
  81. &self,
  82. args: CommitSentinelArgs,
  83. ) -> std::io::Result<CommitSentinelReply> {
  84. self.call_rpc(COMMIT_SENTINEL, args).await
  85. }
  86. }
  87. pub fn make_rpc_handler<Request, Reply, F>(
  88. func: F,
  89. ) -> impl Fn(RequestMessage) -> ReplyMessage
  90. where
  91. Request: DeserializeOwned,
  92. Reply: Serialize,
  93. F: 'static + Fn(Request) -> Reply,
  94. {
  95. move |request| {
  96. let reply = func(
  97. bincode::deserialize(&request)
  98. .expect("Deserialization should not fail"),
  99. );
  100. ReplyMessage::from(
  101. bincode::serialize(&reply).expect("Serialization should not fail"),
  102. )
  103. }
  104. }
  105. pub fn make_async_rpc_handler<'a, Request, Reply, F, Fut>(
  106. func: F,
  107. ) -> impl Fn(RequestMessage) -> BoxFuture<'a, ReplyMessage>
  108. where
  109. Request: DeserializeOwned + Send,
  110. Reply: Serialize,
  111. Fut: Future<Output = Reply> + Send + 'a,
  112. F: 'a + Send + Clone + FnOnce(Request) -> Fut,
  113. {
  114. move |request| {
  115. let func = func.clone();
  116. let fut = async move {
  117. let request = bincode::deserialize(&request)
  118. .expect("Deserialization should not fail");
  119. let reply = func(request).await;
  120. ReplyMessage::from(
  121. bincode::serialize(&reply)
  122. .expect("Serialization should not fail"),
  123. )
  124. };
  125. fut.boxed()
  126. }
  127. }
  128. pub fn register_server<
  129. Command: 'static + Clone + Send + Serialize + DeserializeOwned,
  130. S: AsRef<str>,
  131. >(
  132. raft: Raft<Command>,
  133. name: S,
  134. network: &Mutex<Network>,
  135. ) -> std::io::Result<()> {
  136. let mut network = network.lock();
  137. let server_name = name.as_ref();
  138. let mut server = Server::make_server(server_name);
  139. server.register_rpc_handler(REQUEST_VOTE_RPC.to_owned(), {
  140. let raft = raft.clone();
  141. make_rpc_handler(move |args| raft.process_request_vote(args))
  142. })?;
  143. server.register_rpc_handler(APPEND_ENTRIES_RPC.to_owned(), {
  144. let raft = raft.clone();
  145. make_rpc_handler(move |args| raft.process_append_entries(args))
  146. })?;
  147. server.register_rpc_handler(
  148. INSTALL_SNAPSHOT_RPC.to_owned(),
  149. make_rpc_handler(move |args| raft.process_install_snapshot(args)),
  150. )?;
  151. network.add_server(server_name, server);
  152. Ok(())
  153. }
  154. pub fn register_kv_server<
  155. KV: 'static + AsRef<KVServer> + Send + Sync + Clone,
  156. S: AsRef<str>,
  157. >(
  158. kv: KV,
  159. name: S,
  160. network: &Mutex<Network>,
  161. ) -> std::io::Result<()> {
  162. let mut network = network.lock();
  163. let server_name = name.as_ref();
  164. let mut server = Server::make_server(server_name);
  165. server.register_async_rpc_handler(GET.to_owned(), {
  166. let kv = kv.clone();
  167. // Note: make_async_rpc_handler expects a FnOnce.
  168. make_async_rpc_handler(move |args| async move {
  169. kv.as_ref().get(args).await
  170. })
  171. })?;
  172. server.register_async_rpc_handler(PUT_APPEND.to_owned(), {
  173. let kv = kv.clone();
  174. make_async_rpc_handler(move |args| async move {
  175. kv.as_ref().put_append(args).await
  176. })
  177. })?;
  178. server.register_async_rpc_handler(
  179. COMMIT_SENTINEL.to_owned(),
  180. make_async_rpc_handler(move |args| async move {
  181. kv.as_ref().commit_sentinel(args).await
  182. }),
  183. )?;
  184. network.add_server(server_name, server);
  185. Ok(())
  186. }
  187. #[cfg(test)]
  188. mod tests {
  189. use ruaft::utils::do_nothing::DoNothingPersister;
  190. use ruaft::utils::integration_test::{
  191. make_append_entries_args, make_request_vote_args,
  192. unpack_append_entries_reply, unpack_request_vote_reply,
  193. };
  194. use ruaft::{ApplyCommandMessage, RemoteRaft, Term};
  195. use super::*;
  196. #[test]
  197. fn test_basic_message() -> std::io::Result<()> {
  198. test_utils::init_test_log!();
  199. let client = {
  200. let network = Network::run_daemon();
  201. let name = "test-basic-message";
  202. let client = network
  203. .lock()
  204. .make_client("test-basic-message", name.to_owned());
  205. let raft = Raft::new(
  206. vec![RpcClient(client)],
  207. 0,
  208. DoNothingPersister,
  209. |_: ApplyCommandMessage<i32>| {},
  210. None,
  211. crate::utils::NO_SNAPSHOT,
  212. );
  213. register_server(raft, name, network.as_ref())?;
  214. let client = network
  215. .lock()
  216. .make_client("test-basic-message", name.to_owned());
  217. client
  218. };
  219. let rpc_client = RpcClient(client);
  220. let request = make_request_vote_args(Term(2021), 0, 0, Term(0));
  221. let response = futures::executor::block_on(
  222. (&rpc_client as &dyn RemoteRaft<i32>).request_vote(request),
  223. )?;
  224. let (_, vote_granted) = unpack_request_vote_reply(response);
  225. assert!(vote_granted);
  226. let request =
  227. make_append_entries_args::<i32>(Term(2021), 0, 0, Term(0), 0);
  228. let response =
  229. futures::executor::block_on(rpc_client.append_entries(request))?;
  230. let (Term(term), success) = unpack_append_entries_reply(response);
  231. assert_eq!(2021, term);
  232. assert!(success);
  233. Ok(())
  234. }
  235. }