server.py
1 """ 2 Basic Test LLM Server mimicking an OpenAI-compatible API endpoint. 3 Provides configurable static responses and captures incoming requests for verification. 4 """ 5 6 from fastapi import FastAPI, Request, HTTPException 7 from starlette.responses import StreamingResponse 8 from pydantic import BaseModel, Field 9 from typing import List, Dict, Any, Optional, Union, Literal, AsyncGenerator 10 import uvicorn 11 import json 12 import threading 13 import time 14 import asyncio 15 import logging 16 import os 17 import re 18 import base64 19 20 21 class ToolCallFunction(BaseModel): 22 name: str 23 arguments: str 24 25 26 class ToolCall(BaseModel): 27 id: str 28 type: Literal["function"] = "function" 29 function: ToolCallFunction 30 31 32 class Message(BaseModel): 33 role: str 34 content: Optional[Union[str, List[Dict[str, Any]]]] = None 35 tool_calls: Optional[List[ToolCall]] = None 36 tool_call_id: Optional[str] = None 37 38 39 class ToolCallDeltaFunction(BaseModel): 40 name: Optional[str] = None 41 arguments: Optional[str] = None 42 43 44 class ToolCallDelta(BaseModel): 45 index: int 46 id: Optional[str] = None 47 type: Optional[Literal["function"]] = None 48 function: Optional[ToolCallDeltaFunction] = None 49 50 51 class DeltaMessage(BaseModel): 52 role: Optional[str] = None 53 content: Optional[str] = None 54 tool_calls: Optional[List[ToolCallDelta]] = None 55 56 57 class StreamingChoice(BaseModel): 58 index: int = 0 59 delta: DeltaMessage 60 finish_reason: Optional[str] = None 61 62 63 class ChatCompletionChunk(BaseModel): 64 id: str = Field(default_factory=lambda: f"chatcmpl-test-stream-{int(time.time())}") 65 object: str = "chat.completion.chunk" 66 created: int = Field(default_factory=lambda: int(time.time())) 67 model: str 68 choices: List[StreamingChoice] 69 70 71 class Choice(BaseModel): 72 index: int = 0 73 message: Message 74 finish_reason: Optional[str] = "stop" 75 76 77 class Usage(BaseModel): 78 prompt_tokens: int = 0 79 completion_tokens: int = 0 80 total_tokens: int = 0 81 82 83 class ChatCompletionResponse(BaseModel): 84 id: str = "chatcmpl-test" 85 object: str = "chat.completion" 86 created: int = Field(default_factory=lambda: int(time.time())) 87 model: str = "test-llm-model" 88 choices: List[Choice] 89 usage: Optional[Usage] = Field(default_factory=Usage) 90 91 92 class ChatCompletionRequest(BaseModel): 93 model: str 94 messages: List[Message] 95 tools: Optional[List[Dict[str, Any]]] = None 96 tool_choice: Optional[Union[str, Dict[str, Any]]] = None 97 stream: Optional[bool] = False 98 99 100 app = FastAPI() 101 102 103 class TestLLMServer: 104 DEFAULT_RESPONSE_DELAY_SECONDS: float = 0.01 105 106 def __init__(self, host: str = "127.0.0.1", port: int = 8088): 107 self.host = host 108 self.port = port 109 self._server_thread: Optional[threading.Thread] = None 110 self._static_response: Optional[ChatCompletionResponse] = None 111 self._primed_responses: List[ChatCompletionResponse] = [] 112 self._primed_image_responses: List[Dict[str, Any]] = [] 113 self._primed_response_lock = threading.Lock() 114 self.captured_requests: List[ChatCompletionRequest] = [] 115 self._app = app # Keep a reference to the FastAPI app 116 self._uvicorn_server: Optional[uvicorn.Server] = None # To store the server instance 117 self.response_delay_seconds: float = self.DEFAULT_RESPONSE_DELAY_SECONDS 118 self._setup_logger() 119 self._setup_routes() 120 self._stateful_responses_cache: Dict[str, List[Any]] = {} 121 self._stateful_cache_lock = threading.Lock() 122 123 def _setup_logger(self): 124 """Sets up a dedicated logger for the TestLLMServer.""" 125 self.logger = logging.getLogger("TestLLMServer") 126 self.logger.setLevel(logging.DEBUG) 127 128 self.logger.propagate = False 129 130 for handler in self.logger.handlers[:]: 131 self.logger.removeHandler(handler) 132 133 log_file_path = os.path.join(os.getcwd(), "test_llm_server.log") 134 file_handler = logging.FileHandler(log_file_path, mode="a") 135 file_handler.setFormatter( 136 logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 137 ) 138 self.logger.addHandler(file_handler) 139 self.logger.info( 140 f"TestLLMServer logger initialized. Logging to: {log_file_path}" 141 ) 142 143 @property 144 def started(self) -> bool: 145 """Checks if the uvicorn server instance is started.""" 146 return self._uvicorn_server is not None and self._uvicorn_server.started 147 148 def _setup_routes(self): 149 @self._app.post("/v1/images/generations") 150 async def image_generations(request: Request): 151 await asyncio.sleep(0.01) 152 153 response_data = None 154 with self._primed_response_lock: 155 if self._primed_image_responses: 156 response_data = self._primed_image_responses.pop(0) 157 158 if response_data: 159 status_code = response_data.get("status_code", 200) 160 response_json_str = response_data.get("response", "{}") 161 return json.loads(response_json_str) 162 else: 163 raise HTTPException(status_code=404, detail="No primed image response") 164 165 @self._app.post("/v1/chat/completions") 166 async def chat_completions( 167 request: ChatCompletionRequest, raw_request: Request 168 ): 169 raw_body_bytes = await raw_request.body() 170 raw_body_str = raw_body_bytes.decode("utf-8") 171 self.logger.debug(f"Received raw request body:\n{raw_body_str}") 172 self.logger.debug( 173 f"Parsed ChatCompletionRequest model:\n{request.model_dump_json(indent=2)}" 174 ) 175 176 if request.messages: 177 for i, msg in enumerate(request.messages): 178 self.logger.debug(f"Message {i} - Role: {msg.role}") 179 self.logger.debug( 180 f"Message {i} - Content Type: {type(msg.content)}" 181 ) 182 self.logger.debug(f"Message {i} - Content Value: {msg.content}") 183 if msg.tool_calls: 184 self.logger.debug(f"Message {i} - Tool Calls: {msg.tool_calls}") 185 186 self.captured_requests.append(request.model_copy(deep=True)) 187 188 # Add a small delay to simulate network latency and force the event 189 # loop to yield, ensuring true concurrency in stress tests. 190 await asyncio.sleep(self.response_delay_seconds) 191 192 initial_prompt = request.messages[0].content if request.messages else "" 193 if isinstance(initial_prompt, str): 194 case_id_match = re.search(r"\[test_case_id=([\w.-]+)\]", initial_prompt) 195 if case_id_match: 196 case_id = case_id_match.group(1) 197 self.logger.info(f"Stateful test case detected: {case_id}") 198 199 with self._stateful_cache_lock: 200 if case_id not in self._stateful_responses_cache: 201 self.logger.info( 202 f"Caching responses for new test case: {case_id}" 203 ) 204 responses_match = re.search( 205 r"\[responses_json=([\w=+/]+)\]", initial_prompt 206 ) 207 if responses_match: 208 b64_str = responses_match.group(1) 209 try: 210 json_str = base64.b64decode(b64_str).decode("utf-8") 211 self._stateful_responses_cache[case_id] = ( 212 json.loads(json_str) 213 ) 214 self.logger.info( 215 f"Cached {len(self._stateful_responses_cache[case_id])} responses for {case_id}" 216 ) 217 except ( 218 base64.binascii.Error, 219 json.JSONDecodeError, 220 UnicodeDecodeError, 221 ) as e: 222 self.logger.error( 223 f"Failed to decode stateful responses for {case_id}: {e}" 224 ) 225 raise HTTPException( 226 status_code=500, 227 detail=f"Stateful test case '{case_id}' has invalid [responses_json] directive.", 228 ) 229 else: 230 self.logger.error( 231 f"No [responses_json] directive found for stateful test case: {case_id}" 232 ) 233 raise HTTPException( 234 status_code=500, 235 detail=f"Stateful test case '{case_id}' found but no [responses_json] directive.", 236 ) 237 238 turn_index = (len(request.messages) - 1) // 2 239 self.logger.info( 240 f"Request for turn {turn_index} of test case {case_id}" 241 ) 242 243 with self._stateful_cache_lock: 244 if turn_index < len(self._stateful_responses_cache[case_id]): 245 response_spec = self._stateful_responses_cache[case_id][ 246 turn_index 247 ] 248 self.logger.info( 249 f"Serving response for turn {turn_index} of test case {case_id}" 250 ) 251 else: 252 self.logger.error( 253 f"Test case {case_id} ran out of responses. Requested turn {turn_index}, but only {len(self._stateful_responses_cache[case_id])} defined." 254 ) 255 raise HTTPException( 256 status_code=500, 257 detail=f"Stateful test case '{case_id}' ran out of responses. Requested turn {turn_index}, but only {len(self._stateful_responses_cache[case_id])} are defined.", 258 ) 259 260 if isinstance(response_spec, dict) and response_spec.get( 261 "status_code" 262 ): 263 status_code = response_spec["status_code"] 264 detail = response_spec.get("json_body", {}).get( 265 "error", "Test server error" 266 ) 267 self.logger.info( 268 f"Simulating HTTP error with status code {status_code} and detail '{detail}'" 269 ) 270 raise HTTPException(status_code=status_code, detail=detail) 271 272 if isinstance(response_spec, dict): 273 if "expected_request" in response_spec: 274 self._verify_expected_request( 275 request, 276 response_spec["expected_request"], 277 case_id, 278 turn_index, 279 ) 280 response_to_serve = ChatCompletionResponse( 281 **response_spec.get("static_response", {}) 282 ) 283 else: 284 response_to_serve = response_spec 285 286 if request.stream: 287 self.logger.info( 288 f"Handling stream request for model {request.model}" 289 ) 290 return StreamingResponse( 291 self._generate_stream_chunks( 292 response_to_serve, request.model 293 ), 294 media_type="text/event-stream", 295 ) 296 else: 297 self.logger.info( 298 f"Serving non-streamed response for model {request.model}" 299 ) 300 return response_to_serve 301 302 response_spec = None 303 with self._primed_response_lock: 304 if self._primed_responses: 305 response_spec = self._primed_responses.pop(0) 306 self.logger.info( 307 f"Using primed response. {len(self._primed_responses)} remaining." 308 ) 309 elif self._static_response: 310 response_spec = self._static_response 311 self.logger.info("Using globally configured static response.") 312 else: 313 self.logger.info("Using default response.") 314 default_message = Message( 315 role="assistant", 316 content="Default response from Test LLM Server (no specific response primed or configured)", 317 ) 318 default_choice = Choice( 319 message=default_message, finish_reason="stop" 320 ) 321 response_spec = ChatCompletionResponse(choices=[default_choice]) 322 323 if not response_spec: 324 self.logger.error( 325 "No response configured and default failed to generate." 326 ) 327 raise HTTPException( 328 status_code=500, detail="TestLLMServer: No response configured." 329 ) 330 331 if isinstance(response_spec, dict) and response_spec.get("status_code"): 332 status_code = response_spec["status_code"] 333 detail = response_spec.get("json_body", {}).get( 334 "error", "Test server error" 335 ) 336 self.logger.info( 337 f"Simulating HTTP error with status code {status_code} and detail '{detail}'" 338 ) 339 raise HTTPException(status_code=status_code, detail=detail) 340 341 if isinstance(response_spec, dict): 342 response_to_serve = ChatCompletionResponse(**response_spec) 343 else: 344 response_to_serve = response_spec 345 346 if request.stream: 347 self.logger.info(f"Handling stream request for model {request.model}") 348 return StreamingResponse( 349 self._generate_stream_chunks(response_to_serve, request.model), 350 media_type="text/event-stream", 351 ) 352 else: 353 self.logger.info( 354 f"Serving non-streamed response for model {request.model}" 355 ) 356 return response_to_serve 357 358 async def _generate_stream_chunks( 359 self, full_response: ChatCompletionResponse, request_model: str 360 ) -> AsyncGenerator[str, None]: 361 """ 362 Asynchronously generates SSE formatted delta chunks from a full ChatCompletionResponse. 363 """ 364 try: 365 if ( 366 full_response.choices 367 and full_response.choices[0].message.role == "assistant" 368 ): 369 role_chunk = ChatCompletionChunk( 370 model=request_model, 371 choices=[StreamingChoice(delta=DeltaMessage(role="assistant"))], 372 ) 373 yield f"data: {role_chunk.model_dump_json()}\n\n" 374 await asyncio.sleep(0.01) 375 376 full_content = full_response.choices[0].message.content 377 if isinstance(full_content, str) and full_content: 378 num_chunks = 3 379 content_len = len(full_content) 380 if content_len == 0: 381 pass 382 elif content_len < num_chunks: 383 num_chunks = 1 384 385 approx_chunk_size = (content_len + num_chunks - 1) // num_chunks 386 387 for i in range(num_chunks): 388 start_idx = i * approx_chunk_size 389 end_idx = min((i + 1) * approx_chunk_size, content_len) 390 content_delta = full_content[start_idx:end_idx] 391 392 if content_delta: 393 content_chunk_obj = ChatCompletionChunk( 394 model=request_model, 395 choices=[ 396 StreamingChoice( 397 delta=DeltaMessage(content=content_delta) 398 ) 399 ], 400 ) 401 yield f"data: {content_chunk_obj.model_dump_json()}\n\n" 402 await asyncio.sleep(0.01) 403 404 tool_calls_from_full_response = full_response.choices[0].message.tool_calls 405 if tool_calls_from_full_response: 406 for tc_idx, complete_tool_call in enumerate( 407 tool_calls_from_full_response 408 ): 409 chunk1_delta = DeltaMessage( 410 tool_calls=[ 411 ToolCallDelta( 412 index=tc_idx, 413 id=complete_tool_call.id, 414 type="function", 415 function=ToolCallDeltaFunction( 416 name=complete_tool_call.function.name, arguments="" 417 ), 418 ) 419 ] 420 ) 421 chunk1_obj = ChatCompletionChunk( 422 model=request_model, 423 choices=[StreamingChoice(delta=chunk1_delta)], 424 ) 425 yield f"data: {chunk1_obj.model_dump_json()}\n\n" 426 await asyncio.sleep(0.01) 427 428 chunk2_delta = DeltaMessage( 429 tool_calls=[ 430 ToolCallDelta( 431 index=tc_idx, 432 id=complete_tool_call.id, 433 type="function", 434 function=ToolCallDeltaFunction( 435 arguments=complete_tool_call.function.arguments 436 ), 437 ) 438 ] 439 ) 440 chunk2_obj = ChatCompletionChunk( 441 model=request_model, 442 choices=[StreamingChoice(delta=chunk2_delta)], 443 ) 444 yield f"data: {chunk2_obj.model_dump_json()}\n\n" 445 await asyncio.sleep(0.01) 446 447 finish_reason = full_response.choices[0].finish_reason 448 final_delta_message = DeltaMessage() 449 450 if finish_reason: 451 final_choice = StreamingChoice( 452 delta=final_delta_message, finish_reason=finish_reason 453 ) 454 final_chunk_dict = ChatCompletionChunk( 455 model=request_model, choices=[final_choice] 456 ).model_dump(exclude_none=True) 457 458 if full_response.usage: 459 final_chunk_dict["usage"] = full_response.usage.model_dump() 460 self.logger.info( 461 f"Adding usage data to final stream chunk: {final_chunk_dict['usage']}" 462 ) 463 464 yield f"data: {json.dumps(final_chunk_dict)}\n\n" 465 await asyncio.sleep(0.01) 466 467 except Exception as e: 468 self.logger.error(f"Error during stream generation: {e}", exc_info=True) 469 error_payload = { 470 "error": { 471 "message": f"Stream generation error: {str(e)}", 472 "type": "server_error", 473 "code": 500, 474 } 475 } 476 yield f"data: {json.dumps(error_payload)}\n\n" 477 finally: 478 yield "data: [DONE]\n\n" 479 self.logger.info("Stream finished, sent [DONE].") 480 481 def _verify_tool_declarations( 482 self, 483 actual_tools: List[Dict], 484 expected_declarations: List[Dict], 485 case_id: str, 486 turn_index: int, 487 ): 488 """Verifies that the tool declarations sent to the LLM match expectations.""" 489 actual_tool_map = { 490 tool.get("function", {}).get("name"): tool.get("function", {}) 491 for tool in actual_tools 492 } 493 494 for expected_decl in expected_declarations: 495 expected_name = expected_decl.get("name") 496 if not expected_name: 497 raise HTTPException( 498 status_code=500, 499 detail=f"Stateful test case '{case_id}' turn {turn_index}: " 500 f"expected_tool_declarations_contain item is missing 'name'.", 501 ) 502 503 if expected_name not in actual_tool_map: 504 raise HTTPException( 505 status_code=500, 506 detail=f"Stateful test case '{case_id}' turn {turn_index}: " 507 f"Expected tool '{expected_name}' was not declared to the LLM. " 508 f"Actual tools: {list(actual_tool_map.keys())}", 509 ) 510 511 actual_decl = actual_tool_map[expected_name] 512 if "description_contains" in expected_decl: 513 expected_desc_substr = expected_decl["description_contains"] 514 actual_desc = actual_decl.get("description", "") 515 if expected_desc_substr not in actual_desc: 516 raise HTTPException( 517 status_code=500, 518 detail=f"Stateful test case '{case_id}' turn {turn_index}: " 519 f"Description for tool '{expected_name}' did not match. " 520 f"Expected to contain: '{expected_desc_substr}'. " 521 f"Actual: '{actual_desc}'", 522 ) 523 524 def _verify_tool_responses( 525 self, 526 actual_messages: List[Message], 527 expected_responses: List[Dict], 528 case_id: str, 529 turn_index: int, 530 ): 531 """Verifies that tool responses in the LLM history match expectations.""" 532 tool_messages = [ 533 msg for msg in actual_messages if msg.role == "tool" and msg.tool_call_id 534 ] 535 536 if len(tool_messages) != len(expected_responses): 537 raise HTTPException( 538 status_code=500, 539 detail=f"Stateful test case '{case_id}' turn {turn_index}: " 540 f"Mismatch in number of tool responses. " 541 f"Expected {len(expected_responses)}, Got {len(tool_messages)}.", 542 ) 543 544 # Find the previous request to match tool_call_ids 545 # The current request is the last one in captured_requests. 546 # The one that *made* the tool call is the one before that. 547 if len(self.captured_requests) < 2: 548 raise HTTPException( 549 status_code=500, 550 detail=f"Stateful test case '{case_id}' turn {turn_index}: " 551 f"Cannot verify tool responses, not enough request history captured.", 552 ) 553 prior_request = self.captured_requests[-2] 554 prior_tool_calls = ( 555 prior_request.messages[-1].tool_calls 556 if prior_request.messages and prior_request.messages[-1].tool_calls 557 else [] 558 ) 559 560 for expected_resp in expected_responses: 561 tool_call_id_to_match = None 562 prior_request_index = expected_resp.get( 563 "tool_call_id_matches_prior_request_index" 564 ) 565 if prior_request_index is not None: 566 if prior_request_index < len(prior_tool_calls): 567 tool_call_id_to_match = prior_tool_calls[prior_request_index].id 568 else: 569 raise HTTPException( 570 status_code=500, 571 detail=f"Stateful test case '{case_id}' turn {turn_index}: " 572 f"Invalid tool_call_id_matches_prior_request_index: {prior_request_index}. " 573 f"Prior request only had {len(prior_tool_calls)} tool calls.", 574 ) 575 576 if not tool_call_id_to_match: 577 raise HTTPException( 578 status_code=500, 579 detail=f"Stateful test case '{case_id}' turn {turn_index}: " 580 f"Could not determine tool_call_id for expected response: {expected_resp}", 581 ) 582 583 actual_tool_msg = next( 584 ( 585 msg 586 for msg in tool_messages 587 if msg.tool_call_id == tool_call_id_to_match 588 ), 589 None, 590 ) 591 592 if not actual_tool_msg: 593 raise HTTPException( 594 status_code=500, 595 detail=f"Stateful test case '{case_id}' turn {turn_index}: " 596 f"No tool response found for tool_call_id '{tool_call_id_to_match}'.", 597 ) 598 599 if "response_json_matches" in expected_resp: 600 expected_json = expected_resp["response_json_matches"] 601 try: 602 actual_json = json.loads(actual_tool_msg.content) 603 if actual_json != expected_json: 604 raise HTTPException( 605 status_code=500, 606 detail=f"Stateful test case '{case_id}' turn {turn_index}: " 607 f"JSON content for tool '{tool_call_id_to_match}' did not match.\n" 608 f"Expected: {json.dumps(expected_json)}\n" 609 f"Actual: {json.dumps(actual_json)}", 610 ) 611 except json.JSONDecodeError: 612 raise HTTPException( 613 status_code=500, 614 detail=f"Stateful test case '{case_id}' turn {turn_index}: " 615 f"Tool response for '{tool_call_id_to_match}' was not valid JSON. " 616 f"Content: {actual_tool_msg.content}", 617 ) 618 619 if "response_contains" in expected_resp: 620 expected_substr = expected_resp["response_contains"] 621 if expected_substr not in str(actual_tool_msg.content): 622 raise HTTPException( 623 status_code=500, 624 detail=f"Stateful test case '{case_id}' turn {turn_index}: " 625 f"Content for tool '{tool_call_id_to_match}' did not contain expected substring.\n" 626 f"Expected to contain: '{expected_substr}'\n" 627 f"Actual: '{actual_tool_msg.content}'", 628 ) 629 630 def _verify_expected_request( 631 self, 632 request: ChatCompletionRequest, 633 expected_request_spec: Dict, 634 case_id: str, 635 turn_index: int, 636 ): 637 """Dispatches verification checks based on keys in the expected_request spec.""" 638 if "expected_tool_declarations_contain" in expected_request_spec: 639 self._verify_tool_declarations( 640 request.tools or [], 641 expected_request_spec["expected_tool_declarations_contain"], 642 case_id, 643 turn_index, 644 ) 645 if "expected_tool_responses_in_llm_messages" in expected_request_spec: 646 self._verify_tool_responses( 647 request.messages, 648 expected_request_spec["expected_tool_responses_in_llm_messages"], 649 case_id, 650 turn_index, 651 ) 652 653 def configure_static_response( 654 self, response: Union[Dict[str, Any], ChatCompletionResponse] 655 ): 656 """ 657 Configures a single static response that the server will return if no 658 dynamically primed responses are available. 659 Accepts either a dictionary (which will be parsed into ChatCompletionResponse) 660 or a ChatCompletionResponse object directly. 661 """ 662 if isinstance(response, dict): 663 self._static_response = ChatCompletionResponse(**response) 664 elif isinstance(response, ChatCompletionResponse): 665 self._static_response = response 666 else: 667 raise TypeError( 668 "Static response must be a dict or ChatCompletionResponse object." 669 ) 670 self.logger.info("Global static response configured.") 671 672 def prime_responses( 673 self, responses: List[Union[Dict[str, Any], ChatCompletionResponse]] 674 ): 675 """ 676 Primes the server with a sequence of responses to serve for subsequent requests. 677 Each call to this method overwrites any previously primed responses. 678 """ 679 with self._primed_response_lock: 680 self._primed_responses = [] 681 for rsp in responses: 682 if isinstance(rsp, dict): 683 if rsp.get("status_code"): 684 self._primed_responses.append(rsp) 685 else: 686 self._primed_responses.append(ChatCompletionResponse(**rsp)) 687 elif isinstance(rsp, ChatCompletionResponse): 688 self._primed_responses.append(rsp) 689 else: 690 raise TypeError( 691 "Each response in the list must be a dict or ChatCompletionResponse object." 692 ) 693 self.logger.info(f"Primed with {len(self._primed_responses)} responses.") 694 695 def prime_image_generation_responses(self, responses: List[Dict[str, Any]]): 696 with self._primed_response_lock: 697 self._primed_image_responses = responses 698 self.logger.info( 699 f"Primed with {len(self._primed_image_responses)} image generation responses." 700 ) 701 702 def set_response_delay(self, seconds: float): 703 """Sets a delay for all responses from the chat_completions endpoint.""" 704 self.response_delay_seconds = seconds 705 self.logger.info(f"LLM server response delay set to {seconds} seconds.") 706 707 def clear_all_configurations(self): 708 """Clears primed responses, the global static response, captured requests, and resets response delay.""" 709 with self._primed_response_lock: 710 self._primed_responses = [] 711 self._primed_image_responses = [] 712 self._static_response = None 713 self.captured_requests = [] 714 with self._stateful_cache_lock: 715 self._stateful_responses_cache.clear() 716 self.response_delay_seconds = self.DEFAULT_RESPONSE_DELAY_SECONDS 717 self.logger.info( 718 "All configurations (primed, static, captured requests, response delay) cleared." 719 ) 720 721 def clear_stateful_cache_for_id(self, case_id: str): 722 """Removes a specific test case ID from the stateful response cache.""" 723 with self._stateful_cache_lock: 724 if case_id in self._stateful_responses_cache: 725 del self._stateful_responses_cache[case_id] 726 self.logger.info(f"Cleared stateful cache for test case ID: {case_id}") 727 728 def get_captured_requests(self) -> List[ChatCompletionRequest]: 729 return self.captured_requests 730 731 def clear_captured_requests(self): 732 self.captured_requests = [] 733 734 def start(self): 735 """Starts the FastAPI server in a separate thread.""" 736 if self._server_thread is not None and self._server_thread.is_alive(): 737 self.logger.warning("TestLLMServer is already running.") 738 return 739 740 self.clear_all_configurations() 741 742 config = uvicorn.Config( 743 self._app, host=self.host, port=self.port, log_level="warning" 744 ) 745 self._uvicorn_server = uvicorn.Server(config) 746 747 async def async_serve_wrapper(): 748 """Coroutine to run the server's serve() method and handle potential errors.""" 749 try: 750 if self._uvicorn_server: 751 await self._uvicorn_server.serve() 752 except asyncio.CancelledError: 753 self.logger.info("Server.serve() task was cancelled.") 754 except Exception as e: 755 self.logger.error(f"Error during server.serve(): {e}", exc_info=True) 756 757 def run_server_in_new_loop(): 758 """Target function for the server thread. Sets up and runs an event loop.""" 759 loop = asyncio.new_event_loop() 760 asyncio.set_event_loop(loop) 761 try: 762 loop.run_until_complete(async_serve_wrapper()) 763 except KeyboardInterrupt: 764 print("TestLLMServer: KeyboardInterrupt in server thread.") 765 finally: 766 try: 767 all_tasks = asyncio.all_tasks(loop) 768 if all_tasks: 769 for task in all_tasks: 770 task.cancel() 771 loop.run_until_complete( 772 asyncio.gather(*all_tasks, return_exceptions=True) 773 ) 774 775 if hasattr(loop, "shutdown_asyncgens"): 776 loop.run_until_complete(loop.shutdown_asyncgens()) 777 except Exception as e: 778 self.logger.error( 779 f"Error during loop shutdown tasks: {e}", exc_info=True 780 ) 781 finally: 782 loop.close() 783 self.logger.info("Event loop in server thread closed.") 784 785 self._server_thread = threading.Thread( 786 target=run_server_in_new_loop, daemon=True 787 ) 788 self._server_thread.start() 789 790 self.logger.info(f"TestLLMServer starting on http://{self.host}:{self.port}...") 791 792 def stop(self): 793 """Stops the FastAPI server.""" 794 if self._uvicorn_server: 795 self._uvicorn_server.should_exit = True 796 797 if self._server_thread and self._server_thread.is_alive(): 798 self.logger.info("TestLLMServer stopping, joining thread...") 799 self._server_thread.join(timeout=5.0) 800 if self._server_thread.is_alive(): 801 self.logger.warning("Server thread did not exit cleanly.") 802 self._server_thread = None 803 self._uvicorn_server = None 804 self.logger.info("TestLLMServer stopped.") 805 806 @property 807 def url(self) -> str: 808 return f"http://{self.host}:{self.port}" 809 810 811 if __name__ == "__main__": 812 if __name__ == "__main__": 813 logging.basicConfig( 814 level=logging.INFO, 815 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 816 ) 817 818 server = TestLLMServer() 819 server.start() 820 821 sample_response_data = { 822 "choices": [ 823 { 824 "message": { 825 "role": "assistant", 826 "content": "Hello from the Test LLM!", 827 }, 828 "finish_reason": "stop", 829 } 830 ] 831 } 832 server.configure_static_response(sample_response_data) 833 server.logger.info( 834 f"Test LLM Server running at {server.url}. Configured with a static response." 835 ) 836 server.logger.info( 837 'Try: curl -X POST -H "Content-Type: application/json" -d \'{"model": "test", "messages": [{"role": "user", "content": "Hi"}]}\' http://127.0.0.1:8088/v1/chat/completions' 838 ) 839 840 try: 841 while True: 842 time.sleep(1) 843 except KeyboardInterrupt: 844 server.logger.info("Shutting down Test LLM Server...") 845 finally: 846 server.stop()