/ timed / src / lib.rs
lib.rs
  1  // Copyright (c) 2025 ADnet Contributors
  2  // This file is part of the alpha-std library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  // With credits to kardeiz/funtime.
 17  
 18  extern crate proc_macro;
 19  
 20  use proc_macro::TokenStream;
 21  use quote::quote;
 22  use syn::*;
 23  
 24  #[proc_macro_attribute]
 25  pub fn timed(_attrs: TokenStream, item: TokenStream) -> TokenStream {
 26      if let Ok(mut fun) = parse::<ItemFn>(item.clone()) {
 27          let new_stmts = rewrite_stmts(fun.sig.ident.to_string(), &mut fun.block.stmts);
 28          fun.block.stmts = new_stmts;
 29          return quote!(#fun).into();
 30      }
 31  
 32      if let Ok(mut fun) = parse::<TraitItemMethod>(item.clone()) {
 33          if let Some(block) = fun.default.as_mut() {
 34              let new_stmts = rewrite_stmts(fun.sig.ident.to_string(), &mut block.stmts);
 35              block.stmts = new_stmts;
 36              return quote!(#fun).into();
 37          }
 38      }
 39  
 40      if let Ok(mut fun) = parse::<ImplItemMethod>(item) {
 41          let new_stmts = rewrite_stmts(fun.sig.ident.to_string(), &mut fun.block.stmts);
 42          fun.block.stmts = new_stmts;
 43          return quote!(#fun).into();
 44      }
 45  
 46      panic!("`timed` only works on functions")
 47  }
 48  
 49  #[cfg(feature = "timed")]
 50  fn rewrite_stmts(name: String, stmts: &mut Vec<Stmt>) -> Vec<Stmt> {
 51      /// Truncates the given statement to the specified number of characters.
 52      fn truncate(stmt: &Stmt, len: usize) -> String {
 53          // Convert the statement to a string.
 54          let string = quote::ToTokens::to_token_stream(stmt)
 55              .to_string()
 56              .replace("\n", " ");
 57          // If the statement is too long, truncate it.
 58          match string.chars().count() > len {
 59              // Truncate the statement and append "..." to the end.
 60              true => string
 61                  .chars()
 62                  .take(len)
 63                  .chain("...".chars())
 64                  .collect::<String>(),
 65              // Otherwise, return the statement as-is.
 66              false => string,
 67          }
 68      }
 69  
 70      let setup: Block = parse_quote! {{
 71          struct Timed {
 72              start: std::time::Instant,
 73              name: &'static str,
 74              buffer: String,
 75              prev_mark: Option<std::time::Duration>,
 76          }
 77  
 78          impl Timed {
 79              fn new(name: &'static str) -> Self {
 80                  use std::fmt::Write;
 81                  let mut buffer = String::new();
 82                  writeln!(&mut buffer, "Start: `{}`", name).unwrap();
 83                  Timed {
 84                      start: std::time::Instant::now(),
 85                      name,
 86                      buffer,
 87                      prev_mark: None,
 88                  }
 89              }
 90  
 91              fn mark_elapsed(&mut self, short: &str) {
 92                  use std::fmt::Write;
 93  
 94                  let mut elapsed = self.start.elapsed();
 95                  if let Some(prev) = self.prev_mark.replace(elapsed) {
 96                      elapsed -= prev;
 97                  }
 98  
 99                  let elapsed = {
100                      let secs = elapsed.as_secs();
101                      let millis = elapsed.subsec_millis();
102                      let micros = elapsed.subsec_micros() % 1000;
103                      let nanos = elapsed.subsec_nanos() % 1000;
104                      if secs != 0 {
105                          format!("{}.{:0>3}s", secs, millis)
106                      } else if millis > 0 {
107                          format!("{}.{:0>3}ms", millis, micros)
108                      } else if micros > 0 {
109                          format!("{}.{:0>3}µs", micros, nanos)
110                      } else {
111                          format!("{}ns", elapsed.subsec_nanos())
112                      }
113                  };
114  
115                  writeln!(&mut self.buffer, "    {:<55} {:->25}", short, elapsed).unwrap();
116              }
117          }
118  
119          impl Drop for Timed {
120              fn drop(&mut self) {
121                  use std::fmt::Write;
122                  writeln!(&mut self.buffer, "End: `{}` took {:?}", self.name, self.start.elapsed()).unwrap();
123                  print!("{}", &self.buffer);
124              }
125          }
126  
127          let mut timed = Timed::new(#name);
128  
129      }};
130  
131      const LENGTH: usize = 45;
132  
133      let mut new_stmts = setup.stmts;
134  
135      let last = stmts.pop();
136  
137      for (index, stmt) in stmts.drain(..).enumerate() {
138          let short = truncate(&stmt, LENGTH);
139          let short = format!("L{index}: {short}");
140  
141          let next_stmt = parse_quote!(timed.mark_elapsed(#short););
142  
143          new_stmts.push(stmt);
144          new_stmts.push(next_stmt);
145      }
146  
147      if let Some(stmt) = last {
148          let short = truncate(&stmt, LENGTH);
149  
150          let new_stmt = parse_quote! {
151              let return_stmt = { #stmt };
152          };
153          let next_stmt = parse_quote!(timed.mark_elapsed(#short););
154          let return_stmt = parse_quote!(return return_stmt;);
155  
156          new_stmts.push(new_stmt);
157          new_stmts.push(next_stmt);
158          new_stmts.push(return_stmt);
159      }
160  
161      new_stmts
162  }
163  
164  #[cfg(not(feature = "timed"))]
165  fn rewrite_stmts(_name: String, stmts: &mut [Stmt]) -> Vec<Stmt> {
166      stmts.to_vec()
167  }