/ tests / fakes / fake_ha_server.py
fake_ha_server.py
  1  """Fake Home Assistant server for integration testing.
  2  
  3  Provides a real HTTP + WebSocket server (via aiohttp.web) that mimics the
  4  Home Assistant API surface used by hermes-agent:
  5  
  6  - ``/api/websocket``  -- WebSocket auth handshake + event push
  7  - ``/api/states``     -- GET all entity states
  8  - ``/api/states/{entity_id}`` -- GET single entity state
  9  - ``/api/services/{domain}/{service}`` -- POST service call
 10  - ``/api/services/persistent_notification/create`` -- POST notification
 11  
 12  Usage::
 13  
 14      async with FakeHAServer(token="test-token") as server:
 15          url = server.url            # e.g. "http://127.0.0.1:54321"
 16          await server.push_event(event_data)
 17          assert server.received_notifications  # verify what arrived
 18  """
 19  
 20  import asyncio
 21  import json
 22  from typing import Any, Dict, List, Optional
 23  
 24  import aiohttp
 25  from aiohttp import web
 26  from aiohttp.test_utils import TestServer
 27  
 28  
 29  # -- Sample entity data -------------------------------------------------------
 30  
 31  ENTITY_STATES: List[Dict[str, Any]] = [
 32      {
 33          "entity_id": "light.bedroom",
 34          "state": "on",
 35          "attributes": {"friendly_name": "Bedroom Light", "brightness": 200},
 36          "last_changed": "2025-01-15T10:30:00+00:00",
 37          "last_updated": "2025-01-15T10:30:00+00:00",
 38      },
 39      {
 40          "entity_id": "light.kitchen",
 41          "state": "off",
 42          "attributes": {"friendly_name": "Kitchen Light"},
 43          "last_changed": "2025-01-15T09:00:00+00:00",
 44          "last_updated": "2025-01-15T09:00:00+00:00",
 45      },
 46      {
 47          "entity_id": "sensor.temperature",
 48          "state": "22.5",
 49          "attributes": {
 50              "friendly_name": "Kitchen Temperature",
 51              "unit_of_measurement": "C",
 52          },
 53          "last_changed": "2025-01-15T10:00:00+00:00",
 54          "last_updated": "2025-01-15T10:00:00+00:00",
 55      },
 56      {
 57          "entity_id": "switch.fan",
 58          "state": "on",
 59          "attributes": {"friendly_name": "Living Room Fan"},
 60          "last_changed": "2025-01-15T08:00:00+00:00",
 61          "last_updated": "2025-01-15T08:00:00+00:00",
 62      },
 63      {
 64          "entity_id": "climate.thermostat",
 65          "state": "heat",
 66          "attributes": {
 67              "friendly_name": "Main Thermostat",
 68              "current_temperature": 21,
 69              "temperature": 23,
 70          },
 71          "last_changed": "2025-01-15T07:00:00+00:00",
 72          "last_updated": "2025-01-15T07:00:00+00:00",
 73      },
 74  ]
 75  
 76  
 77  class FakeHAServer:
 78      """In-process fake Home Assistant for integration tests.
 79  
 80      Parameters
 81      ----------
 82      token : str
 83          The expected Bearer token for authentication.
 84      """
 85  
 86      def __init__(self, token: str = "test-token-123"):
 87          self.token = token
 88  
 89          # Observability -- tests inspect these after exercising the adapter.
 90          self.received_service_calls: List[Dict[str, Any]] = []
 91          self.received_notifications: List[Dict[str, Any]] = []
 92  
 93          # Control -- tests push events, server forwards them over WS.
 94          self._event_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
 95  
 96          # Flag to simulate auth rejection.
 97          self.reject_auth = False
 98  
 99          # Flag to simulate server errors.
