utils.rs 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. use std::io::ErrorKind;
  2. use std::net::SocketAddr;
  3. use futures_util::StreamExt;
  4. use tarpc::client::RpcError;
  5. use tarpc::server::{Channel, Serve};
  6. pub(crate) fn deadline_forever() -> std::time::SystemTime {
  7. std::time::SystemTime::now()
  8. // This is the maximum deadline allowed by tarpc / tokio_util.
  9. + std::time::Duration::from_secs(2 * 365 * 24 * 60 * 60)
  10. }
  11. pub(crate) fn context() -> tarpc::context::Context {
  12. let mut context = tarpc::context::Context::current();
  13. context.deadline = deadline_forever();
  14. context
  15. }
  16. pub(crate) fn translate_rpc_error(e: RpcError) -> std::io::Error {
  17. match e {
  18. RpcError::Shutdown | RpcError::Send(_) | RpcError::Receive(_) => {
  19. std::io::Error::new(ErrorKind::BrokenPipe, e)
  20. }
  21. RpcError::DeadlineExceeded => {
  22. std::io::Error::new(ErrorKind::TimedOut, e)
  23. }
  24. RpcError::Server(server_error) => {
  25. std::io::Error::new(ErrorKind::Other, server_error)
  26. }
  27. }
  28. }
  29. pub(crate) async fn start_tarpc_server<Request, Reply, ServeFn>(
  30. addr: SocketAddr,
  31. serve: ServeFn,
  32. ) -> std::io::Result<()>
  33. where
  34. Request: Send + 'static + serde::de::DeserializeOwned,
  35. Reply: Send + 'static + serde::ser::Serialize,
  36. ServeFn:
  37. tarpc::server::Serve<Request, Resp = Reply> + Send + 'static + Clone,
  38. <ServeFn as Serve<Request>>::Fut: Send,
  39. {
  40. let mut listener = tarpc::serde_transport::tcp::listen(
  41. addr,
  42. tokio_serde::formats::Json::default,
  43. )
  44. .await?;
  45. tokio::spawn(async move {
  46. while let Some(conn) = listener.next().await {
  47. if let Ok(conn) = conn {
  48. let channel = tarpc::server::BaseChannel::with_defaults(conn)
  49. .max_concurrent_requests(1);
  50. tokio::spawn(channel.execute(serve.clone()));
  51. }
  52. }
  53. });
  54. Ok(())
  55. }