network.rs 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668
  1. use std::collections::HashMap;
  2. use std::sync::{
  3. atomic::{AtomicBool, Ordering},
  4. Arc,
  5. };
  6. use std::time::{Duration, Instant};
  7. use crossbeam_channel::{Receiver, Sender, TryRecvError};
  8. use parking_lot::Mutex;
  9. use rand::{thread_rng, Rng};
  10. use crate::{
  11. Client, ClientIdentifier, Result, RpcOnWire, Server, ServerIdentifier,
  12. };
  13. pub struct Network {
  14. // Settings.
  15. reliable: bool,
  16. long_delays: bool,
  17. long_reordering: bool,
  18. // Clients
  19. clients: HashMap<ClientIdentifier, (bool, ServerIdentifier)>,
  20. servers: HashMap<ServerIdentifier, Arc<Server>>,
  21. // Network bus
  22. request_bus: Sender<RpcOnWire>,
  23. request_pipe: Option<Receiver<RpcOnWire>>,
  24. // Closing signal.
  25. keep_running: bool,
  26. // Whether the network is active or not.
  27. stopped: AtomicBool,
  28. // RPC Counter, using Cell for interior mutability.
  29. rpc_count: std::cell::Cell<usize>,
  30. }
  31. impl Network {
  32. pub fn set_reliable(&mut self, yes: bool) {
  33. self.reliable = yes
  34. }
  35. pub fn set_long_reordering(&mut self, yes: bool) {
  36. self.long_reordering = yes
  37. }
  38. pub fn set_long_delays(&mut self, yes: bool) {
  39. self.long_delays = yes
  40. }
  41. pub fn stop(&mut self) {
  42. self.keep_running = false;
  43. }
  44. pub fn stopped(&self) -> bool {
  45. self.stopped.load(Ordering::Acquire)
  46. }
  47. pub fn make_client<C: Into<ClientIdentifier>, S: Into<ServerIdentifier>>(
  48. &mut self,
  49. client: C,
  50. server: S,
  51. ) -> Client {
  52. let (client, server) = (client.into(), server.into());
  53. self.clients.insert(client.clone(), (true, server.clone()));
  54. Client {
  55. client,
  56. server,
  57. request_bus: self.request_bus.clone(),
  58. }
  59. }
  60. pub fn set_enable_client<C: AsRef<str>>(&mut self, client: C, yes: bool) {
  61. if let Some(pair) = self.clients.get_mut(client.as_ref()) {
  62. pair.0 = yes;
  63. }
  64. }
  65. pub fn add_server<S: Into<ServerIdentifier>>(
  66. &mut self,
  67. server_name: S,
  68. server: Server,
  69. ) {
  70. self.servers.insert(server_name.into(), Arc::new(server));
  71. }
  72. pub fn remove_server<S: AsRef<str>>(&mut self, server_name: S) {
  73. self.servers.remove(server_name.as_ref());
  74. }
  75. pub fn get_rpc_count<S: AsRef<str>>(
  76. &self,
  77. server_name: S,
  78. ) -> Option<usize> {
  79. self.servers
  80. .get(server_name.as_ref())
  81. .map(|s| s.rpc_count())
  82. }
  83. #[allow(clippy::ptr_arg)]
  84. fn dispatch(&self, client: &ClientIdentifier) -> Result<Arc<Server>> {
  85. let (enabled, server_name) =
  86. self.clients.get(client).ok_or_else(|| {
  87. std::io::Error::new(
  88. std::io::ErrorKind::PermissionDenied,
  89. format!("Client {} is not connected.", client),
  90. )
  91. })?;
  92. if !enabled {
  93. return Err(std::io::Error::new(
  94. std::io::ErrorKind::BrokenPipe,
  95. format!("Client {} is disabled.", client),
  96. ));
  97. }
  98. let server = self.servers.get(server_name).ok_or_else(|| {
  99. std::io::Error::new(
  100. std::io::ErrorKind::NotFound,
  101. format!(
  102. "Cannot connect {} to server {}: server not found.",
  103. client, server_name,
  104. ),
  105. )
  106. })?;
  107. Ok(server.clone())
  108. }
  109. pub fn get_total_rpc_count(&self) -> usize {
  110. self.rpc_count.get()
  111. }
  112. }
  113. impl Network {
  114. const MAX_MINOR_DELAY_MILLIS: u64 = 27;
  115. const MAX_SHORT_DELAY_MILLIS: u64 = 100;
  116. const MAX_LONG_DELAY_MILLIS: u64 = 7000;
  117. const DROP_RATE: (u32, u32) = (100, 1000);
  118. const LONG_REORDERING_RATE: (u32, u32) = (600u32, 900u32);
  119. const LONG_REORDERING_BASE_DELAY_MILLIS: u64 = 200;
  120. const LONG_REORDERING_RANDOM_DELAY_BOUND_MILLIS: u64 = 2000;
  121. const SHUTDOWN_DELAY: Duration = Duration::from_micros(100);
  122. async fn delay_for_millis(milli_seconds: u64) {
  123. tokio::time::sleep(Duration::from_millis(milli_seconds)).await;
  124. }
  125. async fn serve_rpc(network: Arc<Mutex<Self>>, rpc: RpcOnWire) {
  126. let (server_result, reliable, long_reordering, long_delays) = {
  127. let network = network.lock();
  128. network.increase_rpc_count();
  129. (
  130. network.dispatch(&rpc.client),
  131. network.reliable,
  132. network.long_reordering,
  133. network.long_delays,
  134. )
  135. };
  136. // Random delay before sending requests to server.
  137. if !reliable {
  138. let minor_delay =
  139. thread_rng().gen_range(0, Self::MAX_MINOR_DELAY_MILLIS);
  140. Self::delay_for_millis(minor_delay).await;
  141. // Random drop of a DROP_RATE / DROP_BASE chance.
  142. if thread_rng().gen_ratio(Self::DROP_RATE.0, Self::DROP_RATE.1) {
  143. // Note this is different from the original Go version.
  144. // Here we don't reply to client until timeout actually passes.
  145. Self::delay_for_millis(Self::MAX_MINOR_DELAY_MILLIS).await;
  146. let _ = rpc.reply_channel.send(Err(std::io::Error::new(
  147. std::io::ErrorKind::TimedOut,
  148. "Remote server did not respond in time.",
  149. )));
  150. return;
  151. }
  152. }
  153. let mut lookup_result = None;
  154. let reply = match server_result {
  155. // Call the server.
  156. Ok(server) => {
  157. // Simulates the copy from network to server.
  158. let data = rpc.request.clone();
  159. lookup_result.replace(server.clone());
  160. // No need to set timeout. The RPCs are not supposed to block.
  161. server.dispatch(rpc.service_method, data).await
  162. }
  163. // If the server does not exist, return error after a random delay.
  164. Err(e) => {
  165. let long_delay = rand::thread_rng().gen_range(
  166. 0,
  167. if long_delays {
  168. Self::MAX_LONG_DELAY_MILLIS
  169. } else {
  170. Self::MAX_SHORT_DELAY_MILLIS
  171. },
  172. );
  173. Self::delay_for_millis(long_delay).await;
  174. Err(e)
  175. }
  176. };
  177. if reply.is_ok() {
  178. // The lookup must have succeeded.
  179. let lookup_result = lookup_result.unwrap();
  180. // The server's address must have been changed, given that we take
  181. // ownership of the server when it is registered.
  182. let unchanged = match network.lock().dispatch(&rpc.client) {
  183. Ok(server) => Arc::ptr_eq(&server, &lookup_result),
  184. Err(_) => false,
  185. };
  186. // Fail the RPC if the client has been disconnected, or the server
  187. // has been updated.
  188. if !unchanged {
  189. let _ = rpc.reply_channel.send(Err(std::io::Error::new(
  190. std::io::ErrorKind::ConnectionReset,
  191. "Network connection has been reset.".to_owned(),
  192. )));
  193. return;
  194. }
  195. // Random drop again.
  196. if !reliable
  197. && thread_rng().gen_ratio(Self::DROP_RATE.0, Self::DROP_RATE.1)
  198. {
  199. let _ = rpc.reply_channel.send(Err(std::io::Error::new(
  200. std::io::ErrorKind::TimedOut,
  201. "The network did not send respond in time.",
  202. )));
  203. return;
  204. } else if long_reordering {
  205. let should_reorder = thread_rng().gen_ratio(
  206. Self::LONG_REORDERING_RATE.0,
  207. Self::LONG_REORDERING_RATE.1,
  208. );
  209. if should_reorder {
  210. let long_delay_bound = thread_rng().gen_range(
  211. 0,
  212. Self::LONG_REORDERING_RANDOM_DELAY_BOUND_MILLIS,
  213. );
  214. let long_delay = Self::LONG_REORDERING_BASE_DELAY_MILLIS
  215. + thread_rng().gen_range(0, 1 + long_delay_bound);
  216. Self::delay_for_millis(long_delay).await;
  217. // Falling through to send the result.
  218. }
  219. }
  220. }
  221. if let Err(_e) = rpc.reply_channel.send(reply) {
  222. // TODO(ditsing): log and do nothing.
  223. }
  224. }
  225. pub fn run_daemon() -> Arc<Mutex<Network>> {
  226. let mut network = Network::new();
  227. let rx = network
  228. .request_pipe
  229. .take()
  230. .expect("Newly created network should have a rx");
  231. // Using Mutex instead of RWLock, because most of the access are reads.
  232. let network = Arc::new(Mutex::new(network));
  233. // Using tokio instead of futures-rs, because we need timer futures.
  234. let thread_pool = tokio::runtime::Builder::new_multi_thread()
  235. .worker_threads(10)
  236. .max_threads(20)
  237. .thread_name("network")
  238. .enable_time()
  239. .build()
  240. .expect("Creating network thread pool should not fail");
  241. let other = network.clone();
  242. std::thread::spawn(move || {
  243. let network = other;
  244. let mut stop_timer = Instant::now();
  245. loop {
  246. // If the lock of network is unfair, we could starve threads
  247. // trying to add / remove RPC servers, or change settings.
  248. // Having a shutdown delay helps minimise lock holding.
  249. if stop_timer.elapsed() >= Self::SHUTDOWN_DELAY {
  250. if !network.lock().keep_running {
  251. break;
  252. }
  253. stop_timer = Instant::now();
  254. }
  255. match rx.try_recv() {
  256. Ok(rpc) => {
  257. thread_pool
  258. .spawn(Self::serve_rpc(network.clone(), rpc));
  259. }
  260. // All senders have disconnected. This should never happen,
  261. // since the network instance itself holds a sender.
  262. Err(TryRecvError::Disconnected) => break,
  263. Err(TryRecvError::Empty) => {
  264. std::thread::sleep(Self::SHUTDOWN_DELAY)
  265. }
  266. }
  267. }
  268. // Shutdown might leak outstanding tasks if timed-out.
  269. thread_pool.shutdown_timeout(Self::SHUTDOWN_DELAY);
  270. // rx is dropped here, all clients should get disconnected error
  271. // and stop sending messages.
  272. drop(rx);
  273. network.lock().stopped.store(true, Ordering::Release);
  274. });
  275. network
  276. }
  277. }
  278. impl Network {
  279. fn increase_rpc_count(&self) {
  280. self.rpc_count.set(self.rpc_count.get() + 1);
  281. }
  282. fn new() -> Self {
  283. // The channel has infinite buffer, could OOM the server if there are
  284. // too many pending RPCs to be served.
  285. let (tx, rx) = crossbeam_channel::unbounded();
  286. Network {
  287. reliable: true,
  288. long_delays: false,
  289. long_reordering: false,
  290. clients: Default::default(),
  291. servers: Default::default(),
  292. request_bus: tx,
  293. request_pipe: Some(rx),
  294. keep_running: true,
  295. stopped: Default::default(),
  296. rpc_count: std::cell::Cell::new(0),
  297. }
  298. }
  299. }
  300. #[cfg(test)]
  301. mod tests {
  302. use std::sync::Barrier;
  303. use parking_lot::MutexGuard;
  304. use crate::test_utils::{
  305. junk_server::{
  306. make_test_server, JunkRpcs, NON_CLIENT, NON_SERVER, TEST_CLIENT,
  307. TEST_SERVER,
  308. },
  309. make_aborting_rpc, make_echo_rpc,
  310. };
  311. use crate::{ReplyMessage, RequestMessage, Result};
  312. use super::*;
  313. fn make_network() -> Network {
  314. Network::new()
  315. }
  316. #[test]
  317. fn test_rpc_count_works() {
  318. let network = make_network();
  319. assert_eq!(0, network.get_total_rpc_count());
  320. network.increase_rpc_count();
  321. assert_eq!(1, network.get_total_rpc_count());
  322. }
  323. fn unlock<T>(network: &Arc<Mutex<T>>) -> MutexGuard<T> {
  324. network.lock()
  325. }
  326. #[test]
  327. fn test_network_shutdown() {
  328. let network = Network::run_daemon();
  329. let sender = {
  330. let mut network = unlock(&network);
  331. network.keep_running = false;
  332. network.request_bus.clone()
  333. };
  334. while !unlock(&network).stopped() {
  335. std::thread::sleep(Network::SHUTDOWN_DELAY)
  336. }
  337. let (rpc, _) = make_echo_rpc("client", "server", &[]);
  338. let result = sender.send(rpc);
  339. assert!(
  340. result.is_err(),
  341. "Network is shutdown, requests should not be processed."
  342. );
  343. }
  344. fn send_rpc<C: Into<String>, S: Into<String>>(
  345. rpc: RpcOnWire,
  346. rx: futures::channel::oneshot::Receiver<Result<ReplyMessage>>,
  347. client: C,
  348. server: S,
  349. enabled: bool,
  350. ) -> Result<ReplyMessage> {
  351. let network = Network::run_daemon();
  352. let sender = {
  353. let mut network = unlock(&network);
  354. network
  355. .clients
  356. .insert(client.into(), (enabled, server.into()));
  357. network
  358. .servers
  359. .insert(TEST_SERVER.into(), Arc::new(make_test_server()));
  360. network.request_bus.clone()
  361. };
  362. let result = sender.send(rpc);
  363. assert!(
  364. result.is_ok(),
  365. "Network is running, requests should be processed."
  366. );
  367. let reply = match futures::executor::block_on(rx) {
  368. Ok(reply) => reply,
  369. Err(e) => panic!("Future execution should not fail: {}", e),
  370. };
  371. reply
  372. }
  373. #[test]
  374. fn test_proxy_rpc() -> Result<()> {
  375. let (rpc, rx) =
  376. make_echo_rpc(TEST_CLIENT, TEST_SERVER, &[0x09u8, 0x00u8]);
  377. let reply = send_rpc(rpc, rx, TEST_CLIENT, TEST_SERVER, true);
  378. match reply {
  379. Ok(reply) => assert_eq!(reply.as_ref(), &[0x00u8, 0x09u8]),
  380. Err(e) => panic!("Expecting echo message, got {}", e),
  381. }
  382. Ok(())
  383. }
  384. #[test]
  385. fn test_proxy_rpc_server_error() -> Result<()> {
  386. let (rpc, rx) = make_aborting_rpc(TEST_CLIENT, TEST_SERVER);
  387. let reply = send_rpc(rpc, rx, TEST_CLIENT, TEST_SERVER, true);
  388. let err = reply.expect_err("Network should proxy server errors");
  389. assert_eq!(std::io::ErrorKind::ConnectionReset, err.kind());
  390. Ok(())
  391. }
  392. #[test]
  393. fn test_proxy_rpc_server_not_found() -> Result<()> {
  394. let (rpc, rx) = make_aborting_rpc(TEST_CLIENT, NON_SERVER);
  395. let reply = send_rpc(rpc, rx, TEST_CLIENT, NON_SERVER, true);
  396. let err = reply.expect_err("Network should check server in memory");
  397. assert_eq!(std::io::ErrorKind::NotFound, err.kind());
  398. Ok(())
  399. }
  400. #[test]
  401. fn test_proxy_rpc_client_disabled() -> Result<()> {
  402. let (rpc, rx) = make_aborting_rpc(TEST_CLIENT, TEST_SERVER);
  403. let reply = send_rpc(rpc, rx, TEST_CLIENT, TEST_SERVER, false);
  404. let err =
  405. reply.expect_err("Network should check if client is disabled");
  406. assert_eq!(std::io::ErrorKind::BrokenPipe, err.kind());
  407. Ok(())
  408. }
  409. #[test]
  410. fn test_proxy_rpc_no_such_client() -> Result<()> {
  411. let (rpc, rx) = make_aborting_rpc(NON_CLIENT, TEST_SERVER);
  412. let reply = send_rpc(rpc, rx, TEST_CLIENT, TEST_SERVER, true);
  413. let err = reply.expect_err("Network should check client names");
  414. assert_eq!(std::io::ErrorKind::PermissionDenied, err.kind());
  415. Ok(())
  416. }
  417. #[test]
  418. fn test_server_killed() -> Result<()> {
  419. let start_barrier = Arc::new(Barrier::new(2));
  420. let end_barrier = Arc::new(Barrier::new(2));
  421. let network = Network::run_daemon();
  422. let mut server = Server::make_server(TEST_SERVER);
  423. let start_barrier_clone = start_barrier.clone();
  424. let end_barrier_clone = end_barrier.clone();
  425. server.register_rpc_handler(
  426. "blocking".to_owned(),
  427. Box::new(move |args| {
  428. start_barrier_clone.wait();
  429. end_barrier_clone.wait();
  430. args.into()
  431. }),
  432. )?;
  433. unlock(&network).add_server(TEST_SERVER, server);
  434. let client = unlock(&network).make_client(TEST_CLIENT, TEST_SERVER);
  435. std::thread::spawn(move || {
  436. start_barrier.wait();
  437. unlock(&network).remove_server(TEST_SERVER);
  438. end_barrier.wait();
  439. });
  440. let reply = futures::executor::block_on(
  441. client.call_rpc("blocking".to_owned(), Default::default()),
  442. );
  443. let err = reply
  444. .expect_err("Client should receive error after server is killed");
  445. assert_eq!(std::io::ErrorKind::ConnectionReset, err.kind());
  446. Ok(())
  447. }
  448. fn make_network_and_client() -> (Arc<Mutex<Network>>, Client) {
  449. let network = Network::run_daemon();
  450. let server = make_test_server();
  451. unlock(&network).add_server(TEST_SERVER, server);
  452. let client = unlock(&network).make_client(TEST_CLIENT, TEST_SERVER);
  453. (network, client)
  454. }
  455. #[test]
  456. fn test_basic_functions() -> Result<()> {
  457. // Initialize
  458. let (network, client) = make_network_and_client();
  459. assert_eq!(0, unlock(&network).get_total_rpc_count());
  460. let request = RequestMessage::from_static(&[0x17, 0x20]);
  461. let reply_data = &[0x20, 0x17];
  462. // Send first request.
  463. let reply = futures::executor::block_on(
  464. client.call_rpc(JunkRpcs::Echo.name(), request.clone()),
  465. )?;
  466. assert_eq!(reply_data, reply.as_ref());
  467. assert_eq!(1, unlock(&network).get_total_rpc_count());
  468. // Block the client.
  469. unlock(&network).set_enable_client(TEST_CLIENT, false);
  470. // Send second request.
  471. let reply = futures::executor::block_on(
  472. client.call_rpc(JunkRpcs::Echo.name(), request.clone()),
  473. );
  474. reply.expect_err("Client is blocked");
  475. assert_eq!(2, unlock(&network).get_total_rpc_count());
  476. assert_eq!(Some(1), unlock(&network).get_rpc_count(TEST_SERVER));
  477. assert_eq!(None, unlock(&network).get_rpc_count(NON_SERVER));
  478. // Unblock the client, then remove the server.
  479. unlock(&network).set_enable_client(TEST_CLIENT, true);
  480. unlock(&network).remove_server(&TEST_SERVER);
  481. // Send third request.
  482. let reply = futures::executor::block_on(
  483. client.call_rpc(JunkRpcs::Echo.name(), request.clone()),
  484. );
  485. reply.expect_err("Client is blocked");
  486. assert_eq!(3, unlock(&network).get_total_rpc_count());
  487. // Shutdown the network.
  488. unlock(&network).stop();
  489. while !unlock(&network).stopped() {
  490. std::thread::sleep(Duration::from_millis(10));
  491. }
  492. // Send forth request.
  493. let reply = futures::executor::block_on(
  494. client.call_rpc(JunkRpcs::Echo.name(), request.clone()),
  495. );
  496. reply.expect_err("Network is shutdown");
  497. assert_eq!(3, unlock(&network).get_total_rpc_count());
  498. // Done.
  499. Ok(())
  500. }
  501. #[test]
  502. fn test_unreliable() -> Result<()> {
  503. let (network, _) = make_network_and_client();
  504. network.lock().set_reliable(false);
  505. const RPC_COUNT: usize = 1000;
  506. let mut handles = vec![];
  507. for i in 0..RPC_COUNT {
  508. let client = network.lock().make_client(format!("{}-{}", TEST_CLIENT, i), TEST_SERVER);
  509. let handle = std::thread::spawn(move || {
  510. let reply = client.call_rpc(
  511. JunkRpcs::Echo.name(),
  512. RequestMessage::from_static(&[0x20, 0x17]),
  513. );
  514. futures::executor::block_on(reply)
  515. });
  516. handles.push(handle);
  517. }
  518. let mut success = 0;
  519. for handle in handles {
  520. let result = handle.join().expect("Thread join should not fail");
  521. if result.is_ok() {
  522. success += 1;
  523. }
  524. }
  525. let success = success as f64;
  526. assert!(success > RPC_COUNT as f64 * 0.75, "More than 15% RPC failed");
  527. assert!(success < RPC_COUNT as f64 * 0.85, "Less than 5% RPC failed");
  528. Ok(())
  529. }
  530. #[test]
  531. #[ignore = "Large tests with many threads"]
  532. fn test_many_requests() {
  533. let now = Instant::now();
  534. let (network, _) = make_network_and_client();
  535. let barrier = Arc::new(Barrier::new(THREAD_COUNT + 1));
  536. const THREAD_COUNT: usize = 200;
  537. const RPC_COUNT: usize = 100;
  538. let mut handles = vec![];
  539. for i in 0..THREAD_COUNT {
  540. let network_ref = network.clone();
  541. let barrier_ref = barrier.clone();
  542. let handle = std::thread::spawn(move || {
  543. let client = unlock(&network_ref)
  544. .make_client(format!("{}-{}", TEST_CLIENT, i), TEST_SERVER);
  545. // We should all create the client first.
  546. barrier_ref.wait();
  547. let mut results = vec![];
  548. for _ in 0..RPC_COUNT {
  549. let reply = client.call_rpc(
  550. JunkRpcs::Echo.name(),
  551. RequestMessage::from_static(&[0x20, 0x17]),
  552. );
  553. results.push(reply);
  554. }
  555. for result in results {
  556. futures::executor::block_on(result)
  557. .expect("All futures should succeed");
  558. }
  559. });
  560. handles.push(handle);
  561. }
  562. barrier.wait();
  563. for handle in handles {
  564. handle.join().expect("All threads should succeed");
  565. }
  566. eprintln!("Many requests test took {:?}", now.elapsed());
  567. }
  568. }