test_strands_tracing.py
1 import json 2 from collections.abc import AsyncIterator, Sequence 3 from typing import Any 4 5 from strands import Agent 6 from strands.models.model import Model 7 from strands.tools.tools import PythonAgentTool 8 9 import mlflow 10 from mlflow.entities import SpanType 11 from mlflow.environment_variables import MLFLOW_USE_DEFAULT_TRACER_PROVIDER 12 from mlflow.tracing.constant import SpanAttributeKey 13 from mlflow.tracing.provider import trace_disabled 14 15 from tests.tracing.helper import get_traces 16 17 18 async def sum_tool(tool_use, **_): 19 a = tool_use["input"]["a"] 20 b = tool_use["input"]["b"] 21 return { 22 "toolUseId": tool_use["toolUseId"], 23 "status": "success", 24 "content": [{"json": a + b}], 25 } 26 27 28 tool = PythonAgentTool( 29 "sum", 30 { 31 "name": "sum", 32 "description": "add numbers 1 2", 33 "inputSchema": { 34 "type": "object", 35 "properties": {"a": {"type": "number"}, "b": {"type": "number"}}, 36 "required": ["a", "b"], 37 }, 38 }, 39 sum_tool, 40 ) 41 42 43 class DummyModel(Model): 44 def __init__(self, response_text: str, in_tokens: int = 1, out_tokens: int = 1): 45 self.response_text = response_text 46 self.in_tokens = in_tokens 47 self.out_tokens = out_tokens 48 self.config = {} 49 50 def update_config(self, **model_config): 51 self.config.update(model_config) 52 53 def get_config(self): 54 return self.config 55 56 async def structured_output(self, output_model, prompt, system_prompt=None, **kwargs): 57 if False: 58 yield {} 59 60 async def stream( 61 self, 62 messages: Sequence[dict[str, Any]], 63 tool_specs: Any | None = None, 64 system_prompt: str | None = None, 65 **kwargs: Any, 66 ) -> AsyncIterator[dict[str, Any]]: 67 yield {"messageStart": {"role": "assistant"}} 68 yield {"contentBlockStart": {"start": {}}} 69 yield {"contentBlockDelta": {"delta": {"text": self.response_text}}} 70 yield {"contentBlockStop": {}} 71 yield {"messageStop": {"stopReason": "end_turn"}} 72 yield { 73 "metadata": { 74 "usage": { 75 "inputTokens": self.in_tokens, 76 "outputTokens": self.out_tokens, 77 "totalTokens": self.in_tokens + self.out_tokens, 78 }, 79 "metrics": {"latencyMs": 0}, 80 } 81 } 82 83 84 class ToolCallingModel(Model): 85 def __init__( 86 self, 87 response_text: str, 88 tool_input: dict[str, Any] | None = None, 89 tool_name: str = "sum", 90 ): 91 self.response_text = response_text 92 self.tool_input = tool_input or {"a": 1, "b": 2} 93 self.tool_name = tool_name 94 self.config = {} 95 self._call_count = 0 96 97 def update_config(self, **model_config: Any) -> None: 98 self.config.update(model_config) 99 100 def get_config(self) -> dict[str, object]: 101 return self.config 102 103 async def structured_output( 104 self, 105 output_model: Any, 106 prompt: Any, 107 system_prompt: str | None = None, 108 **kwargs: Any, 109 ) -> AsyncIterator[dict[str, Any]]: 110 if False: 111 yield {} 112 113 async def stream( 114 self, 115 messages: Sequence[dict[str, Any]], 116 tool_specs: Any | None = None, 117 system_prompt: str | None = None, 118 **kwargs: Any, 119 ) -> AsyncIterator[dict[str, Any]]: 120 if self._call_count == 0: 121 self._call_count += 1 122 yield {"messageStart": {"role": "assistant"}} 123 yield { 124 "contentBlockStart": { 125 "start": { 126 "toolUse": { 127 "toolUseId": "tool-1", 128 "name": self.tool_name, 129 } 130 } 131 } 132 } 133 yield { 134 "contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(self.tool_input)}}} 135 } 136 yield {"contentBlockStop": {}} 137 yield {"messageStop": {"stopReason": "tool_use"}} 138 yield { 139 "metadata": { 140 "usage": { 141 "inputTokens": 1, 142 "outputTokens": 1, 143 "totalTokens": 2, 144 }, 145 "metrics": {"latencyMs": 0}, 146 } 147 } 148 else: 149 yield {"messageStart": {"role": "assistant"}} 150 yield {"contentBlockStart": {"start": {}}} 151 yield {"contentBlockDelta": {"delta": {"text": self.response_text}}} 152 yield {"contentBlockStop": {}} 153 yield {"messageStop": {"stopReason": "end_turn"}} 154 yield { 155 "metadata": { 156 "usage": { 157 "inputTokens": 1, 158 "outputTokens": 1, 159 "totalTokens": 2, 160 }, 161 "metrics": {"latencyMs": 0}, 162 } 163 } 164 165 166 def test_strands_autolog_single_trace(): 167 mlflow.strands.autolog() 168 169 agent = Agent(model=DummyModel("hi", 1, 2), name="agent") 170 agent("hello") 171 172 traces = get_traces() 173 assert len(traces) == 1 174 spans = traces[0].data.spans 175 agent_span = next(span for span in spans if span.span_type == SpanType.AGENT) 176 assert agent_span.inputs == [{"role": "user", "content": [{"text": "hello"}]}] 177 assert agent_span.outputs.strip() == "hi" 178 179 usage_spans = [span for span in spans if span.attributes.get(SpanAttributeKey.CHAT_USAGE)] 180 assert usage_spans, "expected at least one child span recording token usage" 181 assert usage_spans[0].attributes[SpanAttributeKey.CHAT_USAGE] == { 182 "input_tokens": 1, 183 "output_tokens": 2, 184 "total_tokens": 3, 185 } 186 assert traces[0].info.token_usage == { 187 "input_tokens": 1, 188 "output_tokens": 2, 189 "total_tokens": 3, 190 } 191 192 mlflow.strands.autolog(disable=True) 193 agent("bye") 194 assert len(get_traces()) == 1 195 196 197 def test_function_calling_creates_single_trace(): 198 mlflow.strands.autolog() 199 200 agent = Agent(model=ToolCallingModel("3"), tools=[tool], name="agent") 201 agent("add numbers 1 2 1 2") 202 203 traces = get_traces() 204 assert len(traces) == 1 205 spans = traces[0].data.spans 206 agent_span = spans[0] 207 assert agent_span.span_type == SpanType.AGENT 208 tool_span = spans[3] 209 assert tool_span.span_type == SpanType.TOOL 210 assert agent_span.inputs == [{"role": "user", "content": [{"text": "add numbers 1 2 1 2"}]}] 211 assert agent_span.outputs == 3 212 assert tool_span.inputs == {"a": 1, "b": 2} 213 assert tool_span.outputs == [{"json": 3}] 214 215 216 def test_multiple_agents_single_trace(): 217 mlflow.strands.autolog() 218 219 agent2 = Agent(model=DummyModel("hi"), name="agent2") 220 221 async def sum_and_call_agent2( 222 tool_use: dict[str, Any], 223 **_: Any, 224 ) -> dict[str, Any]: 225 a = tool_use["input"]["a"] 226 b = tool_use["input"]["b"] 227 await agent2.invoke_async("hello") 228 return { 229 "toolUseId": tool_use["toolUseId"], 230 "status": "success", 231 "content": [{"json": a + b}], 232 } 233 234 tool_with_agent2 = PythonAgentTool( 235 "sum", 236 { 237 "name": "sum", 238 "description": "add numbers 1 2", 239 "inputSchema": { 240 "type": "object", 241 "properties": {"a": {"type": "number"}, "b": {"type": "number"}}, 242 "required": ["a", "b"], 243 }, 244 }, 245 sum_and_call_agent2, 246 ) 247 248 agent1 = Agent(model=ToolCallingModel("3"), tools=[tool_with_agent2], name="agent1") 249 agent1("add numbers 1 2") 250 251 traces = get_traces() 252 assert len(traces) == 1 253 spans = traces[0].data.spans 254 agent1_span = spans[0] 255 assert agent1_span.name == "invoke_agent agent1" 256 tool_span = spans[3] 257 assert tool_span.span_type == SpanType.TOOL 258 agent2_span = spans[4] 259 assert agent2_span.name == "invoke_agent agent2" 260 assert agent1_span.inputs == [{"role": "user", "content": [{"text": "add numbers 1 2"}]}] 261 assert agent1_span.outputs == 3 262 assert tool_span.inputs == {"a": 1, "b": 2} 263 assert tool_span.outputs == [{"json": 3}] 264 assert agent2_span.inputs == [{"role": "user", "content": [{"text": "hello"}]}] 265 assert agent2_span.outputs.strip() == "hi" 266 # top-level span should contain the sum of both the chat spans. this is set 267 # when we translate the genai semantic conventions into mlflow attributes. 268 assert agent1_span.attributes[SpanAttributeKey.CHAT_USAGE] == { 269 "input_tokens": 2, 270 "output_tokens": 2, 271 "total_tokens": 4, 272 } 273 # agent2 span should contain the token usage for its single chat span 274 assert agent2_span.attributes[SpanAttributeKey.CHAT_USAGE] == { 275 "input_tokens": 1, 276 "output_tokens": 1, 277 "total_tokens": 2, 278 } 279 280 281 def test_autolog_disable_prevents_new_traces(): 282 mlflow.strands.autolog() 283 284 agent1 = Agent(model=DummyModel("hi"), name="agent1") 285 agent2 = Agent(model=DummyModel("cya"), name="agent2") 286 287 agent1("hello") 288 assert len(get_traces()) == 1 289 290 mlflow.strands.autolog(disable=True) 291 agent2("bye") 292 assert len(get_traces()) == 1 293 294 295 def test_autolog_does_not_raise_npe_when_tracing_disabled(): 296 mlflow.strands.autolog() 297 298 agent = Agent(model=DummyModel("hi"), name="agent") 299 300 @trace_disabled 301 def run(): 302 agent("hello") 303 304 run() 305 assert len(get_traces()) == 0 306 307 308 def test_strands_autolog_shared_provider_no_recursion(monkeypatch): 309 # Verify strands.autolog() works with shared tracer provider (no RecursionError) 310 monkeypatch.setenv(MLFLOW_USE_DEFAULT_TRACER_PROVIDER.name, "false") 311 312 mlflow.strands.autolog() 313 314 agent = Agent(model=DummyModel("hi"), name="agent") 315 agent("hello") 316 317 traces = get_traces() 318 assert len(traces) == 1 319 spans = traces[0].data.spans 320 agent_span = next(span for span in spans if span.span_type == SpanType.AGENT) 321 assert agent_span.inputs == [{"role": "user", "content": [{"text": "hello"}]}] 322 assert agent_span.outputs.strip() == "hi"