mod.rs 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. use std::future::Future;
  2. use std::sync::atomic::{AtomicUsize, Ordering};
  3. use std::sync::Arc;
  4. use std::time::{Duration, Instant};
  5. use async_trait::async_trait;
  6. use crossbeam_channel::{Receiver, Sender};
  7. use once_cell::sync::OnceCell;
  8. use kvraft::{
  9. GetArgs, KVServer, PutAppendArgs, PutAppendEnum, UniqueId, UniqueKVOp,
  10. };
  11. use ruaft::{
  12. AppendEntriesArgs, AppendEntriesReply, InstallSnapshotArgs,
  13. InstallSnapshotReply, Raft, RemoteRaft, ReplicableCommand, RequestVoteArgs,
  14. RequestVoteReply,
  15. };
  16. use crate::Persister;
  17. type RaftId = usize;
  18. pub struct EventHandle {
  19. pub from: RaftId,
  20. pub to: RaftId,
  21. sender: futures_channel::oneshot::Sender<std::io::Result<()>>,
  22. }
  23. struct EventStub {
  24. receiver: futures_channel::oneshot::Receiver<std::io::Result<()>>,
  25. }
  26. fn create_event_pair(from: RaftId, to: RaftId) -> (EventHandle, EventStub) {
  27. let (sender, receiver) = futures_channel::oneshot::channel();
  28. (EventHandle { from, to, sender }, EventStub { receiver })
  29. }
  30. impl EventHandle {
  31. pub fn unblock(self) {
  32. self.sender.send(Ok(())).unwrap();
  33. }
  34. pub fn reply_error(self, e: std::io::Error) {
  35. self.sender.send(Err(e)).unwrap();
  36. }
  37. pub fn reply_interrupted_error(self) {
  38. self.reply_error(std::io::Error::from(std::io::ErrorKind::Interrupted))
  39. }
  40. }
  41. impl EventStub {
  42. pub async fn wait(self) -> std::io::Result<()> {
  43. self.receiver.await.unwrap_or(Ok(()))
  44. }
  45. }
  46. pub enum RaftRpcEvent<T> {
  47. RequestVoteRequest(RequestVoteArgs),
  48. RequestVoteResponse(RequestVoteArgs, RequestVoteReply),
  49. AppendEntriesRequest(AppendEntriesArgs<T>),
  50. AppendEntriesResponse(AppendEntriesArgs<T>, AppendEntriesReply),
  51. InstallSnapshotRequest(InstallSnapshotArgs),
  52. InstallSnapshotResponse(InstallSnapshotArgs, InstallSnapshotReply),
  53. }
  54. struct InterceptingRpcClient<T> {
  55. from: RaftId,
  56. to: RaftId,
  57. target: OnceCell<Raft<T>>,
  58. event_queue: Sender<(RaftRpcEvent<T>, EventHandle)>,
  59. }
  60. impl<T> InterceptingRpcClient<T> {
  61. async fn intercept(&self, event: RaftRpcEvent<T>) -> std::io::Result<()> {
  62. let (handle, stub) = create_event_pair(self.from, self.to);
  63. let _ = self.event_queue.send((event, handle));
  64. stub.wait().await
  65. }
  66. pub fn set_raft(&self, target: Raft<T>) {
  67. self.target
  68. .set(target)
  69. .map_err(|_| ())
  70. .expect("Raft should only be set once");
  71. }
  72. }
  73. #[async_trait]
  74. impl<T: ReplicableCommand> RemoteRaft<T> for &InterceptingRpcClient<T> {
  75. async fn request_vote(
  76. &self,
  77. args: RequestVoteArgs,
  78. ) -> std::io::Result<RequestVoteReply> {
  79. self.intercept(RaftRpcEvent::RequestVoteRequest(args.clone()))
  80. .await?;
  81. let reply = self.target.wait().process_request_vote(args.clone());
  82. self.intercept(RaftRpcEvent::RequestVoteResponse(args, reply.clone()))
  83. .await
  84. .map(|_| reply)
  85. }
  86. async fn append_entries(
  87. &self,
  88. args: AppendEntriesArgs<T>,
  89. ) -> std::io::Result<AppendEntriesReply> {
  90. let args_clone = args.clone();
  91. self.intercept(RaftRpcEvent::AppendEntriesRequest(args_clone))
  92. .await?;
  93. let reply = self.target.wait().process_append_entries(args.clone());
  94. self.intercept(RaftRpcEvent::AppendEntriesResponse(args, reply.clone()))
  95. .await
  96. .map(|_| reply)
  97. }
  98. async fn install_snapshot(
  99. &self,
  100. args: InstallSnapshotArgs,
  101. ) -> std::io::Result<InstallSnapshotReply> {
  102. self.intercept(RaftRpcEvent::InstallSnapshotRequest(args.clone()))
  103. .await?;
  104. let reply = self.target.wait().process_install_snapshot(args.clone());
  105. self.intercept(RaftRpcEvent::InstallSnapshotResponse(
  106. args,
  107. reply.clone(),
  108. ))
  109. .await
  110. .map(|_| reply)
  111. }
  112. }
  113. pub struct EventQueue<T> {
  114. pub receiver: Receiver<(RaftRpcEvent<T>, EventHandle)>,
  115. }
  116. fn make_grid_clients<T>(
  117. server_count: usize,
  118. ) -> (EventQueue<T>, Vec<Vec<InterceptingRpcClient<T>>>) {
  119. let (sender, receiver) = crossbeam_channel::unbounded();
  120. let mut all_clients = vec![];
  121. for from in 0..server_count {
  122. let mut clients = vec![];
  123. for to in 0..server_count {
  124. let interceptor = InterceptingRpcClient {
  125. from,
  126. to,
  127. target: Default::default(),
  128. event_queue: sender.clone(),
  129. };
  130. clients.push(interceptor);
  131. }
  132. all_clients.push(clients);
  133. }
  134. (EventQueue { receiver }, all_clients)
  135. }
  136. pub struct Config {
  137. pub event_queue: EventQueue<UniqueKVOp>,
  138. pub kv_servers: Vec<Arc<KVServer>>,
  139. seq: AtomicUsize,
  140. }
  141. impl Config {
  142. pub fn find_leader(&self) -> Option<&KVServer> {
  143. let start = Instant::now();
  144. while start.elapsed() < Duration::from_secs(1) {
  145. if let Some(kv_server) = self
  146. .kv_servers
  147. .iter()
  148. .find(|kv_server| kv_server.raft().get_state().1)
  149. {
  150. return Some(kv_server.as_ref());
  151. }
  152. }
  153. None
  154. }
  155. pub async fn put_to_kv(
  156. &self,
  157. kv_server: &KVServer,
  158. key: String,
  159. value: String,
  160. ) -> Result<(), ()> {
  161. let result = kv_server
  162. .put_append(PutAppendArgs {
  163. key,
  164. value,
  165. op: PutAppendEnum::Put,
  166. unique_id: UniqueId {
  167. clerk_id: 1,
  168. sequence_id: self.seq.fetch_add(1, Ordering::Relaxed)
  169. as u64,
  170. },
  171. })
  172. .await;
  173. result.result.map_err(|_| ())
  174. }
  175. pub async fn put(&self, key: String, value: String) -> Result<(), ()> {
  176. let kv_server = self.find_leader().unwrap();
  177. self.put_to_kv(kv_server, key, value).await
  178. }
  179. pub fn spawn_put_to_kv(
  180. self: &Arc<Self>,
  181. index: usize,
  182. key: String,
  183. value: String,
  184. ) -> impl Future<Output = Result<(), ()>> {
  185. let this = self.clone();
  186. async move {
  187. this.put_to_kv(this.kv_servers[index].as_ref(), key, value)
  188. .await
  189. }
  190. }
  191. pub fn spawn_put(
  192. self: &Arc<Self>,
  193. key: String,
  194. value: String,
  195. ) -> impl Future<Output = Result<(), ()>> {
  196. let this = self.clone();
  197. async move { this.put(key, value).await }
  198. }
  199. pub async fn get_from_kv(
  200. &self,
  201. kv_server: &KVServer,
  202. key: String,
  203. ) -> Result<String, ()> {
  204. let result = kv_server.get(GetArgs { key }).await;
  205. result.result.map(|v| v.unwrap_or_default()).map_err(|_| ())
  206. }
  207. pub fn spawn_get_from_kv(
  208. self: &Arc<Self>,
  209. index: usize,
  210. key: String,
  211. ) -> impl Future<Output = Result<String, ()>> {
  212. let this = self.clone();
  213. async move { this.get_from_kv(this.kv_servers[index].as_ref(), key).await }
  214. }
  215. pub async fn get(&self, key: String) -> Result<String, ()> {
  216. let kv_server = self.find_leader().unwrap();
  217. self.get_from_kv(kv_server, key).await
  218. }
  219. pub fn spawn_get(
  220. self: &Arc<Self>,
  221. key: String,
  222. ) -> impl Future<Output = Result<String, ()>> {
  223. let this = self.clone();
  224. async move { this.get(key).await }
  225. }
  226. }
  227. pub fn make_config(server_count: usize, max_state: Option<usize>) -> Config {
  228. let (event_queue, clients) = make_grid_clients(server_count);
  229. let persister = Arc::new(Persister::new());
  230. let mut kv_servers = vec![];
  231. let clients: Vec<Vec<&'static InterceptingRpcClient<UniqueKVOp>>> = clients
  232. .into_iter()
  233. .map(|v| {
  234. v.into_iter()
  235. .map(|c| {
  236. let c = Box::leak(Box::new(c));
  237. &*c
  238. })
  239. .collect()
  240. })
  241. .collect();
  242. for (index, client_vec) in clients.iter().enumerate() {
  243. let kv_server = KVServer::new(
  244. client_vec.to_vec(),
  245. index,
  246. persister.clone(),
  247. max_state,
  248. );
  249. kv_servers.push(kv_server);
  250. }
  251. for clients in clients.iter() {
  252. for j in 0..server_count {
  253. clients[j].set_raft(kv_servers[j].raft().clone());
  254. }
  255. }
  256. Config {
  257. event_queue,
  258. kv_servers,
  259. seq: AtomicUsize::new(0),
  260. }
  261. }