/ tests / genai / utils / test_trace_utils.py
test_trace_utils.py
   1  import asyncio
   2  import json
   3  from typing import Any
   4  from unittest import mock
   5  
   6  import httpx
   7  import numpy as np
   8  import openai
   9  import pandas as pd
  10  import pytest
  11  from opentelemetry.sdk.trace import ReadableSpan as OTelReadableSpan
  12  
  13  import mlflow
  14  from mlflow.entities.assessment import Expectation
  15  from mlflow.entities.assessment_source import AssessmentSource, AssessmentSourceType
  16  from mlflow.entities.dataset_record_source import DatasetRecordSource, DatasetRecordSourceType
  17  from mlflow.entities.span import Span, SpanType
  18  from mlflow.entities.trace import Trace
  19  from mlflow.entities.trace_data import TraceData
  20  from mlflow.genai.evaluation.entities import EvalItem
  21  from mlflow.genai.evaluation.utils import is_none_or_nan
  22  from mlflow.genai.scorers.base import scorer
  23  from mlflow.genai.utils.trace_utils import (
  24      _does_store_support_trace_linking,
  25      _extract_tool_name_from_span,
  26      _should_keep_trace,
  27      _try_extract_available_tools_with_llm,
  28      clean_up_extra_traces,
  29      convert_predict_fn,
  30      create_minimal_trace,
  31      extract_available_tools_from_trace,
  32      extract_expectations_from_trace,
  33      extract_inputs_from_trace,
  34      extract_outputs_from_trace,
  35      extract_request_from_trace,
  36      extract_response_from_trace,
  37      extract_retrieval_context_from_trace,
  38      parse_inputs_to_str,
  39      parse_outputs_to_str,
  40      parse_tool_call_messages_from_trace,
  41      resolve_conversation_from_session,
  42      resolve_expectations_from_session,
  43  )
  44  from mlflow.tracing import set_span_chat_tools
  45  from mlflow.tracing.constant import TraceMetadataKey
  46  from mlflow.tracing.utils import build_otel_context
  47  from mlflow.types.chat import ChatTool, FunctionToolDefinition
  48  
  49  from tests.tracing.helper import create_test_trace_info, get_traces, purge_traces
  50  
  51  
  52  def httpx_send_patch(request, *args, **kwargs):
  53      return httpx.Response(
  54          status_code=200,
  55          request=request,
  56          json={
  57              "id": "chatcmpl-Ax4UAd5xf32KjgLkS1SEEY9oorI9m",
  58              "object": "chat.completion",
  59              "created": 1738641958,
  60              "model": "gpt-4o-2024-08-06",
  61              "choices": [
  62                  {
  63                      "index": 0,
  64                      "message": {
  65                          "role": "assistant",
  66                          "content": "test",
  67                          "refusal": None,
  68                      },
  69                      "logprobs": None,
  70                      "finish_reason": "stop",
  71                  }
  72              ],
  73          },
  74      )
  75  
  76  
  77  def get_openai_predict_fn(with_tracing=False):
  78      if with_tracing:
  79          mlflow.openai.autolog()
  80  
  81      def predict_fn(request):
  82          with mock.patch("httpx.Client.send", side_effect=httpx_send_patch):
  83              response = openai.OpenAI().chat.completions.create(
  84                  messages=request["messages"],
  85                  model="gpt-4o-mini",
  86              )
  87              return response.choices[0].message.content
  88  
  89      return predict_fn
  90  
  91  
  92  def get_dummy_predict_fn(with_tracing=False):
  93      def predict_fn(request):
  94          return "test"
  95  
  96      if with_tracing:
  97          return mlflow.trace(predict_fn)
  98  
  99      return predict_fn
 100  
 101  
 102  @pytest.fixture
 103  def mock_openai_env(monkeypatch):
 104      monkeypatch.setenv("OPENAI_API_KEY", "fake_api_key")
 105  
 106  
 107  @pytest.mark.usefixtures("mock_openai_env")
 108  @pytest.mark.parametrize(
 109      ("predict_fn_generator", "with_tracing", "should_be_wrapped"),
 110      [
 111          (get_dummy_predict_fn, False, True),
 112          # If the function is already traced, it should not be wrapped with @mlflow.trace.
 113          (get_dummy_predict_fn, True, False),
 114          # OpenAI autologging is automatically enabled during evaluation,
 115          # so we don't need to wrap the function with @mlflow.trace.
 116          (get_openai_predict_fn, False, False),
 117          (get_openai_predict_fn, True, False),
 118      ],
 119      ids=[
 120          "dummy predict_fn without tracing",
 121          "dummy predict_fn with tracing",
 122          "openai predict_fn without tracing",
 123          "openai predict_fn with tracing",
 124      ],
 125  )
 126  def test_convert_predict_fn(predict_fn_generator, with_tracing, should_be_wrapped):
 127      predict_fn = predict_fn_generator(with_tracing=with_tracing)
 128      sample_input = {"request": {"messages": [{"role": "user", "content": "test"}]}}
 129  
 130      # predict_fn is callable as is
 131      result = predict_fn(**sample_input)
 132      assert result == "test"
 133      assert len(get_traces()) == (1 if with_tracing else 0)
 134      purge_traces()
 135  
 136      converted_fn = convert_predict_fn(predict_fn, sample_input)
 137  
 138      # converted function takes a single 'request' argument
 139      result = converted_fn(request=sample_input)
 140      assert result == "test"
 141  
 142      # Trace should be generated if decorated or wrapped with @mlflow.trace
 143      assert len(get_traces()) == (1 if with_tracing or should_be_wrapped else 0)
 144      purge_traces()
 145  
 146      # All function should generate a trace when executed through mlflow.genai.evaluate
 147      @scorer
 148      def dummy_scorer(inputs, outputs):
 149          return 0
 150  
 151      mlflow.genai.evaluate(
 152          data=[{"inputs": sample_input}],
 153          predict_fn=predict_fn,
 154          scorers=[dummy_scorer],
 155      )
 156      assert len(get_traces()) == 1
 157  
 158  
 159  def test_convert_predict_fn_skip_validation(monkeypatch):
 160      monkeypatch.setenv("MLFLOW_GENAI_EVAL_SKIP_TRACE_VALIDATION", "true")
 161  
 162      call_count = 0
 163  
 164      def dummy_predict_fn(question: str, context: str):
 165          nonlocal call_count
 166          call_count += 1
 167          return question + context
 168  
 169      sample_input = {"question": "test", "context": "test"}
 170      converted_fn = convert_predict_fn(dummy_predict_fn, sample_input)
 171      # Predict function should not be validated when the env var is set to True
 172      assert call_count == 0
 173  
 174      # converted function takes a single 'request' argument
 175      result = converted_fn(request=sample_input)
 176      assert result == "testtest"
 177  
 178  
 179  def create_span(
 180      span_id: int,
 181      parent_id: int,
 182      span_type: str,
 183      inputs: dict[str, Any],
 184      outputs: dict[str, Any],
 185  ) -> Span:
 186      otel_span = OTelReadableSpan(
 187          name="test",
 188          context=build_otel_context(123, span_id),
 189          parent=build_otel_context(123, parent_id) if parent_id else None,
 190          start_time=100,
 191          end_time=200,
 192          attributes={
 193              "mlflow.spanInputs": json.dumps(inputs),
 194              "mlflow.spanOutputs": json.dumps(outputs),
 195              "mlflow.spanType": json.dumps(span_type),
 196          },
 197      )
 198      return Span(otel_span)
 199  
 200  
 201  @pytest.mark.parametrize(
 202      ("spans", "expected_retrieval_context"),
 203      [
 204          # multiple retrieval steps - only take the last top-level one
 205          (
 206              [
 207                  create_span(
 208                      span_id=1,
 209                      parent_id=None,  # root span
 210                      inputs="question",
 211                      outputs={"generations": [[{"text": "some text"}]]},
 212                      span_type=SpanType.LLM,
 213                  ),
 214                  create_span(
 215                      span_id=2,
 216                      parent_id=1,
 217                      inputs="What is the capital of France?",
 218                      outputs=[
 219                          {
 220                              "page_content": "document content 3",
 221                              "metadata": {
 222                                  "doc_uri": "uri3",
 223                                  "chunk_id": "3",
 224                              },
 225                              "type": "Document",
 226                          },
 227                      ],
 228                      span_type=SpanType.RETRIEVER,
 229                  ),
 230                  create_span(
 231                      span_id=3,
 232                      parent_id=1,
 233                      inputs="What is the capital of France?",
 234                      outputs=[
 235                          {
 236                              "page_content": "document content 1",
 237                              "metadata": {
 238                                  "doc_uri": "uri1",
 239                                  "chunk_id": "1",
 240                              },
 241                              "type": "Document",
 242                          },
 243                          {
 244                              "page_content": "document content 2",
 245                              "metadata": {
 246                                  "doc_uri": "uri2",
 247                                  "chunk_id": "2",
 248                              },
 249                              "type": "Document",
 250                          },
 251                      ],
 252                      span_type=SpanType.RETRIEVER,
 253                  ),
 254                  create_span(
 255                      span_id=4,
 256                      parent_id=3,
 257                      inputs="This should be ignored because it's not a top-level retrieval span",
 258                      outputs=[
 259                          {
 260                              "page_content": "document content 4",
 261                              "metadata": {
 262                                  "doc_uri": "uri4",
 263                                  "chunk_id": "4",
 264                              },
 265                              "type": "Document",
 266                          },
 267                      ],
 268                      span_type=SpanType.RETRIEVER,
 269                  ),
 270              ],
 271              {
 272                  "0000000000000002": [
 273                      {
 274                          "doc_uri": "uri3",
 275                          "content": "document content 3",
 276                      },
 277                  ],
 278                  "0000000000000003": [
 279                      {
 280                          "doc_uri": "uri1",
 281                          "content": "document content 1",
 282                      },
 283                      {
 284                          "doc_uri": "uri2",
 285                          "content": "document content 2",
 286                      },
 287                  ],
 288              },
 289          ),
 290          # one retrieval step
 291          (
 292              [
 293                  create_span(
 294                      span_id=1,
 295                      parent_id=None,
 296                      inputs="What is the capital of France?",
 297                      outputs=[
 298                          {
 299                              "page_content": "document content 1",
 300                              "metadata": {
 301                                  "doc_uri": "uri1",
 302                                  "chunk_id": "1",
 303                              },
 304                              "type": "Document",
 305                          },
 306                          # missing doc_uri
 307                          {
 308                              "page_content": "document content 2",
 309                              "metadata": {
 310                                  "chunk_id": "2",
 311                              },
 312                              "type": "Document",
 313                          },
 314                          # missing content
 315                          {
 316                              "metadata": {
 317                                  "doc_uri": "uri3",
 318                                  "chunk_id": "3",
 319                              },
 320                              "type": "Document",
 321                          },
 322                          # missing metadata
 323                          {
 324                              "page_content": "document content 4",
 325                              "type": "Document",
 326                          },
 327                      ],
 328                      span_type=SpanType.RETRIEVER,
 329                  ),
 330              ],
 331              {
 332                  "0000000000000001": [
 333                      {
 334                          "doc_uri": "uri1",
 335                          "content": "document content 1",
 336                      },
 337                      {
 338                          "content": "document content 2",
 339                      },
 340                      {
 341                          "content": None,
 342                          "doc_uri": "uri3",
 343                      },
 344                      {
 345                          "content": "document content 4",
 346                      },
 347                  ],
 348              },
 349          ),
 350          # one retrieval step - string outputs (UC schema casts attributes to MAP<STRING, STRING>)
 351          (
 352              [
 353                  create_span(
 354                      span_id=1,
 355                      parent_id=None,
 356                      inputs="What is the capital of France?",
 357                      outputs=json.dumps([
 358                          {
 359                              "page_content": "document content 1",
 360                              "metadata": {"doc_uri": "uri1"},
 361                          },
 362                          {
 363                              "page_content": "document content 2",
 364                              "metadata": {"doc_uri": "uri2"},
 365                          },
 366                      ]),
 367                      span_type=SpanType.RETRIEVER,
 368                  ),
 369              ],
 370              {
 371                  "0000000000000001": [
 372                      {"doc_uri": "uri1", "content": "document content 1"},
 373                      {"doc_uri": "uri2", "content": "document content 2"},
 374                  ],
 375              },
 376          ),
 377          # one retrieval step - empty retrieval span outputs
 378          (
 379              [
 380                  create_span(
 381                      span_id=1,
 382                      parent_id=None,
 383                      inputs="What is the capital of France?",
 384                      outputs=[],
 385                      span_type=SpanType.RETRIEVER,
 386                  ),
 387              ],
 388              {"0000000000000001": []},
 389          ),
 390          # one retrieval step - wrong format retrieval span outputs
 391          (
 392              [
 393                  create_span(
 394                      span_id=1,
 395                      parent_id=None,
 396                      inputs="What is the capital of France?",
 397                      outputs=["wrong output", "should be ignored"],
 398                      span_type=SpanType.RETRIEVER,
 399                  ),
 400              ],
 401              {"0000000000000001": []},
 402          ),
 403          # no retrieval steps
 404          (
 405              [
 406                  create_span(
 407                      span_id=1,
 408                      parent_id=None,
 409                      inputs="What is the capital of France?",
 410                      outputs=[{"text": "some text"}],
 411                      span_type=SpanType.LLM,
 412                  ),
 413              ],
 414              {},
 415          ),
 416          # None trace
 417          (
 418              None,
 419              {},
 420          ),
 421      ],
 422  )
 423  def test_get_retrieval_context_from_trace(spans, expected_retrieval_context):
 424      trace = Trace(info=create_test_trace_info(trace_id="tr-123"), data=TraceData(spans=spans))
 425      assert extract_retrieval_context_from_trace(trace) == expected_retrieval_context
 426  
 427  
 428  @pytest.mark.parametrize(
 429      ("input_data", "expected"),
 430      [
 431          # String input
 432          ("Hello world", "Hello world"),
 433          # Chat completion/ChatModel/ChatAgent request
 434          (
 435              {"messages": [{"role": "user", "content": "User message"}]},
 436              "User message",
 437          ),
 438          # Multi-turn messages
 439          (
 440              {
 441                  "messages": [
 442                      {"role": "assistant", "content": "First"},
 443                      {"role": "user", "content": "Second"},
 444                  ]
 445              },
 446              '[{"role": "assistant", "content": "First"}, {"role": "user", "content": "Second"}]',
 447          ),
 448          # Empty dict input
 449          (
 450              {},
 451              "{}",
 452          ),
 453          # Dict input
 454          (
 455              {"unsupported_key": "value"},
 456              "{'unsupported_key': 'value'}",
 457          ),
 458          # Non-standard messages
 459          (
 460              {
 461                  "messages": [
 462                      {"role": "assistant", "k": "First"},
 463                      {"role": "user", "k": "Second"},
 464                  ]
 465              },
 466              "{'messages': [{'role': 'assistant', 'k': 'First'}, {'role': 'user', 'k': 'Second'}]}",
 467          ),
 468          # Strands format - list of messages with role and content
 469          (
 470              [{"role": "user", "content": [{"text": "hello"}]}],
 471              '[{"role": "user", "content": [{"text": "hello"}]}]',
 472          ),
 473          # Strands format - multiple messages with simple string content
 474          (
 475              [
 476                  {"role": "user", "content": "First"},
 477                  {"role": "assistant", "content": "Second"},
 478              ],
 479              '[{"role": "user", "content": "First"}, {"role": "assistant", "content": "Second"}]',
 480          ),
 481          # Strands format - single message with string content
 482          (
 483              [{"role": "user", "content": "Single message"}],
 484              '[{"role": "user", "content": "Single message"}]',
 485          ),
 486      ],
 487  )
 488  def test_parse_inputs_to_str(input_data, expected):
 489      assert parse_inputs_to_str(input_data) == expected
 490  
 491  
 492  @pytest.mark.parametrize(
 493      ("output_data", "expected"),
 494      [
 495          # String output
 496          ("Output string", "Output string"),
 497          # Chat completion/ChatModel response
 498          (
 499              {
 500                  "choices": [
 501                      {
 502                          "index": 0,
 503                          "message": {
 504                              "role": "assistant",
 505                              "content": "Output content",
 506                          },
 507                      }
 508                  ]
 509              },
 510              "Output content",
 511          ),
 512          # ChatAgent response with multiple messages
 513          (
 514              {
 515                  "messages": [
 516                      {
 517                          "role": "user",
 518                          "content": "Input content",
 519                      },
 520                      {
 521                          "role": "assistant",
 522                          "content": "Intermediate Output content",
 523                      },
 524                      {
 525                          "role": "user",
 526                          "content": "Intermediate Input content",
 527                      },
 528                      {
 529                          "role": "assistant",
 530                          "content": "Output content",
 531                      },
 532                  ]
 533              },
 534              "Output content",
 535          ),
 536          # List of strings
 537          (["Response content"], "Response content"),
 538          # ChatAgent response with multiple messages
 539          (
 540              [
 541                  {
 542                      "choices": [
 543                          {
 544                              "index": 0,
 545                              "message": {
 546                                  "role": "assistant",
 547                                  "content": "Output content",
 548                              },
 549                          }
 550                      ]
 551                  }
 552              ],
 553              "Output content",
 554          ),
 555          # List of direct string response
 556          (
 557              {"unsupported_key": "value"},
 558              '{"unsupported_key": "value"}',
 559          ),
 560          # Handle custom messages array format
 561          (
 562              {"messages": ["a", "b", "c"]},
 563              '{"messages": ["a", "b", "c"]}',
 564          ),
 565          # OpenAI Responses API format with output_text content type
 566          (
 567              {
 568                  "output": [
 569                      {
 570                          "id": "msg_123",
 571                          "type": "message",
 572                          "role": "assistant",
 573                          "content": [{"type": "output_text", "text": "Response from Responses API"}],
 574                      }
 575                  ]
 576              },
 577              "Response from Responses API",
 578          ),
 579          # OpenAI Responses API format with text content type
 580          (
 581              {
 582                  "output": [
 583                      {
 584                          "id": "msg_456",
 585                          "type": "message",
 586                          "role": "assistant",
 587                          "content": [{"type": "text", "text": "Text type response"}],
 588                      }
 589                  ]
 590              },
 591              "Text type response",
 592          ),
 593          # OpenAI Responses API format with string content
 594          (
 595              {
 596                  "output": [
 597                      {
 598                          "id": "msg_789",
 599                          "type": "message",
 600                          "role": "assistant",
 601                          "content": "Direct string content",
 602                      }
 603                  ]
 604              },
 605              "Direct string content",
 606          ),
 607          # OpenAI Responses API format with multiple output items (gets last assistant message)
 608          (
 609              {
 610                  "output": [
 611                      {
 612                          "id": "item_1",
 613                          "type": "function_call",
 614                          "name": "get_weather",
 615                      },
 616                      {
 617                          "id": "msg_final",
 618                          "type": "message",
 619                          "role": "assistant",
 620                          "content": [{"type": "output_text", "text": "Final response"}],
 621                      },
 622                  ]
 623              },
 624              "Final response",
 625          ),
 626      ],
 627  )
 628  def test_parse_outputs_to_str(output_data, expected):
 629      assert parse_outputs_to_str(output_data) == expected
 630  
 631  
 632  @pytest.mark.parametrize(
 633      ("input_value", "expected"),
 634      [
 635          (None, True),
 636          (np.nan, True),
 637          (float("nan"), True),
 638          ("Not NaN", False),
 639          (123, False),
 640          ([], False),
 641          ({}, False),
 642          (0.0, False),
 643          (1.5, False),
 644      ],
 645  )
 646  def test_is_none_or_nan(input_value, expected):
 647      assert is_none_or_nan(input_value) == expected
 648  
 649  
 650  def test_extract_expectations_from_trace_with_source_filter():
 651      with mlflow.start_span(name="test_span") as span:
 652          span.set_inputs({"question": "What is MLflow?"})
 653          span.set_outputs({"answer": "MLflow is an open source platform"})
 654  
 655      trace_id = span.trace_id
 656  
 657      human_expectation = Expectation(
 658          name="human_expectation",
 659          value={"expected": "Answer from human"},
 660          source=AssessmentSource(source_type=AssessmentSourceType.HUMAN),
 661      )
 662      mlflow.log_assessment(trace_id=trace_id, assessment=human_expectation)
 663  
 664      llm_expectation = Expectation(
 665          name="llm_expectation",
 666          value="LLM generated expectation",
 667          source=AssessmentSource(source_type=AssessmentSourceType.LLM_JUDGE),
 668      )
 669      mlflow.log_assessment(trace_id=trace_id, assessment=llm_expectation)
 670  
 671      code_expectation = Expectation(
 672          name="code_expectation",
 673          value=42,
 674          source=AssessmentSource(source_type=AssessmentSourceType.CODE),
 675      )
 676      mlflow.log_assessment(trace_id=trace_id, assessment=code_expectation)
 677  
 678      trace = mlflow.get_trace(trace_id)
 679  
 680      result = extract_expectations_from_trace(trace, source_type=None)
 681      assert result == {
 682          "human_expectation": {"expected": "Answer from human"},
 683          "llm_expectation": "LLM generated expectation",
 684          "code_expectation": 42,
 685      }
 686  
 687      result = extract_expectations_from_trace(trace, source_type="HUMAN")
 688      assert result == {"human_expectation": {"expected": "Answer from human"}}
 689  
 690      result = extract_expectations_from_trace(trace, source_type="LLM_JUDGE")
 691      assert result == {"llm_expectation": "LLM generated expectation"}
 692  
 693      result = extract_expectations_from_trace(trace, source_type="CODE")
 694      assert result == {"code_expectation": 42}
 695  
 696      result = extract_expectations_from_trace(trace, source_type="human")
 697      assert result == {"human_expectation": {"expected": "Answer from human"}}
 698  
 699      with pytest.raises(mlflow.exceptions.MlflowException, match="Invalid assessment source type"):
 700          extract_expectations_from_trace(trace, source_type="INVALID_SOURCE")
 701  
 702  
 703  def test_extract_expectations_from_trace_returns_none_when_no_expectations():
 704      with mlflow.start_span(name="test_span") as span:
 705          span.set_inputs({"question": "What is MLflow?"})
 706          span.set_outputs({"answer": "MLflow is an open source platform"})
 707  
 708      trace = mlflow.get_trace(span.trace_id)
 709  
 710      result = extract_expectations_from_trace(trace)
 711      assert result is None
 712  
 713      result = extract_expectations_from_trace(trace, source_type="HUMAN")
 714      assert result is None
 715  
 716  
 717  def test_extract_inputs_and_outputs_from_trace():
 718      test_inputs = {"question": "What is MLflow?", "context": "MLflow is a tool"}
 719      test_outputs = {"answer": "MLflow is an open source platform", "confidence": 0.95}
 720  
 721      with mlflow.start_span(name="test_span") as span:
 722          span.set_inputs(test_inputs)
 723          span.set_outputs(test_outputs)
 724  
 725      trace = mlflow.get_trace(span.trace_id)
 726  
 727      assert extract_inputs_from_trace(trace) == test_inputs
 728      assert extract_outputs_from_trace(trace) == test_outputs
 729  
 730      trace_without_data = Trace(
 731          info=create_test_trace_info(trace_id="tr-123"), data=TraceData(spans=[])
 732      )
 733      assert extract_inputs_from_trace(trace_without_data) is None
 734      assert extract_outputs_from_trace(trace_without_data) is None
 735  
 736  
 737  def test_extract_request_and_response_from_trace():
 738      test_inputs = {"messages": [{"role": "user", "content": "What is MLflow?"}]}
 739      test_outputs = {
 740          "choices": [{"index": 0, "message": {"role": "assistant", "content": "MLflow is great"}}]
 741      }
 742  
 743      with mlflow.start_span(name="test_span") as span:
 744          span.set_inputs(test_inputs)
 745          span.set_outputs(test_outputs)
 746  
 747      trace = mlflow.get_trace(span.trace_id)
 748  
 749      assert extract_request_from_trace(trace) == "What is MLflow?"
 750      assert extract_response_from_trace(trace) == "MLflow is great"
 751  
 752      trace_without_data = Trace(
 753          info=create_test_trace_info(trace_id="tr-123"), data=TraceData(spans=[])
 754      )
 755      assert extract_request_from_trace(trace_without_data) is None
 756      assert extract_response_from_trace(trace_without_data) is None
 757  
 758  
 759  def test_extract_request_and_response_with_string_inputs():
 760      test_inputs = "Simple string input"
 761      test_outputs = "Simple string output"
 762  
 763      with mlflow.start_span(name="test_span") as span:
 764          span.set_inputs(test_inputs)
 765          span.set_outputs(test_outputs)
 766  
 767      trace = mlflow.get_trace(span.trace_id)
 768  
 769      assert extract_request_from_trace(trace) == "Simple string input"
 770      assert extract_response_from_trace(trace) == "Simple string output"
 771  
 772  
 773  def test_does_store_support_trace_linking():
 774      test_trace = Trace(info=create_test_trace_info(trace_id="tr-123"), data=TraceData(spans=[]))
 775  
 776      # Databricks backend support trace linking
 777      assert _does_store_support_trace_linking(
 778          tracking_uri="databricks",
 779          trace=test_trace,
 780          run_id="run-123",
 781      )
 782  
 783      assert _does_store_support_trace_linking(
 784          tracking_uri="databricks://test",
 785          trace=test_trace,
 786          run_id="run-123",
 787      )
 788  
 789      mock_client = mock.MagicMock()
 790      with mock.patch("mlflow.genai.utils.trace_utils.MlflowClient", return_value=mock_client):
 791          # SQLAlchemy backend support trace linking
 792          mock_client.link_traces_to_run.side_effect = None
 793  
 794          assert _does_store_support_trace_linking(
 795              tracking_uri="sqlalchemy://test",
 796              trace=test_trace,
 797              run_id="run-123",
 798          )
 799  
 800          # File store doesn't support trace linking
 801          mock_client.link_traces_to_run.side_effect = Exception("Test error")
 802  
 803          assert not _does_store_support_trace_linking(
 804              tracking_uri="file://test",
 805              trace=test_trace,
 806              run_id="run-123",
 807          )
 808  
 809          # Result should be cached per tracking URI
 810          mock_client.reset_mock()
 811          mock_client.link_traces_to_run.side_effect = None
 812          for _ in range(10):
 813              assert _does_store_support_trace_linking(
 814                  tracking_uri="sqlalchemy://test2",
 815                  trace=test_trace,
 816                  run_id="run-123",
 817              )
 818          mock_client.link_traces_to_run.assert_called_once()
 819  
 820  
 821  def test_create_minimal_trace_restores_session_metadata():
 822      source = DatasetRecordSource(
 823          source_type=DatasetRecordSourceType.TRACE,
 824          source_data={"trace_id": "tr-original", "session_id": "session_1"},
 825      )
 826  
 827      eval_item = EvalItem(
 828          request_id="req-123",
 829          inputs={"question": "test"},
 830          outputs="answer",
 831          expectations={},
 832          source=source,
 833      )
 834  
 835      trace = create_minimal_trace(eval_item)
 836  
 837      # Verify session metadata was restored
 838      assert trace.info.trace_metadata.get("mlflow.trace.session") == "session_1"
 839      assert trace.data._get_root_span().inputs == {"question": "test"}
 840      assert trace.data._get_root_span().outputs == "answer"
 841  
 842  
 843  def test_create_minimal_trace_without_source():
 844      eval_item = EvalItem(
 845          request_id="req-123",
 846          inputs={"question": "test"},
 847          outputs="answer",
 848          expectations={},
 849          source=None,
 850      )
 851  
 852      trace = create_minimal_trace(eval_item)
 853  
 854      # Should create trace successfully without session metadata
 855      assert trace is not None
 856      assert "mlflow.trace.session" not in trace.info.trace_metadata
 857      assert trace.data._get_root_span().inputs == {"question": "test"}
 858      assert trace.data._get_root_span().outputs == "answer"
 859  
 860  
 861  def test_create_minimal_trace_with_source_but_no_session():
 862      source = DatasetRecordSource(
 863          source_type=DatasetRecordSourceType.TRACE,
 864          source_data={"trace_id": "tr-original"},  # No session_id
 865      )
 866  
 867      eval_item = EvalItem(
 868          request_id="req-123",
 869          inputs={"question": "test"},
 870          outputs="answer",
 871          expectations={},
 872          source=source,
 873      )
 874  
 875      trace = create_minimal_trace(eval_item)
 876  
 877      # Should work without session metadata
 878      assert trace is not None
 879      assert "mlflow.trace.session" not in trace.info.trace_metadata
 880      assert trace.data._get_root_span().inputs == {"question": "test"}
 881      assert trace.data._get_root_span().outputs == "answer"
 882  
 883  
 884  def test_parse_tool_call_messages_from_trace():
 885      with mlflow.start_span(name="root") as root_span:
 886          root_span.set_inputs({"question": "What is the stock price?"})
 887  
 888          with mlflow.start_span(name="get_stock_price", span_type=SpanType.TOOL) as tool_span:
 889              tool_span.set_inputs({"symbol": "AAPL"})
 890              tool_span.set_outputs({"price": 150.0})
 891  
 892          with mlflow.start_span(name="get_market_cap", span_type=SpanType.TOOL) as tool_span2:
 893              tool_span2.set_inputs({"symbol": "AAPL"})
 894              tool_span2.set_outputs({"market_cap": "2.5T"})
 895  
 896          root_span.set_outputs("AAPL price is $150.")
 897  
 898      trace = mlflow.get_trace(root_span.trace_id)
 899      tool_messages = parse_tool_call_messages_from_trace(trace)
 900  
 901      assert len(tool_messages) == 2
 902      assert tool_messages[0] == {
 903          "role": "tool",
 904          "content": "Tool: get_stock_price\nInputs: {'symbol': 'AAPL'}\nOutputs: {'price': 150.0}",
 905      }
 906      assert tool_messages[1] == {
 907          "role": "tool",
 908          "content": (
 909              "Tool: get_market_cap\nInputs: {'symbol': 'AAPL'}\nOutputs: {'market_cap': '2.5T'}"
 910          ),
 911      }
 912  
 913  
 914  def test_parse_tool_call_messages_from_trace_no_tools():
 915      with mlflow.start_span(name="root") as span:
 916          span.set_inputs({"question": "Hello"})
 917          span.set_outputs("Hi there")
 918  
 919      trace = mlflow.get_trace(span.trace_id)
 920      tool_messages = parse_tool_call_messages_from_trace(trace)
 921  
 922      assert tool_messages == []
 923  
 924  
 925  def test_parse_tool_call_messages_from_trace_tool_without_outputs():
 926      with mlflow.start_span(name="root") as root_span:
 927          root_span.set_inputs({"query": "test"})
 928  
 929          with mlflow.start_span(name="my_tool", span_type=SpanType.TOOL) as tool_span:
 930              tool_span.set_inputs({"param": "value"})
 931  
 932          root_span.set_outputs("result")
 933  
 934      trace = mlflow.get_trace(root_span.trace_id)
 935      tool_messages = parse_tool_call_messages_from_trace(trace)
 936  
 937      assert len(tool_messages) == 1
 938      assert tool_messages[0] == {
 939          "role": "tool",
 940          "content": "Tool: my_tool\nInputs: {'param': 'value'}",
 941      }
 942  
 943  
 944  def test_extract_tool_name_from_span_uses_span_name_by_default():
 945      with mlflow.start_span(name="root") as root_span:
 946          root_span.set_inputs({"query": "test"})
 947  
 948          with mlflow.start_span(name="my_tool", span_type=SpanType.TOOL) as tool_span:
 949              tool_span.set_inputs({"arg": "value"})
 950  
 951          root_span.set_outputs("result")
 952  
 953      trace = mlflow.get_trace(root_span.trace_id)
 954      tool_spans = trace.search_spans(span_type=SpanType.TOOL)
 955  
 956      assert _extract_tool_name_from_span(tool_spans[0]) == "my_tool"
 957  
 958  
 959  def test_extract_tool_name_from_span_extracts_from_call_tool_name():
 960      with mlflow.start_span(name="root") as root_span:
 961          root_span.set_inputs({"query": "test"})
 962  
 963          with mlflow.start_span(
 964              name="ToolManager.handle_call", span_type=SpanType.TOOL
 965          ) as tool_span:
 966              tool_span.set_inputs({"call": {"tool_name": "list_client", "args": {"param": "value"}}})
 967  
 968          root_span.set_outputs("result")
 969  
 970      trace = mlflow.get_trace(root_span.trace_id)
 971      tool_spans = trace.search_spans(span_type=SpanType.TOOL)
 972  
 973      assert _extract_tool_name_from_span(tool_spans[0]) == "list_client"
 974  
 975  
 976  def test_resolve_conversation_from_session():
 977      session_id = "test_session_resolve"
 978      traces = []
 979  
 980      with mlflow.start_span(name="turn_0") as span:
 981          span.set_inputs({"messages": [{"role": "user", "content": "What is AAPL price?"}]})
 982          span.set_outputs("AAPL is $150.")
 983          mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id})
 984      traces.append(mlflow.get_trace(span.trace_id))
 985  
 986      with mlflow.start_span(name="turn_1") as span:
 987          span.set_inputs({"messages": [{"role": "user", "content": "How about MSFT?"}]})
 988          span.set_outputs("MSFT is $300.")
 989          mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id})
 990      traces.append(mlflow.get_trace(span.trace_id))
 991  
 992      conversation = resolve_conversation_from_session(traces)
 993  
 994      assert len(conversation) == 4
 995      assert conversation[0] == {"role": "user", "content": "What is AAPL price?"}
 996      assert conversation[1] == {"role": "assistant", "content": "AAPL is $150."}
 997      assert conversation[2] == {"role": "user", "content": "How about MSFT?"}
 998      assert conversation[3] == {"role": "assistant", "content": "MSFT is $300."}
 999  
