rpcs.rs 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. use labrpc::{Client, Network, ReplyMessage, RequestMessage, Server};
  2. use parking_lot::Mutex;
  3. use crate::{
  4. AppendEntriesArgs, AppendEntriesReply, InstallSnapshotArgs,
  5. InstallSnapshotReply, Raft, RequestVoteArgs, RequestVoteReply,
  6. };
  7. use serde::de::DeserializeOwned;
  8. use serde::Serialize;
  9. pub(crate) const REQUEST_VOTE_RPC: &str = "Raft.RequestVote";
  10. pub(crate) const APPEND_ENTRIES_RPC: &str = "Raft.AppendEntries";
  11. pub(crate) const INSTALL_SNAPSHOT_RPC: &str = "Raft.InstallSnapshot";
  12. pub struct RpcClient(Client);
  13. impl RpcClient {
  14. pub fn new(client: Client) -> Self {
  15. Self(client)
  16. }
  17. async fn call_rpc<Method, Request, Reply>(
  18. &self,
  19. method: Method,
  20. request: Request,
  21. ) -> std::io::Result<Reply>
  22. where
  23. Method: AsRef<str>,
  24. Request: Serialize,
  25. Reply: DeserializeOwned,
  26. {
  27. let data = RequestMessage::from(
  28. bincode::serialize(&request)
  29. .expect("Serialization of requests should not fail"),
  30. );
  31. let reply = self.0.call_rpc(method.as_ref().to_owned(), data).await?;
  32. Ok(bincode::deserialize(reply.as_ref())
  33. .expect("Deserialization of reply should not fail"))
  34. }
  35. pub(crate) async fn call_request_vote(
  36. &self,
  37. request: RequestVoteArgs,
  38. ) -> std::io::Result<RequestVoteReply> {
  39. self.call_rpc(REQUEST_VOTE_RPC, request).await
  40. }
  41. pub(crate) async fn call_append_entries<Command: Serialize>(
  42. &self,
  43. request: AppendEntriesArgs<Command>,
  44. ) -> std::io::Result<AppendEntriesReply> {
  45. self.call_rpc(APPEND_ENTRIES_RPC, request).await
  46. }
  47. pub(crate) async fn call_install_snapshot(
  48. &self,
  49. request: InstallSnapshotArgs,
  50. ) -> std::io::Result<InstallSnapshotReply> {
  51. self.call_rpc(INSTALL_SNAPSHOT_RPC, request).await
  52. }
  53. }
  54. pub fn make_rpc_handler<Request, Reply, F>(
  55. func: F,
  56. ) -> Box<dyn Fn(RequestMessage) -> ReplyMessage>
  57. where
  58. Request: DeserializeOwned,
  59. Reply: Serialize,
  60. F: 'static + Fn(Request) -> Reply,
  61. {
  62. Box::new(move |request| {
  63. let reply = func(
  64. bincode::deserialize(&request)
  65. .expect("Deserialization should not fail"),
  66. );
  67. ReplyMessage::from(
  68. bincode::serialize(&reply).expect("Serialization should not fail"),
  69. )
  70. })
  71. }
  72. pub fn register_server<
  73. Command: 'static + Clone + Serialize + DeserializeOwned + Default,
  74. R: 'static + AsRef<Raft<Command>> + Clone,
  75. S: AsRef<str>,
  76. >(
  77. raft: R,
  78. name: S,
  79. network: &Mutex<Network>,
  80. ) -> std::io::Result<()> {
  81. let mut network = network.lock();
  82. let server_name = name.as_ref();
  83. let mut server = Server::make_server(server_name);
  84. let raft_clone = raft.clone();
  85. server.register_rpc_handler(
  86. REQUEST_VOTE_RPC.to_owned(),
  87. make_rpc_handler(move |args| {
  88. raft_clone.as_ref().process_request_vote(args)
  89. }),
  90. )?;
  91. let raft_clone = raft.clone();
  92. server.register_rpc_handler(
  93. APPEND_ENTRIES_RPC.to_owned(),
  94. make_rpc_handler(move |args| {
  95. raft_clone.as_ref().process_append_entries(args)
  96. }),
  97. )?;
  98. let raft_clone = raft;
  99. server.register_rpc_handler(
  100. INSTALL_SNAPSHOT_RPC.to_owned(),
  101. make_rpc_handler(move |args| {
  102. raft_clone.as_ref().process_install_snapshot(args)
  103. }),
  104. )?;
  105. network.add_server(server_name, server);
  106. Ok(())
  107. }
  108. #[cfg(test)]
  109. mod tests {
  110. use std::sync::Arc;
  111. use bytes::Bytes;
  112. use crate::{ApplyCommandMessage, Peer, Term};
  113. use super::*;
  114. type DoNothingPersister = ();
  115. impl crate::Persister for DoNothingPersister {
  116. fn read_state(&self) -> Bytes {
  117. Bytes::new()
  118. }
  119. fn save_state(&self, _bytes: Bytes) {}
  120. fn state_size(&self) -> usize {
  121. 0
  122. }
  123. fn save_snapshot_and_state(&self, _: Bytes, _: &[u8]) {}
  124. }
  125. #[test]
  126. fn test_basic_message() -> std::io::Result<()> {
  127. let client = {
  128. let network = Network::run_daemon();
  129. let name = "test-basic-message";
  130. let client = network
  131. .lock()
  132. .make_client("test-basic-message", name.to_owned());
  133. let raft = Arc::new(Raft::new(
  134. vec![RpcClient(client)],
  135. 0,
  136. Arc::new(()),
  137. |_: ApplyCommandMessage<i32>| {},
  138. None,
  139. Raft::<i32>::NO_SNAPSHOT,
  140. ));
  141. register_server(raft, name, network.as_ref())?;
  142. let client = network
  143. .lock()
  144. .make_client("test-basic-message", name.to_owned());
  145. client
  146. };
  147. let rpc_client = RpcClient(client);
  148. let request = RequestVoteArgs {
  149. term: Term(2021),
  150. candidate_id: Peer(0),
  151. last_log_index: 0,
  152. last_log_term: Term(0),
  153. };
  154. let response =
  155. futures::executor::block_on(rpc_client.call_request_vote(request))?;
  156. assert!(response.vote_granted);
  157. let request = AppendEntriesArgs::<i32> {
  158. term: Term(2021),
  159. leader_id: Peer(0),
  160. prev_log_index: 0,
  161. prev_log_term: Term(0),
  162. entries: vec![],
  163. leader_commit: 0,
  164. };
  165. let response = futures::executor::block_on(
  166. rpc_client.call_append_entries(request),
  167. )?;
  168. assert_eq!(2021, response.term.0);
  169. assert!(response.success);
  170. Ok(())
  171. }
  172. }