api.rs
1 use std::fmt::{Debug, Formatter}; 2 use std::net::SocketAddr; 3 use std::panic::AssertUnwindSafe; 4 use std::sync::Arc; 5 use std::time::Duration; 6 7 use anyhow::Context; 8 use async_trait::async_trait; 9 use fedimint_core::core::ModuleInstanceId; 10 use fedimint_core::module::{ApiEndpoint, ApiEndpointContext, ApiError, ApiRequestErased}; 11 use fedimint_logging::LOG_NET_API; 12 use futures::FutureExt; 13 use jsonrpsee::server::{PingConfig, RpcServiceBuilder, ServerBuilder, ServerHandle}; 14 use jsonrpsee::types::ErrorObject; 15 use jsonrpsee::RpcModule; 16 use tracing::{error, info}; 17 18 use crate::metrics; 19 20 /// A state that has context for the API, passed to each rpc handler callback 21 #[derive(Clone)] 22 pub struct RpcHandlerCtx<M> { 23 pub rpc_context: Arc<M>, 24 } 25 26 impl<M> RpcHandlerCtx<M> { 27 pub fn new_module(state: M) -> RpcModule<RpcHandlerCtx<M>> { 28 RpcModule::new(Self { 29 rpc_context: Arc::new(state), 30 }) 31 } 32 } 33 34 impl<M: Debug> Debug for RpcHandlerCtx<M> { 35 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 36 f.write_str("State { ... }") 37 } 38 } 39 40 /// How long to wait before timing out client connections 41 const API_ENDPOINT_TIMEOUT: Duration = Duration::from_secs(60); 42 43 /// Has the context necessary for serving API endpoints 44 /// 45 /// Returns the specific `State` the endpoint requires and the 46 /// `ApiEndpointContext` which all endpoints can access. 47 #[async_trait] 48 pub trait HasApiContext<State> { 49 async fn context( 50 &self, 51 request: &ApiRequestErased, 52 id: Option<ModuleInstanceId>, 53 ) -> (&State, ApiEndpointContext<'_>); 54 } 55 56 pub type ApiResult<T> = Result<T, ApiError>; 57 58 pub fn check_auth(context: &mut ApiEndpointContext) -> ApiResult<()> { 59 if !context.has_auth() { 60 Err(ApiError::unauthorized()) 61 } else { 62 Ok(()) 63 } 64 } 65 66 pub async fn spawn<T>( 67 name: &'static str, 68 api_bind: &SocketAddr, 69 module: RpcModule<RpcHandlerCtx<T>>, 70 max_connections: u32, 71 ) -> ServerHandle { 72 info!(target: LOG_NET_API, "Starting api on ws://{api_bind}"); 73 74 ServerBuilder::new() 75 .max_connections(max_connections) 76 .enable_ws_ping(PingConfig::new().ping_interval(Duration::from_secs(10))) 77 .set_rpc_middleware(RpcServiceBuilder::new().layer(metrics::jsonrpsee::MetricsLayer)) 78 .build(&api_bind.to_string()) 79 .await 80 .context(format!("Bind address: {api_bind}")) 81 .context(format!("API name: {name}")) 82 .expect("Could not build API server") 83 .start(module) 84 } 85 86 pub fn attach_endpoints<State, T>( 87 rpc_module: &mut RpcModule<RpcHandlerCtx<T>>, 88 endpoints: Vec<ApiEndpoint<State>>, 89 module_instance_id: Option<ModuleInstanceId>, 90 ) where 91 T: HasApiContext<State> + Sync + Send + 'static, 92 State: Sync + Send + 'static, 93 { 94 for endpoint in endpoints { 95 let path = if let Some(module_instance_id) = module_instance_id { 96 // This memory leak is fine because it only happens on server startup 97 // and path has to live till the end of program anyways. 98 Box::leak(format!("module_{}_{}", module_instance_id, endpoint.path).into_boxed_str()) 99 } else { 100 endpoint.path 101 }; 102 // Check if paths contain any abnormal characters 103 if path.contains(|c: char| !matches!(c, '0'..='9' | 'a'..='z' | '_')) { 104 panic!("Constructing bad path name {path}"); 105 } 106 107 // Another memory leak that is fine because the function is only called once at 108 // startup 109 let handler: &'static _ = Box::leak(endpoint.handler); 110 111 rpc_module 112 .register_async_method(path, move |params, rpc_state| async move { 113 let params = params.one::<serde_json::Value>()?; 114 let rpc_context = &rpc_state.rpc_context; 115 116 // Using AssertUnwindSafe here is far from ideal. In theory this means we could 117 // end up with an inconsistent state in theory. In practice most API functions 118 // are only reading and the few that do write anything are atomic. Lastly, this 119 // is only the last line of defense 120 AssertUnwindSafe(tokio::time::timeout(API_ENDPOINT_TIMEOUT, async { 121 let request = serde_json::from_value(params) 122 .map_err(|e| ApiError::bad_request(e.to_string()))?; 123 let (state, context) = rpc_context.context(&request, module_instance_id).await; 124 125 (handler)(state, context, request).await 126 })) 127 .catch_unwind() 128 .await 129 .map_err(|_| { 130 error!( 131 target: LOG_NET_API, 132 path, "API handler panicked, DO NOT IGNORE, FIX IT!!!" 133 ); 134 ErrorObject::owned(500, "API handler panicked", None::<()>) 135 })? 136 .map_err(|tokio::time::error::Elapsed { .. }| { 137 // TODO: find a better error for this, the error we used before: 138 // jsonrpsee::core::Error::RequestTimeout 139 // was moved to be client-side only 140 ErrorObject::owned(-32000, "Request timeout", None::<()>) 141 })? 142 .map_err(|e| ErrorObject::owned(e.code, e.message, None::<()>)) 143 }) 144 .expect("Failed to register async method"); 145 } 146 }