test_tracing_utils.py
1 import json 2 from typing import Any 3 4 import pytest 5 6 import mlflow 7 from mlflow.entities import SpanType 8 from mlflow.gateway.schemas.chat import StreamResponsePayload 9 from mlflow.gateway.tracing_utils import ( 10 _get_model_span_info, 11 aggregate_anthropic_messages_stream_chunks, 12 aggregate_chat_stream_chunks, 13 aggregate_gemini_stream_generate_content_chunks, 14 aggregate_openai_responses_stream_chunks, 15 maybe_traced_gateway_call, 16 ) 17 from mlflow.store.tracking.gateway.entities import GatewayEndpointConfig 18 from mlflow.tracing.client import TracingClient 19 from mlflow.tracing.constant import SpanAttributeKey, TraceMetadataKey 20 from mlflow.tracing.distributed import get_tracing_context_headers_for_http_request 21 from mlflow.tracking.fluent import _get_experiment_id 22 from mlflow.types.chat import ChatChoiceDelta, ChatChunkChoice, ChatUsage, Function, ToolCallDelta 23 24 25 @pytest.fixture 26 def gateway_experiment_id(): 27 experiment_name = "gateway-test-endpoint" 28 experiment = mlflow.get_experiment_by_name(experiment_name) 29 if experiment is not None: 30 return experiment.experiment_id 31 return mlflow.create_experiment(experiment_name) 32 33 34 def get_traces(): 35 return TracingClient().search_traces(locations=[_get_experiment_id()]) 36 37 38 @pytest.fixture 39 def endpoint_config(): 40 return GatewayEndpointConfig( 41 endpoint_id="test-endpoint-id", 42 endpoint_name="test-endpoint", 43 experiment_id=_get_experiment_id(), 44 usage_tracking=True, 45 models=[], 46 ) 47 48 49 @pytest.fixture 50 def endpoint_config_no_experiment(): 51 return GatewayEndpointConfig( 52 endpoint_id="test-endpoint-id", 53 endpoint_name="test-endpoint", 54 experiment_id=None, 55 models=[], 56 ) 57 58 59 async def mock_async_func(payload): 60 return {"result": "success", "payload": payload} 61 62 63 def _make_chunk( 64 content=None, 65 finish_reason=None, 66 id="chunk-1", 67 model="test-model", 68 created=1700000000, 69 usage=None, 70 tool_calls=None, 71 role="assistant", 72 choice_index=0, 73 ): 74 delta = ChatChoiceDelta(role=role, content=content, tool_calls=tool_calls) 75 choice = ChatChunkChoice(index=choice_index, finish_reason=finish_reason, delta=delta) 76 return StreamResponsePayload( 77 id=id, 78 created=created, 79 model=model, 80 choices=[choice], 81 usage=usage, 82 ) 83 84 85 def test_aggregate_chat_stream_chunks_aggregates_content(): 86 chunks = [ 87 _make_chunk(content="Hello"), 88 _make_chunk(content=" "), 89 _make_chunk(content="world"), 90 _make_chunk(content=None, finish_reason="stop"), 91 ] 92 result = aggregate_chat_stream_chunks(chunks) 93 94 assert result["object"] == "chat.completion" 95 assert result["model"] == "test-model" 96 assert result["choices"][0]["message"]["role"] == "assistant" 97 assert result["choices"][0]["message"]["content"] == "Hello world" 98 assert result["choices"][0]["finish_reason"] == "stop" 99 100 101 def test_aggregate_chat_stream_chunks_with_usage(): 102 usage = ChatUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15) 103 chunks = [ 104 _make_chunk(content="Hi"), 105 _make_chunk(content=None, finish_reason="stop", usage=usage), 106 ] 107 result = aggregate_chat_stream_chunks(chunks) 108 109 assert result["choices"][0]["message"]["content"] == "Hi" 110 assert result["usage"] == { 111 "prompt_tokens": 10, 112 "completion_tokens": 5, 113 "total_tokens": 15, 114 } 115 116 117 def test_aggregate_chat_stream_chunks_empty(): 118 assert aggregate_chat_stream_chunks([]) is None 119 120 121 def test_aggregate_chat_stream_chunks_defaults_finish_reason(): 122 chunks = [_make_chunk(content="Hi")] 123 result = aggregate_chat_stream_chunks(chunks) 124 125 assert result["choices"][0]["finish_reason"] == "stop" 126 127 128 def test_reduce_chat_stream_chunks_aggregates_tool_calls(): 129 chunks = [ 130 # First chunk: tool call id, type, and function name 131 _make_chunk( 132 tool_calls=[ 133 ToolCallDelta( 134 index=0, 135 id="call_abc", 136 type="function", 137 function=Function(name="get_weather", arguments=""), 138 ), 139 ], 140 ), 141 # Subsequent chunks: argument fragments 142 _make_chunk( 143 tool_calls=[ 144 ToolCallDelta(index=0, function=Function(arguments='{"loc')), 145 ], 146 ), 147 _make_chunk( 148 tool_calls=[ 149 ToolCallDelta(index=0, function=Function(arguments='ation": "SF"}')), 150 ], 151 ), 152 _make_chunk(finish_reason="tool_calls"), 153 ] 154 result = aggregate_chat_stream_chunks(chunks) 155 156 assert result["choices"][0]["message"]["content"] is None 157 assert result["choices"][0]["finish_reason"] == "tool_calls" 158 159 tool_calls = result["choices"][0]["message"]["tool_calls"] 160 assert len(tool_calls) == 1 161 assert tool_calls[0]["id"] == "call_abc" 162 assert tool_calls[0]["type"] == "function" 163 assert tool_calls[0]["function"]["name"] == "get_weather" 164 assert tool_calls[0]["function"]["arguments"] == '{"location": "SF"}' 165 166 167 def test_reduce_chat_stream_chunks_derives_role_from_delta(): 168 chunks = [ 169 _make_chunk(role="developer", content="Hello"), 170 _make_chunk(role=None, content=" world"), 171 _make_chunk(role=None, finish_reason="stop"), 172 ] 173 result = aggregate_chat_stream_chunks(chunks) 174 175 assert result["choices"][0]["message"]["role"] == "developer" 176 177 178 def test_reduce_chat_stream_chunks_multiple_choice_indices(): 179 chunks = [ 180 _make_chunk(content="Hi", choice_index=0), 181 _make_chunk(content="Hey", choice_index=1), 182 _make_chunk(content=" there", choice_index=0), 183 _make_chunk(content=" you", choice_index=1), 184 _make_chunk(finish_reason="stop", choice_index=0), 185 _make_chunk(finish_reason="stop", choice_index=1), 186 ] 187 result = aggregate_chat_stream_chunks(chunks) 188 189 assert len(result["choices"]) == 2 190 assert result["choices"][0]["index"] == 0 191 assert result["choices"][0]["message"]["content"] == "Hi there" 192 assert result["choices"][1]["index"] == 1 193 assert result["choices"][1]["message"]["content"] == "Hey you" 194 195 196 @pytest.mark.asyncio 197 async def test_maybe_traced_gateway_call_basic(endpoint_config): 198 traced_func = maybe_traced_gateway_call(mock_async_func, endpoint_config) 199 result = await traced_func({"input": "test"}) 200 201 assert result == {"result": "success", "payload": {"input": "test"}} 202 203 traces = get_traces() 204 assert len(traces) == 1 205 trace = traces[0] 206 207 # Find the gateway span 208 span_name_to_span = {span.name: span for span in trace.data.spans} 209 assert f"gateway/{endpoint_config.endpoint_name}" in span_name_to_span 210 211 gateway_span = span_name_to_span[f"gateway/{endpoint_config.endpoint_name}"] 212 assert gateway_span.attributes.get("endpoint_id") == "test-endpoint-id" 213 assert gateway_span.attributes.get("endpoint_name") == "test-endpoint" 214 # Input should be unwrapped (not nested under "payload" key) 215 assert gateway_span.inputs == {"input": "test"} 216 # No user metadata should be present in trace 217 assert trace.info.request_metadata.get(TraceMetadataKey.AUTH_USERNAME) is None 218 assert trace.info.request_metadata.get(TraceMetadataKey.AUTH_USER_ID) is None 219 220 221 @pytest.mark.asyncio 222 async def test_maybe_traced_gateway_call_with_user_metadata(endpoint_config): 223 traced_func = maybe_traced_gateway_call( 224 mock_async_func, 225 endpoint_config, 226 metadata={ 227 TraceMetadataKey.AUTH_USERNAME: "alice", 228 TraceMetadataKey.AUTH_USER_ID: "123", 229 }, 230 ) 231 result = await traced_func({"input": "test"}) 232 233 assert result == {"result": "success", "payload": {"input": "test"}} 234 235 traces = get_traces() 236 assert len(traces) == 1 237 trace = traces[0] 238 239 span_name_to_span = {span.name: span for span in trace.data.spans} 240 gateway_span = span_name_to_span[f"gateway/{endpoint_config.endpoint_name}"] 241 242 assert gateway_span.attributes.get("endpoint_id") == "test-endpoint-id" 243 assert gateway_span.attributes.get("endpoint_name") == "test-endpoint" 244 # Input should be unwrapped (not nested under "payload" key) 245 assert gateway_span.inputs == {"input": "test"} 246 # User metadata should be in trace info, not span attributes 247 assert trace.info.request_metadata.get(TraceMetadataKey.AUTH_USERNAME) == "alice" 248 assert trace.info.request_metadata.get(TraceMetadataKey.AUTH_USER_ID) == "123" 249 250 251 @pytest.mark.asyncio 252 async def test_maybe_traced_gateway_call_without_usage_tracking(endpoint_config_no_experiment): 253 traced_func = maybe_traced_gateway_call( 254 mock_async_func, 255 endpoint_config_no_experiment, 256 metadata={ 257 TraceMetadataKey.AUTH_USERNAME: "alice", 258 TraceMetadataKey.AUTH_USER_ID: "123", 259 }, 260 ) 261 262 # When usage_tracking is False, maybe_traced_gateway_call returns the original function 263 assert traced_func is mock_async_func 264 265 result = await traced_func({"input": "test"}) 266 assert result == {"result": "success", "payload": {"input": "test"}} 267 268 # No traces should be created 269 traces = get_traces() 270 assert len(traces) == 0 271 272 273 @pytest.mark.asyncio 274 async def test_maybe_traced_gateway_call_with_output_reducer(endpoint_config): 275 async def mock_async_stream(payload): 276 yield _make_chunk(content="Hello") 277 yield _make_chunk(content=" world") 278 yield _make_chunk( 279 content=None, 280 finish_reason="stop", 281 usage=ChatUsage(prompt_tokens=5, completion_tokens=2, total_tokens=7), 282 ) 283 284 traced_func = maybe_traced_gateway_call( 285 mock_async_stream, 286 endpoint_config, 287 output_reducer=aggregate_chat_stream_chunks, 288 ) 289 290 # Consume the stream 291 chunks = [ 292 chunk async for chunk in traced_func({"messages": [{"role": "user", "content": "hi"}]}) 293 ] 294 assert len(chunks) == 3 295 296 traces = get_traces() 297 assert len(traces) == 1 298 trace = traces[0] 299 300 span_name_to_span = {span.name: span for span in trace.data.spans} 301 gateway_span = span_name_to_span[f"gateway/{endpoint_config.endpoint_name}"] 302 303 # Input should be unwrapped (not nested under "payload" key) 304 assert gateway_span.inputs == {"messages": [{"role": "user", "content": "hi"}]} 305 306 # The output should be the reduced aggregated response, not raw chunks 307 output = gateway_span.outputs 308 assert output["object"] == "chat.completion" 309 assert output["choices"][0]["message"]["content"] == "Hello world" 310 assert output["choices"][0]["finish_reason"] == "stop" 311 assert output["usage"]["total_tokens"] == 7 312 313 314 @pytest.mark.asyncio 315 async def test_maybe_traced_gateway_call_with_payload_kwarg(endpoint_config): 316 async def mock_passthrough_func(action, payload, headers=None): 317 return {"result": "success", "action": action, "payload": payload} 318 319 traced_func = maybe_traced_gateway_call(mock_passthrough_func, endpoint_config) 320 result = await traced_func( 321 action="test_action", payload={"messages": [{"role": "user", "content": "hi"}]}, headers={} 322 ) 323 324 assert result["result"] == "success" 325 326 traces = get_traces() 327 assert len(traces) == 1 328 trace = traces[0] 329 330 span_name_to_span = {span.name: span for span in trace.data.spans} 331 gateway_span = span_name_to_span[f"gateway/{endpoint_config.endpoint_name}"] 332 333 # Input should be unwrapped to just the payload dict 334 assert gateway_span.inputs == {"messages": [{"role": "user", "content": "hi"}]} 335 336 337 # --------------------------------------------------------------------------- 338 # Tests for distributed tracing helpers 339 # --------------------------------------------------------------------------- 340 341 342 @pytest.mark.asyncio 343 async def test_get_model_span_info_reads_child_span(endpoint_config): 344 async def func_with_child_span(payload): 345 with mlflow.start_span("provider/openai/gpt-4", span_type=SpanType.LLM) as child: 346 child.set_attributes({ 347 SpanAttributeKey.CHAT_USAGE: { 348 "input_tokens": 10, 349 "output_tokens": 5, 350 "total_tokens": 15, 351 }, 352 SpanAttributeKey.MODEL: "gpt-4", 353 SpanAttributeKey.MODEL_PROVIDER: "openai", 354 }) 355 return {"result": "ok"} 356 357 traced = maybe_traced_gateway_call(func_with_child_span, endpoint_config) 358 await traced({"input": "test"}) 359 360 traces = get_traces() 361 assert len(traces) == 1 362 gateway_trace_id = traces[0].info.trace_id 363 364 # After the trace is exported, spans are removed from InMemoryTraceManager, 365 # so we expect empty here. The actual reading happens inside the wrapper 366 # while the trace is still in memory. 367 assert _get_model_span_info(gateway_trace_id) == [] 368 369 370 # --------------------------------------------------------------------------- 371 # Integration tests for distributed tracing via traceparent 372 # --------------------------------------------------------------------------- 373 374 375 @pytest.mark.asyncio 376 async def test_maybe_traced_gateway_call_with_traceparent(gateway_experiment_id): 377 ep_config = GatewayEndpointConfig( 378 endpoint_id="test-endpoint-id", 379 endpoint_name="test-endpoint", 380 experiment_id=gateway_experiment_id, 381 usage_tracking=True, 382 models=[], 383 ) 384 385 async def func_with_usage(payload): 386 with mlflow.start_span("provider/openai/gpt-4", span_type=SpanType.LLM) as child: 387 child.set_attributes({ 388 SpanAttributeKey.CHAT_USAGE: { 389 "input_tokens": 10, 390 "output_tokens": 5, 391 "total_tokens": 15, 392 }, 393 SpanAttributeKey.MODEL: "gpt-4", 394 SpanAttributeKey.MODEL_PROVIDER: "openai", 395 }) 396 return {"result": "ok"} 397 398 # Step 1: Agent creates span and generates traceparent headers 399 with mlflow.start_span("agent-root") as agent_span: 400 headers = get_tracing_context_headers_for_http_request() 401 agent_trace_id = agent_span.trace_id 402 agent_span_id = agent_span.span_id 403 404 # Step 2: Gateway processes request (no active agent span, simulating separate server) 405 traced = maybe_traced_gateway_call(func_with_usage, ep_config, request_headers=headers) 406 result = await traced({"input": "test"}) 407 408 assert result == {"result": "ok"} 409 410 # Flush to ensure all spans are written (batch processor may be active) 411 mlflow.flush_trace_async_logging(terminate=True) 412 413 # Gateway trace should exist in the gateway experiment 414 gateway_traces = TracingClient().search_traces(locations=[gateway_experiment_id]) 415 assert len(gateway_traces) == 1 416 gateway_trace_id = gateway_traces[0].info.trace_id 417 418 # The gateway trace should be separate from the agent trace 419 assert gateway_trace_id != agent_trace_id 420 421 # Agent trace should contain two distributed spans (gateway + provider) 422 agent_trace = mlflow.get_trace(agent_trace_id) 423 assert agent_trace is not None 424 425 spans_by_name = {s.name: s for s in agent_trace.data.spans} 426 assert "agent-root" in spans_by_name 427 assert f"gateway/{ep_config.endpoint_name}" in spans_by_name 428 assert "provider/openai/gpt-4" in spans_by_name 429 430 # Gateway span: child of agent root, has endpoint attrs + link 431 gw_span = spans_by_name[f"gateway/{ep_config.endpoint_name}"] 432 assert gw_span.parent_id == agent_span_id 433 assert gw_span.attributes.get("endpoint_id") == ep_config.endpoint_id 434 assert gw_span.attributes.get("endpoint_name") == ep_config.endpoint_name 435 assert gw_span.attributes.get(SpanAttributeKey.LINKED_GATEWAY_TRACE_ID) == gateway_trace_id 436 437 # Provider span: child of gateway span, has provider attrs 438 provider_span = spans_by_name["provider/openai/gpt-4"] 439 assert provider_span.parent_id == gw_span.span_id 440 assert provider_span.attributes.get(SpanAttributeKey.CHAT_USAGE) == { 441 "input_tokens": 10, 442 "output_tokens": 5, 443 "total_tokens": 15, 444 } 445 assert provider_span.attributes.get(SpanAttributeKey.MODEL) == "gpt-4" 446 assert provider_span.attributes.get(SpanAttributeKey.MODEL_PROVIDER) == "openai" 447 448 # Provider span should preserve timing from the gateway trace 449 gateway_provider_span = next( 450 s for s in gateway_traces[0].data.spans if s.name == "provider/openai/gpt-4" 451 ) 452 assert provider_span.start_time_ns == gateway_provider_span.start_time_ns 453 assert provider_span.end_time_ns == gateway_provider_span.end_time_ns 454 455 # Neither span should have request/response payloads 456 assert gw_span.inputs is None 457 assert gw_span.outputs is None 458 assert provider_span.inputs is None 459 assert provider_span.outputs is None 460 461 462 @pytest.mark.asyncio 463 async def test_maybe_traced_gateway_call_streaming_with_traceparent(gateway_experiment_id): 464 ep_config = GatewayEndpointConfig( 465 endpoint_id="test-endpoint-id", 466 endpoint_name="test-endpoint", 467 experiment_id=gateway_experiment_id, 468 usage_tracking=True, 469 models=[], 470 ) 471 472 async def mock_stream_with_usage(payload): 473 with mlflow.start_span("provider/openai/gpt-4", span_type=SpanType.LLM) as child: 474 child.set_attributes({ 475 SpanAttributeKey.CHAT_USAGE: { 476 "input_tokens": 20, 477 "output_tokens": 10, 478 "total_tokens": 30, 479 }, 480 SpanAttributeKey.MODEL: "gpt-4", 481 SpanAttributeKey.MODEL_PROVIDER: "openai", 482 }) 483 yield _make_chunk(content="Hello") 484 yield _make_chunk(content=" world", finish_reason="stop") 485 486 # Agent creates headers 487 with mlflow.start_span("agent-root") as agent_span: 488 headers = get_tracing_context_headers_for_http_request() 489 agent_trace_id = agent_span.trace_id 490 agent_span_id = agent_span.span_id 491 492 # Gateway processes request (separate context) 493 traced = maybe_traced_gateway_call( 494 mock_stream_with_usage, 495 ep_config, 496 output_reducer=aggregate_chat_stream_chunks, 497 request_headers=headers, 498 ) 499 chunks = [chunk async for chunk in traced({"input": "test"})] 500 501 assert len(chunks) == 2 502 503 # Flush to ensure all spans are written (batch processor may be active) 504 mlflow.flush_trace_async_logging(terminate=True) 505 506 # Gateway trace should exist 507 gateway_traces = TracingClient().search_traces(locations=[gateway_experiment_id]) 508 assert len(gateway_traces) == 1 509 gateway_trace_id = gateway_traces[0].info.trace_id 510 assert gateway_trace_id != agent_trace_id 511 512 # Agent trace should contain two distributed spans (gateway + provider) 513 agent_trace = mlflow.get_trace(agent_trace_id) 514 assert agent_trace is not None 515 516 spans_by_name = {s.name: s for s in agent_trace.data.spans} 517 assert "agent-root" in spans_by_name 518 assert f"gateway/{ep_config.endpoint_name}" in spans_by_name 519 assert "provider/openai/gpt-4" in spans_by_name 520 521 # Gateway span: child of agent root, has endpoint attrs + link 522 gw_span = spans_by_name[f"gateway/{ep_config.endpoint_name}"] 523 assert gw_span.parent_id == agent_span_id 524 assert gw_span.attributes.get("endpoint_id") == ep_config.endpoint_id 525 assert gw_span.attributes.get("endpoint_name") == ep_config.endpoint_name 526 assert gw_span.attributes.get(SpanAttributeKey.LINKED_GATEWAY_TRACE_ID) == gateway_trace_id 527 528 # Provider span: child of gateway span, has provider attrs 529 provider_span = spans_by_name["provider/openai/gpt-4"] 530 assert provider_span.parent_id == gw_span.span_id 531 assert provider_span.attributes.get(SpanAttributeKey.CHAT_USAGE) == { 532 "input_tokens": 20, 533 "output_tokens": 10, 534 "total_tokens": 30, 535 } 536 assert provider_span.attributes.get(SpanAttributeKey.MODEL) == "gpt-4" 537 assert provider_span.attributes.get(SpanAttributeKey.MODEL_PROVIDER) == "openai" 538 539 # Provider span should preserve timing from the gateway trace 540 gateway_provider_span = next( 541 s for s in gateway_traces[0].data.spans if s.name == "provider/openai/gpt-4" 542 ) 543 assert provider_span.start_time_ns == gateway_provider_span.start_time_ns 544 assert provider_span.end_time_ns == gateway_provider_span.end_time_ns 545 546 # Neither span should have request/response payloads 547 assert gw_span.inputs is None 548 assert gw_span.outputs is None 549 assert provider_span.inputs is None 550 assert provider_span.outputs is None 551 552 553 @pytest.mark.asyncio 554 async def test_maybe_traced_gateway_call_with_traceparent_multiple_providers(gateway_experiment_id): 555 ep_config = GatewayEndpointConfig( 556 endpoint_id="test-endpoint-id", 557 endpoint_name="test-endpoint", 558 experiment_id=gateway_experiment_id, 559 usage_tracking=True, 560 models=[], 561 ) 562 563 async def func_with_multiple_providers(payload): 564 with mlflow.start_span("provider/openai/gpt-4", span_type=SpanType.LLM) as child: 565 child.set_attributes({ 566 SpanAttributeKey.CHAT_USAGE: { 567 "input_tokens": 10, 568 "output_tokens": 5, 569 "total_tokens": 15, 570 }, 571 SpanAttributeKey.MODEL: "gpt-4", 572 SpanAttributeKey.MODEL_PROVIDER: "openai", 573 }) 574 with mlflow.start_span("provider/anthropic/claude-3", span_type=SpanType.LLM) as child: 575 child.set_attributes({ 576 SpanAttributeKey.CHAT_USAGE: { 577 "input_tokens": 20, 578 "output_tokens": 10, 579 "total_tokens": 30, 580 }, 581 SpanAttributeKey.MODEL: "claude-3", 582 SpanAttributeKey.MODEL_PROVIDER: "anthropic", 583 }) 584 return {"result": "ok"} 585 586 with mlflow.start_span("agent-root") as agent_span: 587 headers = get_tracing_context_headers_for_http_request() 588 agent_trace_id = agent_span.trace_id 589 590 traced = maybe_traced_gateway_call( 591 func_with_multiple_providers, ep_config, request_headers=headers 592 ) 593 await traced({"input": "test"}) 594 595 mlflow.flush_trace_async_logging() 596 agent_trace = mlflow.get_trace(agent_trace_id) 597 assert agent_trace is not None 598 599 spans_by_name = {s.name: s for s in agent_trace.data.spans} 600 gw_span = spans_by_name[f"gateway/{ep_config.endpoint_name}"] 601 602 # Both provider spans should be children of the gateway span 603 provider_openai = spans_by_name["provider/openai/gpt-4"] 604 assert provider_openai.parent_id == gw_span.span_id 605 assert provider_openai.attributes.get(SpanAttributeKey.MODEL) == "gpt-4" 606 assert provider_openai.attributes.get(SpanAttributeKey.MODEL_PROVIDER) == "openai" 607 assert provider_openai.attributes.get(SpanAttributeKey.CHAT_USAGE) == { 608 "input_tokens": 10, 609 "output_tokens": 5, 610 "total_tokens": 15, 611 } 612 613 provider_anthropic = spans_by_name["provider/anthropic/claude-3"] 614 assert provider_anthropic.parent_id == gw_span.span_id 615 assert provider_anthropic.attributes.get(SpanAttributeKey.MODEL) == "claude-3" 616 assert provider_anthropic.attributes.get(SpanAttributeKey.MODEL_PROVIDER) == "anthropic" 617 assert provider_anthropic.attributes.get(SpanAttributeKey.CHAT_USAGE) == { 618 "input_tokens": 20, 619 "output_tokens": 10, 620 "total_tokens": 30, 621 } 622 623 # Provider spans should preserve timing from the gateway trace 624 gateway_traces = TracingClient().search_traces(locations=[gateway_experiment_id]) 625 assert len(gateway_traces) == 1 626 gw_spans_by_name = {s.name: s for s in gateway_traces[0].data.spans} 627 628 gw_openai = gw_spans_by_name["provider/openai/gpt-4"] 629 assert provider_openai.start_time_ns == gw_openai.start_time_ns 630 assert provider_openai.end_time_ns == gw_openai.end_time_ns 631 632 gw_anthropic = gw_spans_by_name["provider/anthropic/claude-3"] 633 assert provider_anthropic.start_time_ns == gw_anthropic.start_time_ns 634 assert provider_anthropic.end_time_ns == gw_anthropic.end_time_ns 635 636 637 # --------------------------------------------------------------------------- 638 # Tests for aggregate_anthropic_messages_stream_chunks 639 # --------------------------------------------------------------------------- 640 641 642 def _sse(event: dict[str, Any]) -> bytes: 643 """Encode a single event dict as an SSE data line.""" 644 return f"data: {json.dumps(event)}\n".encode() 645 646 647 def _msg_start(msg_id: str, model: str, input_tokens: int | None = None) -> bytes: 648 usage = {"input_tokens": input_tokens} if input_tokens is not None else {} 649 return _sse({ 650 "type": "message_start", 651 "message": {"id": msg_id, "model": model, "role": "assistant", "usage": usage}, 652 }) 653 654 655 def _text_block_start(index: int) -> bytes: 656 return _sse({ 657 "type": "content_block_start", 658 "index": index, 659 "content_block": {"type": "text", "text": ""}, 660 }) 661 662 663 def _text_delta(index: int, text: str) -> bytes: 664 return _sse({ 665 "type": "content_block_delta", 666 "index": index, 667 "delta": {"type": "text_delta", "text": text}, 668 }) 669 670 671 def _tool_block_start(index: int, tool_id: str, name: str) -> bytes: 672 return _sse({ 673 "type": "content_block_start", 674 "index": index, 675 "content_block": {"type": "tool_use", "id": tool_id, "name": name, "input": {}}, 676 }) 677 678 679 def _tool_delta(index: int, partial_json: str) -> bytes: 680 return _sse({ 681 "type": "content_block_delta", 682 "index": index, 683 "delta": {"type": "input_json_delta", "partial_json": partial_json}, 684 }) 685 686 687 def _msg_delta(stop_reason: str, output_tokens: int, stop_sequence: str | None = None) -> bytes: 688 return _sse({ 689 "type": "message_delta", 690 "delta": {"stop_reason": stop_reason, "stop_sequence": stop_sequence}, 691 "usage": {"output_tokens": output_tokens}, 692 }) 693 694 695 def test_aggregate_anthropic_messages_stream_chunks_empty(): 696 assert aggregate_anthropic_messages_stream_chunks([]) is None 697 698 699 def test_aggregate_anthropic_messages_stream_chunks_no_parseable_events(): 700 chunks = [b"event: ping\n", b"data: [DONE]\n"] 701 assert aggregate_anthropic_messages_stream_chunks(chunks) is None 702 703 704 def test_aggregate_anthropic_messages_stream_chunks_text(): 705 chunks = [ 706 _msg_start("msg_1", "claude-3-5-sonnet-20241022", input_tokens=10), 707 _text_block_start(0), 708 _text_delta(0, "Hello"), 709 _text_delta(0, " world"), 710 _sse({"type": "content_block_stop", "index": 0}), 711 _msg_delta("end_turn", output_tokens=5, stop_sequence=None), 712 _sse({"type": "message_stop"}), 713 ] 714 result = aggregate_anthropic_messages_stream_chunks(chunks) 715 716 assert result["id"] == "msg_1" 717 assert result["type"] == "message" 718 assert result["role"] == "assistant" 719 assert result["model"] == "claude-3-5-sonnet-20241022" 720 assert result["stop_reason"] == "end_turn" 721 assert result["stop_sequence"] is None 722 assert result["content"] == [{"type": "text", "text": "Hello world"}] 723 assert result["usage"] == {"input_tokens": 10, "output_tokens": 5} 724 725 726 def test_aggregate_anthropic_messages_stream_chunks_tool_use(): 727 chunks = [ 728 _msg_start("msg_2", "claude-3-5-sonnet-20241022", input_tokens=20), 729 _tool_block_start(0, "toolu_abc", "get_weather"), 730 _tool_delta(0, '{"city"'), 731 _tool_delta(0, ': "Paris"}'), 732 _sse({"type": "content_block_stop", "index": 0}), 733 _msg_delta("tool_use", output_tokens=15, stop_sequence=None), 734 ] 735 result = aggregate_anthropic_messages_stream_chunks(chunks) 736 737 assert result["stop_reason"] == "tool_use" 738 assert len(result["content"]) == 1 739 block = result["content"][0] 740 assert block["type"] == "tool_use" 741 assert block["id"] == "toolu_abc" 742 assert block["name"] == "get_weather" 743 assert block["input"] == {"city": "Paris"} 744 assert result["usage"] == {"input_tokens": 20, "output_tokens": 15} 745 746 747 def test_aggregate_anthropic_messages_stream_chunks_mixed_content(): 748 chunks = [ 749 _msg_start("msg_3", "claude-3-5-sonnet-20241022", input_tokens=30), 750 _text_block_start(0), 751 _text_delta(0, "Let me check that."), 752 _sse({"type": "content_block_stop", "index": 0}), 753 _tool_block_start(1, "toolu_xyz", "search"), 754 _tool_delta(1, '{"q": "mlflow"}'), 755 _sse({"type": "content_block_stop", "index": 1}), 756 _msg_delta("tool_use", output_tokens=25, stop_sequence=None), 757 ] 758 result = aggregate_anthropic_messages_stream_chunks(chunks) 759 760 assert len(result["content"]) == 2 761 assert result["content"][0] == {"type": "text", "text": "Let me check that."} 762 assert result["content"][1] == { 763 "type": "tool_use", 764 "id": "toolu_xyz", 765 "name": "search", 766 "input": {"q": "mlflow"}, 767 } 768 769 770 def test_aggregate_anthropic_messages_stream_chunks_multiple_chunks_per_sse(): 771 # Multiple SSE events packed into one bytes chunk (newline-separated) 772 combined = ( 773 _msg_start("msg_4", "claude-3-5-sonnet-20241022", input_tokens=5) 774 + _text_block_start(0) 775 + _text_delta(0, "Hi") 776 + _msg_delta("end_turn", output_tokens=2, stop_sequence=None) 777 ) 778 result = aggregate_anthropic_messages_stream_chunks([combined]) 779 780 assert result["id"] == "msg_4" 781 assert result["content"] == [{"type": "text", "text": "Hi"}] 782 assert result["usage"] == {"input_tokens": 5, "output_tokens": 2} 783 784 785 @pytest.mark.parametrize( 786 ("raw_json", "expected_input"), 787 [ 788 ('{"key": "val"}', {"key": "val"}), 789 ("", {}), 790 ("not-valid-json", {}), 791 ], 792 ) 793 def test_aggregate_anthropic_messages_stream_chunks_tool_input_edge_cases(raw_json, expected_input): 794 chunks = [ 795 _msg_start("msg_5", "claude-3-5-sonnet-20241022"), 796 _tool_block_start(0, "t1", "fn"), 797 ] 798 if raw_json: 799 chunks.append(_tool_delta(0, raw_json)) 800 chunks.append(_msg_delta("tool_use", output_tokens=1)) 801 802 result = aggregate_anthropic_messages_stream_chunks(chunks) 803 assert result["content"][0]["input"] == expected_input 804 805 806 def test_aggregate_anthropic_messages_stream_chunks_split_sse_lines(): 807 # Simulate an aiohttp byte chunk that splits a "data:" SSE line mid-way. 808 # All events should still be parsed correctly after concatenation. 809 msg_start_bytes = _msg_start("msg_split", "claude-3-5-sonnet-20241022", input_tokens=3) 810 mid = len(msg_start_bytes) // 2 811 chunks = [ 812 msg_start_bytes[:mid], 813 msg_start_bytes[mid:] + _msg_delta("end_turn", output_tokens=1), 814 ] 815 result = aggregate_anthropic_messages_stream_chunks(chunks) 816 817 assert result is not None 818 assert result["id"] == "msg_split" 819 assert result["usage"] == {"input_tokens": 3, "output_tokens": 1} 820 821 822 def test_aggregate_anthropic_messages_stream_chunks_cache_tokens(): 823 chunks = [ 824 _sse({ 825 "type": "message_start", 826 "message": { 827 "id": "msg_cache", 828 "model": "claude-3-5-sonnet-20241022", 829 "role": "assistant", 830 "usage": { 831 "input_tokens": 10, 832 "cache_read_input_tokens": 5, 833 "cache_creation_input_tokens": 2, 834 }, 835 }, 836 }), 837 _msg_delta("end_turn", output_tokens=8), 838 ] 839 result = aggregate_anthropic_messages_stream_chunks(chunks) 840 841 assert result["usage"] == { 842 "input_tokens": 10, 843 "cache_read_input_tokens": 5, 844 "cache_creation_input_tokens": 2, 845 "output_tokens": 8, 846 } 847 848 849 # --------------------------------------------------------------------------- 850 # Tests for aggregate_openai_responses_stream_chunks 851 # --------------------------------------------------------------------------- 852 853 _RESPONSES_CREATED = ( 854 b'data: {"type":"response.created","response":{"id":"resp_1","object":"response",' 855 b'"created_at":1741290958,"status":"in_progress","output":[],"usage":null}}\n' 856 ) 857 _RESPONSES_TEXT_DELTA = ( 858 b'data: {"type":"response.output_text.delta","item_id":"msg_1",' 859 b'"output_index":0,"content_index":0,"delta":"Hi"}\n' 860 ) 861 _RESPONSES_TEXT_DONE = ( 862 b'data: {"type":"response.output_text.done","item_id":"msg_1",' 863 b'"output_index":0,"content_index":0,"text":"Hi there!"}\n' 864 ) 865 _RESPONSES_COMPLETED = ( 866 b'data: {"type":"response.completed","response":{"id":"resp_1","object":"response",' 867 b'"created_at":1741290958,"status":"completed",' 868 b'"output":[{"id":"msg_1","type":"message","status":"completed","role":"assistant",' 869 b'"content":[{"type":"output_text","text":"Hi there!","annotations":[]}]}],' 870 b'"usage":{"input_tokens":37,"output_tokens":11,"total_tokens":48}}}\n' 871 ) 872 873 874 def test_aggregate_openai_responses_stream_chunks_empty(): 875 assert aggregate_openai_responses_stream_chunks([]) is None 876 877 878 def test_aggregate_openai_responses_stream_chunks_no_completed_event(): 879 chunks = [_RESPONSES_CREATED, _RESPONSES_TEXT_DELTA] 880 assert aggregate_openai_responses_stream_chunks(chunks) is None 881 882 883 def test_aggregate_openai_responses_stream_chunks_basic(): 884 chunks = [ 885 _RESPONSES_CREATED, 886 _RESPONSES_TEXT_DELTA, 887 _RESPONSES_TEXT_DONE, 888 _RESPONSES_COMPLETED, 889 ] 890 result = aggregate_openai_responses_stream_chunks(chunks) 891 892 assert result["id"] == "resp_1" 893 assert result["object"] == "response" 894 assert result["status"] == "completed" 895 assert len(result["output"]) == 1 896 assert result["output"][0]["role"] == "assistant" 897 assert result["output"][0]["content"][0]["text"] == "Hi there!" 898 assert result["usage"] == {"input_tokens": 37, "output_tokens": 11, "total_tokens": 48} 899 900 901 def test_aggregate_openai_responses_stream_chunks_split_sse_lines(): 902 # Simulate aiohttp yielding a chunk that splits the data: line mid-way. 903 mid = len(_RESPONSES_COMPLETED) // 2 904 chunks = [ 905 _RESPONSES_CREATED, 906 _RESPONSES_COMPLETED[:mid], 907 _RESPONSES_COMPLETED[mid:], 908 ] 909 result = aggregate_openai_responses_stream_chunks(chunks) 910 911 assert result is not None 912 assert result["id"] == "resp_1" 913 assert result["status"] == "completed" 914 915 916 def test_aggregate_openai_responses_stream_chunks_returns_completed_response(): 917 # When multiple events are packed into a single bytes chunk, the 918 # completed response is still extracted correctly. 919 combined = _RESPONSES_CREATED + _RESPONSES_TEXT_DELTA + _RESPONSES_COMPLETED 920 result = aggregate_openai_responses_stream_chunks([combined]) 921 922 assert result["status"] == "completed" 923 assert result["usage"]["total_tokens"] == 48 924 925 926 # --------------------------------------------------------------------------- 927 # Tests for aggregate_gemini_stream_generate_content_chunks 928 # --------------------------------------------------------------------------- 929 930 931 def _gemini_sse(event: dict[str, Any]) -> bytes: 932 return f"data: {json.dumps(event)}\n".encode() 933 934 935 def _gemini_text_chunk(text: str, finish_reason: str | None = None) -> bytes: 936 candidate: dict[str, Any] = {"content": {"parts": [{"text": text}], "role": "model"}} 937 if finish_reason: 938 candidate["finishReason"] = finish_reason 939 return _gemini_sse({"candidates": [candidate]}) 940 941 942 def test_aggregate_gemini_stream_chunks_empty(): 943 assert aggregate_gemini_stream_generate_content_chunks([]) is None 944 945 946 def test_aggregate_gemini_stream_chunks_no_parseable_events(): 947 chunks = [b"event: ping\n", b"data: [DONE]\n"] 948 assert aggregate_gemini_stream_generate_content_chunks(chunks) is None 949 950 951 def test_aggregate_gemini_stream_chunks_text(): 952 chunks = [ 953 _gemini_text_chunk("Hello"), 954 _gemini_text_chunk(" world", finish_reason="STOP"), 955 _gemini_sse({ 956 "usageMetadata": { 957 "promptTokenCount": 10, 958 "candidatesTokenCount": 5, 959 "totalTokenCount": 15, 960 } 961 }), 962 ] 963 result = aggregate_gemini_stream_generate_content_chunks(chunks) 964 965 assert len(result["candidates"]) == 1 966 cand = result["candidates"][0] 967 assert cand["content"]["parts"] == [{"text": "Hello world"}] 968 assert cand["content"]["role"] == "model" 969 assert cand["finishReason"] == "STOP" 970 assert result["usageMetadata"] == { 971 "promptTokenCount": 10, 972 "candidatesTokenCount": 5, 973 "totalTokenCount": 15, 974 } 975 976 977 def test_aggregate_gemini_stream_chunks_tool_call(): 978 chunks = [ 979 _gemini_sse({ 980 "candidates": [ 981 { 982 "content": { 983 "parts": [ 984 {"functionCall": {"name": "get_weather", "args": {"city": "Paris"}}} 985 ], 986 "role": "model", 987 }, 988 "finishReason": "STOP", 989 "index": 0, 990 } 991 ] 992 }), 993 _gemini_sse({ 994 "usageMetadata": { 995 "promptTokenCount": 8, 996 "candidatesTokenCount": 12, 997 "totalTokenCount": 20, 998 } 999 }), 1000 ] 1001 result = aggregate_gemini_stream_generate_content_chunks(chunks) 1002 1003 cand = result["candidates"][0] 1004 assert cand["content"]["parts"] == [ 1005 {"functionCall": {"name": "get_weather", "args": {"city": "Paris"}}} 1006 ] 1007 assert cand["finishReason"] == "STOP" 1008 1009 1010 def test_aggregate_gemini_stream_chunks_split_sse_lines(): 1011 chunk_bytes = _gemini_text_chunk("Hi", finish_reason="STOP") 1012 mid = len(chunk_bytes) // 2 1013 result = aggregate_gemini_stream_generate_content_chunks([chunk_bytes[:mid], chunk_bytes[mid:]]) 1014 1015 assert result is not None 1016 assert result["candidates"][0]["content"]["parts"] == [{"text": "Hi"}] 1017 1018 1019 @pytest.mark.parametrize( 1020 ("finish_reasons", "expected"), 1021 [ 1022 ([None, None, "STOP"], "STOP"), 1023 ([None, "stop", None], "stop"), 1024 ([None, None, None], None), 1025 ], 1026 ) 1027 def test_aggregate_gemini_stream_chunks_finish_reason(finish_reasons, expected): 1028 chunks = [_gemini_text_chunk(f"t{i}", finish_reason=fr) for i, fr in enumerate(finish_reasons)] 1029 result = aggregate_gemini_stream_generate_content_chunks(chunks) 1030 assert result["candidates"][0]["finishReason"] == expected