lib.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. use std::convert::TryFrom;
  2. use std::sync::atomic::{AtomicBool, Ordering};
  3. use std::sync::Arc;
  4. use std::time::Duration;
  5. use crossbeam_utils::sync::WaitGroup;
  6. use parking_lot::{Condvar, Mutex};
  7. use serde_derive::{Deserialize, Serialize};
  8. use crate::apply_command::ApplyCommandFnMut;
  9. pub use crate::apply_command::ApplyCommandMessage;
  10. use crate::daemon_env::{DaemonEnv, ThreadEnv};
  11. use crate::election::ElectionState;
  12. use crate::index_term::IndexTerm;
  13. use crate::persister::PersistedRaftState;
  14. pub use crate::persister::Persister;
  15. pub(crate) use crate::raft_state::RaftState;
  16. pub(crate) use crate::raft_state::State;
  17. pub use crate::remote_raft::RemoteRaft;
  18. pub use crate::snapshot::Snapshot;
  19. use crate::snapshot::{RequestSnapshotFnMut, SnapshotDaemon};
  20. mod apply_command;
  21. mod daemon_env;
  22. mod election;
  23. mod heartbeats;
  24. mod index_term;
  25. mod log_array;
  26. mod persister;
  27. mod process_append_entries;
  28. mod process_install_snapshot;
  29. mod process_request_vote;
  30. mod raft_state;
  31. mod remote_raft;
  32. mod snapshot;
  33. mod sync_log_entries;
  34. mod term_marker;
  35. pub mod utils;
  36. #[derive(
  37. Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize,
  38. )]
  39. pub struct Term(pub usize);
  40. #[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
  41. struct Peer(usize);
  42. pub type Index = usize;
  43. #[derive(Clone, Debug, Serialize, Deserialize)]
  44. struct LogEntry<Command> {
  45. index: Index,
  46. term: Term,
  47. command: Command,
  48. }
  49. #[derive(Clone)]
  50. pub struct Raft<Command> {
  51. inner_state: Arc<Mutex<RaftState<Command>>>,
  52. peers: Vec<Arc<dyn RemoteRaft<Command>>>,
  53. me: Peer,
  54. persister: Arc<dyn Persister>,
  55. new_log_entry: Option<utils::SharedSender<Option<Peer>>>,
  56. apply_command_signal: Arc<Condvar>,
  57. keep_running: Arc<AtomicBool>,
  58. election: Arc<ElectionState>,
  59. snapshot_daemon: SnapshotDaemon,
  60. thread_pool: Arc<tokio::runtime::Runtime>,
  61. daemon_env: DaemonEnv,
  62. stop_wait_group: WaitGroup,
  63. }
  64. #[derive(Clone, Debug, Serialize, Deserialize)]
  65. pub struct RequestVoteArgs {
  66. term: Term,
  67. candidate_id: Peer,
  68. last_log_index: Index,
  69. last_log_term: Term,
  70. }
  71. #[derive(Clone, Debug, Serialize, Deserialize)]
  72. pub struct RequestVoteReply {
  73. term: Term,
  74. vote_granted: bool,
  75. }
  76. #[derive(Clone, Debug, Serialize, Deserialize)]
  77. pub struct AppendEntriesArgs<Command> {
  78. term: Term,
  79. leader_id: Peer,
  80. prev_log_index: Index,
  81. prev_log_term: Term,
  82. entries: Vec<LogEntry<Command>>,
  83. leader_commit: Index,
  84. }
  85. #[derive(Clone, Debug, Serialize, Deserialize)]
  86. pub struct AppendEntriesReply {
  87. term: Term,
  88. success: bool,
  89. committed: Option<IndexTerm>,
  90. }
  91. #[derive(Clone, Debug, Serialize, Deserialize)]
  92. pub struct InstallSnapshotArgs {
  93. term: Term,
  94. leader_id: Peer,
  95. last_included_index: Index,
  96. last_included_term: Term,
  97. // TODO(ditsing): Serde cannot handle Vec<u8> as efficient as expected.
  98. data: Vec<u8>,
  99. offset: usize,
  100. done: bool,
  101. }
  102. #[derive(Clone, Debug, Serialize, Deserialize)]
  103. pub struct InstallSnapshotReply {
  104. term: Term,
  105. committed: Option<IndexTerm>,
  106. }
  107. // Commands must be
  108. // 0. 'static: they have to live long enough for thread pools.
  109. // 1. clone: they are put in vectors and request messages.
  110. // 2. serializable: they are sent over RPCs and persisted.
  111. // 3. deserializable: they are restored from storage.
  112. // 4. send: they are referenced in futures.
  113. // 5. default, because we need an element for the first entry.
  114. impl<Command> Raft<Command>
  115. where
  116. Command: 'static
  117. + Clone
  118. + serde::Serialize
  119. + serde::de::DeserializeOwned
  120. + Send
  121. + Default,
  122. {
  123. /// Create a new raft instance.
  124. ///
  125. /// Each instance will create at least 4 + (number of peers) threads. The
  126. /// extensive usage of threads is to minimize latency.
  127. pub fn new(
  128. peers: Vec<impl RemoteRaft<Command> + 'static>,
  129. me: usize,
  130. persister: Arc<dyn Persister>,
  131. apply_command: impl ApplyCommandFnMut<Command>,
  132. max_state_size_bytes: Option<usize>,
  133. request_snapshot: impl RequestSnapshotFnMut,
  134. ) -> Self {
  135. let peer_size = peers.len();
  136. assert!(peer_size > me, "My index should be smaller than peer size.");
  137. let mut state = RaftState::create(peer_size, Peer(me));
  138. // COMMIT_INDEX_INVARIANT, SNAPSHOT_INDEX_INVARIANT: Initially
  139. // commit_index = log.start() and commit_index + 1 = log.end(). Thus
  140. // log.start() <= commit_index and commit_index < log.end() both hold.
  141. assert_eq!(state.commit_index + 1, state.log.end());
  142. if let Ok(persisted_state) =
  143. PersistedRaftState::try_from(persister.read_state())
  144. {
  145. state.current_term = persisted_state.current_term;
  146. state.voted_for = persisted_state.voted_for;
  147. state.log = persisted_state.log;
  148. state.commit_index = state.log.start();
  149. // COMMIT_INDEX_INVARIANT, SNAPSHOT_INDEX_INVARIANT: the saved
  150. // snapshot must have a valid log.start() and log.end(). Thus
  151. // log.start() <= commit_index and commit_index < log.end() hold.
  152. assert!(state.commit_index < state.log.end());
  153. state
  154. .log
  155. .validate(state.current_term)
  156. .expect("Persisted log should not contain error");
  157. }
  158. let election = ElectionState::create();
  159. election.reset_election_timer();
  160. let daemon_env = DaemonEnv::create();
  161. let thread_env = daemon_env.for_thread();
  162. let thread_pool = tokio::runtime::Builder::new_multi_thread()
  163. .enable_time()
  164. .enable_io()
  165. .thread_name(format!("raft-instance-{}", me))
  166. .worker_threads(peer_size)
  167. .on_thread_start(move || thread_env.clone().attach())
  168. .on_thread_stop(ThreadEnv::detach)
  169. .build()
  170. .expect("Creating thread pool should not fail");
  171. let peers = peers
  172. .into_iter()
  173. .map(|r| Arc::new(r) as Arc<dyn RemoteRaft<Command>>)
  174. .collect();
  175. let mut this = Raft {
  176. inner_state: Arc::new(Mutex::new(state)),
  177. peers,
  178. me: Peer(me),
  179. persister,
  180. new_log_entry: None,
  181. apply_command_signal: Arc::new(Default::default()),
  182. keep_running: Arc::new(Default::default()),
  183. election: Arc::new(election),
  184. snapshot_daemon: Default::default(),
  185. thread_pool: Arc::new(thread_pool),
  186. daemon_env,
  187. stop_wait_group: WaitGroup::new(),
  188. };
  189. this.keep_running.store(true, Ordering::SeqCst);
  190. // Running in a standalone thread.
  191. this.run_snapshot_daemon(max_state_size_bytes, request_snapshot);
  192. // Running in a standalone thread.
  193. this.run_log_entry_daemon();
  194. // Running in a standalone thread.
  195. this.run_apply_command_daemon(apply_command);
  196. // One off function that schedules many little tasks, running on the
  197. // internal thread pool.
  198. this.schedule_heartbeats(Duration::from_millis(
  199. HEARTBEAT_INTERVAL_MILLIS,
  200. ));
  201. // The last step is to start running election timer.
  202. this.run_election_timer();
  203. this
  204. }
  205. }
  206. // Command must be
  207. // 0. 'static: Raft<Command> must be 'static, it is moved to another thread.
  208. // 1. clone: they are copied to the persister.
  209. // 2. send: Arc<Mutex<Vec<LogEntry<Command>>>> must be send, it is moved to another thread.
  210. // 3. serialize: they are converted to bytes to persist.
  211. // 4. default: a default value is used as the first element of log.
  212. impl<Command> Raft<Command>
  213. where
  214. Command: 'static + Clone + Send + serde::Serialize + Default,
  215. {
  216. /// Adds a new command to the log, returns its index and the current term.
  217. ///
  218. /// Returns `None` if we are not the leader. The log entry may not have been
  219. /// committed to the log when this method returns. When and if it is
  220. /// committed, the `apply_command` callback will be called.
  221. pub fn start(&self, command: Command) -> Option<(Term, Index)> {
  222. let mut rf = self.inner_state.lock();
  223. let term = rf.current_term;
  224. if !rf.is_leader() {
  225. return None;
  226. }
  227. let index = rf.log.add_command(term, command);
  228. self.persister.save_state(rf.persisted_state().into());
  229. // Several attempts have been made to remove the unwrap below.
  230. let _ = self.new_log_entry.as_ref().unwrap().send(None);
  231. log::info!("{:?} started new entry at {} {:?}", self.me, index, term);
  232. Some((term, index))
  233. }
  234. /// Cleanly shutdown this instance. This function never blocks forever. It
  235. /// either panics or returns eventually.
  236. pub fn kill(mut self) {
  237. self.keep_running.store(false, Ordering::SeqCst);
  238. self.election.stop_election_timer();
  239. self.new_log_entry.take().map(|n| n.send(None));
  240. self.apply_command_signal.notify_all();
  241. self.snapshot_daemon.kill();
  242. // We cannot easily combine stop_wait_group into DaemonEnv because of
  243. // shutdown dependencies. The thread pool is not managed by DaemonEnv,
  244. // but it cannot be shutdown until all daemons are. On the other hand
  245. // the thread pool uses DaemonEnv, thus must be shutdown before
  246. // DaemonEnv. The shutdown sequence is stop_wait_group -> thread_pool
  247. // -> DaemonEnv. The first and third cannot be combined with the second
  248. // in the middle.
  249. self.stop_wait_group.wait();
  250. std::sync::Arc::try_unwrap(self.thread_pool)
  251. .expect(
  252. "All references to the thread pool should have been dropped.",
  253. )
  254. .shutdown_timeout(Duration::from_millis(
  255. HEARTBEAT_INTERVAL_MILLIS * 2,
  256. ));
  257. // DaemonEnv must be shutdown after the thread pool, since there might
  258. // be tasks logging errors in the pool.
  259. self.daemon_env.shutdown();
  260. }
  261. /// Returns the current term and whether we are the leader.
  262. ///
  263. /// Take a quick peek at the current state of this instance. The returned
  264. /// value is stale as soon as this function returns.
  265. pub fn get_state(&self) -> (Term, bool) {
  266. let state = self.inner_state.lock();
  267. (state.current_term, state.is_leader())
  268. }
  269. }
  270. pub(crate) const HEARTBEAT_INTERVAL_MILLIS: u64 = 150;
  271. #[cfg(test)]
  272. mod tests {
  273. #[test]
  274. fn test_raft_must_sync() {
  275. let optional_raft: Option<super::Raft<i32>> = None;
  276. fn must_sync<T: Sync>(value: T) {
  277. drop(value)
  278. }
  279. must_sync(optional_raft)
  280. // The following raft is not Sync.
  281. // let optional_raft: Option<super::Raft<std::rc::Rc<i32>>> = None;
  282. }
  283. }