rpcs.rs 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. use std::future::Future;
  2. use async_trait::async_trait;
  3. use futures_util::future::BoxFuture;
  4. use futures_util::FutureExt;
  5. use labrpc::{Client, Network, ReplyMessage, RequestMessage, Server};
  6. use parking_lot::Mutex;
  7. use serde::de::DeserializeOwned;
  8. use serde::Serialize;
  9. use kvraft::{
  10. CommitSentinelArgs, CommitSentinelReply, GetArgs, GetReply, KVServer,
  11. PutAppendArgs, PutAppendReply, RemoteKvraft,
  12. };
  13. use ruaft::{
  14. AppendEntriesArgs, AppendEntriesReply, InstallSnapshotArgs,
  15. InstallSnapshotReply, Raft, ReplicableCommand, RequestVoteArgs,
  16. RequestVoteReply,
  17. };
  18. const REQUEST_VOTE_RPC: &str = "Raft.RequestVote";
  19. const APPEND_ENTRIES_RPC: &str = "Raft.AppendEntries";
  20. const INSTALL_SNAPSHOT_RPC: &str = "Raft.InstallSnapshot";
  21. pub struct RpcClient(Client);
  22. impl RpcClient {
  23. pub fn new(client: Client) -> Self {
  24. Self(client)
  25. }
  26. async fn call_rpc<Method, Request, Reply>(
  27. &self,
  28. method: Method,
  29. request: Request,
  30. ) -> std::io::Result<Reply>
  31. where
  32. Method: AsRef<str>,
  33. Request: Serialize,
  34. Reply: DeserializeOwned,
  35. {
  36. let data = RequestMessage::from(
  37. bincode::serialize(&request)
  38. .expect("Serialization of requests should not fail"),
  39. );
  40. let reply = self.0.call_rpc(method.as_ref().to_owned(), data).await?;
  41. Ok(bincode::deserialize(reply.as_ref())
  42. .expect("Deserialization of reply should not fail"))
  43. }
  44. }
  45. #[async_trait]
  46. impl<Command: ReplicableCommand> ruaft::RemoteRaft<Command> for RpcClient {
  47. async fn request_vote(
  48. &self,
  49. args: RequestVoteArgs,
  50. ) -> std::io::Result<RequestVoteReply> {
  51. self.call_rpc(REQUEST_VOTE_RPC, args).await
  52. }
  53. async fn append_entries(
  54. &self,
  55. args: AppendEntriesArgs<Command>,
  56. ) -> std::io::Result<AppendEntriesReply> {
  57. self.call_rpc(APPEND_ENTRIES_RPC, args).await
  58. }
  59. async fn install_snapshot(
  60. &self,
  61. args: InstallSnapshotArgs,
  62. ) -> std::io::Result<InstallSnapshotReply> {
  63. self.call_rpc(INSTALL_SNAPSHOT_RPC, args).await
  64. }
  65. }
  66. const GET: &str = "KVServer.Get";
  67. const PUT_APPEND: &str = "KVServer.PutAppend";
  68. const COMMIT_SENTINEL: &str = "KVServer.CommitSentinel";
  69. #[async_trait]
  70. impl RemoteKvraft for RpcClient {
  71. async fn get(&self, args: GetArgs) -> std::io::Result<GetReply> {
  72. self.call_rpc(GET, args).await
  73. }
  74. async fn put_append(
  75. &self,
  76. args: PutAppendArgs,
  77. ) -> std::io::Result<PutAppendReply> {
  78. self.call_rpc(PUT_APPEND, args).await
  79. }
  80. async fn commit_sentinel(
  81. &self,
  82. args: CommitSentinelArgs,
  83. ) -> std::io::Result<CommitSentinelReply> {
  84. self.call_rpc(COMMIT_SENTINEL, args).await
  85. }
  86. }
  87. pub fn make_rpc_handler<Request, Reply, F>(
  88. func: F,
  89. ) -> impl Fn(RequestMessage) -> ReplyMessage
  90. where
  91. Request: DeserializeOwned,
  92. Reply: Serialize,
  93. F: 'static + Fn(Request) -> Reply,
  94. {
  95. move |request| {
  96. let reply = func(
  97. bincode::deserialize(&request)
  98. .expect("Deserialization should not fail"),
  99. );
  100. ReplyMessage::from(
  101. bincode::serialize(&reply).expect("Serialization should not fail"),
  102. )
  103. }
  104. }
  105. pub fn make_async_rpc_handler<'a, Request, Reply, F, Fut>(
  106. func: F,
  107. ) -> impl Fn(RequestMessage) -> BoxFuture<'a, ReplyMessage>
  108. where
  109. Request: DeserializeOwned + Send,
  110. Reply: Serialize,
  111. Fut: Future<Output = Reply> + Send + 'a,
  112. F: 'a + Send + Clone + FnOnce(Request) -> Fut,
  113. {
  114. move |request| {
  115. let func = func.clone();
  116. let fut = async move {
  117. let request = bincode::deserialize(&request)
  118. .expect("Deserialization should not fail");
  119. let reply = func(request).await;
  120. ReplyMessage::from(
  121. bincode::serialize(&reply)
  122. .expect("Serialization should not fail"),
  123. )
  124. };
  125. fut.boxed()
  126. }
  127. }
  128. pub fn register_server<
  129. Command: 'static + Clone + Serialize + DeserializeOwned,
  130. R: 'static + AsRef<Raft<Command>> + Send + Sync + Clone,
  131. S: AsRef<str>,
  132. >(
  133. raft: R,
  134. name: S,
  135. network: &Mutex<Network>,
  136. ) -> std::io::Result<()> {
  137. let mut network = network.lock();
  138. let server_name = name.as_ref();
  139. let mut server = Server::make_server(server_name);
  140. server.register_rpc_handler(REQUEST_VOTE_RPC.to_owned(), {
  141. let raft = raft.clone();
  142. make_rpc_handler(move |args| raft.as_ref().process_request_vote(args))
  143. })?;
  144. server.register_rpc_handler(APPEND_ENTRIES_RPC.to_owned(), {
  145. let raft = raft.clone();
  146. make_rpc_handler(move |args| raft.as_ref().process_append_entries(args))
  147. })?;
  148. server.register_rpc_handler(
  149. INSTALL_SNAPSHOT_RPC.to_owned(),
  150. make_rpc_handler(move |args| {
  151. raft.as_ref().process_install_snapshot(args)
  152. }),
  153. )?;
  154. network.add_server(server_name, server);
  155. Ok(())
  156. }
  157. pub fn register_kv_server<
  158. KV: 'static + AsRef<KVServer> + Send + Sync + Clone,
  159. S: AsRef<str>,
  160. >(
  161. kv: KV,
  162. name: S,
  163. network: &Mutex<Network>,
  164. ) -> std::io::Result<()> {
  165. let mut network = network.lock();
  166. let server_name = name.as_ref();
  167. let mut server = Server::make_server(server_name);
  168. server.register_async_rpc_handler(GET.to_owned(), {
  169. let kv = kv.clone();
  170. make_async_rpc_handler(move |args| async move {
  171. kv.as_ref().get(args).await
  172. })
  173. })?;
  174. server.register_async_rpc_handler(PUT_APPEND.to_owned(), {
  175. let kv = kv.clone();
  176. make_async_rpc_handler(move |args| async move {
  177. kv.as_ref().put_append(args).await
  178. })
  179. })?;
  180. server.register_async_rpc_handler(
  181. COMMIT_SENTINEL.to_owned(),
  182. make_async_rpc_handler(move |args| async move {
  183. kv.as_ref().commit_sentinel(args).await
  184. }),
  185. )?;
  186. network.add_server(server_name, server);
  187. Ok(())
  188. }
  189. #[cfg(test)]
  190. mod tests {
  191. use std::sync::Arc;
  192. use bytes::Bytes;
  193. use ruaft::utils::integration_test::{
  194. make_append_entries_args, make_request_vote_args,
  195. unpack_append_entries_reply, unpack_request_vote_reply,
  196. };
  197. use ruaft::{ApplyCommandMessage, RemoteRaft, Term};
  198. use super::*;
  199. struct DoNothingPersister;
  200. impl ruaft::Persister for DoNothingPersister {
  201. fn read_state(&self) -> Bytes {
  202. Bytes::new()
  203. }
  204. fn save_state(&self, _bytes: Bytes) {}
  205. fn state_size(&self) -> usize {
  206. 0
  207. }
  208. fn save_snapshot_and_state(&self, _: Bytes, _: &[u8]) {}
  209. }
  210. #[test]
  211. fn test_basic_message() -> std::io::Result<()> {
  212. test_utils::init_test_log!();
  213. let client = {
  214. let network = Network::run_daemon();
  215. let name = "test-basic-message";
  216. let client = network
  217. .lock()
  218. .make_client("test-basic-message", name.to_owned());
  219. let raft = Arc::new(Raft::new(
  220. vec![RpcClient(client)],
  221. 0,
  222. Arc::new(DoNothingPersister),
  223. |_: ApplyCommandMessage<i32>| {},
  224. None,
  225. crate::utils::NO_SNAPSHOT,
  226. ));
  227. register_server(raft, name, network.as_ref())?;
  228. let client = network
  229. .lock()
  230. .make_client("test-basic-message", name.to_owned());
  231. client
  232. };
  233. let rpc_client = RpcClient(client);
  234. let request = make_request_vote_args(Term(2021), 0, 0, Term(0));
  235. let response = futures::executor::block_on(
  236. (&rpc_client as &dyn RemoteRaft<i32>).request_vote(request),
  237. )?;
  238. let (_, vote_granted) = unpack_request_vote_reply(response);
  239. assert!(vote_granted);
  240. let request =
  241. make_append_entries_args::<i32>(Term(2021), 0, 0, Term(0), 0);
  242. let response =
  243. futures::executor::block_on(rpc_client.append_entries(request))?;
  244. let (Term(term), success) = unpack_append_entries_reply(response);
  245. assert_eq!(2021, term);
  246. assert!(success);
  247. Ok(())
  248. }
  249. }