test_agno_tracing.py
1 import sys 2 from unittest.mock import MagicMock, patch 3 4 import agno 5 import pytest 6 from agno.agent import Agent 7 from agno.exceptions import ModelProviderError 8 from agno.models.anthropic import Claude 9 from agno.tools.function import Function, FunctionCall 10 from anthropic.types import Message, TextBlock, Usage 11 from packaging.version import Version 12 13 import mlflow 14 import mlflow.agno 15 from mlflow.entities import SpanType 16 from mlflow.entities.span_status import SpanStatusCode 17 from mlflow.tracing.constant import TokenUsageKey 18 19 from tests.tracing.helper import get_traces, purge_traces 20 21 AGNO_VERSION = Version(getattr(agno, "__version__", "1.0.0")) 22 IS_AGNO_V2 = AGNO_VERSION >= Version("2.0.0") 23 # In agno >= 2.3.14, errors are caught internally and returned as error status 24 # instead of being raised as ModelProviderError 25 AGNO_CATCHES_ERRORS = AGNO_VERSION >= Version("2.3.14") 26 27 28 def get_v2_autolog_module(): 29 from mlflow.agno.autolog_v2 import _is_agno_v2 # noqa: F401 30 31 return sys.modules["mlflow.agno.autolog_v2"] 32 33 34 def _create_message(content): 35 return Message( 36 id="1", 37 model="claude-sonnet-4-20250514", 38 content=[TextBlock(text=content, type="text")], 39 role="assistant", 40 stop_reason="end_turn", 41 stop_sequence=None, 42 type="message", 43 usage=Usage(input_tokens=5, output_tokens=7, total_tokens=12), 44 ) 45 46 47 @pytest.fixture 48 def simple_agent(): 49 return Agent( 50 model=Claude(id="claude-sonnet-4-20250514"), 51 instructions="Be concise.", 52 markdown=True, 53 ) 54 55 56 @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior") 57 def test_run_simple_autolog(simple_agent): 58 mlflow.agno.autolog() 59 60 mock_client = MagicMock() 61 mock_client.messages.create.return_value = _create_message("Paris") 62 with patch.object(Claude, "get_client", return_value=mock_client): 63 resp = simple_agent.run("Capital of France?") 64 assert resp.content == "Paris" 65 66 traces = get_traces() 67 assert len(traces) == 1 68 assert traces[0].info.status == "OK" 69 assert traces[0].info.token_usage == { 70 TokenUsageKey.INPUT_TOKENS: 5, 71 TokenUsageKey.OUTPUT_TOKENS: 7, 72 TokenUsageKey.TOTAL_TOKENS: 12, 73 } 74 spans = traces[0].data.spans 75 assert len(spans) == 2 76 assert spans[0].span_type == SpanType.AGENT 77 assert spans[0].name == "Agent.run" 78 assert spans[0].inputs == {"message": "Capital of France?"} 79 assert spans[0].outputs["content"] == "Paris" 80 assert spans[1].span_type == SpanType.LLM 81 assert spans[1].name == "Claude.invoke" 82 assert spans[1].inputs["messages"][-1]["content"] == "Capital of France?" 83 assert spans[1].outputs["content"][0]["text"] == "Paris" 84 assert spans[1].model_name == "claude-sonnet-4-20250514" 85 86 purge_traces() 87 88 mlflow.agno.autolog(disable=True) 89 with patch.object(Claude, "get_client", return_value=mock_client): 90 simple_agent.run("Again?") 91 assert get_traces() == [] 92 93 94 @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior") 95 def test_run_failure_tracing(simple_agent): 96 mlflow.agno.autolog() 97 98 mock_client = MagicMock() 99 mock_client.messages.create.side_effect = RuntimeError("bang") 100 with patch.object(Claude, "get_client", return_value=mock_client): 101 with pytest.raises(ModelProviderError, match="bang"): 102 simple_agent.run("fail") 103 104 trace = get_traces()[0] 105 assert trace.info.status == "ERROR" 106 assert trace.info.token_usage is None 107 spans = trace.data.spans 108 assert spans[0].name == "Agent.run" 109 assert spans[1].name == "Claude.invoke" 110 assert spans[1].status.status_code == SpanStatusCode.ERROR 111 assert spans[1].status.description == "ModelProviderError: bang" 112 113 114 @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior") 115 @pytest.mark.asyncio 116 async def test_arun_simple_autolog(simple_agent): 117 mlflow.agno.autolog() 118 119 async def _mock_create(*args, **kwargs): 120 return _create_message("Paris") 121 122 mock_client = MagicMock() 123 mock_client.messages.create.side_effect = _mock_create 124 with patch.object(Claude, "get_async_client", return_value=mock_client): 125 resp = await simple_agent.arun("Capital of France?") 126 127 assert resp.content == "Paris" 128 129 traces = get_traces() 130 assert len(traces) == 1 131 assert traces[0].info.status == "OK" 132 assert traces[0].info.token_usage == { 133 TokenUsageKey.INPUT_TOKENS: 5, 134 TokenUsageKey.OUTPUT_TOKENS: 7, 135 TokenUsageKey.TOTAL_TOKENS: 12, 136 } 137 spans = traces[0].data.spans 138 assert len(spans) == 2 139 assert spans[0].span_type == SpanType.AGENT 140 assert spans[0].name == "Agent.arun" 141 assert spans[0].inputs == {"message": "Capital of France?"} 142 assert spans[0].outputs["content"] == "Paris" 143 assert spans[1].span_type == SpanType.LLM 144 assert spans[1].name == "Claude.ainvoke" 145 assert spans[1].inputs["messages"][-1]["content"] == "Capital of France?" 146 assert spans[1].outputs["content"][0]["text"] == "Paris" 147 assert spans[1].model_name == "claude-sonnet-4-20250514" 148 149 150 @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior") 151 @pytest.mark.asyncio 152 @pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"]) 153 async def test_failure_tracing(simple_agent, is_async): 154 mlflow.agno.autolog() 155 156 mock_client = MagicMock() 157 mock_client.messages.create.side_effect = RuntimeError("bang") 158 mock_method = "get_async_client" if is_async else "get_client" 159 with patch.object(Claude, mock_method, return_value=mock_client): 160 with pytest.raises(ModelProviderError, match="bang"): # noqa: PT012 161 if is_async: 162 await simple_agent.arun("fail") 163 else: 164 simple_agent.run("fail") 165 166 trace = get_traces()[0] 167 assert trace.info.status == "ERROR" 168 assert trace.info.token_usage is None 169 spans = trace.data.spans 170 assert spans[0].name == "Agent.run" if not is_async else "Agent.arun" 171 assert spans[1].name == "Claude.invoke" if not is_async else "Claude.ainvoke" 172 assert spans[1].status.status_code == SpanStatusCode.ERROR 173 assert spans[1].status.description == "ModelProviderError: bang" 174 175 176 @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior") 177 def test_function_execute_tracing(): 178 def dummy(x): 179 return x + 1 180 181 fc = FunctionCall(function=Function.from_callable(dummy, name="dummy"), arguments={"x": 1}) 182 183 mlflow.agno.autolog(log_traces=True) 184 result = fc.execute() 185 assert result.result == 2 186 187 spans = get_traces()[0].data.spans 188 assert len(spans) == 1 189 span = spans[0] 190 assert span.span_type == SpanType.TOOL 191 assert span.name == "dummy" 192 assert span.inputs == {"x": 1} 193 assert span.attributes["entrypoint"] is not None 194 assert span.outputs["result"] == 2 195 196 197 @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior") 198 @pytest.mark.asyncio 199 async def test_function_aexecute_tracing(): 200 async def dummy(x): 201 return x + 1 202 203 fc = FunctionCall(function=Function.from_callable(dummy, name="dummy"), arguments={"x": 1}) 204 205 mlflow.agno.autolog(log_traces=True) 206 result = await fc.aexecute() 207 assert result.result == 2 208 209 spans = get_traces()[0].data.spans 210 assert len(spans) == 1 211 span = spans[0] 212 assert span.span_type == SpanType.TOOL 213 assert span.name == "dummy" 214 assert span.inputs == {"x": 1} 215 assert span.attributes["entrypoint"] is not None 216 assert span.outputs["result"] == 2 217 218 219 @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior") 220 def test_function_execute_failure_tracing(): 221 from agno.exceptions import AgentRunException 222 223 def boom(x): 224 raise AgentRunException("bad") 225 226 fc = FunctionCall(function=Function.from_callable(boom, name="boom"), arguments={"x": 1}) 227 228 mlflow.agno.autolog(log_traces=True) 229 with pytest.raises(AgentRunException, match="bad"): 230 fc.execute() 231 232 trace = get_traces()[0] 233 assert trace.info.status == "ERROR" 234 span = trace.data.spans[0] 235 assert span.span_type == SpanType.TOOL 236 assert span.status.status_code == SpanStatusCode.ERROR 237 assert span.inputs == {"x": 1} 238 assert span.outputs is None 239 240 241 @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior") 242 @pytest.mark.asyncio 243 @pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"]) 244 async def test_agno_and_anthropic_autolog_single_trace(simple_agent, is_async): 245 mlflow.agno.autolog() 246 mlflow.anthropic.autolog() 247 248 client = "AsyncAPIClient" if is_async else "SyncAPIClient" 249 with patch(f"anthropic._base_client.{client}.post", return_value=_create_message("Paris")): 250 if is_async: 251 await simple_agent.arun("hi") 252 else: 253 simple_agent.run("hi") 254 255 traces = get_traces() 256 assert len(traces) == 1 257 spans = traces[0].data.spans 258 assert spans[0].span_type == SpanType.AGENT 259 assert spans[0].name == "Agent.arun" if is_async else "Agent.run" 260 assert spans[1].span_type == SpanType.LLM 261 assert spans[1].name == "Claude.ainvoke" if is_async else "Claude.invoke" 262 assert spans[2].span_type == SpanType.CHAT_MODEL 263 assert spans[2].name == "AsyncMessages.create" if is_async else "Messages.create" 264 265 266 @pytest.mark.skipif(not IS_AGNO_V2, reason="Test requires V2 functionality") 267 def test_v2_autolog_setup_teardown(): 268 autolog_module = get_v2_autolog_module() 269 original_instrumentor = autolog_module._agno_instrumentor 270 271 try: 272 autolog_module._agno_instrumentor = None 273 274 with patch("mlflow.get_tracking_uri", return_value="http://localhost:5000"): 275 mlflow.agno.autolog(log_traces=True) 276 assert autolog_module._agno_instrumentor is not None 277 278 mlflow.agno.autolog(log_traces=False) 279 finally: 280 autolog_module._agno_instrumentor = original_instrumentor 281 282 283 @pytest.mark.skipif(not IS_AGNO_V2, reason="Test requires V2 functionality") 284 @pytest.mark.asyncio 285 @pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"]) 286 async def test_v2_creates_otel_spans(simple_agent, is_async): 287 from opentelemetry import trace 288 from opentelemetry.sdk.trace import TracerProvider 289 from opentelemetry.sdk.trace.export import SimpleSpanProcessor 290 from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter 291 292 memory_exporter = InMemorySpanExporter() 293 tracer_provider = TracerProvider() 294 tracer_provider.add_span_processor(SimpleSpanProcessor(memory_exporter)) 295 trace.set_tracer_provider(tracer_provider) 296 297 try: 298 with patch("mlflow.get_tracking_uri", return_value="http://localhost:5000"): 299 mlflow.agno.autolog(log_traces=True) 300 301 mock_client = MagicMock() 302 if is_async: 303 304 async def _mock_create(*args, **kwargs): 305 return _create_message("Paris") 306 307 mock_client.messages.create.side_effect = _mock_create 308 else: 309 mock_client.messages.create.return_value = _create_message("Paris") 310 311 mock_method = "get_async_client" if is_async else "get_client" 312 with patch.object(Claude, mock_method, return_value=mock_client): 313 if is_async: 314 resp = await simple_agent.arun("Capital of France?") 315 else: 316 resp = simple_agent.run("Capital of France?") 317 318 assert resp.content == "Paris" 319 spans = memory_exporter.get_finished_spans() 320 assert len(spans) > 0 321 finally: 322 mlflow.agno.autolog(disable=True) 323 324 325 @pytest.mark.skipif(not IS_AGNO_V2, reason="Test requires V2 functionality") 326 def test_v2_failure_creates_spans(simple_agent): 327 from opentelemetry import trace 328 from opentelemetry.sdk.trace import TracerProvider 329 from opentelemetry.sdk.trace.export import SimpleSpanProcessor 330 from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter 331 from opentelemetry.trace import StatusCode 332 333 memory_exporter = InMemorySpanExporter() 334 tracer_provider = TracerProvider() 335 tracer_provider.add_span_processor(SimpleSpanProcessor(memory_exporter)) 336 trace.set_tracer_provider(tracer_provider) 337 338 try: 339 with patch("mlflow.get_tracking_uri", return_value="http://localhost:5000"): 340 mlflow.agno.autolog(log_traces=True) 341 342 mock_client = MagicMock() 343 mock_client.messages.create.side_effect = RuntimeError("bang") 344 with patch.object(Claude, "get_client", return_value=mock_client): 345 if AGNO_CATCHES_ERRORS: 346 # In agno >= 2.3.14, errors are caught internally and returned as error status 347 from agno.run import RunStatus 348 349 result = simple_agent.run("fail") 350 assert result.status == RunStatus.error 351 assert "bang" in result.content 352 else: 353 # In agno < 2.3.14, errors are raised as ModelProviderError 354 with pytest.raises(ModelProviderError, match="bang"): 355 simple_agent.run("fail") 356 357 spans = memory_exporter.get_finished_spans() 358 assert len(spans) > 0 359 if not AGNO_CATCHES_ERRORS: 360 # Error spans are only created when exceptions propagate 361 error_spans = [s for s in spans if s.status.status_code == StatusCode.ERROR] 362 assert len(error_spans) > 0 363 finally: 364 mlflow.agno.autolog(disable=True)