lib.rs
1 #![allow(clippy::type_complexity)] 2 3 pub mod shutdown; 4 5 use anyhow::anyhow; 6 use arroyo_types::POSTHOG_KEY; 7 use axum::body::Bytes; 8 use axum::extract::State; 9 use axum::http::StatusCode; 10 use axum::routing::get; 11 use axum::Router; 12 use hyper::Body; 13 use lazy_static::lazy_static; 14 use once_cell::sync::OnceCell; 15 use prometheus::{register_int_counter, Encoder, IntCounter, ProtobufEncoder, TextEncoder}; 16 use reqwest::Client; 17 use serde_json::{json, Value}; 18 use std::error::Error; 19 use std::fs; 20 use std::future::Future; 21 use std::net::SocketAddr; 22 use std::path::PathBuf; 23 use std::sync::Arc; 24 use std::task::{Context, Poll}; 25 use tonic::body::BoxBody; 26 use tonic::transport::Server; 27 use tower::layer::util::Stack; 28 use tower::{Layer, Service}; 29 use tower_http::classify::{GrpcCode, GrpcErrorsAsFailures, SharedClassifier}; 30 use tower_http::trace::{DefaultOnFailure, TraceLayer}; 31 32 use tracing::metadata::LevelFilter; 33 use tracing::{debug, info, span, Level}; 34 use tracing_subscriber::fmt::format::{FmtSpan, Format}; 35 use tracing_subscriber::prelude::*; 36 use tracing_subscriber::EnvFilter; 37 use tracing_subscriber::Registry; 38 39 use arroyo_rpc::config::{config, LogFormat}; 40 use tracing_appender::non_blocking::WorkerGuard; 41 use tracing_log::LogTracer; 42 use uuid::Uuid; 43 44 pub const BUILD_TIMESTAMP: &str = env!("VERGEN_BUILD_TIMESTAMP"); 45 pub const GIT_SHA: &str = env!("VERGEN_GIT_SHA"); 46 pub const VERSION: &str = "0.12.0-dev"; 47 48 static CLUSTER_ID: OnceCell<String> = OnceCell::new(); 49 50 pub fn init_logging(name: &str) -> WorkerGuard { 51 init_logging_with_filter( 52 name, 53 EnvFilter::builder() 54 .with_default_directive(LevelFilter::INFO.into()) 55 .from_env_lossy(), 56 ) 57 } 58 59 pub fn init_logging_with_filter(_name: &str, filter: EnvFilter) -> WorkerGuard { 60 if let Err(e) = LogTracer::init() { 61 eprintln!("Failed to initialize log tracer {:?}", e); 62 } 63 64 let filter = filter.add_directive("refinery_core=warn".parse().unwrap()); 65 66 let (nonblocking, guard) = tracing_appender::non_blocking(std::io::stderr()); 67 68 match config().logging.format { 69 LogFormat::Plaintext => { 70 tracing::subscriber::set_global_default( 71 Registry::default().with( 72 tracing_subscriber::fmt::layer() 73 .with_line_number(false) 74 .with_file(false) 75 .with_span_events(FmtSpan::NONE) 76 .with_writer(nonblocking) 77 .with_filter(filter), 78 ), 79 ) 80 .expect("Unable to set global log subscriber"); 81 } 82 LogFormat::Logfmt => { 83 tracing::subscriber::set_global_default( 84 Registry::default().with( 85 tracing_subscriber::fmt::layer() 86 .event_format(tracing_logfmt::EventsFormatter) 87 .fmt_fields(tracing_logfmt::FieldsFormatter) 88 .with_writer(nonblocking) 89 .with_filter(filter), 90 ), 91 ) 92 .expect("Unable to set global log subscriber"); 93 } 94 LogFormat::Json => { 95 tracing::subscriber::set_global_default( 96 Registry::default().with( 97 tracing_subscriber::fmt::layer() 98 .event_format(Format::default().json()) 99 .with_writer(nonblocking) 100 .with_filter(filter), 101 ), 102 ) 103 .expect("Unable to set global log subscriber"); 104 } 105 } 106 107 std::panic::set_hook(Box::new(|panic| { 108 if let Some(location) = panic.location() { 109 tracing::error!( 110 message = %panic, 111 panic.file = location.file(), 112 panic.line = location.line(), 113 panic.column = location.column(), 114 ); 115 } else { 116 tracing::error!(message = %panic); 117 } 118 })); 119 120 guard 121 } 122 123 fn existing_cluster_id(path: Option<&PathBuf>) -> Option<String> { 124 let path = path?; 125 if path.exists() { 126 let s = fs::read_to_string(path).ok()?.trim().to_string(); 127 Uuid::parse_str(&s).ok()?; 128 Some(s) 129 } else { 130 None 131 } 132 } 133 134 pub fn set_cluster_id(cluster_id: &str) { 135 let path = dirs::config_dir().map(|p| p.join("arroyo").join("cluster-info")); 136 137 if let Some(id) = existing_cluster_id(path.as_ref()) { 138 CLUSTER_ID.set(id).unwrap(); 139 } else { 140 CLUSTER_ID.set(cluster_id.to_string()).unwrap(); 141 if let Some(path) = path { 142 let _ = fs::write(&path, cluster_id); 143 } 144 } 145 } 146 147 pub fn get_cluster_id() -> String { 148 CLUSTER_ID.get().map(|s| s.to_string()).unwrap() 149 } 150 151 pub fn log_event(name: &str, mut props: Value) { 152 static CLIENT: OnceCell<Client> = OnceCell::new(); 153 let cluster_id = get_cluster_id(); 154 if !config().disable_telemetry { 155 let name = name.to_string(); 156 tokio::task::spawn(async move { 157 let client = CLIENT.get_or_init(Client::new); 158 159 if let Some(props) = props.as_object_mut() { 160 props.insert("distinct_id".to_string(), Value::String(cluster_id)); 161 props.insert("git_sha".to_string(), Value::String(GIT_SHA.to_string())); 162 props.insert("version".to_string(), Value::String(VERSION.to_string())); 163 props.insert( 164 "build_timestamp".to_string(), 165 Value::String(BUILD_TIMESTAMP.to_string()), 166 ); 167 } 168 169 let obj = json!({ 170 "api_key": POSTHOG_KEY, 171 "event": name, 172 "properties": props, 173 }); 174 175 if let Err(e) = client 176 .post("https://events.arroyo.dev/capture") 177 .json(&obj) 178 .send() 179 .await 180 { 181 debug!("Failed to record event: {}", e); 182 } 183 }); 184 } 185 } 186 187 struct AdminState { 188 name: String, 189 } 190 191 async fn root<'a>(State(state): State<Arc<AdminState>>) -> String { 192 format!("{}\n", state.name) 193 } 194 195 async fn status<'a>() -> String { 196 "ok".to_string() 197 } 198 199 async fn metrics() -> Result<Bytes, StatusCode> { 200 let encoder = TextEncoder::new(); 201 let registry = prometheus::default_registry(); 202 match encoder.encode_to_string(®istry.gather()) { 203 Ok(s) => Ok(Bytes::from(s)), 204 Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), 205 } 206 } 207 208 async fn metrics_proto() -> Result<Bytes, StatusCode> { 209 let encoder = ProtobufEncoder::new(); 210 let registry = prometheus::default_registry(); 211 let mut buf = vec![]; 212 encoder 213 .encode(®istry.gather(), &mut buf) 214 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; 215 216 Ok(buf.into()) 217 } 218 219 async fn config_route() -> Result<String, StatusCode> { 220 Ok(toml::to_string(&*config()).unwrap()) 221 } 222 223 async fn details<'a>(State(state): State<Arc<AdminState>>) -> String { 224 serde_json::to_string_pretty(&json!({ 225 "service": state.name, 226 "git_sha": GIT_SHA, 227 "version": VERSION, 228 "build_timestamp": BUILD_TIMESTAMP, 229 })) 230 .unwrap() 231 } 232 233 pub async fn start_admin_server(service: &str) -> anyhow::Result<()> { 234 let addr = config().admin.bind_address; 235 let port = config().admin.http_port; 236 237 info!("Starting {} admin server on {}:{}", service, addr, port); 238 239 let state = Arc::new(AdminState { 240 name: format!("arroyo-{}", service), 241 }); 242 let app = Router::new() 243 .route("/status", get(status)) 244 .route("/name", get(root)) 245 .route("/metrics", get(metrics)) 246 .route("/metrics.pb", get(metrics_proto)) 247 .route("/details", get(details)) 248 .route("/config", get(config_route)) 249 .with_state(state); 250 251 let addr = SocketAddr::new(addr, port); 252 253 axum::Server::bind(&addr) 254 .serve(app.into_make_service()) 255 .await 256 .map_err(|e| anyhow!("Failed to start admin HTTP server: {}", e)) 257 } 258 259 lazy_static! { 260 static ref REQUEST_COUNTER: IntCounter = 261 register_int_counter!("grpc_request_counter", "grpc requests").unwrap(); 262 } 263 264 #[derive(Debug, Clone, Default)] 265 pub struct GrpcErrorLogMiddlewareLayer; 266 267 impl<S> Layer<S> for GrpcErrorLogMiddlewareLayer { 268 type Service = GrpcErrorLogMiddleware<S>; 269 270 fn layer(&self, service: S) -> Self::Service { 271 GrpcErrorLogMiddleware { inner: service } 272 } 273 } 274 275 #[derive(Debug, Clone)] 276 pub struct GrpcErrorLogMiddleware<S> { 277 inner: S, 278 } 279 280 impl<S> Service<hyper::Request<Body>> for GrpcErrorLogMiddleware<S> 281 where 282 S: Service<hyper::Request<Body>, Response = hyper::Response<BoxBody>> + Clone + Send + 'static, 283 S::Future: Send + 'static, 284 { 285 type Response = S::Response; 286 type Error = S::Error; 287 type Future = futures::future::BoxFuture<'static, Result<Self::Response, Self::Error>>; 288 289 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 290 self.inner.poll_ready(cx) 291 } 292 293 fn call(&mut self, req: hyper::Request<Body>) -> Self::Future { 294 let clone = self.inner.clone(); 295 let path = req.uri().clone(); 296 let mut inner = std::mem::replace(&mut self.inner, clone); 297 298 Box::pin(async move { 299 // Do extra async work here... 300 let response = inner.call(req).await?; 301 302 let code = response 303 .headers() 304 .get("grpc-status") 305 .iter() 306 .flat_map(|status| status.to_str().ok()) 307 .flat_map(|status| status.parse::<i32>().ok()) 308 .find_map(|code| match code { 309 0 => Some(GrpcCode::Ok), 310 1 => Some(GrpcCode::Cancelled), 311 2 => Some(GrpcCode::Unknown), 312 3 => Some(GrpcCode::InvalidArgument), 313 4 => Some(GrpcCode::DeadlineExceeded), 314 5 => Some(GrpcCode::NotFound), 315 6 => Some(GrpcCode::AlreadyExists), 316 7 => Some(GrpcCode::PermissionDenied), 317 8 => Some(GrpcCode::ResourceExhausted), 318 9 => Some(GrpcCode::FailedPrecondition), 319 10 => Some(GrpcCode::Aborted), 320 11 => Some(GrpcCode::OutOfRange), 321 12 => Some(GrpcCode::Unimplemented), 322 13 => Some(GrpcCode::Internal), 323 14 => Some(GrpcCode::Unavailable), 324 15 => Some(GrpcCode::DataLoss), 325 16 => Some(GrpcCode::Unauthenticated), 326 _ => None, 327 }); 328 329 if let Some(code) = code { 330 span!( 331 Level::ERROR, 332 "response failed", 333 code = format!("{:?}", code), 334 path = format!("{:?}", path) 335 ); 336 } 337 338 Ok(response) 339 }) 340 } 341 } 342 343 pub fn grpc_server() -> Server< 344 Stack< 345 Stack< 346 GrpcErrorLogMiddlewareLayer, 347 Stack<TraceLayer<SharedClassifier<GrpcErrorsAsFailures>>, tower::layer::util::Identity>, 348 >, 349 tower::layer::util::Identity, 350 >, 351 > { 352 let layer = tower::ServiceBuilder::new() 353 .layer(TraceLayer::new_for_grpc().on_failure(DefaultOnFailure::new().level(Level::TRACE))) 354 .layer(GrpcErrorLogMiddlewareLayer) 355 .into_inner(); 356 357 Server::builder().layer(layer) 358 } 359 360 pub async fn wrap_start( 361 name: &str, 362 addr: SocketAddr, 363 result: impl Future<Output = Result<(), impl Error>>, 364 ) -> anyhow::Result<()> { 365 result.await.map_err(|e| { 366 anyhow!( 367 "Failed to start {} server on {}: {}", 368 name, 369 addr, 370 e.source() 371 .map(|e| e.to_string()) 372 .unwrap_or_else(|| e.to_string()) 373 ) 374 }) 375 }