test_trace_utils.py
1 import asyncio 2 import json 3 from typing import Any 4 from unittest import mock 5 6 import httpx 7 import numpy as np 8 import openai 9 import pandas as pd 10 import pytest 11 from opentelemetry.sdk.trace import ReadableSpan as OTelReadableSpan 12 13 import mlflow 14 from mlflow.entities.assessment import Expectation 15 from mlflow.entities.assessment_source import AssessmentSource, AssessmentSourceType 16 from mlflow.entities.dataset_record_source import DatasetRecordSource, DatasetRecordSourceType 17 from mlflow.entities.span import Span, SpanType 18 from mlflow.entities.trace import Trace 19 from mlflow.entities.trace_data import TraceData 20 from mlflow.genai.evaluation.entities import EvalItem 21 from mlflow.genai.evaluation.utils import is_none_or_nan 22 from mlflow.genai.scorers.base import scorer 23 from mlflow.genai.utils.trace_utils import ( 24 _does_store_support_trace_linking, 25 _extract_tool_name_from_span, 26 _should_keep_trace, 27 _try_extract_available_tools_with_llm, 28 clean_up_extra_traces, 29 convert_predict_fn, 30 create_minimal_trace, 31 extract_available_tools_from_trace, 32 extract_expectations_from_trace, 33 extract_inputs_from_trace, 34 extract_outputs_from_trace, 35 extract_request_from_trace, 36 extract_response_from_trace, 37 extract_retrieval_context_from_trace, 38 parse_inputs_to_str, 39 parse_outputs_to_str, 40 parse_tool_call_messages_from_trace, 41 resolve_conversation_from_session, 42 resolve_expectations_from_session, 43 ) 44 from mlflow.tracing import set_span_chat_tools 45 from mlflow.tracing.constant import TraceMetadataKey 46 from mlflow.tracing.utils import build_otel_context 47 from mlflow.types.chat import ChatTool, FunctionToolDefinition 48 49 from tests.tracing.helper import create_test_trace_info, get_traces, purge_traces 50 51 52 def httpx_send_patch(request, *args, **kwargs): 53 return httpx.Response( 54 status_code=200, 55 request=request, 56 json={ 57 "id": "chatcmpl-Ax4UAd5xf32KjgLkS1SEEY9oorI9m", 58 "object": "chat.completion", 59 "created": 1738641958, 60 "model": "gpt-4o-2024-08-06", 61 "choices": [ 62 { 63 "index": 0, 64 "message": { 65 "role": "assistant", 66 "content": "test", 67 "refusal": None, 68 }, 69 "logprobs": None, 70 "finish_reason": "stop", 71 } 72 ], 73 }, 74 ) 75 76 77 def get_openai_predict_fn(with_tracing=False): 78 if with_tracing: 79 mlflow.openai.autolog() 80 81 def predict_fn(request): 82 with mock.patch("httpx.Client.send", side_effect=httpx_send_patch): 83 response = openai.OpenAI().chat.completions.create( 84 messages=request["messages"], 85 model="gpt-4o-mini", 86 ) 87 return response.choices[0].message.content 88 89 return predict_fn 90 91 92 def get_dummy_predict_fn(with_tracing=False): 93 def predict_fn(request): 94 return "test" 95 96 if with_tracing: 97 return mlflow.trace(predict_fn) 98 99 return predict_fn 100 101 102 @pytest.fixture 103 def mock_openai_env(monkeypatch): 104 monkeypatch.setenv("OPENAI_API_KEY", "fake_api_key") 105 106 107 @pytest.mark.usefixtures("mock_openai_env") 108 @pytest.mark.parametrize( 109 ("predict_fn_generator", "with_tracing", "should_be_wrapped"), 110 [ 111 (get_dummy_predict_fn, False, True), 112 # If the function is already traced, it should not be wrapped with @mlflow.trace. 113 (get_dummy_predict_fn, True, False), 114 # OpenAI autologging is automatically enabled during evaluation, 115 # so we don't need to wrap the function with @mlflow.trace. 116 (get_openai_predict_fn, False, False), 117 (get_openai_predict_fn, True, False), 118 ], 119 ids=[ 120 "dummy predict_fn without tracing", 121 "dummy predict_fn with tracing", 122 "openai predict_fn without tracing", 123 "openai predict_fn with tracing", 124 ], 125 ) 126 def test_convert_predict_fn(predict_fn_generator, with_tracing, should_be_wrapped): 127 predict_fn = predict_fn_generator(with_tracing=with_tracing) 128 sample_input = {"request": {"messages": [{"role": "user", "content": "test"}]}} 129 130 # predict_fn is callable as is 131 result = predict_fn(**sample_input) 132 assert result == "test" 133 assert len(get_traces()) == (1 if with_tracing else 0) 134 purge_traces() 135 136 converted_fn = convert_predict_fn(predict_fn, sample_input) 137 138 # converted function takes a single 'request' argument 139 result = converted_fn(request=sample_input) 140 assert result == "test" 141 142 # Trace should be generated if decorated or wrapped with @mlflow.trace 143 assert len(get_traces()) == (1 if with_tracing or should_be_wrapped else 0) 144 purge_traces() 145 146 # All function should generate a trace when executed through mlflow.genai.evaluate 147 @scorer 148 def dummy_scorer(inputs, outputs): 149 return 0 150 151 mlflow.genai.evaluate( 152 data=[{"inputs": sample_input}], 153 predict_fn=predict_fn, 154 scorers=[dummy_scorer], 155 ) 156 assert len(get_traces()) == 1 157 158 159 def test_convert_predict_fn_skip_validation(monkeypatch): 160 monkeypatch.setenv("MLFLOW_GENAI_EVAL_SKIP_TRACE_VALIDATION", "true") 161 162 call_count = 0 163 164 def dummy_predict_fn(question: str, context: str): 165 nonlocal call_count 166 call_count += 1 167 return question + context 168 169 sample_input = {"question": "test", "context": "test"} 170 converted_fn = convert_predict_fn(dummy_predict_fn, sample_input) 171 # Predict function should not be validated when the env var is set to True 172 assert call_count == 0 173 174 # converted function takes a single 'request' argument 175 result = converted_fn(request=sample_input) 176 assert result == "testtest" 177 178 179 def create_span( 180 span_id: int, 181 parent_id: int, 182 span_type: str, 183 inputs: dict[str, Any], 184 outputs: dict[str, Any], 185 ) -> Span: 186 otel_span = OTelReadableSpan( 187 name="test", 188 context=build_otel_context(123, span_id), 189 parent=build_otel_context(123, parent_id) if parent_id else None, 190 start_time=100, 191 end_time=200, 192 attributes={ 193 "mlflow.spanInputs": json.dumps(inputs), 194 "mlflow.spanOutputs": json.dumps(outputs), 195 "mlflow.spanType": json.dumps(span_type), 196 }, 197 ) 198 return Span(otel_span) 199 200 201 @pytest.mark.parametrize( 202 ("spans", "expected_retrieval_context"), 203 [ 204 # multiple retrieval steps - only take the last top-level one 205 ( 206 [ 207 create_span( 208 span_id=1, 209 parent_id=None, # root span 210 inputs="question", 211 outputs={"generations": [[{"text": "some text"}]]}, 212 span_type=SpanType.LLM, 213 ), 214 create_span( 215 span_id=2, 216 parent_id=1, 217 inputs="What is the capital of France?", 218 outputs=[ 219 { 220 "page_content": "document content 3", 221 "metadata": { 222 "doc_uri": "uri3", 223 "chunk_id": "3", 224 }, 225 "type": "Document", 226 }, 227 ], 228 span_type=SpanType.RETRIEVER, 229 ), 230 create_span( 231 span_id=3, 232 parent_id=1, 233 inputs="What is the capital of France?", 234 outputs=[ 235 { 236 "page_content": "document content 1", 237 "metadata": { 238 "doc_uri": "uri1", 239 "chunk_id": "1", 240 }, 241 "type": "Document", 242 }, 243 { 244 "page_content": "document content 2", 245 "metadata": { 246 "doc_uri": "uri2", 247 "chunk_id": "2", 248 }, 249 "type": "Document", 250 }, 251 ], 252 span_type=SpanType.RETRIEVER, 253 ), 254 create_span( 255 span_id=4, 256 parent_id=3, 257 inputs="This should be ignored because it's not a top-level retrieval span", 258 outputs=[ 259 { 260 "page_content": "document content 4", 261 "metadata": { 262 "doc_uri": "uri4", 263 "chunk_id": "4", 264 }, 265 "type": "Document", 266 }, 267 ], 268 span_type=SpanType.RETRIEVER, 269 ), 270 ], 271 { 272 "0000000000000002": [ 273 { 274 "doc_uri": "uri3", 275 "content": "document content 3", 276 }, 277 ], 278 "0000000000000003": [ 279 { 280 "doc_uri": "uri1", 281 "content": "document content 1", 282 }, 283 { 284 "doc_uri": "uri2", 285 "content": "document content 2", 286 }, 287 ], 288 }, 289 ), 290 # one retrieval step 291 ( 292 [ 293 create_span( 294 span_id=1, 295 parent_id=None, 296 inputs="What is the capital of France?", 297 outputs=[ 298 { 299 "page_content": "document content 1", 300 "metadata": { 301 "doc_uri": "uri1", 302 "chunk_id": "1", 303 }, 304 "type": "Document", 305 }, 306 # missing doc_uri 307 { 308 "page_content": "document content 2", 309 "metadata": { 310 "chunk_id": "2", 311 }, 312 "type": "Document", 313 }, 314 # missing content 315 { 316 "metadata": { 317 "doc_uri": "uri3", 318 "chunk_id": "3", 319 }, 320 "type": "Document", 321 }, 322 # missing metadata 323 { 324 "page_content": "document content 4", 325 "type": "Document", 326 }, 327 ], 328 span_type=SpanType.RETRIEVER, 329 ), 330 ], 331 { 332 "0000000000000001": [ 333 { 334 "doc_uri": "uri1", 335 "content": "document content 1", 336 }, 337 { 338 "content": "document content 2", 339 }, 340 { 341 "content": None, 342 "doc_uri": "uri3", 343 }, 344 { 345 "content": "document content 4", 346 }, 347 ], 348 }, 349 ), 350 # one retrieval step - string outputs (UC schema casts attributes to MAP<STRING, STRING>) 351 ( 352 [ 353 create_span( 354 span_id=1, 355 parent_id=None, 356 inputs="What is the capital of France?", 357 outputs=json.dumps([ 358 { 359 "page_content": "document content 1", 360 "metadata": {"doc_uri": "uri1"}, 361 }, 362 { 363 "page_content": "document content 2", 364 "metadata": {"doc_uri": "uri2"}, 365 }, 366 ]), 367 span_type=SpanType.RETRIEVER, 368 ), 369 ], 370 { 371 "0000000000000001": [ 372 {"doc_uri": "uri1", "content": "document content 1"}, 373 {"doc_uri": "uri2", "content": "document content 2"}, 374 ], 375 }, 376 ), 377 # one retrieval step - empty retrieval span outputs 378 ( 379 [ 380 create_span( 381 span_id=1, 382 parent_id=None, 383 inputs="What is the capital of France?", 384 outputs=[], 385 span_type=SpanType.RETRIEVER, 386 ), 387 ], 388 {"0000000000000001": []}, 389 ), 390 # one retrieval step - wrong format retrieval span outputs 391 ( 392 [ 393 create_span( 394 span_id=1, 395 parent_id=None, 396 inputs="What is the capital of France?", 397 outputs=["wrong output", "should be ignored"], 398 span_type=SpanType.RETRIEVER, 399 ), 400 ], 401 {"0000000000000001": []}, 402 ), 403 # no retrieval steps 404 ( 405 [ 406 create_span( 407 span_id=1, 408 parent_id=None, 409 inputs="What is the capital of France?", 410 outputs=[{"text": "some text"}], 411 span_type=SpanType.LLM, 412 ), 413 ], 414 {}, 415 ), 416 # None trace 417 ( 418 None, 419 {}, 420 ), 421 ], 422 ) 423 def test_get_retrieval_context_from_trace(spans, expected_retrieval_context): 424 trace = Trace(info=create_test_trace_info(trace_id="tr-123"), data=TraceData(spans=spans)) 425 assert extract_retrieval_context_from_trace(trace) == expected_retrieval_context 426 427 428 @pytest.mark.parametrize( 429 ("input_data", "expected"), 430 [ 431 # String input 432 ("Hello world", "Hello world"), 433 # Chat completion/ChatModel/ChatAgent request 434 ( 435 {"messages": [{"role": "user", "content": "User message"}]}, 436 "User message", 437 ), 438 # Multi-turn messages 439 ( 440 { 441 "messages": [ 442 {"role": "assistant", "content": "First"}, 443 {"role": "user", "content": "Second"}, 444 ] 445 }, 446 '[{"role": "assistant", "content": "First"}, {"role": "user", "content": "Second"}]', 447 ), 448 # Empty dict input 449 ( 450 {}, 451 "{}", 452 ), 453 # Dict input 454 ( 455 {"unsupported_key": "value"}, 456 "{'unsupported_key': 'value'}", 457 ), 458 # Non-standard messages 459 ( 460 { 461 "messages": [ 462 {"role": "assistant", "k": "First"}, 463 {"role": "user", "k": "Second"}, 464 ] 465 }, 466 "{'messages': [{'role': 'assistant', 'k': 'First'}, {'role': 'user', 'k': 'Second'}]}", 467 ), 468 # Strands format - list of messages with role and content 469 ( 470 [{"role": "user", "content": [{"text": "hello"}]}], 471 '[{"role": "user", "content": [{"text": "hello"}]}]', 472 ), 473 # Strands format - multiple messages with simple string content 474 ( 475 [ 476 {"role": "user", "content": "First"}, 477 {"role": "assistant", "content": "Second"}, 478 ], 479 '[{"role": "user", "content": "First"}, {"role": "assistant", "content": "Second"}]', 480 ), 481 # Strands format - single message with string content 482 ( 483 [{"role": "user", "content": "Single message"}], 484 '[{"role": "user", "content": "Single message"}]', 485 ), 486 ], 487 ) 488 def test_parse_inputs_to_str(input_data, expected): 489 assert parse_inputs_to_str(input_data) == expected 490 491 492 @pytest.mark.parametrize( 493 ("output_data", "expected"), 494 [ 495 # String output 496 ("Output string", "Output string"), 497 # Chat completion/ChatModel response 498 ( 499 { 500 "choices": [ 501 { 502 "index": 0, 503 "message": { 504 "role": "assistant", 505 "content": "Output content", 506 }, 507 } 508 ] 509 }, 510 "Output content", 511 ), 512 # ChatAgent response with multiple messages 513 ( 514 { 515 "messages": [ 516 { 517 "role": "user", 518 "content": "Input content", 519 }, 520 { 521 "role": "assistant", 522 "content": "Intermediate Output content", 523 }, 524 { 525 "role": "user", 526 "content": "Intermediate Input content", 527 }, 528 { 529 "role": "assistant", 530 "content": "Output content", 531 }, 532 ] 533 }, 534 "Output content", 535 ), 536 # List of strings 537 (["Response content"], "Response content"), 538 # ChatAgent response with multiple messages 539 ( 540 [ 541 { 542 "choices": [ 543 { 544 "index": 0, 545 "message": { 546 "role": "assistant", 547 "content": "Output content", 548 }, 549 } 550 ] 551 } 552 ], 553 "Output content", 554 ), 555 # List of direct string response 556 ( 557 {"unsupported_key": "value"}, 558 '{"unsupported_key": "value"}', 559 ), 560 # Handle custom messages array format 561 ( 562 {"messages": ["a", "b", "c"]}, 563 '{"messages": ["a", "b", "c"]}', 564 ), 565 # OpenAI Responses API format with output_text content type 566 ( 567 { 568 "output": [ 569 { 570 "id": "msg_123", 571 "type": "message", 572 "role": "assistant", 573 "content": [{"type": "output_text", "text": "Response from Responses API"}], 574 } 575 ] 576 }, 577 "Response from Responses API", 578 ), 579 # OpenAI Responses API format with text content type 580 ( 581 { 582 "output": [ 583 { 584 "id": "msg_456", 585 "type": "message", 586 "role": "assistant", 587 "content": [{"type": "text", "text": "Text type response"}], 588 } 589 ] 590 }, 591 "Text type response", 592 ), 593 # OpenAI Responses API format with string content 594 ( 595 { 596 "output": [ 597 { 598 "id": "msg_789", 599 "type": "message", 600 "role": "assistant", 601 "content": "Direct string content", 602 } 603 ] 604 }, 605 "Direct string content", 606 ), 607 # OpenAI Responses API format with multiple output items (gets last assistant message) 608 ( 609 { 610 "output": [ 611 { 612 "id": "item_1", 613 "type": "function_call", 614 "name": "get_weather", 615 }, 616 { 617 "id": "msg_final", 618 "type": "message", 619 "role": "assistant", 620 "content": [{"type": "output_text", "text": "Final response"}], 621 }, 622 ] 623 }, 624 "Final response", 625 ), 626 ], 627 ) 628 def test_parse_outputs_to_str(output_data, expected): 629 assert parse_outputs_to_str(output_data) == expected 630 631 632 @pytest.mark.parametrize( 633 ("input_value", "expected"), 634 [ 635 (None, True), 636 (np.nan, True), 637 (float("nan"), True), 638 ("Not NaN", False), 639 (123, False), 640 ([], False), 641 ({}, False), 642 (0.0, False), 643 (1.5, False), 644 ], 645 ) 646 def test_is_none_or_nan(input_value, expected): 647 assert is_none_or_nan(input_value) == expected 648 649 650 def test_extract_expectations_from_trace_with_source_filter(): 651 with mlflow.start_span(name="test_span") as span: 652 span.set_inputs({"question": "What is MLflow?"}) 653 span.set_outputs({"answer": "MLflow is an open source platform"}) 654 655 trace_id = span.trace_id 656 657 human_expectation = Expectation( 658 name="human_expectation", 659 value={"expected": "Answer from human"}, 660 source=AssessmentSource(source_type=AssessmentSourceType.HUMAN), 661 ) 662 mlflow.log_assessment(trace_id=trace_id, assessment=human_expectation) 663 664 llm_expectation = Expectation( 665 name="llm_expectation", 666 value="LLM generated expectation", 667 source=AssessmentSource(source_type=AssessmentSourceType.LLM_JUDGE), 668 ) 669 mlflow.log_assessment(trace_id=trace_id, assessment=llm_expectation) 670 671 code_expectation = Expectation( 672 name="code_expectation", 673 value=42, 674 source=AssessmentSource(source_type=AssessmentSourceType.CODE), 675 ) 676 mlflow.log_assessment(trace_id=trace_id, assessment=code_expectation) 677 678 trace = mlflow.get_trace(trace_id) 679 680 result = extract_expectations_from_trace(trace, source_type=None) 681 assert result == { 682 "human_expectation": {"expected": "Answer from human"}, 683 "llm_expectation": "LLM generated expectation", 684 "code_expectation": 42, 685 } 686 687 result = extract_expectations_from_trace(trace, source_type="HUMAN") 688 assert result == {"human_expectation": {"expected": "Answer from human"}} 689 690 result = extract_expectations_from_trace(trace, source_type="LLM_JUDGE") 691 assert result == {"llm_expectation": "LLM generated expectation"} 692 693 result = extract_expectations_from_trace(trace, source_type="CODE") 694 assert result == {"code_expectation": 42} 695 696 result = extract_expectations_from_trace(trace, source_type="human") 697 assert result == {"human_expectation": {"expected": "Answer from human"}} 698 699 with pytest.raises(mlflow.exceptions.MlflowException, match="Invalid assessment source type"): 700 extract_expectations_from_trace(trace, source_type="INVALID_SOURCE") 701 702 703 def test_extract_expectations_from_trace_returns_none_when_no_expectations(): 704 with mlflow.start_span(name="test_span") as span: 705 span.set_inputs({"question": "What is MLflow?"}) 706 span.set_outputs({"answer": "MLflow is an open source platform"}) 707 708 trace = mlflow.get_trace(span.trace_id) 709 710 result = extract_expectations_from_trace(trace) 711 assert result is None 712 713 result = extract_expectations_from_trace(trace, source_type="HUMAN") 714 assert result is None 715 716 717 def test_extract_inputs_and_outputs_from_trace(): 718 test_inputs = {"question": "What is MLflow?", "context": "MLflow is a tool"} 719 test_outputs = {"answer": "MLflow is an open source platform", "confidence": 0.95} 720 721 with mlflow.start_span(name="test_span") as span: 722 span.set_inputs(test_inputs) 723 span.set_outputs(test_outputs) 724 725 trace = mlflow.get_trace(span.trace_id) 726 727 assert extract_inputs_from_trace(trace) == test_inputs 728 assert extract_outputs_from_trace(trace) == test_outputs 729 730 trace_without_data = Trace( 731 info=create_test_trace_info(trace_id="tr-123"), data=TraceData(spans=[]) 732 ) 733 assert extract_inputs_from_trace(trace_without_data) is None 734 assert extract_outputs_from_trace(trace_without_data) is None 735 736 737 def test_extract_request_and_response_from_trace(): 738 test_inputs = {"messages": [{"role": "user", "content": "What is MLflow?"}]} 739 test_outputs = { 740 "choices": [{"index": 0, "message": {"role": "assistant", "content": "MLflow is great"}}] 741 } 742 743 with mlflow.start_span(name="test_span") as span: 744 span.set_inputs(test_inputs) 745 span.set_outputs(test_outputs) 746 747 trace = mlflow.get_trace(span.trace_id) 748 749 assert extract_request_from_trace(trace) == "What is MLflow?" 750 assert extract_response_from_trace(trace) == "MLflow is great" 751 752 trace_without_data = Trace( 753 info=create_test_trace_info(trace_id="tr-123"), data=TraceData(spans=[]) 754 ) 755 assert extract_request_from_trace(trace_without_data) is None 756 assert extract_response_from_trace(trace_without_data) is None 757 758 759 def test_extract_request_and_response_with_string_inputs(): 760 test_inputs = "Simple string input" 761 test_outputs = "Simple string output" 762 763 with mlflow.start_span(name="test_span") as span: 764 span.set_inputs(test_inputs) 765 span.set_outputs(test_outputs) 766 767 trace = mlflow.get_trace(span.trace_id) 768 769 assert extract_request_from_trace(trace) == "Simple string input" 770 assert extract_response_from_trace(trace) == "Simple string output" 771 772 773 def test_does_store_support_trace_linking(): 774 test_trace = Trace(info=create_test_trace_info(trace_id="tr-123"), data=TraceData(spans=[])) 775 776 # Databricks backend support trace linking 777 assert _does_store_support_trace_linking( 778 tracking_uri="databricks", 779 trace=test_trace, 780 run_id="run-123", 781 ) 782 783 assert _does_store_support_trace_linking( 784 tracking_uri="databricks://test", 785 trace=test_trace, 786 run_id="run-123", 787 ) 788 789 mock_client = mock.MagicMock() 790 with mock.patch("mlflow.genai.utils.trace_utils.MlflowClient", return_value=mock_client): 791 # SQLAlchemy backend support trace linking 792 mock_client.link_traces_to_run.side_effect = None 793 794 assert _does_store_support_trace_linking( 795 tracking_uri="sqlalchemy://test", 796 trace=test_trace, 797 run_id="run-123", 798 ) 799 800 # File store doesn't support trace linking 801 mock_client.link_traces_to_run.side_effect = Exception("Test error") 802 803 assert not _does_store_support_trace_linking( 804 tracking_uri="file://test", 805 trace=test_trace, 806 run_id="run-123", 807 ) 808 809 # Result should be cached per tracking URI 810 mock_client.reset_mock() 811 mock_client.link_traces_to_run.side_effect = None 812 for _ in range(10): 813 assert _does_store_support_trace_linking( 814 tracking_uri="sqlalchemy://test2", 815 trace=test_trace, 816 run_id="run-123", 817 ) 818 mock_client.link_traces_to_run.assert_called_once() 819 820 821 def test_create_minimal_trace_restores_session_metadata(): 822 source = DatasetRecordSource( 823 source_type=DatasetRecordSourceType.TRACE, 824 source_data={"trace_id": "tr-original", "session_id": "session_1"}, 825 ) 826 827 eval_item = EvalItem( 828 request_id="req-123", 829 inputs={"question": "test"}, 830 outputs="answer", 831 expectations={}, 832 source=source, 833 ) 834 835 trace = create_minimal_trace(eval_item) 836 837 # Verify session metadata was restored 838 assert trace.info.trace_metadata.get("mlflow.trace.session") == "session_1" 839 assert trace.data._get_root_span().inputs == {"question": "test"} 840 assert trace.data._get_root_span().outputs == "answer" 841 842 843 def test_create_minimal_trace_without_source(): 844 eval_item = EvalItem( 845 request_id="req-123", 846 inputs={"question": "test"}, 847 outputs="answer", 848 expectations={}, 849 source=None, 850 ) 851 852 trace = create_minimal_trace(eval_item) 853 854 # Should create trace successfully without session metadata 855 assert trace is not None 856 assert "mlflow.trace.session" not in trace.info.trace_metadata 857 assert trace.data._get_root_span().inputs == {"question": "test"} 858 assert trace.data._get_root_span().outputs == "answer" 859 860 861 def test_create_minimal_trace_with_source_but_no_session(): 862 source = DatasetRecordSource( 863 source_type=DatasetRecordSourceType.TRACE, 864 source_data={"trace_id": "tr-original"}, # No session_id 865 ) 866 867 eval_item = EvalItem( 868 request_id="req-123", 869 inputs={"question": "test"}, 870 outputs="answer", 871 expectations={}, 872 source=source, 873 ) 874 875 trace = create_minimal_trace(eval_item) 876 877 # Should work without session metadata 878 assert trace is not None 879 assert "mlflow.trace.session" not in trace.info.trace_metadata 880 assert trace.data._get_root_span().inputs == {"question": "test"} 881 assert trace.data._get_root_span().outputs == "answer" 882 883 884 def test_parse_tool_call_messages_from_trace(): 885 with mlflow.start_span(name="root") as root_span: 886 root_span.set_inputs({"question": "What is the stock price?"}) 887 888 with mlflow.start_span(name="get_stock_price", span_type=SpanType.TOOL) as tool_span: 889 tool_span.set_inputs({"symbol": "AAPL"}) 890 tool_span.set_outputs({"price": 150.0}) 891 892 with mlflow.start_span(name="get_market_cap", span_type=SpanType.TOOL) as tool_span2: 893 tool_span2.set_inputs({"symbol": "AAPL"}) 894 tool_span2.set_outputs({"market_cap": "2.5T"}) 895 896 root_span.set_outputs("AAPL price is $150.") 897 898 trace = mlflow.get_trace(root_span.trace_id) 899 tool_messages = parse_tool_call_messages_from_trace(trace) 900 901 assert len(tool_messages) == 2 902 assert tool_messages[0] == { 903 "role": "tool", 904 "content": "Tool: get_stock_price\nInputs: {'symbol': 'AAPL'}\nOutputs: {'price': 150.0}", 905 } 906 assert tool_messages[1] == { 907 "role": "tool", 908 "content": ( 909 "Tool: get_market_cap\nInputs: {'symbol': 'AAPL'}\nOutputs: {'market_cap': '2.5T'}" 910 ), 911 } 912 913 914 def test_parse_tool_call_messages_from_trace_no_tools(): 915 with mlflow.start_span(name="root") as span: 916 span.set_inputs({"question": "Hello"}) 917 span.set_outputs("Hi there") 918 919 trace = mlflow.get_trace(span.trace_id) 920 tool_messages = parse_tool_call_messages_from_trace(trace) 921 922 assert tool_messages == [] 923 924 925 def test_parse_tool_call_messages_from_trace_tool_without_outputs(): 926 with mlflow.start_span(name="root") as root_span: 927 root_span.set_inputs({"query": "test"}) 928 929 with mlflow.start_span(name="my_tool", span_type=SpanType.TOOL) as tool_span: 930 tool_span.set_inputs({"param": "value"}) 931 932 root_span.set_outputs("result") 933 934 trace = mlflow.get_trace(root_span.trace_id) 935 tool_messages = parse_tool_call_messages_from_trace(trace) 936 937 assert len(tool_messages) == 1 938 assert tool_messages[0] == { 939 "role": "tool", 940 "content": "Tool: my_tool\nInputs: {'param': 'value'}", 941 } 942 943 944 def test_extract_tool_name_from_span_uses_span_name_by_default(): 945 with mlflow.start_span(name="root") as root_span: 946 root_span.set_inputs({"query": "test"}) 947 948 with mlflow.start_span(name="my_tool", span_type=SpanType.TOOL) as tool_span: 949 tool_span.set_inputs({"arg": "value"}) 950 951 root_span.set_outputs("result") 952 953 trace = mlflow.get_trace(root_span.trace_id) 954 tool_spans = trace.search_spans(span_type=SpanType.TOOL) 955 956 assert _extract_tool_name_from_span(tool_spans[0]) == "my_tool" 957 958 959 def test_extract_tool_name_from_span_extracts_from_call_tool_name(): 960 with mlflow.start_span(name="root") as root_span: 961 root_span.set_inputs({"query": "test"}) 962 963 with mlflow.start_span( 964 name="ToolManager.handle_call", span_type=SpanType.TOOL 965 ) as tool_span: 966 tool_span.set_inputs({"call": {"tool_name": "list_client", "args": {"param": "value"}}}) 967 968 root_span.set_outputs("result") 969 970 trace = mlflow.get_trace(root_span.trace_id) 971 tool_spans = trace.search_spans(span_type=SpanType.TOOL) 972 973 assert _extract_tool_name_from_span(tool_spans[0]) == "list_client" 974 975 976 def test_resolve_conversation_from_session(): 977 session_id = "test_session_resolve" 978 traces = [] 979 980 with mlflow.start_span(name="turn_0") as span: 981 span.set_inputs({"messages": [{"role": "user", "content": "What is AAPL price?"}]}) 982 span.set_outputs("AAPL is $150.") 983 mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id}) 984 traces.append(mlflow.get_trace(span.trace_id)) 985 986 with mlflow.start_span(name="turn_1") as span: 987 span.set_inputs({"messages": [{"role": "user", "content": "How about MSFT?"}]}) 988 span.set_outputs("MSFT is $300.") 989 mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id}) 990 traces.append(mlflow.get_trace(span.trace_id)) 991 992 conversation = resolve_conversation_from_session(traces) 993 994 assert len(conversation) == 4 995 assert conversation[0] == {"role": "user", "content": "What is AAPL price?"} 996 assert conversation[1] == {"role": "assistant", "content": "AAPL is $150."} 997 assert conversation[2] == {"role": "user", "content": "How about MSFT?"} 998 assert conversation[3] == {"role": "assistant", "content": "MSFT is $300."} 999 1000 1001 def test_resolve_conversation_from_session_with_tool_calls(): 1002 session_id = "test_session_with_tools" 1003 traces = [] 1004 1005 with mlflow.start_span(name="turn_0") as root_span: 1006 root_span.set_inputs({"messages": [{"role": "user", "content": "Get AAPL price"}]}) 1007 1008 with mlflow.start_span(name="get_stock_price", span_type=SpanType.TOOL) as tool_span: 1009 tool_span.set_inputs({"symbol": "AAPL"}) 1010 tool_span.set_outputs({"price": 150}) 1011 1012 root_span.set_outputs("AAPL is $150.") 1013 mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id}) 1014 traces.append(mlflow.get_trace(root_span.trace_id)) 1015 1016 conversation = resolve_conversation_from_session(traces, include_tool_calls=False) 1017 assert len(conversation) == 2 1018 assert conversation[0]["role"] == "user" 1019 assert conversation[1]["role"] == "assistant" 1020 1021 conversation_with_tools = resolve_conversation_from_session(traces, include_tool_calls=True) 1022 assert len(conversation_with_tools) == 3 1023 assert conversation_with_tools[0] == {"role": "user", "content": "Get AAPL price"} 1024 assert conversation_with_tools[1] == { 1025 "role": "tool", 1026 "content": "Tool: get_stock_price\nInputs: {'symbol': 'AAPL'}\nOutputs: {'price': 150}", 1027 } 1028 assert conversation_with_tools[2] == {"role": "assistant", "content": "AAPL is $150."} 1029 1030 1031 def test_resolve_conversation_from_session_empty(): 1032 assert resolve_conversation_from_session([]) == [] 1033 1034 1035 @pytest.mark.parametrize("include_timing", [True, False]) 1036 def test_resolve_conversation_from_session_with_timing_parameter(include_timing): 1037 session_id = "test_session" 1038 traces = [] 1039 1040 with mlflow.start_span(name="turn_0") as span: 1041 span.set_inputs({"messages": [{"role": "user", "content": "What is MLflow?"}]}) 1042 span.set_outputs("MLflow is an ML platform.") 1043 mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id}) 1044 traces.append(mlflow.get_trace(span.trace_id)) 1045 1046 conversation = resolve_conversation_from_session(traces, include_timing=include_timing) 1047 1048 assert len(conversation) == 2 1049 assert conversation[0] == {"role": "user", "content": "What is MLflow?"} 1050 assert conversation[1]["role"] == "assistant" 1051 assert "MLflow is an ML platform." in conversation[1]["content"] 1052 assert ("[Response duration:" in conversation[1]["content"]) is include_timing 1053 assert ("slowest spans:" in conversation[1]["content"]) is include_timing 1054 1055 1056 def test_session_level_expectations_filtering(): 1057 session_id = "test-session" 1058 1059 with mlflow.start_span(name="test_span") as span: 1060 span.set_inputs({"question": "Test"}) 1061 span.set_outputs({"answer": "Test answer"}) 1062 1063 trace_id = span.trace_id 1064 1065 session_exp = Expectation( 1066 name="session_exp", 1067 value="session_value", 1068 source=AssessmentSource(source_type=AssessmentSourceType.HUMAN), 1069 metadata={TraceMetadataKey.TRACE_SESSION: session_id}, 1070 ) 1071 mlflow.log_assessment(trace_id=trace_id, assessment=session_exp) 1072 1073 trace_exp = Expectation( 1074 name="trace_exp", 1075 value="trace_value", 1076 source=AssessmentSource(source_type=AssessmentSourceType.HUMAN), 1077 metadata={}, 1078 ) 1079 mlflow.log_assessment(trace_id=trace_id, assessment=trace_exp) 1080 1081 trace = mlflow.get_trace(trace_id) 1082 1083 session_result = resolve_expectations_from_session(None, [trace]) 1084 assert session_result == {"session_exp": "session_value"} 1085 assert "trace_exp" not in session_result 1086 1087 1088 def test_resolve_expectations_from_session_with_provided_expectations(): 1089 with mlflow.start_span(name="test_span") as span: 1090 span.set_inputs({"question": "Test"}) 1091 span.set_outputs({"answer": "Test answer"}) 1092 1093 trace = mlflow.get_trace(span.trace_id) 1094 provided_expectations = {"provided": "value"} 1095 1096 result = resolve_expectations_from_session(provided_expectations, [trace]) 1097 assert result == provided_expectations 1098 1099 1100 @pytest.mark.parametrize( 1101 ("expectations", "has_session_exp", "expected"), 1102 [ 1103 (None, False, None), 1104 (None, True, {"session_exp": "session_value"}), 1105 ({"provided": "value"}, True, {"provided": "value"}), 1106 ], 1107 ) 1108 def test_resolve_expectations_from_session_edge_cases(expectations, has_session_exp, expected): 1109 session_id = "test-session" 1110 1111 with mlflow.start_span(name="test_span") as span: 1112 span.set_inputs({"question": "Test"}) 1113 span.set_outputs({"answer": "Test answer"}) 1114 mlflow.update_current_trace(metadata={TraceMetadataKey.TRACE_SESSION: session_id}) 1115 1116 if has_session_exp: 1117 exp = Expectation( 1118 name="session_exp", 1119 value="session_value", 1120 source=AssessmentSource(source_type=AssessmentSourceType.HUMAN), 1121 metadata={TraceMetadataKey.TRACE_SESSION: session_id}, 1122 ) 1123 mlflow.log_assessment(trace_id=span.trace_id, assessment=exp) 1124 1125 trace = mlflow.get_trace(span.trace_id) 1126 result = resolve_expectations_from_session(expectations, [trace]) 1127 assert result == expected 1128 1129 1130 def test_convert_predict_fn_async_function(): 1131 async def async_predict_fn(request): 1132 await asyncio.sleep(0.01) 1133 return "async test response" 1134 1135 sample_input = {"request": {"messages": [{"role": "user", "content": "test"}]}} 1136 1137 converted_fn = convert_predict_fn(async_predict_fn, sample_input) 1138 1139 result = converted_fn(request=sample_input) 1140 assert result == "async test response" 1141 1142 traces = get_traces() 1143 assert len(traces) == 1 1144 purge_traces() 1145 1146 1147 def test_evaluate_with_async_predict_fn(): 1148 async def async_predict_fn(request): 1149 await asyncio.sleep(0.01) 1150 return "async test response" 1151 1152 sample_input = {"request": {"messages": [{"role": "user", "content": "test"}]}} 1153 1154 @scorer 1155 def dummy_scorer(inputs, outputs): 1156 return 0 1157 1158 mlflow.genai.evaluate( 1159 data=[{"inputs": sample_input}], 1160 predict_fn=async_predict_fn, 1161 scorers=[dummy_scorer], 1162 ) 1163 assert len(get_traces()) == 1 1164 purge_traces() 1165 1166 1167 def test_convert_predict_fn_async_function_with_timeout(monkeypatch): 1168 monkeypatch.setenv("MLFLOW_GENAI_EVAL_ASYNC_TIMEOUT", "1") 1169 monkeypatch.setenv("MLFLOW_GENAI_EVAL_SKIP_TRACE_VALIDATION", "true") 1170 1171 async def slow_async_predict_fn(request): 1172 await asyncio.sleep(2) 1173 return "should timeout" 1174 1175 sample_input = {"request": {"messages": [{"role": "user", "content": "test"}]}} 1176 1177 converted_fn = convert_predict_fn(slow_async_predict_fn, sample_input) 1178 1179 with pytest.raises(asyncio.TimeoutError): # noqa: PT011 1180 converted_fn(request=sample_input) 1181 1182 assert len(get_traces()) == 0 1183 1184 1185 @pytest.mark.parametrize( 1186 ("span_type", "use_attribute", "tool_name", "tool_description"), 1187 [ 1188 ("LLM", True, "get_weather", "Get current weather"), 1189 ("CHAT_MODEL", False, "search", "Search the web"), 1190 ], 1191 ) 1192 def test_extract_available_tools_from_trace_basic( 1193 span_type, use_attribute, tool_name, tool_description 1194 ): 1195 tools = [ 1196 { 1197 "type": "function", 1198 "function": { 1199 "name": tool_name, 1200 "description": tool_description, 1201 "parameters": {"type": "object", "properties": {"param": {"type": "string"}}}, 1202 }, 1203 } 1204 ] 1205 1206 with mlflow.start_span(name="test_span", span_type=span_type) as span: 1207 if use_attribute: 1208 set_span_chat_tools(span, tools) 1209 span.set_inputs({"prompt": "test"}) 1210 else: 1211 span.set_inputs({"messages": [{"role": "user", "content": "test"}], "tools": tools}) 1212 span.set_outputs({"response": "result"}) 1213 1214 trace = mlflow.get_trace(span.trace_id) 1215 extracted_tools = extract_available_tools_from_trace(trace) 1216 1217 assert len(extracted_tools) == 1 1218 assert extracted_tools[0].model_dump(exclude_none=True) == { 1219 "type": "function", 1220 "function": { 1221 "name": tool_name, 1222 "description": tool_description, 1223 "parameters": {"type": "object", "properties": {"param": {"type": "string"}}}, 1224 }, 1225 } 1226 1227 1228 def test_extract_available_tools_from_trace_with_multiple_spans(): 1229 tool1 = [ 1230 { 1231 "type": "function", 1232 "function": { 1233 "name": "add", 1234 "description": "Add two numbers", 1235 "parameters": { 1236 "type": "object", 1237 "properties": { 1238 "a": {"type": "number"}, 1239 "b": {"type": "number"}, 1240 }, 1241 }, 1242 }, 1243 } 1244 ] 1245 1246 tool2 = [ 1247 { 1248 "type": "function", 1249 "function": { 1250 "name": "multiply", 1251 "description": "Multiply two numbers", 1252 "parameters": { 1253 "type": "object", 1254 "properties": { 1255 "x": {"type": "number"}, 1256 "y": {"type": "number"}, 1257 }, 1258 }, 1259 }, 1260 } 1261 ] 1262 1263 with mlflow.start_span(name="parent") as parent: 1264 with mlflow.start_span(name="llm1", span_type="LLM") as span1: 1265 set_span_chat_tools(span1, tool1) 1266 1267 with mlflow.start_span(name="llm2", span_type="CHAT_MODEL") as span2: 1268 set_span_chat_tools(span2, tool2) 1269 1270 trace = mlflow.get_trace(parent.trace_id) 1271 extracted_tools = extract_available_tools_from_trace(trace) 1272 1273 assert len(extracted_tools) == 2 1274 1275 extracted_tools_sorted = sorted(extracted_tools, key=lambda t: t.function.name) 1276 1277 assert extracted_tools_sorted[0].model_dump(exclude_none=True) == { 1278 "type": "function", 1279 "function": { 1280 "name": "add", 1281 "description": "Add two numbers", 1282 "parameters": { 1283 "type": "object", 1284 "properties": { 1285 "a": {"type": "number"}, 1286 "b": {"type": "number"}, 1287 }, 1288 }, 1289 }, 1290 } 1291 1292 assert extracted_tools_sorted[1].model_dump(exclude_none=True) == { 1293 "type": "function", 1294 "function": { 1295 "name": "multiply", 1296 "description": "Multiply two numbers", 1297 "parameters": { 1298 "type": "object", 1299 "properties": { 1300 "x": {"type": "number"}, 1301 "y": {"type": "number"}, 1302 }, 1303 }, 1304 }, 1305 } 1306 1307 1308 def test_extract_available_tools_from_trace_deduplication(): 1309 tools = [ 1310 { 1311 "type": "function", 1312 "function": { 1313 "name": "get_weather", 1314 "description": "Get weather info", 1315 "parameters": {"type": "object", "properties": {}}, 1316 }, 1317 } 1318 ] 1319 1320 with mlflow.start_span(name="parent") as parent: 1321 with mlflow.start_span(name="llm1", span_type="LLM") as span1: 1322 set_span_chat_tools(span1, tools) 1323 1324 with mlflow.start_span(name="llm2", span_type="LLM") as span2: 1325 set_span_chat_tools(span2, tools) 1326 1327 trace = mlflow.get_trace(parent.trace_id) 1328 extracted_tools = extract_available_tools_from_trace(trace) 1329 1330 assert len(extracted_tools) == 1 1331 assert extracted_tools[0].model_dump(exclude_none=True) == { 1332 "type": "function", 1333 "function": { 1334 "name": "get_weather", 1335 "description": "Get weather info", 1336 "parameters": {"type": "object", "properties": {}}, 1337 }, 1338 } 1339 1340 1341 def test_extract_available_tools_from_trace_different_descriptions(): 1342 tool1 = [ 1343 { 1344 "type": "function", 1345 "function": { 1346 "name": "search", 1347 "description": "Search the web", 1348 "parameters": {"type": "object", "properties": {}}, 1349 }, 1350 } 1351 ] 1352 1353 tool2 = [ 1354 { 1355 "type": "function", 1356 "function": { 1357 "name": "search", 1358 "description": "Search the database", 1359 "parameters": {"type": "object", "properties": {}}, 1360 }, 1361 } 1362 ] 1363 1364 with mlflow.start_span(name="parent") as parent: 1365 with mlflow.start_span(name="llm1", span_type="LLM") as span1: 1366 set_span_chat_tools(span1, tool1) 1367 1368 with mlflow.start_span(name="llm2", span_type="LLM") as span2: 1369 set_span_chat_tools(span2, tool2) 1370 1371 trace = mlflow.get_trace(parent.trace_id) 1372 extracted_tools = extract_available_tools_from_trace(trace) 1373 1374 assert len(extracted_tools) == 2 1375 1376 extracted_tools_sorted = sorted(extracted_tools, key=lambda t: t.function.description) 1377 1378 assert extracted_tools_sorted[0].model_dump(exclude_none=True) == { 1379 "type": "function", 1380 "function": { 1381 "name": "search", 1382 "description": "Search the database", 1383 "parameters": {"type": "object", "properties": {}}, 1384 }, 1385 } 1386 1387 assert extracted_tools_sorted[1].model_dump(exclude_none=True) == { 1388 "type": "function", 1389 "function": { 1390 "name": "search", 1391 "description": "Search the web", 1392 "parameters": {"type": "object", "properties": {}}, 1393 }, 1394 } 1395 1396 1397 def test_extract_available_tools_from_trace_returns_empty(): 1398 trace_fixture = Trace(info=create_test_trace_info(trace_id="tr-456"), data=TraceData(spans=[])) 1399 result = extract_available_tools_from_trace(trace_fixture) 1400 assert result == [] 1401 1402 1403 @pytest.mark.parametrize( 1404 ("has_valid_tool", "expected_count"), 1405 [ 1406 (False, 0), # Only invalid tools 1407 (True, 1), # Mix of valid and invalid tools 1408 ], 1409 ) 1410 def test_extract_available_tools_from_trace_with_invalid_tools(has_valid_tool, expected_count): 1411 with mlflow.start_span(name="parent") as parent: 1412 if has_valid_tool: 1413 valid_tool = [ 1414 { 1415 "type": "function", 1416 "function": { 1417 "name": "valid_tool", 1418 "description": "A valid tool", 1419 }, 1420 } 1421 ] 1422 with mlflow.start_span(name="llm1", span_type="LLM") as span1: 1423 set_span_chat_tools(span1, valid_tool) 1424 1425 with mlflow.start_span(name="llm2", span_type="LLM") as span2: 1426 span2.set_inputs({ 1427 "messages": [{"role": "user", "content": "test"}], 1428 "tools": [ 1429 {"invalid": "tool"}, # Missing required fields 1430 {"type": "function"}, # Missing function field 1431 ], 1432 }) 1433 1434 trace = mlflow.get_trace(parent.trace_id) 1435 extracted_tools = extract_available_tools_from_trace(trace) 1436 1437 assert len(extracted_tools) == expected_count 1438 if has_valid_tool: 1439 assert extracted_tools[0].model_dump(exclude_none=True) == { 1440 "type": "function", 1441 "function": { 1442 "name": "valid_tool", 1443 "description": "A valid tool", 1444 }, 1445 } 1446 1447 1448 def test_extract_available_tools_llm_fallback_triggered_when_no_tools_found(monkeypatch): 1449 with mlflow.start_span(name="llm_span", span_type=SpanType.LLM) as span: 1450 span.set_inputs({ 1451 "messages": [{"role": "user", "content": "test"}], 1452 "tools": [ 1453 { 1454 "tool_name": "hard_to_extract_tool", 1455 "description": "A tool that is hard to extract", 1456 } 1457 ], 1458 }) 1459 span.set_outputs({"response": "result"}) 1460 1461 trace = mlflow.get_trace(span.trace_id) 1462 1463 mock_tools = [ 1464 ChatTool( 1465 type="function", 1466 function=FunctionToolDefinition( 1467 name="hard_to_extract_tool", 1468 description="A tool that is hard to extract", 1469 parameters={"type": "object", "properties": {"x": {"type": "string"}}}, 1470 ), 1471 ) 1472 ] 1473 1474 mock_llm_fallback_called = [] 1475 1476 def mock_llm_fallback(trace_arg, model_arg): 1477 mock_llm_fallback_called.append({"trace": trace_arg, "model": model_arg}) 1478 return mock_tools 1479 1480 monkeypatch.setattr( 1481 "mlflow.genai.utils.trace_utils._try_extract_available_tools_with_llm", 1482 mock_llm_fallback, 1483 ) 1484 1485 extracted_tools = extract_available_tools_from_trace(trace, model="openai:/gpt-4") 1486 1487 assert len(mock_llm_fallback_called) == 1 1488 assert mock_llm_fallback_called[0]["trace"] == trace 1489 assert mock_llm_fallback_called[0]["model"] == "openai:/gpt-4" 1490 assert len(extracted_tools) == 1 1491 assert extracted_tools[0].model_dump(exclude_none=True) == { 1492 "type": "function", 1493 "function": { 1494 "name": "hard_to_extract_tool", 1495 "description": "A tool that is hard to extract", 1496 "parameters": {"type": "object", "properties": {"x": {"type": "string"}}}, 1497 }, 1498 } 1499 1500 1501 def test_try_extract_available_tools_with_llm_returns_empty_on_error(monkeypatch): 1502 with mlflow.start_span(name="llm_span", span_type=SpanType.LLM) as span: 1503 span.set_inputs({"messages": [{"role": "user", "content": "test"}]}) 1504 span.set_outputs({"response": "result"}) 1505 1506 trace = mlflow.get_trace(span.trace_id) 1507 1508 def mock_raise_error(*args, **kwargs): 1509 raise RuntimeError("LLM API error") 1510 1511 monkeypatch.setattr( 1512 "mlflow.genai.utils.trace_utils.get_chat_completions_with_structured_output", 1513 mock_raise_error, 1514 ) 1515 1516 result = _try_extract_available_tools_with_llm(trace, model="openai:/gpt-4") 1517 assert result == [] 1518 1519 1520 def test_should_keep_trace_preserves_input_trace_ids(): 1521 trace_info = create_test_trace_info( 1522 trace_id="tr-input-123", 1523 request_time=2000, 1524 ) 1525 trace = Trace(info=trace_info, data=TraceData(spans=[])) 1526 1527 eval_start_time = 1000 1528 input_trace_ids = {"tr-input-123"} 1529 1530 result = _should_keep_trace(trace, eval_start_time, input_trace_ids) 1531 assert result is True 1532 1533 1534 def test_should_keep_trace_deletes_non_input_traces_after_eval_start(): 1535 trace_info = create_test_trace_info( 1536 trace_id="tr-extra-456", 1537 request_time=2000, 1538 ) 1539 trace = Trace(info=trace_info, data=TraceData(spans=[])) 1540 1541 eval_start_time = 1000 1542 input_trace_ids = {"tr-input-123"} 1543 1544 result = _should_keep_trace(trace, eval_start_time, input_trace_ids) 1545 assert result is False 1546 1547 1548 def test_clean_up_extra_traces_preserves_input_traces(): 1549 experiment_id = mlflow.set_experiment("test_experiment").experiment_id 1550 1551 with mlflow.start_span(name="input_trace_1") as span1: 1552 span1.set_inputs({"question": "test1"}) 1553 span1.set_outputs({"answer": "answer1"}) 1554 trace1 = mlflow.get_trace(span1.trace_id) 1555 1556 with mlflow.start_span(name="input_trace_2") as span2: 1557 span2.set_inputs({"question": "test2"}) 1558 span2.set_outputs({"answer": "answer2"}) 1559 trace2 = mlflow.get_trace(span2.trace_id) 1560 1561 eval_start_time = int(trace1.info.timestamp_ms - 1000) 1562 1563 input_trace_ids = {trace1.info.trace_id, trace2.info.trace_id} 1564 all_traces = [trace1, trace2] 1565 1566 clean_up_extra_traces(all_traces, eval_start_time, experiment_id, input_trace_ids) 1567 1568 remaining_traces = get_traces() 1569 remaining_trace_ids = {t.info.trace_id for t in remaining_traces} 1570 assert trace1.info.trace_id in remaining_trace_ids 1571 assert trace2.info.trace_id in remaining_trace_ids 1572 1573 1574 def test_clean_up_extra_traces_uses_correct_experiment_id(): 1575 exp_1 = mlflow.set_experiment("cleanup_test_experiment").experiment_id 1576 with mlflow.start_span(name="input_trace") as span1: 1577 span1.set_inputs({"question": "test"}) 1578 span1.set_outputs({"answer": "answer"}) 1579 input_trace = mlflow.get_trace(span1.trace_id) 1580 1581 with mlflow.start_span(name="extra_trace") as span2: 1582 span2.set_inputs({"question": "extra"}) 1583 span2.set_outputs({"answer": "extra_answer"}) 1584 extra_trace = mlflow.get_trace(span2.trace_id) 1585 1586 mlflow.set_experiment("cleanup_test_experiment_2") 1587 clean_up_extra_traces([input_trace, extra_trace], 0, exp_1, {input_trace.info.trace_id}) 1588 1589 remaining_traces = mlflow.search_traces(locations=[exp_1], return_type="list") 1590 assert len(remaining_traces) == 1 1591 assert remaining_traces[0].info.trace_id == input_trace.info.trace_id 1592 1593 1594 def test_evaluate_with_trace_column_preserves_traces(): 1595 @scorer 1596 def dummy_scorer(inputs, outputs): 1597 return 1.0 1598 1599 with mlflow.start_span(name="original_trace") as span: 1600 span.set_inputs({"question": "What is MLflow?"}) 1601 span.set_outputs({"answer": "MLflow is an ML platform"}) 1602 1603 original_trace = mlflow.get_trace(span.trace_id) 1604 original_trace_id = original_trace.info.trace_id 1605 1606 eval_df = pd.DataFrame([ 1607 { 1608 "trace": original_trace, 1609 "inputs": {"question": "What is MLflow?"}, 1610 "outputs": {"answer": "MLflow is an ML platform"}, 1611 } 1612 ]) 1613 1614 mlflow.genai.evaluate(data=eval_df, scorers=[dummy_scorer]) 1615 1616 remaining_traces = get_traces() 1617 remaining_trace_ids = {t.info.trace_id for t in remaining_traces} 1618 assert original_trace_id in remaining_trace_ids