/ tools / website_policy.py
website_policy.py
  1  """Website access policy helpers for URL-capable tools.
  2  
  3  This module loads a user-managed website blocklist from ~/.hermes/config.yaml
  4  and optional shared list files. It is intentionally lightweight so web/browser
  5  tools can enforce URL policy without pulling in the heavier CLI config stack.
  6  
  7  Policy is cached in memory with a short TTL so config changes take effect
  8  quickly without re-reading the file on every URL check.
  9  """
 10  
 11  from __future__ import annotations
 12  
 13  import fnmatch
 14  import logging
 15  import threading
 16  import time
 17  from pathlib import Path
 18  from typing import Any, Dict, List, Optional, Tuple
 19  from urllib.parse import urlparse
 20  
 21  from hermes_constants import get_hermes_home
 22  
 23  logger = logging.getLogger(__name__)
 24  
 25  _DEFAULT_WEBSITE_BLOCKLIST = {
 26      "enabled": False,
 27      "domains": [],
 28      "shared_files": [],
 29  }
 30  
 31  # Cache: parsed policy + timestamp.  Avoids re-reading config.yaml on every
 32  # URL check (a web_crawl with 50 pages would otherwise mean 51 YAML parses).
 33  _CACHE_TTL_SECONDS = 30.0
 34  _cache_lock = threading.Lock()
 35  _cached_policy: Optional[Dict[str, Any]] = None
 36  _cached_policy_path: Optional[str] = None
 37  _cached_policy_time: float = 0.0
 38  
 39  
 40  def _get_default_config_path() -> Path:
 41      return get_hermes_home() / "config.yaml"
 42  
 43  
 44  class WebsitePolicyError(Exception):
 45      """Raised when a website policy file is malformed."""
 46  
 47  
 48  def _normalize_host(host: str) -> str:
 49      return (host or "").strip().lower().rstrip(".")
 50  
 51  
 52  def _normalize_rule(rule: Any) -> Optional[str]:
 53      if not isinstance(rule, str):
 54          return None
 55      value = rule.strip().lower()
 56      if not value or value.startswith("#"):
 57          return None
 58      if "://" in value:
 59          parsed = urlparse(value)
 60          value = parsed.netloc or parsed.path
 61      value = value.split("/", 1)[0].strip().rstrip(".")
 62      if value.startswith("www."):
 63          value = value[4:]
 64      return value or None
 65  
 66  
 67  def _iter_blocklist_file_rules(path: Path) -> List[str]:
 68      """Load rules from a shared blocklist file.
 69  
 70      Missing or unreadable files log a warning and return an empty list
 71      rather than raising — a bad file path should not disable all web tools.
 72      """
 73      try:
 74          raw = path.read_text(encoding="utf-8")
 75      except FileNotFoundError:
 76          logger.warning("Shared blocklist file not found (skipping): %s", path)
 77          return []
 78      except (OSError, UnicodeDecodeError) as exc:
 79          logger.warning("Failed to read shared blocklist file %s (skipping): %s", path, exc)
 80          return []
 81  
 82      rules: List[str] = []
 83      for line in raw.splitlines():
 84          stripped = line.strip()
 85          if not stripped or stripped.startswith("#"):
 86              continue
 87          normalized = _normalize_rule(stripped)
 88          if normalized:
 89              rules.append(normalized)
 90      return rules
 91  
 92  
 93  def _load_policy_config(config_path: Optional[Path] = None) -> Dict[str, Any]:
 94      config_path = config_path or _get_default_config_path()
 95      if not config_path.exists():
 96          return dict(_DEFAULT_WEBSITE_BLOCKLIST)
 97  
 98      try:
 99          import yaml
