/ tests / agno / test_agno_tracing.py
test_agno_tracing.py
  1  import sys
  2  from unittest.mock import MagicMock, patch
  3  
  4  import agno
  5  import pytest
  6  from agno.agent import Agent
  7  from agno.exceptions import ModelProviderError
  8  from agno.models.anthropic import Claude
  9  from agno.tools.function import Function, FunctionCall
 10  from anthropic.types import Message, TextBlock, Usage
 11  from packaging.version import Version
 12  
 13  import mlflow
 14  import mlflow.agno
 15  from mlflow.entities import SpanType
 16  from mlflow.entities.span_status import SpanStatusCode
 17  from mlflow.tracing.constant import TokenUsageKey
 18  
 19  from tests.tracing.helper import get_traces, purge_traces
 20  
 21  AGNO_VERSION = Version(getattr(agno, "__version__", "1.0.0"))
 22  IS_AGNO_V2 = AGNO_VERSION >= Version("2.0.0")
 23  # In agno >= 2.3.14, errors are caught internally and returned as error status
 24  # instead of being raised as ModelProviderError
 25  AGNO_CATCHES_ERRORS = AGNO_VERSION >= Version("2.3.14")
 26  
 27  
 28  def get_v2_autolog_module():
 29      from mlflow.agno.autolog_v2 import _is_agno_v2  # noqa: F401
 30  
 31      return sys.modules["mlflow.agno.autolog_v2"]
 32  
 33  
 34  def _create_message(content):
 35      return Message(
 36          id="1",
 37          model="claude-sonnet-4-20250514",
 38          content=[TextBlock(text=content, type="text")],
 39          role="assistant",
 40          stop_reason="end_turn",
 41          stop_sequence=None,
 42          type="message",
 43          usage=Usage(input_tokens=5, output_tokens=7, total_tokens=12),
 44      )
 45  
 46  
 47  @pytest.fixture
 48  def simple_agent():
 49      return Agent(
 50          model=Claude(id="claude-sonnet-4-20250514"),
 51          instructions="Be concise.",
 52          markdown=True,
 53      )
 54  
 55  
 56  @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior")
 57  def test_run_simple_autolog(simple_agent):
 58      mlflow.agno.autolog()
 59  
 60      mock_client = MagicMock()
 61      mock_client.messages.create.return_value = _create_message("Paris")
 62      with patch.object(Claude, "get_client", return_value=mock_client):
 63          resp = simple_agent.run("Capital of France?")
 64      assert resp.content == "Paris"
 65  
 66      traces = get_traces()
 67      assert len(traces) == 1
 68      assert traces[0].info.status == "OK"
 69      assert traces[0].info.token_usage == {
 70          TokenUsageKey.INPUT_TOKENS: 5,
 71          TokenUsageKey.OUTPUT_TOKENS: 7,
 72          TokenUsageKey.TOTAL_TOKENS: 12,
 73      }
 74      spans = traces[0].data.spans
 75      assert len(spans) == 2
 76      assert spans[0].span_type == SpanType.AGENT
 77      assert spans[0].name == "Agent.run"
 78      assert spans[0].inputs == {"message": "Capital of France?"}
 79      assert spans[0].outputs["content"] == "Paris"
 80      assert spans[1].span_type == SpanType.LLM
 81      assert spans[1].name == "Claude.invoke"
 82      assert spans[1].inputs["messages"][-1]["content"] == "Capital of France?"
 83      assert spans[1].outputs["content"][0]["text"] == "Paris"
 84      assert spans[1].model_name == "claude-sonnet-4-20250514"
 85  
 86      purge_traces()
 87  
 88      mlflow.agno.autolog(disable=True)
 89      with patch.object(Claude, "get_client", return_value=mock_client):
 90          simple_agent.run("Again?")
 91      assert get_traces() == []
 92  
 93  
 94  @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior")
 95  def test_run_failure_tracing(simple_agent):
 96      mlflow.agno.autolog()
 97  
 98      mock_client = MagicMock()
 99      mock_client.messages.create.side_effect = RuntimeError("bang")
