homeassistant_tool.py
1 """Home Assistant tool for controlling smart home devices via REST API. 2 3 Registers four LLM-callable tools: 4 - ``ha_list_entities`` -- list/filter entities by domain or area 5 - ``ha_get_state`` -- get detailed state of a single entity 6 - ``ha_list_services`` -- list available services (actions) per domain 7 - ``ha_call_service`` -- call a HA service (turn_on, turn_off, set_temperature, etc.) 8 9 Authentication uses a Long-Lived Access Token via ``HASS_TOKEN`` env var. 10 The HA instance URL is read from ``HASS_URL`` (default: http://homeassistant.local:8123). 11 """ 12 13 import asyncio 14 import json 15 import logging 16 import os 17 import re 18 from typing import Any, Dict, Optional 19 20 logger = logging.getLogger(__name__) 21 22 # --------------------------------------------------------------------------- 23 # Configuration 24 # --------------------------------------------------------------------------- 25 26 # Kept for backward compatibility (e.g. test monkeypatching); prefer _get_config(). 27 _HASS_URL: str = "" 28 _HASS_TOKEN: str = "" 29 30 31 def _get_config(): 32 """Return (hass_url, hass_token) from env vars at call time.""" 33 return ( 34 (_HASS_URL or os.getenv("HASS_URL", "http://homeassistant.local:8123")).rstrip("/"), 35 _HASS_TOKEN or os.getenv("HASS_TOKEN", ""), 36 ) 37 38 # Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1") 39 _ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$") 40 41 # Regex for valid HA service/domain names (e.g. "light", "turn_on", "shell_command"). 42 # Only lowercase ASCII letters, digits, and underscores — no slashes, dots, or 43 # other characters that could allow path traversal in URL construction. 44 # The domain and service are interpolated into /api/services/{domain}/{service}, 45 # so allowing arbitrary strings would enable SSRF via path traversal 46 # (e.g. domain="../../api/config") or blocked-domain bypass 47 # (e.g. domain="shell_command/../light"). 48 _SERVICE_NAME_RE = re.compile(r"^[a-z][a-z0-9_]*$") 49 50 # Service domains blocked for security -- these allow arbitrary code/command 51 # execution on the HA host or enable SSRF attacks on the local network. 52 # HA provides zero service-level access control; all safety must be in our layer. 53 _BLOCKED_DOMAINS = frozenset({ 54 "shell_command", # arbitrary shell commands as root in HA container 55 "command_line", # sensors/switches that execute shell commands 56 "python_script", # sandboxed but can escalate via hass.services.call() 57 "pyscript", # scripting integration with broader access 58 "hassio", # addon control, host shutdown/reboot, stdin to containers 59 "rest_command", # HTTP requests from HA server (SSRF vector) 60 }) 61 62 63 def _get_headers(token: str = "") -> Dict[str, str]: 64 """Return authorization headers for HA REST API.""" 65 if not token: 66 _, token = _get_config() 67 return { 68 "Authorization": f"Bearer {token}", 69 "Content-Type": "application/json", 70 } 71 72 73 # --------------------------------------------------------------------------- 74 # Async helpers (called from sync handlers via run_until_complete) 75 # --------------------------------------------------------------------------- 76 77 def _filter_and_summarize( 78 states: list, 79 domain: Optional[str] = None, 80 area: Optional[str] = None, 81 ) -> Dict[str, Any]: 82 """Filter raw HA states by domain/area and return a compact summary.""" 83 if domain: 84 states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")] 85 86 if area: 87 area_lower = area.lower() 88 states = [ 89 s for s in states 90 if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower() 91 or area_lower in (s.get("attributes", {}).get("area", "") or "").lower() 92 ] 93 94 entities = [] 95 for s in states: 96 entities.append({ 97 "entity_id": s["entity_id"], 98 "state": s["state"], 99 "friendly_name": s.get("attributes", {}).get("friendly_name", ""), 100 }) 101 102 return {"count": len(entities), "entities": entities} 103 104 105 async def _async_list_entities( 106 domain: Optional[str] = None, 107 area: Optional[str] = None, 108 ) -> Dict[str, Any]: 109 """Fetch entity states from HA and optionally filter by domain/area.""" 110 import aiohttp 111 112 hass_url, hass_token = _get_config() 113 url = f"{hass_url}/api/states" 114 async with aiohttp.ClientSession() as session: 115 async with session.get(url, headers=_get_headers(hass_token), timeout=aiohttp.ClientTimeout(total=15)) as resp: 116 resp.raise_for_status() 117 states = await resp.json() 118 119 return _filter_and_summarize(states, domain, area) 120 121 122 async def _async_get_state(entity_id: str) -> Dict[str, Any]: 123 """Fetch detailed state of a single entity.""" 124 import aiohttp 125 126 hass_url, hass_token = _get_config() 127 url = f"{hass_url}/api/states/{entity_id}" 128 async with aiohttp.ClientSession() as session: 129 async with session.get(url, headers=_get_headers(hass_token), timeout=aiohttp.ClientTimeout(total=10)) as resp: 130 resp.raise_for_status() 131 data = await resp.json() 132 133 return { 134 "entity_id": data["entity_id"], 135 "state": data["state"], 136 "attributes": data.get("attributes", {}), 137 "last_changed": data.get("last_changed"), 138 "last_updated": data.get("last_updated"), 139 } 140 141 142 def _build_service_payload( 143 entity_id: Optional[str] = None, 144 data: Optional[Dict[str, Any]] = None, 145 ) -> Dict[str, Any]: 146 """Build the JSON payload for a HA service call.""" 147 payload: Dict[str, Any] = {} 148 if data: 149 payload.update(data) 150 # entity_id parameter takes precedence over data["entity_id"] 151 if entity_id: 152 payload["entity_id"] = entity_id 153 return payload 154 155 156 def _parse_service_response( 157 domain: str, 158 service: str, 159 result: Any, 160 ) -> Dict[str, Any]: 161 """Parse HA service call response into a structured result.""" 162 affected = [] 163 if isinstance(result, list): 164 for s in result: 165 affected.append({ 166 "entity_id": s.get("entity_id", ""), 167 "state": s.get("state", ""), 168 }) 169 170 return { 171 "success": True, 172 "service": f"{domain}.{service}", 173 "affected_entities": affected, 174 } 175 176 177 async def _async_call_service( 178 domain: str, 179 service: str, 180 entity_id: Optional[str] = None, 181 data: Optional[Dict[str, Any]] = None, 182 ) -> Dict[str, Any]: 183 """Call a Home Assistant service.""" 184 import aiohttp 185 186 hass_url, hass_token = _get_config() 187 url = f"{hass_url}/api/services/{domain}/{service}" 188 payload = _build_service_payload(entity_id, data) 189 190 async with aiohttp.ClientSession() as session: 191 async with session.post( 192 url, 193 headers=_get_headers(hass_token), 194 json=payload, 195 timeout=aiohttp.ClientTimeout(total=15), 196 ) as resp: 197 resp.raise_for_status() 198 result = await resp.json() 199 200 return _parse_service_response(domain, service, result) 201 202 203 # --------------------------------------------------------------------------- 204 # Sync wrappers (handler signature: (args, **kw) -> str) 205 # --------------------------------------------------------------------------- 206 207 def _run_async(coro): 208 """Run an async coroutine from a sync handler.""" 209 try: 210 loop = asyncio.get_running_loop() 211 except RuntimeError: 212 loop = None 213 214 if loop and loop.is_running(): 215 # Already inside an event loop -- create a new thread 216 import concurrent.futures 217 with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: 218 future = pool.submit(asyncio.run, coro) 219 return future.result(timeout=30) 220 else: 221 return asyncio.run(coro) 222 223 224 def _handle_list_entities(args: dict, **kw) -> str: 225 """Handler for ha_list_entities tool.""" 226 domain = args.get("domain") 227 area = args.get("area") 228 try: 229 result = _run_async(_async_list_entities(domain=domain, area=area)) 230 return json.dumps({"result": result}) 231 except Exception as e: 232 logger.error("ha_list_entities error: %s", e) 233 return tool_error(f"Failed to list entities: {e}") 234 235 236 def _handle_get_state(args: dict, **kw) -> str: 237 """Handler for ha_get_state tool.""" 238 entity_id = args.get("entity_id", "") 239 if not entity_id: 240 return tool_error("Missing required parameter: entity_id") 241 if not _ENTITY_ID_RE.match(entity_id): 242 return tool_error(f"Invalid entity_id format: {entity_id}") 243 try: 244 result = _run_async(_async_get_state(entity_id)) 245 return json.dumps({"result": result}) 246 except Exception as e: 247 logger.error("ha_get_state error: %s", e) 248 return tool_error(f"Failed to get state for {entity_id}: {e}") 249 250 251 def _handle_call_service(args: dict, **kw) -> str: 252 """Handler for ha_call_service tool.""" 253 domain = args.get("domain", "") 254 service = args.get("service", "") 255 if not domain or not service: 256 return tool_error("Missing required parameters: domain and service") 257 258 # Validate domain/service format BEFORE the blocklist check — prevents 259 # path traversal in /api/services/{domain}/{service} and blocklist bypass 260 # via payloads like "shell_command/../light". 261 if not _SERVICE_NAME_RE.match(domain): 262 return tool_error(f"Invalid domain format: {domain!r}") 263 if not _SERVICE_NAME_RE.match(service): 264 return tool_error(f"Invalid service format: {service!r}") 265 266 if domain in _BLOCKED_DOMAINS: 267 return json.dumps({ 268 "error": f"Service domain '{domain}' is blocked for security. " 269 f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}" 270 }) 271 272 entity_id = args.get("entity_id") 273 if entity_id and not _ENTITY_ID_RE.match(entity_id): 274 return tool_error(f"Invalid entity_id format: {entity_id}") 275 276 data = args.get("data") 277 if isinstance(data, str): 278 try: 279 data = json.loads(data) if data.strip() else None 280 except json.JSONDecodeError as e: 281 return tool_error(f"Invalid JSON string in 'data' parameter: {e}") 282 283 try: 284 result = _run_async(_async_call_service(domain, service, entity_id, data)) 285 return json.dumps({"result": result}) 286 except Exception as e: 287 logger.error("ha_call_service error: %s", e) 288 return tool_error(f"Failed to call {domain}.{service}: {e}") 289 290 291 # --------------------------------------------------------------------------- 292 # List services 293 # --------------------------------------------------------------------------- 294 295 async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]: 296 """Fetch available services from HA and optionally filter by domain.""" 297 import aiohttp 298 299 hass_url, hass_token = _get_config() 300 url = f"{hass_url}/api/services" 301 headers = {"Authorization": f"Bearer {hass_token}", "Content-Type": "application/json"} 302 async with aiohttp.ClientSession() as session: 303 async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=15)) as resp: 304 resp.raise_for_status() 305 services = await resp.json() 306 307 if domain: 308 services = [s for s in services if s.get("domain") == domain] 309 310 # Compact the output for context efficiency 311 result = [] 312 for svc_domain in services: 313 d = svc_domain.get("domain", "") 314 domain_services = {} 315 for svc_name, svc_info in svc_domain.get("services", {}).items(): 316 svc_entry: Dict[str, Any] = {"description": svc_info.get("description", "")} 317 fields = svc_info.get("fields", {}) 318 if fields: 319 svc_entry["fields"] = { 320 k: v.get("description", "") for k, v in fields.items() 321 if isinstance(v, dict) 322 } 323 domain_services[svc_name] = svc_entry 324 result.append({"domain": d, "services": domain_services}) 325 326 return {"count": len(result), "domains": result} 327 328 329 def _handle_list_services(args: dict, **kw) -> str: 330 """Handler for ha_list_services tool.""" 331 domain = args.get("domain") 332 try: 333 result = _run_async(_async_list_services(domain=domain)) 334 return json.dumps({"result": result}) 335 except Exception as e: 336 logger.error("ha_list_services error: %s", e) 337 return tool_error(f"Failed to list services: {e}") 338 339 340 # --------------------------------------------------------------------------- 341 # Availability check 342 # --------------------------------------------------------------------------- 343 344 def _check_ha_available() -> bool: 345 """Tool is only available when HASS_TOKEN is set.""" 346 return bool(os.getenv("HASS_TOKEN")) 347 348 349 # --------------------------------------------------------------------------- 350 # Tool schemas 351 # --------------------------------------------------------------------------- 352 353 HA_LIST_ENTITIES_SCHEMA = { 354 "name": "ha_list_entities", 355 "description": ( 356 "List Home Assistant entities. Optionally filter by domain " 357 "(light, switch, climate, sensor, binary_sensor, cover, fan, etc.) " 358 "or by area name (living room, kitchen, bedroom, etc.)." 359 ), 360 "parameters": { 361 "type": "object", 362 "properties": { 363 "domain": { 364 "type": "string", 365 "description": ( 366 "Entity domain to filter by (e.g. 'light', 'switch', 'climate', " 367 "'sensor', 'binary_sensor', 'cover', 'fan', 'media_player'). " 368 "Omit to list all entities." 369 ), 370 }, 371 "area": { 372 "type": "string", 373 "description": ( 374 "Area/room name to filter by (e.g. 'living room', 'kitchen'). " 375 "Matches against entity friendly names. Omit to list all." 376 ), 377 }, 378 }, 379 "required": [], 380 }, 381 } 382 383 HA_GET_STATE_SCHEMA = { 384 "name": "ha_get_state", 385 "description": ( 386 "Get the detailed state of a single Home Assistant entity, including all " 387 "attributes (brightness, color, temperature setpoint, sensor readings, etc.)." 388 ), 389 "parameters": { 390 "type": "object", 391 "properties": { 392 "entity_id": { 393 "type": "string", 394 "description": ( 395 "The entity ID to query (e.g. 'light.living_room', " 396 "'climate.thermostat', 'sensor.temperature')." 397 ), 398 }, 399 }, 400 "required": ["entity_id"], 401 }, 402 } 403 404 HA_LIST_SERVICES_SCHEMA = { 405 "name": "ha_list_services", 406 "description": ( 407 "List available Home Assistant services (actions) for device control. " 408 "Shows what actions can be performed on each device type and what " 409 "parameters they accept. Use this to discover how to control devices " 410 "found via ha_list_entities." 411 ), 412 "parameters": { 413 "type": "object", 414 "properties": { 415 "domain": { 416 "type": "string", 417 "description": ( 418 "Filter by domain (e.g. 'light', 'climate', 'switch'). " 419 "Omit to list services for all domains." 420 ), 421 }, 422 }, 423 "required": [], 424 }, 425 } 426 427 HA_CALL_SERVICE_SCHEMA = { 428 "name": "ha_call_service", 429 "description": ( 430 "Call a Home Assistant service to control a device. Use ha_list_services " 431 "to discover available services and their parameters for each domain." 432 ), 433 "parameters": { 434 "type": "object", 435 "properties": { 436 "domain": { 437 "type": "string", 438 "description": ( 439 "Service domain (e.g. 'light', 'switch', 'climate', " 440 "'cover', 'media_player', 'fan', 'scene', 'script')." 441 ), 442 }, 443 "service": { 444 "type": "string", 445 "description": ( 446 "Service name (e.g. 'turn_on', 'turn_off', 'toggle', " 447 "'set_temperature', 'set_hvac_mode', 'open_cover', " 448 "'close_cover', 'set_volume_level')." 449 ), 450 }, 451 "entity_id": { 452 "type": "string", 453 "description": ( 454 "Target entity ID (e.g. 'light.living_room'). " 455 "Some services (like scene.turn_on) may not need this." 456 ), 457 }, 458 "data": { 459 "type": "string", 460 "description": ( 461 "Additional service data as a JSON string. Examples: " 462 '{"brightness": 255, "color_name": "blue"} for lights, ' 463 '{"temperature": 22, "hvac_mode": "heat"} for climate, ' 464 '{"volume_level": 0.5} for media players.' 465 ), 466 }, 467 }, 468 "required": ["domain", "service"], 469 }, 470 } 471 472 473 # --------------------------------------------------------------------------- 474 # Registration 475 # --------------------------------------------------------------------------- 476 477 from tools.registry import registry, tool_error 478 479 registry.register( 480 name="ha_list_entities", 481 toolset="homeassistant", 482 schema=HA_LIST_ENTITIES_SCHEMA, 483 handler=_handle_list_entities, 484 check_fn=_check_ha_available, 485 emoji="🏠", 486 ) 487 488 registry.register( 489 name="ha_get_state", 490 toolset="homeassistant", 491 schema=HA_GET_STATE_SCHEMA, 492 handler=_handle_get_state, 493 check_fn=_check_ha_available, 494 emoji="🏠", 495 ) 496 497 registry.register( 498 name="ha_list_services", 499 toolset="homeassistant", 500 schema=HA_LIST_SERVICES_SCHEMA, 501 handler=_handle_list_services, 502 check_fn=_check_ha_available, 503 emoji="🏠", 504 ) 505 506 registry.register( 507 name="ha_call_service", 508 toolset="homeassistant", 509 schema=HA_CALL_SERVICE_SCHEMA, 510 handler=_handle_call_service, 511 check_fn=_check_ha_available, 512 emoji="🏠", 513 )