100      except ImportError:
101          logger.debug("PyYAML not installed — website blocklist disabled")
102          return dict(_DEFAULT_WEBSITE_BLOCKLIST)
103  
104      try:
105          with open(config_path, encoding="utf-8") as f:
106              config = yaml.safe_load(f) or {}
107      except yaml.YAMLError as exc:
108          raise WebsitePolicyError(f"Invalid config YAML at {config_path}: {exc}") from exc
109      except OSError as exc:
110          raise WebsitePolicyError(f"Failed to read config file {config_path}: {exc}") from exc
111      if not isinstance(config, dict):
112          raise WebsitePolicyError("config root must be a mapping")
113  
114      security = config.get("security", {})
115      if security is None:
116          security = {}
117      if not isinstance(security, dict):
118          raise WebsitePolicyError("security must be a mapping")
119  
120      website_blocklist = security.get("website_blocklist", {})
121      if website_blocklist is None:
122          website_blocklist = {}
123      if not isinstance(website_blocklist, dict):
124          raise WebsitePolicyError("security.website_blocklist must be a mapping")
125  
126      policy = dict(_DEFAULT_WEBSITE_BLOCKLIST)
127      policy.update(website_blocklist)
128      return policy
129  
130  
131  def load_website_blocklist(config_path: Optional[Path] = None) -> Dict[str, Any]:
132      """Load and return the parsed website blocklist policy.
133  
134      Results are cached for ``_CACHE_TTL_SECONDS`` to avoid re-reading
135      config.yaml on every URL check.  Pass an explicit ``config_path``
136      to bypass the cache (used by tests).
137      """
138      global _cached_policy, _cached_policy_path, _cached_policy_time
139  
140      resolved_path = str(config_path) if config_path else "__default__"
141      now = time.monotonic()
142  
143      # Return cached policy if still fresh and same path
144      if config_path is None:
145          with _cache_lock:
146              if (
147                  _cached_policy is not None
148                  and _cached_policy_path == resolved_path
149                  and (now - _cached_policy_time) < _CACHE_TTL_SECONDS
150              ):
151                  return _cached_policy
152  
153      config_path = config_path or _get_default_config_path()
154      policy = _load_policy_config(config_path)
155  
156      raw_domains = policy.get("domains", []) or []
157      if not isinstance(raw_domains, list):
158          raise WebsitePolicyError("security.website_blocklist.domains must be a list")
159  
160      raw_shared_files = policy.get("shared_files", []) or []
161      if not isinstance(raw_shared_files, list):
162          raise WebsitePolicyError("security.website_blocklist.shared_files must be a list")
163  
164      enabled = policy.get("enabled", True)
165      if not isinstance(enabled, bool):
166          raise WebsitePolicyError("security.website_blocklist.enabled must be a boolean")
167  
168      rules: List[Dict[str, str]] = []
169      seen: set[Tuple[str, str]] = set()
170  
171      for raw_rule in raw_domains:
172          normalized = _normalize_rule(raw_rule)
173          if normalized and ("config", normalized) not in seen:
174              rules.append({"pattern": normalized, "source": "config"})
175              seen.add(("config", normalized))
176  
177      for shared_file in raw_shared_files:
178          if not isinstance(shared_file, str) or not shared_file.strip():
179              continue
180          path = Path(shared_file).expanduser()
181          if not path.is_absolute():
182              path = (get_hermes_home() / path).resolve()
183          for normalized in _iter_blocklist_file_rules(path):
184              key = (str(path), normalized)
185              if key in seen:
186                  continue
187              rules.append({"pattern": normalized, "source": str(path)})
188              seen.add(key)
189  
190      result = {"enabled": enabled, "rules": rules}
191  
192      # Cache the result (only for the default path — explicit paths are tests)
193      if config_path == _get_default_config_path():
194          with _cache_lock:
195              _cached_policy = result
196              _cached_policy_path = "__default__"
197              _cached_policy_time = now
198  
199      return result
200  
201  
202  def invalidate_cache() -> None:
203      """Force the next ``check_website_access`` call to re-read config."""
204      global _cached_policy
205      with _cache_lock:
206          _cached_policy = None
207  
208  
209  def _match_host_against_rule(host: str, pattern: str) -> bool:
210      if not host or not pattern:
211          return False
212      if pattern.startswith("*."):
213          return fnmatch.fnmatch(host, pattern)
214      return host == pattern or host.endswith(f".{pattern}")
215  
216  
217  def _extract_host_from_urlish(url: str) -> str:
218      parsed = urlparse(url)
219      host = _normalize_host(parsed.hostname or parsed.netloc)
220      if host:
221          return host
222  
223      if "://" not in url:
224          schemeless = urlparse(f"//{url}")
225          host = _normalize_host(schemeless.hostname or schemeless.netloc)
226          if host:
227              return host
228  
229      return ""
230  
231  
232  def check_website_access(url: str, config_path: Optional[Path] = None) -> Optional[Dict[str, str]]:
233      """Check whether a URL is allowed by the website blocklist policy.
234  
235      Returns ``None`` if access is allowed, or a dict with block metadata
236      (``host``, ``rule``, ``source``, ``message``) if blocked.
237  
238      Never raises on policy errors — logs a warning and returns ``None``
239      (fail-open) so a config typo doesn't break all web tools.  Pass
240      ``config_path`` explicitly (tests) to get strict error propagation.
241      """
242      # Fast path: if no explicit config_path and the cached policy is disabled
243      # or empty, skip all work (no YAML read, no host extraction).
244      if config_path is None:
245          with _cache_lock:
246              if _cached_policy is not None and not _cached_policy.get("enabled"):
247                  return None
248  
249      host = _extract_host_from_urlish(url)
250      if not host:
251          return None
252  
253      try:
254          policy = load_website_blocklist(config_path)
255      except WebsitePolicyError as exc:
256          if config_path is not None:
257              raise  # Tests pass explicit paths — let errors propagate
258          logger.warning("Website policy config error (failing open): %s", exc)
259          return None
260      except Exception as exc:
261          logger.warning("Unexpected error loading website policy (failing open): %s", exc)
262          return None
263  
264      if not policy.get("enabled"):
265          return None
266  
267      for rule in policy.get("rules", []):
268          pattern = rule.get("pattern", "")
269          if _match_host_against_rule(host, pattern):
270              logger.info("Blocked URL %s — matched rule '%s' from %s",
271                          url, pattern, rule.get("source", "config"))
272              return {
273                  "url": url,
274                  "host": host,
275                  "rule": pattern,
276                  "source": rule.get("source", "config"),
277                  "message": (
278                      f"Blocked by website policy: '{host}' matched rule '{pattern}'"
279                      f" from {rule.get('source', 'config')}"
280                  ),
281              }
282      return None