lib.rs 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. extern crate bincode;
  2. extern crate futures_channel;
  3. extern crate futures_util;
  4. extern crate labrpc;
  5. extern crate rand;
  6. #[macro_use]
  7. extern crate serde_derive;
  8. extern crate tokio;
  9. use std::convert::TryFrom;
  10. use std::sync::atomic::{AtomicBool, Ordering};
  11. use std::sync::Arc;
  12. use std::time::Duration;
  13. use crossbeam_utils::sync::WaitGroup;
  14. use parking_lot::{Condvar, Mutex};
  15. use crate::apply_command::ApplyCommandFnMut;
  16. pub use crate::apply_command::ApplyCommandMessage;
  17. use crate::daemon_env::{DaemonEnv, ThreadEnv};
  18. use crate::election::ElectionState;
  19. use crate::index_term::IndexTerm;
  20. use crate::persister::PersistedRaftState;
  21. pub use crate::persister::Persister;
  22. pub(crate) use crate::raft_state::RaftState;
  23. pub(crate) use crate::raft_state::State;
  24. pub use crate::rpcs::RpcClient;
  25. pub use crate::snapshot::Snapshot;
  26. use crate::snapshot::{RequestSnapshotFnMut, SnapshotDaemon};
  27. mod apply_command;
  28. mod daemon_env;
  29. mod election;
  30. mod heartbeats;
  31. mod index_term;
  32. mod log_array;
  33. mod persister;
  34. mod process_append_entries;
  35. mod process_install_snapshot;
  36. mod process_request_vote;
  37. mod raft_state;
  38. pub mod rpcs;
  39. mod snapshot;
  40. mod sync_log_entries;
  41. mod term_marker;
  42. pub mod utils;
  43. #[derive(
  44. Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize,
  45. )]
  46. pub struct Term(pub usize);
  47. #[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
  48. struct Peer(usize);
  49. pub type Index = usize;
  50. #[derive(Clone, Debug, Serialize, Deserialize)]
  51. struct LogEntry<Command> {
  52. index: Index,
  53. term: Term,
  54. command: Command,
  55. }
  56. #[derive(Clone)]
  57. pub struct Raft<Command> {
  58. inner_state: Arc<Mutex<RaftState<Command>>>,
  59. peers: Vec<Arc<RpcClient>>,
  60. me: Peer,
  61. persister: Arc<dyn Persister>,
  62. new_log_entry: Option<std::sync::mpsc::Sender<Option<Peer>>>,
  63. apply_command_signal: Arc<Condvar>,
  64. keep_running: Arc<AtomicBool>,
  65. election: Arc<ElectionState>,
  66. snapshot_daemon: SnapshotDaemon,
  67. thread_pool: Arc<tokio::runtime::Runtime>,
  68. daemon_env: DaemonEnv,
  69. stop_wait_group: WaitGroup,
  70. }
  71. #[derive(Clone, Debug, Serialize, Deserialize)]
  72. struct RequestVoteArgs {
  73. term: Term,
  74. candidate_id: Peer,
  75. last_log_index: Index,
  76. last_log_term: Term,
  77. }
  78. #[derive(Clone, Debug, Serialize, Deserialize)]
  79. struct RequestVoteReply {
  80. term: Term,
  81. vote_granted: bool,
  82. }
  83. #[derive(Clone, Debug, Serialize, Deserialize)]
  84. struct AppendEntriesArgs<Command> {
  85. term: Term,
  86. leader_id: Peer,
  87. prev_log_index: Index,
  88. prev_log_term: Term,
  89. entries: Vec<LogEntry<Command>>,
  90. leader_commit: Index,
  91. }
  92. #[derive(Clone, Debug, Serialize, Deserialize)]
  93. struct AppendEntriesReply {
  94. term: Term,
  95. success: bool,
  96. committed: Option<IndexTerm>,
  97. }
  98. #[derive(Clone, Debug, Serialize, Deserialize)]
  99. struct InstallSnapshotArgs {
  100. pub(crate) term: Term,
  101. leader_id: Peer,
  102. pub(crate) last_included_index: Index,
  103. last_included_term: Term,
  104. // TODO(ditsing): Serde cannot handle Vec<u8> as efficient as expected.
  105. data: Vec<u8>,
  106. offset: usize,
  107. done: bool,
  108. }
  109. #[derive(Clone, Debug, Serialize, Deserialize)]
  110. struct InstallSnapshotReply {
  111. term: Term,
  112. committed: Option<IndexTerm>,
  113. }
  114. // Commands must be
  115. // 0. 'static: they have to live long enough for thread pools.
  116. // 1. clone: they are put in vectors and request messages.
  117. // 2. serializable: they are sent over RPCs and persisted.
  118. // 3. deserializable: they are restored from storage.
  119. // 4. send: they are referenced in futures.
  120. // 5. default, because we need an element for the first entry.
  121. impl<Command> Raft<Command>
  122. where
  123. Command: 'static
  124. + Clone
  125. + serde::Serialize
  126. + serde::de::DeserializeOwned
  127. + Send
  128. + Default,
  129. {
  130. /// Create a new raft instance.
  131. ///
  132. /// Each instance will create at least 4 + (number of peers) threads. The
  133. /// extensive usage of threads is to minimize latency.
  134. pub fn new(
  135. peers: Vec<RpcClient>,
  136. me: usize,
  137. persister: Arc<dyn Persister>,
  138. apply_command: impl ApplyCommandFnMut<Command>,
  139. max_state_size_bytes: Option<usize>,
  140. request_snapshot: impl RequestSnapshotFnMut,
  141. ) -> Self {
  142. let peer_size = peers.len();
  143. assert!(peer_size > me, "My index should be smaller than peer size.");
  144. let mut state = RaftState::create(peer_size, Peer(me));
  145. // COMMIT_INDEX_INVARIANT, SNAPSHOT_INDEX_INVARIANT: Initially
  146. // commit_index = log.start() and commit_index + 1 = log.end(). Thus
  147. // log.start() <= commit_index and commit_index < log.end() both hold.
  148. assert_eq!(state.commit_index + 1, state.log.end());
  149. if let Ok(persisted_state) =
  150. PersistedRaftState::try_from(persister.read_state())
  151. {
  152. state.current_term = persisted_state.current_term;
  153. state.voted_for = persisted_state.voted_for;
  154. state.log = persisted_state.log;
  155. state.commit_index = state.log.start();
  156. // COMMIT_INDEX_INVARIANT, SNAPSHOT_INDEX_INVARIANT: the saved
  157. // snapshot must have a valid log.start() and log.end(). Thus
  158. // log.start() <= commit_index and commit_index < log.end() hold.
  159. assert!(state.commit_index < state.log.end());
  160. }
  161. let election = ElectionState::create();
  162. election.reset_election_timer();
  163. let daemon_env = DaemonEnv::create();
  164. let thread_env = daemon_env.for_thread();
  165. let thread_pool = tokio::runtime::Builder::new_multi_thread()
  166. .enable_time()
  167. .thread_name(format!("raft-instance-{}", me))
  168. .worker_threads(peer_size)
  169. .on_thread_start(move || thread_env.clone().attach())
  170. .on_thread_stop(ThreadEnv::detach)
  171. .build()
  172. .expect("Creating thread pool should not fail");
  173. let peers = peers.into_iter().map(Arc::new).collect();
  174. let mut this = Raft {
  175. inner_state: Arc::new(Mutex::new(state)),
  176. peers,
  177. me: Peer(me),
  178. persister,
  179. new_log_entry: None,
  180. apply_command_signal: Arc::new(Default::default()),
  181. keep_running: Arc::new(Default::default()),
  182. election: Arc::new(election),
  183. snapshot_daemon: Default::default(),
  184. thread_pool: Arc::new(thread_pool),
  185. daemon_env,
  186. stop_wait_group: WaitGroup::new(),
  187. };
  188. this.keep_running.store(true, Ordering::SeqCst);
  189. // Running in a standalone thread.
  190. this.run_snapshot_daemon(max_state_size_bytes, request_snapshot);
  191. // Running in a standalone thread.
  192. this.run_log_entry_daemon();
  193. // Running in a standalone thread.
  194. this.run_apply_command_daemon(apply_command);
  195. // One off function that schedules many little tasks, running on the
  196. // internal thread pool.
  197. this.schedule_heartbeats(Duration::from_millis(
  198. HEARTBEAT_INTERVAL_MILLIS,
  199. ));
  200. // The last step is to start running election timer.
  201. this.run_election_timer();
  202. this
  203. }
  204. }
  205. // Command must be
  206. // 0. 'static: Raft<Command> must be 'static, it is moved to another thread.
  207. // 1. clone: they are copied to the persister.
  208. // 2. send: Arc<Mutex<Vec<LogEntry<Command>>>> must be send, it is moved to another thread.
  209. // 3. serialize: they are converted to bytes to persist.
  210. // 4. default: a default value is used as the first element of log.
  211. impl<Command> Raft<Command>
  212. where
  213. Command: 'static + Clone + Send + serde::Serialize + Default,
  214. {
  215. /// Adds a new command to the log, returns its index and the current term.
  216. ///
  217. /// Returns `None` if we are not the leader. The log entry may not have been
  218. /// committed to the log when this method returns. When and if it is
  219. /// committed, the `apply_command` callback will be called.
  220. pub fn start(&self, command: Command) -> Option<(Term, Index)> {
  221. let mut rf = self.inner_state.lock();
  222. let term = rf.current_term;
  223. if !rf.is_leader() {
  224. return None;
  225. }
  226. let index = rf.log.add_command(term, command);
  227. self.persister.save_state(rf.persisted_state().into());
  228. let _ = self.new_log_entry.clone().unwrap().send(None);
  229. Some((term, index))
  230. }
  231. /// Cleanly shutdown this instance. This function never blocks forever. It
  232. /// either panics or returns eventually.
  233. pub fn kill(mut self) {
  234. self.keep_running.store(false, Ordering::SeqCst);
  235. self.election.stop_election_timer();
  236. self.new_log_entry.take().map(|n| n.send(None));
  237. self.apply_command_signal.notify_all();
  238. self.snapshot_daemon.kill();
  239. self.stop_wait_group.wait();
  240. std::sync::Arc::try_unwrap(self.thread_pool)
  241. .expect(
  242. "All references to the thread pool should have been dropped.",
  243. )
  244. .shutdown_timeout(Duration::from_millis(
  245. HEARTBEAT_INTERVAL_MILLIS * 2,
  246. ));
  247. // DaemonEnv must be shutdown after the thread pool, since there might
  248. // be tasks logging errors in the pool.
  249. self.daemon_env.shutdown();
  250. }
  251. /// Returns the current term and whether we are the leader.
  252. ///
  253. /// Take a quick peek at the current state of this instance. The returned
  254. /// value is stale as soon as this function returns.
  255. pub fn get_state(&self) -> (Term, bool) {
  256. let state = self.inner_state.lock();
  257. (state.current_term, state.is_leader())
  258. }
  259. }
  260. pub(crate) const HEARTBEAT_INTERVAL_MILLIS: u64 = 150;
  261. impl<C> Raft<C> {
  262. /// Pass this function to [`Raft::new`] if the application will not accept
  263. /// any request for taking snapshots.
  264. pub const NO_SNAPSHOT: fn(Index) = |_| {};
  265. }