server.rs 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. use std::collections::hash_map::Entry::Vacant;
  2. use std::sync::Arc;
  3. use crate::{ReplyMessage, RequestMessage, Result};
  4. pub trait RpcHandler {
  5. // Note this method is not async.
  6. fn call(&self, data: RequestMessage) -> ReplyMessage;
  7. }
  8. struct ServerState {
  9. rpc_handlers: std::collections::HashMap<String, Arc<Box<dyn RpcHandler>>>,
  10. rpc_count: std::cell::Cell<usize>,
  11. }
  12. pub struct Server {
  13. name: String,
  14. state: std::sync::Mutex<ServerState>,
  15. thread_pool: futures::executor::ThreadPool,
  16. }
  17. impl Unpin for Server {}
  18. // Server contains a immutable name, a mutex-protected state, and a thread pool.
  19. // All of those 3 are `Send` and `Sync`.
  20. unsafe impl Send for Server {}
  21. unsafe impl Sync for Server {}
  22. impl Server {
  23. const THREAD_POOL_SIZE: usize = 4;
  24. pub async fn dispatch(
  25. self: Arc<Self>,
  26. service_method: String,
  27. data: RequestMessage,
  28. ) -> Result<ReplyMessage> {
  29. let (tx, rx) = futures::channel::oneshot::channel();
  30. let this = self.clone();
  31. this.thread_pool.spawn_ok(async move {
  32. let rpc_handler = {
  33. // Blocking on a mutex in a thread pool. Sounds horrible, but
  34. // in fact quite safe, given that the critical section is short.
  35. let state = self
  36. .state
  37. .lock()
  38. .expect("The server state mutex should not be poisoned");
  39. state.rpc_count.set(state.rpc_count.get() + 1);
  40. state.rpc_handlers.get(&service_method).cloned()
  41. };
  42. let response = match rpc_handler {
  43. Some(rpc_handler) => Ok(rpc_handler.call(data)),
  44. None => Err(std::io::Error::new(
  45. std::io::ErrorKind::InvalidInput,
  46. format!(
  47. "Method {} on server {} not found.",
  48. service_method, self.name
  49. ),
  50. )),
  51. };
  52. #[allow(clippy::redundant_pattern_matching)]
  53. if let Err(_) = tx.send(response) {
  54. // Receiving end is dropped. Never mind.
  55. // Do nothing.
  56. }
  57. });
  58. rx.await.map_err(|_e| {
  59. std::io::Error::new(
  60. std::io::ErrorKind::ConnectionReset,
  61. format!("Remote server {} cancelled the RPC.", this.name),
  62. )
  63. })?
  64. }
  65. pub fn register_rpc_handler(
  66. &mut self,
  67. service_method: String,
  68. rpc_handler: Box<dyn RpcHandler>,
  69. ) -> Result<()> {
  70. let mut state = self
  71. .state
  72. .lock()
  73. .expect("The server state mutex should not be poisoned");
  74. let debug_service_method = service_method.clone();
  75. if let Vacant(vacant) = state.rpc_handlers.entry(service_method) {
  76. vacant.insert(Arc::new(rpc_handler));
  77. Ok(())
  78. } else {
  79. Err(std::io::Error::new(
  80. std::io::ErrorKind::AlreadyExists,
  81. format!(
  82. "Service method {} already exists in server {}.",
  83. debug_service_method, self.name
  84. ),
  85. ))
  86. }
  87. }
  88. pub fn rpc_count(&self) -> usize {
  89. self.state
  90. .lock()
  91. .expect("The server state mutex should not be poisoned")
  92. .rpc_count
  93. .get()
  94. }
  95. pub fn make_server(name: String) -> Self {
  96. let state = std::sync::Mutex::new(ServerState {
  97. rpc_handlers: std::collections::HashMap::new(),
  98. rpc_count: std::cell::Cell::new(0),
  99. });
  100. let thread_pool = futures::executor::ThreadPool::builder()
  101. .name_prefix(name.clone())
  102. .pool_size(Self::THREAD_POOL_SIZE)
  103. .create()
  104. .expect("Creating thread pools should not fail");
  105. Self {
  106. name,
  107. state,
  108. thread_pool,
  109. }
  110. }
  111. }
  112. #[cfg(test)]
  113. mod tests {
  114. use crate::test_utils::junk_server::{
  115. make_test_server, EchoRpcHandler,
  116. JunkRpcs::{Aborting, Echo},
  117. };
  118. use super::*;
  119. fn rpc_handlers_len(server: &Server) -> usize {
  120. server
  121. .state
  122. .lock()
  123. .expect("The server state mutex should not be poisoned.")
  124. .rpc_handlers
  125. .len()
  126. }
  127. #[test]
  128. fn test_register_rpc_handler() -> Result<()> {
  129. let server = make_test_server();
  130. assert_eq!(2, rpc_handlers_len(server.as_ref()));
  131. Ok(())
  132. }
  133. #[test]
  134. fn test_register_rpc_handler_failure() -> Result<()> {
  135. let mut server = make_test_server();
  136. let server = std::sync::Arc::get_mut(&mut server)
  137. .expect("Server should only be held by the current thread");
  138. let result = server.register_rpc_handler(
  139. "echo".to_string(),
  140. Box::new(EchoRpcHandler {}),
  141. );
  142. assert!(result.is_err());
  143. assert_eq!(2, rpc_handlers_len(server));
  144. Ok(())
  145. }
  146. #[test]
  147. fn test_serve_rpc() -> Result<()> {
  148. let server = make_test_server();
  149. let reply = server.dispatch(
  150. "echo".to_string(),
  151. RequestMessage::from_static(&[0x08, 0x07]),
  152. );
  153. let result = futures::executor::block_on(reply)?;
  154. assert_eq!(ReplyMessage::from_static(&[0x07, 0x08]), result);
  155. Ok(())
  156. }
  157. #[test]
  158. fn test_rpc_not_found() -> Result<()> {
  159. let server = make_test_server();
  160. let reply = server.dispatch("acorn".to_string(), RequestMessage::new());
  161. match futures::executor::block_on(reply) {
  162. Ok(_) => panic!("acorn service is not registered."),
  163. Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput),
  164. }
  165. Ok(())
  166. }
  167. #[test]
  168. fn test_rpc_error() -> Result<()> {
  169. let server = make_test_server();
  170. let reply = futures::executor::block_on(
  171. server.dispatch(Aborting.name(), RequestMessage::new()),
  172. );
  173. assert_eq!(
  174. reply
  175. .err()
  176. .expect("Aborting RPC should return error")
  177. .kind(),
  178. std::io::ErrorKind::ConnectionReset,
  179. );
  180. Ok(())
  181. }
  182. #[test]
  183. fn test_server_survives_3_rpc_errors() -> Result<()> {
  184. let server = make_test_server();
  185. // TODO(ditsing): server hangs after the 4th RPC error.
  186. for _ in 0..3 {
  187. let server_clone = server.clone();
  188. let _ = futures::executor::block_on(
  189. server_clone.dispatch(Aborting.name(), RequestMessage::new()),
  190. );
  191. }
  192. let reply = server
  193. .dispatch(Echo.name(), RequestMessage::from_static(&[0x08, 0x07]));
  194. let result = futures::executor::block_on(reply)?;
  195. assert_eq!(ReplyMessage::from_static(&[0x07, 0x08]), result);
  196. Ok(())
  197. }
  198. }