瀏覽代碼

Create a lazy wrapper around Raft.

Otherwise we would have to connect to other Raft instances before
starting the local one, which creates a dependency loop.
Jing Yang 4 年之前
父節點
當前提交
10f1e91521
共有 2 個文件被更改,包括 36 次插入28 次删除
  1. 30 17
      durio/src/raft_service.rs
  2. 6 11
      durio/src/run.rs

+ 30 - 17
durio/src/raft_service.rs

@@ -52,17 +52,36 @@ impl RaftService for RaftRpcServer {
     }
 }
 
-pub(crate) struct OptionalRaftServiceClient(Option<RaftServiceClient>);
+pub(crate) struct LazyRaftServiceClient {
+    socket_addr: SocketAddr,
+    once_cell: tokio::sync::OnceCell<RaftServiceClient>,
+}
+
+impl LazyRaftServiceClient {
+    pub(crate) fn new(socket_addr: SocketAddr) -> Self {
+        Self {
+            socket_addr,
+            once_cell: tokio::sync::OnceCell::new(),
+        }
+    }
+
+    pub(crate) async fn get_or_try_init(
+        &self,
+    ) -> std::io::Result<&RaftServiceClient> {
+        self.once_cell
+            .get_or_try_init(|| connect_to_raft_service(self.socket_addr))
+            .await
+    }
+}
 
 #[async_trait]
-impl RemoteRaft<UniqueKVOp> for OptionalRaftServiceClient {
+impl RemoteRaft<UniqueKVOp> for LazyRaftServiceClient {
     async fn request_vote(
         &self,
         args: RequestVoteArgs,
     ) -> std::io::Result<RequestVoteReply> {
-        self.0
-            .as_ref()
-            .unwrap()
+        self.get_or_try_init()
+            .await?
             .request_vote(Context::current(), args)
             .await
             .map_err(crate::utils::translate_rpc_error)
@@ -72,9 +91,8 @@ impl RemoteRaft<UniqueKVOp> for OptionalRaftServiceClient {
         &self,
         args: AppendEntriesArgs<UniqueKVOp>,
     ) -> std::io::Result<AppendEntriesReply> {
-        self.0
-            .as_ref()
-            .unwrap()
+        self.get_or_try_init()
+            .await?
             .append_entries(Context::current(), args)
             .await
             .map_err(crate::utils::translate_rpc_error)
@@ -84,22 +102,17 @@ impl RemoteRaft<UniqueKVOp> for OptionalRaftServiceClient {
         &self,
         args: InstallSnapshotArgs,
     ) -> std::io::Result<InstallSnapshotReply> {
-        self.0
-            .as_ref()
-            .unwrap()
+        self.get_or_try_init()
+            .await?
             .install_snapshot(Context::current(), args)
             .await
             .map_err(crate::utils::translate_rpc_error)
     }
 }
 
-pub(crate) fn no_raft_service() -> OptionalRaftServiceClient {
-    OptionalRaftServiceClient(None)
-}
-
 pub(crate) async fn connect_to_raft_service(
     addr: SocketAddr,
-) -> std::io::Result<OptionalRaftServiceClient> {
+) -> std::io::Result<RaftServiceClient> {
     let conn = tarpc::serde_transport::tcp::connect(
         addr,
         tokio_serde::formats::Json::default,
@@ -107,7 +120,7 @@ pub(crate) async fn connect_to_raft_service(
     .await?;
     let client =
         RaftServiceClient::new(tarpc::client::Config::default(), conn).spawn();
-    Ok(OptionalRaftServiceClient(Some(client)))
+    Ok(client)
 }
 
 pub(crate) fn start_raft_service_server(

+ 6 - 11
durio/src/run.rs

@@ -5,23 +5,18 @@ use kvraft::KVServer;
 
 use crate::kv_service::start_kv_service_server;
 use crate::persister::Persister;
-use crate::raft_service::{
-    connect_to_raft_service, no_raft_service, start_raft_service_server,
-};
+use crate::raft_service::{start_raft_service_server, LazyRaftServiceClient};
 
 pub(crate) async fn run_kv_instance(
     addr: SocketAddr,
     raft_peers: Vec<SocketAddr>,
     me: usize,
 ) -> std::io::Result<Arc<KVServer>> {
+    let local_raft_peer = raft_peers[me];
+
     let mut remote_rafts = vec![];
-    for (index, raft_peer) in raft_peers.iter().enumerate() {
-        let remote_raft = if index == me {
-            no_raft_service()
-        } else {
-            connect_to_raft_service(*raft_peer).await?
-        };
-        remote_rafts.push(remote_raft);
+    for raft_peer in raft_peers {
+        remote_rafts.push(LazyRaftServiceClient::new(raft_peer));
     }
 
     let persister = Arc::new(Persister::new());
@@ -29,7 +24,7 @@ pub(crate) async fn run_kv_instance(
     let kv_server = KVServer::new(remote_rafts, me, persister, None);
     let raft = Arc::new(kv_server.raft().clone());
 
-    start_raft_service_server(raft_peers[me], raft).await?;
+    start_raft_service_server(local_raft_peer, raft).await?;
     start_kv_service_server(addr, kv_server.clone()).await?;
 
     Ok(kv_server)