lib.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. #![allow(unused)]
  2. extern crate bincode;
  3. extern crate futures;
  4. extern crate labrpc;
  5. extern crate rand;
  6. #[macro_use]
  7. extern crate serde_derive;
  8. extern crate tokio;
  9. use std::future::Future;
  10. use std::sync::atomic::AtomicBool;
  11. use std::sync::Arc;
  12. use std::time::Duration;
  13. use parking_lot::Mutex;
  14. use rand::{thread_rng, Rng};
  15. use crate::rpcs::RpcClient;
  16. use crate::utils::retry_rpc;
  17. pub mod rpcs;
  18. mod utils;
  19. #[derive(Eq, PartialEq)]
  20. enum State {
  21. Follower,
  22. Candidate,
  23. // TODO: add PreVote
  24. Leader,
  25. }
  26. #[derive(
  27. Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize,
  28. )]
  29. struct Term(usize);
  30. #[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
  31. struct Peer(usize);
  32. #[derive(Clone, Copy, Debug, Serialize, Deserialize)]
  33. struct Command(usize);
  34. // TODO: remove all of the defaults.
  35. impl Default for State {
  36. fn default() -> Self {
  37. Self::Leader
  38. }
  39. }
  40. impl Default for Term {
  41. fn default() -> Self {
  42. Self(0)
  43. }
  44. }
  45. impl Default for Peer {
  46. fn default() -> Self {
  47. Self(0)
  48. }
  49. }
  50. #[derive(Clone, Copy, Serialize, Deserialize)]
  51. struct LogEntry {
  52. term: Term,
  53. index: usize,
  54. // TODO: Allow sending of arbitrary information.
  55. command: Command,
  56. }
  57. #[derive(Default)]
  58. struct RaftState {
  59. current_term: Term,
  60. voted_for: Option<Peer>,
  61. log: Vec<LogEntry>,
  62. commit_index: usize,
  63. last_applied: usize,
  64. next_index: Vec<usize>,
  65. match_index: Vec<usize>,
  66. current_step: Vec<i64>,
  67. state: State,
  68. leader_id: Peer,
  69. // Current election cancel token, might be None if no election is running.
  70. election_cancel_token: Option<futures::channel::oneshot::Sender<Term>>,
  71. // Timer will be removed upon shutdown or elected.
  72. election_timer: Option<tokio::time::Delay>,
  73. }
  74. #[derive(Default)]
  75. struct Raft {
  76. inner_state: Arc<Mutex<RaftState>>,
  77. peers: Vec<RpcClient>,
  78. me: Peer,
  79. // new_log_entry: Sender<usize>,
  80. // new_log_entry: Receiver<usize>,
  81. // apply_command_cond: Condvar
  82. keep_running: AtomicBool,
  83. // applyCh: Sender<ApplyMsg>
  84. }
  85. #[derive(Clone, Serialize, Deserialize)]
  86. struct RequestVoteArgs {
  87. term: Term,
  88. candidate_id: Peer,
  89. last_log_index: usize,
  90. last_log_term: Term,
  91. }
  92. #[derive(Clone, Serialize, Deserialize)]
  93. struct RequestVoteReply {
  94. term: Term,
  95. vote_granted: bool,
  96. }
  97. #[derive(Clone, Serialize, Deserialize)]
  98. struct AppendEntriesArgs {
  99. term: Term,
  100. leader_id: Peer,
  101. prev_log_index: usize,
  102. prev_log_term: Term,
  103. entries: Vec<LogEntry>,
  104. leader_commit: usize,
  105. }
  106. #[derive(Clone, Serialize, Deserialize)]
  107. struct AppendEntriesReply {
  108. term: Term,
  109. success: bool,
  110. }
  111. impl Raft {
  112. pub fn new() -> Self {
  113. let raft = Self {
  114. ..Default::default()
  115. };
  116. raft.inner_state.lock().log.push(LogEntry {
  117. term: Default::default(),
  118. index: 0,
  119. command: Command(0),
  120. });
  121. raft
  122. }
  123. pub(crate) fn process_request_vote(
  124. &self,
  125. args: RequestVoteArgs,
  126. ) -> RequestVoteReply {
  127. let mut rf = self.inner_state.lock();
  128. let term = rf.current_term;
  129. if args.term < term {
  130. return RequestVoteReply {
  131. term,
  132. vote_granted: false,
  133. };
  134. } else if args.term > term {
  135. rf.current_term = args.term;
  136. rf.voted_for = None;
  137. rf.state = State::Follower;
  138. rf.reset_election_timer();
  139. rf.stop_current_election();
  140. rf.persist();
  141. }
  142. let voted_for = rf.voted_for;
  143. let (last_log_index, last_log_term) = rf.last_log_index_and_term();
  144. if (voted_for.is_none() || voted_for == Some(args.candidate_id))
  145. && (args.last_log_term > last_log_term
  146. || (args.last_log_term == last_log_term
  147. && args.last_log_index >= last_log_index))
  148. {
  149. rf.voted_for = Some(args.candidate_id);
  150. rf.reset_election_timer();
  151. // No need to stop the election. We are not a candidate.
  152. rf.persist();
  153. RequestVoteReply {
  154. term: args.term,
  155. vote_granted: true,
  156. }
  157. } else {
  158. RequestVoteReply {
  159. term: args.term,
  160. vote_granted: false,
  161. }
  162. }
  163. }
  164. pub(crate) fn process_append_entries(
  165. &self,
  166. args: AppendEntriesArgs,
  167. ) -> AppendEntriesReply {
  168. let mut rf = self.inner_state.lock();
  169. if rf.current_term > args.term {
  170. return AppendEntriesReply {
  171. term: rf.current_term,
  172. success: false,
  173. };
  174. }
  175. if rf.current_term < args.term {
  176. rf.current_term = args.term;
  177. rf.voted_for = None;
  178. }
  179. rf.state = State::Follower;
  180. rf.reset_election_timer();
  181. rf.stop_current_election();
  182. rf.leader_id = args.leader_id;
  183. if rf.log.len() <= args.prev_log_index
  184. || rf.log[args.prev_log_index].term != args.term
  185. {
  186. return AppendEntriesReply {
  187. term: args.term,
  188. success: false,
  189. };
  190. }
  191. for (i, entry) in args.entries.iter().enumerate() {
  192. let index = i + args.prev_log_index + 1;
  193. if rf.log.len() > index {
  194. if rf.log[index].term != entry.term {
  195. rf.log.truncate(index);
  196. rf.log.push(entry.clone());
  197. }
  198. } else {
  199. rf.log.push(entry.clone());
  200. }
  201. }
  202. if args.leader_commit > rf.commit_index {
  203. rf.commit_index = if args.leader_commit < rf.log.len() {
  204. args.leader_commit
  205. } else {
  206. rf.log.len() - 1
  207. };
  208. // TODO: apply commands.
  209. }
  210. AppendEntriesReply {
  211. term: args.term,
  212. success: true,
  213. }
  214. }
  215. fn run_election(&self) {
  216. let me = self.me;
  217. let (term, args, cancel_token) = {
  218. let mut rf = self.inner_state.lock();
  219. let (tx, rx) = futures::channel::oneshot::channel();
  220. rf.current_term.0 += 1;
  221. rf.voted_for = Some(self.me);
  222. rf.state = State::Candidate;
  223. rf.reset_election_timer();
  224. rf.stop_current_election();
  225. rf.election_cancel_token.replace(tx);
  226. rf.persist();
  227. let term = rf.current_term;
  228. let (last_log_index, last_log_term) = rf.last_log_index_and_term();
  229. (
  230. term,
  231. RequestVoteArgs {
  232. term,
  233. candidate_id: me,
  234. last_log_index,
  235. last_log_term,
  236. },
  237. rx,
  238. )
  239. };
  240. let mut votes = vec![];
  241. for (index, rpc_client) in self.peers.iter().enumerate() {
  242. if index != self.me.0 {
  243. // RpcClient must be cloned to avoid sending its reference
  244. // across threads.
  245. let rpc_client = rpc_client.clone();
  246. // RPCs are started right away.
  247. let one_vote = tokio::spawn(Self::request_one_vote(
  248. rpc_client,
  249. args.clone(),
  250. ));
  251. // Futures must be pinned so that they have Unpin, as required
  252. // by futures::future::select.
  253. votes.push(one_vote);
  254. }
  255. }
  256. tokio::spawn(Self::count_vote_util_cancelled(
  257. term,
  258. self.inner_state.clone(),
  259. votes,
  260. self.peers.len() / 2,
  261. cancel_token,
  262. ));
  263. }
  264. const REQUEST_VOTE_RETRY: usize = 4;
  265. async fn request_one_vote(
  266. rpc_client: RpcClient,
  267. args: RequestVoteArgs,
  268. ) -> Option<bool> {
  269. let term = args.term;
  270. let reply = retry_rpc(Self::REQUEST_VOTE_RETRY, move |_round| {
  271. rpc_client.clone().call_request_vote(args.clone())
  272. })
  273. .await;
  274. if let Ok(reply) = reply {
  275. return Some(reply.vote_granted && reply.term == term);
  276. }
  277. return None;
  278. }
  279. async fn count_vote_util_cancelled(
  280. term: Term,
  281. rf: Arc<Mutex<RaftState>>,
  282. votes: Vec<tokio::task::JoinHandle<Option<bool>>>,
  283. majority: usize,
  284. cancel_token: futures::channel::oneshot::Receiver<Term>,
  285. ) {
  286. let mut vote_count = 0;
  287. let mut against_count = 0;
  288. let mut cancel_token = cancel_token;
  289. let mut futures_vec = votes;
  290. while vote_count < majority && against_count <= majority {
  291. // Mixing tokio futures with futures-rs ones. Fingers crossed.
  292. let selected = futures::future::select(
  293. cancel_token,
  294. futures::future::select_all(futures_vec),
  295. )
  296. .await;
  297. let ((one_vote, _, rest), new_token) = match selected {
  298. futures::future::Either::Left(_) => break,
  299. futures::future::Either::Right(tuple) => tuple,
  300. };
  301. futures_vec = rest;
  302. cancel_token = new_token;
  303. if let Ok(Some(vote)) = one_vote {
  304. if vote {
  305. vote_count += 1
  306. } else {
  307. against_count += 1
  308. }
  309. }
  310. }
  311. if vote_count < majority {
  312. return;
  313. }
  314. let mut rf = rf.lock();
  315. if rf.current_term == term && rf.state == State::Candidate {
  316. rf.state = State::Leader;
  317. }
  318. let log_len = rf.log.len();
  319. for item in rf.next_index.iter_mut() {
  320. *item = log_len;
  321. }
  322. for item in rf.match_index.iter_mut() {
  323. *item = 0;
  324. }
  325. // TODO: send heartbeats.
  326. // Drop the timer and cancel token.
  327. rf.election_cancel_token.take();
  328. rf.election_timer.take();
  329. rf.persist();
  330. }
  331. fn schedule_heartbeats(&self, interval: Duration) {
  332. for (peer_index, rpc_client) in self.peers.iter().enumerate() {
  333. if peer_index != self.me.0 {
  334. // Interval and rf are now owned by the outer async function.
  335. let mut interval = tokio::time::interval(interval);
  336. let rf = self.inner_state.clone();
  337. // RPC client must be cloned into the outer async function.
  338. let rpc_client = rpc_client.clone();
  339. tokio::spawn(async move {
  340. loop {
  341. // TODO: shutdown signal or cancel token.
  342. interval.tick().await;
  343. if let Some(args) = Self::build_heartbeat(&rf) {
  344. tokio::spawn(Self::send_heartbeat(
  345. rpc_client.clone(),
  346. args,
  347. ));
  348. }
  349. }
  350. });
  351. }
  352. }
  353. }
  354. fn build_heartbeat(
  355. rf: &Arc<Mutex<RaftState>>,
  356. ) -> Option<AppendEntriesArgs> {
  357. let rf = rf.lock();
  358. // copy states.
  359. let term = rf.current_term;
  360. let is_leader = rf.state == State::Leader;
  361. let (last_log_index, last_log_term) = rf.last_log_index_and_term();
  362. let commit_index = rf.commit_index;
  363. let leader_id = rf.leader_id;
  364. if !is_leader {
  365. return None;
  366. }
  367. let args = AppendEntriesArgs {
  368. term,
  369. leader_id,
  370. prev_log_index: last_log_index,
  371. prev_log_term: last_log_term,
  372. entries: vec![],
  373. leader_commit: commit_index,
  374. };
  375. Some(args)
  376. }
  377. const HEARTBEAT_RETRY: usize = 3;
  378. async fn send_heartbeat(
  379. rpc_client: RpcClient,
  380. args: AppendEntriesArgs,
  381. ) -> std::io::Result<()> {
  382. retry_rpc(Self::HEARTBEAT_RETRY, move |_round| {
  383. rpc_client.clone().call_append_entries(args.clone())
  384. })
  385. .await?;
  386. Ok(())
  387. }
  388. fn run_log_entry_daemon(
  389. &self,
  390. ) -> (
  391. std::thread::JoinHandle<()>,
  392. std::sync::mpsc::Sender<Option<Peer>>,
  393. ) {
  394. let (tx, rx) = std::sync::mpsc::channel::<Option<Peer>>();
  395. // Clone everything that the thread needs.
  396. let rerun = tx.clone();
  397. let peers = self.peers.clone();
  398. let rf = self.inner_state.clone();
  399. let me = self.me;
  400. let handle = std::thread::spawn(move || {
  401. while let Ok(peer) = rx.recv() {
  402. for (i, rpc_client) in peers.iter().enumerate() {
  403. if i != me.0 && peer.map(|p| p.0 == i).unwrap_or(true) {
  404. let rf = rf.clone();
  405. let rpc_client = rpc_client.clone();
  406. let rerun = rerun.clone();
  407. let peer_index = i;
  408. tokio::spawn(async move {
  409. // TODO: cancel in flight changes?
  410. let args =
  411. Self::build_append_entries(&rf, peer_index);
  412. let succeeded =
  413. Self::append_entries(rpc_client, args).await;
  414. match succeeded {
  415. Ok(done) => {
  416. if !done {
  417. let mut rf = rf.lock();
  418. let step =
  419. &mut rf.current_step[peer_index];
  420. *step += 1;
  421. let diff = (1 << 8) << *step;
  422. let next_index =
  423. &mut rf.next_index[peer_index];
  424. if diff >= *next_index {
  425. *next_index = 1usize;
  426. } else {
  427. *next_index -= diff;
  428. }
  429. rerun.send(Some(Peer(peer_index)));
  430. }
  431. }
  432. Err(_) => {
  433. tokio::time::delay_for(
  434. Duration::from_millis(
  435. HEARTBEAT_INTERVAL_MILLIS,
  436. ),
  437. )
  438. .await;
  439. rerun.send(Some(Peer(peer_index)));
  440. }
  441. };
  442. });
  443. }
  444. }
  445. }
  446. });
  447. (handle, tx)
  448. }
  449. fn build_append_entries(
  450. rf: &Arc<Mutex<RaftState>>,
  451. peer_index: usize,
  452. ) -> AppendEntriesArgs {
  453. let rf = rf.lock();
  454. let (prev_log_index, prev_log_term) = rf.last_log_index_and_term();
  455. AppendEntriesArgs {
  456. term: rf.current_term,
  457. leader_id: rf.leader_id,
  458. prev_log_index,
  459. prev_log_term,
  460. entries: rf.log[rf.next_index[peer_index]..].to_vec(),
  461. leader_commit: rf.commit_index,
  462. }
  463. }
  464. const APPEND_ENTRIES_RETRY: usize = 3;
  465. async fn append_entries(
  466. rpc_client: RpcClient,
  467. args: AppendEntriesArgs,
  468. ) -> std::io::Result<bool> {
  469. let term = args.term;
  470. let reply = retry_rpc(Self::APPEND_ENTRIES_RETRY, move |_round| {
  471. rpc_client.clone().call_append_entries(args.clone())
  472. })
  473. .await?;
  474. Ok(reply.term != term || reply.success)
  475. }
  476. }
  477. const HEARTBEAT_INTERVAL_MILLIS: u64 = 150;
  478. const ELECTION_TIMEOUT_BASE_MILLIS: u64 = 150;
  479. const ELECTION_TIMEOUT_VAR_MILLIS: u64 = 250;
  480. impl RaftState {
  481. fn reset_election_timer(&mut self) {
  482. self.election_timer.as_mut().map(|timer| {
  483. timer.reset(
  484. (std::time::Instant::now() + Self::election_timeout()).into(),
  485. )
  486. });
  487. }
  488. fn election_timeout() -> Duration {
  489. Duration::from_millis(
  490. ELECTION_TIMEOUT_BASE_MILLIS
  491. + thread_rng().gen_range(0, ELECTION_TIMEOUT_VAR_MILLIS),
  492. )
  493. }
  494. fn stop_current_election(&mut self) {
  495. self.election_cancel_token
  496. .take()
  497. .map(|sender| sender.send(self.current_term));
  498. }
  499. fn persist(&self) {
  500. // TODO: implement
  501. }
  502. fn last_log_index_and_term(&self) -> (usize, Term) {
  503. let len = self.log.len();
  504. assert!(len > 0, "There should always be at least one entry in log");
  505. (len - 1, self.log.last().unwrap().term)
  506. }
  507. }