lib.rs
  1  #![allow(clippy::type_complexity)]
  2  
  3  pub mod shutdown;
  4  
  5  use anyhow::anyhow;
  6  use arroyo_types::POSTHOG_KEY;
  7  use axum::body::Bytes;
  8  use axum::extract::State;
  9  use axum::http::StatusCode;
 10  use axum::routing::get;
 11  use axum::Router;
 12  use hyper::Body;
 13  use lazy_static::lazy_static;
 14  use once_cell::sync::OnceCell;
 15  use prometheus::{register_int_counter, Encoder, IntCounter, ProtobufEncoder, TextEncoder};
 16  use reqwest::Client;
 17  use serde_json::{json, Value};
 18  use std::error::Error;
 19  use std::fs;
 20  use std::future::Future;
 21  use std::net::SocketAddr;
 22  use std::path::PathBuf;
 23  use std::sync::Arc;
 24  use std::task::{Context, Poll};
 25  use tonic::body::BoxBody;
 26  use tonic::transport::Server;
 27  use tower::layer::util::Stack;
 28  use tower::{Layer, Service};
 29  use tower_http::classify::{GrpcCode, GrpcErrorsAsFailures, SharedClassifier};
 30  use tower_http::trace::{DefaultOnFailure, TraceLayer};
 31  
 32  use tracing::metadata::LevelFilter;
 33  use tracing::{debug, info, span, Level};
 34  use tracing_subscriber::fmt::format::{FmtSpan, Format};
 35  use tracing_subscriber::prelude::*;
 36  use tracing_subscriber::EnvFilter;
 37  use tracing_subscriber::Registry;
 38  
 39  use arroyo_rpc::config::{config, LogFormat};
 40  use tracing_appender::non_blocking::WorkerGuard;
 41  use tracing_log::LogTracer;
 42  use uuid::Uuid;
 43  
 44  pub const BUILD_TIMESTAMP: &str = env!("VERGEN_BUILD_TIMESTAMP");
 45  pub const GIT_SHA: &str = env!("VERGEN_GIT_SHA");
 46  pub const VERSION: &str = "0.12.0-dev";
 47  
 48  static CLUSTER_ID: OnceCell<String> = OnceCell::new();
 49  
 50  pub fn init_logging(name: &str) -> WorkerGuard {
 51      init_logging_with_filter(
 52          name,
 53          EnvFilter::builder()
 54              .with_default_directive(LevelFilter::INFO.into())
 55              .from_env_lossy(),
 56      )
 57  }
 58  
 59  pub fn init_logging_with_filter(_name: &str, filter: EnvFilter) -> WorkerGuard {
 60      if let Err(e) = LogTracer::init() {
 61          eprintln!("Failed to initialize log tracer {:?}", e);
 62      }
 63  
 64      let filter = filter.add_directive("refinery_core=warn".parse().unwrap());
 65  
 66      let (nonblocking, guard) = tracing_appender::non_blocking(std::io::stderr());
 67  
 68      match config().logging.format {
 69          LogFormat::Plaintext => {
 70              tracing::subscriber::set_global_default(
 71                  Registry::default().with(
 72                      tracing_subscriber::fmt::layer()
 73                          .with_line_number(false)
 74                          .with_file(false)
 75                          .with_span_events(FmtSpan::NONE)
 76                          .with_writer(nonblocking)
 77                          .with_filter(filter),
 78                  ),
 79              )
 80              .expect("Unable to set global log subscriber");
 81          }
 82          LogFormat::Logfmt => {
 83              tracing::subscriber::set_global_default(
 84                  Registry::default().with(
 85                      tracing_subscriber::fmt::layer()
 86                          .event_format(tracing_logfmt::EventsFormatter)
 87                          .fmt_fields(tracing_logfmt::FieldsFormatter)
 88                          .with_writer(nonblocking)
 89                          .with_filter(filter),
 90                  ),
 91              )
 92              .expect("Unable to set global log subscriber");
 93          }
 94          LogFormat::Json => {
 95              tracing::subscriber::set_global_default(
 96                  Registry::default().with(
 97                      tracing_subscriber::fmt::layer()
 98                          .event_format(Format::default().json())
 99                          .with_writer(nonblocking)
100                          .with_filter(filter),
101                  ),
102              )
103              .expect("Unable to set global log subscriber");
104          }
105      }
106  
107      std::panic::set_hook(Box::new(|panic| {
108          if let Some(location) = panic.location() {
109              tracing::error!(
110                  message = %panic,
111                  panic.file = location.file(),
112                  panic.line = location.line(),
113                  panic.column = location.column(),
114              );
115          } else {
116              tracing::error!(message = %panic);
117          }
118      }));
119  
120      guard
121  }
122  
123  fn existing_cluster_id(path: Option<&PathBuf>) -> Option<String> {
124      let path = path?;
125      if path.exists() {
126          let s = fs::read_to_string(path).ok()?.trim().to_string();
127          Uuid::parse_str(&s).ok()?;
128          Some(s)
129      } else {
130          None
131      }
132  }
133  
134  pub fn set_cluster_id(cluster_id: &str) {
135      let path = dirs::config_dir().map(|p| p.join("arroyo").join("cluster-info"));
136  
137      if let Some(id) = existing_cluster_id(path.as_ref()) {
138          CLUSTER_ID.set(id).unwrap();
139      } else {
140          CLUSTER_ID.set(cluster_id.to_string()).unwrap();
141          if let Some(path) = path {
142              let _ = fs::write(&path, cluster_id);
143          }
144      }
145  }
146  
147  pub fn get_cluster_id() -> String {
148      CLUSTER_ID.get().map(|s| s.to_string()).unwrap()
149  }
150  
151  pub fn log_event(name: &str, mut props: Value) {
152      static CLIENT: OnceCell<Client> = OnceCell::new();
153      let cluster_id = get_cluster_id();
154      if !config().disable_telemetry {
155          let name = name.to_string();
156          tokio::task::spawn(async move {
157              let client = CLIENT.get_or_init(Client::new);
158  
159              if let Some(props) = props.as_object_mut() {
160                  props.insert("distinct_id".to_string(), Value::String(cluster_id));
161                  props.insert("git_sha".to_string(), Value::String(GIT_SHA.to_string()));
162                  props.insert("version".to_string(), Value::String(VERSION.to_string()));
163                  props.insert(
164                      "build_timestamp".to_string(),
165                      Value::String(BUILD_TIMESTAMP.to_string()),
166                  );
167              }
168  
169              let obj = json!({
170                  "api_key": POSTHOG_KEY,
171                  "event": name,
172                  "properties": props,
173              });
174  
175              if let Err(e) = client
176                  .post("https://events.arroyo.dev/capture")
177                  .json(&obj)
178                  .send()
179                  .await
180              {
181                  debug!("Failed to record event: {}", e);
182              }
183          });
184      }
185  }
186  
187  struct AdminState {
188      name: String,
189  }
190  
191  async fn root<'a>(State(state): State<Arc<AdminState>>) -> String {
192      format!("{}\n", state.name)
193  }
194  
195  async fn status<'a>() -> String {
196      "ok".to_string()
197  }
198  
199  async fn metrics() -> Result<Bytes, StatusCode> {
200      let encoder = TextEncoder::new();
201      let registry = prometheus::default_registry();
202      match encoder.encode_to_string(&registry.gather()) {
203          Ok(s) => Ok(Bytes::from(s)),
204          Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
205      }
206  }
207  
208  async fn metrics_proto() -> Result<Bytes, StatusCode> {
209      let encoder = ProtobufEncoder::new();
210      let registry = prometheus::default_registry();
211      let mut buf = vec![];
212      encoder
213          .encode(&registry.gather(), &mut buf)
214          .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
215  
216      Ok(buf.into())
217  }
218  
219  async fn config_route() -> Result<String, StatusCode> {
220      Ok(toml::to_string(&*config()).unwrap())
221  }
222  
223  async fn details<'a>(State(state): State<Arc<AdminState>>) -> String {
224      serde_json::to_string_pretty(&json!({
225          "service": state.name,
226          "git_sha": GIT_SHA,
227          "version": VERSION,
228          "build_timestamp": BUILD_TIMESTAMP,
229      }))
230      .unwrap()
231  }
232  
233  pub async fn start_admin_server(service: &str) -> anyhow::Result<()> {
234      let addr = config().admin.bind_address;
235      let port = config().admin.http_port;
236  
237      info!("Starting {} admin server on {}:{}", service, addr, port);
238  
239      let state = Arc::new(AdminState {
240          name: format!("arroyo-{}", service),
241      });
242      let app = Router::new()
243          .route("/status", get(status))
244          .route("/name", get(root))
245          .route("/metrics", get(metrics))
246          .route("/metrics.pb", get(metrics_proto))
247          .route("/details", get(details))
248          .route("/config", get(config_route))
249          .with_state(state);
250  
251      let addr = SocketAddr::new(addr, port);
252  
253      axum::Server::bind(&addr)
254          .serve(app.into_make_service())
255          .await
256          .map_err(|e| anyhow!("Failed to start admin HTTP server: {}", e))
257  }
258  
259  lazy_static! {
260      static ref REQUEST_COUNTER: IntCounter =
261          register_int_counter!("grpc_request_counter", "grpc requests").unwrap();
262  }
263  
264  #[derive(Debug, Clone, Default)]
265  pub struct GrpcErrorLogMiddlewareLayer;
266  
267  impl<S> Layer<S> for GrpcErrorLogMiddlewareLayer {
268      type Service = GrpcErrorLogMiddleware<S>;
269  
270      fn layer(&self, service: S) -> Self::Service {
271          GrpcErrorLogMiddleware { inner: service }
272      }
273  }
274  
275  #[derive(Debug, Clone)]
276  pub struct GrpcErrorLogMiddleware<S> {
277      inner: S,
278  }
279  
280  impl<S> Service<hyper::Request<Body>> for GrpcErrorLogMiddleware<S>
281  where
282      S: Service<hyper::Request<Body>, Response = hyper::Response<BoxBody>> + Clone + Send + 'static,
283      S::Future: Send + 'static,
284  {
285      type Response = S::Response;
286      type Error = S::Error;
287      type Future = futures::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
288  
289      fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
290          self.inner.poll_ready(cx)
291      }
292  
293      fn call(&mut self, req: hyper::Request<Body>) -> Self::Future {
294          let clone = self.inner.clone();
295          let path = req.uri().clone();
296          let mut inner = std::mem::replace(&mut self.inner, clone);
297  
298          Box::pin(async move {
299              // Do extra async work here...
300              let response = inner.call(req).await?;
301  
302              let code = response
303                  .headers()
304                  .get("grpc-status")
305                  .iter()
306                  .flat_map(|status| status.to_str().ok())
307                  .flat_map(|status| status.parse::<i32>().ok())
308                  .find_map(|code| match code {
309                      0 => Some(GrpcCode::Ok),
310                      1 => Some(GrpcCode::Cancelled),
311                      2 => Some(GrpcCode::Unknown),
312                      3 => Some(GrpcCode::InvalidArgument),
313                      4 => Some(GrpcCode::DeadlineExceeded),
314                      5 => Some(GrpcCode::NotFound),
315                      6 => Some(GrpcCode::AlreadyExists),
316                      7 => Some(GrpcCode::PermissionDenied),
317                      8 => Some(GrpcCode::ResourceExhausted),
318                      9 => Some(GrpcCode::FailedPrecondition),
319                      10 => Some(GrpcCode::Aborted),
320                      11 => Some(GrpcCode::OutOfRange),
321                      12 => Some(GrpcCode::Unimplemented),
322                      13 => Some(GrpcCode::Internal),
323                      14 => Some(GrpcCode::Unavailable),
324                      15 => Some(GrpcCode::DataLoss),
325                      16 => Some(GrpcCode::Unauthenticated),
326                      _ => None,
327                  });
328  
329              if let Some(code) = code {
330                  span!(
331                      Level::ERROR,
332                      "response failed",
333                      code = format!("{:?}", code),
334                      path = format!("{:?}", path)
335                  );
336              }
337  
338              Ok(response)
339          })
340      }
341  }
342  
343  pub fn grpc_server() -> Server<
344      Stack<
345          Stack<
346              GrpcErrorLogMiddlewareLayer,
347              Stack<TraceLayer<SharedClassifier<GrpcErrorsAsFailures>>, tower::layer::util::Identity>,
348          >,
349          tower::layer::util::Identity,
350      >,
351  > {
352      let layer = tower::ServiceBuilder::new()
353          .layer(TraceLayer::new_for_grpc().on_failure(DefaultOnFailure::new().level(Level::TRACE)))
354          .layer(GrpcErrorLogMiddlewareLayer)
355          .into_inner();
356  
357      Server::builder().layer(layer)
358  }
359  
360  pub async fn wrap_start(
361      name: &str,
362      addr: SocketAddr,
363      result: impl Future<Output = Result<(), impl Error>>,
364  ) -> anyhow::Result<()> {
365      result.await.map_err(|e| {
366          anyhow!(
367              "Failed to start {} server on {}: {}",
368              name,
369              addr,
370              e.source()
371                  .map(|e| e.to_string())
372                  .unwrap_or_else(|| e.to_string())
373          )
374      })
375  }