fake_ha_server.py
1 """Fake Home Assistant server for integration testing. 2 3 Provides a real HTTP + WebSocket server (via aiohttp.web) that mimics the 4 Home Assistant API surface used by hermes-agent: 5 6 - ``/api/websocket`` -- WebSocket auth handshake + event push 7 - ``/api/states`` -- GET all entity states 8 - ``/api/states/{entity_id}`` -- GET single entity state 9 - ``/api/services/{domain}/{service}`` -- POST service call 10 - ``/api/services/persistent_notification/create`` -- POST notification 11 12 Usage:: 13 14 async with FakeHAServer(token="test-token") as server: 15 url = server.url # e.g. "http://127.0.0.1:54321" 16 await server.push_event(event_data) 17 assert server.received_notifications # verify what arrived 18 """ 19 20 import asyncio 21 import json 22 from typing import Any, Dict, List, Optional 23 24 import aiohttp 25 from aiohttp import web 26 from aiohttp.test_utils import TestServer 27 28 29 # -- Sample entity data ------------------------------------------------------- 30 31 ENTITY_STATES: List[Dict[str, Any]] = [ 32 { 33 "entity_id": "light.bedroom", 34 "state": "on", 35 "attributes": {"friendly_name": "Bedroom Light", "brightness": 200}, 36 "last_changed": "2025-01-15T10:30:00+00:00", 37 "last_updated": "2025-01-15T10:30:00+00:00", 38 }, 39 { 40 "entity_id": "light.kitchen", 41 "state": "off", 42 "attributes": {"friendly_name": "Kitchen Light"}, 43 "last_changed": "2025-01-15T09:00:00+00:00", 44 "last_updated": "2025-01-15T09:00:00+00:00", 45 }, 46 { 47 "entity_id": "sensor.temperature", 48 "state": "22.5", 49 "attributes": { 50 "friendly_name": "Kitchen Temperature", 51 "unit_of_measurement": "C", 52 }, 53 "last_changed": "2025-01-15T10:00:00+00:00", 54 "last_updated": "2025-01-15T10:00:00+00:00", 55 }, 56 { 57 "entity_id": "switch.fan", 58 "state": "on", 59 "attributes": {"friendly_name": "Living Room Fan"}, 60 "last_changed": "2025-01-15T08:00:00+00:00", 61 "last_updated": "2025-01-15T08:00:00+00:00", 62 }, 63 { 64 "entity_id": "climate.thermostat", 65 "state": "heat", 66 "attributes": { 67 "friendly_name": "Main Thermostat", 68 "current_temperature": 21, 69 "temperature": 23, 70 }, 71 "last_changed": "2025-01-15T07:00:00+00:00", 72 "last_updated": "2025-01-15T07:00:00+00:00", 73 }, 74 ] 75 76 77 class FakeHAServer: 78 """In-process fake Home Assistant for integration tests. 79 80 Parameters 81 ---------- 82 token : str 83 The expected Bearer token for authentication. 84 """ 85 86 def __init__(self, token: str = "test-token-123"): 87 self.token = token 88 89 # Observability -- tests inspect these after exercising the adapter. 90 self.received_service_calls: List[Dict[str, Any]] = [] 91 self.received_notifications: List[Dict[str, Any]] = [] 92 93 # Control -- tests push events, server forwards them over WS. 94 self._event_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() 95 96 # Flag to simulate auth rejection. 97 self.reject_auth = False 98 99 # Flag to simulate server errors. 100 self.force_500 = False 101 102 # Internal bookkeeping. 103 self._app: Optional[web.Application] = None 104 self._server: Optional[TestServer] = None 105 self._ws_connections: List[web.WebSocketResponse] = [] 106 107 # -- Public helpers -------------------------------------------------------- 108 109 @property 110 def url(self) -> str: 111 """Base URL of the running server, e.g. ``http://127.0.0.1:12345``.""" 112 assert self._server is not None, "Server not started" 113 host = self._server.host 114 port = self._server.port 115 return f"http://{host}:{port}" 116 117 async def push_event(self, event_data: Dict[str, Any]) -> None: 118 """Enqueue a state_changed event for delivery over WebSocket.""" 119 await self._event_queue.put(event_data) 120 121 # -- Lifecycle ------------------------------------------------------------- 122 123 async def start(self) -> None: 124 self._app = self._build_app() 125 self._server = TestServer(self._app) 126 await self._server.start_server() 127 128 async def stop(self) -> None: 129 # Close any remaining WS connections. 130 for ws in self._ws_connections: 131 if not ws.closed: 132 await ws.close() 133 self._ws_connections.clear() 134 if self._server is not None: 135 await self._server.close() 136 137 async def __aenter__(self) -> "FakeHAServer": 138 await self.start() 139 return self 140 141 async def __aexit__(self, *exc) -> None: 142 await self.stop() 143 144 # -- Application construction ---------------------------------------------- 145 146 def _build_app(self) -> web.Application: 147 app = web.Application() 148 app.router.add_get("/api/websocket", self._handle_ws) 149 app.router.add_get("/api/states", self._handle_get_states) 150 app.router.add_get("/api/states/{entity_id}", self._handle_get_state) 151 # Notification endpoint must be registered before the generic service 152 # route so that it takes priority. 153 app.router.add_post( 154 "/api/services/persistent_notification/create", 155 self._handle_notification, 156 ) 157 app.router.add_post( 158 "/api/services/{domain}/{service}", 159 self._handle_call_service, 160 ) 161 return app 162 163 # -- Auth helper ----------------------------------------------------------- 164 165 def _check_rest_auth(self, request: web.Request) -> Optional[web.Response]: 166 """Return a 401 response if the Bearer token is wrong, else None.""" 167 auth = request.headers.get("Authorization", "") 168 if auth != f"Bearer {self.token}": 169 return web.Response(status=401, text="Unauthorized") 170 if self.force_500: 171 return web.Response(status=500, text="Internal Server Error") 172 return None 173 174 # -- WebSocket handler ----------------------------------------------------- 175 176 async def _handle_ws(self, request: web.Request) -> web.WebSocketResponse: 177 ws = web.WebSocketResponse() 178 await ws.prepare(request) 179 self._ws_connections.append(ws) 180 181 # Step 1: auth_required 182 await ws.send_json({"type": "auth_required", "ha_version": "2025.1.0"}) 183 184 # Step 2: receive auth 185 msg = await ws.receive() 186 if msg.type != aiohttp.WSMsgType.TEXT: 187 await ws.close() 188 return ws 189 auth_msg = json.loads(msg.data) 190 191 # Step 3: validate 192 if self.reject_auth or auth_msg.get("access_token") != self.token: 193 await ws.send_json({"type": "auth_invalid", "message": "Invalid token"}) 194 await ws.close() 195 return ws 196 197 await ws.send_json({"type": "auth_ok", "ha_version": "2025.1.0"}) 198 199 # Step 4: subscribe_events 200 msg = await ws.receive() 201 if msg.type != aiohttp.WSMsgType.TEXT: 202 await ws.close() 203 return ws 204 sub_msg = json.loads(msg.data) 205 sub_id = sub_msg.get("id", 1) 206 207 # Step 5: ACK 208 await ws.send_json({ 209 "id": sub_id, 210 "type": "result", 211 "success": True, 212 "result": None, 213 }) 214 215 # Step 6: push events from queue until closed 216 try: 217 while not ws.closed: 218 try: 219 event_data = await asyncio.wait_for( 220 self._event_queue.get(), timeout=0.1, 221 ) 222 await ws.send_json({ 223 "id": sub_id, 224 "type": "event", 225 "event": event_data, 226 }) 227 except asyncio.TimeoutError: 228 continue 229 except (ConnectionResetError, asyncio.CancelledError): 230 pass 231 232 return ws 233 234 # -- REST handlers --------------------------------------------------------- 235 236 async def _handle_get_states(self, request: web.Request) -> web.Response: 237 err = self._check_rest_auth(request) 238 if err: 239 return err 240 return web.json_response(ENTITY_STATES) 241 242 async def _handle_get_state(self, request: web.Request) -> web.Response: 243 err = self._check_rest_auth(request) 244 if err: 245 return err 246 entity_id = request.match_info["entity_id"] 247 for s in ENTITY_STATES: 248 if s["entity_id"] == entity_id: 249 return web.json_response(s) 250 return web.Response(status=404, text=f"Entity {entity_id} not found") 251 252 async def _handle_notification(self, request: web.Request) -> web.Response: 253 err = self._check_rest_auth(request) 254 if err: 255 return err 256 body = await request.json() 257 self.received_notifications.append(body) 258 return web.json_response([]) 259 260 async def _handle_call_service(self, request: web.Request) -> web.Response: 261 err = self._check_rest_auth(request) 262 if err: 263 return err 264 domain = request.match_info["domain"] 265 service = request.match_info["service"] 266 body = await request.json() 267 268 self.received_service_calls.append({ 269 "domain": domain, 270 "service": service, 271 "data": body, 272 }) 273 274 # Return affected entities (mimics real HA behaviour for light/switch). 275 affected = [] 276 entity_id = body.get("entity_id") 277 if entity_id: 278 for s in ENTITY_STATES: 279 if s["entity_id"] == entity_id: 280 if service == "turn_on": 281 s["state"] = "on" 282 elif service == "turn_off": 283 s["state"] = "off" 284 elif service == "set_temperature" and "temperature" in body: 285 s["attributes"]["temperature"] = body["temperature"] 286 # Keep current state or set to heat if off 287 if s["state"] == "off": 288 s["state"] = "heat" 289 # Simulate temperature sensor approaching the target 290 for ts in ENTITY_STATES: 291 if ts["entity_id"] == "sensor.temperature": 292 ts["state"] = str(body["temperature"] - 0.5) 293 break 294 affected.append({ 295 "entity_id": entity_id, 296 "state": s["state"], 297 "attributes": s.get("attributes", {}), 298 }) 299 break 300 301 return web.json_response(affected)