Forráskód Böngészése

First draft, contains server and network.

Jing Yang 5 éve
szülő
commit
a2d6c7ef16
4 módosított fájl, 243 hozzáadás és 4 törlés
  1. 2 2
      Cargo.toml
  2. 1 0
      rustfmt.toml
  3. 129 2
      src/lib.rs
  4. 111 0
      src/server.rs

+ 2 - 2
Cargo.toml

@@ -4,6 +4,6 @@ version = "0.1.0"
 authors = ["Jing Yang <ditsing@gmail.com>"]
 edition = "2018"
 
-# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
-
 [dependencies]
+bytes = "0.5.6"
+futures = { version = "0.3.5", features =[ "thread-pool" ] }

+ 1 - 0
rustfmt.toml

@@ -0,0 +1 @@
+max_width = 80

+ 129 - 2
src/lib.rs

@@ -1,7 +1,134 @@
+extern crate bytes;
+extern crate futures;
+
+mod server;
+
+use crate::server::Server;
+use bytes::Bytes;
+use std::collections::HashMap;
+use std::sync::Arc;
+
+type Result<T> = std::io::Result<T>;
+
+// Messages passed on network.
+struct RequestMessage<'a> {
+    service_method: String,
+    arg: &'a [u8],
+}
+
+type ReplyMessage = Bytes;
+
+type ServerIdentifier = String;
+type ClientIdentifier = String;
+
+// Client interface, used by the RPC client.
+struct Client {
+    client: ClientIdentifier,
+    server: ServerIdentifier,
+    // Closing signal,
+}
+
+struct Network {
+    // Need a lock field
+    // Settings.
+    reliable: bool,
+    long_delays: bool,
+    long_reordering: bool,
+
+    // Clients
+    clients: HashMap<ClientIdentifier, (bool, ServerIdentifier)>,
+    servers: HashMap<ServerIdentifier, Arc<Server>>,
+
+    // Closing signal.
+
+    // RPC Counter, using Cell for interior mutability.
+    rpc_count: std::cell::Cell<usize>,
+}
+
+impl Network {
+    pub fn cleanup(self) {
+        unimplemented!()
+    }
+
+    pub fn set_reliable(&mut self, yes: bool) {
+        self.reliable = yes
+    }
+
+    pub fn set_long_reordering(&mut self, yes: bool) {
+        self.long_reordering = yes
+    }
+
+    pub fn set_long_delays(&mut self, yes: bool) {
+        self.long_delays = yes
+    }
+
+    pub fn make_connection(_server_name: ServerIdentifier) -> Client {
+        unimplemented!()
+    }
+
+    pub async fn dispatch(
+        &self,
+        client: ClientIdentifier,
+        request: RequestMessage<'_>,
+    ) -> Result<ReplyMessage> {
+        // TODO: acquire a lock.
+        self.increase_rpc_count();
+        let server_name = self
+            .clients
+            .get(&client)
+            .map(|(enabled, server)| if !enabled { None } else { Some(server) })
+            .flatten()
+            .ok_or_else(|| {
+                std::io::Error::new(
+                    std::io::ErrorKind::NotConnected,
+                    format!("Client {} is not connected", client),
+                )
+            })?;
+        let server = self.servers.get(server_name).ok_or_else(|| {
+            std::io::Error::new(
+                std::io::ErrorKind::NotFound,
+                format!(
+                    "Cannot connect {} to server {}: server not found.",
+                    client, server_name,
+                ),
+            )
+        })?;
+        let data = Bytes::copy_from_slice(request.arg);
+        server.clone().dispatch(request.service_method, data).await
+    }
+
+    pub fn get_total_rpc_count(&self) -> usize {
+        self.rpc_count.get()
+    }
+}
+
+impl Network {
+    fn increase_rpc_count(&self) {
+        self.rpc_count.set(self.rpc_count.get() + 1)
+    }
+}
+
 #[cfg(test)]
 mod tests {
+    use super::*;
+
+    fn make_network() -> Network {
+        Network {
+            reliable: false,
+            long_delays: false,
+            long_reordering: false,
+            clients: Default::default(),
+            servers: Default::default(),
+            rpc_count: std::cell::Cell::new(0),
+        }
+    }
+
     #[test]
-    fn it_works() {
-        assert_eq!(2 + 2, 4);
+    fn rpc_count_works() {
+        let network = make_network();
+        assert_eq!(0, network.get_total_rpc_count());
+
+        network.increase_rpc_count();
+        assert_eq!(1, network.get_total_rpc_count());
     }
 }

+ 111 - 0
src/server.rs

@@ -0,0 +1,111 @@
+use crate::Result;
+use bytes::Bytes;
+use std::collections::hash_map::Entry::Vacant;
+use std::sync::Arc;
+
+pub trait RpcHandler {
+    // Note this method is not async.
+    fn call(&self, data: Bytes) -> Bytes;
+}
+
+struct ServerState {
+    rpc_handlers: std::collections::HashMap<String, Arc<Box<dyn RpcHandler>>>,
+    rpc_count: std::cell::Cell<usize>,
+}
+
+pub struct Server {
+    name: String,
+    state: std::sync::Mutex<ServerState>,
+    thread_pool: futures::executor::ThreadPool,
+}
+
+impl Unpin for Server {}
+// Server contains a immutable name, a mutex-protected state, and a thread pool.
+// All of those 3 are `Send` and `Sync`.
+unsafe impl Send for Server {}
+unsafe impl Sync for Server {}
+
+impl Server {
+    const THREAD_POOL_SIZE: usize = 4;
+    pub async fn dispatch(
+        self: Arc<Self>,
+        service_method: String,
+        data: Bytes,
+    ) -> Result<Bytes> {
+        let (tx, rx) = futures::channel::oneshot::channel();
+        let this = self.clone();
+        this.thread_pool.spawn_ok(async move {
+            let rpc_handler = {
+                // Blocking on a mutex in a thread pool. Sounds horrible, but
+                // in fact quite safe, given that the critical section is short.
+                let state = self
+                    .state
+                    .lock()
+                    .expect("The server state mutex should not be poisoned.");
+                state.rpc_count.set(state.rpc_count.get() + 1);
+                state.rpc_handlers.get(&service_method).map(|r| r.clone())
+            };
+            let response = match rpc_handler {
+                Some(rpc_handler) => Ok(rpc_handler.call(data)),
+                None => Err(std::io::Error::new(
+                    std::io::ErrorKind::InvalidData,
+                    format!(
+                        "Method {} on server {} not found.",
+                        service_method, self.name
+                    ),
+                )),
+            };
+            if let Err(_) = tx.send(response) {
+                // Receiving end is dropped. Never mind.
+                // Do nothing.
+            }
+        });
+        rx.await.map_err(|_e| {
+            std::io::Error::new(
+                std::io::ErrorKind::ConnectionReset,
+                format!("Remote server {} cancelled the RPC.", this.name),
+            )
+        })?
+    }
+
+    pub fn register_rpc_handler(
+        &mut self,
+        service_method: String,
+        rpc_handler: Box<dyn RpcHandler>,
+    ) -> Result<()> {
+        let mut state = self
+            .state
+            .lock()
+            .expect("The server state mutex should not be poisoned.");
+        let debug_service_method = service_method.clone();
+        if let Vacant(vacant) = state.rpc_handlers.entry(service_method) {
+            vacant.insert(Arc::new(rpc_handler));
+            Ok(())
+        } else {
+            Err(std::io::Error::new(
+                std::io::ErrorKind::AlreadyExists,
+                format!(
+                    "Service method {} already exists in server {}.",
+                    debug_service_method, self.name
+                ),
+            ))
+        }
+    }
+
+    pub fn make_server(name: String) -> Self {
+        let state = std::sync::Mutex::new(ServerState {
+            rpc_handlers: std::collections::HashMap::new(),
+            rpc_count: std::cell::Cell::new(0),
+        });
+        let thread_pool = futures::executor::ThreadPool::builder()
+            .name_prefix(name.clone())
+            .pool_size(Self::THREAD_POOL_SIZE)
+            .create()
+            .expect("Creating thread pools should not fail.");
+        Self {
+            name,
+            state,
+            thread_pool,
+        }
+    }
+}