100          self.force_500 = False
101  
102          # Internal bookkeeping.
103          self._app: Optional[web.Application] = None
104          self._server: Optional[TestServer] = None
105          self._ws_connections: List[web.WebSocketResponse] = []
106  
107      # -- Public helpers --------------------------------------------------------
108  
109      @property
110      def url(self) -> str:
111          """Base URL of the running server, e.g. ``http://127.0.0.1:12345``."""
112          assert self._server is not None, "Server not started"
113          host = self._server.host
114          port = self._server.port
115          return f"http://{host}:{port}"
116  
117      async def push_event(self, event_data: Dict[str, Any]) -> None:
118          """Enqueue a state_changed event for delivery over WebSocket."""
119          await self._event_queue.put(event_data)
120  
121      # -- Lifecycle -------------------------------------------------------------
122  
123      async def start(self) -> None:
124          self._app = self._build_app()
125          self._server = TestServer(self._app)
126          await self._server.start_server()
127  
128      async def stop(self) -> None:
129          # Close any remaining WS connections.
130          for ws in self._ws_connections:
131              if not ws.closed:
132                  await ws.close()
133          self._ws_connections.clear()
134          if self._server is not None:
135              await self._server.close()
136  
137      async def __aenter__(self) -> "FakeHAServer":
138          await self.start()
139          return self
140  
141      async def __aexit__(self, *exc) -> None:
142          await self.stop()
143  
144      # -- Application construction ----------------------------------------------
145  
146      def _build_app(self) -> web.Application:
147          app = web.Application()
148          app.router.add_get("/api/websocket", self._handle_ws)
149          app.router.add_get("/api/states", self._handle_get_states)
150          app.router.add_get("/api/states/{entity_id}", self._handle_get_state)
151          # Notification endpoint must be registered before the generic service
152          # route so that it takes priority.
153          app.router.add_post(
154              "/api/services/persistent_notification/create",
155              self._handle_notification,
156          )
157          app.router.add_post(
158              "/api/services/{domain}/{service}",
159              self._handle_call_service,
160          )
161          return app
162  
163      # -- Auth helper -----------------------------------------------------------
164  
165      def _check_rest_auth(self, request: web.Request) -> Optional[web.Response]:
166          """Return a 401 response if the Bearer token is wrong, else None."""
167          auth = request.headers.get("Authorization", "")
168          if auth != f"Bearer {self.token}":
169              return web.Response(status=401, text="Unauthorized")
170          if self.force_500:
171              return web.Response(status=500, text="Internal Server Error")
172          return None
173  
174      # -- WebSocket handler -----------------------------------------------------
175  
176      async def _handle_ws(self, request: web.Request) -> web.WebSocketResponse:
177          ws = web.WebSocketResponse()
178          await ws.prepare(request)
179          self._ws_connections.append(ws)
180  
181          # Step 1: auth_required
182          await ws.send_json({"type": "auth_required", "ha_version": "2025.1.0"})
183  
184          # Step 2: receive auth
185          msg = await ws.receive()
186          if msg.type != aiohttp.WSMsgType.TEXT:
187              await ws.close()
188              return ws
189          auth_msg = json.loads(msg.data)
190  
191          # Step 3: validate
192          if self.reject_auth or auth_msg.get("access_token") != self.token:
193              await ws.send_json({"type": "auth_invalid", "message": "Invalid token"})
194              await ws.close()
195              return ws
196  
197          await ws.send_json({"type": "auth_ok", "ha_version": "2025.1.0"})
198  
199          # Step 4: subscribe_events
200          msg = await ws.receive()
201          if msg.type != aiohttp.WSMsgType.TEXT:
202              await ws.close()
203              return ws
204          sub_msg = json.loads(msg.data)
205          sub_id = sub_msg.get("id", 1)
206  
207          # Step 5: ACK
208          await ws.send_json({
209              "id": sub_id,
210              "type": "result",
211              "success": True,
212              "result": None,
213          })
214  
215          # Step 6: push events from queue until closed
216          try:
217              while not ws.closed:
218                  try:
219                      event_data = await asyncio.wait_for(
220                          self._event_queue.get(), timeout=0.1,
221                      )
222                      await ws.send_json({
223                          "id": sub_id,
224                          "type": "event",
225                          "event": event_data,
226                      })
227                  except asyncio.TimeoutError:
228                      continue
229          except (ConnectionResetError, asyncio.CancelledError):
230              pass
231  
232          return ws
233  
234      # -- REST handlers ---------------------------------------------------------
235  
236      async def _handle_get_states(self, request: web.Request) -> web.Response:
237          err = self._check_rest_auth(request)
238          if err:
239              return err
240          return web.json_response(ENTITY_STATES)
241  
242      async def _handle_get_state(self, request: web.Request) -> web.Response:
243          err = self._check_rest_auth(request)
244          if err:
245              return err
246          entity_id = request.match_info["entity_id"]
247          for s in ENTITY_STATES:
248              if s["entity_id"] == entity_id:
249                  return web.json_response(s)
250          return web.Response(status=404, text=f"Entity {entity_id} not found")
251  
252      async def _handle_notification(self, request: web.Request) -> web.Response:
253          err = self._check_rest_auth(request)
254          if err:
255              return err
256          body = await request.json()
257          self.received_notifications.append(body)
258          return web.json_response([])
259  
260      async def _handle_call_service(self, request: web.Request) -> web.Response:
261          err = self._check_rest_auth(request)
262          if err:
263              return err
264          domain = request.match_info["domain"]
265          service = request.match_info["service"]
266          body = await request.json()
267  
268          self.received_service_calls.append({
269              "domain": domain,
270              "service": service,
271              "data": body,
272          })
273  
274          # Return affected entities (mimics real HA behaviour for light/switch).
275          affected = []
276          entity_id = body.get("entity_id")
277          if entity_id:
278              for s in ENTITY_STATES:
279                  if s["entity_id"] == entity_id:
280                      if service == "turn_on":
281                          s["state"] = "on"
282                      elif service == "turn_off":
283                          s["state"] = "off"
284                      elif service == "set_temperature" and "temperature" in body:
285                          s["attributes"]["temperature"] = body["temperature"]
286                          # Keep current state or set to heat if off
287                          if s["state"] == "off":
288                              s["state"] = "heat"
289                          # Simulate temperature sensor approaching the target
290                          for ts in ENTITY_STATES:
291                              if ts["entity_id"] == "sensor.temperature":
292                                  ts["state"] = str(body["temperature"] - 0.5)
293                                  break
294                      affected.append({
295                          "entity_id": entity_id,
296                          "state": s["state"],
297                          "attributes": s.get("attributes", {}),
298                      })
299                      break
300  
301          return web.json_response(affected)