server.py
  1  """
  2  Basic Test LLM Server mimicking an OpenAI-compatible API endpoint.
  3  Provides configurable static responses and captures incoming requests for verification.
  4  """
  5  
  6  from fastapi import FastAPI, Request, HTTPException
  7  from starlette.responses import StreamingResponse
  8  from pydantic import BaseModel, Field
  9  from typing import List, Dict, Any, Optional, Union, Literal, AsyncGenerator
 10  import uvicorn
 11  import json
 12  import threading
 13  import time
 14  import asyncio
 15  import logging
 16  import os
 17  import re
 18  import base64
 19  
 20  
 21  class ToolCallFunction(BaseModel):
 22      name: str
 23      arguments: str
 24  
 25  
 26  class ToolCall(BaseModel):
 27      id: str
 28      type: Literal["function"] = "function"
 29      function: ToolCallFunction
 30  
 31  
 32  class Message(BaseModel):
 33      role: str
 34      content: Optional[Union[str, List[Dict[str, Any]]]] = None
 35      tool_calls: Optional[List[ToolCall]] = None
 36      tool_call_id: Optional[str] = None
 37  
 38  
 39  class ToolCallDeltaFunction(BaseModel):
 40      name: Optional[str] = None
 41      arguments: Optional[str] = None
 42  
 43  
 44  class ToolCallDelta(BaseModel):
 45      index: int
 46      id: Optional[str] = None
 47      type: Optional[Literal["function"]] = None
 48      function: Optional[ToolCallDeltaFunction] = None
 49  
 50  
 51  class DeltaMessage(BaseModel):
 52      role: Optional[str] = None
 53      content: Optional[str] = None
 54      tool_calls: Optional[List[ToolCallDelta]] = None
 55  
 56  
 57  class StreamingChoice(BaseModel):
 58      index: int = 0
 59      delta: DeltaMessage
 60      finish_reason: Optional[str] = None
 61  
 62  
 63  class ChatCompletionChunk(BaseModel):
 64      id: str = Field(default_factory=lambda: f"chatcmpl-test-stream-{int(time.time())}")
 65      object: str = "chat.completion.chunk"
 66      created: int = Field(default_factory=lambda: int(time.time()))
 67      model: str
 68      choices: List[StreamingChoice]
 69  
 70  
 71  class Choice(BaseModel):
 72      index: int = 0
 73      message: Message
 74      finish_reason: Optional[str] = "stop"
 75  
 76  
 77  class Usage(BaseModel):
 78      prompt_tokens: int = 0
 79      completion_tokens: int = 0
 80      total_tokens: int = 0
 81  
 82  
 83  class ChatCompletionResponse(BaseModel):
 84      id: str = "chatcmpl-test"
 85      object: str = "chat.completion"
 86      created: int = Field(default_factory=lambda: int(time.time()))
 87      model: str = "test-llm-model"
 88      choices: List[Choice]
 89      usage: Optional[Usage] = Field(default_factory=Usage)
 90  
 91  
 92  class ChatCompletionRequest(BaseModel):
 93      model: str
 94      messages: List[Message]
 95      tools: Optional[List[Dict[str, Any]]] = None
 96      tool_choice: Optional[Union[str, Dict[str, Any]]] = None
 97      stream: Optional[bool] = False
 98  
 99  
