raft.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. use std::sync::atomic::{AtomicBool, Ordering};
  2. use std::sync::Arc;
  3. use std::time::Duration;
  4. use crossbeam_utils::sync::WaitGroup;
  5. use parking_lot::{Condvar, Mutex};
  6. use serde_derive::{Deserialize, Serialize};
  7. use crate::apply_command::ApplyCommandFnMut;
  8. use crate::daemon_env::{DaemonEnv, ThreadEnv};
  9. use crate::daemon_watch::{Daemon, DaemonWatch};
  10. use crate::election::ElectionState;
  11. use crate::heartbeats::{HeartbeatsDaemon, HEARTBEAT_INTERVAL};
  12. use crate::persister::PersistedRaftState;
  13. use crate::remote_context::RemoteContext;
  14. use crate::remote_peer::RemotePeer;
  15. use crate::snapshot::{RequestSnapshotFnMut, SnapshotDaemon};
  16. use crate::sync_log_entries::SyncLogEntriesComms;
  17. use crate::term_marker::TermMarker;
  18. use crate::verify_authority::VerifyAuthorityDaemon;
  19. use crate::{IndexTerm, Persister, RaftState, RemoteRaft, ReplicableCommand};
  20. #[derive(
  21. Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize,
  22. )]
  23. pub struct Term(pub usize);
  24. #[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
  25. pub struct Peer(pub usize);
  26. #[derive(Clone)]
  27. pub struct Raft<Command> {
  28. pub(crate) inner_state: Arc<Mutex<RaftState<Command>>>,
  29. pub(crate) peers: Vec<Peer>,
  30. pub(crate) me: Peer,
  31. pub(crate) persister: Arc<dyn Persister>,
  32. pub(crate) sync_log_entries_comms: SyncLogEntriesComms,
  33. pub(crate) apply_command_signal: Arc<Condvar>,
  34. pub(crate) keep_running: Arc<AtomicBool>,
  35. pub(crate) election: Arc<ElectionState>,
  36. pub(crate) snapshot_daemon: SnapshotDaemon,
  37. pub(crate) verify_authority_daemon: VerifyAuthorityDaemon,
  38. pub(crate) heartbeats_daemon: HeartbeatsDaemon,
  39. pub(crate) thread_pool: tokio::runtime::Handle,
  40. pub(crate) daemon_env: DaemonEnv,
  41. stop_wait_group: WaitGroup,
  42. join_handle: Arc<Mutex<Option<RaftJoinHandle>>>,
  43. }
  44. impl<Command: ReplicableCommand> Raft<Command> {
  45. /// Create a new raft instance.
  46. ///
  47. /// Each instance will create at least 4 + (number of peers) threads. The
  48. /// extensive usage of threads is to minimize latency.
  49. pub fn new(
  50. peers: Vec<impl RemoteRaft<Command> + 'static>,
  51. me: usize,
  52. persister: impl Persister + 'static,
  53. apply_command: impl ApplyCommandFnMut<Command>,
  54. max_state_size_bytes: Option<usize>,
  55. request_snapshot: impl RequestSnapshotFnMut,
  56. ) -> Self {
  57. let peer_size = peers.len();
  58. assert!(peer_size > me, "My index should be smaller than peer size.");
  59. let mut state = RaftState::create(peer_size, Peer(me));
  60. // COMMIT_INDEX_INVARIANT, SNAPSHOT_INDEX_INVARIANT: Initially
  61. // commit_index = log.start() and commit_index + 1 = log.end(). Thus
  62. // log.start() <= commit_index and commit_index < log.end() both hold.
  63. assert_eq!(state.commit_index + 1, state.log.end());
  64. if let Ok(persisted_state) =
  65. PersistedRaftState::try_from(persister.read_state())
  66. {
  67. state.current_term = persisted_state.current_term;
  68. state.voted_for = persisted_state.voted_for;
  69. state.log = persisted_state.log;
  70. state.commit_index = state.log.start();
  71. // COMMIT_INDEX_INVARIANT, SNAPSHOT_INDEX_INVARIANT: the saved
  72. // snapshot must have a valid log.start() and log.end(). Thus
  73. // log.start() <= commit_index and commit_index < log.end() hold.
  74. assert!(state.commit_index < state.log.end());
  75. state
  76. .log
  77. .validate(state.current_term)
  78. .expect("Persisted log should not contain error");
  79. }
  80. let inner_state = Arc::new(Mutex::new(state));
  81. let election = Arc::new(ElectionState::create());
  82. election.reset_election_timer();
  83. let persister = Arc::new(persister);
  84. let term_marker = TermMarker::create(
  85. inner_state.clone(),
  86. election.clone(),
  87. persister.clone(),
  88. );
  89. let verify_authority_daemon = VerifyAuthorityDaemon::create(peer_size);
  90. let remote_peers = peers
  91. .into_iter()
  92. .enumerate()
  93. .map(|(index, remote_raft)| {
  94. RemotePeer::create(
  95. Peer(index),
  96. remote_raft,
  97. verify_authority_daemon.beat_ticker(index),
  98. )
  99. })
  100. .collect();
  101. let context = RemoteContext::create(term_marker, remote_peers);
  102. let daemon_env = DaemonEnv::create();
  103. let thread_env = daemon_env.for_thread();
  104. let thread_pool = tokio::runtime::Builder::new_multi_thread()
  105. .enable_time()
  106. .enable_io()
  107. .thread_name(format!("raft-instance-{}", me))
  108. .worker_threads(peer_size)
  109. .on_thread_start(move || {
  110. context.clone().attach();
  111. thread_env.clone().attach();
  112. })
  113. .on_thread_stop(move || {
  114. RemoteContext::<Command>::detach();
  115. ThreadEnv::detach();
  116. })
  117. .build()
  118. .expect("Creating thread pool should not fail");
  119. let peers = (0..peer_size).filter(|p| *p != me).map(Peer).collect();
  120. let (sync_log_entries_comms, sync_log_entries_daemon) =
  121. crate::sync_log_entries::create(peer_size);
  122. let mut this = Raft {
  123. inner_state,
  124. peers,
  125. me: Peer(me),
  126. persister,
  127. sync_log_entries_comms,
  128. apply_command_signal: Arc::new(Condvar::new()),
  129. keep_running: Arc::new(AtomicBool::new(true)),
  130. election,
  131. snapshot_daemon: SnapshotDaemon::create(),
  132. verify_authority_daemon,
  133. heartbeats_daemon: HeartbeatsDaemon::create(),
  134. thread_pool: thread_pool.handle().clone(),
  135. stop_wait_group: WaitGroup::new(),
  136. daemon_env: daemon_env.clone(),
  137. // The join handle will be created later.
  138. join_handle: Arc::new(Mutex::new(None)),
  139. };
  140. let mut daemon_watch = DaemonWatch::create(daemon_env.for_thread());
  141. // Running in a standalone thread.
  142. let verify_authority_daemon = this.run_verify_authority_daemon();
  143. daemon_watch
  144. .create_daemon(Daemon::VerifyAuthority, verify_authority_daemon);
  145. // Running in a standalone thread.
  146. let snapshot_daemon =
  147. this.run_snapshot_daemon(max_state_size_bytes, request_snapshot);
  148. daemon_watch.create_daemon(Daemon::Snapshot, snapshot_daemon);
  149. // Running in a standalone thread.
  150. let sync_log_entry_daemon =
  151. this.run_log_entry_daemon(sync_log_entries_daemon);
  152. daemon_watch
  153. .create_daemon(Daemon::SyncLogEntries, sync_log_entry_daemon);
  154. // Running in a standalone thread.
  155. let apply_command_daemon = this.run_apply_command_daemon(apply_command);
  156. daemon_watch.create_daemon(Daemon::ApplyCommand, apply_command_daemon);
  157. // One off function that schedules many little tasks, running on the
  158. // internal thread pool.
  159. this.schedule_heartbeats(HEARTBEAT_INTERVAL);
  160. // The last step is to start running election timer.
  161. daemon_watch.create_daemon(Daemon::ElectionTimer, {
  162. let raft = this.clone();
  163. move || raft.run_election_timer()
  164. });
  165. // Create the join handle
  166. this.join_handle.lock().replace(RaftJoinHandle {
  167. stop_wait_group: this.stop_wait_group.clone(),
  168. thread_pool,
  169. daemon_watch,
  170. daemon_env,
  171. });
  172. this
  173. }
  174. }
  175. // Command must be
  176. // 0. 'static: Raft<Command> must be 'static, it is moved to another thread.
  177. // 1. clone: they are copied to the persister.
  178. // 2. send: Arc<Mutex<Vec<LogEntry<Command>>>> must be send, it is moved to another thread.
  179. // 3. serialize: they are converted to bytes to persist.
  180. impl<Command: ReplicableCommand> Raft<Command> {
  181. /// Adds a new command to the log, returns its index and the current term.
  182. ///
  183. /// Returns `None` if we are not the leader. The log entry may not have been
  184. /// committed to the log when this method returns. When and if it is
  185. /// committed, the `apply_command` callback will be called.
  186. pub fn start(&self, command: Command) -> Option<IndexTerm> {
  187. let _guard = self.daemon_env.for_scope();
  188. let mut rf = self.inner_state.lock();
  189. let term = rf.current_term;
  190. if !rf.is_leader() {
  191. return None;
  192. }
  193. let index = rf.log.add_command(term, command);
  194. self.persister.save_state(rf.persisted_state().into());
  195. self.sync_log_entries_comms.update_followers(index);
  196. log::info!("{:?} started new entry at {} {:?}", self.me, index, term);
  197. Some(IndexTerm::pack(index, term))
  198. }
  199. /// Cleanly shutdown this instance. This function never blocks forever. It
  200. /// either panics or returns eventually.
  201. pub fn kill(self) -> RaftJoinHandle {
  202. self.keep_running.store(false, Ordering::Release);
  203. self.election.stop_election_timer();
  204. self.sync_log_entries_comms.kill();
  205. self.apply_command_signal.notify_all();
  206. self.snapshot_daemon.kill();
  207. self.verify_authority_daemon.kill();
  208. self.join_handle.lock().take().unwrap()
  209. }
  210. /// Returns the current term and whether we are the leader.
  211. ///
  212. /// Take a quick peek at the current state of this instance. The returned
  213. /// value is stale as soon as this function returns.
  214. pub fn get_state(&self) -> (Term, bool) {
  215. let state = self.inner_state.lock();
  216. (state.current_term, state.is_leader())
  217. }
  218. }
  219. /// A join handle returned by `Raft::kill()`. Join this handle to cleanly
  220. /// shutdown a Raft instance.
  221. ///
  222. /// All clones of the same Raft instance created by `Raft::clone()` must be
  223. /// dropped before `RaftJoinHandle::join()` can return.
  224. ///
  225. /// After `RaftJoinHandle::join()` returns, all threads and thread pools created
  226. /// by this Raft instance will have stopped. No callbacks will be called. No new
  227. /// commits will be created by this Raft instance.
  228. #[must_use]
  229. pub struct RaftJoinHandle {
  230. stop_wait_group: WaitGroup,
  231. thread_pool: tokio::runtime::Runtime,
  232. daemon_watch: DaemonWatch,
  233. daemon_env: DaemonEnv,
  234. }
  235. impl RaftJoinHandle {
  236. const SHUTDOWN_TIMEOUT: Duration =
  237. Duration::from_millis(HEARTBEAT_INTERVAL.as_millis() as u64 * 2);
  238. /// Waits for the Raft instance to shutdown.
  239. ///
  240. /// See the struct documentation for more details.
  241. pub fn join(self) {
  242. // Wait for all Raft instances to be dropped.
  243. self.stop_wait_group.wait();
  244. self.daemon_watch.wait_for_daemons();
  245. self.thread_pool.shutdown_timeout(Self::SHUTDOWN_TIMEOUT);
  246. // DaemonEnv must be shutdown after the thread pool, since there might
  247. // be tasks logging errors in the pool.
  248. self.daemon_env.shutdown();
  249. }
  250. }
  251. #[cfg(test)]
  252. mod tests {
  253. use crate::utils::do_nothing::{DoNothingPersister, DoNothingRemoteRaft};
  254. use crate::ApplyCommandMessage;
  255. use super::*;
  256. #[test]
  257. fn test_raft_must_sync() {
  258. let optional_raft: Option<super::Raft<i32>> = None;
  259. fn must_sync<T: Sync>(value: T) {
  260. drop(value)
  261. }
  262. must_sync(optional_raft)
  263. // The following raft is not Sync.
  264. // let optional_raft: Option<super::Raft<std::rc::Rc<i32>>> = None;
  265. }
  266. #[test]
  267. fn test_no_me_in_peers() {
  268. let peer_size = 5;
  269. let me = 2;
  270. let raft = Raft::new(
  271. vec![DoNothingRemoteRaft {}; peer_size],
  272. me,
  273. DoNothingPersister {},
  274. |_: ApplyCommandMessage<i32>| {},
  275. None,
  276. |_| {},
  277. );
  278. assert_eq!(4, raft.peers.len());
  279. for peer in &raft.peers {
  280. assert_ne!(peer.0, me);
  281. }
  282. }
  283. }