/ tests / gateway / test_tracing_utils.py
test_tracing_utils.py
   1  import json
   2  from typing import Any
   3  
   4  import pytest
   5  
   6  import mlflow
   7  from mlflow.entities import SpanType
   8  from mlflow.gateway.schemas.chat import StreamResponsePayload
   9  from mlflow.gateway.tracing_utils import (
  10      _get_model_span_info,
  11      aggregate_anthropic_messages_stream_chunks,
  12      aggregate_chat_stream_chunks,
  13      aggregate_gemini_stream_generate_content_chunks,
  14      aggregate_openai_responses_stream_chunks,
  15      maybe_traced_gateway_call,
  16  )
  17  from mlflow.store.tracking.gateway.entities import GatewayEndpointConfig
  18  from mlflow.tracing.client import TracingClient
  19  from mlflow.tracing.constant import SpanAttributeKey, TraceMetadataKey
  20  from mlflow.tracing.distributed import get_tracing_context_headers_for_http_request
  21  from mlflow.tracking.fluent import _get_experiment_id
  22  from mlflow.types.chat import ChatChoiceDelta, ChatChunkChoice, ChatUsage, Function, ToolCallDelta
  23  
  24  
  25  @pytest.fixture
  26  def gateway_experiment_id():
  27      experiment_name = "gateway-test-endpoint"
  28      experiment = mlflow.get_experiment_by_name(experiment_name)
  29      if experiment is not None:
  30          return experiment.experiment_id
  31      return mlflow.create_experiment(experiment_name)
  32  
  33  
  34  def get_traces():
  35      return TracingClient().search_traces(locations=[_get_experiment_id()])
  36  
  37  
  38  @pytest.fixture
  39  def endpoint_config():
  40      return GatewayEndpointConfig(
  41          endpoint_id="test-endpoint-id",
  42          endpoint_name="test-endpoint",
  43          experiment_id=_get_experiment_id(),
  44          usage_tracking=True,
  45          models=[],
  46      )
  47  
  48  
  49  @pytest.fixture
  50  def endpoint_config_no_experiment():
  51      return GatewayEndpointConfig(
  52          endpoint_id="test-endpoint-id",
  53          endpoint_name="test-endpoint",
  54          experiment_id=None,
  55          models=[],
  56      )
  57  
  58  
  59  async def mock_async_func(payload):
  60      return {"result": "success", "payload": payload}
  61  
  62  
  63  def _make_chunk(
  64      content=None,
  65      finish_reason=None,
  66      id="chunk-1",
  67      model="test-model",
  68      created=1700000000,
  69      usage=None,
  70      tool_calls=None,
  71      role="assistant",
  72      choice_index=0,
  73  ):
  74      delta = ChatChoiceDelta(role=role, content=content, tool_calls=tool_calls)
  75      choice = ChatChunkChoice(index=choice_index, finish_reason=finish_reason, delta=delta)
  76      return StreamResponsePayload(
  77          id=id,
  78          created=created,
  79          model=model,
  80          choices=[choice],
  81          usage=usage,
  82      )
  83  
  84  
  85  def test_aggregate_chat_stream_chunks_aggregates_content():
  86      chunks = [
  87          _make_chunk(content="Hello"),
  88          _make_chunk(content=" "),
  89          _make_chunk(content="world"),
  90          _make_chunk(content=None, finish_reason="stop"),
  91      ]
  92      result = aggregate_chat_stream_chunks(chunks)
  93  
  94      assert result["object"] == "chat.completion"
  95      assert result["model"] == "test-model"
  96      assert result["choices"][0]["message"]["role"] == "assistant"
  97      assert result["choices"][0]["message"]["content"] == "Hello world"
  98      assert result["choices"][0]["finish_reason"] == "stop"
  99  
 100  
 101  def test_aggregate_chat_stream_chunks_with_usage():
 102      usage = ChatUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15)
 103      chunks = [
 104          _make_chunk(content="Hi"),
 105          _make_chunk(content=None, finish_reason="stop", usage=usage),
 106      ]
 107      result = aggregate_chat_stream_chunks(chunks)
 108  
 109      assert result["choices"][0]["message"]["content"] == "Hi"
 110      assert result["usage"] == {
 111          "prompt_tokens": 10,
 112          "completion_tokens": 5,
 113          "total_tokens": 15,
 114      }
 115  
 116  
 117  def test_aggregate_chat_stream_chunks_empty():
 118      assert aggregate_chat_stream_chunks([]) is None
 119  
 120  
 121  def test_aggregate_chat_stream_chunks_defaults_finish_reason():
 122      chunks = [_make_chunk(content="Hi")]
 123      result = aggregate_chat_stream_chunks(chunks)
 124  
 125      assert result["choices"][0]["finish_reason"] == "stop"
 126  
 127  
 128  def test_reduce_chat_stream_chunks_aggregates_tool_calls():
 129      chunks = [
 130          # First chunk: tool call id, type, and function name
 131          _make_chunk(
 132              tool_calls=[
 133                  ToolCallDelta(
 134                      index=0,
 135                      id="call_abc",
 136                      type="function",
 137                      function=Function(name="get_weather", arguments=""),
 138                  ),
 139              ],
 140          ),
 141          # Subsequent chunks: argument fragments
 142          _make_chunk(
 143              tool_calls=[
 144                  ToolCallDelta(index=0, function=Function(arguments='{"loc')),
 145              ],
 146          ),
 147          _make_chunk(
 148              tool_calls=[
 149                  ToolCallDelta(index=0, function=Function(arguments='ation": "SF"}')),
 150              ],
 151          ),
 152          _make_chunk(finish_reason="tool_calls"),
 153      ]
 154      result = aggregate_chat_stream_chunks(chunks)
 155  
 156      assert result["choices"][0]["message"]["content"] is None
 157      assert result["choices"][0]["finish_reason"] == "tool_calls"
 158  
 159      tool_calls = result["choices"][0]["message"]["tool_calls"]
 160      assert len(tool_calls) == 1
 161      assert tool_calls[0]["id"] == "call_abc"
 162      assert tool_calls[0]["type"] == "function"
 163      assert tool_calls[0]["function"]["name"] == "get_weather"
 164      assert tool_calls[0]["function"]["arguments"] == '{"location": "SF"}'
 165  
 166  
 167  def test_reduce_chat_stream_chunks_derives_role_from_delta():
 168      chunks = [
 169          _make_chunk(role="developer", content="Hello"),
 170          _make_chunk(role=None, content=" world"),
 171          _make_chunk(role=None, finish_reason="stop"),
 172      ]
 173      result = aggregate_chat_stream_chunks(chunks)
 174  
 175      assert result["choices"][0]["message"]["role"] == "developer"
 176  
 177  
 178  def test_reduce_chat_stream_chunks_multiple_choice_indices():
 179      chunks = [
 180          _make_chunk(content="Hi", choice_index=0),
 181          _make_chunk(content="Hey", choice_index=1),
 182          _make_chunk(content=" there", choice_index=0),
 183          _make_chunk(content=" you", choice_index=1),
 184          _make_chunk(finish_reason="stop", choice_index=0),
 185          _make_chunk(finish_reason="stop", choice_index=1),
 186      ]
 187      result = aggregate_chat_stream_chunks(chunks)
 188  
 189      assert len(result["choices"]) == 2
 190      assert result["choices"][0]["index"] == 0
 191      assert result["choices"][0]["message"]["content"] == "Hi there"
 192      assert result["choices"][1]["index"] == 1
 193      assert result["choices"][1]["message"]["content"] == "Hey you"
 194  
 195  
 196  @pytest.mark.asyncio
 197  async def test_maybe_traced_gateway_call_basic(endpoint_config):
 198      traced_func = maybe_traced_gateway_call(mock_async_func, endpoint_config)
 199      result = await traced_func({"input": "test"})
 200  
 201      assert result == {"result": "success", "payload": {"input": "test"}}
 202  
 203      traces = get_traces()
 204      assert len(traces) == 1
 205      trace = traces[0]
 206  
 207      # Find the gateway span
 208      span_name_to_span = {span.name: span for span in trace.data.spans}
 209      assert f"gateway/{endpoint_config.endpoint_name}" in span_name_to_span
 210  
 211      gateway_span = span_name_to_span[f"gateway/{endpoint_config.endpoint_name}"]
 212      assert gateway_span.attributes.get("endpoint_id") == "test-endpoint-id"
 213      assert gateway_span.attributes.get("endpoint_name") == "test-endpoint"
 214      # Input should be unwrapped (not nested under "payload" key)
 215      assert gateway_span.inputs == {"input": "test"}
 216      # No user metadata should be present in trace
 217      assert trace.info.request_metadata.get(TraceMetadataKey.AUTH_USERNAME) is None
 218      assert trace.info.request_metadata.get(TraceMetadataKey.AUTH_USER_ID) is None
 219  
 220  
 221  @pytest.mark.asyncio
 222  async def test_maybe_traced_gateway_call_with_user_metadata(endpoint_config):
 223      traced_func = maybe_traced_gateway_call(
 224          mock_async_func,
 225          endpoint_config,
 226          metadata={
 227              TraceMetadataKey.AUTH_USERNAME: "alice",
 228              TraceMetadataKey.AUTH_USER_ID: "123",
 229          },
 230      )
 231      result = await traced_func({"input": "test"})
 232  
 233      assert result == {"result": "success", "payload": {"input": "test"}}
 234  
 235      traces = get_traces()
 236      assert len(traces) == 1
 237      trace = traces[0]
 238  
 239      span_name_to_span = {span.name: span for span in trace.data.spans}
 240      gateway_span = span_name_to_span[f"gateway/{endpoint_config.endpoint_name}"]
 241  
 242      assert gateway_span.attributes.get("endpoint_id") == "test-endpoint-id"
 243      assert gateway_span.attributes.get("endpoint_name") == "test-endpoint"
 244      # Input should be unwrapped (not nested under "payload" key)
 245      assert gateway_span.inputs == {"input": "test"}
 246      # User metadata should be in trace info, not span attributes
 247      assert trace.info.request_metadata.get(TraceMetadataKey.AUTH_USERNAME) == "alice"
 248      assert trace.info.request_metadata.get(TraceMetadataKey.AUTH_USER_ID) == "123"
 249  
 250  
 251  @pytest.mark.asyncio
 252  async def test_maybe_traced_gateway_call_without_usage_tracking(endpoint_config_no_experiment):
 253      traced_func = maybe_traced_gateway_call(
 254          mock_async_func,
 255          endpoint_config_no_experiment,
 256          metadata={
 257              TraceMetadataKey.AUTH_USERNAME: "alice",
 258              TraceMetadataKey.AUTH_USER_ID: "123",
 259          },
 260      )
 261  
 262      # When usage_tracking is False, maybe_traced_gateway_call returns the original function
 263      assert traced_func is mock_async_func
 264  
 265      result = await traced_func({"input": "test"})
 266      assert result == {"result": "success", "payload": {"input": "test"}}
 267  
 268      # No traces should be created
 269      traces = get_traces()
 270      assert len(traces) == 0
 271  
 272  
 273  @pytest.mark.asyncio
 274  async def test_maybe_traced_gateway_call_with_output_reducer(endpoint_config):
 275      async def mock_async_stream(payload):
 276          yield _make_chunk(content="Hello")
 277          yield _make_chunk(content=" world")
 278          yield _make_chunk(
 279              content=None,
 280              finish_reason="stop",
 281              usage=ChatUsage(prompt_tokens=5, completion_tokens=2, total_tokens=7),
 282          )
 283  
 284      traced_func = maybe_traced_gateway_call(
 285          mock_async_stream,
 286          endpoint_config,
 287          output_reducer=aggregate_chat_stream_chunks,
 288      )
 289  
 290      # Consume the stream
 291      chunks = [
 292          chunk async for chunk in traced_func({"messages": [{"role": "user", "content": "hi"}]})
 293      ]
 294      assert len(chunks) == 3
 295  
 296      traces = get_traces()
 297      assert len(traces) == 1
 298      trace = traces[0]
 299  
 300      span_name_to_span = {span.name: span for span in trace.data.spans}
 301      gateway_span = span_name_to_span[f"gateway/{endpoint_config.endpoint_name}"]
 302  
 303      # Input should be unwrapped (not nested under "payload" key)
 304      assert gateway_span.inputs == {"messages": [{"role": "user", "content": "hi"}]}
 305  
 306      # The output should be the reduced aggregated response, not raw chunks
 307      output = gateway_span.outputs
 308      assert output["object"] == "chat.completion"
 309      assert output["choices"][0]["message"]["content"] == "Hello world"
 310      assert output["choices"][0]["finish_reason"] == "stop"
 311      assert output["usage"]["total_tokens"] == 7
 312  
 313  
 314  @pytest.mark.asyncio
 315  async def test_maybe_traced_gateway_call_with_payload_kwarg(endpoint_config):
 316      async def mock_passthrough_func(action, payload, headers=None):
 317          return {"result": "success", "action": action, "payload": payload}
 318  
 319      traced_func = maybe_traced_gateway_call(mock_passthrough_func, endpoint_config)
 320      result = await traced_func(
 321          action="test_action", payload={"messages": [{"role": "user", "content": "hi"}]}, headers={}
 322      )
 323  
 324      assert result["result"] == "success"
 325  
 326      traces = get_traces()
 327      assert len(traces) == 1
 328      trace = traces[0]
 329  
 330      span_name_to_span = {span.name: span for span in trace.data.spans}
 331      gateway_span = span_name_to_span[f"gateway/{endpoint_config.endpoint_name}"]
 332  
 333      # Input should be unwrapped to just the payload dict
 334      assert gateway_span.inputs == {"messages": [{"role": "user", "content": "hi"}]}
 335  
 336  
 337  # ---------------------------------------------------------------------------
 338  # Tests for distributed tracing helpers
 339  # ---------------------------------------------------------------------------
 340  
 341  
 342  @pytest.mark.asyncio
 343  async def test_get_model_span_info_reads_child_span(endpoint_config):
 344      async def func_with_child_span(payload):
 345          with mlflow.start_span("provider/openai/gpt-4", span_type=SpanType.LLM) as child:
 346              child.set_attributes({
 347                  SpanAttributeKey.CHAT_USAGE: {
 348                      "input_tokens": 10,
 349                      "output_tokens": 5,
 350                      "total_tokens": 15,
 351                  },
 352                  SpanAttributeKey.MODEL: "gpt-4",
 353                  SpanAttributeKey.MODEL_PROVIDER: "openai",
 354              })
 355          return {"result": "ok"}
 356  
 357      traced = maybe_traced_gateway_call(func_with_child_span, endpoint_config)
 358      await traced({"input": "test"})
 359  
 360      traces = get_traces()
 361      assert len(traces) == 1
 362      gateway_trace_id = traces[0].info.trace_id
 363  
 364      # After the trace is exported, spans are removed from InMemoryTraceManager,
 365      # so we expect empty here. The actual reading happens inside the wrapper
 366      # while the trace is still in memory.
 367      assert _get_model_span_info(gateway_trace_id) == []
 368  
 369  
 370  # ---------------------------------------------------------------------------
 371  # Integration tests for distributed tracing via traceparent
 372  # ---------------------------------------------------------------------------
 373  
 374  
 375  @pytest.mark.asyncio
 376  async def test_maybe_traced_gateway_call_with_traceparent(gateway_experiment_id):
 377      ep_config = GatewayEndpointConfig(
 378          endpoint_id="test-endpoint-id",
 379          endpoint_name="test-endpoint",
 380          experiment_id=gateway_experiment_id,
 381          usage_tracking=True,
 382          models=[],
 383      )
 384  
 385      async def func_with_usage(payload):
 386          with mlflow.start_span("provider/openai/gpt-4", span_type=SpanType.LLM) as child:
 387              child.set_attributes({
 388                  SpanAttributeKey.CHAT_USAGE: {
 389                      "input_tokens": 10,
 390                      "output_tokens": 5,
 391                      "total_tokens": 15,
 392                  },
 393                  SpanAttributeKey.MODEL: "gpt-4",
 394                  SpanAttributeKey.MODEL_PROVIDER: "openai",
 395              })
 396          return {"result": "ok"}
 397  
 398      # Step 1: Agent creates span and generates traceparent headers
 399      with mlflow.start_span("agent-root") as agent_span:
 400          headers = get_tracing_context_headers_for_http_request()
 401          agent_trace_id = agent_span.trace_id
 402          agent_span_id = agent_span.span_id
 403  
 404      # Step 2: Gateway processes request (no active agent span, simulating separate server)
 405      traced = maybe_traced_gateway_call(func_with_usage, ep_config, request_headers=headers)
 406      result = await traced({"input": "test"})
 407  
 408      assert result == {"result": "ok"}
 409  
 410      # Flush to ensure all spans are written (batch processor may be active)
 411      mlflow.flush_trace_async_logging(terminate=True)
 412  
 413      # Gateway trace should exist in the gateway experiment
 414      gateway_traces = TracingClient().search_traces(locations=[gateway_experiment_id])
 415      assert len(gateway_traces) == 1
 416      gateway_trace_id = gateway_traces[0].info.trace_id
 417  
 418      # The gateway trace should be separate from the agent trace
 419      assert gateway_trace_id != agent_trace_id
 420  
 421      # Agent trace should contain two distributed spans (gateway + provider)
 422      agent_trace = mlflow.get_trace(agent_trace_id)
 423      assert agent_trace is not None
 424  
 425      spans_by_name = {s.name: s for s in agent_trace.data.spans}
 426      assert "agent-root" in spans_by_name
 427      assert f"gateway/{ep_config.endpoint_name}" in spans_by_name
 428      assert "provider/openai/gpt-4" in spans_by_name
 429  
 430      # Gateway span: child of agent root, has endpoint attrs + link
 431      gw_span = spans_by_name[f"gateway/{ep_config.endpoint_name}"]
 432      assert gw_span.parent_id == agent_span_id
 433      assert gw_span.attributes.get("endpoint_id") == ep_config.endpoint_id
 434      assert gw_span.attributes.get("endpoint_name") == ep_config.endpoint_name
 435      assert gw_span.attributes.get(SpanAttributeKey.LINKED_GATEWAY_TRACE_ID) == gateway_trace_id
 436  
 437      # Provider span: child of gateway span, has provider attrs
 438      provider_span = spans_by_name["provider/openai/gpt-4"]
 439      assert provider_span.parent_id == gw_span.span_id
 440      assert provider_span.attributes.get(SpanAttributeKey.CHAT_USAGE) == {
 441          "input_tokens": 10,
 442          "output_tokens": 5,
 443          "total_tokens": 15,
 444      }
 445      assert provider_span.attributes.get(SpanAttributeKey.MODEL) == "gpt-4"
 446      assert provider_span.attributes.get(SpanAttributeKey.MODEL_PROVIDER) == "openai"
 447  
 448      # Provider span should preserve timing from the gateway trace
 449      gateway_provider_span = next(
 450          s for s in gateway_traces[0].data.spans if s.name == "provider/openai/gpt-4"
 451      )
 452      assert provider_span.start_time_ns == gateway_provider_span.start_time_ns
 453      assert provider_span.end_time_ns == gateway_provider_span.end_time_ns
 454  
 455      # Neither span should have request/response payloads
 456      assert gw_span.inputs is None
 457      assert gw_span.outputs is None
 458      assert provider_span.inputs is None
 459      assert provider_span.outputs is None
 460  
 461  
 462  @pytest.mark.asyncio
 463  async def test_maybe_traced_gateway_call_streaming_with_traceparent(gateway_experiment_id):
 464      ep_config = GatewayEndpointConfig(
 465          endpoint_id="test-endpoint-id",
 466          endpoint_name="test-endpoint",
 467          experiment_id=gateway_experiment_id,
 468          usage_tracking=True,
 469          models=[],
 470      )
 471  
 472      async def mock_stream_with_usage(payload):
 473          with mlflow.start_span("provider/openai/gpt-4", span_type=SpanType.LLM) as child:
 474              child.set_attributes({
 475                  SpanAttributeKey.CHAT_USAGE: {
 476                      "input_tokens": 20,
 477                      "output_tokens": 10,
 478                      "total_tokens": 30,
 479                  },
 480                  SpanAttributeKey.MODEL: "gpt-4",
 481                  SpanAttributeKey.MODEL_PROVIDER: "openai",
 482              })
 483          yield _make_chunk(content="Hello")
 484          yield _make_chunk(content=" world", finish_reason="stop")
 485  
 486      # Agent creates headers
 487      with mlflow.start_span("agent-root") as agent_span:
 488          headers = get_tracing_context_headers_for_http_request()
 489          agent_trace_id = agent_span.trace_id
 490          agent_span_id = agent_span.span_id
 491  
 492      # Gateway processes request (separate context)
 493      traced = maybe_traced_gateway_call(
 494          mock_stream_with_usage,
 495          ep_config,
 496          output_reducer=aggregate_chat_stream_chunks,
 497          request_headers=headers,
 498      )
 499      chunks = [chunk async for chunk in traced({"input": "test"})]
 500  
 501      assert len(chunks) == 2
 502  
 503      # Flush to ensure all spans are written (batch processor may be active)
 504      mlflow.flush_trace_async_logging(terminate=True)
 505  
 506      # Gateway trace should exist
 507      gateway_traces = TracingClient().search_traces(locations=[gateway_experiment_id])
 508      assert len(gateway_traces) == 1
 509      gateway_trace_id = gateway_traces[0].info.trace_id
 510      assert gateway_trace_id != agent_trace_id
 511  
 512      # Agent trace should contain two distributed spans (gateway + provider)
 513      agent_trace = mlflow.get_trace(agent_trace_id)
 514      assert agent_trace is not None
 515  
 516      spans_by_name = {s.name: s for s in agent_trace.data.spans}
 517      assert "agent-root" in spans_by_name
 518      assert f"gateway/{ep_config.endpoint_name}" in spans_by_name
 519      assert "provider/openai/gpt-4" in spans_by_name
 520  
 521      # Gateway span: child of agent root, has endpoint attrs + link
 522      gw_span = spans_by_name[f"gateway/{ep_config.endpoint_name}"]
 523      assert gw_span.parent_id == agent_span_id
 524      assert gw_span.attributes.get("endpoint_id") == ep_config.endpoint_id
 525      assert gw_span.attributes.get("endpoint_name") == ep_config.endpoint_name
 526      assert gw_span.attributes.get(SpanAttributeKey.LINKED_GATEWAY_TRACE_ID) == gateway_trace_id
 527  
 528      # Provider span: child of gateway span, has provider attrs
 529      provider_span = spans_by_name["provider/openai/gpt-4"]
 530      assert provider_span.parent_id == gw_span.span_id
 531      assert provider_span.attributes.get(SpanAttributeKey.CHAT_USAGE) == {
 532          "input_tokens": 20,
 533          "output_tokens": 10,
 534          "total_tokens": 30,
 535      }
 536      assert provider_span.attributes.get(SpanAttributeKey.MODEL) == "gpt-4"
 537      assert provider_span.attributes.get(SpanAttributeKey.MODEL_PROVIDER) == "openai"
 538  
 539      # Provider span should preserve timing from the gateway trace
 540      gateway_provider_span = next(
 541          s for s in gateway_traces[0].data.spans if s.name == "provider/openai/gpt-4"
 542      )
 543      assert provider_span.start_time_ns == gateway_provider_span.start_time_ns
 544      assert provider_span.end_time_ns == gateway_provider_span.end_time_ns
 545  
 546      # Neither span should have request/response payloads
 547      assert gw_span.inputs is None
 548      assert gw_span.outputs is None
 549      assert provider_span.inputs is None
 550      assert provider_span.outputs is None
 551  
 552  
 553  @pytest.mark.asyncio
 554  async def test_maybe_traced_gateway_call_with_traceparent_multiple_providers(gateway_experiment_id):
 555      ep_config = GatewayEndpointConfig(
 556          endpoint_id="test-endpoint-id",
 557          endpoint_name="test-endpoint",
 558          experiment_id=gateway_experiment_id,
 559          usage_tracking=True,
 560          models=[],
 561      )
 562  
 563      async def func_with_multiple_providers(payload):
 564          with mlflow.start_span("provider/openai/gpt-4", span_type=SpanType.LLM) as child:
 565              child.set_attributes({
 566                  SpanAttributeKey.CHAT_USAGE: {
 567                      "input_tokens": 10,
 568                      "output_tokens": 5,
 569                      "total_tokens": 15,
 570                  },
 571                  SpanAttributeKey.MODEL: "gpt-4",
 572                  SpanAttributeKey.MODEL_PROVIDER: "openai",
 573              })
 574          with mlflow.start_span("provider/anthropic/claude-3", span_type=SpanType.LLM) as child:
 575              child.set_attributes({
 576                  SpanAttributeKey.CHAT_USAGE: {
 577                      "input_tokens": 20,
 578                      "output_tokens": 10,
 579                      "total_tokens": 30,
 580                  },
 581                  SpanAttributeKey.MODEL: "claude-3",
 582                  SpanAttributeKey.MODEL_PROVIDER: "anthropic",
 583              })
 584          return {"result": "ok"}
 585  
 586      with mlflow.start_span("agent-root") as agent_span:
 587          headers = get_tracing_context_headers_for_http_request()
 588          agent_trace_id = agent_span.trace_id
 589  
 590      traced = maybe_traced_gateway_call(
 591          func_with_multiple_providers, ep_config, request_headers=headers
 592      )
 593      await traced({"input": "test"})
 594  
 595      mlflow.flush_trace_async_logging()
 596      agent_trace = mlflow.get_trace(agent_trace_id)
 597      assert agent_trace is not None
 598  
 599      spans_by_name = {s.name: s for s in agent_trace.data.spans}
 600      gw_span = spans_by_name[f"gateway/{ep_config.endpoint_name}"]
 601  
 602      # Both provider spans should be children of the gateway span
 603      provider_openai = spans_by_name["provider/openai/gpt-4"]
 604      assert provider_openai.parent_id == gw_span.span_id
 605      assert provider_openai.attributes.get(SpanAttributeKey.MODEL) == "gpt-4"
 606      assert provider_openai.attributes.get(SpanAttributeKey.MODEL_PROVIDER) == "openai"
 607      assert provider_openai.attributes.get(SpanAttributeKey.CHAT_USAGE) == {
 608          "input_tokens": 10,
 609          "output_tokens": 5,
 610          "total_tokens": 15,
 611      }
 612  
 613      provider_anthropic = spans_by_name["provider/anthropic/claude-3"]
 614      assert provider_anthropic.parent_id == gw_span.span_id
 615      assert provider_anthropic.attributes.get(SpanAttributeKey.MODEL) == "claude-3"
 616      assert provider_anthropic.attributes.get(SpanAttributeKey.MODEL_PROVIDER) == "anthropic"
 617      assert provider_anthropic.attributes.get(SpanAttributeKey.CHAT_USAGE) == {
 618          "input_tokens": 20,
 619          "output_tokens": 10,
 620          "total_tokens": 30,
 621      }
 622  
 623      # Provider spans should preserve timing from the gateway trace
 624      gateway_traces = TracingClient().search_traces(locations=[gateway_experiment_id])
 625      assert len(gateway_traces) == 1
 626      gw_spans_by_name = {s.name: s for s in gateway_traces[0].data.spans}
 627  
 628      gw_openai = gw_spans_by_name["provider/openai/gpt-4"]
 629      assert provider_openai.start_time_ns == gw_openai.start_time_ns
 630      assert provider_openai.end_time_ns == gw_openai.end_time_ns
 631  
 632      gw_anthropic = gw_spans_by_name["provider/anthropic/claude-3"]
 633      assert provider_anthropic.start_time_ns == gw_anthropic.start_time_ns
 634      assert provider_anthropic.end_time_ns == gw_anthropic.end_time_ns
 635  
 636  
 637  # ---------------------------------------------------------------------------
 638  # Tests for aggregate_anthropic_messages_stream_chunks
 639  # ---------------------------------------------------------------------------
 640  
 641  
 642  def _sse(event: dict[str, Any]) -> bytes:
 643      """Encode a single event dict as an SSE data line."""
 644      return f"data: {json.dumps(event)}\n".encode()
 645  
 646  
 647  def _msg_start(msg_id: str, model: str, input_tokens: int | None = None) -> bytes:
 648      usage = {"input_tokens": input_tokens} if input_tokens is not None else {}
 649      return _sse({
 650          "type": "message_start",
 651          "message": {"id": msg_id, "model": model, "role": "assistant", "usage": usage},
 652      })
 653  
 654  
 655  def _text_block_start(index: int) -> bytes:
 656      return _sse({
 657          "type": "content_block_start",
 658          "index": index,
 659          "content_block": {"type": "text", "text": ""},
 660      })
 661  
 662  
 663  def _text_delta(index: int, text: str) -> bytes:
 664      return _sse({
 665          "type": "content_block_delta",
 666          "index": index,
 667          "delta": {"type": "text_delta", "text": text},
 668      })
 669  
 670  
 671  def _tool_block_start(index: int, tool_id: str, name: str) -> bytes:
 672      return _sse({
 673          "type": "content_block_start",
 674          "index": index,
 675          "content_block": {"type": "tool_use", "id": tool_id, "name": name, "input": {}},
 676      })
 677  
 678  
 679  def _tool_delta(index: int, partial_json: str) -> bytes:
 680      return _sse({
 681          "type": "content_block_delta",
 682          "index": index,
 683          "delta": {"type": "input_json_delta", "partial_json": partial_json},
 684      })
 685  
 686  
 687  def _msg_delta(stop_reason: str, output_tokens: int, stop_sequence: str | None = None) -> bytes:
 688      return _sse({
 689          "type": "message_delta",
 690          "delta": {"stop_reason": stop_reason, "stop_sequence": stop_sequence},
 691          "usage": {"output_tokens": output_tokens},
 692      })
 693  
 694  
 695  def test_aggregate_anthropic_messages_stream_chunks_empty():
 696      assert aggregate_anthropic_messages_stream_chunks([]) is None
 697  
 698  
 699  def test_aggregate_anthropic_messages_stream_chunks_no_parseable_events():
 700      chunks = [b"event: ping\n", b"data: [DONE]\n"]
 701      assert aggregate_anthropic_messages_stream_chunks(chunks) is None
 702  
 703  
 704  def test_aggregate_anthropic_messages_stream_chunks_text():
 705      chunks = [
 706          _msg_start("msg_1", "claude-3-5-sonnet-20241022", input_tokens=10),
 707          _text_block_start(0),
 708          _text_delta(0, "Hello"),
 709          _text_delta(0, " world"),
 710          _sse({"type": "content_block_stop", "index": 0}),
 711          _msg_delta("end_turn", output_tokens=5, stop_sequence=None),
 712          _sse({"type": "message_stop"}),
 713      ]
 714      result = aggregate_anthropic_messages_stream_chunks(chunks)
 715  
 716      assert result["id"] == "msg_1"
 717      assert result["type"] == "message"
 718      assert result["role"] == "assistant"
 719      assert result["model"] == "claude-3-5-sonnet-20241022"
 720      assert result["stop_reason"] == "end_turn"
 721      assert result["stop_sequence"] is None
 722      assert result["content"] == [{"type": "text", "text": "Hello world"}]
 723      assert result["usage"] == {"input_tokens": 10, "output_tokens": 5}
 724  
 725  
 726  def test_aggregate_anthropic_messages_stream_chunks_tool_use():
 727      chunks = [
 728          _msg_start("msg_2", "claude-3-5-sonnet-20241022", input_tokens=20),
 729          _tool_block_start(0, "toolu_abc", "get_weather"),
 730          _tool_delta(0, '{"city"'),
 731          _tool_delta(0, ': "Paris"}'),
 732          _sse({"type": "content_block_stop", "index": 0}),
 733          _msg_delta("tool_use", output_tokens=15, stop_sequence=None),
 734      ]
 735      result = aggregate_anthropic_messages_stream_chunks(chunks)
 736  
 737      assert result["stop_reason"] == "tool_use"
 738      assert len(result["content"]) == 1
 739      block = result["content"][0]
 740      assert block["type"] == "tool_use"
 741      assert block["id"] == "toolu_abc"
 742      assert block["name"] == "get_weather"
 743      assert block["input"] == {"city": "Paris"}
 744      assert result["usage"] == {"input_tokens": 20, "output_tokens": 15}
 745  
 746  
 747  def test_aggregate_anthropic_messages_stream_chunks_mixed_content():
 748      chunks = [
 749          _msg_start("msg_3", "claude-3-5-sonnet-20241022", input_tokens=30),
 750          _text_block_start(0),
 751          _text_delta(0, "Let me check that."),
 752          _sse({"type": "content_block_stop", "index": 0}),
 753          _tool_block_start(1, "toolu_xyz", "search"),
 754          _tool_delta(1, '{"q": "mlflow"}'),
 755          _sse({"type": "content_block_stop", "index": 1}),
 756          _msg_delta("tool_use", output_tokens=25, stop_sequence=None),
 757      ]
 758      result = aggregate_anthropic_messages_stream_chunks(chunks)
 759  
 760      assert len(result["content"]) == 2
 761      assert result["content"][0] == {"type": "text", "text": "Let me check that."}
 762      assert result["content"][1] == {
 763          "type": "tool_use",
 764          "id": "toolu_xyz",
 765          "name": "search",
 766          "input": {"q": "mlflow"},
 767      }
 768  
 769  
 770  def test_aggregate_anthropic_messages_stream_chunks_multiple_chunks_per_sse():
 771      # Multiple SSE events packed into one bytes chunk (newline-separated)
 772      combined = (
 773          _msg_start("msg_4", "claude-3-5-sonnet-20241022", input_tokens=5)
 774          + _text_block_start(0)
 775          + _text_delta(0, "Hi")
 776          + _msg_delta("end_turn", output_tokens=2, stop_sequence=None)
 777      )
 778      result = aggregate_anthropic_messages_stream_chunks([combined])
 779  
 780      assert result["id"] == "msg_4"
 781      assert result["content"] == [{"type": "text", "text": "Hi"}]
 782      assert result["usage"] == {"input_tokens": 5, "output_tokens": 2}
 783  
 784  
 785  @pytest.mark.parametrize(
 786      ("raw_json", "expected_input"),
 787      [
 788          ('{"key": "val"}', {"key": "val"}),
 789          ("", {}),
 790          ("not-valid-json", {}),
 791      ],
 792  )
 793  def test_aggregate_anthropic_messages_stream_chunks_tool_input_edge_cases(raw_json, expected_input):
 794      chunks = [
 795          _msg_start("msg_5", "claude-3-5-sonnet-20241022"),
 796          _tool_block_start(0, "t1", "fn"),
 797      ]
 798      if raw_json:
 799          chunks.append(_tool_delta(0, raw_json))
 800      chunks.append(_msg_delta("tool_use", output_tokens=1))
 801  
 802      result = aggregate_anthropic_messages_stream_chunks(chunks)
 803      assert result["content"][0]["input"] == expected_input
 804  
 805  
 806  def test_aggregate_anthropic_messages_stream_chunks_split_sse_lines():
 807      # Simulate an aiohttp byte chunk that splits a "data:" SSE line mid-way.
 808      # All events should still be parsed correctly after concatenation.
 809      msg_start_bytes = _msg_start("msg_split", "claude-3-5-sonnet-20241022", input_tokens=3)
 810      mid = len(msg_start_bytes) // 2
 811      chunks = [
 812          msg_start_bytes[:mid],
 813          msg_start_bytes[mid:] + _msg_delta("end_turn", output_tokens=1),
 814      ]
 815      result = aggregate_anthropic_messages_stream_chunks(chunks)
 816  
 817      assert result is not None
 818      assert result["id"] == "msg_split"
 819      assert result["usage"] == {"input_tokens": 3, "output_tokens": 1}
 820  
 821  
 822  def test_aggregate_anthropic_messages_stream_chunks_cache_tokens():
 823      chunks = [
 824          _sse({
 825              "type": "message_start",
 826              "message": {
 827                  "id": "msg_cache",
 828                  "model": "claude-3-5-sonnet-20241022",
 829                  "role": "assistant",
 830                  "usage": {
 831                      "input_tokens": 10,
 832                      "cache_read_input_tokens": 5,
 833                      "cache_creation_input_tokens": 2,
 834                  },
 835              },
 836          }),
 837          _msg_delta("end_turn", output_tokens=8),
 838      ]
 839      result = aggregate_anthropic_messages_stream_chunks(chunks)
 840  
 841      assert result["usage"] == {
 842          "input_tokens": 10,
 843          "cache_read_input_tokens": 5,
 844          "cache_creation_input_tokens": 2,
 845          "output_tokens": 8,
 846      }
 847  
 848  
 849  # ---------------------------------------------------------------------------
 850  # Tests for aggregate_openai_responses_stream_chunks
 851  # ---------------------------------------------------------------------------
 852  
 853  _RESPONSES_CREATED = (
 854      b'data: {"type":"response.created","response":{"id":"resp_1","object":"response",'
 855      b'"created_at":1741290958,"status":"in_progress","output":[],"usage":null}}\n'
 856  )
 857  _RESPONSES_TEXT_DELTA = (
 858      b'data: {"type":"response.output_text.delta","item_id":"msg_1",'
 859      b'"output_index":0,"content_index":0,"delta":"Hi"}\n'
 860  )
 861  _RESPONSES_TEXT_DONE = (
 862      b'data: {"type":"response.output_text.done","item_id":"msg_1",'
 863      b'"output_index":0,"content_index":0,"text":"Hi there!"}\n'
 864  )
 865  _RESPONSES_COMPLETED = (
 866      b'data: {"type":"response.completed","response":{"id":"resp_1","object":"response",'
 867      b'"created_at":1741290958,"status":"completed",'
 868      b'"output":[{"id":"msg_1","type":"message","status":"completed","role":"assistant",'
 869      b'"content":[{"type":"output_text","text":"Hi there!","annotations":[]}]}],'
 870      b'"usage":{"input_tokens":37,"output_tokens":11,"total_tokens":48}}}\n'
 871  )
 872  
 873  
 874  def test_aggregate_openai_responses_stream_chunks_empty():
 875      assert aggregate_openai_responses_stream_chunks([]) is None
 876  
 877  
 878  def test_aggregate_openai_responses_stream_chunks_no_completed_event():
 879      chunks = [_RESPONSES_CREATED, _RESPONSES_TEXT_DELTA]
 880      assert aggregate_openai_responses_stream_chunks(chunks) is None
 881  
 882  
 883  def test_aggregate_openai_responses_stream_chunks_basic():
 884      chunks = [
 885          _RESPONSES_CREATED,
 886          _RESPONSES_TEXT_DELTA,
 887          _RESPONSES_TEXT_DONE,
 888          _RESPONSES_COMPLETED,
 889      ]
 890      result = aggregate_openai_responses_stream_chunks(chunks)
 891  
 892      assert result["id"] == "resp_1"
 893      assert result["object"] == "response"
 894      assert result["status"] == "completed"
 895      assert len(result["output"]) == 1
 896      assert result["output"][0]["role"] == "assistant"
 897      assert result["output"][0]["content"][0]["text"] == "Hi there!"
 898      assert result["usage"] == {"input_tokens": 37, "output_tokens": 11, "total_tokens": 48}
 899  
 900  
 901  def test_aggregate_openai_responses_stream_chunks_split_sse_lines():
 902      # Simulate aiohttp yielding a chunk that splits the data: line mid-way.
 903      mid = len(_RESPONSES_COMPLETED) // 2
 904      chunks = [
 905          _RESPONSES_CREATED,
 906          _RESPONSES_COMPLETED[:mid],
 907          _RESPONSES_COMPLETED[mid:],
 908      ]
 909      result = aggregate_openai_responses_stream_chunks(chunks)
 910  
 911      assert result is not None
 912      assert result["id"] == "resp_1"
 913      assert result["status"] == "completed"
 914  
 915  
 916  def test_aggregate_openai_responses_stream_chunks_returns_completed_response():
 917      # When multiple events are packed into a single bytes chunk, the
 918      # completed response is still extracted correctly.
 919      combined = _RESPONSES_CREATED + _RESPONSES_TEXT_DELTA + _RESPONSES_COMPLETED
 920      result = aggregate_openai_responses_stream_chunks([combined])
 921  
 922      assert result["status"] == "completed"
 923      assert result["usage"]["total_tokens"] == 48
 924  
 925  
 926  # ---------------------------------------------------------------------------
 927  # Tests for aggregate_gemini_stream_generate_content_chunks
 928  # ---------------------------------------------------------------------------
 929  
 930  
 931  def _gemini_sse(event: dict[str, Any]) -> bytes:
 932      return f"data: {json.dumps(event)}\n".encode()
 933  
 934  
 935  def _gemini_text_chunk(text: str, finish_reason: str | None = None) -> bytes:
 936      candidate: dict[str, Any] = {"content": {"parts": [{"text": text}], "role": "model"}}
 937      if finish_reason:
 938          candidate["finishReason"] = finish_reason
 939      return _gemini_sse({"candidates": [candidate]})
 940  
 941  
 942  def test_aggregate_gemini_stream_chunks_empty():
 943      assert aggregate_gemini_stream_generate_content_chunks([]) is None
 944  
 945  
 946  def test_aggregate_gemini_stream_chunks_no_parseable_events():
 947      chunks = [b"event: ping\n", b"data: [DONE]\n"]
 948      assert aggregate_gemini_stream_generate_content_chunks(chunks) is None
 949  
 950  
 951  def test_aggregate_gemini_stream_chunks_text():
 952      chunks = [
 953          _gemini_text_chunk("Hello"),
 954          _gemini_text_chunk(" world", finish_reason="STOP"),
 955          _gemini_sse({
 956              "usageMetadata": {
 957                  "promptTokenCount": 10,
 958                  "candidatesTokenCount": 5,
 959                  "totalTokenCount": 15,
 960              }
 961          }),
 962      ]
 963      result = aggregate_gemini_stream_generate_content_chunks(chunks)
 964  
 965      assert len(result["candidates"]) == 1
 966      cand = result["candidates"][0]
 967      assert cand["content"]["parts"] == [{"text": "Hello world"}]
 968      assert cand["content"]["role"] == "model"
 969      assert cand["finishReason"] == "STOP"
 970      assert result["usageMetadata"] == {
 971          "promptTokenCount": 10,
 972          "candidatesTokenCount": 5,
 973          "totalTokenCount": 15,
 974      }
 975  
 976  
 977  def test_aggregate_gemini_stream_chunks_tool_call():
 978      chunks = [
 979          _gemini_sse({
 980              "candidates": [
 981                  {
 982                      "content": {
 983                          "parts": [
 984                              {"functionCall": {"name": "get_weather", "args": {"city": "Paris"}}}
 985                          ],
 986                          "role": "model",
 987                      },
 988                      "finishReason": "STOP",
 989                      "index": 0,
 990                  }
 991              ]
 992          }),
 993          _gemini_sse({
 994              "usageMetadata": {
 995                  "promptTokenCount": 8,
 996                  "candidatesTokenCount": 12,
 997                  "totalTokenCount": 20,
 998              }
 999          }),
1000      ]
1001      result = aggregate_gemini_stream_generate_content_chunks(chunks)
1002  
1003      cand = result["candidates"][0]
1004      assert cand["content"]["parts"] == [
1005          {"functionCall": {"name": "get_weather", "args": {"city": "Paris"}}}
1006      ]
1007      assert cand["finishReason"] == "STOP"
1008  
1009  
1010  def test_aggregate_gemini_stream_chunks_split_sse_lines():
1011      chunk_bytes = _gemini_text_chunk("Hi", finish_reason="STOP")
1012      mid = len(chunk_bytes) // 2
1013      result = aggregate_gemini_stream_generate_content_chunks([chunk_bytes[:mid], chunk_bytes[mid:]])
1014  
1015      assert result is not None
1016      assert result["candidates"][0]["content"]["parts"] == [{"text": "Hi"}]
1017  
1018  
1019  @pytest.mark.parametrize(
1020      ("finish_reasons", "expected"),
1021      [
1022          ([None, None, "STOP"], "STOP"),
1023          ([None, "stop", None], "stop"),
1024          ([None, None, None], None),
1025      ],
1026  )
1027  def test_aggregate_gemini_stream_chunks_finish_reason(finish_reasons, expected):
1028      chunks = [_gemini_text_chunk(f"t{i}", finish_reason=fr) for i, fr in enumerate(finish_reasons)]
1029      result = aggregate_gemini_stream_generate_content_chunks(chunks)
1030      assert result["candidates"][0]["finishReason"] == expected