rpcs.rs 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. use std::sync::{Arc, Mutex};
  2. use labrpc::{
  3. Client, Network, ReplyMessage, RequestMessage, RpcHandler, Server,
  4. };
  5. use crate::{
  6. AppendEntriesArgs, AppendEntriesReply, Raft, RequestVoteArgs,
  7. RequestVoteReply,
  8. };
  9. struct RequestVoteRpcHandler(Arc<Raft>);
  10. impl RpcHandler for RequestVoteRpcHandler {
  11. fn call(&self, data: RequestMessage) -> ReplyMessage {
  12. let reply = self.0.process_request_vote(
  13. bincode::deserialize(data.as_ref())
  14. .expect("Deserialization of requests should not fail"),
  15. );
  16. ReplyMessage::from(
  17. bincode::serialize(&reply)
  18. .expect("Serialization of reply should not fail"),
  19. )
  20. }
  21. }
  22. struct AppendEntriesRpcHandler(Arc<Raft>);
  23. impl RpcHandler for AppendEntriesRpcHandler {
  24. fn call(&self, data: RequestMessage) -> ReplyMessage {
  25. let reply = self.0.process_append_entries(
  26. bincode::deserialize(data.as_ref())
  27. .expect("Deserialization should not fail"),
  28. );
  29. ReplyMessage::from(
  30. bincode::serialize(&reply).expect("Serialization should not fail"),
  31. )
  32. }
  33. }
  34. pub(crate) const REQUEST_VOTE_RPC: &str = "Raft.RequestVote";
  35. pub(crate) const APPEND_ENTRIES_RPC: &str = "Raft.AppendEntries";
  36. #[derive(Clone)]
  37. pub struct RpcClient(Client);
  38. impl RpcClient {
  39. pub fn new(client: Client) -> Self {
  40. Self(client)
  41. }
  42. pub(crate) async fn call_request_vote(
  43. self: Self,
  44. request: RequestVoteArgs,
  45. ) -> std::io::Result<RequestVoteReply> {
  46. let data = RequestMessage::from(
  47. bincode::serialize(&request)
  48. .expect("Serialization of requests should not fail"),
  49. );
  50. let reply = self.0.call_rpc(REQUEST_VOTE_RPC.to_owned(), data).await?;
  51. Ok(bincode::deserialize(reply.as_ref())
  52. .expect("Deserialization of reply should not fail"))
  53. }
  54. pub(crate) async fn call_append_entries(
  55. self: Self,
  56. request: AppendEntriesArgs,
  57. ) -> std::io::Result<AppendEntriesReply> {
  58. let data = RequestMessage::from(
  59. bincode::serialize(&request)
  60. .expect("Serialization of requests should not fail"),
  61. );
  62. let reply =
  63. self.0.call_rpc(APPEND_ENTRIES_RPC.to_owned(), data).await?;
  64. Ok(bincode::deserialize(reply.as_ref())
  65. .expect("Deserialization of reply should not fail"))
  66. }
  67. }
  68. pub fn register_server<S: AsRef<str>>(
  69. raft: Arc<Raft>,
  70. name: S,
  71. network: &Mutex<Network>,
  72. ) -> std::io::Result<()> {
  73. let mut network =
  74. network.lock().expect("Network lock should not be poisoned");
  75. let server_name = name.as_ref().clone();
  76. let mut server = Server::make_server(server_name.clone());
  77. let request_vote_rpc_handler = RequestVoteRpcHandler(raft.clone());
  78. server.register_rpc_handler(
  79. REQUEST_VOTE_RPC.to_owned(),
  80. Box::new(request_vote_rpc_handler),
  81. )?;
  82. let append_entries_rpc_handler = AppendEntriesRpcHandler(raft);
  83. server.register_rpc_handler(
  84. APPEND_ENTRIES_RPC.to_owned(),
  85. Box::new(append_entries_rpc_handler),
  86. )?;
  87. network.add_server(server_name, server);
  88. Ok(())
  89. }
  90. #[cfg(test)]
  91. mod tests {
  92. use bytes::Bytes;
  93. use crate::{Peer, Term};
  94. use super::*;
  95. type DoNothingPersister = ();
  96. impl crate::Persister for DoNothingPersister {
  97. fn read_state(&self) -> Bytes {
  98. Bytes::new()
  99. }
  100. fn save_state(&self, _bytes: Bytes) {}
  101. }
  102. #[test]
  103. fn test_basic_message() -> std::io::Result<()> {
  104. let client = {
  105. let network = Network::run_daemon();
  106. let name = "test-basic-message";
  107. let client = network
  108. .lock()
  109. .expect("Network lock should not be poisoned")
  110. .make_client("test-basic-message", name.to_owned());
  111. let raft = Arc::new(Raft::new(
  112. vec![RpcClient(client.clone())],
  113. 0,
  114. Arc::new(()),
  115. |_, _| {},
  116. ));
  117. register_server(raft, name, network.as_ref())?;
  118. client
  119. };
  120. let rpc_client = RpcClient(client);
  121. let request = RequestVoteArgs {
  122. term: Term(2021),
  123. candidate_id: Peer(0),
  124. last_log_index: 0,
  125. last_log_term: Term(0),
  126. };
  127. let response = futures::executor::block_on(
  128. rpc_client.clone().call_request_vote(request),
  129. )?;
  130. assert_eq!(true, response.vote_granted);
  131. let request = AppendEntriesArgs {
  132. term: Term(2021),
  133. leader_id: Peer(0),
  134. prev_log_index: 0,
  135. prev_log_term: Term(0),
  136. entries: vec![],
  137. leader_commit: 0,
  138. };
  139. let response = futures::executor::block_on(
  140. rpc_client.clone().call_append_entries(request),
  141. )?;
  142. assert_eq!(2021, response.term.0);
  143. assert_eq!(true, response.success);
  144. Ok(())
  145. }
  146. }