utils.rs 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. use std::io::ErrorKind;
  2. use std::net::SocketAddr;
  3. use futures_util::StreamExt;
  4. use tarpc::client::RpcError;
  5. use tarpc::server::{Channel, Serve};
  6. pub(crate) fn translate_rpc_error(e: RpcError) -> std::io::Error {
  7. match e {
  8. RpcError::Disconnected => std::io::Error::new(ErrorKind::BrokenPipe, e),
  9. RpcError::DeadlineExceeded => {
  10. std::io::Error::new(ErrorKind::TimedOut, e)
  11. }
  12. RpcError::Server(server_error) => {
  13. std::io::Error::new(ErrorKind::Other, server_error)
  14. }
  15. }
  16. }
  17. pub(crate) async fn start_tarpc_server<Request, Reply, ServeFn>(
  18. addr: SocketAddr,
  19. serve: ServeFn,
  20. ) -> std::io::Result<()>
  21. where
  22. Request: Send + 'static + serde::de::DeserializeOwned,
  23. Reply: Send + 'static + serde::ser::Serialize,
  24. ServeFn:
  25. tarpc::server::Serve<Request, Resp = Reply> + Send + 'static + Clone,
  26. <ServeFn as Serve<Request>>::Fut: Send,
  27. {
  28. let mut listener = tarpc::serde_transport::tcp::listen(
  29. addr,
  30. tokio_serde::formats::Json::default,
  31. )
  32. .await?;
  33. tokio::spawn(async move {
  34. while let Some(conn) = listener.next().await {
  35. if let Ok(conn) = conn {
  36. let channel = tarpc::server::BaseChannel::with_defaults(conn)
  37. .max_concurrent_requests(1);
  38. tokio::spawn(channel.execute(serve.clone()));
  39. }
  40. }
  41. });
  42. Ok(())
  43. }