/ agent-scan / test_llm_error_handling.py
test_llm_error_handling.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 import sys 20 from pathlib import Path 21 22 import pytest 23 24 sys.path.insert(0, str(Path(__file__).resolve().parent)) 25 26 import core.base_agent as base_agent_module 27 import utils.llm as llm_module 28 from core.base_agent import BaseAgent 29 from utils.llm import LLM, LLM_ERROR_PREFIX 30 31 32 class DummyConnectionError(Exception): 33 pass 34 35 36 class DummyTimeoutError(Exception): 37 pass 38 39 40 class DummyAPIError(Exception): 41 pass 42 43 44 class DummyBadRequestError(Exception): 45 pass 46 47 48 @pytest.fixture 49 def llm(monkeypatch): 50 monkeypatch.setattr(llm_module.openai, "OpenAI", lambda **kwargs: object()) 51 monkeypatch.setattr(llm_module.openai, "APIConnectionError", DummyConnectionError) 52 monkeypatch.setattr(llm_module.openai, "APITimeoutError", DummyTimeoutError) 53 monkeypatch.setattr(llm_module.openai, "APIError", DummyAPIError) 54 monkeypatch.setattr(llm_module.openai, "BadRequestError", DummyBadRequestError) 55 return LLM(model="test-model", api_key="test-key", base_url="https://example.com") 56 57 58 def test_chat_resets_buffer_before_retry(monkeypatch, llm): 59 attempts = {"count": 0} 60 61 def fake_chat_stream(_message): 62 attempts["count"] += 1 63 if attempts["count"] == 1: 64 yield "partial" 65 raise DummyConnectionError("connection dropped") 66 yield "full" 67 68 monkeypatch.setattr(llm, "chat_stream", fake_chat_stream) 69 monkeypatch.setattr(llm_module.time, "sleep", lambda _seconds: None) 70 71 assert llm.chat([]) == "full" 72 73 74 def test_chat_returns_prefixed_error_for_empty_responses(monkeypatch, llm): 75 monkeypatch.setattr(llm, "chat_stream", lambda _message: iter(())) 76 monkeypatch.setattr(llm_module.time, "sleep", lambda _seconds: None) 77 78 response = llm.chat([], language="en") 79 80 assert response.startswith(LLM_ERROR_PREFIX) 81 82 83 @pytest.mark.asyncio 84 async def test_compact_history_skips_on_llm_error(monkeypatch): 85 class DummyLLM: 86 async def chat_async(self, _history, language="zh"): 87 return f"{LLM_ERROR_PREFIX} compact failed]" 88 89 agent = BaseAgent.__new__(BaseAgent) 90 agent.llm = DummyLLM() 91 agent.language = "en" 92 agent.history = [ 93 {"role": "system", "content": "system"}, 94 {"role": "user", "content": "original task"}, 95 {"role": "assistant", "content": "intermediate context"}, 96 ] 97 98 original_history = list(agent.history) 99 monkeypatch.setattr(base_agent_module.prompt_manager, "load_template", lambda _name: "compact prompt") 100 101 assert await BaseAgent.compact_history(agent) is False 102 assert agent.history == original_history 103 104 105 @pytest.mark.asyncio 106 async def test_format_final_output_falls_back_to_last_assistant(monkeypatch): 107 class DummyLLM: 108 async def chat_async(self, _history, language="zh"): 109 return f"{LLM_ERROR_PREFIX} format failed]" 110 111 agent = BaseAgent.__new__(BaseAgent) 112 agent.llm = DummyLLM() 113 agent.language = "en" 114 agent.instruction = "format output" 115 agent.history = [ 116 {"role": "system", "content": "system"}, 117 {"role": "user", "content": "original task"}, 118 {"role": "assistant", "content": "final report content"}, 119 ] 120 121 monkeypatch.setattr(base_agent_module.prompt_manager, "format_prompt", lambda *_args, **_kwargs: "format prompt") 122 123 assert await BaseAgent._format_final_output(agent) == "final report content"