/ fedimint-server / src / net / api.rs
api.rs
  1  use std::fmt::{Debug, Formatter};
  2  use std::net::SocketAddr;
  3  use std::panic::AssertUnwindSafe;
  4  use std::sync::Arc;
  5  use std::time::Duration;
  6  
  7  use anyhow::Context;
  8  use async_trait::async_trait;
  9  use fedimint_core::core::ModuleInstanceId;
 10  use fedimint_core::module::{ApiEndpoint, ApiEndpointContext, ApiError, ApiRequestErased};
 11  use fedimint_logging::LOG_NET_API;
 12  use futures::FutureExt;
 13  use jsonrpsee::server::{PingConfig, RpcServiceBuilder, ServerBuilder, ServerHandle};
 14  use jsonrpsee::types::ErrorObject;
 15  use jsonrpsee::RpcModule;
 16  use tracing::{error, info};
 17  
 18  use crate::metrics;
 19  
 20  /// A state that has context for the API, passed to each rpc handler callback
 21  #[derive(Clone)]
 22  pub struct RpcHandlerCtx<M> {
 23      pub rpc_context: Arc<M>,
 24  }
 25  
 26  impl<M> RpcHandlerCtx<M> {
 27      pub fn new_module(state: M) -> RpcModule<RpcHandlerCtx<M>> {
 28          RpcModule::new(Self {
 29              rpc_context: Arc::new(state),
 30          })
 31      }
 32  }
 33  
 34  impl<M: Debug> Debug for RpcHandlerCtx<M> {
 35      fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 36          f.write_str("State { ... }")
 37      }
 38  }
 39  
 40  /// How long to wait before timing out client connections
 41  const API_ENDPOINT_TIMEOUT: Duration = Duration::from_secs(60);
 42  
 43  /// Has the context necessary for serving API endpoints
 44  ///
 45  /// Returns the specific `State` the endpoint requires and the
 46  /// `ApiEndpointContext` which all endpoints can access.
 47  #[async_trait]
 48  pub trait HasApiContext<State> {
 49      async fn context(
 50          &self,
 51          request: &ApiRequestErased,
 52          id: Option<ModuleInstanceId>,
 53      ) -> (&State, ApiEndpointContext<'_>);
 54  }
 55  
 56  pub type ApiResult<T> = Result<T, ApiError>;
 57  
 58  pub fn check_auth(context: &mut ApiEndpointContext) -> ApiResult<()> {
 59      if !context.has_auth() {
 60          Err(ApiError::unauthorized())
 61      } else {
 62          Ok(())
 63      }
 64  }
 65  
 66  pub async fn spawn<T>(
 67      name: &'static str,
 68      api_bind: &SocketAddr,
 69      module: RpcModule<RpcHandlerCtx<T>>,
 70      max_connections: u32,
 71  ) -> ServerHandle {
 72      info!(target: LOG_NET_API, "Starting api on ws://{api_bind}");
 73  
 74      ServerBuilder::new()
 75          .max_connections(max_connections)
 76          .enable_ws_ping(PingConfig::new().ping_interval(Duration::from_secs(10)))
 77          .set_rpc_middleware(RpcServiceBuilder::new().layer(metrics::jsonrpsee::MetricsLayer))
 78          .build(&api_bind.to_string())
 79          .await
 80          .context(format!("Bind address: {api_bind}"))
 81          .context(format!("API name: {name}"))
 82          .expect("Could not build API server")
 83          .start(module)
 84  }
 85  
 86  pub fn attach_endpoints<State, T>(
 87      rpc_module: &mut RpcModule<RpcHandlerCtx<T>>,
 88      endpoints: Vec<ApiEndpoint<State>>,
 89      module_instance_id: Option<ModuleInstanceId>,
 90  ) where
 91      T: HasApiContext<State> + Sync + Send + 'static,
 92      State: Sync + Send + 'static,
 93  {
 94      for endpoint in endpoints {
 95          let path = if let Some(module_instance_id) = module_instance_id {
 96              // This memory leak is fine because it only happens on server startup
 97              // and path has to live till the end of program anyways.
 98              Box::leak(format!("module_{}_{}", module_instance_id, endpoint.path).into_boxed_str())
 99          } else {
100              endpoint.path
101          };
102          // Check if paths contain any abnormal characters
103          if path.contains(|c: char| !matches!(c, '0'..='9' | 'a'..='z' | '_')) {
104              panic!("Constructing bad path name {path}");
105          }
106  
107          // Another memory leak that is fine because the function is only called once at
108          // startup
109          let handler: &'static _ = Box::leak(endpoint.handler);
110  
111          rpc_module
112              .register_async_method(path, move |params, rpc_state| async move {
113                  let params = params.one::<serde_json::Value>()?;
114                  let rpc_context = &rpc_state.rpc_context;
115  
116                  // Using AssertUnwindSafe here is far from ideal. In theory this means we could
117                  // end up with an inconsistent state in theory. In practice most API functions
118                  // are only reading and the few that do write anything are atomic. Lastly, this
119                  // is only the last line of defense
120                  AssertUnwindSafe(tokio::time::timeout(API_ENDPOINT_TIMEOUT, async {
121                      let request = serde_json::from_value(params)
122                          .map_err(|e| ApiError::bad_request(e.to_string()))?;
123                      let (state, context) = rpc_context.context(&request, module_instance_id).await;
124  
125                      (handler)(state, context, request).await
126                  }))
127                  .catch_unwind()
128                  .await
129                  .map_err(|_| {
130                      error!(
131                          target: LOG_NET_API,
132                          path, "API handler panicked, DO NOT IGNORE, FIX IT!!!"
133                      );
134                      ErrorObject::owned(500, "API handler panicked", None::<()>)
135                  })?
136                  .map_err(|tokio::time::error::Elapsed { .. }| {
137                      // TODO: find a better error for this, the error we used before:
138                      // jsonrpsee::core::Error::RequestTimeout
139                      // was moved to be client-side only
140                      ErrorObject::owned(-32000, "Request timeout", None::<()>)
141                  })?
142                  .map_err(|e| ErrorObject::owned(e.code, e.message, None::<()>))
143              })
144              .expect("Failed to register async method");
145      }
146  }