/ timed / src / lib.rs
lib.rs
  1  // Copyright (C) 2019-2021 Alpha-Delta Network Inc.
  2  // This file is part of the alphastd library.
  3  
  4  // The alphastd library is free software: you can redistribute it and/or modify
  5  // it under the terms of the GNU General Public License as published by
  6  // the Free Software Foundation, either version 3 of the License, or
  7  // (at your option) any later version.
  8  
  9  // The alphastd library is distributed in the hope that it will be useful,
 10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
 11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 12  // GNU General Public License for more details.
 13  
 14  // You should have received a copy of the GNU General Public License
 15  // along with the alphastd library. If not, see <https://www.gnu.org/licenses/>.
 16  
 17  // With credits to kardeiz/funtime.
 18  
 19  extern crate proc_macro;
 20  
 21  use proc_macro::TokenStream;
 22  use quote::quote;
 23  use syn::*;
 24  
 25  #[proc_macro_attribute]
 26  pub fn timed(_attrs: TokenStream, item: TokenStream) -> TokenStream {
 27      if let Ok(mut fun) = parse::<ItemFn>(item.clone()) {
 28          let new_stmts = rewrite_stmts(fun.sig.ident.to_string(), &mut fun.block.stmts);
 29          fun.block.stmts = new_stmts;
 30          return quote!(#fun).into();
 31      }
 32  
 33      if let Ok(mut fun) = parse::<TraitItemMethod>(item.clone()) {
 34          if let Some(block) = fun.default.as_mut() {
 35              let new_stmts = rewrite_stmts(fun.sig.ident.to_string(), &mut block.stmts);
 36              block.stmts = new_stmts;
 37              return quote!(#fun).into();
 38          }
 39      }
 40  
 41      if let Ok(mut fun) = parse::<ImplItemMethod>(item) {
 42          let new_stmts = rewrite_stmts(fun.sig.ident.to_string(), &mut fun.block.stmts);
 43          fun.block.stmts = new_stmts;
 44          return quote!(#fun).into();
 45      }
 46  
 47      panic!("`timed` only works on functions")
 48  }
 49  
 50  #[cfg(feature = "timed")]
 51  fn rewrite_stmts(name: String, stmts: &mut Vec<Stmt>) -> Vec<Stmt> {
 52      /// Truncates the given statement to the specified number of characters.
 53      fn truncate(stmt: &Stmt, len: usize) -> String {
 54          // Convert the statement to a string.
 55          let string = quote::ToTokens::to_token_stream(stmt).to_string().replace("\n", " ");
 56          // If the statement is too long, truncate it.
 57          match string.chars().count() > len {
 58              // Truncate the statement and append "..." to the end.
 59              true => string.chars().take(len).chain("...".chars()).collect::<String>(),
 60              // Otherwise, return the statement as-is.
 61              false => string,
 62          }
 63      }
 64  
 65      let setup: Block = parse_quote! {{
 66          struct Timed {
 67              start: std::time::Instant,
 68              name: &'static str,
 69              buffer: String,
 70              prev_mark: Option<std::time::Duration>,
 71          }
 72  
 73          impl Timed {
 74              fn new(name: &'static str) -> Self {
 75                  use std::fmt::Write;
 76                  let mut buffer = String::new();
 77                  writeln!(&mut buffer, "Start: `{}`", name).unwrap();
 78                  Timed {
 79                      start: std::time::Instant::now(),
 80                      name,
 81                      buffer,
 82                      prev_mark: None,
 83                  }
 84              }
 85  
 86              fn mark_elapsed(&mut self, short: &str) {
 87                  use std::fmt::Write;
 88  
 89                  let mut elapsed = self.start.elapsed();
 90                  if let Some(prev) = self.prev_mark.replace(elapsed) {
 91                      elapsed -= prev;
 92                  }
 93  
 94                  let elapsed = {
 95                      let secs = elapsed.as_secs();
 96                      let millis = elapsed.subsec_millis();
 97                      let micros = elapsed.subsec_micros() % 1000;
 98                      let nanos = elapsed.subsec_nanos() % 1000;
 99                      if secs != 0 {
100                          format!("{}.{:0>3}s", secs, millis)
101                      } else if millis > 0 {
102                          format!("{}.{:0>3}ms", millis, micros)
103                      } else if micros > 0 {
104                          format!("{}.{:0>3}µs", micros, nanos)
105                      } else {
106                          format!("{}ns", elapsed.subsec_nanos())
107                      }
108                  };
109  
110                  writeln!(&mut self.buffer, "    {:<55} {:->25}", short, elapsed).unwrap();
111              }
112          }
113  
114          impl Drop for Timed {
115              fn drop(&mut self) {
116                  use std::fmt::Write;
117                  writeln!(&mut self.buffer, "End: `{}` took {:?}", self.name, self.start.elapsed()).unwrap();
118                  print!("{}", &self.buffer);
119              }
120          }
121  
122          let mut timed = Timed::new(#name);
123  
124      }};
125  
126      const LENGTH: usize = 45;
127  
128      let mut new_stmts = setup.stmts;
129  
130      let last = stmts.pop();
131  
132      for (index, stmt) in stmts.drain(..).enumerate() {
133          let short = truncate(&stmt, LENGTH);
134          let short = format!("L{index}: {short}");
135  
136          let next_stmt = parse_quote!(timed.mark_elapsed(#short););
137  
138          new_stmts.push(stmt);
139          new_stmts.push(next_stmt);
140      }
141  
142      if let Some(stmt) = last {
143          let short = truncate(&stmt, LENGTH);
144  
145          let new_stmt = parse_quote! {
146              let return_stmt = { #stmt };
147          };
148          let next_stmt = parse_quote!(timed.mark_elapsed(#short););
149          let return_stmt = parse_quote!(return return_stmt;);
150  
151          new_stmts.push(new_stmt);
152          new_stmts.push(next_stmt);
153          new_stmts.push(return_stmt);
154      }
155  
156      new_stmts
157  }
158  
159  #[cfg(not(feature = "timed"))]
160  fn rewrite_stmts(_name: String, stmts: &mut [Stmt]) -> Vec<Stmt> {
161      stmts.to_vec()
162  }