Bladeren bron

Add rpc_client and beat_ticker to the context.

Jing Yang 3 jaren geleden
bovenliggende
commit
7329e47622
8 gewijzigde bestanden met toevoegingen van 146 en 77 verwijderingen
  1. 9 16
      src/election.rs
  2. 8 25
      src/heartbeats.rs
  3. 1 0
      src/lib.rs
  4. 19 7
      src/raft.rs
  5. 68 6
      src/remote_context.rs
  6. 27 0
      src/remote_peer.rs
  7. 14 19
      src/sync_log_entries.rs
  8. 0 4
      src/verify_authority.rs

+ 9 - 16
src/election.rs

@@ -10,8 +10,8 @@ use crate::sync_log_entries::SyncLogEntriesComms;
 use crate::utils::{retry_rpc, RPC_DEADLINE};
 use crate::verify_authority::VerifyAuthorityDaemon;
 use crate::{
-    Peer, Persister, Raft, RaftState, RemoteRaft, ReplicableCommand,
-    RequestVoteArgs, State, Term,
+    Peer, Persister, Raft, RaftState, ReplicableCommand, RequestVoteArgs,
+    State, Term,
 };
 
 struct VersionedDeadline {
@@ -253,15 +253,11 @@ impl<Command: ReplicableCommand> Raft<Command> {
         };
 
         let mut votes = vec![];
-        for (index, rpc_client) in self.peers.iter().enumerate() {
-            if index != self.me.0 {
-                // RpcClient must be cloned so that it lives long enough for
-                // spawn(), which requires static life time.
-                // RPCs are started right away.
-                let one_vote = self.thread_pool.spawn(Self::request_vote(
-                    rpc_client.clone(),
-                    args.clone(),
-                ));
+        for peer in self.peers.clone().into_iter() {
+            if peer != self.me {
+                let one_vote = self
+                    .thread_pool
+                    .spawn(Self::request_vote(peer, args.clone()));
                 votes.push(one_vote);
             }
         }
@@ -282,13 +278,10 @@ impl<Command: ReplicableCommand> Raft<Command> {
     }
 
     const REQUEST_VOTE_RETRY: usize = 1;
-    async fn request_vote(
-        rpc_client: impl RemoteRaft<Command>,
-        args: RequestVoteArgs,
-    ) -> Option<bool> {
+    async fn request_vote(peer: Peer, args: RequestVoteArgs) -> Option<bool> {
         let term = args.term;
         // See the comment in send_heartbeat() for this override.
-        let rpc_client = &rpc_client;
+        let rpc_client = RemoteContext::<Command>::rpc_client(peer);
         let reply =
             retry_rpc(Self::REQUEST_VOTE_RETRY, RPC_DEADLINE, move |_round| {
                 rpc_client.request_vote(args.clone())

+ 8 - 25
src/heartbeats.rs

@@ -6,10 +6,7 @@ use parking_lot::Mutex;
 
 use crate::remote_context::RemoteContext;
 use crate::utils::{retry_rpc, RPC_DEADLINE};
-use crate::verify_authority::DaemonBeatTicker;
-use crate::{
-    AppendEntriesArgs, Raft, RaftState, RemoteRaft, ReplicableCommand,
-};
+use crate::{AppendEntriesArgs, Peer, Raft, RaftState, ReplicableCommand};
 
 pub(crate) const HEARTBEAT_INTERVAL: Duration = Duration::from_millis(150);
 
@@ -70,17 +67,12 @@ impl<Command: ReplicableCommand> Raft<Command> {
     /// The request message is a stripped down version of `AppendEntries`. The
     /// response from the peer is ignored.
     pub(crate) fn schedule_heartbeats(&self, interval: Duration) {
-        for (peer_index, rpc_client) in self.peers.iter().enumerate() {
-            if peer_index != self.me.0 {
+        for peer in self.peers.clone().into_iter() {
+            if peer != self.me {
                 // rf is now owned by the outer async function.
                 let rf = self.inner_state.clone();
-                // A function that casts an "authoritative" vote with Ok()
-                // responses to heartbeats.
-                let beat_ticker = self.beat_ticker(peer_index);
                 // A on-demand trigger to sending a heartbeat.
                 let mut trigger = self.heartbeats_daemon.sender.subscribe();
-                // RPC client must be cloned into the outer async function.
-                let rpc_client = rpc_client.clone();
                 // Shutdown signal.
                 let keep_running = self.keep_running.clone();
                 self.thread_pool.spawn(async move {
@@ -92,11 +84,7 @@ impl<Command: ReplicableCommand> Raft<Command> {
                         let _ =
                             futures_util::future::select(tick, trigger).await;
                         if let Some(args) = Self::build_heartbeat(&rf) {
-                            tokio::spawn(Self::send_heartbeat(
-                                rpc_client.clone(),
-                                args,
-                                beat_ticker.clone(),
-                            ));
+                            tokio::spawn(Self::send_heartbeat(peer, args));
                         }
                     }
                 });
@@ -127,17 +115,12 @@ impl<Command: ReplicableCommand> Raft<Command> {
 
     const HEARTBEAT_RETRY: usize = 1;
     async fn send_heartbeat(
-        // Here rpc_client must be owned by the returned future. The returned
-        // future is scheduled to run on a thread pool. We do not control when
-        // the future will be run, or when it will be done with the RPC client.
-        // If a reference is passed in, the reference essentially has to be a
-        // static one, i.e. lives forever. Thus we chose to let the future own
-        // the RPC client.
-        rpc_client: impl RemoteRaft<Command>,
+        peer: Peer,
         args: AppendEntriesArgs<Command>,
-        beat_ticker: DaemonBeatTicker,
     ) -> std::io::Result<()> {
         let term = args.term;
+        let beat_ticker = RemoteContext::<Command>::beat_ticker(peer);
+
         let beat = beat_ticker.next_beat();
         // Passing a reference that is moved to the following closure.
         //
@@ -153,7 +136,7 @@ impl<Command: ReplicableCommand> Raft<Command> {
         // Another option is to use non-move closures, in which case rpc_client
         // of type Arc can be passed-in directly. However that requires args to
         // be sync because they can be shared by more than one futures.
-        let rpc_client = &rpc_client;
+        let rpc_client = RemoteContext::<Command>::rpc_client(peer);
         let response =
             retry_rpc(Self::HEARTBEAT_RETRY, RPC_DEADLINE, move |_round| {
                 rpc_client.append_entries(args.clone())

+ 1 - 0
src/lib.rs

@@ -30,6 +30,7 @@ mod process_request_vote;
 mod raft;
 mod raft_state;
 mod remote_context;
+mod remote_peer;
 mod remote_raft;
 mod replicable_command;
 mod snapshot;

+ 19 - 7
src/raft.rs

@@ -13,6 +13,7 @@ use crate::election::ElectionState;
 use crate::heartbeats::{HeartbeatsDaemon, HEARTBEAT_INTERVAL};
 use crate::persister::PersistedRaftState;
 use crate::remote_context::RemoteContext;
+use crate::remote_peer::RemotePeer;
 use crate::snapshot::{RequestSnapshotFnMut, SnapshotDaemon};
 use crate::sync_log_entries::SyncLogEntriesComms;
 use crate::term_marker::TermMarker;
@@ -30,7 +31,7 @@ pub struct Peer(pub usize);
 #[derive(Clone)]
 pub struct Raft<Command> {
     pub(crate) inner_state: Arc<Mutex<RaftState<Command>>>,
-    pub(crate) peers: Vec<Arc<dyn RemoteRaft<Command>>>,
+    pub(crate) peers: Vec<Peer>,
 
     pub(crate) me: Peer,
 
@@ -99,7 +100,21 @@ impl<Command: ReplicableCommand> Raft<Command> {
             election.clone(),
             persister.clone(),
         );
-        let context = RemoteContext::create(term_marker);
+
+        let verify_authority_daemon = VerifyAuthorityDaemon::create(peer_size);
+        let remote_peers = peers
+            .into_iter()
+            .enumerate()
+            .map(|(index, remote_raft)| {
+                RemotePeer::create(
+                    Peer(index),
+                    remote_raft,
+                    verify_authority_daemon.beat_ticker(index),
+                )
+            })
+            .collect();
+
+        let context = RemoteContext::create(term_marker, remote_peers);
 
         let daemon_env = DaemonEnv::create();
         let thread_env = daemon_env.for_thread();
@@ -118,10 +133,7 @@ impl<Command: ReplicableCommand> Raft<Command> {
             })
             .build()
             .expect("Creating thread pool should not fail");
-        let peers = peers
-            .into_iter()
-            .map(|r| Arc::new(r) as Arc<dyn RemoteRaft<Command>>)
-            .collect();
+        let peers = (0..peer_size).map(Peer).collect();
         let (sync_log_entries_comms, sync_log_entries_daemon) =
             crate::sync_log_entries::create(peer_size);
 
@@ -135,7 +147,7 @@ impl<Command: ReplicableCommand> Raft<Command> {
             keep_running: Arc::new(AtomicBool::new(true)),
             election,
             snapshot_daemon: SnapshotDaemon::create(),
-            verify_authority_daemon: VerifyAuthorityDaemon::create(peer_size),
+            verify_authority_daemon,
             heartbeats_daemon: HeartbeatsDaemon::create(),
             thread_pool: thread_pool.handle().clone(),
             stop_wait_group: WaitGroup::new(),

+ 68 - 6
src/remote_context.rs

@@ -1,22 +1,44 @@
 use std::any::Any;
 use std::cell::RefCell;
 
+use crate::remote_peer::RemotePeer;
 use crate::term_marker::TermMarker;
+use crate::verify_authority::DaemonBeatTicker;
+use crate::{Peer, RemoteRaft};
 
 #[derive(Clone)]
 pub(crate) struct RemoteContext<Command> {
     term_marker: TermMarker<Command>,
+    remote_peers: Vec<RemotePeer<Command, Peer>>,
 }
 
 impl<Command: 'static> RemoteContext<Command> {
-    pub fn create(term_marker: TermMarker<Command>) -> Self {
-        Self { term_marker }
+    pub fn create(
+        term_marker: TermMarker<Command>,
+        remote_peers: Vec<RemotePeer<Command, Peer>>,
+    ) -> Self {
+        Self {
+            term_marker,
+            remote_peers,
+        }
     }
 
     pub fn term_marker() -> &'static TermMarker<Command> {
         &Self::fetch_context().term_marker
     }
 
+    pub fn remote_peer(peer: Peer) -> &'static RemotePeer<Command, Peer> {
+        &Self::fetch_context().remote_peers[peer.0]
+    }
+
+    pub fn rpc_client(peer: Peer) -> &'static dyn RemoteRaft<Command> {
+        Self::remote_peer(peer).rpc_client.as_ref()
+    }
+
+    pub fn beat_ticker(peer: Peer) -> &'static DaemonBeatTicker {
+        &Self::remote_peer(peer).beat_ticker
+    }
+
     thread_local! {
         // Using Any to mask the fact that we are storing a generic struct.
         static REMOTE_CONTEXT: RefCell<Option<&'static dyn Any>> = RefCell::new(None);
@@ -52,13 +74,21 @@ impl<Command: 'static> RemoteContext<Command> {
 
 #[cfg(test)]
 mod tests {
+    use std::sync::Arc;
+
+    use async_trait::async_trait;
     use bytes::Bytes;
     use parking_lot::Mutex;
-    use std::sync::Arc;
 
     use crate::election::ElectionState;
+    use crate::remote_peer::RemotePeer;
     use crate::term_marker::TermMarker;
-    use crate::{Peer, Persister, RaftState};
+    use crate::verify_authority::VerifyAuthorityDaemon;
+    use crate::{
+        AppendEntriesArgs, AppendEntriesReply, InstallSnapshotArgs,
+        InstallSnapshotReply, Peer, Persister, RaftState, RemoteRaft,
+        RequestVoteArgs, RequestVoteReply,
+    };
 
     use super::RemoteContext;
 
@@ -77,14 +107,46 @@ mod tests {
         fn save_snapshot_and_state(&self, _: Bytes, _: &[u8]) {}
     }
 
+    struct DoNothingRemoteRaft;
+    #[async_trait]
+    impl<Command: 'static + Send> RemoteRaft<Command> for DoNothingRemoteRaft {
+        async fn request_vote(
+            &self,
+            _args: RequestVoteArgs,
+        ) -> std::io::Result<RequestVoteReply> {
+            unimplemented!()
+        }
+
+        async fn append_entries(
+            &self,
+            _args: AppendEntriesArgs<Command>,
+        ) -> std::io::Result<AppendEntriesReply> {
+            unimplemented!()
+        }
+
+        async fn install_snapshot(
+            &self,
+            _args: InstallSnapshotArgs,
+        ) -> std::io::Result<InstallSnapshotReply> {
+            unimplemented!()
+        }
+    }
+
     #[test]
     fn test_context_api() {
         let rf = Arc::new(Mutex::new(RaftState::<i32>::create(1, Peer(0))));
         let election = Arc::new(ElectionState::create());
+        let verify_authority_daemon = VerifyAuthorityDaemon::create(1);
         let term_marker =
             TermMarker::create(rf, election, Arc::new(DoNothingPersister));
-
-        let context = Box::new(RemoteContext::create(term_marker));
+        let remote_peer = RemotePeer::create(
+            Peer(0),
+            DoNothingRemoteRaft,
+            verify_authority_daemon.beat_ticker(0),
+        );
+
+        let context =
+            Box::new(RemoteContext::create(term_marker, vec![remote_peer]));
         let context_ptr: *const RemoteContext<i32> = &*context;
 
         RemoteContext::set_context(context);

+ 27 - 0
src/remote_peer.rs

@@ -0,0 +1,27 @@
+use std::sync::Arc;
+
+use crate::verify_authority::DaemonBeatTicker;
+use crate::RemoteRaft;
+
+#[derive(Clone)]
+pub(crate) struct RemotePeer<Command, UniqueId> {
+    #[allow(dead_code)]
+    pub unique_id: UniqueId,
+    pub rpc_client: Arc<dyn RemoteRaft<Command>>,
+    pub beat_ticker: DaemonBeatTicker,
+}
+
+impl<Command, UniqueId> RemotePeer<Command, UniqueId> {
+    pub fn create<RpcClient: 'static + RemoteRaft<Command>>(
+        unique_id: UniqueId,
+        rpc_client: RpcClient,
+        beat_ticker: DaemonBeatTicker,
+    ) -> Self {
+        let rpc_client = Arc::new(rpc_client);
+        RemotePeer {
+            unique_id,
+            rpc_client,
+            beat_ticker,
+        }
+    }
+}

+ 14 - 19
src/sync_log_entries.rs

@@ -8,10 +8,9 @@ use crate::heartbeats::HEARTBEAT_INTERVAL;
 use crate::peer_progress::PeerProgress;
 use crate::remote_context::RemoteContext;
 use crate::utils::{retry_rpc, SharedSender, RPC_DEADLINE};
-use crate::verify_authority::DaemonBeatTicker;
 use crate::{
     check_or_record, AppendEntriesArgs, Index, IndexTerm, InstallSnapshotArgs,
-    Peer, Raft, RaftState, RemoteRaft, ReplicableCommand, Term,
+    Peer, Raft, RaftState, ReplicableCommand, Term,
 };
 
 #[derive(Eq, PartialEq)]
@@ -132,9 +131,9 @@ impl<Command: ReplicableCommand> Raft<Command> {
                 if !this.inner_state.lock().is_leader() {
                     continue;
                 }
-                for (i, rpc_client) in this.peers.iter().enumerate() {
-                    if i != this.me.0 && event.should_schedule(Peer(i)) {
-                        let progress = &peer_progress[i];
+                for peer in this.peers.clone().into_iter() {
+                    if peer != this.me && event.should_schedule(peer) {
+                        let progress = &peer_progress[peer.0];
                         if let Event::NewTerm(_term, index) = event {
                             progress.reset_progress(index);
                         }
@@ -144,11 +143,9 @@ impl<Command: ReplicableCommand> Raft<Command> {
                             task_number += 1;
                             this.thread_pool.spawn(Self::sync_log_entries(
                                 this.inner_state.clone(),
-                                rpc_client.clone(),
                                 this.sync_log_entries_comms.clone(),
                                 progress.clone(),
                                 this.apply_command_signal.clone(),
-                                this.beat_ticker(i),
                                 TaskNumber(task_number),
                             ));
                         }
@@ -200,14 +197,11 @@ impl<Command: ReplicableCommand> Raft<Command> {
     /// failure of the last case, we will never hit the other failure again,
     /// since in the last case we always sync log entry at a committed index,
     /// and a committed log entry can never diverge.
-    #[allow(clippy::too_many_arguments)]
     async fn sync_log_entries(
         rf: Arc<Mutex<RaftState<Command>>>,
-        rpc_client: impl RemoteRaft<Command>,
         comms: SyncLogEntriesComms,
         progress: PeerProgress,
         apply_command_signal: Arc<Condvar>,
-        beat_ticker: DaemonBeatTicker,
         task_number: TaskNumber,
     ) {
         if !progress.take_task() {
@@ -222,8 +216,7 @@ impl<Command: ReplicableCommand> Raft<Command> {
                 let term = args.term;
                 let prev_log_index = args.prev_log_index;
                 let match_index = args.prev_log_index + args.entries.len();
-                let succeeded =
-                    Self::append_entries(&rpc_client, args, beat_ticker).await;
+                let succeeded = Self::append_entries(peer, args).await;
 
                 (term, prev_log_index, match_index, succeeded)
             }
@@ -231,9 +224,7 @@ impl<Command: ReplicableCommand> Raft<Command> {
                 let term = args.term;
                 let prev_log_index = args.last_included_index;
                 let match_index = args.last_included_index;
-                let succeeded =
-                    Self::install_snapshot(&rpc_client, args, beat_ticker)
-                        .await;
+                let succeeded = Self::install_snapshot(peer, args).await;
 
                 (term, prev_log_index, match_index, succeeded)
             }
@@ -477,10 +468,12 @@ impl<Command: ReplicableCommand> Raft<Command> {
 
     const APPEND_ENTRIES_RETRY: usize = 1;
     async fn append_entries(
-        rpc_client: &dyn RemoteRaft<Command>,
+        peer: Peer,
         args: AppendEntriesArgs<Command>,
-        beat_ticker: DaemonBeatTicker,
     ) -> std::io::Result<SyncLogEntriesResult> {
+        let beat_ticker = RemoteContext::<Command>::beat_ticker(peer);
+        let rpc_client = RemoteContext::<Command>::rpc_client(peer);
+
         let term = args.term;
         let beat = beat_ticker.next_beat();
         let reply = retry_rpc(
@@ -520,10 +513,12 @@ impl<Command: ReplicableCommand> Raft<Command> {
 
     const INSTALL_SNAPSHOT_RETRY: usize = 1;
     async fn install_snapshot(
-        rpc_client: &dyn RemoteRaft<Command>,
+        peer: Peer,
         args: InstallSnapshotArgs,
-        beat_ticker: DaemonBeatTicker,
     ) -> std::io::Result<SyncLogEntriesResult> {
+        let beat_ticker = RemoteContext::<Command>::beat_ticker(peer);
+        let rpc_client = RemoteContext::<Command>::rpc_client(peer);
+
         let term = args.term;
         let beat = beat_ticker.next_beat();
         let reply = retry_rpc(

+ 0 - 4
src/verify_authority.rs

@@ -400,10 +400,6 @@ impl<Command: 'static + Send> Raft<Command> {
                 .expect("Verify authority daemon never drops senders")
         })
     }
-
-    pub(crate) fn beat_ticker(&self, peer_index: usize) -> DaemonBeatTicker {
-        self.verify_authority_daemon.beat_ticker(peer_index)
-    }
 }
 
 #[cfg(test)]