debug.py
 1  import json
 2  from typing import Awaitable, Callable, Type
 3  
 4  import pydantic
 5  from fastapi import Request
 6  from fastapi.responses import StreamingResponse
 7  from fastapi.routing import APIRoute
 8  from pydantic import BaseModel
 9  from starlette.concurrency import iterate_in_threadpool
10  
11  from api.logger import get_logger
12  
13  
14  logger = get_logger(__name__)
15  
16  
17  def _check_response_schema(method: str, route: APIRoute, status_code: int, body: bytes) -> None:
18      if status_code in [405, 422]:
19          return
20      if not route.include_in_schema:
21          return
22      if status_code not in route.responses:
23          logger.error(f"[{method} {route.path}] no response schema defined for status code {status_code}")
24          return
25  
26      response = route.responses[status_code]
27  
28      if "model" in response:
29          response_schema: Type[BaseModel] = response["model"]
30          try:
31              pydantic.parse_raw_as(response_schema, body)
32          except Exception as e:
33              logger.error(f"[{method} {route.path}] response schema validation failed ({status_code}):\n{e}")
34      elif not json.loads(body) in (
35          v.get("value") for v in response.get("content", {}).get("application/json", {}).get("examples", {}).values()
36      ):
37          logger.error(f"[{method} {route.path}] response schema validation failed ({status_code})")
38  
39  
40  async def check_responses(
41      request: Request, call_next: Callable[..., Awaitable[StreamingResponse]]
42  ) -> StreamingResponse:
43      response: StreamingResponse = await call_next(request)
44      if response.headers.get("Content-type") != "application/json":
45          return response
46  
47      chunks = [chunk async for chunk in response.body_iterator]
48      body = b"".join(chunks)
49  
50      response.body_iterator = iterate_in_threadpool(iter(chunks))
51  
52      if route := request.scope.get("route"):
53          _check_response_schema(request.method, route, response.status_code, body)
54  
55      return response