/ src / config.rs
config.rs
  1  use std::path::PathBuf;
  2  
  3  use chrono::FixedOffset;
  4  use config::{File, FileFormat};
  5  use matrix_sdk::ruma::{OwnedRoomId, OwnedUserId};
  6  use regex::Regex;
  7  use serde::{Deserialize, Deserializer};
  8  
  9  use crate::{
 10      aoc::models::AocId,
 11      utils::{self, regex_set_replacer::RegexSetReplacer},
 12  };
 13  
 14  pub fn load<'a>(config_path: impl Iterator<Item = &'a str>) -> anyhow::Result<Config> {
 15      load_with_defaults(std::iter::empty(), config_path)
 16  }
 17  
 18  fn load_with_defaults<'a>(
 19      defaults: impl Iterator<Item = &'a str>,
 20      config_path: impl Iterator<Item = &'a str>,
 21  ) -> anyhow::Result<Config> {
 22      let mut builder = config::Config::builder();
 23  
 24      for content in defaults.chain([include_str!("../config.toml")]) {
 25          let source = File::from_str(content, FileFormat::Toml);
 26          builder = builder.add_source(source);
 27      }
 28  
 29      for path in config_path {
 30          builder = builder.add_source(File::with_name(path.trim()));
 31      }
 32  
 33      builder.build()?.try_deserialize().map_err(Into::into)
 34  }
 35  
 36  #[derive(Debug, Deserialize)]
 37  #[serde(deny_unknown_fields)]
 38  pub struct Config {
 39      #[serde(with = "utils::serde::via_string")]
 40      pub local_timezone: FixedOffset,
 41      pub matrix: MatrixConfig,
 42      pub aoc: AocConfig,
 43      pub garygrady: GarygradyConfig,
 44      pub users: Vec<User>,
 45  }
 46  
 47  #[derive(Debug, Deserialize)]
 48  #[serde(deny_unknown_fields)]
 49  pub struct MatrixConfig {
 50      pub homeserver: String,
 51      pub store_path: PathBuf,
 52      pub admin_ids: Vec<OwnedUserId>,
 53      pub room_id: OwnedRoomId,
 54      pub command_prefix: String,
 55      pub link_prefix: String,
 56  }
 57  
 58  #[derive(Debug, Deserialize)]
 59  #[serde(deny_unknown_fields)]
 60  pub struct AocConfig {
 61      pub session_file: PathBuf,
 62      pub leaderboard_rows: usize,
 63      pub default_cache_ttl: u64,
 64      pub cache_ttl_rules: Vec<CacheTtlRule>,
 65      #[serde(deserialize_with = "deserialize_repo_rules")]
 66      pub repo_rules: RegexSetReplacer,
 67  }
 68  
 69  #[derive(Debug, Deserialize)]
 70  #[serde(deny_unknown_fields)]
 71  pub struct CacheTtlRule {
 72      pub minutes_after_unlock: i64,
 73      pub ttl: u64,
 74  }
 75  
 76  #[derive(Debug, Deserialize)]
 77  #[serde(deny_unknown_fields)]
 78  pub struct GarygradyConfig {
 79      pub interval: u64,
 80      pub max_age: u64,
 81  }
 82  
 83  #[derive(Debug, Clone, Deserialize)]
 84  #[serde(deny_unknown_fields)]
 85  pub struct User {
 86      pub aoc: Option<AocId>,
 87      pub matrix: Option<OwnedUserId>,
 88      pub repo: Option<String>,
 89  }
 90  
 91  fn deserialize_repo_rules<'de, D>(deserializer: D) -> Result<RegexSetReplacer, D::Error>
 92  where
 93      D: Deserializer<'de>,
 94  {
 95      #[derive(Deserialize)]
 96      struct RepoRule {
 97          #[serde(deserialize_with = "utils::serde::deserialize_regex")]
 98          regex: Regex,
 99          title: Option<String>,
100      }
101  
102      let rules = Vec::<RepoRule>::deserialize(deserializer)?
103          .into_iter()
104          .map(|rule| (rule.regex, rule.title.unwrap_or_else(|| "$0".into())))
105          .collect();
106  
107      Ok(RegexSetReplacer::new(rules))
108  }
109  
110  #[cfg(test)]
111  mod tests {
112      use super::*;
113  
114      #[test]
115      fn load() {
116          load_with_defaults(
117              [
118                  "matrix.homeserver = \"https://matrix.example.com\"",
119                  "matrix.store_path = \".store\"",
120                  "matrix.admin_ids = []",
121                  "matrix.room_id = \"!xoXcjSEJPUfQmzETtS:matrix.example.com\"",
122                  "aoc.session_file = \".session\"",
123              ]
124              .into_iter(),
125              [concat!(env!("CARGO_MANIFEST_DIR"), "/users.toml")].into_iter(),
126          )
127          .unwrap();
128      }
129  }