Explorar el Código

Use Arc to save some clone time.

Jing Yang hace 5 años
padre
commit
7850d934df
Se han modificado 3 ficheros con 28 adiciones y 26 borrados
  1. 16 16
      src/lib.rs
  2. 10 8
      src/rpcs.rs
  3. 2 2
      src/utils.rs

+ 16 - 16
src/lib.rs

@@ -79,7 +79,7 @@ struct ElectionState {
 #[derive(Clone)]
 pub struct Raft {
     inner_state: Arc<Mutex<RaftState>>,
-    peers: Vec<RpcClient>,
+    peers: Vec<Arc<RpcClient>>,
 
     me: Peer,
 
@@ -178,6 +178,7 @@ impl Raft {
             .max_threads(peer_size * 2)
             .build()
             .expect("Creating thread pool should not fail");
+        let peers = peers.into_iter().map(|r| Arc::new(r)).collect();
         let mut this = Raft {
             inner_state: Arc::new(Mutex::new(state)),
             peers,
@@ -459,13 +460,13 @@ impl Raft {
 
     const REQUEST_VOTE_RETRY: usize = 1;
     async fn request_vote(
-        rpc_client: RpcClient,
+        rpc_client: Arc<RpcClient>,
         args: RequestVoteArgs,
     ) -> Option<bool> {
         let term = args.term;
         let reply =
-            retry_rpc(Self::REQUEST_VOTE_RETRY, RPC_DEADLINE, move |_round| {
-                rpc_client.clone().call_request_vote(args.clone())
+            retry_rpc(Self::REQUEST_VOTE_RETRY, RPC_DEADLINE, |_round| {
+                rpc_client.call_request_vote(args.clone())
             })
             .await;
         if let Ok(reply) = reply {
@@ -586,11 +587,11 @@ impl Raft {
 
     const HEARTBEAT_RETRY: usize = 1;
     async fn send_heartbeat(
-        rpc_client: RpcClient,
+        rpc_client: Arc<RpcClient>,
         args: AppendEntriesArgs,
     ) -> std::io::Result<()> {
-        retry_rpc(Self::HEARTBEAT_RETRY, RPC_DEADLINE, move |_round| {
-            rpc_client.clone().call_append_entries(args.clone())
+        retry_rpc(Self::HEARTBEAT_RETRY, RPC_DEADLINE, |_round| {
+            rpc_client.call_append_entries(args.clone())
         })
         .await?;
         Ok(())
@@ -639,7 +640,7 @@ impl Raft {
 
     async fn sync_log_entry(
         rf: Arc<Mutex<RaftState>>,
-        rpc_client: RpcClient,
+        rpc_client: Arc<RpcClient>,
         peer_index: usize,
         rerun: std::sync::mpsc::Sender<Option<Peer>>,
         openings: Arc<Vec<AtomicUsize>>,
@@ -656,7 +657,7 @@ impl Raft {
         };
         let term = args.term;
         let match_index = args.prev_log_index + args.entries.len();
-        let succeeded = Self::append_entries(rpc_client, args).await;
+        let succeeded = Self::append_entries(&rpc_client, args).await;
         match succeeded {
             Ok(Some(true)) => {
                 let mut rf = rf.lock();
@@ -737,16 +738,15 @@ impl Raft {
 
     const APPEND_ENTRIES_RETRY: usize = 1;
     async fn append_entries(
-        rpc_client: RpcClient,
+        rpc_client: &RpcClient,
         args: AppendEntriesArgs,
     ) -> std::io::Result<Option<bool>> {
         let term = args.term;
-        let reply = retry_rpc(
-            Self::APPEND_ENTRIES_RETRY,
-            RPC_DEADLINE,
-            move |_round| rpc_client.clone().call_append_entries(args.clone()),
-        )
-        .await?;
+        let reply =
+            retry_rpc(Self::APPEND_ENTRIES_RETRY, RPC_DEADLINE, |_round| {
+                rpc_client.call_append_entries(args.clone())
+            })
+            .await?;
         Ok(if reply.term == term {
             Some(reply.success)
         } else {

+ 10 - 8
src/rpcs.rs

@@ -44,7 +44,6 @@ impl RpcHandler for AppendEntriesRpcHandler {
 pub(crate) const REQUEST_VOTE_RPC: &str = "Raft.RequestVote";
 pub(crate) const APPEND_ENTRIES_RPC: &str = "Raft.AppendEntries";
 
-#[derive(Clone)]
 pub struct RpcClient(Client);
 
 impl RpcClient {
@@ -53,7 +52,7 @@ impl RpcClient {
     }
 
     pub(crate) async fn call_request_vote(
-        self: Self,
+        &self,
         request: RequestVoteArgs,
     ) -> std::io::Result<RequestVoteReply> {
         let data = RequestMessage::from(
@@ -68,7 +67,7 @@ impl RpcClient {
     }
 
     pub(crate) async fn call_append_entries(
-        self: Self,
+        &self,
         request: AppendEntriesArgs,
     ) -> std::io::Result<AppendEntriesReply> {
         let data = RequestMessage::from(
@@ -138,12 +137,16 @@ mod tests {
                 .make_client("test-basic-message", name.to_owned());
 
             let raft = Arc::new(Raft::new(
-                vec![RpcClient(client.clone())],
+                vec![RpcClient(client)],
                 0,
                 Arc::new(()),
                 |_, _| {},
             ));
             register_server(raft, name, network.as_ref())?;
+
+            let client = network
+                .lock()
+                .make_client("test-basic-message", name.to_owned());
             client
         };
 
@@ -155,9 +158,8 @@ mod tests {
             last_log_index: 0,
             last_log_term: Term(0),
         };
-        let response = futures::executor::block_on(
-            rpc_client.clone().call_request_vote(request),
-        )?;
+        let response =
+            futures::executor::block_on(rpc_client.call_request_vote(request))?;
         assert_eq!(true, response.vote_granted);
 
         let request = AppendEntriesArgs {
@@ -169,7 +171,7 @@ mod tests {
             leader_commit: 0,
         };
         let response = futures::executor::block_on(
-            rpc_client.clone().call_append_entries(request),
+            rpc_client.call_append_entries(request),
         )?;
         assert_eq!(2021, response.term.0);
         assert_eq!(true, response.success);

+ 2 - 2
src/utils.rs

@@ -1,13 +1,13 @@
 use std::future::Future;
 use std::time::Duration;
 
-pub async fn retry_rpc<Func, Fut, T>(
+pub async fn retry_rpc<'a, Func, Fut, T>(
     max_retry: usize,
     deadline: Duration,
     mut task_gen: Func,
 ) -> std::io::Result<T>
 where
-    Fut: Future<Output = std::io::Result<T>> + Send + 'static,
+    Fut: Future<Output = std::io::Result<T>> + Send + 'a,
     Func: FnMut(usize) -> Fut,
 {
     for i in 0..max_retry {