rpcs.rs 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. use labrpc::{Client, Network, ReplyMessage, RequestMessage, Server};
  2. use parking_lot::Mutex;
  3. use crate::install_snapshot::{InstallSnapshotArgs, InstallSnapshotReply};
  4. use crate::{
  5. AppendEntriesArgs, AppendEntriesReply, Raft, RequestVoteArgs,
  6. RequestVoteReply,
  7. };
  8. use serde::de::DeserializeOwned;
  9. use serde::Serialize;
  10. pub(crate) const REQUEST_VOTE_RPC: &str = "Raft.RequestVote";
  11. pub(crate) const APPEND_ENTRIES_RPC: &str = "Raft.AppendEntries";
  12. pub(crate) const INSTALL_SNAPSHOT_RPC: &str = "Raft.InstallSnapshot";
  13. pub struct RpcClient(Client);
  14. impl RpcClient {
  15. pub fn new(client: Client) -> Self {
  16. Self(client)
  17. }
  18. async fn call_rpc<Method, Request, Reply>(
  19. &self,
  20. method: Method,
  21. request: Request,
  22. ) -> std::io::Result<Reply>
  23. where
  24. Method: AsRef<str>,
  25. Request: Serialize,
  26. Reply: DeserializeOwned,
  27. {
  28. let data = RequestMessage::from(
  29. bincode::serialize(&request)
  30. .expect("Serialization of requests should not fail"),
  31. );
  32. let reply = self.0.call_rpc(method.as_ref().to_owned(), data).await?;
  33. Ok(bincode::deserialize(reply.as_ref())
  34. .expect("Deserialization of reply should not fail"))
  35. }
  36. pub(crate) async fn call_request_vote(
  37. &self,
  38. request: RequestVoteArgs,
  39. ) -> std::io::Result<RequestVoteReply> {
  40. self.call_rpc(REQUEST_VOTE_RPC, request).await
  41. }
  42. pub(crate) async fn call_append_entries<Command: Serialize>(
  43. &self,
  44. request: AppendEntriesArgs<Command>,
  45. ) -> std::io::Result<AppendEntriesReply> {
  46. self.call_rpc(APPEND_ENTRIES_RPC, request).await
  47. }
  48. pub(crate) async fn call_install_snapshot(
  49. &self,
  50. request: InstallSnapshotArgs,
  51. ) -> std::io::Result<InstallSnapshotReply> {
  52. self.call_rpc(INSTALL_SNAPSHOT_RPC, request).await
  53. }
  54. }
  55. pub fn make_rpc_handler<Request, Reply, F>(
  56. func: F,
  57. ) -> Box<dyn Fn(RequestMessage) -> ReplyMessage>
  58. where
  59. Request: DeserializeOwned,
  60. Reply: Serialize,
  61. F: 'static + Fn(Request) -> Reply,
  62. {
  63. Box::new(move |request| {
  64. let reply = func(
  65. bincode::deserialize(&request)
  66. .expect("Deserialization should not fail"),
  67. );
  68. ReplyMessage::from(
  69. bincode::serialize(&reply).expect("Serialization should not fail"),
  70. )
  71. })
  72. }
  73. pub fn register_server<
  74. Command: 'static + Clone + Serialize + DeserializeOwned + Default,
  75. R: 'static + AsRef<Raft<Command>> + Clone,
  76. S: AsRef<str>,
  77. >(
  78. raft: R,
  79. name: S,
  80. network: &Mutex<Network>,
  81. ) -> std::io::Result<()> {
  82. let mut network = network.lock();
  83. let server_name = name.as_ref();
  84. let mut server = Server::make_server(server_name);
  85. let raft_clone = raft.clone();
  86. server.register_rpc_handler(
  87. REQUEST_VOTE_RPC.to_owned(),
  88. make_rpc_handler(move |args| {
  89. raft_clone.as_ref().process_request_vote(args)
  90. }),
  91. )?;
  92. let raft_clone = raft.clone();
  93. server.register_rpc_handler(
  94. APPEND_ENTRIES_RPC.to_owned(),
  95. make_rpc_handler(move |args| {
  96. raft_clone.as_ref().process_append_entries(args)
  97. }),
  98. )?;
  99. let raft_clone = raft;
  100. server.register_rpc_handler(
  101. INSTALL_SNAPSHOT_RPC.to_owned(),
  102. make_rpc_handler(move |args| {
  103. raft_clone.as_ref().process_install_snapshot(args)
  104. }),
  105. )?;
  106. network.add_server(server_name, server);
  107. Ok(())
  108. }
  109. #[cfg(test)]
  110. mod tests {
  111. use std::sync::Arc;
  112. use bytes::Bytes;
  113. use crate::{ApplyCommandMessage, Peer, Term};
  114. use super::*;
  115. type DoNothingPersister = ();
  116. impl crate::Persister for DoNothingPersister {
  117. fn read_state(&self) -> Bytes {
  118. Bytes::new()
  119. }
  120. fn save_state(&self, _bytes: Bytes) {}
  121. fn state_size(&self) -> usize {
  122. 0
  123. }
  124. fn save_snapshot_and_state(&self, _: Bytes, _: &[u8]) {}
  125. }
  126. #[test]
  127. fn test_basic_message() -> std::io::Result<()> {
  128. let client = {
  129. let network = Network::run_daemon();
  130. let name = "test-basic-message";
  131. let client = network
  132. .lock()
  133. .make_client("test-basic-message", name.to_owned());
  134. let raft = Arc::new(Raft::new(
  135. vec![RpcClient(client)],
  136. 0,
  137. Arc::new(()),
  138. |_: ApplyCommandMessage<i32>| {},
  139. None,
  140. Raft::<i32>::NO_SNAPSHOT,
  141. ));
  142. register_server(raft, name, network.as_ref())?;
  143. let client = network
  144. .lock()
  145. .make_client("test-basic-message", name.to_owned());
  146. client
  147. };
  148. let rpc_client = RpcClient(client);
  149. let request = RequestVoteArgs {
  150. term: Term(2021),
  151. candidate_id: Peer(0),
  152. last_log_index: 0,
  153. last_log_term: Term(0),
  154. };
  155. let response =
  156. futures::executor::block_on(rpc_client.call_request_vote(request))?;
  157. assert_eq!(true, response.vote_granted);
  158. let request = AppendEntriesArgs::<i32> {
  159. term: Term(2021),
  160. leader_id: Peer(0),
  161. prev_log_index: 0,
  162. prev_log_term: Term(0),
  163. entries: vec![],
  164. leader_commit: 0,
  165. };
  166. let response = futures::executor::block_on(
  167. rpc_client.call_append_entries(request),
  168. )?;
  169. assert_eq!(2021, response.term.0);
  170. assert_eq!(true, response.success);
  171. Ok(())
  172. }
  173. }