website_policy.py
1 """Website access policy helpers for URL-capable tools. 2 3 This module loads a user-managed website blocklist from ~/.hermes/config.yaml 4 and optional shared list files. It is intentionally lightweight so web/browser 5 tools can enforce URL policy without pulling in the heavier CLI config stack. 6 7 Policy is cached in memory with a short TTL so config changes take effect 8 quickly without re-reading the file on every URL check. 9 """ 10 11 from __future__ import annotations 12 13 import fnmatch 14 import logging 15 import threading 16 import time 17 from pathlib import Path 18 from typing import Any, Dict, List, Optional, Tuple 19 from urllib.parse import urlparse 20 21 from hermes_constants import get_hermes_home 22 23 logger = logging.getLogger(__name__) 24 25 _DEFAULT_WEBSITE_BLOCKLIST = { 26 "enabled": False, 27 "domains": [], 28 "shared_files": [], 29 } 30 31 # Cache: parsed policy + timestamp. Avoids re-reading config.yaml on every 32 # URL check (a web_crawl with 50 pages would otherwise mean 51 YAML parses). 33 _CACHE_TTL_SECONDS = 30.0 34 _cache_lock = threading.Lock() 35 _cached_policy: Optional[Dict[str, Any]] = None 36 _cached_policy_path: Optional[str] = None 37 _cached_policy_time: float = 0.0 38 39 40 def _get_default_config_path() -> Path: 41 return get_hermes_home() / "config.yaml" 42 43 44 class WebsitePolicyError(Exception): 45 """Raised when a website policy file is malformed.""" 46 47 48 def _normalize_host(host: str) -> str: 49 return (host or "").strip().lower().rstrip(".") 50 51 52 def _normalize_rule(rule: Any) -> Optional[str]: 53 if not isinstance(rule, str): 54 return None 55 value = rule.strip().lower() 56 if not value or value.startswith("#"): 57 return None 58 if "://" in value: 59 parsed = urlparse(value) 60 value = parsed.netloc or parsed.path 61 value = value.split("/", 1)[0].strip().rstrip(".") 62 if value.startswith("www."): 63 value = value[4:] 64 return value or None 65 66 67 def _iter_blocklist_file_rules(path: Path) -> List[str]: 68 """Load rules from a shared blocklist file. 69 70 Missing or unreadable files log a warning and return an empty list 71 rather than raising — a bad file path should not disable all web tools. 72 """ 73 try: 74 raw = path.read_text(encoding="utf-8") 75 except FileNotFoundError: 76 logger.warning("Shared blocklist file not found (skipping): %s", path) 77 return [] 78 except (OSError, UnicodeDecodeError) as exc: 79 logger.warning("Failed to read shared blocklist file %s (skipping): %s", path, exc) 80 return [] 81 82 rules: List[str] = [] 83 for line in raw.splitlines(): 84 stripped = line.strip() 85 if not stripped or stripped.startswith("#"): 86 continue 87 normalized = _normalize_rule(stripped) 88 if normalized: 89 rules.append(normalized) 90 return rules 91 92 93 def _load_policy_config(config_path: Optional[Path] = None) -> Dict[str, Any]: 94 config_path = config_path or _get_default_config_path() 95 if not config_path.exists(): 96 return dict(_DEFAULT_WEBSITE_BLOCKLIST) 97 98 try: 99 import yaml 100 except ImportError: 101 logger.debug("PyYAML not installed — website blocklist disabled") 102 return dict(_DEFAULT_WEBSITE_BLOCKLIST) 103 104 try: 105 with open(config_path, encoding="utf-8") as f: 106 config = yaml.safe_load(f) or {} 107 except yaml.YAMLError as exc: 108 raise WebsitePolicyError(f"Invalid config YAML at {config_path}: {exc}") from exc 109 except OSError as exc: 110 raise WebsitePolicyError(f"Failed to read config file {config_path}: {exc}") from exc 111 if not isinstance(config, dict): 112 raise WebsitePolicyError("config root must be a mapping") 113 114 security = config.get("security", {}) 115 if security is None: 116 security = {} 117 if not isinstance(security, dict): 118 raise WebsitePolicyError("security must be a mapping") 119 120 website_blocklist = security.get("website_blocklist", {}) 121 if website_blocklist is None: 122 website_blocklist = {} 123 if not isinstance(website_blocklist, dict): 124 raise WebsitePolicyError("security.website_blocklist must be a mapping") 125 126 policy = dict(_DEFAULT_WEBSITE_BLOCKLIST) 127 policy.update(website_blocklist) 128 return policy 129 130 131 def load_website_blocklist(config_path: Optional[Path] = None) -> Dict[str, Any]: 132 """Load and return the parsed website blocklist policy. 133 134 Results are cached for ``_CACHE_TTL_SECONDS`` to avoid re-reading 135 config.yaml on every URL check. Pass an explicit ``config_path`` 136 to bypass the cache (used by tests). 137 """ 138 global _cached_policy, _cached_policy_path, _cached_policy_time 139 140 resolved_path = str(config_path) if config_path else "__default__" 141 now = time.monotonic() 142 143 # Return cached policy if still fresh and same path 144 if config_path is None: 145 with _cache_lock: 146 if ( 147 _cached_policy is not None 148 and _cached_policy_path == resolved_path 149 and (now - _cached_policy_time) < _CACHE_TTL_SECONDS 150 ): 151 return _cached_policy 152 153 config_path = config_path or _get_default_config_path() 154 policy = _load_policy_config(config_path) 155 156 raw_domains = policy.get("domains", []) or [] 157 if not isinstance(raw_domains, list): 158 raise WebsitePolicyError("security.website_blocklist.domains must be a list") 159 160 raw_shared_files = policy.get("shared_files", []) or [] 161 if not isinstance(raw_shared_files, list): 162 raise WebsitePolicyError("security.website_blocklist.shared_files must be a list") 163 164 enabled = policy.get("enabled", True) 165 if not isinstance(enabled, bool): 166 raise WebsitePolicyError("security.website_blocklist.enabled must be a boolean") 167 168 rules: List[Dict[str, str]] = [] 169 seen: set[Tuple[str, str]] = set() 170 171 for raw_rule in raw_domains: 172 normalized = _normalize_rule(raw_rule) 173 if normalized and ("config", normalized) not in seen: 174 rules.append({"pattern": normalized, "source": "config"}) 175 seen.add(("config", normalized)) 176 177 for shared_file in raw_shared_files: 178 if not isinstance(shared_file, str) or not shared_file.strip(): 179 continue 180 path = Path(shared_file).expanduser() 181 if not path.is_absolute(): 182 path = (get_hermes_home() / path).resolve() 183 for normalized in _iter_blocklist_file_rules(path): 184 key = (str(path), normalized) 185 if key in seen: 186 continue 187 rules.append({"pattern": normalized, "source": str(path)}) 188 seen.add(key) 189 190 result = {"enabled": enabled, "rules": rules} 191 192 # Cache the result (only for the default path — explicit paths are tests) 193 if config_path == _get_default_config_path(): 194 with _cache_lock: 195 _cached_policy = result 196 _cached_policy_path = "__default__" 197 _cached_policy_time = now 198 199 return result 200 201 202 def invalidate_cache() -> None: 203 """Force the next ``check_website_access`` call to re-read config.""" 204 global _cached_policy 205 with _cache_lock: 206 _cached_policy = None 207 208 209 def _match_host_against_rule(host: str, pattern: str) -> bool: 210 if not host or not pattern: 211 return False 212 if pattern.startswith("*."): 213 return fnmatch.fnmatch(host, pattern) 214 return host == pattern or host.endswith(f".{pattern}") 215 216 217 def _extract_host_from_urlish(url: str) -> str: 218 parsed = urlparse(url) 219 host = _normalize_host(parsed.hostname or parsed.netloc) 220 if host: 221 return host 222 223 if "://" not in url: 224 schemeless = urlparse(f"//{url}") 225 host = _normalize_host(schemeless.hostname or schemeless.netloc) 226 if host: 227 return host 228 229 return "" 230 231 232 def check_website_access(url: str, config_path: Optional[Path] = None) -> Optional[Dict[str, str]]: 233 """Check whether a URL is allowed by the website blocklist policy. 234 235 Returns ``None`` if access is allowed, or a dict with block metadata 236 (``host``, ``rule``, ``source``, ``message``) if blocked. 237 238 Never raises on policy errors — logs a warning and returns ``None`` 239 (fail-open) so a config typo doesn't break all web tools. Pass 240 ``config_path`` explicitly (tests) to get strict error propagation. 241 """ 242 # Fast path: if no explicit config_path and the cached policy is disabled 243 # or empty, skip all work (no YAML read, no host extraction). 244 if config_path is None: 245 with _cache_lock: 246 if _cached_policy is not None and not _cached_policy.get("enabled"): 247 return None 248 249 host = _extract_host_from_urlish(url) 250 if not host: 251 return None 252 253 try: 254 policy = load_website_blocklist(config_path) 255 except WebsitePolicyError as exc: 256 if config_path is not None: 257 raise # Tests pass explicit paths — let errors propagate 258 logger.warning("Website policy config error (failing open): %s", exc) 259 return None 260 except Exception as exc: 261 logger.warning("Unexpected error loading website policy (failing open): %s", exc) 262 return None 263 264 if not policy.get("enabled"): 265 return None 266 267 for rule in policy.get("rules", []): 268 pattern = rule.get("pattern", "") 269 if _match_host_against_rule(host, pattern): 270 logger.info("Blocked URL %s — matched rule '%s' from %s", 271 url, pattern, rule.get("source", "config")) 272 return { 273 "url": url, 274 "host": host, 275 "rule": pattern, 276 "source": rule.get("source", "config"), 277 "message": ( 278 f"Blocked by website policy: '{host}' matched rule '{pattern}'" 279 f" from {rule.get('source', 'config')}" 280 ), 281 } 282 return None