lib.rs
1 //! Procedural macros for generating service-related boilerplate in the 2 //! Overwatch framework. 3 //! 4 //! This crate provides macros to derive service-related traits and 5 //! implementations to ensure compile-time validation and structured lifecycle 6 //! management for services. 7 //! 8 //! # Provided Macros 9 //! 10 //! - `#[derive_services]`: Modifies a struct by changing its fields to 11 //! `OpaqueServiceHandle<T>` and automatically derives the `Services` trait. 12 //! - `#[derive(Services)]`: Implements the `Services` trait for a struct, 13 //! generating necessary service lifecycle methods and runtime service ID 14 //! management. **This derive macro is not meant to be used directly**. 15 //! 16 //! # Features 17 //! 18 //! - Ensures that all services are registered at compile-time, avoiding runtime 19 //! checks and panics. 20 //! - Provides compile-time validation for service settings and runtime service 21 //! identifiers. 22 23 use proc_macro::TokenStream; 24 use proc_macro_error2::{abort_call_site, proc_macro_error}; 25 use proc_macro2::{Ident, Span}; 26 use quote::{format_ident, quote}; 27 use syn::{ 28 Data, DeriveInput, Field, Fields, GenericArgument, Generics, ItemStruct, PathArguments, Type, 29 parse, parse_macro_input, parse_str, punctuated::Punctuated, token::Comma, 30 }; 31 32 mod utils; 33 34 /// Procedural macro to derive service-related implementations for a struct. 35 /// 36 /// This macro modifies a struct by converting its fields from `T` to 37 /// `OpaqueServiceHandle<T>` and deriving the `Services` trait 38 /// to manage service lifecycle operations. 39 /// 40 /// # Example 41 /// ```rust,ignore 42 /// use overwatch_derive::derive_services; 43 /// 44 /// #[derive_services] 45 /// struct MyServices { 46 /// database: DatabaseService, 47 /// cache: CacheService, 48 /// } 49 /// ``` 50 /// This expands to: 51 /// ```rust,ignore 52 /// use overwatch::OpaqueServiceRunnerHandle; 53 /// use async_trait::async_trait; 54 /// 55 /// struct MyServices { 56 /// database: OpaqueServiceHandle<DatabaseService>, 57 /// cache: OpaqueServiceHandle<CacheService>, 58 /// } 59 /// 60 /// #[async_trait] 61 /// impl Services for MyServices { /* service lifecycle methods */ } 62 /// ``` 63 #[expect( 64 clippy::missing_panics_doc, 65 reason = "We will add docs to this macro later on." 66 )] 67 #[proc_macro_attribute] 68 pub fn derive_services(_attr: TokenStream, item: TokenStream) -> TokenStream { 69 let input = parse_macro_input!(item as ItemStruct); 70 let struct_name = &input.ident; 71 let visibility = &input.vis; 72 let generics = &input.generics; 73 74 let Fields::Named(named_fields) = input.fields else { 75 panic!("`derive_services` macro only supports structs with named fields"); 76 }; 77 let fields = named_fields.named; 78 79 let modified_fields = fields.iter().map(|field| { 80 let field_name = &field.ident; 81 let field_type = &field.ty; 82 let field_attrs = &field.attrs; // Preserve attributes (including feature flags) 83 84 let new_field_type = quote! { 85 ::overwatch::OpaqueServiceRunnerHandle<#field_type> 86 }; 87 88 quote! { 89 #(#field_attrs)* 90 #field_name: #new_field_type 91 } 92 }); 93 94 // Generate the modified struct with #[derive(Services)] 95 let modified_struct = quote! { 96 #[derive(::overwatch::Services)] 97 #visibility struct #struct_name #generics { 98 #(#modified_fields),* 99 } 100 }; 101 102 modified_struct.into() 103 } 104 105 /// Returns default instrumentation settings if the `instrumentation` feature is 106 /// enabled. 107 /// 108 /// The output of this function is to be used in places that want to add tracing 109 /// capabilities to non `Result` types. For `Result` types, use 110 /// [`get_default_instrumentation_for_result`] instead. 111 fn get_default_instrumentation() -> proc_macro2::TokenStream { 112 #[cfg(feature = "instrumentation")] 113 quote! { 114 #[tracing::instrument(skip(self))] 115 } 116 117 #[cfg(not(feature = "instrumentation"))] 118 quote! {} 119 } 120 121 /// Returns instrumentation settings that track errors if `instrumentation` is 122 /// enabled. 123 /// 124 /// The output of this function is to be used in places that want to add tracing 125 /// capabilities to `Result` types. For non `Result` types, use 126 /// [`get_default_instrumentation`] instead. 127 fn get_default_instrumentation_for_result() -> proc_macro2::TokenStream { 128 #[cfg(feature = "instrumentation")] 129 quote! { 130 #[tracing::instrument(skip(self), err)] 131 } 132 133 #[cfg(not(feature = "instrumentation"))] 134 quote! {} 135 } 136 137 /// Returns instrumentation settings that ignore `settings` in traces. 138 fn get_default_instrumentation_without_settings() -> proc_macro2::TokenStream { 139 #[cfg(feature = "instrumentation")] 140 quote! { 141 #[tracing::instrument(skip(self, settings))] 142 } 143 144 #[cfg(not(feature = "instrumentation"))] 145 quote! {} 146 } 147 148 /// Derives the `Services` trait for a struct, implementing service lifecycle 149 /// operations. 150 /// 151 /// This macro generates the necessary implementations to manage services, 152 /// including: 153 /// - Initializing services. 154 /// - Starting/stopping services. 155 /// - Handling relays and status updates. 156 /// 157 /// **THIS MACRO IS NOT MEANT TO BE USED DIRECTLY BY DEVELOPERS, WHO SHOULD 158 /// RATHER USE THE `derive_services` MACRO**. 159 /// 160 /// # Example 161 /// ```rust,ignore 162 /// use overwatch::OpaqueServiceHandle; 163 /// 164 /// #[derive(Services)] 165 /// struct MyServices { 166 /// database: OpaqueServiceHandle<DatabaseService>, 167 /// cache: OpaqueServiceHandle<CacheService>, 168 /// } 169 /// ``` 170 #[proc_macro_derive(Services)] 171 #[proc_macro_error] 172 pub fn services_derive(input: TokenStream) -> TokenStream { 173 let parsed_input: DeriveInput = parse(input).expect("A syn parseable token stream"); 174 let derived = impl_services(&parsed_input); 175 derived.into() 176 } 177 178 /// Creates a service settings identifier from a services identifier. 179 /// 180 /// This function takes a services identifier and appends `"ServiceSettings"` to 181 /// create the corresponding settings type name. 182 /// 183 /// # Arguments 184 /// 185 /// * `services_identifier` - The identifier of the services struct 186 /// 187 /// # Examples 188 /// 189 /// ```rust,ignore 190 /// use quote::format_ident; 191 /// 192 /// let service_id = format_ident!("AppServices"); 193 /// let settings_id = service_settings_identifier_from(&service_id); 194 /// // settings_id will be "AppServicesServiceSettings" 195 /// ``` 196 fn service_settings_identifier_from(services_identifier: &Ident) -> Ident { 197 format_ident!("{}ServiceSettings", services_identifier) 198 } 199 200 /// Creates a service settings field identifier from a field identifier. 201 /// 202 /// This function takes a field identifier and appends "_settings" to create 203 /// the corresponding settings field name. 204 /// 205 /// # Arguments 206 /// 207 /// * `field_identifier` - The identifier of the service field 208 /// 209 /// # Examples 210 /// 211 /// ```rust,ignore 212 /// use quote::format_ident; 213 /// 214 /// let field_id = format_ident!("database"); 215 /// let settings_field_id = service_settings_field_identifier_from(&field_id); 216 /// // settings_field_id will be "database_settings" 217 /// ``` 218 fn service_settings_field_identifier_from(field_identifier: &Ident) -> Ident { 219 format_ident!("{}_settings", field_identifier) 220 } 221 222 /// Implements the [`overwatch::overwatch::Services`] trait for the given input. 223 /// 224 /// This function examines the input structure and generates the appropriate 225 /// implementation of the trait based on the structure's fields. 226 /// 227 /// # Arguments 228 /// 229 /// * `input` - The parsed derive input 230 /// 231 /// # Returns 232 /// 233 /// A token stream containing the Services trait implementation 234 /// 235 /// # Panics 236 /// 237 /// This function will abort compilation if the input is not a struct with named 238 /// fields. 239 fn impl_services(input: &DeriveInput) -> proc_macro2::TokenStream { 240 use syn::DataStruct; 241 242 let struct_identifier = &input.ident; 243 let data = &input.data; 244 let generics = &input.generics; 245 match data { 246 Data::Struct(DataStruct { 247 fields: Fields::Named(fields), 248 .. 249 }) => impl_services_for_struct(struct_identifier, generics, &fields.named), 250 _ => { 251 abort_call_site!( 252 "Deriving Services is only supported for named structs with at least one field." 253 ); 254 } 255 } 256 } 257 258 /// Implements the [`overwatch::overwatch::Services`] trait for a struct with 259 /// named fields. 260 /// 261 /// This function generates all necessary code for implementing the Services 262 /// trait, including runtime service types, settings, and implementation 263 /// methods. 264 /// 265 /// # Arguments 266 /// 267 /// * `identifier` - The struct identifier 268 /// * `generics` - The struct's generic parameters 269 /// * `fields` - The struct's fields 270 /// 271 /// # Returns 272 /// 273 /// A token stream containing the combined implementations. 274 fn impl_services_for_struct( 275 identifier: &Ident, 276 generics: &Generics, 277 fields: &Punctuated<Field, Comma>, 278 ) -> proc_macro2::TokenStream { 279 let runtime_service_type = generate_runtime_service_types(fields); 280 let settings = generate_services_settings(identifier, generics, fields); 281 let services_impl = generate_services_impl(identifier, generics, fields); 282 283 quote! { 284 #runtime_service_type 285 286 #settings 287 288 #services_impl 289 } 290 } 291 292 /// Generates the services settings struct for a given service. 293 /// 294 /// This function creates a new struct that holds the settings for each service 295 /// field in the original struct. The generated struct will have the same 296 /// generics as the original struct. 297 /// 298 /// # Arguments 299 /// 300 /// * `services_identifier` - The identifier of the services struct 301 /// * `generics` - The generic parameters of the services struct 302 /// * `fields` - The fields of the services struct 303 /// 304 /// # Returns 305 /// 306 /// A token stream containing the settings struct definition. 307 fn generate_services_settings( 308 services_identifier: &Ident, 309 generics: &Generics, 310 fields: &Punctuated<Field, Comma>, 311 ) -> proc_macro2::TokenStream { 312 let services_settings = fields.iter().map(|field| { 313 let service_name = field.ident.as_ref().expect("A named struct attribute"); 314 let _type = utils::extract_type_from(&field.ty); 315 316 quote!(pub #service_name: <#_type as ::overwatch::services::ServiceData>::Settings) 317 }); 318 let services_settings_identifier = service_settings_identifier_from(services_identifier); 319 let where_clause = &generics.where_clause; 320 quote! { 321 #[derive(::core::clone::Clone, ::core::fmt::Debug)] 322 pub struct #services_settings_identifier #generics #where_clause { 323 #( #services_settings ),* 324 } 325 } 326 } 327 328 const RUNTIME_SERVICE_ID_TYPE_NAME: &str = "RuntimeServiceId"; 329 fn get_runtime_service_id_type_name() -> Type { 330 parse_str(RUNTIME_SERVICE_ID_TYPE_NAME) 331 .expect("Runtime service ID type is a valid type token stream.") 332 } 333 334 /// Generates the [`overwatch::overwatch::Services`] trait implementation for a 335 /// struct. 336 /// 337 /// This function creates the full implementation of the `Services` trait, 338 /// including all required methods like `new`, `start_all`, `start`, `stop`, 339 /// etc. 340 /// 341 /// # Arguments 342 /// 343 /// * `services_identifier` - The identifier of the services struct 344 /// * `generics` - The generic parameters of the services struct 345 /// * `fields` - The fields of the services struct 346 /// 347 /// # Returns 348 /// 349 /// A token stream containing the Services trait implementation. 350 fn generate_services_impl( 351 services_identifier: &Ident, 352 generics: &Generics, 353 fields: &Punctuated<Field, Comma>, 354 ) -> proc_macro2::TokenStream { 355 let services_settings_identifier = service_settings_identifier_from(services_identifier); 356 let impl_new = generate_new_impl(fields); 357 let impl_start = generate_start_impl(fields); 358 let impl_start_sequence = generate_start_sequence_impl(fields); 359 let impl_start_all = generate_start_all_impl(fields); 360 let impl_stop = generate_stop_impl(fields); 361 let impl_stop_sequence = generate_stop_sequence_impl(fields); 362 let impl_stop_all = generate_stop_all_impl(fields); 363 let impl_teardown = generate_teardown_impl(fields); 364 let impl_ids = generate_ids_impl(fields); 365 let impl_relay = generate_request_relay_impl(fields); 366 let impl_status = generate_request_status_watcher_impl(fields); 367 let impl_update_settings = generate_update_settings_impl(fields); 368 let impl_get_service_lifecycle_notifier = generate_get_service_lifecycle_notifier_impl(fields); 369 370 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); 371 372 let runtime_service_id_type_name = get_runtime_service_id_type_name(); 373 quote! { 374 #[::async_trait::async_trait] 375 impl #impl_generics ::overwatch::overwatch::Services for #services_identifier #ty_generics #where_clause { 376 type Settings = #services_settings_identifier #ty_generics; 377 type RuntimeServiceId = #runtime_service_id_type_name; 378 379 #impl_new 380 381 #impl_start 382 383 #impl_start_sequence 384 385 #impl_start_all 386 387 #impl_stop 388 389 #impl_stop_sequence 390 391 #impl_stop_all 392 393 #impl_teardown 394 395 #impl_ids 396 397 #impl_relay 398 399 #impl_status 400 401 #impl_update_settings 402 403 #impl_get_service_lifecycle_notifier 404 } 405 } 406 } 407 408 /// Generates the `new` method implementation for the `Services` trait. 409 /// 410 /// This function creates the code to initialize each service field with its 411 /// corresponding settings and wrap it in an `OpaqueServiceHandle`. 412 /// 413 /// # Arguments 414 /// 415 /// * `fields` - The fields of the services struct 416 /// 417 /// # Returns 418 /// 419 /// A token stream containing the new method implementation. 420 fn generate_new_impl(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream { 421 let fields_settings = fields.iter().map(|field| { 422 let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); 423 let settings_field_identifier = service_settings_field_identifier_from(field_identifier); 424 quote! { 425 #field_identifier: #settings_field_identifier 426 } 427 }); 428 429 let managers = fields.iter().map(|field| { 430 let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); 431 let service_type = utils::extract_type_from(&field.ty); 432 let settings_field_identifier = service_settings_field_identifier_from(field_identifier); 433 quote! { 434 #field_identifier: { 435 let runner = 436 ::overwatch::OpaqueServiceRunner::<#service_type, Self::RuntimeServiceId>::new( 437 #settings_field_identifier, overwatch_handle.clone(), <#service_type as ::overwatch::services::ServiceData>::SERVICE_RELAY_BUFFER_SIZE 438 ); 439 let service_runner_handle = runner.run::<#service_type>(); 440 service_runner_handle 441 } 442 } 443 }); 444 445 quote! { 446 fn new(settings: Self::Settings, overwatch_handle: ::overwatch::overwatch::handle::OverwatchHandle<Self::RuntimeServiceId>) -> ::core::result::Result<Self, ::overwatch::DynError> { 447 let Self::Settings { 448 #( #fields_settings ),* 449 } = settings; 450 451 let app = Self { 452 #( #managers ),* 453 }; 454 455 ::core::result::Result::Ok(app) 456 } 457 } 458 } 459 460 /// Generates the `start` method implementation for the `Services` trait. 461 /// 462 /// This function creates code to start a specific service identified by its 463 /// `RuntimeServiceId`. It generates a match expression that maps each service 464 /// ID to the corresponding field's service runner. 465 /// 466 /// # Arguments 467 /// 468 /// * `fields` - The fields of the services struct 469 /// 470 /// # Returns 471 /// 472 /// A token stream containing the start method implementation. 473 fn generate_start_impl(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream { 474 let instrumentation = get_default_instrumentation_for_result(); 475 476 let cases = fields.iter().map(|field| { 477 let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); 478 let type_id = utils::extract_type_from(&field.ty); 479 quote! { 480 &<Self::RuntimeServiceId as ::overwatch::services::AsServiceId<#type_id>>::SERVICE_ID => { 481 self.#field_identifier.service_handle().lifecycle_notifier().send( 482 ::overwatch::services::lifecycle::LifecycleMessage::Start(sender) 483 ).await?; 484 } 485 } 486 }); 487 488 quote! { 489 #instrumentation 490 async fn start(&mut self, service_id: &Self::RuntimeServiceId) -> ::core::result::Result<(), ::overwatch::overwatch::Error> { 491 let (sender, mut receiver) = ::overwatch::utils::finished_signal::channel(); 492 match service_id { 493 #( #cases ),* 494 }; 495 receiver.await.map_err(|error| { 496 let dyn_error: ::overwatch::DynError = Box::new(error); 497 ::overwatch::overwatch::Error::from(dyn_error) 498 }) 499 } 500 } 501 } 502 503 /// Generates the `start_sequence` method implementation for the `Services` 504 /// trait. 505 /// 506 /// This function creates code to start a list of services identified by their 507 /// `RuntimeServiceId`. 508 /// 509 /// # Arguments 510 /// 511 /// * `fields` - The fields of the services struct 512 /// 513 /// # Returns 514 /// 515 /// A token stream containing the `start_sequence` method implementation. 516 fn generate_start_sequence_impl(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream { 517 let instrumentation = get_default_instrumentation(); 518 519 let var_services_len = Ident::new("services_len", Span::call_site()); 520 let call_create_finished_signal_channels = 521 create_finished_signal_channels_from_variable(&var_services_len); 522 523 let var_service_ids = Ident::new("service_ids", Span::call_site()); 524 let var_service_id = Ident::new("service_id", Span::call_site()); 525 let match_cases = fields.iter().map(|field| { 526 let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); 527 let type_id = utils::extract_type_from(&field.ty); 528 let call_send_start = send_start_lifecycle_message_over_senders(field_identifier); 529 quote! { 530 &<Self::RuntimeServiceId as ::overwatch::services::AsServiceId<#type_id>>::SERVICE_ID => { 531 #call_send_start 532 } 533 } 534 }); 535 let loop_match = quote! { 536 for #var_service_id in #var_service_ids { 537 match #var_service_id { 538 #( #match_cases ),* 539 } 540 } 541 }; 542 543 let call_await_finished_signal_receivers = await_finished_signal_receivers(); 544 545 quote! { 546 #instrumentation 547 async fn start_sequence(&mut self, service_ids: &[Self::RuntimeServiceId]) -> ::core::result::Result<(), ::overwatch::overwatch::Error> { 548 let #var_services_len = service_ids.len(); 549 #call_create_finished_signal_channels; 550 551 #loop_match; 552 553 #call_await_finished_signal_receivers; 554 555 Ok(()) 556 } 557 } 558 } 559 560 /// Generates the `start_all` method implementation for the `Services` trait. 561 /// 562 /// This function creates code to start all service runners and return a 563 /// combined lifecycle handle that can be used to manage the running services. 564 /// 565 /// # Arguments 566 /// 567 /// * `fields` - The fields of the services struct 568 /// 569 /// # Returns 570 /// 571 /// A token stream containing the `start_all` method implementation. 572 fn generate_start_all_impl(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream { 573 let instrumentation = get_default_instrumentation(); 574 575 let fields_len = fields.len(); 576 let call_create_channels = create_finished_signal_channels_from_amount(fields_len); 577 578 let call_send_start_message = fields.iter().map(|field| { 579 let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); 580 send_start_lifecycle_message_over_senders(field_identifier) 581 }); 582 583 let call_recv_finished_signals = await_finished_signal_receivers(); 584 585 quote! { 586 #instrumentation 587 async fn start_all(&mut self) -> ::core::result::Result<(), ::overwatch::overwatch::Error> { 588 #call_create_channels 589 590 #( #call_send_start_message )* 591 592 #call_recv_finished_signals 593 594 Ok::<(), ::overwatch::overwatch::Error>(()) 595 } 596 } 597 } 598 599 /// Generates the `stop` method implementation for the `Services` trait. 600 /// 601 /// This function creates code to stop a specific service identified by its 602 /// `RuntimeServiceId`. Currently, this generates unimplemented stubs as the 603 /// service lifecycle is not yet fully implemented. 604 /// 605 /// # Arguments 606 /// 607 /// * `fields` - The fields of the services struct 608 /// 609 /// # Returns 610 /// 611 /// A token stream containing the stop method implementation. 612 fn generate_stop_impl(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream { 613 let instrumentation = get_default_instrumentation(); 614 615 let cases = fields.iter().map(|field| { 616 let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); 617 let type_id = utils::extract_type_from(&field.ty); 618 quote! { 619 &<Self::RuntimeServiceId as ::overwatch::services::AsServiceId<#type_id>>::SERVICE_ID => { 620 self.#field_identifier.service_handle().lifecycle_notifier().send( 621 ::overwatch::services::lifecycle::LifecycleMessage::Stop(sender) 622 ).await?; 623 } 624 } 625 }); 626 627 quote! { 628 #instrumentation 629 async fn stop(&mut self, service_id: &Self::RuntimeServiceId) -> ::core::result::Result<(), ::overwatch::overwatch::Error> { 630 let (sender, mut receiver) = ::overwatch::utils::finished_signal::channel(); 631 match service_id { 632 #( #cases ),* 633 }; 634 receiver.await.map_err(|error| { 635 let dyn_error: ::overwatch::DynError = Box::new(error); 636 ::overwatch::overwatch::Error::from(dyn_error) 637 }) 638 } 639 } 640 } 641 642 /// Generates the `stop_sequence` method implementation for the `Services` 643 /// trait. 644 /// 645 /// This function creates code to stop a list of services identified by their 646 /// `RuntimeServiceId`. 647 /// 648 /// # Arguments 649 /// 650 /// * `fields` - The fields of the services struct 651 /// 652 /// # Returns 653 /// 654 /// A token stream containing the `stop_sequence` method implementation. 655 fn generate_stop_sequence_impl(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream { 656 let instrumentation = get_default_instrumentation(); 657 658 let var_services_len = Ident::new("services_len", Span::call_site()); 659 let call_create_finished_signal_channels = 660 create_finished_signal_channels_from_variable(&var_services_len); 661 662 let var_service_ids = Ident::new("service_ids", Span::call_site()); 663 let var_service_id = Ident::new("service_id", Span::call_site()); 664 let match_cases = fields.iter().map(|field| { 665 let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); 666 let type_id = utils::extract_type_from(&field.ty); 667 let call_send_stop = send_stop_lifecycle_message_over_senders(field_identifier); 668 quote! { 669 &<Self::RuntimeServiceId as ::overwatch::services::AsServiceId<#type_id>>::SERVICE_ID => { 670 #call_send_stop 671 } 672 } 673 }); 674 let loop_match = quote! { 675 for #var_service_id in #var_service_ids { 676 match #var_service_id { 677 #( #match_cases ),* 678 } 679 } 680 }; 681 682 let call_await_finished_signal_receivers = await_finished_signal_receivers(); 683 684 quote! { 685 #instrumentation 686 async fn stop_sequence(&mut self, service_ids: &[Self::RuntimeServiceId]) -> ::core::result::Result<(), ::overwatch::overwatch::Error> { 687 let #var_services_len = service_ids.len(); 688 #call_create_finished_signal_channels; 689 690 #loop_match; 691 692 #call_await_finished_signal_receivers; 693 694 Ok(()) 695 } 696 } 697 } 698 699 /// Generates the `stop_all` method implementation for the `Services` trait. 700 /// 701 /// This function creates code to stop all service runners. 702 /// 703 /// # Arguments 704 /// 705 /// * `fields` - The fields of the services struct 706 /// 707 /// # Returns 708 /// 709 /// A token stream containing the `stop_all` method implementation. 710 fn generate_stop_all_impl(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream { 711 let instrumentation = get_default_instrumentation(); 712 713 let fields_len = fields.len(); 714 let call_create_channels = create_finished_signal_channels_from_amount(fields_len); 715 716 let call_send_stop_message_to_services = fields.iter().map(|field| { 717 let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); 718 send_stop_lifecycle_message_over_senders(field_identifier) 719 }); 720 721 let call_recv_finished_signals = await_finished_signal_receivers(); 722 723 quote! { 724 #instrumentation 725 async fn stop_all(&mut self) -> Result<(), ::overwatch::overwatch::Error> { 726 #call_create_channels 727 728 #( #call_send_stop_message_to_services )* 729 730 #call_recv_finished_signals 731 732 Ok::<(), ::overwatch::overwatch::Error>(()) 733 } 734 } 735 } 736 737 /// Generates the `teardown` method implementation for the `Services` trait. 738 /// 739 /// This function creates code to teardown the `Services` struct. 740 /// 741 /// # Arguments 742 /// 743 /// * `fields` - The fields of the services struct 744 /// 745 /// # Returns 746 /// 747 /// A token stream containing the `teardown` method implementation. 748 fn generate_teardown_impl(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream { 749 let instrumentation = get_default_instrumentation(); 750 751 let call_abort_service_runner_join_handles = fields.iter().map(|field| { 752 let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); 753 quote! { 754 self.#field_identifier.runner_join_handle().abort(); 755 } 756 }); 757 758 let call_await_service_runner_join_handles = fields.iter().map(|field| { 759 let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); 760 quote! { 761 if let Err(error) = self.#field_identifier.runner_join_handle_owned().await { 762 ::tracing::error!("Error while awaiting ServiceRunner's JoinHandle: {error}"); 763 } 764 } 765 }); 766 767 quote! { 768 #instrumentation 769 async fn teardown(self) -> Result<(), ::overwatch::overwatch::Error> { 770 # (#call_abort_service_runner_join_handles)* 771 772 # (#call_await_service_runner_join_handles)* 773 774 Ok::<(), ::overwatch::overwatch::Error>(()) 775 } 776 } 777 } 778 779 /// Generates the `ids` method implementation for the `Services` trait. 780 /// 781 /// This function creates code to retrieve the `RuntimeServiceId` for each 782 /// service defined in the struct. It returns a [`Vec`] of `RuntimeServiceId`s. 783 /// 784 /// # Arguments 785 /// 786 /// * `fields` - The fields of the services struct 787 /// 788 /// # Returns 789 /// 790 /// A token stream containing the `ids` method implementation. 791 fn generate_ids_impl(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream { 792 let instrumentation = get_default_instrumentation(); 793 794 let service_ids = fields.iter().map(|field| { 795 let type_id = utils::extract_type_from(&field.ty); 796 quote! { 797 <Self::RuntimeServiceId as ::overwatch::services::AsServiceId<#type_id>>::SERVICE_ID 798 } 799 }); 800 801 quote! { 802 #instrumentation 803 fn ids(&self) -> Vec<Self::RuntimeServiceId> { 804 vec![ #( #service_ids ),* ] 805 } 806 } 807 } 808 809 /// Generates the `request_relay` method implementation for the `Services` 810 /// trait. 811 /// 812 /// This function creates code to request a message relay for a specific service 813 /// identified by its `RuntimeServiceId`. 814 /// 815 /// # Arguments 816 /// 817 /// * `fields` - The fields of the services struct 818 /// 819 /// # Returns 820 /// 821 /// A token stream containing the `request_relay` method implementation. 822 fn generate_request_relay_impl(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream { 823 let instrumentation = get_default_instrumentation(); 824 825 let cases = fields.iter().map(|field| { 826 let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); 827 let type_id = utils::extract_type_from(&field.ty); 828 quote! { 829 &<Self::RuntimeServiceId as ::overwatch::services::AsServiceId<#type_id>>::SERVICE_ID => { 830 ::std::boxed::Box::new(self.#field_identifier.service_handle().relay_with()) 831 } 832 } 833 }); 834 835 quote! { 836 #instrumentation 837 fn request_relay(&mut self, service_id: &Self::RuntimeServiceId) -> ::overwatch::services::relay::AnyMessage { 838 match service_id { 839 #( #cases )* 840 } 841 } 842 } 843 } 844 845 /// Generates the `request_status_watcher` method implementation for the 846 /// `Services` trait. 847 /// 848 /// This function creates code to request a status watcher for a specific 849 /// service identified by its `RuntimeServiceId`. The status watcher can be used 850 /// to monitor the service's status changes. 851 /// 852 /// # Arguments 853 /// 854 /// * `fields` - The fields of the services struct 855 /// 856 /// # Returns 857 /// 858 /// A token stream containing the `request_status_watcher` method 859 /// implementation. 860 fn generate_request_status_watcher_impl( 861 fields: &Punctuated<Field, Comma>, 862 ) -> proc_macro2::TokenStream { 863 let instrumentation = get_default_instrumentation(); 864 865 let cases = fields.iter().map(|field| { 866 let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); 867 let type_id = utils::extract_type_from(&field.ty); 868 quote! { 869 &<Self::RuntimeServiceId as ::overwatch::services::AsServiceId<#type_id>>::SERVICE_ID => { 870 self.#field_identifier.service_handle().status_watcher().clone() 871 } 872 } 873 }); 874 875 quote! { 876 #instrumentation 877 fn request_status_watcher(&self, service_id: &Self::RuntimeServiceId) -> ::overwatch::services::status::StatusWatcher { 878 match service_id { 879 #( #cases )* 880 } 881 } 882 } 883 } 884 885 /// Generates the `update_settings` method implementation for the `Services` 886 /// trait. 887 /// 888 /// This function creates code to update the settings for all services. It 889 /// destructures the settings struct and passes each field's settings to the 890 /// corresponding service handle. 891 /// 892 /// # Arguments 893 /// 894 /// * `fields` - The fields of the services struct 895 /// 896 /// # Returns 897 /// 898 /// A token stream containing the `update_settings` method implementation. 899 fn generate_update_settings_impl(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream { 900 let instrumentation = get_default_instrumentation_without_settings(); 901 902 let fields_settings = fields.iter().map(|field| { 903 let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); 904 let settings_field_identifier = service_settings_field_identifier_from(field_identifier); 905 quote! { 906 #field_identifier: #settings_field_identifier 907 } 908 }); 909 910 let update_settings_call = fields.iter().map(|field| { 911 let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); 912 let settings_field_identifier = service_settings_field_identifier_from(field_identifier); 913 quote! { 914 self.#field_identifier.service_handle().update_settings(#settings_field_identifier); 915 } 916 }); 917 918 quote! { 919 #instrumentation 920 fn update_settings(&mut self, settings: Self::Settings) { 921 let Self::Settings { 922 #( #fields_settings ),* 923 } = settings; 924 925 #( #update_settings_call )* 926 } 927 } 928 } 929 930 /// Generates the `get_service_lifecycle_notifier` method implementation for the 931 /// `Services` trait. 932 /// 933 /// This function creates code to retrieve the lifecycle handle for a specific 934 /// service identified by its `RuntimeServiceId`. The lifecycle handle can be 935 /// used to manage the service's lifecycle events. 936 /// 937 /// # Arguments 938 /// 939 /// * `fields` - The fields of the services struct 940 /// 941 /// # Returns 942 /// 943 /// A token stream containing the `get_service_lifecycle_notifier` method 944 /// implementation. 945 fn generate_get_service_lifecycle_notifier_impl( 946 fields: &Punctuated<Field, Comma>, 947 ) -> proc_macro2::TokenStream { 948 let instrumentation = get_default_instrumentation(); 949 950 let cases = fields.iter().map(|field| { 951 let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); 952 let type_id = utils::extract_type_from(&field.ty); 953 quote! { 954 &<Self::RuntimeServiceId as ::overwatch::services::AsServiceId<#type_id>>::SERVICE_ID => { 955 self.#field_identifier.service_handle().lifecycle_notifier() 956 } 957 } 958 }); 959 960 quote! { 961 #instrumentation 962 fn get_service_lifecycle_notifier(&self, service_id: &Self::RuntimeServiceId) -> &::overwatch::services::lifecycle::LifecycleNotifier { 963 match service_id { 964 #( #cases ),* 965 } 966 } 967 } 968 } 969 970 /// Generates the runtime service type definitions. 971 /// 972 /// This function creates the `RuntimeServiceId` enum, service ID trait 973 /// implementations, and `AsServiceId` trait implementations for each service 974 /// type that is part of the specified runtime. 975 /// 976 /// # Arguments 977 /// 978 /// * `fields` - The fields of the services struct, indicating the different 979 /// services that are part of the runtime. 980 /// 981 /// # Returns 982 /// 983 /// A token stream containing all runtime service type definitions. 984 fn generate_runtime_service_types(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream { 985 let runtime_service_id = generate_runtime_service_id(fields); 986 let service_id_trait_impls = generate_service_id_trait_impls(fields); 987 let as_service_id_impl = generate_as_service_id_impl(fields); 988 989 quote! { 990 #runtime_service_id 991 992 #service_id_trait_impls 993 994 #as_service_id_impl 995 } 996 } 997 998 /// Generates a runtime service ID enum from the fields of a service container 999 /// struct. 1000 /// 1001 /// This function creates an enum named `RuntimeServiceId` where each variant 1002 /// corresponds to a service defined in the service container struct. The enum 1003 /// is automatically derived with useful traits including `Debug`, `Clone`, 1004 /// `Copy`, `PartialEq` and `Eq`. 1005 /// 1006 /// The service names from the struct fields are converted to `PascalCase` for 1007 /// the enum variants. 1008 /// 1009 /// # Arguments 1010 /// 1011 /// * `fields` - A punctuated list of fields from the service container struct 1012 /// 1013 /// # Returns 1014 /// 1015 /// A `TokenStream` containing the definition of the `RuntimeServiceId` enum 1016 /// 1017 /// # Example 1018 /// 1019 /// For a service container struct like: 1020 /// 1021 /// ```rust,ignore 1022 /// struct MyServices { 1023 /// database: OpaqueServiceHandle<DatabaseService>, 1024 /// api_gateway: OpaqueServiceHandle<ApiGatewayService>, 1025 /// user_cache: OpaqueServiceHandle<CacheService<User>>, 1026 /// } 1027 /// ``` 1028 /// 1029 /// This function will generate: 1030 /// 1031 /// ```rust 1032 /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] 1033 /// pub enum RuntimeServiceId { 1034 /// Database, 1035 /// ApiGateway, 1036 /// UserCache, 1037 /// } 1038 /// ``` 1039 /// 1040 /// The generated enum serves as a unique identifier for each service in the 1041 /// application, enabling service lookup, lifecycle management, and message 1042 /// routing throughout the Overwatch framework. 1043 fn generate_runtime_service_id(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream { 1044 let services_names = fields 1045 .iter() 1046 .clone() 1047 .map(|field| (&field.ident, &field.attrs)); 1048 let enum_variants = services_names.map(|(service_name, service_attrs)| { 1049 let capitalized_service_name = format_ident!( 1050 "{}", 1051 utils::field_name_to_type_name( 1052 &service_name 1053 .clone() 1054 .expect("Expected struct named fields.") 1055 .to_string() 1056 ) 1057 ); 1058 1059 quote! { #(#service_attrs),* #capitalized_service_name } 1060 }); 1061 let runtime_service_id_type_name = get_runtime_service_id_type_name(); 1062 let expanded = quote! { 1063 #[derive(::core::fmt::Debug, ::core::clone::Clone, ::core::marker::Copy, ::core::cmp::PartialEq, ::core::cmp::Eq)] 1064 pub enum #runtime_service_id_type_name { 1065 #(#enum_variants),* 1066 } 1067 }; 1068 1069 quote! { 1070 #expanded 1071 } 1072 } 1073 1074 /// Generates different trait implementations, e.g. `Display`, for 1075 /// `RuntimeServiceId`. 1076 /// 1077 /// # Returns 1078 /// 1079 /// A token stream containing the Display trait implementation 1080 fn generate_service_id_trait_impls(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream { 1081 let runtime_service_id_type_name = get_runtime_service_id_type_name(); 1082 1083 let runtime_service_id_from_str_impl = generate_runtime_service_id_from_str_impl(fields); 1084 1085 quote! { 1086 impl ::core::fmt::Display for #runtime_service_id_type_name { 1087 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { 1088 <Self as ::core::fmt::Debug>::fmt(self, f) 1089 } 1090 } 1091 1092 #runtime_service_id_from_str_impl 1093 } 1094 } 1095 1096 /// Generates the `RuntimeServiceId` enum from a string representation. 1097 /// 1098 /// # Returns 1099 /// 1100 /// A token stream containing the implementation of the `From<Into<String>>` 1101 /// trait 1102 fn generate_runtime_service_id_from_str_impl( 1103 fields: &Punctuated<Field, Comma>, 1104 ) -> proc_macro2::TokenStream { 1105 let runtime_service_id_type_name = get_runtime_service_id_type_name(); 1106 1107 let available_services = fields 1108 .iter() 1109 .map(|field| { 1110 let field_identifier = field 1111 .ident 1112 .as_ref() 1113 .expect("Expected struct named fields.") 1114 .to_string(); 1115 utils::field_name_to_type_name(&field_identifier) 1116 }) 1117 .collect::<Vec<_>>() 1118 .join(", "); 1119 1120 let string_to_variant_pairs = fields.iter().map(|field| { 1121 let field_ident = field.ident.as_ref().expect("Expected struct named fields."); 1122 let type_name_capitalized = utils::field_name_to_type_name(&field_ident.to_string()); 1123 let type_identifier_capitalized = format_ident!("{}", type_name_capitalized); 1124 let runtime_service_id_variant = 1125 quote! { #runtime_service_id_type_name::#type_identifier_capitalized }; 1126 (type_name_capitalized, runtime_service_id_variant) 1127 }); 1128 1129 let arms = string_to_variant_pairs.map(|(name, variant)| { 1130 quote! { 1131 #name => { Ok(#variant) } 1132 } 1133 }); 1134 1135 quote! { 1136 impl ::std::str::FromStr for #runtime_service_id_type_name { 1137 type Err = ::overwatch::overwatch::Error; 1138 1139 fn from_str(value: &str) -> ::core::result::Result<Self, Self::Err> { 1140 match value.as_ref() { 1141 #( #arms ),* 1142 _ => { 1143 let error_string = format!( 1144 "Couldn't find a service with the name: {value}. Available services are: {}.", 1145 #available_services 1146 ); 1147 let error = ::overwatch::overwatch::Error::Any(::overwatch::DynError::from(error_string)); 1148 Err(error) 1149 } 1150 } 1151 } 1152 } 1153 } 1154 } 1155 1156 /// Generates implementations of the `AsServiceId` trait for service types. 1157 /// 1158 /// This function creates trait implementations that map service types to their 1159 /// corresponding service ID enum variants. It examines the fields of a service 1160 /// container struct and automatically generates the necessary trait 1161 /// implementations to connect each service with its identifier in the runtime 1162 /// service ID enum. 1163 /// 1164 /// This is an internal function used by the `derive_services` macro to generate 1165 /// the necessary trait implementations for service identification. 1166 /// 1167 /// # Arguments 1168 /// 1169 /// * `fields` - A punctuated list of fields from the service container struct 1170 /// 1171 /// # Returns 1172 /// 1173 /// A `TokenStream` containing all the `AsServiceId` trait implementations for 1174 /// the service types 1175 /// 1176 /// # Example 1177 /// 1178 /// Assuming we have the following service container struct: 1179 /// 1180 /// ```rust,ignore 1181 /// use overwatch::OpaqueServiceHandle; 1182 /// 1183 /// struct MyServices { 1184 /// database: OpaqueServiceHandle<DatabaseService>, 1185 /// api: OpaqueServiceHandle<ApiService>, 1186 /// } 1187 /// ``` 1188 /// 1189 /// The function will generate code similar to: 1190 /// 1191 /// ```rust,ignore 1192 /// use overwatch::services::AsServiceId; 1193 /// 1194 /// impl AsServiceId<DatabaseService> for RuntimeServiceId { 1195 /// const SERVICE_ID: Self = RuntimeServiceId::Database; 1196 /// } 1197 /// 1198 /// impl AsServiceId<ApiService> for RuntimeServiceId { 1199 /// const SERVICE_ID: Self = RuntimeServiceId::Api; 1200 /// } 1201 /// ``` 1202 /// 1203 /// For services with generic parameters: 1204 /// 1205 /// ```rust,ignore 1206 /// use overwatch::OpaqueServiceHandle; 1207 /// 1208 /// struct MyServices { 1209 /// cache: OpaqueServiceHandle<CacheService<String, u64>>, 1210 /// } 1211 /// ``` 1212 /// 1213 /// It will generate: 1214 /// 1215 /// ```rust,ignore 1216 /// use overwatch::services::AsServiceId; 1217 /// 1218 /// impl AsServiceId<CacheService<String, u64>> for RuntimeServiceId { 1219 /// const SERVICE_ID: Self = RuntimeServiceId::Cache; 1220 /// } 1221 /// ``` 1222 /// 1223 /// This enables the runtime system to map between service types and their 1224 /// corresponding identifiers, which is essential for service lifecycle 1225 /// management and message routing. 1226 fn generate_as_service_id_impl(fields: &Punctuated<Field, Comma>) -> proc_macro2::TokenStream { 1227 let impl_blocks = fields.iter().filter_map(|field| { 1228 let field_type = &field.ty; 1229 let field_attrs = &field.attrs; 1230 let capitalized_service_name = format_ident!( 1231 "{}", 1232 utils::field_name_to_type_name( 1233 &field 1234 .ident 1235 .clone() 1236 .expect("Expected struct named fields.") 1237 .to_string() 1238 ) 1239 ); 1240 1241 let Type::Path(path) = &field_type else { 1242 return None; 1243 }; 1244 let path_segment = path.path.segments.last()?; 1245 1246 // Extract the inner type inside OpaqueServiceHandle<T> 1247 let PathArguments::AngleBracketed(args) = &path_segment.arguments else { 1248 return None; 1249 }; 1250 1251 let Some(GenericArgument::Type(inner_type)) = &args.args.first() else { 1252 return None; 1253 }; 1254 1255 let Type::Path(inner_path) = inner_type else { 1256 return None; 1257 }; 1258 1259 let inner_ident = &inner_path.path.segments.last().expect("Expected at least one segment in the inner type path").ident; 1260 let runtime_service_id_type_name = get_runtime_service_id_type_name(); 1261 1262 inner_path.path.segments.last().map_or_else( 1263 || None, 1264 |segment| match &segment.arguments { 1265 PathArguments::AngleBracketed(generic_args) => { 1266 let struct_generics: Vec<_> = generic_args.args.iter() 1267 .filter_map(|arg| match arg { 1268 GenericArgument::Type(Type::Path(type_path)) => Some(type_path.clone()), 1269 _ => None, 1270 }) 1271 .collect(); 1272 1273 Some(quote! { 1274 #(#field_attrs),* 1275 impl ::overwatch::services::AsServiceId<#inner_ident<#(#struct_generics),*>> for #runtime_service_id_type_name { 1276 const SERVICE_ID: Self = #runtime_service_id_type_name::#capitalized_service_name; 1277 } 1278 }) 1279 }, 1280 // No generics case 1281 _ => Some(quote! { 1282 #(#field_attrs),* 1283 impl ::overwatch::services::AsServiceId<#inner_ident> for #runtime_service_id_type_name { 1284 const SERVICE_ID: Self = #runtime_service_id_type_name::#capitalized_service_name; 1285 } 1286 }), 1287 } 1288 ) 1289 }); 1290 1291 quote! { 1292 #(#impl_blocks)* 1293 } 1294 } 1295 1296 fn create_finished_signal_channels_from_amount(amount: usize) -> proc_macro2::TokenStream { 1297 quote! { 1298 let channels = (0..#amount).map(|_| { ::overwatch::utils::finished_signal::channel() }); 1299 let (mut senders, receivers): (Vec<_>, Vec<_>) = channels.into_iter().unzip(); 1300 } 1301 } 1302 1303 fn create_finished_signal_channels_from_variable(variable: &Ident) -> proc_macro2::TokenStream { 1304 quote! { 1305 let channels = (0..#variable).map(|_| { ::overwatch::utils::finished_signal::channel() }); 1306 let (mut senders, receivers): (Vec<_>, Vec<_>) = channels.into_iter().unzip(); 1307 } 1308 } 1309 1310 fn await_finished_signal_receivers() -> proc_macro2::TokenStream { 1311 quote! { 1312 for mut receiver in receivers { 1313 receiver.await.map_err(|error| { 1314 let dyn_error: ::overwatch::DynError = Box::new(error); 1315 ::overwatch::overwatch::Error::from(dyn_error) 1316 })?; 1317 } 1318 } 1319 } 1320 1321 fn send_lifecycle_message_over_senders( 1322 field: &Ident, 1323 lifecycle_variant: &str, 1324 ) -> proc_macro2::TokenStream { 1325 let lifecycle_variant = format_ident!("{}", lifecycle_variant); 1326 quote! { 1327 self.#field.service_handle().lifecycle_notifier().send( 1328 ::overwatch::services::lifecycle::LifecycleMessage::#lifecycle_variant(senders.remove(0)) 1329 ).await?; 1330 } 1331 } 1332 1333 fn send_start_lifecycle_message_over_senders(field: &Ident) -> proc_macro2::TokenStream { 1334 send_lifecycle_message_over_senders(field, "Start") 1335 } 1336 1337 fn send_stop_lifecycle_message_over_senders(field: &Ident) -> proc_macro2::TokenStream { 1338 send_lifecycle_message_over_senders(field, "Stop") 1339 }