/ radicle-httpd / src / lib.rs
lib.rs
  1  #![allow(clippy::type_complexity)]
  2  #![allow(clippy::too_many_arguments)]
  3  #![recursion_limit = "256"]
  4  pub mod error;
  5  
  6  use std::collections::HashMap;
  7  use std::net::SocketAddr;
  8  use std::num::NonZeroUsize;
  9  use std::process::Command;
 10  use std::str;
 11  use std::sync::Arc;
 12  use std::time::Duration;
 13  
 14  use anyhow::Context as _;
 15  use axum::body::{Body, BoxBody, HttpBody};
 16  use axum::http::{Request, Response};
 17  use axum::middleware;
 18  use axum::Router;
 19  use tower_http::trace::TraceLayer;
 20  use tracing::Span;
 21  
 22  use radicle::identity::Id;
 23  use radicle::Profile;
 24  
 25  use tracing_extra::{tracing_middleware, ColoredStatus, Paint, RequestId, TracingInfo};
 26  
 27  mod api;
 28  mod axum_extra;
 29  mod cache;
 30  mod git;
 31  mod raw;
 32  #[cfg(test)]
 33  mod test;
 34  mod tracing_extra;
 35  
 36  /// Default cache HTTP size.
 37  pub const DEFAULT_CACHE_SIZE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(100) };
 38  
 39  #[derive(Debug, Clone)]
 40  pub struct Options {
 41      pub aliases: HashMap<String, Id>,
 42      pub listen: SocketAddr,
 43      pub cache: Option<NonZeroUsize>,
 44  }
 45  
 46  /// Run the Server.
 47  pub async fn run(options: Options) -> anyhow::Result<()> {
 48      let git_version = Command::new("git")
 49          .arg("version")
 50          .output()
 51          .context("'git' command must be available")?
 52          .stdout;
 53  
 54      tracing::info!("{}", str::from_utf8(&git_version)?.trim());
 55  
 56      let listen = options.listen;
 57  
 58      tracing::info!("listening on http://{}", listen);
 59  
 60      let profile = Profile::load()?;
 61      let request_id = RequestId::new();
 62  
 63      tracing::info!("using radicle home at {}", profile.home().display());
 64  
 65      let app =
 66          router(options, profile)?
 67          .layer(middleware::from_fn(tracing_middleware))
 68          .layer(
 69              TraceLayer::new_for_http()
 70                  .make_span_with(move |_request: &Request<Body>| {
 71                      tracing::info_span!("request", id = %request_id.clone().next())
 72                  })
 73                  .on_response(
 74                      |response: &Response<BoxBody>, latency: Duration, _span: &Span| {
 75                          if let Some(info) = response.extensions().get::<TracingInfo>() {
 76                              tracing::info!(
 77                                  "{} \"{} {} {:?}\" {} {:?} {}",
 78                                  info.connect_info.0,
 79                                  info.method,
 80                                  info.uri,
 81                                  info.version,
 82                                  ColoredStatus(response.status()),
 83                                  latency,
 84                                  Paint::dim(
 85                                      response
 86                                          .body()
 87                                          .size_hint()
 88                                          .exact()
 89                                          .map(|n| n.to_string())
 90                                          .unwrap_or("0".to_string())
 91                                          .into()
 92                                  ),
 93                              );
 94                          } else {
 95                              tracing::info!("Processed");
 96                          }
 97                      },
 98                  ),
 99          )
100          .into_make_service_with_connect_info::<SocketAddr>();
101  
102      axum::Server::bind(&listen)
103          .serve(app)
104          .await
105          .map_err(anyhow::Error::from)
106  }
107  
108  /// Create a router consisting of other sub-routers.
109  fn router(options: Options, profile: Profile) -> anyhow::Result<Router> {
110      let profile = Arc::new(profile);
111      let ctx = api::Context::new(profile.clone(), &options);
112  
113      let api_router = api::router(ctx);
114      let git_router = git::router(profile.clone(), options.aliases);
115      let raw_router = raw::router(profile);
116  
117      let app = Router::new()
118          .merge(git_router)
119          .nest("/api", api_router)
120          .nest("/raw", raw_router);
121  
122      Ok(app)
123  }
124  
125  pub mod logger {
126      use tracing::dispatcher::Dispatch;
127  
128      pub fn init() -> Result<(), tracing::subscriber::SetGlobalDefaultError> {
129          tracing::dispatcher::set_global_default(Dispatch::new(subscriber()))
130      }
131  
132      #[cfg(feature = "logfmt")]
133      pub fn subscriber() -> impl tracing::Subscriber {
134          use tracing_subscriber::layer::SubscriberExt as _;
135          use tracing_subscriber::EnvFilter;
136  
137          tracing_subscriber::Registry::default()
138              .with(EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")))
139              .with(tracing_logfmt::layer())
140      }
141  
142      #[cfg(not(feature = "logfmt"))]
143      pub fn subscriber() -> impl tracing::Subscriber {
144          tracing_subscriber::FmtSubscriber::builder()
145              .with_target(false)
146              .with_max_level(tracing::Level::DEBUG)
147              .finish()
148      }
149  }
150  
151  #[cfg(test)]
152  mod routes {
153      use std::collections::HashMap;
154      use std::net::SocketAddr;
155  
156      use axum::extract::connect_info::MockConnectInfo;
157      use axum::http::StatusCode;
158  
159      use crate::test::{self, get};
160  
161      #[tokio::test]
162      async fn test_invalid_route_returns_404() {
163          let tmp = tempfile::tempdir().unwrap();
164          let app = super::router(
165              super::Options {
166                  aliases: HashMap::new(),
167                  listen: SocketAddr::from(([0, 0, 0, 0], 8080)),
168                  cache: None,
169              },
170              test::profile(tmp.path(), [0xff; 32]),
171          )
172          .unwrap()
173          .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 8080))));
174  
175          let response = get(&app, "/aa/a").await;
176  
177          assert_eq!(response.status(), StatusCode::NOT_FOUND);
178      }
179  }