100      with patch.object(Claude, "get_client", return_value=mock_client):
101          with pytest.raises(ModelProviderError, match="bang"):
102              simple_agent.run("fail")
103  
104      trace = get_traces()[0]
105      assert trace.info.status == "ERROR"
106      assert trace.info.token_usage is None
107      spans = trace.data.spans
108      assert spans[0].name == "Agent.run"
109      assert spans[1].name == "Claude.invoke"
110      assert spans[1].status.status_code == SpanStatusCode.ERROR
111      assert spans[1].status.description == "ModelProviderError: bang"
112  
113  
114  @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior")
115  @pytest.mark.asyncio
116  async def test_arun_simple_autolog(simple_agent):
117      mlflow.agno.autolog()
118  
119      async def _mock_create(*args, **kwargs):
120          return _create_message("Paris")
121  
122      mock_client = MagicMock()
123      mock_client.messages.create.side_effect = _mock_create
124      with patch.object(Claude, "get_async_client", return_value=mock_client):
125          resp = await simple_agent.arun("Capital of France?")
126  
127      assert resp.content == "Paris"
128  
129      traces = get_traces()
130      assert len(traces) == 1
131      assert traces[0].info.status == "OK"
132      assert traces[0].info.token_usage == {
133          TokenUsageKey.INPUT_TOKENS: 5,
134          TokenUsageKey.OUTPUT_TOKENS: 7,
135          TokenUsageKey.TOTAL_TOKENS: 12,
136      }
137      spans = traces[0].data.spans
138      assert len(spans) == 2
139      assert spans[0].span_type == SpanType.AGENT
140      assert spans[0].name == "Agent.arun"
141      assert spans[0].inputs == {"message": "Capital of France?"}
142      assert spans[0].outputs["content"] == "Paris"
143      assert spans[1].span_type == SpanType.LLM
144      assert spans[1].name == "Claude.ainvoke"
145      assert spans[1].inputs["messages"][-1]["content"] == "Capital of France?"
146      assert spans[1].outputs["content"][0]["text"] == "Paris"
147      assert spans[1].model_name == "claude-sonnet-4-20250514"
148  
149  
150  @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior")
151  @pytest.mark.asyncio
152  @pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
153  async def test_failure_tracing(simple_agent, is_async):
154      mlflow.agno.autolog()
155  
156      mock_client = MagicMock()
157      mock_client.messages.create.side_effect = RuntimeError("bang")
158      mock_method = "get_async_client" if is_async else "get_client"
159      with patch.object(Claude, mock_method, return_value=mock_client):
160          with pytest.raises(ModelProviderError, match="bang"):  # noqa: PT012
161              if is_async:
162                  await simple_agent.arun("fail")
163              else:
164                  simple_agent.run("fail")
165  
166      trace = get_traces()[0]
167      assert trace.info.status == "ERROR"
168      assert trace.info.token_usage is None
169      spans = trace.data.spans
170      assert spans[0].name == "Agent.run" if not is_async else "Agent.arun"
171      assert spans[1].name == "Claude.invoke" if not is_async else "Claude.ainvoke"
172      assert spans[1].status.status_code == SpanStatusCode.ERROR
173      assert spans[1].status.description == "ModelProviderError: bang"
174  
175  
176  @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior")
177  def test_function_execute_tracing():
178      def dummy(x):
179          return x + 1
180  
181      fc = FunctionCall(function=Function.from_callable(dummy, name="dummy"), arguments={"x": 1})
182  
183      mlflow.agno.autolog(log_traces=True)
184      result = fc.execute()
185      assert result.result == 2
186  
187      spans = get_traces()[0].data.spans
188      assert len(spans) == 1
189      span = spans[0]
190      assert span.span_type == SpanType.TOOL
191      assert span.name == "dummy"
192      assert span.inputs == {"x": 1}
193      assert span.attributes["entrypoint"] is not None
194      assert span.outputs["result"] == 2
195  
196  
197  @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior")
198  @pytest.mark.asyncio
199  async def test_function_aexecute_tracing():
200      async def dummy(x):
201          return x + 1
202  
203      fc = FunctionCall(function=Function.from_callable(dummy, name="dummy"), arguments={"x": 1})
204  
205      mlflow.agno.autolog(log_traces=True)
206      result = await fc.aexecute()
207      assert result.result == 2
208  
209      spans = get_traces()[0].data.spans
210      assert len(spans) == 1
211      span = spans[0]
212      assert span.span_type == SpanType.TOOL
213      assert span.name == "dummy"
214      assert span.inputs == {"x": 1}
215      assert span.attributes["entrypoint"] is not None
216      assert span.outputs["result"] == 2
217  
218  
219  @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior")
220  def test_function_execute_failure_tracing():
221      from agno.exceptions import AgentRunException
222  
223      def boom(x):
224          raise AgentRunException("bad")
225  
226      fc = FunctionCall(function=Function.from_callable(boom, name="boom"), arguments={"x": 1})
227  
228      mlflow.agno.autolog(log_traces=True)
229      with pytest.raises(AgentRunException, match="bad"):
230          fc.execute()
231  
232      trace = get_traces()[0]
233      assert trace.info.status == "ERROR"
234      span = trace.data.spans[0]
235      assert span.span_type == SpanType.TOOL
236      assert span.status.status_code == SpanStatusCode.ERROR
237      assert span.inputs == {"x": 1}
238      assert span.outputs is None
239  
240  
241  @pytest.mark.skipif(IS_AGNO_V2, reason="Test uses V1 patching behavior")
242  @pytest.mark.asyncio
243  @pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
244  async def test_agno_and_anthropic_autolog_single_trace(simple_agent, is_async):
245      mlflow.agno.autolog()
246      mlflow.anthropic.autolog()
247  
248      client = "AsyncAPIClient" if is_async else "SyncAPIClient"
249      with patch(f"anthropic._base_client.{client}.post", return_value=_create_message("Paris")):
250          if is_async:
251              await simple_agent.arun("hi")
252          else:
253              simple_agent.run("hi")
254  
255      traces = get_traces()
256      assert len(traces) == 1
257      spans = traces[0].data.spans
258      assert spans[0].span_type == SpanType.AGENT
259      assert spans[0].name == "Agent.arun" if is_async else "Agent.run"
260      assert spans[1].span_type == SpanType.LLM
261      assert spans[1].name == "Claude.ainvoke" if is_async else "Claude.invoke"
262      assert spans[2].span_type == SpanType.CHAT_MODEL
263      assert spans[2].name == "AsyncMessages.create" if is_async else "Messages.create"
264  
265  
266  @pytest.mark.skipif(not IS_AGNO_V2, reason="Test requires V2 functionality")
267  def test_v2_autolog_setup_teardown():
268      autolog_module = get_v2_autolog_module()
269      original_instrumentor = autolog_module._agno_instrumentor
270  
271      try:
272          autolog_module._agno_instrumentor = None
273  
274          with patch("mlflow.get_tracking_uri", return_value="http://localhost:5000"):
275              mlflow.agno.autolog(log_traces=True)
276              assert autolog_module._agno_instrumentor is not None
277  
278              mlflow.agno.autolog(log_traces=False)
279      finally:
280          autolog_module._agno_instrumentor = original_instrumentor
281  
282  
283  @pytest.mark.skipif(not IS_AGNO_V2, reason="Test requires V2 functionality")
284  @pytest.mark.asyncio
285  @pytest.mark.parametrize("is_async", [True, False], ids=["async", "sync"])
286  async def test_v2_creates_otel_spans(simple_agent, is_async):
287      from opentelemetry import trace
288      from opentelemetry.sdk.trace import TracerProvider
289      from opentelemetry.sdk.trace.export import SimpleSpanProcessor
290      from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
291  
292      memory_exporter = InMemorySpanExporter()
293      tracer_provider = TracerProvider()
294      tracer_provider.add_span_processor(SimpleSpanProcessor(memory_exporter))
295      trace.set_tracer_provider(tracer_provider)
296  
297      try:
298          with patch("mlflow.get_tracking_uri", return_value="http://localhost:5000"):
299              mlflow.agno.autolog(log_traces=True)
300  
301              mock_client = MagicMock()
302              if is_async:
303  
304                  async def _mock_create(*args, **kwargs):
305                      return _create_message("Paris")
306  
307                  mock_client.messages.create.side_effect = _mock_create
308              else:
309                  mock_client.messages.create.return_value = _create_message("Paris")
310  
311              mock_method = "get_async_client" if is_async else "get_client"
312              with patch.object(Claude, mock_method, return_value=mock_client):
313                  if is_async:
314                      resp = await simple_agent.arun("Capital of France?")
315                  else:
316                      resp = simple_agent.run("Capital of France?")
317  
318              assert resp.content == "Paris"
319              spans = memory_exporter.get_finished_spans()
320              assert len(spans) > 0
321      finally:
322          mlflow.agno.autolog(disable=True)
323  
324  
325  @pytest.mark.skipif(not IS_AGNO_V2, reason="Test requires V2 functionality")
326  def test_v2_failure_creates_spans(simple_agent):
327      from opentelemetry import trace
328      from opentelemetry.sdk.trace import TracerProvider
329      from opentelemetry.sdk.trace.export import SimpleSpanProcessor
330      from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
331      from opentelemetry.trace import StatusCode
332  
333      memory_exporter = InMemorySpanExporter()
334      tracer_provider = TracerProvider()
335      tracer_provider.add_span_processor(SimpleSpanProcessor(memory_exporter))
336      trace.set_tracer_provider(tracer_provider)
337  
338      try:
339          with patch("mlflow.get_tracking_uri", return_value="http://localhost:5000"):
340              mlflow.agno.autolog(log_traces=True)
341  
342              mock_client = MagicMock()
343              mock_client.messages.create.side_effect = RuntimeError("bang")
344              with patch.object(Claude, "get_client", return_value=mock_client):
345                  if AGNO_CATCHES_ERRORS:
346                      # In agno >= 2.3.14, errors are caught internally and returned as error status
347                      from agno.run import RunStatus
348  
349                      result = simple_agent.run("fail")
350                      assert result.status == RunStatus.error
351                      assert "bang" in result.content
352                  else:
353                      # In agno < 2.3.14, errors are raised as ModelProviderError
354                      with pytest.raises(ModelProviderError, match="bang"):
355                          simple_agent.run("fail")
356  
357              spans = memory_exporter.get_finished_spans()
358              assert len(spans) > 0
359              if not AGNO_CATCHES_ERRORS:
360                  # Error spans are only created when exceptions propagate
361                  error_spans = [s for s in spans if s.status.status_code == StatusCode.ERROR]
362                  assert len(error_spans) > 0
363      finally:
364          mlflow.agno.autolog(disable=True)