lib.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. #![allow(unused)]
  2. extern crate bincode;
  3. extern crate futures;
  4. extern crate labrpc;
  5. extern crate rand;
  6. #[macro_use]
  7. extern crate serde_derive;
  8. extern crate tokio;
  9. use std::future::Future;
  10. use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
  11. use std::sync::Arc;
  12. use std::time::Duration;
  13. use futures::FutureExt;
  14. use parking_lot::{Condvar, Mutex};
  15. use rand::{thread_rng, Rng};
  16. use crate::rpcs::RpcClient;
  17. use std::cell::RefCell;
  18. pub mod rpcs;
  19. #[derive(Eq, PartialEq)]
  20. enum State {
  21. Follower,
  22. Candidate,
  23. // TODO: add PreVote
  24. Leader,
  25. }
  26. #[derive(
  27. Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize,
  28. )]
  29. struct Term(usize);
  30. #[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
  31. struct Peer(usize);
  32. #[derive(Clone, Copy, Debug, Serialize, Deserialize)]
  33. struct Command(usize);
  34. // TODO: remove all of the defaults.
  35. impl Default for State {
  36. fn default() -> Self {
  37. Self::Leader
  38. }
  39. }
  40. impl Default for Term {
  41. fn default() -> Self {
  42. Self(0)
  43. }
  44. }
  45. impl Default for Peer {
  46. fn default() -> Self {
  47. Self(0)
  48. }
  49. }
  50. #[derive(Clone, Copy, Serialize, Deserialize)]
  51. struct LogEntry {
  52. term: Term,
  53. index: usize,
  54. // TODO: Allow sending of arbitrary information.
  55. command: Command,
  56. }
  57. #[derive(Default)]
  58. struct RaftState {
  59. current_term: Term,
  60. voted_for: Option<Peer>,
  61. log: Vec<LogEntry>,
  62. commit_index: usize,
  63. last_applied: usize,
  64. next_index: Vec<usize>,
  65. match_index: Vec<usize>,
  66. current_step: Vec<i64>,
  67. state: State,
  68. leader_id: Peer,
  69. // Current election cancel token, might be None if no election is running.
  70. election_cancel_token: Option<futures::channel::oneshot::Sender<Term>>,
  71. // Timer will be removed upon shutdown or elected.
  72. election_timer: Option<tokio::time::Delay>,
  73. }
  74. #[derive(Default)]
  75. struct Raft {
  76. inner_state: Arc<Mutex<RaftState>>,
  77. peers: Vec<RpcClient>,
  78. me: Peer,
  79. // new_log_entry: Sender<usize>,
  80. // new_log_entry: Receiver<usize>,
  81. // apply_command_cond: Condvar
  82. keep_running: AtomicBool,
  83. // applyCh: Sender<ApplyMsg>
  84. }
  85. #[derive(Serialize, Deserialize)]
  86. struct RequestVoteArgs {
  87. term: Term,
  88. candidate_id: Peer,
  89. last_log_index: usize,
  90. last_log_term: Term,
  91. }
  92. #[derive(Serialize, Deserialize)]
  93. struct RequestVoteReply {
  94. term: Term,
  95. vote_granted: bool,
  96. }
  97. #[derive(Serialize, Deserialize)]
  98. struct AppendEntriesArgs {
  99. term: Term,
  100. leader_id: Peer,
  101. prev_log_index: usize,
  102. prev_log_term: Term,
  103. entries: Vec<LogEntry>,
  104. leader_commit: usize,
  105. }
  106. #[derive(Serialize, Deserialize)]
  107. struct AppendEntriesReply {
  108. term: Term,
  109. success: bool,
  110. }
  111. impl Raft {
  112. pub fn new() -> Self {
  113. let mut raft = Self {
  114. ..Default::default()
  115. };
  116. raft.inner_state.lock().log.push(LogEntry {
  117. term: Default::default(),
  118. index: 0,
  119. command: Command(0),
  120. });
  121. raft
  122. }
  123. pub(crate) fn process_request_vote(
  124. &self,
  125. args: RequestVoteArgs,
  126. ) -> RequestVoteReply {
  127. let mut rf = self.inner_state.lock();
  128. let term = rf.current_term;
  129. if args.term < term {
  130. return RequestVoteReply {
  131. term,
  132. vote_granted: false,
  133. };
  134. } else if args.term > term {
  135. rf.current_term = args.term;
  136. rf.voted_for = None;
  137. rf.state = State::Follower;
  138. rf.reset_election_timer();
  139. rf.stop_current_election();
  140. rf.persist();
  141. }
  142. let voted_for = rf.voted_for;
  143. let last_log_index = rf.log.len() - 1;
  144. let last_log_term = rf.log.last().unwrap().term;
  145. if (voted_for.is_none() || voted_for == Some(args.candidate_id))
  146. && (args.last_log_term > last_log_term
  147. || (args.last_log_term == last_log_term
  148. && args.last_log_index >= last_log_index))
  149. {
  150. rf.voted_for = Some(args.candidate_id);
  151. rf.reset_election_timer();
  152. // No need to stop the election. We are not a candidate.
  153. rf.persist();
  154. RequestVoteReply {
  155. term: args.term,
  156. vote_granted: true,
  157. }
  158. } else {
  159. RequestVoteReply {
  160. term: args.term,
  161. vote_granted: false,
  162. }
  163. }
  164. }
  165. pub(crate) fn process_append_entries(
  166. &self,
  167. args: AppendEntriesArgs,
  168. ) -> AppendEntriesReply {
  169. let mut rf = self.inner_state.lock();
  170. if rf.current_term > args.term {
  171. return AppendEntriesReply {
  172. term: rf.current_term,
  173. success: false,
  174. };
  175. }
  176. if rf.current_term < args.term {
  177. rf.current_term = args.term;
  178. rf.voted_for = None;
  179. }
  180. rf.state = State::Follower;
  181. rf.reset_election_timer();
  182. rf.stop_current_election();
  183. rf.leader_id = args.leader_id;
  184. if rf.log.len() <= args.prev_log_index
  185. || rf.log[args.prev_log_index].term != args.term
  186. {
  187. return AppendEntriesReply {
  188. term: args.term,
  189. success: false,
  190. };
  191. }
  192. for (i, entry) in args.entries.iter().enumerate() {
  193. let index = i + args.prev_log_index + 1;
  194. if rf.log.len() > index {
  195. if rf.log[index].term != entry.term {
  196. rf.log.truncate(index);
  197. rf.log.push(entry.clone());
  198. }
  199. } else {
  200. rf.log.push(entry.clone());
  201. }
  202. }
  203. if args.leader_commit > rf.commit_index {
  204. rf.commit_index = if args.leader_commit < rf.log.len() {
  205. args.leader_commit
  206. } else {
  207. rf.log.len() - 1
  208. };
  209. // TODO: apply commands.
  210. }
  211. AppendEntriesReply {
  212. term: args.term,
  213. success: true,
  214. }
  215. }
  216. async fn retry_rpc<Func, Fut, T>(
  217. max_retry: usize,
  218. mut task_gen: Func,
  219. ) -> std::io::Result<T>
  220. where
  221. Fut: Future<Output = std::io::Result<T>> + Send + 'static,
  222. Func: FnMut(usize) -> Fut,
  223. {
  224. for i in 0..max_retry {
  225. if let Ok(reply) = task_gen(i).await {
  226. return Ok(reply);
  227. }
  228. tokio::time::delay_for(Duration::from_millis((1 << i) * 10)).await;
  229. }
  230. Err(std::io::Error::new(
  231. std::io::ErrorKind::TimedOut,
  232. format!("Timed out after {} retries", max_retry),
  233. ))
  234. }
  235. fn run_election(&self) {
  236. let (term, last_log_index, last_log_term, cancel_token) = {
  237. let mut rf = self.inner_state.lock();
  238. let (tx, rx) = futures::channel::oneshot::channel();
  239. rf.current_term.0 += 1;
  240. rf.voted_for = Some(self.me);
  241. rf.state = State::Candidate;
  242. rf.reset_election_timer();
  243. rf.stop_current_election();
  244. rf.election_cancel_token.replace(tx);
  245. rf.persist();
  246. (
  247. rf.current_term,
  248. rf.log.len() - 1,
  249. rf.log.last().unwrap().term,
  250. rx,
  251. )
  252. };
  253. let me = self.me;
  254. let mut votes = vec![];
  255. for i in 0..self.peers.len() {
  256. if i != self.me.0 {
  257. // Make a clone now so that self will not be passed across await
  258. // boundary.
  259. let rpc_client = self.peers[i].clone();
  260. // RPCs are started right away.
  261. let one_vote = tokio::spawn(async move {
  262. let reply_future = Self::retry_rpc(4, move |_round| {
  263. rpc_client.clone().call_request_vote(RequestVoteArgs {
  264. term,
  265. candidate_id: me,
  266. last_log_index,
  267. last_log_term,
  268. })
  269. });
  270. if let Ok(reply) = reply_future.await {
  271. return Some(reply.vote_granted && reply.term == term);
  272. }
  273. return None;
  274. });
  275. // Futures must be pinned so that they have Unpin, as required
  276. // by futures::future::select.
  277. votes.push(one_vote);
  278. }
  279. }
  280. tokio::spawn(Self::count_vote_util_cancelled(
  281. term,
  282. self.inner_state.clone(),
  283. votes,
  284. self.peers.len() / 2,
  285. cancel_token,
  286. ));
  287. }
  288. async fn count_vote_util_cancelled(
  289. term: Term,
  290. rf: Arc<Mutex<RaftState>>,
  291. votes: Vec<tokio::task::JoinHandle<Option<bool>>>,
  292. majority: usize,
  293. cancel_token: futures::channel::oneshot::Receiver<Term>,
  294. ) {
  295. let mut vote_count = 0;
  296. let mut against_count = 0;
  297. let mut cancel_token = cancel_token;
  298. let mut futures_vec = votes;
  299. while vote_count < majority && against_count <= majority {
  300. // Mixing tokio futures with futures-rs ones. Fingers crossed.
  301. let selected = futures::future::select(
  302. cancel_token,
  303. futures::future::select_all(futures_vec),
  304. )
  305. .await;
  306. let ((one_vote, index, rest), new_token) = match selected {
  307. futures::future::Either::Left(_) => break,
  308. futures::future::Either::Right(tuple) => tuple,
  309. };
  310. futures_vec = rest;
  311. cancel_token = new_token;
  312. if let Ok(Some(vote)) = one_vote {
  313. if vote {
  314. vote_count += 1
  315. } else {
  316. against_count += 1
  317. }
  318. }
  319. }
  320. if vote_count < majority {
  321. return;
  322. }
  323. let mut rf = rf.lock();
  324. if rf.current_term == term && rf.state == State::Candidate {
  325. rf.state = State::Leader;
  326. }
  327. let log_len = rf.log.len();
  328. for item in rf.next_index.iter_mut() {
  329. *item = log_len;
  330. }
  331. for item in rf.match_index.iter_mut() {
  332. *item = 0;
  333. }
  334. // TODO: send heartbeats.
  335. // Drop the timer and cancel token.
  336. rf.election_cancel_token.take();
  337. rf.election_timer.take();
  338. rf.persist();
  339. }
  340. }
  341. const HEARTBEAT_INTERVAL_MILLIS: u64 = 150;
  342. const ELECTION_TIMEOUT_BASE_MILLIS: u64 = 150;
  343. const ELECTION_TIMEOUT_VAR_MILLIS: u64 = 250;
  344. impl RaftState {
  345. fn reset_election_timer(&mut self) {
  346. self.election_timer.as_mut().map(|timer| {
  347. timer.reset(
  348. (std::time::Instant::now() + Self::election_timeout()).into(),
  349. )
  350. });
  351. }
  352. fn election_timeout() -> Duration {
  353. Duration::from_millis(
  354. ELECTION_TIMEOUT_BASE_MILLIS
  355. + thread_rng().gen_range(0, ELECTION_TIMEOUT_VAR_MILLIS),
  356. )
  357. }
  358. fn stop_current_election(&mut self) {
  359. self.election_cancel_token
  360. .take()
  361. .map(|sender| sender.send(self.current_term));
  362. }
  363. fn persist(&self) {
  364. // TODO: implement
  365. }
  366. }