| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- use crossbeam_channel::Sender;
- #[cfg(feature = "tracing")]
- use crate::tracing::{Trace, TraceHolder};
- use crate::{
- ClientIdentifier, ReplyMessage, RequestMessage, Result, RpcOnWire,
- ServerIdentifier,
- };
- // Client interface, used by the RPC client.
- pub struct Client {
- pub(crate) client: ClientIdentifier,
- pub(crate) server: ServerIdentifier,
- pub(crate) request_bus: Sender<Option<RpcOnWire>>,
- }
- impl Client {
- /// Error type and meaning
- /// * Not connected: The client did not have a chance to send the request
- /// because the network is down.
- /// * Permission denied: The network does not allow the client to send
- /// requests.
- /// * Broken pipe: The network no longer allows the client to send requests.
- /// * Not found: The network could not find the target server.
- /// * Invalid input: The server could not find the service / method to call.
- /// * Connection reset: The server received the request, but decided to stop
- /// responding.
- /// * Connection aborted: The client will not receive a reply because the
- /// the connection is closed by the network.
- pub async fn call_rpc(
- &self,
- service_method: String,
- request: RequestMessage,
- ) -> Result<ReplyMessage> {
- #[cfg(feature = "tracing")]
- {
- let trace = TraceHolder::make();
- self.trace_and_call_rpc(service_method, request, trace)
- .await
- }
- #[cfg(not(feature = "tracing"))]
- self.trace_and_call_rpc(service_method, request).await
- }
- async fn trace_and_call_rpc(
- &self,
- service_method: String,
- request: RequestMessage,
- #[cfg(feature = "tracing")] trace: TraceHolder,
- ) -> Result<ReplyMessage> {
- #[cfg(feature = "tracing")]
- let local_trace = trace.clone();
- let (tx, rx) = futures::channel::oneshot::channel();
- let rpc = RpcOnWire {
- client: self.client.clone(),
- server: self.server.clone(),
- service_method,
- request,
- reply_channel: tx,
- #[cfg(feature = "tracing")]
- trace,
- };
- mark_trace!(local_trace, assemble);
- self.request_bus.send(Some(rpc)).map_err(|e| {
- // The receiving end has been closed. Network connection is broken.
- std::io::Error::new(
- std::io::ErrorKind::NotConnected,
- format!(
- "Cannot send rpc, client {} is disconnected. {}",
- self.client, e
- ),
- )
- })?;
- mark_trace!(local_trace, enqueue);
- #[allow(clippy::let_and_return)]
- let ret = rx.await.map_err(|e| {
- std::io::Error::new(
- // The network closed our connection. The server might not even
- // get a chance to see the request.
- std::io::ErrorKind::ConnectionAborted,
- format!("Network request is dropped: {}", e),
- )
- })?;
- mark_trace!(local_trace, response);
- ret
- }
- #[cfg(feature = "tracing")]
- pub async fn trace_rpc(
- &self,
- service_method: String,
- request: RequestMessage,
- ) -> (Result<ReplyMessage>, Trace) {
- let trace = TraceHolder::make();
- let local_trace = trace.clone();
- let response = self
- .trace_and_call_rpc(service_method, request, trace)
- .await;
- (response, local_trace.extract())
- }
- }
- #[cfg(test)]
- mod tests {
- use crossbeam_channel::{unbounded, Sender};
- use super::*;
- fn make_rpc_call(tx: Sender<Option<RpcOnWire>>) -> Result<ReplyMessage> {
- let client = Client {
- client: "C".into(),
- server: "S".into(),
- request_bus: tx,
- };
- let request = RequestMessage::from_static(&[0x17, 0x20]);
- futures::executor::block_on(client.call_rpc("hello".into(), request))
- }
- fn make_rpc_call_and_reply(
- reply: Result<ReplyMessage>,
- ) -> Result<ReplyMessage> {
- let (tx, rx) = unbounded();
- let handle = std::thread::spawn(move || make_rpc_call(tx));
- let rpc = rx
- .recv()
- .expect("The request message should arrive")
- .expect("The request message should not be null");
- assert_eq!("C", &rpc.client);
- assert_eq!("S", &rpc.server);
- assert_eq!("hello", &rpc.service_method);
- assert_eq!(&[0x17, 0x20], rpc.request.as_ref());
- rpc.reply_channel
- .send(reply)
- .expect("The reply channel should not be closed");
- handle.join().expect("Rpc sending thread should succeed")
- }
- #[test]
- fn test_call_rpc() -> Result<()> {
- let data = &[0x11, 0x99];
- let reply =
- make_rpc_call_and_reply(Ok(ReplyMessage::from_static(data)))?;
- assert_eq!(data, reply.as_ref());
- Ok(())
- }
- #[test]
- fn test_call_rpc_remote_error() -> Result<()> {
- let reply = make_rpc_call_and_reply(Err(std::io::Error::new(
- std::io::ErrorKind::AddrInUse,
- "",
- )));
- if let Err(e) = reply {
- assert_eq!(std::io::ErrorKind::AddrInUse, e.kind());
- } else {
- panic!("Client should propagate remote error.")
- }
- Ok(())
- }
- #[test]
- fn test_call_rpc_remote_dropped() -> Result<()> {
- let (tx, rx) = unbounded();
- let handle = std::thread::spawn(move || make_rpc_call(tx));
- let rpc = rx
- .recv()
- .expect("The request message should arrive")
- .expect("The request message should not be null");
- drop(rpc.reply_channel);
- let reply = handle.join().expect("Rpc sending thread should succeed");
- if let Err(e) = reply {
- assert_eq!(std::io::ErrorKind::ConnectionAborted, e.kind());
- } else {
- panic!(
- "Client should return error. Reply channel has been dropped."
- )
- }
- Ok(())
- }
- #[test]
- fn test_call_rpc_not_connected() -> Result<()> {
- let (tx, rx) = unbounded();
- {
- drop(rx);
- }
- let handle = std::thread::spawn(move || make_rpc_call(tx));
- let reply = handle.join().expect("Rpc sending thread should succeed");
- if let Err(e) = reply {
- assert_eq!(std::io::ErrorKind::NotConnected, e.kind());
- } else {
- panic!("Client should return error. request_bus has been dropped.")
- }
- Ok(())
- }
- async fn make_rpc(client: Client) -> Result<ReplyMessage> {
- let request = RequestMessage::from_static(&[0x17, 0x20]);
- client.call_rpc("hello".into(), request).await
- }
- #[test]
- fn test_call_across_threads() -> Result<()> {
- let (tx, rx) = unbounded();
- let rpc_future = {
- let client = Client {
- client: "C".into(),
- server: "S".into(),
- request_bus: tx,
- };
- make_rpc(client)
- };
- std::thread::spawn(move || {
- let _ = futures::executor::block_on(rpc_future);
- });
- let rpc = rx
- .recv()
- .expect("The request message should arrive")
- .expect("The request message should not be null");
- rpc.reply_channel
- .send(Ok(Default::default()))
- .expect("The reply channel should not be closed");
- Ok(())
- }
- }
|