utils.rs 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. use std::io::ErrorKind;
  2. use std::net::SocketAddr;
  3. use futures_util::StreamExt;
  4. use tarpc::client::RpcError;
  5. use tarpc::server::{Channel, Serve};
  6. pub(crate) fn deadline_forever() -> std::time::SystemTime {
  7. std::time::SystemTime::now()
  8. // This is the maximum deadline allowed by tarpc / tokio_util.
  9. + std::time::Duration::from_secs(2 * 365 * 24 * 60 * 60)
  10. }
  11. pub(crate) fn context() -> tarpc::context::Context {
  12. let mut context = tarpc::context::Context::current();
  13. context.deadline = deadline_forever();
  14. context
  15. }
  16. pub(crate) fn translate_rpc_error(e: RpcError) -> std::io::Error {
  17. match e {
  18. RpcError::Disconnected => std::io::Error::new(ErrorKind::BrokenPipe, e),
  19. RpcError::DeadlineExceeded => {
  20. std::io::Error::new(ErrorKind::TimedOut, e)
  21. }
  22. RpcError::Server(server_error) => {
  23. std::io::Error::new(ErrorKind::Other, server_error)
  24. }
  25. }
  26. }
  27. pub(crate) async fn start_tarpc_server<Request, Reply, ServeFn>(
  28. addr: SocketAddr,
  29. serve: ServeFn,
  30. ) -> std::io::Result<()>
  31. where
  32. Request: Send + 'static + serde::de::DeserializeOwned,
  33. Reply: Send + 'static + serde::ser::Serialize,
  34. ServeFn:
  35. tarpc::server::Serve<Request, Resp = Reply> + Send + 'static + Clone,
  36. <ServeFn as Serve<Request>>::Fut: Send,
  37. {
  38. let mut listener = tarpc::serde_transport::tcp::listen(
  39. addr,
  40. tokio_serde::formats::Json::default,
  41. )
  42. .await?;
  43. tokio::spawn(async move {
  44. while let Some(conn) = listener.next().await {
  45. if let Ok(conn) = conn {
  46. let channel = tarpc::server::BaseChannel::with_defaults(conn)
  47. .max_concurrent_requests(1);
  48. tokio::spawn(channel.execute(serve.clone()));
  49. }
  50. }
  51. });
  52. Ok(())
  53. }