/ tests / strands / test_strands_tracing.py
test_strands_tracing.py
  1  import json
  2  from collections.abc import AsyncIterator, Sequence
  3  from typing import Any
  4  
  5  from strands import Agent
  6  from strands.models.model import Model
  7  from strands.tools.tools import PythonAgentTool
  8  
  9  import mlflow
 10  from mlflow.entities import SpanType
 11  from mlflow.environment_variables import MLFLOW_USE_DEFAULT_TRACER_PROVIDER
 12  from mlflow.tracing.constant import SpanAttributeKey
 13  from mlflow.tracing.provider import trace_disabled
 14  
 15  from tests.tracing.helper import get_traces
 16  
 17  
 18  async def sum_tool(tool_use, **_):
 19      a = tool_use["input"]["a"]
 20      b = tool_use["input"]["b"]
 21      return {
 22          "toolUseId": tool_use["toolUseId"],
 23          "status": "success",
 24          "content": [{"json": a + b}],
 25      }
 26  
 27  
 28  tool = PythonAgentTool(
 29      "sum",
 30      {
 31          "name": "sum",
 32          "description": "add numbers 1 2",
 33          "inputSchema": {
 34              "type": "object",
 35              "properties": {"a": {"type": "number"}, "b": {"type": "number"}},
 36              "required": ["a", "b"],
 37          },
 38      },
 39      sum_tool,
 40  )
 41  
 42  
 43  class DummyModel(Model):
 44      def __init__(self, response_text: str, in_tokens: int = 1, out_tokens: int = 1):
 45          self.response_text = response_text
 46          self.in_tokens = in_tokens
 47          self.out_tokens = out_tokens
 48          self.config = {}
 49  
 50      def update_config(self, **model_config):
 51          self.config.update(model_config)
 52  
 53      def get_config(self):
 54          return self.config
 55  
 56      async def structured_output(self, output_model, prompt, system_prompt=None, **kwargs):
 57          if False:
 58              yield {}
 59  
 60      async def stream(
 61          self,
 62          messages: Sequence[dict[str, Any]],
 63          tool_specs: Any | None = None,
 64          system_prompt: str | None = None,
 65          **kwargs: Any,
 66      ) -> AsyncIterator[dict[str, Any]]:
 67          yield {"messageStart": {"role": "assistant"}}
 68          yield {"contentBlockStart": {"start": {}}}
 69          yield {"contentBlockDelta": {"delta": {"text": self.response_text}}}
 70          yield {"contentBlockStop": {}}
 71          yield {"messageStop": {"stopReason": "end_turn"}}
 72          yield {
 73              "metadata": {
 74                  "usage": {
 75                      "inputTokens": self.in_tokens,
 76                      "outputTokens": self.out_tokens,
 77                      "totalTokens": self.in_tokens + self.out_tokens,
 78                  },
 79                  "metrics": {"latencyMs": 0},
 80              }
 81          }
 82  
 83  
 84  class ToolCallingModel(Model):
 85      def __init__(
 86          self,
 87          response_text: str,
 88          tool_input: dict[str, Any] | None = None,
 89          tool_name: str = "sum",
 90      ):
 91          self.response_text = response_text
 92          self.tool_input = tool_input or {"a": 1, "b": 2}
 93          self.tool_name = tool_name
 94          self.config = {}
 95          self._call_count = 0
 96  
 97      def update_config(self, **model_config: Any) -> None:
 98          self.config.update(model_config)
 99  
