/ tools / homeassistant_tool.py
homeassistant_tool.py
  1  """Home Assistant tool for controlling smart home devices via REST API.
  2  
  3  Registers four LLM-callable tools:
  4  - ``ha_list_entities`` -- list/filter entities by domain or area
  5  - ``ha_get_state`` -- get detailed state of a single entity
  6  - ``ha_list_services`` -- list available services (actions) per domain
  7  - ``ha_call_service`` -- call a HA service (turn_on, turn_off, set_temperature, etc.)
  8  
  9  Authentication uses a Long-Lived Access Token via ``HASS_TOKEN`` env var.
 10  The HA instance URL is read from ``HASS_URL`` (default: http://homeassistant.local:8123).
 11  """
 12  
 13  import asyncio
 14  import json
 15  import logging
 16  import os
 17  import re
 18  from typing import Any, Dict, Optional
 19  
 20  logger = logging.getLogger(__name__)
 21  
 22  # ---------------------------------------------------------------------------
 23  # Configuration
 24  # ---------------------------------------------------------------------------
 25  
 26  # Kept for backward compatibility (e.g. test monkeypatching); prefer _get_config().
 27  _HASS_URL: str = ""
 28  _HASS_TOKEN: str = ""
 29  
 30  
 31  def _get_config():
 32      """Return (hass_url, hass_token) from env vars at call time."""
 33      return (
 34          (_HASS_URL or os.getenv("HASS_URL", "http://homeassistant.local:8123")).rstrip("/"),
 35          _HASS_TOKEN or os.getenv("HASS_TOKEN", ""),
 36      )
 37  
 38  # Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1")
 39  _ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$")
 40  
 41  # Regex for valid HA service/domain names (e.g. "light", "turn_on", "shell_command").
 42  # Only lowercase ASCII letters, digits, and underscores — no slashes, dots, or
 43  # other characters that could allow path traversal in URL construction.
 44  # The domain and service are interpolated into /api/services/{domain}/{service},
 45  # so allowing arbitrary strings would enable SSRF via path traversal
 46  # (e.g. domain="../../api/config") or blocked-domain bypass
 47  # (e.g. domain="shell_command/../light").
 48  _SERVICE_NAME_RE = re.compile(r"^[a-z][a-z0-9_]*$")
 49  
 50  # Service domains blocked for security -- these allow arbitrary code/command
 51  # execution on the HA host or enable SSRF attacks on the local network.
 52  # HA provides zero service-level access control; all safety must be in our layer.
 53  _BLOCKED_DOMAINS = frozenset({
 54      "shell_command",    # arbitrary shell commands as root in HA container
 55      "command_line",     # sensors/switches that execute shell commands
 56      "python_script",    # sandboxed but can escalate via hass.services.call()
 57      "pyscript",         # scripting integration with broader access
 58      "hassio",           # addon control, host shutdown/reboot, stdin to containers
 59      "rest_command",     # HTTP requests from HA server (SSRF vector)
 60  })
 61  
 62  
 63  def _get_headers(token: str = "") -> Dict[str, str]:
 64      """Return authorization headers for HA REST API."""
 65      if not token:
 66          _, token = _get_config()
 67      return {
 68          "Authorization": f"Bearer {token}",
 69          "Content-Type": "application/json",
 70      }
 71  
 72  
 73  # ---------------------------------------------------------------------------
 74  # Async helpers (called from sync handlers via run_until_complete)
 75  # ---------------------------------------------------------------------------
 76  
 77  def _filter_and_summarize(
 78      states: list,
 79      domain: Optional[str] = None,
 80      area: Optional[str] = None,
 81  ) -> Dict[str, Any]:
 82      """Filter raw HA states by domain/area and return a compact summary."""
 83      if domain:
 84          states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")]
 85  
 86      if area:
 87          area_lower = area.lower()
 88          states = [
 89              s for s in states
 90              if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower()
 91              or area_lower in (s.get("attributes", {}).get("area", "") or "").lower()
 92          ]
 93  
 94      entities = []
 95      for s in states:
 96          entities.append({
 97              "entity_id": s["entity_id"],
 98              "state": s["state"],
 99              "friendly_name": s.get("attributes", {}).get("friendly_name", ""),
100          })
101  
102      return {"count": len(entities), "entities": entities}
103  
104  
105  async def _async_list_entities(
106      domain: Optional[str] = None,
107      area: Optional[str] = None,
108  ) -> Dict[str, Any]:
109      """Fetch entity states from HA and optionally filter by domain/area."""
110      import aiohttp
111  
112      hass_url, hass_token = _get_config()
113      url = f"{hass_url}/api/states"
114      async with aiohttp.ClientSession() as session:
115          async with session.get(url, headers=_get_headers(hass_token), timeout=aiohttp.ClientTimeout(total=15)) as resp:
116              resp.raise_for_status()
117              states = await resp.json()
118  
119      return _filter_and_summarize(states, domain, area)
120  
121  
122  async def _async_get_state(entity_id: str) -> Dict[str, Any]:
123      """Fetch detailed state of a single entity."""
124      import aiohttp
125  
126      hass_url, hass_token = _get_config()
127      url = f"{hass_url}/api/states/{entity_id}"
128      async with aiohttp.ClientSession() as session:
129          async with session.get(url, headers=_get_headers(hass_token), timeout=aiohttp.ClientTimeout(total=10)) as resp:
130              resp.raise_for_status()
131              data = await resp.json()
132  
133      return {
134          "entity_id": data["entity_id"],
135          "state": data["state"],
136          "attributes": data.get("attributes", {}),
137          "last_changed": data.get("last_changed"),
138          "last_updated": data.get("last_updated"),
139      }
140  
141  
142  def _build_service_payload(
143      entity_id: Optional[str] = None,
144      data: Optional[Dict[str, Any]] = None,
145  ) -> Dict[str, Any]:
146      """Build the JSON payload for a HA service call."""
147      payload: Dict[str, Any] = {}
148      if data:
149          payload.update(data)
150      # entity_id parameter takes precedence over data["entity_id"]
151      if entity_id:
152          payload["entity_id"] = entity_id
153      return payload
154  
155  
156  def _parse_service_response(
157      domain: str,
158      service: str,
159      result: Any,
160  ) -> Dict[str, Any]:
161      """Parse HA service call response into a structured result."""
162      affected = []
163      if isinstance(result, list):
164          for s in result:
165              affected.append({
166                  "entity_id": s.get("entity_id", ""),
167                  "state": s.get("state", ""),
168              })
169  
170      return {
171          "success": True,
172          "service": f"{domain}.{service}",
173          "affected_entities": affected,
174      }
175  
176  
177  async def _async_call_service(
178      domain: str,
179      service: str,
180      entity_id: Optional[str] = None,
181      data: Optional[Dict[str, Any]] = None,
182  ) -> Dict[str, Any]:
183      """Call a Home Assistant service."""
184      import aiohttp
185  
186      hass_url, hass_token = _get_config()
187      url = f"{hass_url}/api/services/{domain}/{service}"
188      payload = _build_service_payload(entity_id, data)
189  
190      async with aiohttp.ClientSession() as session:
191          async with session.post(
192              url,
193              headers=_get_headers(hass_token),
194              json=payload,
195              timeout=aiohttp.ClientTimeout(total=15),
196          ) as resp:
197              resp.raise_for_status()
198              result = await resp.json()
199  
200      return _parse_service_response(domain, service, result)
201  
202  
203  # ---------------------------------------------------------------------------
204  # Sync wrappers (handler signature: (args, **kw) -> str)
205  # ---------------------------------------------------------------------------
206  
207  def _run_async(coro):
208      """Run an async coroutine from a sync handler."""
209      try:
210          loop = asyncio.get_running_loop()
211      except RuntimeError:
212          loop = None
213  
214      if loop and loop.is_running():
215          # Already inside an event loop -- create a new thread
216          import concurrent.futures
217          with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
218              future = pool.submit(asyncio.run, coro)
219              return future.result(timeout=30)
220      else:
221          return asyncio.run(coro)
222  
223  
224  def _handle_list_entities(args: dict, **kw) -> str:
225      """Handler for ha_list_entities tool."""
226      domain = args.get("domain")
227      area = args.get("area")
228      try:
229          result = _run_async(_async_list_entities(domain=domain, area=area))
230          return json.dumps({"result": result})
231      except Exception as e:
232          logger.error("ha_list_entities error: %s", e)
233          return tool_error(f"Failed to list entities: {e}")
234  
235  
236  def _handle_get_state(args: dict, **kw) -> str:
237      """Handler for ha_get_state tool."""
238      entity_id = args.get("entity_id", "")
239      if not entity_id:
240          return tool_error("Missing required parameter: entity_id")
241      if not _ENTITY_ID_RE.match(entity_id):
242          return tool_error(f"Invalid entity_id format: {entity_id}")
243      try:
244          result = _run_async(_async_get_state(entity_id))
245          return json.dumps({"result": result})
246      except Exception as e:
247          logger.error("ha_get_state error: %s", e)
248          return tool_error(f"Failed to get state for {entity_id}: {e}")
249  
250  
251  def _handle_call_service(args: dict, **kw) -> str:
252      """Handler for ha_call_service tool."""
253      domain = args.get("domain", "")
254      service = args.get("service", "")
255      if not domain or not service:
256          return tool_error("Missing required parameters: domain and service")
257  
258      # Validate domain/service format BEFORE the blocklist check — prevents
259      # path traversal in /api/services/{domain}/{service} and blocklist bypass
260      # via payloads like "shell_command/../light".
261      if not _SERVICE_NAME_RE.match(domain):
262          return tool_error(f"Invalid domain format: {domain!r}")
263      if not _SERVICE_NAME_RE.match(service):
264          return tool_error(f"Invalid service format: {service!r}")
265  
266      if domain in _BLOCKED_DOMAINS:
267          return json.dumps({
268              "error": f"Service domain '{domain}' is blocked for security. "
269              f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
270          })
271  
272      entity_id = args.get("entity_id")
273      if entity_id and not _ENTITY_ID_RE.match(entity_id):
274          return tool_error(f"Invalid entity_id format: {entity_id}")
275  
276      data = args.get("data")
277      if isinstance(data, str):
278          try:
279              data = json.loads(data) if data.strip() else None
280          except json.JSONDecodeError as e:
281              return tool_error(f"Invalid JSON string in 'data' parameter: {e}")
282  
283      try:
284          result = _run_async(_async_call_service(domain, service, entity_id, data))
285          return json.dumps({"result": result})
286      except Exception as e:
287          logger.error("ha_call_service error: %s", e)
288          return tool_error(f"Failed to call {domain}.{service}: {e}")
289  
290  
291  # ---------------------------------------------------------------------------
292  # List services
293  # ---------------------------------------------------------------------------
294  
295  async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]:
296      """Fetch available services from HA and optionally filter by domain."""
297      import aiohttp
298  
299      hass_url, hass_token = _get_config()
300      url = f"{hass_url}/api/services"
301      headers = {"Authorization": f"Bearer {hass_token}", "Content-Type": "application/json"}
302      async with aiohttp.ClientSession() as session:
303          async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=15)) as resp:
304              resp.raise_for_status()
305              services = await resp.json()
306  
307      if domain:
308          services = [s for s in services if s.get("domain") == domain]
309  
310      # Compact the output for context efficiency
311      result = []
312      for svc_domain in services:
313          d = svc_domain.get("domain", "")
314          domain_services = {}
315          for svc_name, svc_info in svc_domain.get("services", {}).items():
316              svc_entry: Dict[str, Any] = {"description": svc_info.get("description", "")}
317              fields = svc_info.get("fields", {})
318              if fields:
319                  svc_entry["fields"] = {
320                      k: v.get("description", "") for k, v in fields.items()
321                      if isinstance(v, dict)
322                  }
323              domain_services[svc_name] = svc_entry
324          result.append({"domain": d, "services": domain_services})
325  
326      return {"count": len(result), "domains": result}
327  
328  
329  def _handle_list_services(args: dict, **kw) -> str:
330      """Handler for ha_list_services tool."""
331      domain = args.get("domain")
332      try:
333          result = _run_async(_async_list_services(domain=domain))
334          return json.dumps({"result": result})
335      except Exception as e:
336          logger.error("ha_list_services error: %s", e)
337          return tool_error(f"Failed to list services: {e}")
338  
339  
340  # ---------------------------------------------------------------------------
341  # Availability check
342  # ---------------------------------------------------------------------------
343  
344  def _check_ha_available() -> bool:
345      """Tool is only available when HASS_TOKEN is set."""
346      return bool(os.getenv("HASS_TOKEN"))
347  
348  
349  # ---------------------------------------------------------------------------
350  # Tool schemas
351  # ---------------------------------------------------------------------------
352  
353  HA_LIST_ENTITIES_SCHEMA = {
354      "name": "ha_list_entities",
355      "description": (
356          "List Home Assistant entities. Optionally filter by domain "
357          "(light, switch, climate, sensor, binary_sensor, cover, fan, etc.) "
358          "or by area name (living room, kitchen, bedroom, etc.)."
359      ),
360      "parameters": {
361          "type": "object",
362          "properties": {
363              "domain": {
364                  "type": "string",
365                  "description": (
366                      "Entity domain to filter by (e.g. 'light', 'switch', 'climate', "
367                      "'sensor', 'binary_sensor', 'cover', 'fan', 'media_player'). "
368                      "Omit to list all entities."
369                  ),
370              },
371              "area": {
372                  "type": "string",
373                  "description": (
374                      "Area/room name to filter by (e.g. 'living room', 'kitchen'). "
375                      "Matches against entity friendly names. Omit to list all."
376                  ),
377              },
378          },
379          "required": [],
380      },
381  }
382  
383  HA_GET_STATE_SCHEMA = {
384      "name": "ha_get_state",
385      "description": (
386          "Get the detailed state of a single Home Assistant entity, including all "
387          "attributes (brightness, color, temperature setpoint, sensor readings, etc.)."
388      ),
389      "parameters": {
390          "type": "object",
391          "properties": {
392              "entity_id": {
393                  "type": "string",
394                  "description": (
395                      "The entity ID to query (e.g. 'light.living_room', "
396                      "'climate.thermostat', 'sensor.temperature')."
397                  ),
398              },
399          },
400          "required": ["entity_id"],
401      },
402  }
403  
404  HA_LIST_SERVICES_SCHEMA = {
405      "name": "ha_list_services",
406      "description": (
407          "List available Home Assistant services (actions) for device control. "
408          "Shows what actions can be performed on each device type and what "
409          "parameters they accept. Use this to discover how to control devices "
410          "found via ha_list_entities."
411      ),
412      "parameters": {
413          "type": "object",
414          "properties": {
415              "domain": {
416                  "type": "string",
417                  "description": (
418                      "Filter by domain (e.g. 'light', 'climate', 'switch'). "
419                      "Omit to list services for all domains."
420                  ),
421              },
422          },
423          "required": [],
424      },
425  }
426  
427  HA_CALL_SERVICE_SCHEMA = {
428      "name": "ha_call_service",
429      "description": (
430          "Call a Home Assistant service to control a device. Use ha_list_services "
431          "to discover available services and their parameters for each domain."
432      ),
433      "parameters": {
434          "type": "object",
435          "properties": {
436              "domain": {
437                  "type": "string",
438                  "description": (
439                      "Service domain (e.g. 'light', 'switch', 'climate', "
440                      "'cover', 'media_player', 'fan', 'scene', 'script')."
441                  ),
442              },
443              "service": {
444                  "type": "string",
445                  "description": (
446                      "Service name (e.g. 'turn_on', 'turn_off', 'toggle', "
447                      "'set_temperature', 'set_hvac_mode', 'open_cover', "
448                      "'close_cover', 'set_volume_level')."
449                  ),
450              },
451              "entity_id": {
452                  "type": "string",
453                  "description": (
454                      "Target entity ID (e.g. 'light.living_room'). "
455                      "Some services (like scene.turn_on) may not need this."
456                  ),
457              },
458              "data": {
459                  "type": "string",
460                  "description": (
461                      "Additional service data as a JSON string. Examples: "
462                      '{"brightness": 255, "color_name": "blue"} for lights, '
463                      '{"temperature": 22, "hvac_mode": "heat"} for climate, '
464                      '{"volume_level": 0.5} for media players.'
465                  ),
466              },
467          },
468          "required": ["domain", "service"],
469      },
470  }
471  
472  
473  # ---------------------------------------------------------------------------
474  # Registration
475  # ---------------------------------------------------------------------------
476  
477  from tools.registry import registry, tool_error
478  
479  registry.register(
480      name="ha_list_entities",
481      toolset="homeassistant",
482      schema=HA_LIST_ENTITIES_SCHEMA,
483      handler=_handle_list_entities,
484      check_fn=_check_ha_available,
485      emoji="🏠",
486  )
487  
488  registry.register(
489      name="ha_get_state",
490      toolset="homeassistant",
491      schema=HA_GET_STATE_SCHEMA,
492      handler=_handle_get_state,
493      check_fn=_check_ha_available,
494      emoji="🏠",
495  )
496  
497  registry.register(
498      name="ha_list_services",
499      toolset="homeassistant",
500      schema=HA_LIST_SERVICES_SCHEMA,
501      handler=_handle_list_services,
502      check_fn=_check_ha_available,
503      emoji="🏠",
504  )
505  
506  registry.register(
507      name="ha_call_service",
508      toolset="homeassistant",
509      schema=HA_CALL_SERVICE_SCHEMA,
510      handler=_handle_call_service,
511      check_fn=_check_ha_available,
512      emoji="🏠",
513  )