rpcs.rs 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. use std::future::Future;
  2. use async_trait::async_trait;
  3. use labrpc::{Client, Network, ReplyMessage, RequestMessage, Server};
  4. use parking_lot::Mutex;
  5. use serde::de::DeserializeOwned;
  6. use serde::Serialize;
  7. use futures_util::future::BoxFuture;
  8. use futures_util::FutureExt;
  9. use kvraft::{
  10. GetArgs, GetReply, KVServer, PutAppendArgs, PutAppendReply, RemoteKvraft,
  11. };
  12. use ruaft::{
  13. AppendEntriesArgs, AppendEntriesReply, InstallSnapshotArgs,
  14. InstallSnapshotReply, Raft, RequestVoteArgs, RequestVoteReply,
  15. };
  16. const REQUEST_VOTE_RPC: &str = "Raft.RequestVote";
  17. const APPEND_ENTRIES_RPC: &str = "Raft.AppendEntries";
  18. const INSTALL_SNAPSHOT_RPC: &str = "Raft.InstallSnapshot";
  19. pub struct RpcClient(Client);
  20. impl RpcClient {
  21. pub fn new(client: Client) -> Self {
  22. Self(client)
  23. }
  24. async fn call_rpc<Method, Request, Reply>(
  25. &self,
  26. method: Method,
  27. request: Request,
  28. ) -> std::io::Result<Reply>
  29. where
  30. Method: AsRef<str>,
  31. Request: Serialize,
  32. Reply: DeserializeOwned,
  33. {
  34. let data = RequestMessage::from(
  35. bincode::serialize(&request)
  36. .expect("Serialization of requests should not fail"),
  37. );
  38. let reply = self.0.call_rpc(method.as_ref().to_owned(), data).await?;
  39. Ok(bincode::deserialize(reply.as_ref())
  40. .expect("Deserialization of reply should not fail"))
  41. }
  42. }
  43. #[async_trait]
  44. impl<Command: 'static + Send + Serialize> ruaft::RemoteRaft<Command>
  45. for RpcClient
  46. {
  47. async fn request_vote(
  48. &self,
  49. args: RequestVoteArgs,
  50. ) -> std::io::Result<RequestVoteReply> {
  51. self.call_rpc(REQUEST_VOTE_RPC, args).await
  52. }
  53. async fn append_entries(
  54. &self,
  55. args: AppendEntriesArgs<Command>,
  56. ) -> std::io::Result<AppendEntriesReply> {
  57. self.call_rpc(APPEND_ENTRIES_RPC, args).await
  58. }
  59. async fn install_snapshot(
  60. &self,
  61. args: InstallSnapshotArgs,
  62. ) -> std::io::Result<InstallSnapshotReply> {
  63. self.call_rpc(INSTALL_SNAPSHOT_RPC, args).await
  64. }
  65. }
  66. const GET: &str = "KVServer.Get";
  67. const PUT_APPEND: &str = "KVServer.PutAppend";
  68. #[async_trait]
  69. impl RemoteKvraft for RpcClient {
  70. async fn get(&self, args: GetArgs) -> std::io::Result<GetReply> {
  71. self.call_rpc(GET, args).await
  72. }
  73. async fn put_append(
  74. &self,
  75. args: PutAppendArgs,
  76. ) -> std::io::Result<PutAppendReply> {
  77. self.call_rpc(PUT_APPEND, args).await
  78. }
  79. }
  80. pub fn make_rpc_handler<Request, Reply, F>(
  81. func: F,
  82. ) -> impl Fn(RequestMessage) -> ReplyMessage
  83. where
  84. Request: DeserializeOwned,
  85. Reply: Serialize,
  86. F: 'static + Fn(Request) -> Reply,
  87. {
  88. move |request| {
  89. let reply = func(
  90. bincode::deserialize(&request)
  91. .expect("Deserialization should not fail"),
  92. );
  93. ReplyMessage::from(
  94. bincode::serialize(&reply).expect("Serialization should not fail"),
  95. )
  96. }
  97. }
  98. pub fn make_async_rpc_handler<'a, Request, Reply, F, Fut>(
  99. func: F,
  100. ) -> impl Fn(RequestMessage) -> BoxFuture<'a, ReplyMessage>
  101. where
  102. Request: DeserializeOwned + Send,
  103. Reply: Serialize,
  104. Fut: Future<Output = Reply> + Send + 'a,
  105. F: 'a + Send + Clone + FnOnce(Request) -> Fut,
  106. {
  107. move |request| {
  108. let func = func.clone();
  109. let fut = async move {
  110. let request = bincode::deserialize(&request)
  111. .expect("Deserialization should not fail");
  112. let reply = func(request).await;
  113. ReplyMessage::from(
  114. bincode::serialize(&reply)
  115. .expect("Serialization should not fail"),
  116. )
  117. };
  118. fut.boxed()
  119. }
  120. }
  121. pub fn register_server<
  122. Command: 'static + Clone + Serialize + DeserializeOwned + Default,
  123. R: 'static + AsRef<Raft<Command>> + Send + Sync + Clone,
  124. S: AsRef<str>,
  125. >(
  126. raft: R,
  127. name: S,
  128. network: &Mutex<Network>,
  129. ) -> std::io::Result<()> {
  130. let mut network = network.lock();
  131. let server_name = name.as_ref();
  132. let mut server = Server::make_server(server_name);
  133. let raft_clone = raft.clone();
  134. server.register_rpc_handler(
  135. REQUEST_VOTE_RPC.to_owned(),
  136. make_rpc_handler(move |args| {
  137. raft_clone.as_ref().process_request_vote(args)
  138. }),
  139. )?;
  140. let raft_clone = raft.clone();
  141. server.register_rpc_handler(
  142. APPEND_ENTRIES_RPC.to_owned(),
  143. make_rpc_handler(move |args| {
  144. raft_clone.as_ref().process_append_entries(args)
  145. }),
  146. )?;
  147. let raft_clone = raft;
  148. server.register_rpc_handler(
  149. INSTALL_SNAPSHOT_RPC.to_owned(),
  150. make_rpc_handler(move |args| {
  151. raft_clone.as_ref().process_install_snapshot(args)
  152. }),
  153. )?;
  154. network.add_server(server_name, server);
  155. Ok(())
  156. }
  157. pub fn register_kv_server<
  158. KV: 'static + AsRef<KVServer> + Send + Sync + Clone,
  159. S: AsRef<str>,
  160. >(
  161. kv: KV,
  162. name: S,
  163. network: &Mutex<Network>,
  164. ) -> std::io::Result<()> {
  165. let mut network = network.lock();
  166. let server_name = name.as_ref();
  167. let mut server = Server::make_server(server_name);
  168. let kv_clone = kv.clone();
  169. server.register_async_rpc_handler(
  170. GET.to_owned(),
  171. make_async_rpc_handler(move |args| async move {
  172. kv_clone.as_ref().get(args).await
  173. }),
  174. )?;
  175. server.register_async_rpc_handler(
  176. PUT_APPEND.to_owned(),
  177. make_async_rpc_handler(move |args| async move {
  178. kv.as_ref().put_append(args).await
  179. }),
  180. )?;
  181. network.add_server(server_name, server);
  182. Ok(())
  183. }
  184. #[cfg(test)]
  185. mod tests {
  186. use std::sync::Arc;
  187. use bytes::Bytes;
  188. use ruaft::{ApplyCommandMessage, RemoteRaft, Term};
  189. use super::*;
  190. use ruaft::utils::integration_test::{
  191. make_append_entries_args, make_request_vote_args,
  192. unpack_append_entries_reply, unpack_request_vote_reply,
  193. };
  194. struct DoNothingPersister;
  195. impl ruaft::Persister for DoNothingPersister {
  196. fn read_state(&self) -> Bytes {
  197. Bytes::new()
  198. }
  199. fn save_state(&self, _bytes: Bytes) {}
  200. fn state_size(&self) -> usize {
  201. 0
  202. }
  203. fn save_snapshot_and_state(&self, _: Bytes, _: &[u8]) {}
  204. }
  205. #[test]
  206. fn test_basic_message() -> std::io::Result<()> {
  207. test_utils::init_test_log!();
  208. let client = {
  209. let network = Network::run_daemon();
  210. let name = "test-basic-message";
  211. let client = network
  212. .lock()
  213. .make_client("test-basic-message", name.to_owned());
  214. let raft = Arc::new(Raft::new(
  215. vec![RpcClient(client)],
  216. 0,
  217. Arc::new(DoNothingPersister),
  218. |_: ApplyCommandMessage<i32>| {},
  219. None,
  220. crate::utils::NO_SNAPSHOT,
  221. ));
  222. register_server(raft, name, network.as_ref())?;
  223. let client = network
  224. .lock()
  225. .make_client("test-basic-message", name.to_owned());
  226. client
  227. };
  228. let rpc_client = RpcClient(client);
  229. let request = make_request_vote_args(Term(2021), 0, 0, Term(0));
  230. let response = futures::executor::block_on(
  231. (&rpc_client as &dyn RemoteRaft<i32>).request_vote(request),
  232. )?;
  233. let (_, vote_granted) = unpack_request_vote_reply(response);
  234. assert!(vote_granted);
  235. let request =
  236. make_append_entries_args::<i32>(Term(2021), 0, 0, Term(0), 0);
  237. let response =
  238. futures::executor::block_on(rpc_client.append_entries(request))?;
  239. let (Term(term), success) = unpack_append_entries_reply(response);
  240. assert_eq!(2021, term);
  241. assert!(success);
  242. Ok(())
  243. }
  244. }