network.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. use std::collections::HashMap;
  2. use std::sync::mpsc::{channel, Receiver, Sender, TryRecvError};
  3. use std::sync::{
  4. atomic::{AtomicBool, Ordering},
  5. Arc, Mutex,
  6. };
  7. use std::time::{Duration, Instant};
  8. use rand::{thread_rng, Rng};
  9. use crate::{
  10. Client, ClientIdentifier, Result, RpcOnWire, Server, ServerIdentifier,
  11. };
  12. pub struct Network {
  13. // Settings.
  14. reliable: bool,
  15. long_delays: bool,
  16. long_reordering: bool,
  17. // Clients
  18. clients: HashMap<ClientIdentifier, (bool, ServerIdentifier)>,
  19. servers: HashMap<ServerIdentifier, Arc<Server>>,
  20. // Network bus
  21. request_bus: Sender<RpcOnWire>,
  22. request_pipe: Option<Receiver<RpcOnWire>>,
  23. // Closing signal.
  24. keep_running: bool,
  25. // Whether the network is active or not.
  26. stopped: AtomicBool,
  27. // RPC Counter, using Cell for interior mutability.
  28. rpc_count: std::cell::Cell<usize>,
  29. }
  30. impl Network {
  31. pub fn set_reliable(&mut self, yes: bool) {
  32. self.reliable = yes
  33. }
  34. pub fn set_long_reordering(&mut self, yes: bool) {
  35. self.long_reordering = yes
  36. }
  37. pub fn set_long_delays(&mut self, yes: bool) {
  38. self.long_delays = yes
  39. }
  40. pub fn stop(&mut self) {
  41. self.keep_running = false;
  42. }
  43. pub fn stopped(&self) -> bool {
  44. self.stopped.load(Ordering::Acquire)
  45. }
  46. pub fn make_client<C: Into<ClientIdentifier>, S: Into<ServerIdentifier>>(
  47. &mut self,
  48. client: C,
  49. server: S,
  50. ) -> Client {
  51. let (client, server) = (client.into(), server.into());
  52. self.clients.insert(client.clone(), (true, server.clone()));
  53. Client {
  54. client,
  55. server,
  56. request_bus: self.request_bus.clone(),
  57. }
  58. }
  59. pub fn set_enable_client<C: AsRef<str> + Sized>(
  60. &mut self,
  61. client: C,
  62. yes: bool,
  63. ) {
  64. self.clients
  65. .get_mut(client.as_ref())
  66. .map(|pair| pair.0 = yes);
  67. }
  68. pub fn add_server<S: Into<ServerIdentifier>>(
  69. &mut self,
  70. server_name: S,
  71. server: Arc<Server>,
  72. ) {
  73. self.servers.insert(server_name.into(), server);
  74. }
  75. pub fn remove_server<S: AsRef<str> + Sized>(&mut self, server_name: &S) {
  76. self.servers.remove(server_name.as_ref());
  77. }
  78. pub fn get_rpc_count<S: AsRef<str> + Sized>(
  79. &self,
  80. server_name: S,
  81. ) -> Option<usize> {
  82. self.servers
  83. .get(server_name.as_ref())
  84. .map(|s| s.rpc_count())
  85. }
  86. fn dispatch(&self, client: &ClientIdentifier) -> Result<Arc<Server>> {
  87. let (enabled, server_name) =
  88. self.clients.get(client).ok_or_else(|| {
  89. std::io::Error::new(
  90. std::io::ErrorKind::PermissionDenied,
  91. format!("Client {} is not connected.", client),
  92. )
  93. })?;
  94. if !enabled {
  95. return Err(std::io::Error::new(
  96. std::io::ErrorKind::BrokenPipe,
  97. format!("Client {} is disabled.", client),
  98. ));
  99. }
  100. let server = self.servers.get(server_name).ok_or_else(|| {
  101. std::io::Error::new(
  102. std::io::ErrorKind::NotFound,
  103. format!(
  104. "Cannot connect {} to server {}: server not found.",
  105. client, server_name,
  106. ),
  107. )
  108. })?;
  109. Ok(server.clone())
  110. }
  111. pub fn get_total_rpc_count(&self) -> usize {
  112. self.rpc_count.get()
  113. }
  114. }
  115. impl Network {
  116. const MAX_MINOR_DELAY_MILLIS: u64 = 27;
  117. const MAX_SHORT_DELAY_MILLIS: u64 = 100;
  118. const MAX_LONG_DELAY_MILLIS: u64 = 7000;
  119. const DROP_RATE: (u32, u32) = (100, 1000);
  120. const LONG_REORDERING_RATE: (u32, u32) = (600u32, 900u32);
  121. const LONG_REORDERING_BASE_DELAY_MILLIS: u64 = 200;
  122. const LONG_REORDERING_RANDOM_DELAY_BOUND_MILLIS: u64 = 2000;
  123. const SHUTDOWN_DELAY: Duration = Duration::from_micros(20);
  124. async fn delay_for_millis(milli_seconds: u64) {
  125. tokio::time::delay_for(Duration::from_millis(milli_seconds)).await;
  126. }
  127. async fn serve_rpc(network: Arc<Mutex<Self>>, rpc: RpcOnWire) {
  128. let (server_result, reliable, long_reordering, long_delays) = {
  129. let network = network
  130. .lock()
  131. .expect("Network mutex should not be poisoned");
  132. network.increase_rpc_count();
  133. (
  134. network.dispatch(&rpc.client),
  135. network.reliable,
  136. network.long_reordering,
  137. network.long_delays,
  138. )
  139. };
  140. // Random delay before sending requests to server.
  141. if !reliable {
  142. let minor_delay =
  143. thread_rng().gen_range(0, Self::MAX_MINOR_DELAY_MILLIS);
  144. Self::delay_for_millis(minor_delay).await;
  145. // Random drop of a DROP_RATE / DROP_BASE chance.
  146. if thread_rng().gen_ratio(Self::DROP_RATE.0, Self::DROP_RATE.1) {
  147. // Note this is different from the original Go version.
  148. // Here we don't reply to client until timeout actually passes.
  149. Self::delay_for_millis(Self::MAX_MINOR_DELAY_MILLIS).await;
  150. let _ = rpc.reply_channel.send(Err(std::io::Error::new(
  151. std::io::ErrorKind::TimedOut,
  152. "Remote server did not respond in time.",
  153. )));
  154. return;
  155. }
  156. }
  157. let reply = match server_result {
  158. // Call the server.
  159. Ok(server) => {
  160. // Simulates the copy from network to server.
  161. let data = rpc.request.clone();
  162. server.dispatch(rpc.service_method, data).await
  163. }
  164. // If the server does not exist, return error after a random delay.
  165. Err(e) => {
  166. let long_delay = rand::thread_rng().gen_range(
  167. 0,
  168. if long_delays {
  169. Self::MAX_LONG_DELAY_MILLIS
  170. } else {
  171. Self::MAX_SHORT_DELAY_MILLIS
  172. },
  173. );
  174. Self::delay_for_millis(long_delay).await;
  175. Err(e)
  176. }
  177. };
  178. if reply.is_ok() {
  179. // Random drop again.
  180. if !reliable
  181. && thread_rng().gen_ratio(Self::DROP_RATE.0, Self::DROP_RATE.1)
  182. {
  183. let _ = rpc.reply_channel.send(Err(std::io::Error::new(
  184. std::io::ErrorKind::TimedOut,
  185. "The network did not send respond in time.",
  186. )));
  187. return;
  188. } else if long_reordering {
  189. let should_reorder = thread_rng().gen_ratio(
  190. Self::LONG_REORDERING_RATE.0,
  191. Self::LONG_REORDERING_RATE.1,
  192. );
  193. if should_reorder {
  194. let long_delay_bound = thread_rng().gen_range(
  195. 0,
  196. Self::LONG_REORDERING_RANDOM_DELAY_BOUND_MILLIS,
  197. );
  198. let long_delay = Self::LONG_REORDERING_BASE_DELAY_MILLIS
  199. + thread_rng().gen_range(0, 1 + long_delay_bound);
  200. Self::delay_for_millis(long_delay).await;
  201. // Falling through to send the result.
  202. }
  203. }
  204. }
  205. if let Err(_e) = rpc.reply_channel.send(reply) {
  206. // TODO(ditsing): log and do nothing.
  207. }
  208. }
  209. pub fn run_daemon() -> Arc<Mutex<Network>> {
  210. let mut network = Network::new();
  211. let rx = network
  212. .request_pipe
  213. .take()
  214. .expect("Newly created network should have a rx");
  215. let network = Arc::new(Mutex::new(network));
  216. let thread_pool = tokio::runtime::Builder::new()
  217. .threaded_scheduler()
  218. .core_threads(10)
  219. .max_threads(20)
  220. .thread_name("network")
  221. .enable_time()
  222. .build()
  223. .expect("Creating network thread pool should not fail");
  224. let other = network.clone();
  225. std::thread::spawn(move || {
  226. let network = other;
  227. let mut stop_timer = Instant::now();
  228. loop {
  229. // If the lock of network is unfair, we could starve threads
  230. // trying to add / remove RPC servers, or change settings.
  231. // Having a shutdown delay helps minimise lock holding.
  232. if stop_timer.elapsed() >= Self::SHUTDOWN_DELAY {
  233. let locked_network = network
  234. .lock()
  235. .expect("Network mutex should not be poisoned");
  236. if !locked_network.keep_running {
  237. break;
  238. }
  239. stop_timer = Instant::now();
  240. }
  241. match rx.try_recv() {
  242. Ok(rpc) => {
  243. thread_pool
  244. .spawn(Self::serve_rpc(network.clone(), rpc));
  245. }
  246. // All senders have disconnected. This should never happen,
  247. // since the network instance itself holds a sender.
  248. Err(TryRecvError::Disconnected) => break,
  249. Err(TryRecvError::Empty) => {
  250. std::thread::sleep(Self::SHUTDOWN_DELAY)
  251. }
  252. }
  253. }
  254. // Shutdown might leak outstanding tasks if timed-out.
  255. thread_pool.shutdown_timeout(Self::SHUTDOWN_DELAY);
  256. // rx is dropped here, all clients should get disconnected error
  257. // and stop sending messages.
  258. drop(rx);
  259. network
  260. .lock()
  261. .expect("Network mutex should not be poisoned")
  262. .stopped
  263. .store(true, Ordering::Release);
  264. });
  265. network
  266. }
  267. }
  268. impl Network {
  269. fn increase_rpc_count(&self) {
  270. self.rpc_count.set(self.rpc_count.get() + 1);
  271. }
  272. fn new() -> Self {
  273. // The channel has infinite buffer, could OOM the server if there are
  274. // too many pending RPCs to be served.
  275. let (tx, rx) = channel();
  276. Network {
  277. reliable: true,
  278. long_delays: false,
  279. long_reordering: false,
  280. clients: Default::default(),
  281. servers: Default::default(),
  282. request_bus: tx,
  283. request_pipe: Some(rx),
  284. keep_running: true,
  285. stopped: Default::default(),
  286. rpc_count: std::cell::Cell::new(0),
  287. }
  288. }
  289. }
  290. #[cfg(test)]
  291. mod tests {
  292. use std::sync::MutexGuard;
  293. use crate::test_utils::{
  294. junk_server::{
  295. make_test_server, JunkRpcs, NON_CLIENT, NON_SERVER, TEST_CLIENT,
  296. TEST_SERVER,
  297. },
  298. make_aborting_rpc, make_echo_rpc,
  299. };
  300. use crate::{ReplyMessage, RequestMessage, Result};
  301. use super::*;
  302. fn make_network() -> Network {
  303. Network::new()
  304. }
  305. #[test]
  306. fn test_rpc_count_works() {
  307. let network = make_network();
  308. assert_eq!(0, network.get_total_rpc_count());
  309. network.increase_rpc_count();
  310. assert_eq!(1, network.get_total_rpc_count());
  311. }
  312. fn unlock<T>(network: &Arc<Mutex<T>>) -> MutexGuard<T> {
  313. network
  314. .lock()
  315. .expect("Network mutex should not be poisoned")
  316. }
  317. #[test]
  318. fn test_network_shutdown() {
  319. let network = Network::run_daemon();
  320. let sender = {
  321. let mut network = unlock(&network);
  322. network.keep_running = false;
  323. network.request_bus.clone()
  324. };
  325. while !unlock(&network).stopped() {
  326. std::thread::sleep(Network::SHUTDOWN_DELAY)
  327. }
  328. let (rpc, _) = make_echo_rpc("client", "server", &[]);
  329. let result = sender.send(rpc);
  330. assert!(
  331. result.is_err(),
  332. "Network is shutdown, requests should not be processed."
  333. );
  334. }
  335. fn send_rpc<C: Into<String>, S: Into<String>>(
  336. rpc: RpcOnWire,
  337. rx: futures::channel::oneshot::Receiver<Result<ReplyMessage>>,
  338. client: C,
  339. server: S,
  340. enabled: bool,
  341. ) -> Result<ReplyMessage> {
  342. let network = Network::run_daemon();
  343. let sender = {
  344. let mut network = unlock(&network);
  345. network
  346. .clients
  347. .insert(client.into(), (enabled, server.into()));
  348. network
  349. .servers
  350. .insert(TEST_SERVER.into(), make_test_server());
  351. network.request_bus.clone()
  352. };
  353. let result = sender.send(rpc);
  354. assert!(
  355. result.is_ok(),
  356. "Network is running, requests should be processed."
  357. );
  358. let reply = match futures::executor::block_on(rx) {
  359. Ok(reply) => reply,
  360. Err(e) => panic!("Future execution should not fail: {}", e),
  361. };
  362. reply
  363. }
  364. #[test]
  365. fn test_proxy_rpc() -> Result<()> {
  366. let (rpc, rx) =
  367. make_echo_rpc(TEST_CLIENT, TEST_SERVER, &[0x09u8, 0x00u8]);
  368. let reply = send_rpc(rpc, rx, TEST_CLIENT, TEST_SERVER, true);
  369. match reply {
  370. Ok(reply) => assert_eq!(reply.as_ref(), &[0x00u8, 0x09u8]),
  371. Err(e) => panic!("Expecting echo message, got {}", e),
  372. }
  373. Ok(())
  374. }
  375. #[test]
  376. fn test_proxy_rpc_server_error() -> Result<()> {
  377. let (rpc, rx) = make_aborting_rpc(TEST_CLIENT, TEST_SERVER);
  378. let reply = send_rpc(rpc, rx, TEST_CLIENT, TEST_SERVER, true);
  379. let err = reply.expect_err("Network should proxy server errors");
  380. assert_eq!(std::io::ErrorKind::ConnectionReset, err.kind());
  381. Ok(())
  382. }
  383. #[test]
  384. fn test_proxy_rpc_server_not_found() -> Result<()> {
  385. let (rpc, rx) = make_aborting_rpc(TEST_CLIENT, NON_SERVER);
  386. let reply = send_rpc(rpc, rx, TEST_CLIENT, NON_SERVER, true);
  387. let err = reply.expect_err("Network should check server in memory");
  388. assert_eq!(std::io::ErrorKind::NotFound, err.kind());
  389. Ok(())
  390. }
  391. #[test]
  392. fn test_proxy_rpc_client_disabled() -> Result<()> {
  393. let (rpc, rx) = make_aborting_rpc(TEST_CLIENT, TEST_SERVER);
  394. let reply = send_rpc(rpc, rx, TEST_CLIENT, TEST_SERVER, false);
  395. let err =
  396. reply.expect_err("Network should check if client is disabled");
  397. assert_eq!(std::io::ErrorKind::BrokenPipe, err.kind());
  398. Ok(())
  399. }
  400. #[test]
  401. fn test_proxy_rpc_no_such_client() -> Result<()> {
  402. let (rpc, rx) = make_aborting_rpc(NON_CLIENT, TEST_SERVER);
  403. let reply = send_rpc(rpc, rx, TEST_CLIENT, TEST_SERVER, true);
  404. let err = reply.expect_err("Network should check client names");
  405. assert_eq!(std::io::ErrorKind::PermissionDenied, err.kind());
  406. Ok(())
  407. }
  408. #[test]
  409. fn test_basic_functions() -> Result<()> {
  410. // Initialize
  411. let network = Network::run_daemon();
  412. let server = make_test_server();
  413. unlock(&network).add_server(TEST_SERVER, server);
  414. let client = unlock(&network).make_client(TEST_CLIENT, TEST_SERVER);
  415. assert_eq!(0, unlock(&network).get_total_rpc_count());
  416. let request = RequestMessage::from_static(&[0x17, 0x20]);
  417. let reply_data = &[0x20, 0x17];
  418. // Send first request.
  419. let reply = futures::executor::block_on(
  420. client.call_rpc(JunkRpcs::Echo.name(), request.clone()),
  421. )?;
  422. assert_eq!(reply_data, reply.as_ref());
  423. assert_eq!(1, unlock(&network).get_total_rpc_count());
  424. // Block the client.
  425. unlock(&network).set_enable_client(TEST_CLIENT, false);
  426. // Send second request.
  427. let reply = futures::executor::block_on(
  428. client.call_rpc(JunkRpcs::Echo.name(), request.clone()),
  429. );
  430. reply.expect_err("Client is blocked");
  431. assert_eq!(2, unlock(&network).get_total_rpc_count());
  432. assert_eq!(Some(1), unlock(&network).get_rpc_count(TEST_SERVER));
  433. assert_eq!(None, unlock(&network).get_rpc_count(NON_SERVER));
  434. // Unblock the client, then remove the server.
  435. unlock(&network).set_enable_client(TEST_CLIENT, true);
  436. unlock(&network).remove_server(&TEST_SERVER);
  437. // Send third request.
  438. let reply = futures::executor::block_on(
  439. client.call_rpc(JunkRpcs::Echo.name(), request.clone()),
  440. );
  441. reply.expect_err("Client is blocked");
  442. assert_eq!(3, unlock(&network).get_total_rpc_count());
  443. // Shutdown the network.
  444. unlock(&network).stop();
  445. while !unlock(&network).stopped() {
  446. std::thread::sleep(Duration::from_millis(10));
  447. }
  448. // Send forth request.
  449. let reply = futures::executor::block_on(
  450. client.call_rpc(JunkRpcs::Echo.name(), request.clone()),
  451. );
  452. reply.expect_err("Network is shutdown");
  453. assert_eq!(3, unlock(&network).get_total_rpc_count());
  454. // Done.
  455. Ok(())
  456. }
  457. }