client.rs 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. use crossbeam_channel::Sender;
  2. #[cfg(feature = "tracing")]
  3. use crate::tracing::{Trace, TraceHolder};
  4. use crate::{
  5. ClientIdentifier, ReplyMessage, RequestMessage, Result, RpcOnWire,
  6. ServerIdentifier,
  7. };
  8. // Client interface, used by the RPC client.
  9. pub struct Client {
  10. pub(crate) client: ClientIdentifier,
  11. pub(crate) server: ServerIdentifier,
  12. pub(crate) request_bus: Sender<Option<RpcOnWire>>,
  13. }
  14. impl Client {
  15. /// Error type and meaning
  16. /// * Not connected: The client did not have a chance to send the request
  17. /// because the network is down.
  18. /// * Permission denied: The network does not allow the client to send
  19. /// requests.
  20. /// * Broken pipe: The network no longer allows the client to send requests.
  21. /// * Not found: The network could not find the target server.
  22. /// * Invalid input: The server could not find the service / method to call.
  23. /// * Connection reset: The server received the request, but decided to stop
  24. /// responding.
  25. /// * Connection aborted: The client will not receive a reply because the
  26. /// the connection is closed by the network.
  27. pub async fn call_rpc(
  28. &self,
  29. service_method: String,
  30. request: RequestMessage,
  31. ) -> Result<ReplyMessage> {
  32. #[cfg(feature = "tracing")]
  33. {
  34. let trace = TraceHolder::make();
  35. self.trace_and_call_rpc(service_method, request, trace)
  36. .await
  37. }
  38. #[cfg(not(feature = "tracing"))]
  39. self.trace_and_call_rpc(service_method, request).await
  40. }
  41. async fn trace_and_call_rpc(
  42. &self,
  43. service_method: String,
  44. request: RequestMessage,
  45. #[cfg(feature = "tracing")] trace: TraceHolder,
  46. ) -> Result<ReplyMessage> {
  47. #[cfg(feature = "tracing")]
  48. let local_trace = trace.clone();
  49. let (tx, rx) = futures::channel::oneshot::channel();
  50. let rpc = RpcOnWire {
  51. client: self.client.clone(),
  52. server: self.server.clone(),
  53. service_method,
  54. request,
  55. reply_channel: tx,
  56. #[cfg(feature = "tracing")]
  57. trace,
  58. };
  59. mark_trace!(local_trace, assemble);
  60. self.request_bus.send(Some(rpc)).map_err(|e| {
  61. // The receiving end has been closed. Network connection is broken.
  62. std::io::Error::new(
  63. std::io::ErrorKind::NotConnected,
  64. format!(
  65. "Cannot send rpc, client {} is disconnected. {}",
  66. self.client, e
  67. ),
  68. )
  69. })?;
  70. mark_trace!(local_trace, enqueue);
  71. #[allow(clippy::let_and_return)]
  72. let ret = rx.await.map_err(|e| {
  73. std::io::Error::new(
  74. // The network closed our connection. The server might not even
  75. // get a chance to see the request.
  76. std::io::ErrorKind::ConnectionAborted,
  77. format!("Network request is dropped: {}", e),
  78. )
  79. })?;
  80. mark_trace!(local_trace, response);
  81. ret
  82. }
  83. #[cfg(feature = "tracing")]
  84. pub async fn trace_rpc(
  85. &self,
  86. service_method: String,
  87. request: RequestMessage,
  88. ) -> (Result<ReplyMessage>, Trace) {
  89. let trace = TraceHolder::make();
  90. let local_trace = trace.clone();
  91. let response = self
  92. .trace_and_call_rpc(service_method, request, trace)
  93. .await;
  94. (response, local_trace.extract())
  95. }
  96. }
  97. #[cfg(test)]
  98. mod tests {
  99. use crossbeam_channel::{unbounded, Sender};
  100. use super::*;
  101. fn make_rpc_call(tx: Sender<Option<RpcOnWire>>) -> Result<ReplyMessage> {
  102. let client = Client {
  103. client: "C".into(),
  104. server: "S".into(),
  105. request_bus: tx,
  106. };
  107. let request = RequestMessage::from_static(&[0x17, 0x20]);
  108. futures::executor::block_on(client.call_rpc("hello".into(), request))
  109. }
  110. fn make_rpc_call_and_reply(
  111. reply: Result<ReplyMessage>,
  112. ) -> Result<ReplyMessage> {
  113. let (tx, rx) = unbounded();
  114. let handle = std::thread::spawn(move || make_rpc_call(tx));
  115. let rpc = rx
  116. .recv()
  117. .expect("The request message should arrive")
  118. .expect("The request message should not be null");
  119. assert_eq!("C", &rpc.client);
  120. assert_eq!("S", &rpc.server);
  121. assert_eq!("hello", &rpc.service_method);
  122. assert_eq!(&[0x17, 0x20], rpc.request.as_ref());
  123. rpc.reply_channel
  124. .send(reply)
  125. .expect("The reply channel should not be closed");
  126. handle.join().expect("Rpc sending thread should succeed")
  127. }
  128. #[test]
  129. fn test_call_rpc() -> Result<()> {
  130. let data = &[0x11, 0x99];
  131. let reply =
  132. make_rpc_call_and_reply(Ok(ReplyMessage::from_static(data)))?;
  133. assert_eq!(data, reply.as_ref());
  134. Ok(())
  135. }
  136. #[test]
  137. fn test_call_rpc_remote_error() -> Result<()> {
  138. let reply = make_rpc_call_and_reply(Err(std::io::Error::new(
  139. std::io::ErrorKind::AddrInUse,
  140. "",
  141. )));
  142. if let Err(e) = reply {
  143. assert_eq!(std::io::ErrorKind::AddrInUse, e.kind());
  144. } else {
  145. panic!("Client should propagate remote error.")
  146. }
  147. Ok(())
  148. }
  149. #[test]
  150. fn test_call_rpc_remote_dropped() -> Result<()> {
  151. let (tx, rx) = unbounded();
  152. let handle = std::thread::spawn(move || make_rpc_call(tx));
  153. let rpc = rx
  154. .recv()
  155. .expect("The request message should arrive")
  156. .expect("The request message should not be null");
  157. drop(rpc.reply_channel);
  158. let reply = handle.join().expect("Rpc sending thread should succeed");
  159. if let Err(e) = reply {
  160. assert_eq!(std::io::ErrorKind::ConnectionAborted, e.kind());
  161. } else {
  162. panic!(
  163. "Client should return error. Reply channel has been dropped."
  164. )
  165. }
  166. Ok(())
  167. }
  168. #[test]
  169. fn test_call_rpc_not_connected() -> Result<()> {
  170. let (tx, rx) = unbounded();
  171. {
  172. drop(rx);
  173. }
  174. let handle = std::thread::spawn(move || make_rpc_call(tx));
  175. let reply = handle.join().expect("Rpc sending thread should succeed");
  176. if let Err(e) = reply {
  177. assert_eq!(std::io::ErrorKind::NotConnected, e.kind());
  178. } else {
  179. panic!("Client should return error. request_bus has been dropped.")
  180. }
  181. Ok(())
  182. }
  183. async fn make_rpc(client: Client) -> Result<ReplyMessage> {
  184. let request = RequestMessage::from_static(&[0x17, 0x20]);
  185. client.call_rpc("hello".into(), request).await
  186. }
  187. #[test]
  188. fn test_call_across_threads() -> Result<()> {
  189. let (tx, rx) = unbounded();
  190. let rpc_future = {
  191. let client = Client {
  192. client: "C".into(),
  193. server: "S".into(),
  194. request_bus: tx,
  195. };
  196. make_rpc(client)
  197. };
  198. std::thread::spawn(move || {
  199. let _ = futures::executor::block_on(rpc_future);
  200. });
  201. let rpc = rx
  202. .recv()
  203. .expect("The request message should arrive")
  204. .expect("The request message should not be null");
  205. rpc.reply_channel
  206. .send(Ok(Default::default()))
  207. .expect("The reply channel should not be closed");
  208. Ok(())
  209. }
  210. }