1000  
1001  def test_resolve_conversation_from_session_with_tool_calls():
1002      session_id = "test_session_with_tools"
1003      traces = []
1004  
1005      with mlflow.start_span(name="turn_0") as root_span:
1006          root_span.set_inputs({"messages": [{"role": "user", "content": "Get AAPL price"}]})
1007  
1008          with mlflow.start_span(name="get_stock_price", span_type=SpanType.TOOL) as tool_span:
1009              tool_span.set_inputs({"symbol": "AAPL"})
1010              tool_span.set_outputs({"price": 150})
1011  
1012          root_span.set_outputs("AAPL is $150.")
1013          mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id})
1014      traces.append(mlflow.get_trace(root_span.trace_id))
1015  
1016      conversation = resolve_conversation_from_session(traces, include_tool_calls=False)
1017      assert len(conversation) == 2
1018      assert conversation[0]["role"] == "user"
1019      assert conversation[1]["role"] == "assistant"
1020  
1021      conversation_with_tools = resolve_conversation_from_session(traces, include_tool_calls=True)
1022      assert len(conversation_with_tools) == 3
1023      assert conversation_with_tools[0] == {"role": "user", "content": "Get AAPL price"}
1024      assert conversation_with_tools[1] == {
1025          "role": "tool",
1026          "content": "Tool: get_stock_price\nInputs: {'symbol': 'AAPL'}\nOutputs: {'price': 150}",
1027      }
1028      assert conversation_with_tools[2] == {"role": "assistant", "content": "AAPL is $150."}
1029  
1030  
1031  def test_resolve_conversation_from_session_empty():
1032      assert resolve_conversation_from_session([]) == []
1033  
1034  
1035  @pytest.mark.parametrize("include_timing", [True, False])
1036  def test_resolve_conversation_from_session_with_timing_parameter(include_timing):
1037      session_id = "test_session"
1038      traces = []
1039  
1040      with mlflow.start_span(name="turn_0") as span:
1041          span.set_inputs({"messages": [{"role": "user", "content": "What is MLflow?"}]})
1042          span.set_outputs("MLflow is an ML platform.")
1043          mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id})
1044      traces.append(mlflow.get_trace(span.trace_id))
1045  
1046      conversation = resolve_conversation_from_session(traces, include_timing=include_timing)
1047  
1048      assert len(conversation) == 2
1049      assert conversation[0] == {"role": "user", "content": "What is MLflow?"}
1050      assert conversation[1]["role"] == "assistant"
1051      assert "MLflow is an ML platform." in conversation[1]["content"]
1052      assert ("[Response duration:" in conversation[1]["content"]) is include_timing
1053      assert ("slowest spans:" in conversation[1]["content"]) is include_timing
1054  
1055  
1056  def test_session_level_expectations_filtering():
1057      session_id = "test-session"
1058  
1059      with mlflow.start_span(name="test_span") as span:
1060          span.set_inputs({"question": "Test"})
1061          span.set_outputs({"answer": "Test answer"})
1062  
1063      trace_id = span.trace_id
1064  
1065      session_exp = Expectation(
1066          name="session_exp",
1067          value="session_value",
1068          source=AssessmentSource(source_type=AssessmentSourceType.HUMAN),
1069          metadata={TraceMetadataKey.TRACE_SESSION: session_id},
1070      )
1071      mlflow.log_assessment(trace_id=trace_id, assessment=session_exp)
1072  
1073      trace_exp = Expectation(
1074          name="trace_exp",
1075          value="trace_value",
1076          source=AssessmentSource(source_type=AssessmentSourceType.HUMAN),
1077          metadata={},
1078      )
1079      mlflow.log_assessment(trace_id=trace_id, assessment=trace_exp)
1080  
1081      trace = mlflow.get_trace(trace_id)
1082  
1083      session_result = resolve_expectations_from_session(None, [trace])
1084      assert session_result == {"session_exp": "session_value"}
1085      assert "trace_exp" not in session_result
1086  
1087  
1088  def test_resolve_expectations_from_session_with_provided_expectations():
1089      with mlflow.start_span(name="test_span") as span:
1090          span.set_inputs({"question": "Test"})
1091          span.set_outputs({"answer": "Test answer"})
1092  
1093      trace = mlflow.get_trace(span.trace_id)
1094      provided_expectations = {"provided": "value"}
1095  
1096      result = resolve_expectations_from_session(provided_expectations, [trace])
1097      assert result == provided_expectations
1098  
1099  
1100  @pytest.mark.parametrize(
1101      ("expectations", "has_session_exp", "expected"),
1102      [
1103          (None, False, None),
1104          (None, True, {"session_exp": "session_value"}),
1105          ({"provided": "value"}, True, {"provided": "value"}),
1106      ],
1107  )
1108  def test_resolve_expectations_from_session_edge_cases(expectations, has_session_exp, expected):
1109      session_id = "test-session"
1110  
1111      with mlflow.start_span(name="test_span") as span:
1112          span.set_inputs({"question": "Test"})
1113          span.set_outputs({"answer": "Test answer"})
1114          mlflow.update_current_trace(metadata={TraceMetadataKey.TRACE_SESSION: session_id})
1115  
1116      if has_session_exp:
1117          exp = Expectation(
1118              name="session_exp",
1119              value="session_value",
1120              source=AssessmentSource(source_type=AssessmentSourceType.HUMAN),
1121              metadata={TraceMetadataKey.TRACE_SESSION: session_id},
1122          )
1123          mlflow.log_assessment(trace_id=span.trace_id, assessment=exp)
1124  
1125      trace = mlflow.get_trace(span.trace_id)
1126      result = resolve_expectations_from_session(expectations, [trace])
1127      assert result == expected
1128  
1129  
1130  def test_convert_predict_fn_async_function():
1131      async def async_predict_fn(request):
1132          await asyncio.sleep(0.01)
1133          return "async test response"
1134  
1135      sample_input = {"request": {"messages": [{"role": "user", "content": "test"}]}}
1136  
1137      converted_fn = convert_predict_fn(async_predict_fn, sample_input)
1138  
1139      result = converted_fn(request=sample_input)
1140      assert result == "async test response"
1141  
1142      traces = get_traces()
1143      assert len(traces) == 1
1144      purge_traces()
1145  
1146  
1147  def test_evaluate_with_async_predict_fn():
1148      async def async_predict_fn(request):
1149          await asyncio.sleep(0.01)
1150          return "async test response"
1151  
1152      sample_input = {"request": {"messages": [{"role": "user", "content": "test"}]}}
1153  
1154      @scorer
1155      def dummy_scorer(inputs, outputs):
1156          return 0
1157  
1158      mlflow.genai.evaluate(
1159          data=[{"inputs": sample_input}],
1160          predict_fn=async_predict_fn,
1161          scorers=[dummy_scorer],
1162      )
1163      assert len(get_traces()) == 1
1164      purge_traces()
1165  
1166  
1167  def test_convert_predict_fn_async_function_with_timeout(monkeypatch):
1168      monkeypatch.setenv("MLFLOW_GENAI_EVAL_ASYNC_TIMEOUT", "1")
1169      monkeypatch.setenv("MLFLOW_GENAI_EVAL_SKIP_TRACE_VALIDATION", "true")
1170  
1171      async def slow_async_predict_fn(request):
1172          await asyncio.sleep(2)
1173          return "should timeout"
1174  
1175      sample_input = {"request": {"messages": [{"role": "user", "content": "test"}]}}
1176  
1177      converted_fn = convert_predict_fn(slow_async_predict_fn, sample_input)
1178  
1179      with pytest.raises(asyncio.TimeoutError):  # noqa: PT011
1180          converted_fn(request=sample_input)
1181  
1182      assert len(get_traces()) == 0
1183  
1184  
1185  @pytest.mark.parametrize(
1186      ("span_type", "use_attribute", "tool_name", "tool_description"),
1187      [
1188          ("LLM", True, "get_weather", "Get current weather"),
1189          ("CHAT_MODEL", False, "search", "Search the web"),
1190      ],
1191  )
1192  def test_extract_available_tools_from_trace_basic(
1193      span_type, use_attribute, tool_name, tool_description
1194  ):
1195      tools = [
1196          {
1197              "type": "function",
1198              "function": {
1199                  "name": tool_name,
1200                  "description": tool_description,
1201                  "parameters": {"type": "object", "properties": {"param": {"type": "string"}}},
1202              },
1203          }
1204      ]
1205  
1206      with mlflow.start_span(name="test_span", span_type=span_type) as span:
1207          if use_attribute:
1208              set_span_chat_tools(span, tools)
1209              span.set_inputs({"prompt": "test"})
1210          else:
1211              span.set_inputs({"messages": [{"role": "user", "content": "test"}], "tools": tools})
1212          span.set_outputs({"response": "result"})
1213  
1214      trace = mlflow.get_trace(span.trace_id)
1215      extracted_tools = extract_available_tools_from_trace(trace)
1216  
1217      assert len(extracted_tools) == 1
1218      assert extracted_tools[0].model_dump(exclude_none=True) == {
1219          "type": "function",
1220          "function": {
1221              "name": tool_name,
1222              "description": tool_description,
1223              "parameters": {"type": "object", "properties": {"param": {"type": "string"}}},
1224          },
1225      }
1226  
1227  
1228  def test_extract_available_tools_from_trace_with_multiple_spans():
1229      tool1 = [
1230          {
1231              "type": "function",
1232              "function": {
1233                  "name": "add",
1234                  "description": "Add two numbers",
1235                  "parameters": {
1236                      "type": "object",
1237                      "properties": {
1238                          "a": {"type": "number"},
1239                          "b": {"type": "number"},
1240                      },
1241                  },
1242              },
1243          }
1244      ]
1245  
1246      tool2 = [
1247          {
1248              "type": "function",
1249              "function": {
1250                  "name": "multiply",
1251                  "description": "Multiply two numbers",
1252                  "parameters": {
1253                      "type": "object",
1254                      "properties": {
1255                          "x": {"type": "number"},
1256                          "y": {"type": "number"},
1257                      },
1258                  },
1259              },
1260          }
1261      ]
1262  
1263      with mlflow.start_span(name="parent") as parent:
1264          with mlflow.start_span(name="llm1", span_type="LLM") as span1:
1265              set_span_chat_tools(span1, tool1)
1266  
1267          with mlflow.start_span(name="llm2", span_type="CHAT_MODEL") as span2:
1268              set_span_chat_tools(span2, tool2)
1269  
1270      trace = mlflow.get_trace(parent.trace_id)
1271      extracted_tools = extract_available_tools_from_trace(trace)
1272  
1273      assert len(extracted_tools) == 2
1274  
1275      extracted_tools_sorted = sorted(extracted_tools, key=lambda t: t.function.name)
1276  
1277      assert extracted_tools_sorted[0].model_dump(exclude_none=True) == {
1278          "type": "function",
1279          "function": {
1280              "name": "add",
1281              "description": "Add two numbers",
1282              "parameters": {
1283                  "type": "object",
1284                  "properties": {
1285                      "a": {"type": "number"},
1286                      "b": {"type": "number"},
1287                  },
1288              },
1289          },
1290      }
1291  
1292      assert extracted_tools_sorted[1].model_dump(exclude_none=True) == {
1293          "type": "function",
1294          "function": {
1295              "name": "multiply",
1296              "description": "Multiply two numbers",
1297              "parameters": {
1298                  "type": "object",
1299                  "properties": {
1300                      "x": {"type": "number"},
1301                      "y": {"type": "number"},
1302                  },
1303              },
1304          },
1305      }
1306  
1307  
1308  def test_extract_available_tools_from_trace_deduplication():
1309      tools = [
1310          {
1311              "type": "function",
1312              "function": {
1313                  "name": "get_weather",
1314                  "description": "Get weather info",
1315                  "parameters": {"type": "object", "properties": {}},
1316              },
1317          }
1318      ]
1319  
1320      with mlflow.start_span(name="parent") as parent:
1321          with mlflow.start_span(name="llm1", span_type="LLM") as span1:
1322              set_span_chat_tools(span1, tools)
1323  
1324          with mlflow.start_span(name="llm2", span_type="LLM") as span2:
1325              set_span_chat_tools(span2, tools)
1326  
1327      trace = mlflow.get_trace(parent.trace_id)
1328      extracted_tools = extract_available_tools_from_trace(trace)
1329  
1330      assert len(extracted_tools) == 1
1331      assert extracted_tools[0].model_dump(exclude_none=True) == {
1332          "type": "function",
1333          "function": {
1334              "name": "get_weather",
1335              "description": "Get weather info",
1336              "parameters": {"type": "object", "properties": {}},
1337          },
1338      }
1339  
1340  
1341  def test_extract_available_tools_from_trace_different_descriptions():
1342      tool1 = [
1343          {
1344              "type": "function",
1345              "function": {
1346                  "name": "search",
1347                  "description": "Search the web",
1348                  "parameters": {"type": "object", "properties": {}},
1349              },
1350          }
1351      ]
1352  
1353      tool2 = [
1354          {
1355              "type": "function",
1356              "function": {
1357                  "name": "search",
1358                  "description": "Search the database",
1359                  "parameters": {"type": "object", "properties": {}},
1360              },
1361          }
1362      ]
1363  
1364      with mlflow.start_span(name="parent") as parent:
1365          with mlflow.start_span(name="llm1", span_type="LLM") as span1:
1366              set_span_chat_tools(span1, tool1)
1367  
1368          with mlflow.start_span(name="llm2", span_type="LLM") as span2:
1369              set_span_chat_tools(span2, tool2)
1370  
1371      trace = mlflow.get_trace(parent.trace_id)
1372      extracted_tools = extract_available_tools_from_trace(trace)
1373  
1374      assert len(extracted_tools) == 2
1375  
1376      extracted_tools_sorted = sorted(extracted_tools, key=lambda t: t.function.description)
1377  
1378      assert extracted_tools_sorted[0].model_dump(exclude_none=True) == {
1379          "type": "function",
1380          "function": {
1381              "name": "search",
1382              "description": "Search the database",
1383              "parameters": {"type": "object", "properties": {}},
1384          },
1385      }
1386  
1387      assert extracted_tools_sorted[1].model_dump(exclude_none=True) == {
1388          "type": "function",
1389          "function": {
1390              "name": "search",
1391              "description": "Search the web",
1392              "parameters": {"type": "object", "properties": {}},
1393          },
1394      }
1395  
1396  
1397  def test_extract_available_tools_from_trace_returns_empty():
1398      trace_fixture = Trace(info=create_test_trace_info(trace_id="tr-456"), data=TraceData(spans=[]))
1399      result = extract_available_tools_from_trace(trace_fixture)
1400      assert result == []
1401  
1402  
1403  @pytest.mark.parametrize(
1404      ("has_valid_tool", "expected_count"),
1405      [
1406          (False, 0),  # Only invalid tools
1407          (True, 1),  # Mix of valid and invalid tools
1408      ],
1409  )
1410  def test_extract_available_tools_from_trace_with_invalid_tools(has_valid_tool, expected_count):
1411      with mlflow.start_span(name="parent") as parent:
1412          if has_valid_tool:
1413              valid_tool = [
1414                  {
1415                      "type": "function",
1416                      "function": {
1417                          "name": "valid_tool",
1418                          "description": "A valid tool",
1419                      },
1420                  }
1421              ]
1422              with mlflow.start_span(name="llm1", span_type="LLM") as span1:
1423                  set_span_chat_tools(span1, valid_tool)
1424  
1425          with mlflow.start_span(name="llm2", span_type="LLM") as span2:
1426              span2.set_inputs({
1427                  "messages": [{"role": "user", "content": "test"}],
1428                  "tools": [
1429                      {"invalid": "tool"},  # Missing required fields
1430                      {"type": "function"},  # Missing function field
1431                  ],
1432              })
1433  
1434      trace = mlflow.get_trace(parent.trace_id)
1435      extracted_tools = extract_available_tools_from_trace(trace)
1436  
1437      assert len(extracted_tools) == expected_count
1438      if has_valid_tool:
1439          assert extracted_tools[0].model_dump(exclude_none=True) == {
1440              "type": "function",
1441              "function": {
1442                  "name": "valid_tool",
1443                  "description": "A valid tool",
1444              },
1445          }
1446  
1447  
1448  def test_extract_available_tools_llm_fallback_triggered_when_no_tools_found(monkeypatch):
1449      with mlflow.start_span(name="llm_span", span_type=SpanType.LLM) as span:
1450          span.set_inputs({
1451              "messages": [{"role": "user", "content": "test"}],
1452              "tools": [
1453                  {
1454                      "tool_name": "hard_to_extract_tool",
1455                      "description": "A tool that is hard to extract",
1456                  }
1457              ],
1458          })
1459          span.set_outputs({"response": "result"})
1460  
1461      trace = mlflow.get_trace(span.trace_id)
1462  
1463      mock_tools = [
1464          ChatTool(
1465              type="function",
1466              function=FunctionToolDefinition(
1467                  name="hard_to_extract_tool",
1468                  description="A tool that is hard to extract",
1469                  parameters={"type": "object", "properties": {"x": {"type": "string"}}},
1470              ),
1471          )
1472      ]
1473  
1474      mock_llm_fallback_called = []
1475  
1476      def mock_llm_fallback(trace_arg, model_arg):
1477          mock_llm_fallback_called.append({"trace": trace_arg, "model": model_arg})
1478          return mock_tools
1479  
1480      monkeypatch.setattr(
1481          "mlflow.genai.utils.trace_utils._try_extract_available_tools_with_llm",
1482          mock_llm_fallback,
1483      )
1484  
1485      extracted_tools = extract_available_tools_from_trace(trace, model="openai:/gpt-4")
1486  
1487      assert len(mock_llm_fallback_called) == 1
1488      assert mock_llm_fallback_called[0]["trace"] == trace
1489      assert mock_llm_fallback_called[0]["model"] == "openai:/gpt-4"
1490      assert len(extracted_tools) == 1
1491      assert extracted_tools[0].model_dump(exclude_none=True) == {
1492          "type": "function",
1493          "function": {
1494              "name": "hard_to_extract_tool",
1495              "description": "A tool that is hard to extract",
1496              "parameters": {"type": "object", "properties": {"x": {"type": "string"}}},
1497          },
1498      }
1499  
1500  
1501  def test_try_extract_available_tools_with_llm_returns_empty_on_error(monkeypatch):
1502      with mlflow.start_span(name="llm_span", span_type=SpanType.LLM) as span:
1503          span.set_inputs({"messages": [{"role": "user", "content": "test"}]})
1504          span.set_outputs({"response": "result"})
1505  
1506      trace = mlflow.get_trace(span.trace_id)
1507  
1508      def mock_raise_error(*args, **kwargs):
1509          raise RuntimeError("LLM API error")
1510  
1511      monkeypatch.setattr(
1512          "mlflow.genai.utils.trace_utils.get_chat_completions_with_structured_output",
1513          mock_raise_error,
1514      )
1515  
1516      result = _try_extract_available_tools_with_llm(trace, model="openai:/gpt-4")
1517      assert result == []
1518  
1519  
1520  def test_should_keep_trace_preserves_input_trace_ids():
1521      trace_info = create_test_trace_info(
1522          trace_id="tr-input-123",
1523          request_time=2000,
1524      )
1525      trace = Trace(info=trace_info, data=TraceData(spans=[]))
1526  
1527      eval_start_time = 1000
1528      input_trace_ids = {"tr-input-123"}
1529  
1530      result = _should_keep_trace(trace, eval_start_time, input_trace_ids)
1531      assert result is True
1532  
1533  
1534  def test_should_keep_trace_deletes_non_input_traces_after_eval_start():
1535      trace_info = create_test_trace_info(
1536          trace_id="tr-extra-456",
1537          request_time=2000,
1538      )
1539      trace = Trace(info=trace_info, data=TraceData(spans=[]))
1540  
1541      eval_start_time = 1000
1542      input_trace_ids = {"tr-input-123"}
1543  
1544      result = _should_keep_trace(trace, eval_start_time, input_trace_ids)
1545      assert result is False
1546  
1547  
1548  def test_clean_up_extra_traces_preserves_input_traces():
1549      experiment_id = mlflow.set_experiment("test_experiment").experiment_id
1550  
1551      with mlflow.start_span(name="input_trace_1") as span1:
1552          span1.set_inputs({"question": "test1"})
1553          span1.set_outputs({"answer": "answer1"})
1554      trace1 = mlflow.get_trace(span1.trace_id)
1555  
1556      with mlflow.start_span(name="input_trace_2") as span2:
1557          span2.set_inputs({"question": "test2"})
1558          span2.set_outputs({"answer": "answer2"})
1559      trace2 = mlflow.get_trace(span2.trace_id)
1560  
1561      eval_start_time = int(trace1.info.timestamp_ms - 1000)
1562  
1563      input_trace_ids = {trace1.info.trace_id, trace2.info.trace_id}
1564      all_traces = [trace1, trace2]
1565  
1566      clean_up_extra_traces(all_traces, eval_start_time, experiment_id, input_trace_ids)
1567  
1568      remaining_traces = get_traces()
1569      remaining_trace_ids = {t.info.trace_id for t in remaining_traces}
1570      assert trace1.info.trace_id in remaining_trace_ids
1571      assert trace2.info.trace_id in remaining_trace_ids
1572  
1573  
1574  def test_clean_up_extra_traces_uses_correct_experiment_id():
1575      exp_1 = mlflow.set_experiment("cleanup_test_experiment").experiment_id
1576      with mlflow.start_span(name="input_trace") as span1:
1577          span1.set_inputs({"question": "test"})
1578          span1.set_outputs({"answer": "answer"})
1579      input_trace = mlflow.get_trace(span1.trace_id)
1580  
1581      with mlflow.start_span(name="extra_trace") as span2:
1582          span2.set_inputs({"question": "extra"})
1583          span2.set_outputs({"answer": "extra_answer"})
1584      extra_trace = mlflow.get_trace(span2.trace_id)
1585  
1586      mlflow.set_experiment("cleanup_test_experiment_2")
1587      clean_up_extra_traces([input_trace, extra_trace], 0, exp_1, {input_trace.info.trace_id})
1588  
1589      remaining_traces = mlflow.search_traces(locations=[exp_1], return_type="list")
1590      assert len(remaining_traces) == 1
1591      assert remaining_traces[0].info.trace_id == input_trace.info.trace_id
1592  
1593  
1594  def test_evaluate_with_trace_column_preserves_traces():
1595      @scorer
1596      def dummy_scorer(inputs, outputs):
1597          return 1.0
1598  
1599      with mlflow.start_span(name="original_trace") as span:
1600          span.set_inputs({"question": "What is MLflow?"})
1601          span.set_outputs({"answer": "MLflow is an ML platform"})
1602  
1603      original_trace = mlflow.get_trace(span.trace_id)
1604      original_trace_id = original_trace.info.trace_id
1605  
1606      eval_df = pd.DataFrame([
1607          {
1608              "trace": original_trace,
1609              "inputs": {"question": "What is MLflow?"},
1610              "outputs": {"answer": "MLflow is an ML platform"},
1611          }
1612      ])
1613  
1614      mlflow.genai.evaluate(data=eval_df, scorers=[dummy_scorer])
1615  
1616      remaining_traces = get_traces()
1617      remaining_trace_ids = {t.info.trace_id for t in remaining_traces}
1618      assert original_trace_id in remaining_trace_ids