rpcs.rs 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. use async_trait::async_trait;
  2. use labrpc::{Client, Network, ReplyMessage, RequestMessage, Server};
  3. use parking_lot::Mutex;
  4. use serde::de::DeserializeOwned;
  5. use serde::Serialize;
  6. use kvraft::{
  7. GetArgs, GetReply, KVServer, PutAppendArgs, PutAppendReply, RemoteKvraft,
  8. };
  9. use ruaft::{
  10. AppendEntriesArgs, AppendEntriesReply, InstallSnapshotArgs,
  11. InstallSnapshotReply, Raft, RequestVoteArgs, RequestVoteReply,
  12. };
  13. const REQUEST_VOTE_RPC: &str = "Raft.RequestVote";
  14. const APPEND_ENTRIES_RPC: &str = "Raft.AppendEntries";
  15. const INSTALL_SNAPSHOT_RPC: &str = "Raft.InstallSnapshot";
  16. pub struct RpcClient(Client);
  17. impl RpcClient {
  18. pub fn new(client: Client) -> Self {
  19. Self(client)
  20. }
  21. async fn call_rpc<Method, Request, Reply>(
  22. &self,
  23. method: Method,
  24. request: Request,
  25. ) -> std::io::Result<Reply>
  26. where
  27. Method: AsRef<str>,
  28. Request: Serialize,
  29. Reply: DeserializeOwned,
  30. {
  31. let data = RequestMessage::from(
  32. bincode::serialize(&request)
  33. .expect("Serialization of requests should not fail"),
  34. );
  35. let reply = self.0.call_rpc(method.as_ref().to_owned(), data).await?;
  36. Ok(bincode::deserialize(reply.as_ref())
  37. .expect("Deserialization of reply should not fail"))
  38. }
  39. }
  40. #[async_trait]
  41. impl<Command: 'static + Send + Serialize> ruaft::RemoteRaft<Command>
  42. for RpcClient
  43. {
  44. async fn request_vote(
  45. &self,
  46. args: RequestVoteArgs,
  47. ) -> std::io::Result<RequestVoteReply> {
  48. self.call_rpc(REQUEST_VOTE_RPC, args).await
  49. }
  50. async fn append_entries(
  51. &self,
  52. args: AppendEntriesArgs<Command>,
  53. ) -> std::io::Result<AppendEntriesReply> {
  54. self.call_rpc(APPEND_ENTRIES_RPC, args).await
  55. }
  56. async fn install_snapshot(
  57. &self,
  58. args: InstallSnapshotArgs,
  59. ) -> std::io::Result<InstallSnapshotReply> {
  60. self.call_rpc(INSTALL_SNAPSHOT_RPC, args).await
  61. }
  62. }
  63. const GET: &str = "KVServer.Get";
  64. const PUT_APPEND: &str = "KVServer.PutAppend";
  65. #[async_trait]
  66. impl RemoteKvraft for RpcClient {
  67. async fn get(&self, args: GetArgs) -> std::io::Result<GetReply> {
  68. self.call_rpc(GET, args).await
  69. }
  70. async fn put_append(
  71. &self,
  72. args: PutAppendArgs,
  73. ) -> std::io::Result<PutAppendReply> {
  74. self.call_rpc(PUT_APPEND, args).await
  75. }
  76. }
  77. pub fn make_rpc_handler<Request, Reply, F>(
  78. func: F,
  79. ) -> Box<dyn Fn(RequestMessage) -> ReplyMessage>
  80. where
  81. Request: DeserializeOwned,
  82. Reply: Serialize,
  83. F: 'static + Fn(Request) -> Reply,
  84. {
  85. Box::new(move |request| {
  86. let reply = func(
  87. bincode::deserialize(&request)
  88. .expect("Deserialization should not fail"),
  89. );
  90. ReplyMessage::from(
  91. bincode::serialize(&reply).expect("Serialization should not fail"),
  92. )
  93. })
  94. }
  95. pub fn register_server<
  96. Command: 'static + Clone + Serialize + DeserializeOwned + Default,
  97. R: 'static + AsRef<Raft<Command>> + Clone,
  98. S: AsRef<str>,
  99. >(
  100. raft: R,
  101. name: S,
  102. network: &Mutex<Network>,
  103. ) -> std::io::Result<()> {
  104. let mut network = network.lock();
  105. let server_name = name.as_ref();
  106. let mut server = Server::make_server(server_name);
  107. let raft_clone = raft.clone();
  108. server.register_rpc_handler(
  109. REQUEST_VOTE_RPC.to_owned(),
  110. make_rpc_handler(move |args| {
  111. raft_clone.as_ref().process_request_vote(args)
  112. }),
  113. )?;
  114. let raft_clone = raft.clone();
  115. server.register_rpc_handler(
  116. APPEND_ENTRIES_RPC.to_owned(),
  117. make_rpc_handler(move |args| {
  118. raft_clone.as_ref().process_append_entries(args)
  119. }),
  120. )?;
  121. let raft_clone = raft;
  122. server.register_rpc_handler(
  123. INSTALL_SNAPSHOT_RPC.to_owned(),
  124. make_rpc_handler(move |args| {
  125. raft_clone.as_ref().process_install_snapshot(args)
  126. }),
  127. )?;
  128. network.add_server(server_name, server);
  129. Ok(())
  130. }
  131. pub fn register_kv_server<
  132. KV: 'static + AsRef<KVServer> + Clone,
  133. S: AsRef<str>,
  134. >(
  135. kv: KV,
  136. name: S,
  137. network: &Mutex<Network>,
  138. ) -> std::io::Result<()> {
  139. let mut network = network.lock();
  140. let server_name = name.as_ref();
  141. let mut server = Server::make_server(server_name);
  142. let kv_clone = kv.clone();
  143. server.register_rpc_handler(
  144. GET.to_owned(),
  145. make_rpc_handler(move |args| kv_clone.as_ref().get(args)),
  146. )?;
  147. server.register_rpc_handler(
  148. PUT_APPEND.to_owned(),
  149. make_rpc_handler(move |args| kv.as_ref().put_append(args)),
  150. )?;
  151. network.add_server(server_name, server);
  152. Ok(())
  153. }
  154. #[cfg(test)]
  155. mod tests {
  156. use std::sync::Arc;
  157. use bytes::Bytes;
  158. use ruaft::{ApplyCommandMessage, RemoteRaft, Term};
  159. use super::*;
  160. use ruaft::utils::integration_test::{
  161. make_append_entries_args, make_request_vote_args,
  162. unpack_append_entries_reply, unpack_request_vote_reply,
  163. };
  164. struct DoNothingPersister;
  165. impl ruaft::Persister for DoNothingPersister {
  166. fn read_state(&self) -> Bytes {
  167. Bytes::new()
  168. }
  169. fn save_state(&self, _bytes: Bytes) {}
  170. fn state_size(&self) -> usize {
  171. 0
  172. }
  173. fn save_snapshot_and_state(&self, _: Bytes, _: &[u8]) {}
  174. }
  175. #[test]
  176. fn test_basic_message() -> std::io::Result<()> {
  177. test_utils::init_test_log!();
  178. let client = {
  179. let network = Network::run_daemon();
  180. let name = "test-basic-message";
  181. let client = network
  182. .lock()
  183. .make_client("test-basic-message", name.to_owned());
  184. let raft = Arc::new(Raft::new(
  185. vec![RpcClient(client)],
  186. 0,
  187. Arc::new(DoNothingPersister),
  188. |_: ApplyCommandMessage<i32>| {},
  189. None,
  190. crate::utils::NO_SNAPSHOT,
  191. ));
  192. register_server(raft, name, network.as_ref())?;
  193. let client = network
  194. .lock()
  195. .make_client("test-basic-message", name.to_owned());
  196. client
  197. };
  198. let rpc_client = RpcClient(client);
  199. let request = make_request_vote_args(Term(2021), 0, 0, Term(0));
  200. let response = futures::executor::block_on(
  201. (&rpc_client as &dyn RemoteRaft<i32>).request_vote(request),
  202. )?;
  203. let (_, vote_granted) = unpack_request_vote_reply(response);
  204. assert!(vote_granted);
  205. let request =
  206. make_append_entries_args::<i32>(Term(2021), 0, 0, Term(0), 0);
  207. let response =
  208. futures::executor::block_on(rpc_client.append_entries(request))?;
  209. let (Term(term), success) = unpack_append_entries_reply(response);
  210. assert_eq!(2021, term);
  211. assert!(success);
  212. Ok(())
  213. }
  214. }