server.rs 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. use crate::{ReplyMessage, RequestMessage, Result};
  2. use std::collections::hash_map::Entry::Vacant;
  3. use std::sync::Arc;
  4. pub trait RpcHandler {
  5. // Note this method is not async.
  6. fn call(&self, data: RequestMessage) -> ReplyMessage;
  7. }
  8. struct ServerState {
  9. rpc_handlers: std::collections::HashMap<String, Arc<Box<dyn RpcHandler>>>,
  10. rpc_count: std::cell::Cell<usize>,
  11. }
  12. pub struct Server {
  13. name: String,
  14. state: std::sync::Mutex<ServerState>,
  15. thread_pool: futures::executor::ThreadPool,
  16. }
  17. impl Unpin for Server {}
  18. // Server contains a immutable name, a mutex-protected state, and a thread pool.
  19. // All of those 3 are `Send` and `Sync`.
  20. unsafe impl Send for Server {}
  21. unsafe impl Sync for Server {}
  22. impl Server {
  23. const THREAD_POOL_SIZE: usize = 4;
  24. pub async fn dispatch(
  25. self: Arc<Self>,
  26. service_method: String,
  27. data: RequestMessage,
  28. ) -> Result<ReplyMessage> {
  29. let (tx, rx) = futures::channel::oneshot::channel();
  30. let this = self.clone();
  31. this.thread_pool.spawn_ok(async move {
  32. let rpc_handler = {
  33. // Blocking on a mutex in a thread pool. Sounds horrible, but
  34. // in fact quite safe, given that the critical section is short.
  35. let state = self
  36. .state
  37. .lock()
  38. .expect("The server state mutex should not be poisoned.");
  39. state.rpc_count.set(state.rpc_count.get() + 1);
  40. state.rpc_handlers.get(&service_method).map(|r| r.clone())
  41. };
  42. let response = match rpc_handler {
  43. Some(rpc_handler) => Ok(rpc_handler.call(data)),
  44. None => Err(std::io::Error::new(
  45. std::io::ErrorKind::InvalidInput,
  46. format!(
  47. "Method {} on server {} not found.",
  48. service_method, self.name
  49. ),
  50. )),
  51. };
  52. if let Err(_) = tx.send(response) {
  53. // Receiving end is dropped. Never mind.
  54. // Do nothing.
  55. }
  56. });
  57. rx.await.map_err(|_e| {
  58. std::io::Error::new(
  59. std::io::ErrorKind::ConnectionReset,
  60. format!("Remote server {} cancelled the RPC.", this.name),
  61. )
  62. })?
  63. }
  64. pub fn register_rpc_handler(
  65. &mut self,
  66. service_method: String,
  67. rpc_handler: Box<dyn RpcHandler>,
  68. ) -> Result<()> {
  69. let mut state = self
  70. .state
  71. .lock()
  72. .expect("The server state mutex should not be poisoned.");
  73. let debug_service_method = service_method.clone();
  74. if let Vacant(vacant) = state.rpc_handlers.entry(service_method) {
  75. vacant.insert(Arc::new(rpc_handler));
  76. Ok(())
  77. } else {
  78. Err(std::io::Error::new(
  79. std::io::ErrorKind::AlreadyExists,
  80. format!(
  81. "Service method {} already exists in server {}.",
  82. debug_service_method, self.name
  83. ),
  84. ))
  85. }
  86. }
  87. pub fn make_server(name: String) -> Self {
  88. let state = std::sync::Mutex::new(ServerState {
  89. rpc_handlers: std::collections::HashMap::new(),
  90. rpc_count: std::cell::Cell::new(0),
  91. });
  92. let thread_pool = futures::executor::ThreadPool::builder()
  93. .name_prefix(name.clone())
  94. .pool_size(Self::THREAD_POOL_SIZE)
  95. .create()
  96. .expect("Creating thread pools should not fail.");
  97. Self {
  98. name,
  99. state,
  100. thread_pool,
  101. }
  102. }
  103. }
  104. #[cfg(test)]
  105. mod tests {
  106. use super::*;
  107. use crate::junk_server::{EchoRpcHandler, make_server};
  108. #[test]
  109. fn test_register_rpc_handler() -> Result<()> {
  110. let server = make_server();
  111. assert_eq!(2, server.state.lock().unwrap().rpc_handlers.len());
  112. Ok(())
  113. }
  114. #[test]
  115. fn test_register_rpc_handler_failure() -> Result<()> {
  116. let mut server = make_server();
  117. let server = std::sync::Arc::get_mut(&mut server).unwrap();
  118. let result = server.register_rpc_handler(
  119. "echo".to_string(),
  120. Box::new(EchoRpcHandler {}),
  121. );
  122. assert!(result.is_err());
  123. assert_eq!(2, server.state.lock().unwrap().rpc_handlers.len());
  124. Ok(())
  125. }
  126. #[test]
  127. fn test_serve_rpc() -> Result<()> {
  128. let server = make_server();
  129. let reply = server.dispatch(
  130. "echo".to_string(),
  131. RequestMessage::from_static(&[0x08, 0x07]),
  132. );
  133. let result = futures::executor::block_on(reply)?;
  134. assert_eq!(ReplyMessage::from_static(&[0x07, 0x08]), result);
  135. Ok(())
  136. }
  137. #[test]
  138. fn test_rpc_not_found() -> Result<()> {
  139. let server = make_server();
  140. let reply = server.dispatch("acorn".to_string(), RequestMessage::new());
  141. match futures::executor::block_on(reply) {
  142. Ok(_) => panic!("acorn service is not registered."),
  143. Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput),
  144. }
  145. Ok(())
  146. }
  147. #[test]
  148. fn test_rpc_error() -> Result<()> {
  149. let server = make_server();
  150. let reply = futures::executor::block_on(
  151. server.dispatch("aborting".to_string(), RequestMessage::new()),
  152. );
  153. assert_eq!(
  154. reply
  155. .err()
  156. .expect("Aborting RPC should return error.")
  157. .kind(),
  158. std::io::ErrorKind::ConnectionReset,
  159. );
  160. Ok(())
  161. }
  162. #[test]
  163. fn test_server_survives_3_rpc_errors() -> Result<()> {
  164. let server = make_server();
  165. // TODO(ditsing): server hangs after the 4th RPC error.
  166. for _ in 0..3 {
  167. let server_clone = server.clone();
  168. let _ = futures::executor::block_on(
  169. server_clone.dispatch("aborting".to_string(), RequestMessage::new()),
  170. );
  171. }
  172. let reply = server.dispatch(
  173. "echo".to_string(),
  174. RequestMessage::from_static(&[0x08, 0x07]),
  175. );
  176. let result = futures::executor::block_on(reply)?;
  177. assert_eq!(ReplyMessage::from_static(&[0x07, 0x08]), result);
  178. Ok(())
  179. }
  180. }