rpcs.rs 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. use std::future::Future;
  2. use async_trait::async_trait;
  3. use futures_util::future::BoxFuture;
  4. use futures_util::FutureExt;
  5. use labrpc::{Client, Network, ReplyMessage, RequestMessage, Server};
  6. use parking_lot::Mutex;
  7. use serde::de::DeserializeOwned;
  8. use serde::Serialize;
  9. use kvraft::{
  10. CommitSentinelArgs, CommitSentinelReply, GetArgs, GetReply, KVServer,
  11. PutAppendArgs, PutAppendReply, RemoteKvraft,
  12. };
  13. use ruaft::{
  14. AppendEntriesArgs, AppendEntriesReply, InstallSnapshotArgs,
  15. InstallSnapshotReply, Raft, RequestVoteArgs, RequestVoteReply,
  16. };
  17. const REQUEST_VOTE_RPC: &str = "Raft.RequestVote";
  18. const APPEND_ENTRIES_RPC: &str = "Raft.AppendEntries";
  19. const INSTALL_SNAPSHOT_RPC: &str = "Raft.InstallSnapshot";
  20. pub struct RpcClient(Client);
  21. impl RpcClient {
  22. pub fn new(client: Client) -> Self {
  23. Self(client)
  24. }
  25. async fn call_rpc<Method, Request, Reply>(
  26. &self,
  27. method: Method,
  28. request: Request,
  29. ) -> std::io::Result<Reply>
  30. where
  31. Method: AsRef<str>,
  32. Request: Serialize,
  33. Reply: DeserializeOwned,
  34. {
  35. let data = RequestMessage::from(
  36. bincode::serialize(&request)
  37. .expect("Serialization of requests should not fail"),
  38. );
  39. let reply = self.0.call_rpc(method.as_ref().to_owned(), data).await?;
  40. Ok(bincode::deserialize(reply.as_ref())
  41. .expect("Deserialization of reply should not fail"))
  42. }
  43. }
  44. #[async_trait]
  45. impl<Command: 'static + Send + Serialize> ruaft::RemoteRaft<Command>
  46. for RpcClient
  47. {
  48. async fn request_vote(
  49. &self,
  50. args: RequestVoteArgs,
  51. ) -> std::io::Result<RequestVoteReply> {
  52. self.call_rpc(REQUEST_VOTE_RPC, args).await
  53. }
  54. async fn append_entries(
  55. &self,
  56. args: AppendEntriesArgs<Command>,
  57. ) -> std::io::Result<AppendEntriesReply> {
  58. self.call_rpc(APPEND_ENTRIES_RPC, args).await
  59. }
  60. async fn install_snapshot(
  61. &self,
  62. args: InstallSnapshotArgs,
  63. ) -> std::io::Result<InstallSnapshotReply> {
  64. self.call_rpc(INSTALL_SNAPSHOT_RPC, args).await
  65. }
  66. }
  67. const GET: &str = "KVServer.Get";
  68. const PUT_APPEND: &str = "KVServer.PutAppend";
  69. const COMMIT_SENTINEL: &str = "KVServer.CommitSentinel";
  70. #[async_trait]
  71. impl RemoteKvraft for RpcClient {
  72. async fn get(&self, args: GetArgs) -> std::io::Result<GetReply> {
  73. self.call_rpc(GET, args).await
  74. }
  75. async fn put_append(
  76. &self,
  77. args: PutAppendArgs,
  78. ) -> std::io::Result<PutAppendReply> {
  79. self.call_rpc(PUT_APPEND, args).await
  80. }
  81. async fn commit_sentinel(
  82. &self,
  83. args: CommitSentinelArgs,
  84. ) -> std::io::Result<CommitSentinelReply> {
  85. self.call_rpc(COMMIT_SENTINEL, args).await
  86. }
  87. }
  88. pub fn make_rpc_handler<Request, Reply, F>(
  89. func: F,
  90. ) -> impl Fn(RequestMessage) -> ReplyMessage
  91. where
  92. Request: DeserializeOwned,
  93. Reply: Serialize,
  94. F: 'static + Fn(Request) -> Reply,
  95. {
  96. move |request| {
  97. let reply = func(
  98. bincode::deserialize(&request)
  99. .expect("Deserialization should not fail"),
  100. );
  101. ReplyMessage::from(
  102. bincode::serialize(&reply).expect("Serialization should not fail"),
  103. )
  104. }
  105. }
  106. pub fn make_async_rpc_handler<'a, Request, Reply, F, Fut>(
  107. func: F,
  108. ) -> impl Fn(RequestMessage) -> BoxFuture<'a, ReplyMessage>
  109. where
  110. Request: DeserializeOwned + Send,
  111. Reply: Serialize,
  112. Fut: Future<Output = Reply> + Send + 'a,
  113. F: 'a + Send + Clone + FnOnce(Request) -> Fut,
  114. {
  115. move |request| {
  116. let func = func.clone();
  117. let fut = async move {
  118. let request = bincode::deserialize(&request)
  119. .expect("Deserialization should not fail");
  120. let reply = func(request).await;
  121. ReplyMessage::from(
  122. bincode::serialize(&reply)
  123. .expect("Serialization should not fail"),
  124. )
  125. };
  126. fut.boxed()
  127. }
  128. }
  129. pub fn register_server<
  130. Command: 'static + Clone + Serialize + DeserializeOwned,
  131. R: 'static + AsRef<Raft<Command>> + Send + Sync + Clone,
  132. S: AsRef<str>,
  133. >(
  134. raft: R,
  135. name: S,
  136. network: &Mutex<Network>,
  137. ) -> std::io::Result<()> {
  138. let mut network = network.lock();
  139. let server_name = name.as_ref();
  140. let mut server = Server::make_server(server_name);
  141. let raft_clone = raft.clone();
  142. server.register_rpc_handler(
  143. REQUEST_VOTE_RPC.to_owned(),
  144. make_rpc_handler(move |args| {
  145. raft_clone.as_ref().process_request_vote(args)
  146. }),
  147. )?;
  148. let raft_clone = raft.clone();
  149. server.register_rpc_handler(
  150. APPEND_ENTRIES_RPC.to_owned(),
  151. make_rpc_handler(move |args| {
  152. raft_clone.as_ref().process_append_entries(args)
  153. }),
  154. )?;
  155. let raft_clone = raft;
  156. server.register_rpc_handler(
  157. INSTALL_SNAPSHOT_RPC.to_owned(),
  158. make_rpc_handler(move |args| {
  159. raft_clone.as_ref().process_install_snapshot(args)
  160. }),
  161. )?;
  162. network.add_server(server_name, server);
  163. Ok(())
  164. }
  165. pub fn register_kv_server<
  166. KV: 'static + AsRef<KVServer> + Send + Sync + Clone,
  167. S: AsRef<str>,
  168. >(
  169. kv: KV,
  170. name: S,
  171. network: &Mutex<Network>,
  172. ) -> std::io::Result<()> {
  173. let mut network = network.lock();
  174. let server_name = name.as_ref();
  175. let mut server = Server::make_server(server_name);
  176. let kv_clone = kv.clone();
  177. server.register_async_rpc_handler(
  178. GET.to_owned(),
  179. make_async_rpc_handler(move |args| async move {
  180. kv_clone.as_ref().get(args).await
  181. }),
  182. )?;
  183. let kv_clone = kv.clone();
  184. server.register_async_rpc_handler(
  185. PUT_APPEND.to_owned(),
  186. make_async_rpc_handler(move |args| async move {
  187. kv_clone.as_ref().put_append(args).await
  188. }),
  189. )?;
  190. server.register_async_rpc_handler(
  191. COMMIT_SENTINEL.to_owned(),
  192. make_async_rpc_handler(move |args| async move {
  193. kv.as_ref().commit_sentinel(args).await
  194. }),
  195. )?;
  196. network.add_server(server_name, server);
  197. Ok(())
  198. }
  199. #[cfg(test)]
  200. mod tests {
  201. use std::sync::Arc;
  202. use bytes::Bytes;
  203. use ruaft::utils::integration_test::{
  204. make_append_entries_args, make_request_vote_args,
  205. unpack_append_entries_reply, unpack_request_vote_reply,
  206. };
  207. use ruaft::{ApplyCommandMessage, RemoteRaft, Term};
  208. use super::*;
  209. struct DoNothingPersister;
  210. impl ruaft::Persister for DoNothingPersister {
  211. fn read_state(&self) -> Bytes {
  212. Bytes::new()
  213. }
  214. fn save_state(&self, _bytes: Bytes) {}
  215. fn state_size(&self) -> usize {
  216. 0
  217. }
  218. fn save_snapshot_and_state(&self, _: Bytes, _: &[u8]) {}
  219. }
  220. #[test]
  221. fn test_basic_message() -> std::io::Result<()> {
  222. test_utils::init_test_log!();
  223. let client = {
  224. let network = Network::run_daemon();
  225. let name = "test-basic-message";
  226. let client = network
  227. .lock()
  228. .make_client("test-basic-message", name.to_owned());
  229. let raft = Arc::new(Raft::new(
  230. vec![RpcClient(client)],
  231. 0,
  232. Arc::new(DoNothingPersister),
  233. |_: ApplyCommandMessage<i32>| {},
  234. None,
  235. crate::utils::NO_SNAPSHOT,
  236. ));
  237. register_server(raft, name, network.as_ref())?;
  238. let client = network
  239. .lock()
  240. .make_client("test-basic-message", name.to_owned());
  241. client
  242. };
  243. let rpc_client = RpcClient(client);
  244. let request = make_request_vote_args(Term(2021), 0, 0, Term(0));
  245. let response = futures::executor::block_on(
  246. (&rpc_client as &dyn RemoteRaft<i32>).request_vote(request),
  247. )?;
  248. let (_, vote_granted) = unpack_request_vote_reply(response);
  249. assert!(vote_granted);
  250. let request =
  251. make_append_entries_args::<i32>(Term(2021), 0, 0, Term(0), 0);
  252. let response =
  253. futures::executor::block_on(rpc_client.append_entries(request))?;
  254. let (Term(term), success) = unpack_append_entries_reply(response);
  255. assert_eq!(2021, term);
  256. assert!(success);
  257. Ok(())
  258. }
  259. }