100  app = FastAPI()
101  
102  
103  class TestLLMServer:
104      DEFAULT_RESPONSE_DELAY_SECONDS: float = 0.01
105  
106      def __init__(self, host: str = "127.0.0.1", port: int = 8088):
107          self.host = host
108          self.port = port
109          self._server_thread: Optional[threading.Thread] = None
110          self._static_response: Optional[ChatCompletionResponse] = None
111          self._primed_responses: List[ChatCompletionResponse] = []
112          self._primed_image_responses: List[Dict[str, Any]] = []
113          self._primed_response_lock = threading.Lock()
114          self.captured_requests: List[ChatCompletionRequest] = []
115          self._app = app # Keep a reference to the FastAPI app
116          self._uvicorn_server: Optional[uvicorn.Server] = None # To store the server instance
117          self.response_delay_seconds: float = self.DEFAULT_RESPONSE_DELAY_SECONDS
118          self._setup_logger()
119          self._setup_routes()
120          self._stateful_responses_cache: Dict[str, List[Any]] = {}
121          self._stateful_cache_lock = threading.Lock()
122  
123      def _setup_logger(self):
124          """Sets up a dedicated logger for the TestLLMServer."""
125          self.logger = logging.getLogger("TestLLMServer")
126          self.logger.setLevel(logging.DEBUG)
127  
128          self.logger.propagate = False
129  
130          for handler in self.logger.handlers[:]:
131              self.logger.removeHandler(handler)
132  
133          log_file_path = os.path.join(os.getcwd(), "test_llm_server.log")
134          file_handler = logging.FileHandler(log_file_path, mode="a")
135          file_handler.setFormatter(
136              logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
137          )
138          self.logger.addHandler(file_handler)
139          self.logger.info(
140              f"TestLLMServer logger initialized. Logging to: {log_file_path}"
141          )
142  
143      @property
144      def started(self) -> bool:
145          """Checks if the uvicorn server instance is started."""
146          return self._uvicorn_server is not None and self._uvicorn_server.started
147  
148      def _setup_routes(self):
149          @self._app.post("/v1/images/generations")
150          async def image_generations(request: Request):
151              await asyncio.sleep(0.01)
152  
153              response_data = None
154              with self._primed_response_lock:
155                  if self._primed_image_responses:
156                      response_data = self._primed_image_responses.pop(0)
157  
158              if response_data:
159                  status_code = response_data.get("status_code", 200)
160                  response_json_str = response_data.get("response", "{}")
161                  return json.loads(response_json_str)
162              else:
163                  raise HTTPException(status_code=404, detail="No primed image response")
164  
165          @self._app.post("/v1/chat/completions")
166          async def chat_completions(
167              request: ChatCompletionRequest, raw_request: Request
168          ):
169              raw_body_bytes = await raw_request.body()
170              raw_body_str = raw_body_bytes.decode("utf-8")
171              self.logger.debug(f"Received raw request body:\n{raw_body_str}")
172              self.logger.debug(
173                  f"Parsed ChatCompletionRequest model:\n{request.model_dump_json(indent=2)}"
174              )
175  
176              if request.messages:
177                  for i, msg in enumerate(request.messages):
178                      self.logger.debug(f"Message {i} - Role: {msg.role}")
179                      self.logger.debug(
180                          f"Message {i} - Content Type: {type(msg.content)}"
181                      )
182                      self.logger.debug(f"Message {i} - Content Value: {msg.content}")
183                      if msg.tool_calls:
184                          self.logger.debug(f"Message {i} - Tool Calls: {msg.tool_calls}")
185  
186              self.captured_requests.append(request.model_copy(deep=True))
187  
188              # Add a small delay to simulate network latency and force the event
189              # loop to yield, ensuring true concurrency in stress tests.
190              await asyncio.sleep(self.response_delay_seconds)
191  
192              initial_prompt = request.messages[0].content if request.messages else ""
193              if isinstance(initial_prompt, str):
194                  case_id_match = re.search(r"\[test_case_id=([\w.-]+)\]", initial_prompt)
195                  if case_id_match:
196                      case_id = case_id_match.group(1)
197                      self.logger.info(f"Stateful test case detected: {case_id}")
198  
199                      with self._stateful_cache_lock:
200                          if case_id not in self._stateful_responses_cache:
201                              self.logger.info(
202                                  f"Caching responses for new test case: {case_id}"
203                              )
204                              responses_match = re.search(
205                                  r"\[responses_json=([\w=+/]+)\]", initial_prompt
206                              )
207                              if responses_match:
208                                  b64_str = responses_match.group(1)
209                                  try:
210                                      json_str = base64.b64decode(b64_str).decode("utf-8")
211                                      self._stateful_responses_cache[case_id] = (
212                                          json.loads(json_str)
213                                      )
214                                      self.logger.info(
215                                          f"Cached {len(self._stateful_responses_cache[case_id])} responses for {case_id}"
216                                      )
217                                  except (
218                                      base64.binascii.Error,
219                                      json.JSONDecodeError,
220                                      UnicodeDecodeError,
221                                  ) as e:
222                                      self.logger.error(
223                                          f"Failed to decode stateful responses for {case_id}: {e}"
224                                      )
225                                      raise HTTPException(
226                                          status_code=500,
227                                          detail=f"Stateful test case '{case_id}' has invalid [responses_json] directive.",
228                                      )
229                              else:
230                                  self.logger.error(
231                                      f"No [responses_json] directive found for stateful test case: {case_id}"
232                                  )
233                                  raise HTTPException(
234                                      status_code=500,
235                                      detail=f"Stateful test case '{case_id}' found but no [responses_json] directive.",
236                                  )
237  
238                      turn_index = (len(request.messages) - 1) // 2
239                      self.logger.info(
240                          f"Request for turn {turn_index} of test case {case_id}"
241                      )
242  
243                      with self._stateful_cache_lock:
244                          if turn_index < len(self._stateful_responses_cache[case_id]):
245                              response_spec = self._stateful_responses_cache[case_id][
246                                  turn_index
247                              ]
248                              self.logger.info(
249                                  f"Serving response for turn {turn_index} of test case {case_id}"
250                              )
251                          else:
252                              self.logger.error(
253                                  f"Test case {case_id} ran out of responses. Requested turn {turn_index}, but only {len(self._stateful_responses_cache[case_id])} defined."
254                              )
255                              raise HTTPException(
256                                  status_code=500,
257                                  detail=f"Stateful test case '{case_id}' ran out of responses. Requested turn {turn_index}, but only {len(self._stateful_responses_cache[case_id])} are defined.",
258                              )
259  
260                      if isinstance(response_spec, dict) and response_spec.get(
261                          "status_code"
262                      ):
263                          status_code = response_spec["status_code"]
264                          detail = response_spec.get("json_body", {}).get(
265                              "error", "Test server error"
266                          )
267                          self.logger.info(
268                              f"Simulating HTTP error with status code {status_code} and detail '{detail}'"
269                          )
270                          raise HTTPException(status_code=status_code, detail=detail)
271  
272                      if isinstance(response_spec, dict):
273                          if "expected_request" in response_spec:
274                              self._verify_expected_request(
275                                  request,
276                                  response_spec["expected_request"],
277                                  case_id,
278                                  turn_index,
279                              )
280                          response_to_serve = ChatCompletionResponse(
281                              **response_spec.get("static_response", {})
282                          )
283                      else:
284                          response_to_serve = response_spec
285  
286                      if request.stream:
287                          self.logger.info(
288                              f"Handling stream request for model {request.model}"
289                          )
290                          return StreamingResponse(
291                              self._generate_stream_chunks(
292                                  response_to_serve, request.model
293                              ),
294                              media_type="text/event-stream",
295                          )
296                      else:
297                          self.logger.info(
298                              f"Serving non-streamed response for model {request.model}"
299                          )
300                          return response_to_serve
301  
302              response_spec = None
303              with self._primed_response_lock:
304                  if self._primed_responses:
305                      response_spec = self._primed_responses.pop(0)
306                      self.logger.info(
307                          f"Using primed response. {len(self._primed_responses)} remaining."
308                      )
309                  elif self._static_response:
310                      response_spec = self._static_response
311                      self.logger.info("Using globally configured static response.")
312                  else:
313                      self.logger.info("Using default response.")
314                      default_message = Message(
315                          role="assistant",
316                          content="Default response from Test LLM Server (no specific response primed or configured)",
317                      )
318                      default_choice = Choice(
319                          message=default_message, finish_reason="stop"
320                      )
321                      response_spec = ChatCompletionResponse(choices=[default_choice])
322  
323              if not response_spec:
324                  self.logger.error(
325                      "No response configured and default failed to generate."
326                  )
327                  raise HTTPException(
328                      status_code=500, detail="TestLLMServer: No response configured."
329                  )
330  
331              if isinstance(response_spec, dict) and response_spec.get("status_code"):
332                  status_code = response_spec["status_code"]
333                  detail = response_spec.get("json_body", {}).get(
334                      "error", "Test server error"
335                  )
336                  self.logger.info(
337                      f"Simulating HTTP error with status code {status_code} and detail '{detail}'"
338                  )
339                  raise HTTPException(status_code=status_code, detail=detail)
340  
341              if isinstance(response_spec, dict):
342                  response_to_serve = ChatCompletionResponse(**response_spec)
343              else:
344                  response_to_serve = response_spec
345  
346              if request.stream:
347                  self.logger.info(f"Handling stream request for model {request.model}")
348                  return StreamingResponse(
349                      self._generate_stream_chunks(response_to_serve, request.model),
350                      media_type="text/event-stream",
351                  )
352              else:
353                  self.logger.info(
354                      f"Serving non-streamed response for model {request.model}"
355                  )
356                  return response_to_serve
357  
358      async def _generate_stream_chunks(
359          self, full_response: ChatCompletionResponse, request_model: str
360      ) -> AsyncGenerator[str, None]:
361          """
362          Asynchronously generates SSE formatted delta chunks from a full ChatCompletionResponse.
363          """
364          try:
365              if (
366                  full_response.choices
367                  and full_response.choices[0].message.role == "assistant"
368              ):
369                  role_chunk = ChatCompletionChunk(
370                      model=request_model,
371                      choices=[StreamingChoice(delta=DeltaMessage(role="assistant"))],
372                  )
373                  yield f"data: {role_chunk.model_dump_json()}\n\n"
374                  await asyncio.sleep(0.01)
375  
376              full_content = full_response.choices[0].message.content
377              if isinstance(full_content, str) and full_content:
378                  num_chunks = 3
379                  content_len = len(full_content)
380                  if content_len == 0:
381                      pass
382                  elif content_len < num_chunks:
383                      num_chunks = 1
384  
385                  approx_chunk_size = (content_len + num_chunks - 1) // num_chunks
386  
387                  for i in range(num_chunks):
388                      start_idx = i * approx_chunk_size
389                      end_idx = min((i + 1) * approx_chunk_size, content_len)
390                      content_delta = full_content[start_idx:end_idx]
391  
392                      if content_delta:
393                          content_chunk_obj = ChatCompletionChunk(
394                              model=request_model,
395                              choices=[
396                                  StreamingChoice(
397                                      delta=DeltaMessage(content=content_delta)
398                                  )
399                              ],
400                          )
401                          yield f"data: {content_chunk_obj.model_dump_json()}\n\n"
402                          await asyncio.sleep(0.01)
403  
404              tool_calls_from_full_response = full_response.choices[0].message.tool_calls
405              if tool_calls_from_full_response:
406                  for tc_idx, complete_tool_call in enumerate(
407                      tool_calls_from_full_response
408                  ):
409                      chunk1_delta = DeltaMessage(
410                          tool_calls=[
411                              ToolCallDelta(
412                                  index=tc_idx,
413                                  id=complete_tool_call.id,
414                                  type="function",
415                                  function=ToolCallDeltaFunction(
416                                      name=complete_tool_call.function.name, arguments=""
417                                  ),
418                              )
419                          ]
420                      )
421                      chunk1_obj = ChatCompletionChunk(
422                          model=request_model,
423                          choices=[StreamingChoice(delta=chunk1_delta)],
424                      )
425                      yield f"data: {chunk1_obj.model_dump_json()}\n\n"
426                      await asyncio.sleep(0.01)
427  
428                      chunk2_delta = DeltaMessage(
429                          tool_calls=[
430                              ToolCallDelta(
431                                  index=tc_idx,
432                                  id=complete_tool_call.id,
433                                  type="function",
434                                  function=ToolCallDeltaFunction(
435                                      arguments=complete_tool_call.function.arguments
436                                  ),
437                              )
438                          ]
439                      )
440                      chunk2_obj = ChatCompletionChunk(
441                          model=request_model,
442                          choices=[StreamingChoice(delta=chunk2_delta)],
443                      )
444                      yield f"data: {chunk2_obj.model_dump_json()}\n\n"
445                      await asyncio.sleep(0.01)
446  
447              finish_reason = full_response.choices[0].finish_reason
448              final_delta_message = DeltaMessage()
449  
450              if finish_reason:
451                  final_choice = StreamingChoice(
452                      delta=final_delta_message, finish_reason=finish_reason
453                  )
454                  final_chunk_dict = ChatCompletionChunk(
455                      model=request_model, choices=[final_choice]
456                  ).model_dump(exclude_none=True)
457  
458                  if full_response.usage:
459                      final_chunk_dict["usage"] = full_response.usage.model_dump()
460                      self.logger.info(
461                          f"Adding usage data to final stream chunk: {final_chunk_dict['usage']}"
462                      )
463  
464                  yield f"data: {json.dumps(final_chunk_dict)}\n\n"
465                  await asyncio.sleep(0.01)
466  
467          except Exception as e:
468              self.logger.error(f"Error during stream generation: {e}", exc_info=True)
469              error_payload = {
470                  "error": {
471                      "message": f"Stream generation error: {str(e)}",
472                      "type": "server_error",
473                      "code": 500,
474                  }
475              }
476              yield f"data: {json.dumps(error_payload)}\n\n"
477          finally:
478              yield "data: [DONE]\n\n"
479              self.logger.info("Stream finished, sent [DONE].")
480  
481      def _verify_tool_declarations(
482          self,
483          actual_tools: List[Dict],
484          expected_declarations: List[Dict],
485          case_id: str,
486          turn_index: int,
487      ):
488          """Verifies that the tool declarations sent to the LLM match expectations."""
489          actual_tool_map = {
490              tool.get("function", {}).get("name"): tool.get("function", {})
491              for tool in actual_tools
492          }
493  
494          for expected_decl in expected_declarations:
495              expected_name = expected_decl.get("name")
496              if not expected_name:
497                  raise HTTPException(
498                      status_code=500,
499                      detail=f"Stateful test case '{case_id}' turn {turn_index}: "
500                      f"expected_tool_declarations_contain item is missing 'name'.",
501                  )
502  
503              if expected_name not in actual_tool_map:
504                  raise HTTPException(
505                      status_code=500,
506                      detail=f"Stateful test case '{case_id}' turn {turn_index}: "
507                      f"Expected tool '{expected_name}' was not declared to the LLM. "
508                      f"Actual tools: {list(actual_tool_map.keys())}",
509                  )
510  
511              actual_decl = actual_tool_map[expected_name]
512              if "description_contains" in expected_decl:
513                  expected_desc_substr = expected_decl["description_contains"]
514                  actual_desc = actual_decl.get("description", "")
515                  if expected_desc_substr not in actual_desc:
516                      raise HTTPException(
517                          status_code=500,
518                          detail=f"Stateful test case '{case_id}' turn {turn_index}: "
519                          f"Description for tool '{expected_name}' did not match. "
520                          f"Expected to contain: '{expected_desc_substr}'. "
521                          f"Actual: '{actual_desc}'",
522                      )
523  
524      def _verify_tool_responses(
525          self,
526          actual_messages: List[Message],
527          expected_responses: List[Dict],
528          case_id: str,
529          turn_index: int,
530      ):
531          """Verifies that tool responses in the LLM history match expectations."""
532          tool_messages = [
533              msg for msg in actual_messages if msg.role == "tool" and msg.tool_call_id
534          ]
535  
536          if len(tool_messages) != len(expected_responses):
537              raise HTTPException(
538                  status_code=500,
539                  detail=f"Stateful test case '{case_id}' turn {turn_index}: "
540                  f"Mismatch in number of tool responses. "
541                  f"Expected {len(expected_responses)}, Got {len(tool_messages)}.",
542              )
543  
544          # Find the previous request to match tool_call_ids
545          # The current request is the last one in captured_requests.
546          # The one that *made* the tool call is the one before that.
547          if len(self.captured_requests) < 2:
548              raise HTTPException(
549                  status_code=500,
550                  detail=f"Stateful test case '{case_id}' turn {turn_index}: "
551                  f"Cannot verify tool responses, not enough request history captured.",
552              )
553          prior_request = self.captured_requests[-2]
554          prior_tool_calls = (
555              prior_request.messages[-1].tool_calls
556              if prior_request.messages and prior_request.messages[-1].tool_calls
557              else []
558          )
559  
560          for expected_resp in expected_responses:
561              tool_call_id_to_match = None
562              prior_request_index = expected_resp.get(
563                  "tool_call_id_matches_prior_request_index"
564              )
565              if prior_request_index is not None:
566                  if prior_request_index < len(prior_tool_calls):
567                      tool_call_id_to_match = prior_tool_calls[prior_request_index].id
568                  else:
569                      raise HTTPException(
570                          status_code=500,
571                          detail=f"Stateful test case '{case_id}' turn {turn_index}: "
572                          f"Invalid tool_call_id_matches_prior_request_index: {prior_request_index}. "
573                          f"Prior request only had {len(prior_tool_calls)} tool calls.",
574                      )
575  
576              if not tool_call_id_to_match:
577                  raise HTTPException(
578                      status_code=500,
579                      detail=f"Stateful test case '{case_id}' turn {turn_index}: "
580                      f"Could not determine tool_call_id for expected response: {expected_resp}",
581                  )
582  
583              actual_tool_msg = next(
584                  (
585                      msg
586                      for msg in tool_messages
587                      if msg.tool_call_id == tool_call_id_to_match
588                  ),
589                  None,
590              )
591  
592              if not actual_tool_msg:
593                  raise HTTPException(
594                      status_code=500,
595                      detail=f"Stateful test case '{case_id}' turn {turn_index}: "
596                      f"No tool response found for tool_call_id '{tool_call_id_to_match}'.",
597                  )
598  
599              if "response_json_matches" in expected_resp:
600                  expected_json = expected_resp["response_json_matches"]
601                  try:
602                      actual_json = json.loads(actual_tool_msg.content)
603                      if actual_json != expected_json:
604                          raise HTTPException(
605                              status_code=500,
606                              detail=f"Stateful test case '{case_id}' turn {turn_index}: "
607                              f"JSON content for tool '{tool_call_id_to_match}' did not match.\n"
608                              f"Expected: {json.dumps(expected_json)}\n"
609                              f"Actual:   {json.dumps(actual_json)}",
610                          )
611                  except json.JSONDecodeError:
612                      raise HTTPException(
613                          status_code=500,
614                          detail=f"Stateful test case '{case_id}' turn {turn_index}: "
615                          f"Tool response for '{tool_call_id_to_match}' was not valid JSON. "
616                          f"Content: {actual_tool_msg.content}",
617                      )
618  
619              if "response_contains" in expected_resp:
620                  expected_substr = expected_resp["response_contains"]
621                  if expected_substr not in str(actual_tool_msg.content):
622                      raise HTTPException(
623                          status_code=500,
624                          detail=f"Stateful test case '{case_id}' turn {turn_index}: "
625                          f"Content for tool '{tool_call_id_to_match}' did not contain expected substring.\n"
626                          f"Expected to contain: '{expected_substr}'\n"
627                          f"Actual:              '{actual_tool_msg.content}'",
628                      )
629  
630      def _verify_expected_request(
631          self,
632          request: ChatCompletionRequest,
633          expected_request_spec: Dict,
634          case_id: str,
635          turn_index: int,
636      ):
637          """Dispatches verification checks based on keys in the expected_request spec."""
638          if "expected_tool_declarations_contain" in expected_request_spec:
639              self._verify_tool_declarations(
640                  request.tools or [],
641                  expected_request_spec["expected_tool_declarations_contain"],
642                  case_id,
643                  turn_index,
644              )
645          if "expected_tool_responses_in_llm_messages" in expected_request_spec:
646              self._verify_tool_responses(
647                  request.messages,
648                  expected_request_spec["expected_tool_responses_in_llm_messages"],
649                  case_id,
650                  turn_index,
651              )
652  
653      def configure_static_response(
654          self, response: Union[Dict[str, Any], ChatCompletionResponse]
655      ):
656          """
657          Configures a single static response that the server will return if no
658          dynamically primed responses are available.
659          Accepts either a dictionary (which will be parsed into ChatCompletionResponse)
660          or a ChatCompletionResponse object directly.
661          """
662          if isinstance(response, dict):
663              self._static_response = ChatCompletionResponse(**response)
664          elif isinstance(response, ChatCompletionResponse):
665              self._static_response = response
666          else:
667              raise TypeError(
668                  "Static response must be a dict or ChatCompletionResponse object."
669              )
670          self.logger.info("Global static response configured.")
671  
672      def prime_responses(
673          self, responses: List[Union[Dict[str, Any], ChatCompletionResponse]]
674      ):
675          """
676          Primes the server with a sequence of responses to serve for subsequent requests.
677          Each call to this method overwrites any previously primed responses.
678          """
679          with self._primed_response_lock:
680              self._primed_responses = []
681              for rsp in responses:
682                  if isinstance(rsp, dict):
683                      if rsp.get("status_code"):
684                          self._primed_responses.append(rsp)
685                      else:
686                          self._primed_responses.append(ChatCompletionResponse(**rsp))
687                  elif isinstance(rsp, ChatCompletionResponse):
688                      self._primed_responses.append(rsp)
689                  else:
690                      raise TypeError(
691                          "Each response in the list must be a dict or ChatCompletionResponse object."
692                      )
693              self.logger.info(f"Primed with {len(self._primed_responses)} responses.")
694  
695      def prime_image_generation_responses(self, responses: List[Dict[str, Any]]):
696          with self._primed_response_lock:
697              self._primed_image_responses = responses
698              self.logger.info(
699                  f"Primed with {len(self._primed_image_responses)} image generation responses."
700              )
701  
702      def set_response_delay(self, seconds: float):
703          """Sets a delay for all responses from the chat_completions endpoint."""
704          self.response_delay_seconds = seconds
705          self.logger.info(f"LLM server response delay set to {seconds} seconds.")
706  
707      def clear_all_configurations(self):
708          """Clears primed responses, the global static response, captured requests, and resets response delay."""
709          with self._primed_response_lock:
710              self._primed_responses = []
711              self._primed_image_responses = []
712          self._static_response = None
713          self.captured_requests = []
714          with self._stateful_cache_lock:
715              self._stateful_responses_cache.clear()
716          self.response_delay_seconds = self.DEFAULT_RESPONSE_DELAY_SECONDS
717          self.logger.info(
718              "All configurations (primed, static, captured requests, response delay) cleared."
719          )
720  
721      def clear_stateful_cache_for_id(self, case_id: str):
722          """Removes a specific test case ID from the stateful response cache."""
723          with self._stateful_cache_lock:
724              if case_id in self._stateful_responses_cache:
725                  del self._stateful_responses_cache[case_id]
726                  self.logger.info(f"Cleared stateful cache for test case ID: {case_id}")
727  
728      def get_captured_requests(self) -> List[ChatCompletionRequest]:
729          return self.captured_requests
730  
731      def clear_captured_requests(self):
732          self.captured_requests = []
733  
734      def start(self):
735          """Starts the FastAPI server in a separate thread."""
736          if self._server_thread is not None and self._server_thread.is_alive():
737              self.logger.warning("TestLLMServer is already running.")
738              return
739  
740          self.clear_all_configurations()
741  
742          config = uvicorn.Config(
743              self._app, host=self.host, port=self.port, log_level="warning"
744          )
745          self._uvicorn_server = uvicorn.Server(config)
746  
747          async def async_serve_wrapper():
748              """Coroutine to run the server's serve() method and handle potential errors."""
749              try:
750                  if self._uvicorn_server:
751                      await self._uvicorn_server.serve()
752              except asyncio.CancelledError:
753                  self.logger.info("Server.serve() task was cancelled.")
754              except Exception as e:
755                  self.logger.error(f"Error during server.serve(): {e}", exc_info=True)
756  
757          def run_server_in_new_loop():
758              """Target function for the server thread. Sets up and runs an event loop."""
759              loop = asyncio.new_event_loop()
760              asyncio.set_event_loop(loop)
761              try:
762                  loop.run_until_complete(async_serve_wrapper())
763              except KeyboardInterrupt:
764                  print("TestLLMServer: KeyboardInterrupt in server thread.")
765              finally:
766                  try:
767                      all_tasks = asyncio.all_tasks(loop)
768                      if all_tasks:
769                          for task in all_tasks:
770                              task.cancel()
771                          loop.run_until_complete(
772                              asyncio.gather(*all_tasks, return_exceptions=True)
773                          )
774  
775                      if hasattr(loop, "shutdown_asyncgens"):
776                          loop.run_until_complete(loop.shutdown_asyncgens())
777                  except Exception as e:
778                      self.logger.error(
779                          f"Error during loop shutdown tasks: {e}", exc_info=True
780                      )
781                  finally:
782                      loop.close()
783                      self.logger.info("Event loop in server thread closed.")
784  
785          self._server_thread = threading.Thread(
786              target=run_server_in_new_loop, daemon=True
787          )
788          self._server_thread.start()
789  
790          self.logger.info(f"TestLLMServer starting on http://{self.host}:{self.port}...")
791  
792      def stop(self):
793          """Stops the FastAPI server."""
794          if self._uvicorn_server:
795              self._uvicorn_server.should_exit = True
796  
797          if self._server_thread and self._server_thread.is_alive():
798              self.logger.info("TestLLMServer stopping, joining thread...")
799              self._server_thread.join(timeout=5.0)
800              if self._server_thread.is_alive():
801                  self.logger.warning("Server thread did not exit cleanly.")
802          self._server_thread = None
803          self._uvicorn_server = None
804          self.logger.info("TestLLMServer stopped.")
805  
806      @property
807      def url(self) -> str:
808          return f"http://{self.host}:{self.port}"
809  
810  
811  if __name__ == "__main__":
812      if __name__ == "__main__":
813          logging.basicConfig(
814              level=logging.INFO,
815              format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
816          )
817  
818          server = TestLLMServer()
819          server.start()
820  
821          sample_response_data = {
822              "choices": [
823                  {
824                      "message": {
825                          "role": "assistant",
826                          "content": "Hello from the Test LLM!",
827                      },
828                      "finish_reason": "stop",
829                  }
830              ]
831          }
832          server.configure_static_response(sample_response_data)
833          server.logger.info(
834              f"Test LLM Server running at {server.url}. Configured with a static response."
835          )
836          server.logger.info(
837              'Try: curl -X POST -H "Content-Type: application/json" -d \'{"model": "test", "messages": [{"role": "user", "content": "Hi"}]}\' http://127.0.0.1:8088/v1/chat/completions'
838          )
839  
840          try:
841              while True:
842                  time.sleep(1)
843          except KeyboardInterrupt:
844              server.logger.info("Shutting down Test LLM Server...")
845          finally:
846              server.stop()