100      def get_config(self) -> dict[str, object]:
101          return self.config
102  
103      async def structured_output(
104          self,
105          output_model: Any,
106          prompt: Any,
107          system_prompt: str | None = None,
108          **kwargs: Any,
109      ) -> AsyncIterator[dict[str, Any]]:
110          if False:
111              yield {}
112  
113      async def stream(
114          self,
115          messages: Sequence[dict[str, Any]],
116          tool_specs: Any | None = None,
117          system_prompt: str | None = None,
118          **kwargs: Any,
119      ) -> AsyncIterator[dict[str, Any]]:
120          if self._call_count == 0:
121              self._call_count += 1
122              yield {"messageStart": {"role": "assistant"}}
123              yield {
124                  "contentBlockStart": {
125                      "start": {
126                          "toolUse": {
127                              "toolUseId": "tool-1",
128                              "name": self.tool_name,
129                          }
130                      }
131                  }
132              }
133              yield {
134                  "contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(self.tool_input)}}}
135              }
136              yield {"contentBlockStop": {}}
137              yield {"messageStop": {"stopReason": "tool_use"}}
138              yield {
139                  "metadata": {
140                      "usage": {
141                          "inputTokens": 1,
142                          "outputTokens": 1,
143                          "totalTokens": 2,
144                      },
145                      "metrics": {"latencyMs": 0},
146                  }
147              }
148          else:
149              yield {"messageStart": {"role": "assistant"}}
150              yield {"contentBlockStart": {"start": {}}}
151              yield {"contentBlockDelta": {"delta": {"text": self.response_text}}}
152              yield {"contentBlockStop": {}}
153              yield {"messageStop": {"stopReason": "end_turn"}}
154              yield {
155                  "metadata": {
156                      "usage": {
157                          "inputTokens": 1,
158                          "outputTokens": 1,
159                          "totalTokens": 2,
160                      },
161                      "metrics": {"latencyMs": 0},
162                  }
163              }
164  
165  
166  def test_strands_autolog_single_trace():
167      mlflow.strands.autolog()
168  
169      agent = Agent(model=DummyModel("hi", 1, 2), name="agent")
170      agent("hello")
171  
172      traces = get_traces()
173      assert len(traces) == 1
174      spans = traces[0].data.spans
175      agent_span = next(span for span in spans if span.span_type == SpanType.AGENT)
176      assert agent_span.inputs == [{"role": "user", "content": [{"text": "hello"}]}]
177      assert agent_span.outputs.strip() == "hi"
178  
179      usage_spans = [span for span in spans if span.attributes.get(SpanAttributeKey.CHAT_USAGE)]
180      assert usage_spans, "expected at least one child span recording token usage"
181      assert usage_spans[0].attributes[SpanAttributeKey.CHAT_USAGE] == {
182          "input_tokens": 1,
183          "output_tokens": 2,
184          "total_tokens": 3,
185      }
186      assert traces[0].info.token_usage == {
187          "input_tokens": 1,
188          "output_tokens": 2,
189          "total_tokens": 3,
190      }
191  
192      mlflow.strands.autolog(disable=True)
193      agent("bye")
194      assert len(get_traces()) == 1
195  
196  
197  def test_function_calling_creates_single_trace():
198      mlflow.strands.autolog()
199  
200      agent = Agent(model=ToolCallingModel("3"), tools=[tool], name="agent")
201      agent("add numbers 1 2 1 2")
202  
203      traces = get_traces()
204      assert len(traces) == 1
205      spans = traces[0].data.spans
206      agent_span = spans[0]
207      assert agent_span.span_type == SpanType.AGENT
208      tool_span = spans[3]
209      assert tool_span.span_type == SpanType.TOOL
210      assert agent_span.inputs == [{"role": "user", "content": [{"text": "add numbers 1 2 1 2"}]}]
211      assert agent_span.outputs == 3
212      assert tool_span.inputs == {"a": 1, "b": 2}
213      assert tool_span.outputs == [{"json": 3}]
214  
215  
216  def test_multiple_agents_single_trace():
217      mlflow.strands.autolog()
218  
219      agent2 = Agent(model=DummyModel("hi"), name="agent2")
220  
221      async def sum_and_call_agent2(
222          tool_use: dict[str, Any],
223          **_: Any,
224      ) -> dict[str, Any]:
225          a = tool_use["input"]["a"]
226          b = tool_use["input"]["b"]
227          await agent2.invoke_async("hello")
228          return {
229              "toolUseId": tool_use["toolUseId"],
230              "status": "success",
231              "content": [{"json": a + b}],
232          }
233  
234      tool_with_agent2 = PythonAgentTool(
235          "sum",
236          {
237              "name": "sum",
238              "description": "add numbers 1 2",
239              "inputSchema": {
240                  "type": "object",
241                  "properties": {"a": {"type": "number"}, "b": {"type": "number"}},
242                  "required": ["a", "b"],
243              },
244          },
245          sum_and_call_agent2,
246      )
247  
248      agent1 = Agent(model=ToolCallingModel("3"), tools=[tool_with_agent2], name="agent1")
249      agent1("add numbers 1 2")
250  
251      traces = get_traces()
252      assert len(traces) == 1
253      spans = traces[0].data.spans
254      agent1_span = spans[0]
255      assert agent1_span.name == "invoke_agent agent1"
256      tool_span = spans[3]
257      assert tool_span.span_type == SpanType.TOOL
258      agent2_span = spans[4]
259      assert agent2_span.name == "invoke_agent agent2"
260      assert agent1_span.inputs == [{"role": "user", "content": [{"text": "add numbers 1 2"}]}]
261      assert agent1_span.outputs == 3
262      assert tool_span.inputs == {"a": 1, "b": 2}
263      assert tool_span.outputs == [{"json": 3}]
264      assert agent2_span.inputs == [{"role": "user", "content": [{"text": "hello"}]}]
265      assert agent2_span.outputs.strip() == "hi"
266      # top-level span should contain the sum of both the chat spans. this is set
267      # when we translate the genai semantic conventions into mlflow attributes.
268      assert agent1_span.attributes[SpanAttributeKey.CHAT_USAGE] == {
269          "input_tokens": 2,
270          "output_tokens": 2,
271          "total_tokens": 4,
272      }
273      # agent2 span should contain the token usage for its single chat span
274      assert agent2_span.attributes[SpanAttributeKey.CHAT_USAGE] == {
275          "input_tokens": 1,
276          "output_tokens": 1,
277          "total_tokens": 2,
278      }
279  
280  
281  def test_autolog_disable_prevents_new_traces():
282      mlflow.strands.autolog()
283  
284      agent1 = Agent(model=DummyModel("hi"), name="agent1")
285      agent2 = Agent(model=DummyModel("cya"), name="agent2")
286  
287      agent1("hello")
288      assert len(get_traces()) == 1
289  
290      mlflow.strands.autolog(disable=True)
291      agent2("bye")
292      assert len(get_traces()) == 1
293  
294  
295  def test_autolog_does_not_raise_npe_when_tracing_disabled():
296      mlflow.strands.autolog()
297  
298      agent = Agent(model=DummyModel("hi"), name="agent")
299  
300      @trace_disabled
301      def run():
302          agent("hello")
303  
304      run()
305      assert len(get_traces()) == 0
306  
307  
308  def test_strands_autolog_shared_provider_no_recursion(monkeypatch):
309      # Verify strands.autolog() works with shared tracer provider (no RecursionError)
310      monkeypatch.setenv(MLFLOW_USE_DEFAULT_TRACER_PROVIDER.name, "false")
311  
312      mlflow.strands.autolog()
313  
314      agent = Agent(model=DummyModel("hi"), name="agent")
315      agent("hello")
316  
317      traces = get_traces()
318      assert len(traces) == 1
319      spans = traces[0].data.spans
320      agent_span = next(span for span in spans if span.span_type == SpanType.AGENT)
321      assert agent_span.inputs == [{"role": "user", "content": [{"text": "hello"}]}]
322      assert agent_span.outputs.strip() == "hi"