test_tracked_events.py
1 import json 2 import time 3 from unittest import mock 4 from unittest.mock import AsyncMock, MagicMock, patch 5 6 import pandas as pd 7 import pytest 8 import sklearn.neighbors as knn 9 from click.testing import CliRunner 10 from fastapi import Request 11 12 import mlflow 13 from mlflow import MlflowClient 14 from mlflow.entities import ( 15 EvaluationDataset, 16 Expectation, 17 Feedback, 18 GatewayEndpointModelConfig, 19 IssueSeverity, 20 IssueStatus, 21 Metric, 22 Param, 23 RunTag, 24 ) 25 from mlflow.entities.assessment_source import AssessmentSource, AssessmentSourceType 26 from mlflow.entities.gateway_budget_policy import ( 27 BudgetAction, 28 BudgetDuration, 29 BudgetDurationUnit, 30 BudgetTargetScope, 31 BudgetUnit, 32 ) 33 from mlflow.entities.gateway_endpoint import GatewayModelLinkageType 34 from mlflow.entities.gateway_guardrail import GuardrailAction, GuardrailStage 35 from mlflow.entities.trace import Trace 36 from mlflow.entities.webhook import WebhookAction, WebhookEntity, WebhookEvent 37 from mlflow.gateway.cli import start 38 from mlflow.gateway.constants import MLFLOW_GATEWAY_CALLER_HEADER 39 from mlflow.gateway.schemas import chat 40 from mlflow.genai.datasets import create_dataset 41 from mlflow.genai.discovery.entities import _TriageResult 42 from mlflow.genai.discovery.pipeline import discover_issues 43 from mlflow.genai.judges import make_judge 44 from mlflow.genai.judges.base import AlignmentOptimizer 45 from mlflow.genai.scorers import scorer 46 from mlflow.genai.scorers.base import Scorer 47 from mlflow.genai.scorers.builtin_scorers import ( 48 Completeness, 49 Guidelines, 50 RelevanceToQuery, 51 Safety, 52 UserFrustration, 53 ) 54 from mlflow.genai.simulators import ConversationSimulator 55 from mlflow.pyfunc.model import ( 56 ResponsesAgent, 57 ResponsesAgentRequest, 58 ResponsesAgentResponse, 59 ) 60 from mlflow.server.gateway_api import chat_completions, invocations 61 from mlflow.store.tracking.gateway.entities import GatewayEndpointConfig, GatewayModelConfig 62 from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore 63 from mlflow.telemetry.client import TelemetryClient 64 from mlflow.telemetry.events import ( 65 AiCommandRunEvent, 66 AlignJudgeEvent, 67 AutologgingEvent, 68 CreateDatasetEvent, 69 CreateExperimentEvent, 70 CreateLoggedModelEvent, 71 CreateModelVersionEvent, 72 CreatePromptEvent, 73 CreateRegisteredModelEvent, 74 CreateRunEvent, 75 CreateWebhookEvent, 76 DiscoverIssuesEvent, 77 EvaluateEvent, 78 GatewayCreateBudgetPolicyEvent, 79 GatewayCreateEndpointEvent, 80 GatewayCreateGuardrailEvent, 81 GatewayCreateModelDefinitionEvent, 82 GatewayCreateSecretEvent, 83 GatewayDeleteBudgetPolicyEvent, 84 GatewayDeleteEndpointEvent, 85 GatewayDeleteGuardrailEvent, 86 GatewayDeleteSecretEvent, 87 GatewayGetEndpointEvent, 88 GatewayInvocationEvent, 89 GatewayListBudgetPoliciesEvent, 90 GatewayListEndpointsEvent, 91 GatewayListSecretsEvent, 92 GatewayStartEvent, 93 GatewayUpdateBudgetPolicyEvent, 94 GatewayUpdateEndpointEvent, 95 GatewayUpdateGuardrailEvent, 96 GatewayUpdateSecretEvent, 97 GenAIEvaluateEvent, 98 GetLoggedModelEvent, 99 GitModelVersioningEvent, 100 InvokeCustomJudgeModelEvent, 101 LoadPromptEvent, 102 LogAssessmentEvent, 103 LogBatchEvent, 104 LogDatasetEvent, 105 LogMetricEvent, 106 LogParamEvent, 107 MakeJudgeEvent, 108 McpRunEvent, 109 MergeRecordsEvent, 110 PromptOptimizationEvent, 111 ScorerCallEvent, 112 SimulateConversationEvent, 113 StartTraceEvent, 114 TracingContextPropagation, 115 TrackingServerStartEvent, 116 UpdateIssueEvent, 117 ) 118 from mlflow.tracing.distributed import ( 119 get_tracing_context_headers_for_http_request, 120 set_tracing_context_from_http_request_headers, 121 ) 122 from mlflow.tracking.fluent import ( 123 _create_dataset_input, 124 _create_logged_model, 125 _initialize_logged_model, 126 ) 127 from mlflow.utils.os import is_windows 128 129 from tests.telemetry.helper_functions import validate_telemetry_record 130 131 132 class TestModel(mlflow.pyfunc.PythonModel): 133 def predict(self, model_input: list[str]) -> str: 134 return "test" 135 136 137 @pytest.fixture 138 def mlflow_client(): 139 return MlflowClient() 140 141 142 @pytest.fixture(autouse=True) 143 def mock_get_telemetry_client(mock_telemetry_client: TelemetryClient): 144 with mock.patch( 145 "mlflow.telemetry.track.get_telemetry_client", 146 return_value=mock_telemetry_client, 147 ): 148 yield 149 150 151 def test_create_logged_model(mock_requests, mock_telemetry_client: TelemetryClient): 152 event_name = CreateLoggedModelEvent.name 153 mlflow.create_external_model(name="model") 154 validate_telemetry_record( 155 mock_telemetry_client, mock_requests, event_name, {"flavor": "external"} 156 ) 157 158 mlflow.initialize_logged_model(name="model", tags={"key": "value"}) 159 validate_telemetry_record( 160 mock_telemetry_client, mock_requests, event_name, {"flavor": "initialize"} 161 ) 162 163 _initialize_logged_model(name="model", flavor="keras") 164 validate_telemetry_record(mock_telemetry_client, mock_requests, event_name, {"flavor": "keras"}) 165 166 mlflow.pyfunc.log_model( 167 name="model", 168 python_model=TestModel(), 169 ) 170 validate_telemetry_record( 171 mock_telemetry_client, 172 mock_requests, 173 event_name, 174 {"flavor": "pyfunc.CustomPythonModel"}, 175 ) 176 177 mlflow.sklearn.log_model( 178 knn.KNeighborsClassifier(), 179 name="model", 180 ) 181 validate_telemetry_record( 182 mock_telemetry_client, 183 mock_requests, 184 event_name, 185 {"flavor": "sklearn", "serialization_format": "cloudpickle"}, 186 ) 187 188 mlflow.sklearn.log_model( 189 knn.KNeighborsClassifier(), 190 name="model", 191 serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE, 192 ) 193 validate_telemetry_record( 194 mock_telemetry_client, 195 mock_requests, 196 event_name, 197 {"flavor": "sklearn", "serialization_format": "pickle"}, 198 ) 199 200 class SimpleResponsesAgent(ResponsesAgent): 201 def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse: 202 mock_response = { 203 "output": [ 204 { 205 "type": "message", 206 "id": "1234", 207 "status": "completed", 208 "role": "assistant", 209 "content": [ 210 { 211 "type": "output_text", 212 "text": request.input[0].content, 213 } 214 ], 215 } 216 ], 217 } 218 return ResponsesAgentResponse(**mock_response) 219 220 mlflow.pyfunc.log_model( 221 name="model", 222 python_model=SimpleResponsesAgent(), 223 ) 224 validate_telemetry_record( 225 mock_telemetry_client, 226 mock_requests, 227 event_name, 228 {"flavor": "pyfunc.ResponsesAgent"}, 229 ) 230 231 _create_logged_model(name="model", flavor="pyfunc", uses_uv=True) 232 validate_telemetry_record( 233 mock_telemetry_client, 234 mock_requests, 235 event_name, 236 {"flavor": "pyfunc", "uses_uv": True}, 237 ) 238 239 _create_logged_model(name="model", flavor="pyfunc", uses_uv=False) 240 validate_telemetry_record( 241 mock_telemetry_client, 242 mock_requests, 243 event_name, 244 {"flavor": "pyfunc"}, 245 ) 246 247 248 def test_create_experiment(mock_requests, mlflow_client, mock_telemetry_client: TelemetryClient): 249 event_name = CreateExperimentEvent.name 250 exp_id = mlflow.create_experiment(name="test_experiment") 251 validate_telemetry_record( 252 mock_telemetry_client, mock_requests, event_name, {"experiment_id": exp_id} 253 ) 254 255 exp_id = mlflow_client.create_experiment(name="test_experiment1") 256 validate_telemetry_record( 257 mock_telemetry_client, mock_requests, event_name, {"experiment_id": exp_id} 258 ) 259 260 261 def test_create_run(mock_requests, mlflow_client, mock_telemetry_client: TelemetryClient): 262 event_name = CreateRunEvent.name 263 exp_id = mlflow.create_experiment(name="test_experiment") 264 with mlflow.start_run(experiment_id=exp_id): 265 record = validate_telemetry_record( 266 mock_telemetry_client, mock_requests, event_name, check_params=False 267 ) 268 assert json.loads(record["params"])["experiment_id"] == exp_id 269 270 mlflow_client.create_run(experiment_id=exp_id) 271 validate_telemetry_record(mock_telemetry_client, mock_requests, event_name, check_params=False) 272 273 exp_id = mlflow.create_experiment(name="test_experiment2") 274 mlflow.set_experiment(experiment_id=exp_id) 275 with mlflow.start_run(): 276 record = validate_telemetry_record( 277 mock_telemetry_client, mock_requests, event_name, check_params=False 278 ) 279 params = json.loads(record["params"]) 280 assert params["mlflow_experiment_id"] == exp_id 281 282 283 def test_create_run_with_imports(mock_requests, mock_telemetry_client: TelemetryClient): 284 event_name = CreateRunEvent.name 285 import pyspark.ml # noqa: F401 286 287 with mlflow.start_run(): 288 data = validate_telemetry_record( 289 mock_telemetry_client, mock_requests, event_name, check_params=False 290 ) 291 assert "pyspark.ml" in json.loads(data["params"])["imports"] 292 293 294 def test_create_registered_model( 295 mock_requests, mlflow_client, mock_telemetry_client: TelemetryClient 296 ): 297 event_name = CreateRegisteredModelEvent.name 298 mlflow_client.create_registered_model(name="test_model1") 299 validate_telemetry_record( 300 mock_telemetry_client, 301 mock_requests, 302 event_name, 303 {"is_prompt": False}, 304 ) 305 306 mlflow.pyfunc.log_model( 307 name="model", 308 python_model=TestModel(), 309 registered_model_name="test_model", 310 ) 311 validate_telemetry_record( 312 mock_telemetry_client, 313 mock_requests, 314 event_name, 315 {"is_prompt": False}, 316 ) 317 318 319 def test_create_model_version(mock_requests, mlflow_client, mock_telemetry_client: TelemetryClient): 320 event_name = CreateModelVersionEvent.name 321 mlflow_client.create_registered_model(name="test_model") 322 mlflow_client.create_model_version( 323 name="test_model", source="test_source", run_id="test_run_id" 324 ) 325 validate_telemetry_record( 326 mock_telemetry_client, 327 mock_requests, 328 event_name, 329 {"is_prompt": False}, 330 ) 331 332 mlflow.pyfunc.log_model( 333 name="model", 334 python_model=TestModel(), 335 registered_model_name="test_model", 336 ) 337 validate_telemetry_record( 338 mock_telemetry_client, 339 mock_requests, 340 event_name, 341 {"is_prompt": False}, 342 ) 343 344 mlflow.genai.register_prompt( 345 name="ai_assistant_prompt", 346 template="Respond to the user's message as a {{style}} AI. {{greeting}}", 347 commit_message="Initial version of AI assistant", 348 ) 349 validate_telemetry_record( 350 mock_telemetry_client, 351 mock_requests, 352 event_name, 353 {"is_prompt": True}, 354 ) 355 356 357 def test_start_trace(mock_requests, mlflow_client, mock_telemetry_client: TelemetryClient): 358 event_name = StartTraceEvent.name 359 with mlflow.start_span(name="test_span"): 360 pass 361 validate_telemetry_record(mock_telemetry_client, mock_requests, event_name, check_params=False) 362 363 @mlflow.trace 364 def test_func(): 365 pass 366 367 test_func() 368 validate_telemetry_record(mock_telemetry_client, mock_requests, event_name, check_params=False) 369 370 trace_id = mlflow_client.start_trace(name="test_trace").trace_id 371 mlflow_client.end_trace(trace_id=trace_id) 372 validate_telemetry_record(mock_telemetry_client, mock_requests, event_name, check_params=False) 373 374 import openai # noqa: F401 375 376 test_func() 377 data = validate_telemetry_record( 378 mock_telemetry_client, mock_requests, event_name, check_params=False 379 ) 380 params = json.loads(data["params"]) 381 assert "openai" in params["imports"] 382 assert params["format"] == "native" 383 384 385 def test_start_trace_genai_semconv( 386 mock_requests, monkeypatch, mock_telemetry_client: TelemetryClient 387 ): 388 monkeypatch.setenv("MLFLOW_ENABLE_OTEL_GENAI_SEMCONV", "true") 389 event_name = StartTraceEvent.name 390 391 @mlflow.trace 392 def test_func(): 393 pass 394 395 test_func() 396 data = validate_telemetry_record( 397 mock_telemetry_client, mock_requests, event_name, check_params=False 398 ) 399 assert json.loads(data["params"])["format"] == "genai_semconv" 400 401 402 def test_create_prompt(mock_requests, mlflow_client, mock_telemetry_client: TelemetryClient): 403 mlflow_client.create_prompt(name="test_prompt") 404 validate_telemetry_record(mock_telemetry_client, mock_requests, CreatePromptEvent.name) 405 406 # OSS prompt registry uses create_registered_model with a special tag 407 mlflow.genai.register_prompt( 408 name="greeting_prompt", 409 template="Respond to the user's message as a {{style}} AI. {{greeting}}", 410 ) 411 expected_params = {"is_prompt": True} 412 validate_telemetry_record( 413 mock_telemetry_client, 414 mock_requests, 415 CreateRegisteredModelEvent.name, 416 expected_params, 417 ) 418 419 420 def test_log_assessment(mock_requests, mock_telemetry_client: TelemetryClient): 421 with mlflow.start_span(name="test_span") as span: 422 feedback = Feedback( 423 name="faithfulness", 424 value=0.9, 425 rationale="The model is faithful to the input.", 426 metadata={"model": "gpt-4o-mini"}, 427 ) 428 429 mlflow.log_assessment(trace_id=span.trace_id, assessment=feedback) 430 validate_telemetry_record( 431 mock_telemetry_client, 432 mock_requests, 433 LogAssessmentEvent.name, 434 {"type": "feedback", "source_type": "CODE"}, 435 ) 436 mlflow.log_feedback(trace_id=span.trace_id, value=0.9, name="faithfulness") 437 validate_telemetry_record( 438 mock_telemetry_client, 439 mock_requests, 440 LogAssessmentEvent.name, 441 {"type": "feedback", "source_type": "CODE"}, 442 ) 443 444 with mlflow.start_span(name="test_span2") as span: 445 expectation = Expectation( 446 name="expected_answer", 447 value="MLflow", 448 ) 449 450 mlflow.log_assessment(trace_id=span.trace_id, assessment=expectation) 451 validate_telemetry_record( 452 mock_telemetry_client, 453 mock_requests, 454 LogAssessmentEvent.name, 455 {"type": "expectation", "source_type": "HUMAN"}, 456 ) 457 mlflow.log_expectation(trace_id=span.trace_id, value="MLflow", name="expected_answer") 458 validate_telemetry_record( 459 mock_telemetry_client, 460 mock_requests, 461 LogAssessmentEvent.name, 462 {"type": "expectation", "source_type": "HUMAN"}, 463 ) 464 465 466 def test_evaluate(mock_requests, mock_telemetry_client: TelemetryClient): 467 mlflow.models.evaluate( 468 data=pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}), 469 model=lambda x: x["x"] * 2, 470 extra_metrics=[mlflow.metrics.latency()], 471 ) 472 validate_telemetry_record(mock_telemetry_client, mock_requests, EvaluateEvent.name) 473 474 475 def test_create_webhook(mock_requests, mock_telemetry_client: TelemetryClient): 476 client = MlflowClient() 477 client.create_webhook( 478 name="test_webhook", 479 url="https://example.com/webhook", 480 events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)], 481 ) 482 expected_params = {"events": ["model_version.created"]} 483 validate_telemetry_record( 484 mock_telemetry_client, mock_requests, CreateWebhookEvent.name, expected_params 485 ) 486 487 488 def test_genai_evaluate(mock_requests, mock_telemetry_client: TelemetryClient): 489 @mlflow.genai.scorer 490 def decorator_scorer(): 491 return 1.0 492 493 instructions_judge = make_judge( 494 name="quality_judge", 495 instructions="Evaluate if {{ outputs }} is high quality", 496 model="openai:/gpt-4", 497 ) 498 499 session_level_instruction_judge = make_judge( 500 name="conversation_quality", 501 instructions="Evaluate if the {{ conversation }} is engaging and coherent", 502 model="openai:/gpt-4", 503 ) 504 505 guidelines_scorer = Guidelines( 506 name="politeness", 507 guidelines=["Be polite", "Be respectful"], 508 ) 509 510 builtin_scorer = RelevanceToQuery(name="relevance_check") 511 512 session_level_builtin_scorer = UserFrustration(name="frustration_check") 513 514 data = [ 515 { 516 "inputs": {"model_input": ["What is MLflow?"]}, 517 "outputs": "MLflow is an open source platform.", 518 } 519 ] 520 521 model = TestModel() 522 523 with ( 524 mock.patch("mlflow.genai.judges.utils.invocation_utils.invoke_judge_model"), 525 mock.patch("mlflow.genai.judges.builtin.invoke_judge_model"), 526 mock.patch("mlflow.genai.judges.instructions_judge.invoke_judge_model"), 527 ): 528 # Test with all scorer kinds and scopes, without predict_fn 529 mlflow.genai.evaluate( 530 data=data, 531 scorers=[ 532 decorator_scorer, 533 instructions_judge, 534 session_level_instruction_judge, 535 guidelines_scorer, 536 builtin_scorer, 537 session_level_builtin_scorer, 538 ], 539 ) 540 541 expected_params = { 542 "predict_fn_provided": False, 543 "scorer_info": [ 544 { 545 "class": "UserDefinedScorer", 546 "kind": "decorator", 547 "scope": "trace", 548 }, 549 { 550 "class": "UserDefinedScorer", 551 "kind": "instructions", 552 "scope": "trace", 553 }, 554 { 555 "class": "UserDefinedScorer", 556 "kind": "instructions", 557 "scope": "session", 558 }, 559 {"class": "Guidelines", "kind": "guidelines", "scope": "trace"}, 560 {"class": "RelevanceToQuery", "kind": "builtin", "scope": "trace"}, 561 {"class": "UserFrustration", "kind": "builtin", "scope": "session"}, 562 ], 563 "eval_data_type": "list[dict]", 564 "eval_data_size": 1, 565 "eval_data_provided_fields": ["inputs", "outputs"], 566 } 567 validate_telemetry_record( 568 mock_telemetry_client, 569 mock_requests, 570 GenAIEvaluateEvent.name, 571 expected_params, 572 ) 573 574 # Test with predict_fn 575 mlflow.genai.evaluate( 576 data=data, 577 scorers=[builtin_scorer, guidelines_scorer], 578 predict_fn=model.predict, 579 ) 580 expected_params = { 581 "predict_fn_provided": True, 582 "scorer_info": [ 583 {"class": "RelevanceToQuery", "kind": "builtin", "scope": "trace"}, 584 {"class": "Guidelines", "kind": "guidelines", "scope": "trace"}, 585 ], 586 "eval_data_type": "list[dict]", 587 "eval_data_size": 1, 588 "eval_data_provided_fields": ["inputs", "outputs"], 589 } 590 validate_telemetry_record( 591 mock_telemetry_client, 592 mock_requests, 593 GenAIEvaluateEvent.name, 594 expected_params, 595 ) 596 597 598 def test_genai_evaluate_telemetry_data_fields( 599 mock_requests, mock_telemetry_client: TelemetryClient 600 ): 601 @mlflow.genai.scorer 602 def sample_scorer(): 603 return 1.0 604 605 with mock.patch("mlflow.genai.judges.utils.invocation_utils.invoke_judge_model"): 606 # Test with list of dicts 607 data_list = [ 608 { 609 "inputs": {"question": "Q1"}, 610 "outputs": "A1", 611 "expectations": {"answer": "Expected1"}, 612 }, 613 { 614 "inputs": {"question": "Q2"}, 615 "outputs": "A2", 616 "expectations": {"answer": "Expected2"}, 617 }, 618 ] 619 mlflow.genai.evaluate(data=data_list, scorers=[sample_scorer]) 620 expected_params = { 621 "predict_fn_provided": False, 622 "scorer_info": [ 623 { 624 "class": "UserDefinedScorer", 625 "kind": "decorator", 626 "scope": "trace", 627 }, 628 ], 629 "eval_data_type": "list[dict]", 630 "eval_data_size": 2, 631 "eval_data_provided_fields": ["expectations", "inputs", "outputs"], 632 } 633 validate_telemetry_record( 634 mock_telemetry_client, 635 mock_requests, 636 GenAIEvaluateEvent.name, 637 expected_params, 638 ) 639 640 # Test with pandas DataFrame 641 df_data = pd.DataFrame([ 642 {"inputs": {"question": "Q1"}, "outputs": "A1"}, 643 {"inputs": {"question": "Q2"}, "outputs": "A2"}, 644 {"inputs": {"question": "Q3"}, "outputs": "A3"}, 645 ]) 646 mlflow.genai.evaluate(data=df_data, scorers=[sample_scorer]) 647 expected_params = { 648 "predict_fn_provided": False, 649 "scorer_info": [ 650 { 651 "class": "UserDefinedScorer", 652 "kind": "decorator", 653 "scope": "trace", 654 }, 655 ], 656 "eval_data_type": "pd.DataFrame", 657 "eval_data_size": 3, 658 "eval_data_provided_fields": ["inputs", "outputs"], 659 } 660 validate_telemetry_record( 661 mock_telemetry_client, 662 mock_requests, 663 GenAIEvaluateEvent.name, 664 expected_params, 665 ) 666 667 # Test with list of Traces 668 trace_ids = [] 669 for i in range(2): 670 with mlflow.start_span(name=f"test_span_{i}") as span: 671 span.set_inputs({"question": f"Q{i}"}) 672 span.set_outputs({"answer": f"A{i}"}) 673 trace_ids.append(span.trace_id) 674 675 traces = [mlflow.get_trace(trace_id) for trace_id in trace_ids] 676 mlflow.genai.evaluate(data=traces, scorers=[sample_scorer]) 677 expected_params = { 678 "predict_fn_provided": False, 679 "scorer_info": [ 680 { 681 "class": "UserDefinedScorer", 682 "kind": "decorator", 683 "scope": "trace", 684 }, 685 ], 686 "eval_data_type": "list[Trace]", 687 "eval_data_size": 2, 688 "eval_data_provided_fields": ["inputs", "outputs", "trace"], 689 } 690 validate_telemetry_record( 691 mock_telemetry_client, 692 mock_requests, 693 GenAIEvaluateEvent.name, 694 expected_params, 695 ) 696 697 # Test with EvaluationDataset 698 from mlflow.genai.datasets import create_dataset 699 700 dataset = create_dataset("test_dataset") 701 dataset_data = [ 702 { 703 "inputs": {"question": "Q1"}, 704 "outputs": "A1", 705 "expectations": {"answer": "Expected1"}, 706 }, 707 { 708 "inputs": {"question": "Q2"}, 709 "outputs": "A2", 710 "expectations": {"answer": "Expected2"}, 711 }, 712 ] 713 dataset.merge_records(dataset_data) 714 mlflow.genai.evaluate(data=dataset, scorers=[sample_scorer]) 715 expected_params = { 716 "predict_fn_provided": False, 717 "scorer_info": [ 718 { 719 "class": "UserDefinedScorer", 720 "kind": "decorator", 721 "scope": "trace", 722 }, 723 ], 724 "eval_data_type": "EvaluationDataset", 725 "eval_data_size": 2, 726 "eval_data_provided_fields": ["expectations", "inputs", "outputs"], 727 } 728 validate_telemetry_record( 729 mock_telemetry_client, 730 mock_requests, 731 GenAIEvaluateEvent.name, 732 expected_params, 733 ) 734 735 736 def test_simulate_conversation(mock_requests, mock_telemetry_client: TelemetryClient): 737 simulator = ConversationSimulator( 738 test_cases=[ 739 {"goal": "Learn about MLflow"}, 740 {"goal": "Debug an issue"}, 741 ], 742 max_turns=2, 743 ) 744 745 def mock_predict_fn(input, **kwargs): 746 return {"role": "assistant", "content": "Mock response"} 747 748 mock_trace = mock.Mock() 749 with ( 750 mock.patch( 751 "mlflow.genai.simulators.simulator.invoke_model_without_tracing", 752 return_value="Mock user message", 753 ), 754 mock.patch( 755 "mlflow.genai.simulators.simulator.ConversationSimulator._check_goal_achieved", 756 return_value=False, 757 ), 758 mock.patch( 759 "mlflow.genai.simulators.simulator.mlflow.get_trace", 760 return_value=mock_trace, 761 ), 762 ): 763 result = simulator.simulate(predict_fn=mock_predict_fn) 764 765 assert len(result) == 2 766 767 validate_telemetry_record( 768 mock_telemetry_client, 769 mock_requests, 770 SimulateConversationEvent.name, 771 { 772 "callsite": "conversation_simulator", 773 "simulated_conversation_info": [ 774 {"turn_count": len(result[0])}, 775 {"turn_count": len(result[1])}, 776 ], 777 }, 778 ) 779 780 781 def test_simulate_conversation_from_genai_evaluate( 782 mock_requests, mock_telemetry_client: TelemetryClient 783 ): 784 simulator = ConversationSimulator( 785 test_cases=[ 786 {"goal": "Learn about MLflow"}, 787 ], 788 max_turns=1, 789 ) 790 791 def mock_predict_fn(input, **kwargs): 792 return {"role": "assistant", "content": "Mock response"} 793 794 @scorer 795 def simple_scorer(outputs) -> bool: 796 return len(outputs) > 0 797 798 with ( 799 mock.patch( 800 "mlflow.genai.simulators.simulator.invoke_model_without_tracing", 801 return_value="Mock user message", 802 ), 803 mock.patch( 804 "mlflow.genai.simulators.simulator.ConversationSimulator._check_goal_achieved", 805 return_value=True, 806 ), 807 ): 808 mlflow.genai.evaluate(data=simulator, predict_fn=mock_predict_fn, scorers=[simple_scorer]) 809 810 mock_telemetry_client.flush() 811 812 simulate_events = [ 813 record 814 for record in mock_requests 815 if record["data"]["event_name"] == SimulateConversationEvent.name 816 ] 817 assert len(simulate_events) == 1 818 819 event_params = json.loads(simulate_events[0]["data"]["params"]) 820 assert event_params == { 821 "callsite": "genai_evaluate", 822 "simulated_conversation_info": [{"turn_count": 1}], 823 } 824 825 826 def test_prompt_optimization(mock_requests, mock_telemetry_client: TelemetryClient): 827 from mlflow.genai.optimize import optimize_prompts 828 from mlflow.genai.optimize.optimizers import BasePromptOptimizer 829 from mlflow.genai.optimize.types import PromptOptimizerOutput 830 831 class MockAdapter(BasePromptOptimizer): 832 def __init__(self): 833 self.model_name = "openai:/gpt-4o-mini" 834 835 def optimize(self, eval_fn, train_data, target_prompts, enable_tracking): 836 return PromptOptimizerOutput(optimized_prompts=target_prompts) 837 838 sample_prompt = mlflow.genai.register_prompt( 839 name="test_prompt_for_adaptation", 840 template="Translate {{input_text}} to {{language}}", 841 ) 842 843 sample_data = [ 844 {"inputs": {"input_text": "Hello", "language": "Spanish"}, "outputs": "Hola"}, 845 {"inputs": {"input_text": "World", "language": "French"}, "outputs": "Monde"}, 846 ] 847 848 @mlflow.genai.scorers.scorer 849 def exact_match_scorer(outputs, expectations): 850 return 1.0 if outputs == expectations["expected_response"] else 0.0 851 852 def predict_fn(input_text, language): 853 mlflow.genai.load_prompt(f"prompts:/{sample_prompt.name}/{sample_prompt.version}") 854 return "translated" 855 856 optimize_prompts( 857 predict_fn=predict_fn, 858 train_data=sample_data, 859 prompt_uris=[f"prompts:/{sample_prompt.name}/{sample_prompt.version}"], 860 optimizer=MockAdapter(), 861 scorers=[exact_match_scorer], 862 ) 863 validate_telemetry_record( 864 mock_telemetry_client, 865 mock_requests, 866 PromptOptimizationEvent.name, 867 { 868 "optimizer_type": "MockAdapter", 869 "prompt_count": 1, 870 "scorer_count": 1, 871 "custom_aggregation": False, 872 }, 873 ) 874 875 876 def test_create_dataset(mock_requests, mock_telemetry_client: TelemetryClient): 877 with mock.patch("mlflow.tracking._tracking_service.utils._get_store") as mock_store: 878 mock_store_instance = mock.MagicMock() 879 mock_store.return_value = mock_store_instance 880 mock_store_instance.create_dataset.return_value = mock.MagicMock( 881 dataset_id="test-dataset-id", name="test_dataset", tags={"test": "value"} 882 ) 883 884 create_dataset(name="test_dataset", tags={"test": "value"}) 885 validate_telemetry_record(mock_telemetry_client, mock_requests, CreateDatasetEvent.name) 886 887 888 def test_merge_records(mock_requests, mock_telemetry_client: TelemetryClient): 889 with mock.patch("mlflow.tracking._tracking_service.utils._get_store") as mock_store: 890 mock_store_instance = mock.MagicMock() 891 mock_store.return_value = mock_store_instance 892 mock_store_instance.get_dataset.return_value = mock.MagicMock(dataset_id="test-id") 893 mock_store_instance.upsert_dataset_records.return_value = { 894 "inserted": 2, 895 "updated": 0, 896 } 897 898 evaluation_dataset = EvaluationDataset( 899 dataset_id="test-id", 900 name="test", 901 digest="digest", 902 created_time=123, 903 last_update_time=456, 904 ) 905 906 records = [ 907 {"inputs": {"q": "Q1"}, "expectations": {"a": "A1"}}, 908 {"inputs": {"q": "Q2"}, "expectations": {"a": "A2"}}, 909 ] 910 evaluation_dataset.merge_records(records) 911 912 expected_params = { 913 "record_count": 2, 914 "input_type": "list[dict]", 915 "dataset_type": "trace", 916 } 917 validate_telemetry_record( 918 mock_telemetry_client, 919 mock_requests, 920 MergeRecordsEvent.name, 921 expected_params, 922 ) 923 924 925 def test_log_dataset(mock_requests, mock_telemetry_client: TelemetryClient): 926 with mlflow.start_run() as run: 927 dataset = mlflow.data.from_pandas(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) 928 mlflow.log_input(dataset) 929 validate_telemetry_record(mock_telemetry_client, mock_requests, LogDatasetEvent.name) 930 931 mlflow.log_inputs(datasets=[dataset], contexts=["training"], tags_list=[None]) 932 validate_telemetry_record(mock_telemetry_client, mock_requests, LogDatasetEvent.name) 933 934 client = MlflowClient() 935 client.log_inputs(run_id=run.info.run_id, datasets=[_create_dataset_input(dataset)]) 936 validate_telemetry_record(mock_telemetry_client, mock_requests, LogDatasetEvent.name) 937 938 939 def test_log_metric(mock_requests, mock_telemetry_client: TelemetryClient): 940 with mlflow.start_run(): 941 mlflow.log_metric("test_metric", 1.0) 942 validate_telemetry_record( 943 mock_telemetry_client, 944 mock_requests, 945 LogMetricEvent.name, 946 {"synchronous": True}, 947 ) 948 949 mlflow.log_metric("test_metric", 1.0, synchronous=False) 950 validate_telemetry_record( 951 mock_telemetry_client, 952 mock_requests, 953 LogMetricEvent.name, 954 {"synchronous": False}, 955 ) 956 957 client = MlflowClient() 958 client.log_metric( 959 run_id=mlflow.active_run().info.run_id, 960 key="test_metric", 961 value=1.0, 962 timestamp=int(time.time()), 963 step=0, 964 ) 965 validate_telemetry_record( 966 mock_telemetry_client, 967 mock_requests, 968 LogMetricEvent.name, 969 {"synchronous": True}, 970 ) 971 972 client.log_metric( 973 run_id=mlflow.active_run().info.run_id, 974 key="test_metric", 975 value=1.0, 976 timestamp=int(time.time()), 977 step=0, 978 synchronous=False, 979 ) 980 validate_telemetry_record( 981 mock_telemetry_client, 982 mock_requests, 983 LogMetricEvent.name, 984 {"synchronous": False}, 985 ) 986 987 988 def test_log_param(mock_requests, mock_telemetry_client: TelemetryClient): 989 with mlflow.start_run(): 990 mlflow.log_param("test_param", "test_value") 991 validate_telemetry_record( 992 mock_telemetry_client, 993 mock_requests, 994 LogParamEvent.name, 995 {"synchronous": True}, 996 ) 997 998 mlflow.log_param("test_param", "test_value", synchronous=False) 999 validate_telemetry_record( 1000 mock_telemetry_client, 1001 mock_requests, 1002 LogParamEvent.name, 1003 {"synchronous": False}, 1004 ) 1005 1006 client = mlflow.MlflowClient() 1007 client.log_param( 1008 run_id=mlflow.active_run().info.run_id, 1009 key="test_param", 1010 value="test_value", 1011 ) 1012 validate_telemetry_record( 1013 mock_telemetry_client, 1014 mock_requests, 1015 LogParamEvent.name, 1016 {"synchronous": True}, 1017 ) 1018 1019 1020 def test_log_batch(mock_requests, mock_telemetry_client: TelemetryClient): 1021 with mlflow.start_run(): 1022 mlflow.log_params(params={"test_param": "test_value"}) 1023 validate_telemetry_record( 1024 mock_telemetry_client, 1025 mock_requests, 1026 LogBatchEvent.name, 1027 {"metrics": False, "params": True, "tags": False, "synchronous": True}, 1028 ) 1029 1030 mlflow.log_params(params={"test_param": "test_value"}, synchronous=False) 1031 validate_telemetry_record( 1032 mock_telemetry_client, 1033 mock_requests, 1034 LogBatchEvent.name, 1035 {"metrics": False, "params": True, "tags": False, "synchronous": False}, 1036 ) 1037 1038 mlflow.log_metrics(metrics={"test_metric": 1.0}) 1039 validate_telemetry_record( 1040 mock_telemetry_client, 1041 mock_requests, 1042 LogBatchEvent.name, 1043 {"metrics": True, "params": False, "tags": False, "synchronous": True}, 1044 ) 1045 1046 mlflow.log_metrics(metrics={"test_metric": 1.0}, synchronous=False) 1047 validate_telemetry_record( 1048 mock_telemetry_client, 1049 mock_requests, 1050 LogBatchEvent.name, 1051 {"metrics": True, "params": False, "tags": False, "synchronous": False}, 1052 ) 1053 1054 mlflow.set_tags(tags={"test_tag": "test_value"}) 1055 validate_telemetry_record( 1056 mock_telemetry_client, 1057 mock_requests, 1058 LogBatchEvent.name, 1059 {"metrics": False, "params": False, "tags": True, "synchronous": True}, 1060 ) 1061 1062 mlflow.set_tags(tags={"test_tag": "test_value"}, synchronous=False) 1063 1064 validate_telemetry_record( 1065 mock_telemetry_client, 1066 mock_requests, 1067 LogBatchEvent.name, 1068 {"metrics": False, "params": False, "tags": True, "synchronous": False}, 1069 ) 1070 1071 client = mlflow.MlflowClient() 1072 client.log_batch( 1073 run_id=mlflow.active_run().info.run_id, 1074 metrics=[Metric(key="test_metric", value=1.0, timestamp=int(time.time()), step=0)], 1075 params=[Param(key="test_param", value="test_value")], 1076 tags=[RunTag(key="test_tag", value="test_value")], 1077 ) 1078 validate_telemetry_record( 1079 mock_telemetry_client, 1080 mock_requests, 1081 LogBatchEvent.name, 1082 {"metrics": True, "params": True, "tags": True, "synchronous": True}, 1083 ) 1084 1085 1086 def test_get_logged_model(mock_requests, mock_telemetry_client: TelemetryClient, tmp_path): 1087 model_info = mlflow.sklearn.log_model( 1088 knn.KNeighborsClassifier(), 1089 name="model", 1090 ) 1091 mock_telemetry_client.flush() 1092 1093 mlflow.sklearn.load_model(model_info.model_uri) 1094 data = validate_telemetry_record( 1095 mock_telemetry_client, 1096 mock_requests, 1097 GetLoggedModelEvent.name, 1098 check_params=False, 1099 ) 1100 assert "sklearn" in json.loads(data["params"])["imports"] 1101 1102 mlflow.pyfunc.load_model(model_info.model_uri) 1103 data = validate_telemetry_record( 1104 mock_telemetry_client, 1105 mock_requests, 1106 GetLoggedModelEvent.name, 1107 check_params=False, 1108 ) 1109 1110 model_def = """ 1111 import mlflow 1112 from mlflow.models import set_model 1113 1114 class TestModel(mlflow.pyfunc.PythonModel): 1115 def predict(self, context, model_input: list[str], params=None) -> list[str]: 1116 return model_input 1117 1118 set_model(TestModel()) 1119 """ 1120 model_path = tmp_path / "model.py" 1121 model_path.write_text(model_def) 1122 model_info = mlflow.pyfunc.log_model( 1123 name="model", 1124 python_model=model_path, 1125 ) 1126 mock_telemetry_client.flush() 1127 1128 mlflow.pyfunc.load_model(model_info.model_uri) 1129 data = validate_telemetry_record( 1130 mock_telemetry_client, 1131 mock_requests, 1132 GetLoggedModelEvent.name, 1133 check_params=False, 1134 ) 1135 1136 # test load model after registry 1137 mlflow.register_model(model_info.model_uri, name="test") 1138 mock_telemetry_client.flush() 1139 1140 mlflow.pyfunc.load_model("models:/test/1") 1141 data = validate_telemetry_record( 1142 mock_telemetry_client, 1143 mock_requests, 1144 GetLoggedModelEvent.name, 1145 check_params=False, 1146 ) 1147 1148 1149 def test_mcp_run(mock_requests, mock_telemetry_client: TelemetryClient): 1150 from mlflow.mcp.cli import run 1151 1152 runner = CliRunner(catch_exceptions=False) 1153 with mock.patch("mlflow.mcp.cli.run_server") as mock_run_server: 1154 runner.invoke(run) 1155 1156 mock_run_server.assert_called_once() 1157 mock_telemetry_client.flush() 1158 validate_telemetry_record(mock_telemetry_client, mock_requests, McpRunEvent.name) 1159 1160 1161 @pytest.mark.skipif(is_windows(), reason="Windows does not support gateway start") 1162 def test_gateway_start(tmp_path, mock_requests, mock_telemetry_client: TelemetryClient): 1163 config = tmp_path.joinpath("config.yml") 1164 config.write_text( 1165 """ 1166 endpoints: 1167 - name: test-endpoint 1168 endpoint_type: llm/v1/completions 1169 model: 1170 provider: openai 1171 name: gpt-3.5-turbo 1172 config: 1173 openai_api_key: test-key 1174 """ 1175 ) 1176 1177 def assert_event_recorded_before_run_app(**kwargs): 1178 mock_telemetry_client.flush() 1179 validate_telemetry_record(mock_telemetry_client, mock_requests, GatewayStartEvent.name) 1180 1181 runner = CliRunner(catch_exceptions=False) 1182 with mock.patch("mlflow.gateway.cli.run_app", side_effect=assert_event_recorded_before_run_app): 1183 runner.invoke(start, ["--config-path", str(config)]) 1184 1185 1186 @pytest.mark.parametrize( 1187 ("cli_args", "expected_params"), 1188 [ 1189 ( 1190 ["--backend-store-uri", "sqlite:///test.db"], 1191 { 1192 "auth_enabled": False, 1193 "app_name": None, 1194 "backend_store_type": "sqlite", 1195 "serve_artifacts": True, 1196 "artifacts_only": False, 1197 "expose_prometheus": False, 1198 "enable_workspaces": False, 1199 "workers": None, 1200 "dev": False, 1201 }, 1202 ), 1203 ( 1204 ["--backend-store-uri", "sqlite:///test.db", "--app-name", "basic-auth"], 1205 { 1206 "auth_enabled": True, 1207 "app_name": "basic-auth", 1208 "backend_store_type": "sqlite", 1209 "serve_artifacts": True, 1210 "artifacts_only": False, 1211 "expose_prometheus": False, 1212 "enable_workspaces": False, 1213 "workers": None, 1214 "dev": False, 1215 }, 1216 ), 1217 ( 1218 [ 1219 "--backend-store-uri", 1220 "sqlite:///test.db", 1221 "--no-serve-artifacts", 1222 "--expose-prometheus", 1223 "/tmp/metrics", 1224 "--enable-workspaces", 1225 ], 1226 { 1227 "auth_enabled": False, 1228 "app_name": None, 1229 "backend_store_type": "sqlite", 1230 "serve_artifacts": False, 1231 "artifacts_only": False, 1232 "expose_prometheus": True, 1233 "enable_workspaces": True, 1234 "workers": None, 1235 "dev": False, 1236 }, 1237 ), 1238 ], 1239 ) 1240 def test_tracking_server_start( 1241 tmp_path, 1242 mock_requests, 1243 mock_telemetry_client: TelemetryClient, 1244 monkeypatch, 1245 cli_args, 1246 expected_params, 1247 ): 1248 1249 from mlflow.cli import server 1250 1251 # Isolate env vars that server() mutates so they don't leak into other tests 1252 for key in ( 1253 "MLFLOW_ENABLE_WORKSPACES", 1254 "MLFLOW_WORKSPACE_STORE_URI", 1255 "MLFLOW_SERVER_DISABLE_SECURITY_MIDDLEWARE", 1256 "MLFLOW_SERVER_ALLOWED_HOSTS", 1257 "MLFLOW_SERVER_CORS_ALLOWED_ORIGINS", 1258 "MLFLOW_SERVER_X_FRAME_OPTIONS", 1259 ): 1260 monkeypatch.delenv(key, raising=False) 1261 1262 def assert_event_recorded_before_run_server(**kwargs): 1263 mock_telemetry_client.flush() 1264 validate_telemetry_record( 1265 mock_telemetry_client, 1266 mock_requests, 1267 TrackingServerStartEvent.name, 1268 expected_params, 1269 ) 1270 1271 runner = CliRunner(catch_exceptions=False) 1272 with ( 1273 mock.patch( 1274 "mlflow.server._run_server", side_effect=assert_event_recorded_before_run_server 1275 ), 1276 mock.patch("mlflow.server.handlers.initialize_backend_stores"), 1277 ): 1278 runner.invoke(server, cli_args) 1279 1280 1281 def test_ai_command_run(mock_requests, mock_telemetry_client: TelemetryClient): 1282 from mlflow.ai_commands import commands 1283 1284 runner = CliRunner(catch_exceptions=False) 1285 # Test CLI context 1286 with mock.patch("mlflow.ai_commands.get_command", return_value="---\ntest\n---\nTest command"): 1287 result = runner.invoke(commands, ["run", "test_command"]) 1288 assert result.exit_code == 0 1289 1290 mock_telemetry_client.flush() 1291 validate_telemetry_record( 1292 mock_telemetry_client, 1293 mock_requests, 1294 AiCommandRunEvent.name, 1295 {"command_key": "test_command", "context": "cli"}, 1296 ) 1297 1298 1299 def test_git_model_versioning(mock_requests, mock_telemetry_client): 1300 from mlflow.genai import enable_git_model_versioning 1301 1302 with enable_git_model_versioning(): 1303 pass 1304 1305 mock_telemetry_client.flush() 1306 validate_telemetry_record(mock_telemetry_client, mock_requests, GitModelVersioningEvent.name) 1307 1308 1309 @pytest.mark.parametrize( 1310 ("model_uri", "expected_provider"), 1311 [ 1312 ("databricks:/llama-3.1-70b", "databricks"), 1313 ("openai:/gpt-4o-mini", "openai"), 1314 ("endpoints:/my-endpoint", "endpoints"), 1315 ("anthropic:/claude-3-opus", "anthropic"), 1316 ], 1317 ) 1318 def test_invoke_custom_judge_model( 1319 mock_requests, 1320 mock_telemetry_client: TelemetryClient, 1321 model_uri, 1322 expected_provider, 1323 ): 1324 from mlflow.genai.judges.utils import invoke_judge_model 1325 1326 mock_response = json.dumps({"result": 0.8, "rationale": "Test rationale"}) 1327 1328 with mock.patch( 1329 "mlflow.genai.judges.adapters.gateway_adapter._invoke_via_gateway", 1330 return_value=mock_response, 1331 ): 1332 invoke_judge_model( 1333 model_uri=model_uri, 1334 prompt="Test prompt", 1335 assessment_name="test_assessment", 1336 ) 1337 1338 expected_params = {"model_provider": expected_provider} 1339 validate_telemetry_record( 1340 mock_telemetry_client, 1341 mock_requests, 1342 InvokeCustomJudgeModelEvent.name, 1343 expected_params, 1344 ) 1345 1346 1347 def test_make_judge(mock_requests, mock_telemetry_client: TelemetryClient): 1348 make_judge( 1349 name="test_judge", 1350 instructions="Evaluate the {{ inputs }} and {{ outputs }}", 1351 model="openai:/gpt-4", 1352 feedback_value_type=str, 1353 ) 1354 expected_params = {"model_provider": "openai"} 1355 validate_telemetry_record( 1356 mock_telemetry_client, mock_requests, MakeJudgeEvent.name, expected_params 1357 ) 1358 1359 make_judge( 1360 name="test_judge", 1361 instructions="Evaluate the {{ inputs }} and {{ outputs }}", 1362 feedback_value_type=str, 1363 ) 1364 expected_params = {"model_provider": None} 1365 validate_telemetry_record( 1366 mock_telemetry_client, mock_requests, MakeJudgeEvent.name, expected_params 1367 ) 1368 1369 1370 def test_align_judge(mock_requests, mock_telemetry_client: TelemetryClient): 1371 judge = make_judge( 1372 name="test_judge", 1373 instructions="Evaluate the {{ inputs }} and {{ outputs }}", 1374 model="openai:/gpt-4", 1375 feedback_value_type=str, 1376 ) 1377 1378 traces = [ 1379 mock.MagicMock(spec=Trace), 1380 mock.MagicMock(spec=Trace), 1381 ] 1382 1383 class MockOptimizer(AlignmentOptimizer): 1384 def align(self, judge, traces): 1385 return judge 1386 1387 custom_optimizer = MockOptimizer() 1388 judge.align(traces, optimizer=custom_optimizer) 1389 1390 expected_params = {"trace_count": 2, "optimizer_type": "MockOptimizer"} 1391 validate_telemetry_record( 1392 mock_telemetry_client, mock_requests, AlignJudgeEvent.name, expected_params 1393 ) 1394 1395 1396 def test_discover_issues(mock_requests, mock_telemetry_client: TelemetryClient): 1397 traces = [ 1398 mock.MagicMock(spec=Trace), 1399 mock.MagicMock(spec=Trace), 1400 mock.MagicMock(spec=Trace), 1401 ] 1402 1403 mock_triage_run_id = "abc123" 1404 mock_eval_result = mock.MagicMock() 1405 mock_eval_result.run_id = mock_triage_run_id 1406 1407 with ( 1408 patch("mlflow.genai.discovery.pipeline.get_session_id", return_value=None), 1409 patch("mlflow.genai.discovery.pipeline.verify_scorer"), 1410 patch( 1411 "mlflow.genai.discovery.pipeline.mlflow.genai.evaluate", 1412 return_value=mock_eval_result, 1413 ), 1414 patch( 1415 "mlflow.genai.discovery.pipeline.extract_failing_traces", 1416 return_value=_TriageResult([], {}, {}), 1417 ), 1418 patch("mlflow.genai.discovery.pipeline.mlflow.MlflowClient"), 1419 patch("mlflow.genai.discovery.pipeline.mlflow.set_experiment"), 1420 ): 1421 discover_issues( 1422 traces=traces, 1423 model="openai:/gpt-4", 1424 categories=["hallucination", "accuracy"], 1425 ) 1426 1427 expected_params = { 1428 "model": "openai:/gpt-4", 1429 "trace_count": 3, 1430 "categories": ["hallucination", "accuracy"], 1431 "source_run_id": None, 1432 "issue_count": 0, 1433 "total_traces_analyzed": 3, 1434 "total_cost_usd": None, 1435 "triage_run_id": mock_triage_run_id, 1436 } 1437 validate_telemetry_record( 1438 mock_telemetry_client, mock_requests, DiscoverIssuesEvent.name, expected_params 1439 ) 1440 1441 1442 def test_autologging(mock_requests, mock_telemetry_client: TelemetryClient): 1443 try: 1444 mlflow.openai.autolog() 1445 1446 mlflow.autolog() 1447 mock_telemetry_client.flush() 1448 data = [record["data"] for record in mock_requests] 1449 params = [event["params"] for event in data if event["event_name"] == AutologgingEvent.name] 1450 assert ( 1451 json.dumps({ 1452 "flavor": mlflow.openai.FLAVOR_NAME, 1453 "log_traces": True, 1454 "disable": False, 1455 }) 1456 in params 1457 ) 1458 assert json.dumps({"flavor": "all", "log_traces": True, "disable": False}) in params 1459 finally: 1460 mlflow.autolog(disable=True) 1461 1462 1463 def test_load_prompt(mock_requests, mock_telemetry_client: TelemetryClient): 1464 # Register a prompt first 1465 prompt = mlflow.genai.register_prompt( 1466 name="test_prompt", 1467 template="Hello {{name}}", 1468 ) 1469 mock_telemetry_client.flush() 1470 1471 # Set an alias for testing 1472 mlflow.genai.set_prompt_alias(name="test_prompt", version=prompt.version, alias="production") 1473 1474 # Test load_prompt with version (no alias) 1475 mlflow.genai.load_prompt(name_or_uri="test_prompt", version=prompt.version) 1476 validate_telemetry_record( 1477 mock_telemetry_client, 1478 mock_requests, 1479 LoadPromptEvent.name, 1480 {"uses_alias": False}, 1481 ) 1482 1483 # Test load_prompt with URI and version (no alias) 1484 mlflow.genai.load_prompt(name_or_uri=f"prompts:/test_prompt/{prompt.version}") 1485 validate_telemetry_record( 1486 mock_telemetry_client, 1487 mock_requests, 1488 LoadPromptEvent.name, 1489 {"uses_alias": False}, 1490 ) 1491 1492 # Test load_prompt with alias 1493 mlflow.genai.load_prompt(name_or_uri="prompts:/test_prompt@production") 1494 validate_telemetry_record( 1495 mock_telemetry_client, mock_requests, LoadPromptEvent.name, {"uses_alias": True} 1496 ) 1497 1498 # Test load_prompt with @latest (special alias) 1499 mlflow.genai.load_prompt(name_or_uri="prompts:/test_prompt@latest") 1500 validate_telemetry_record( 1501 mock_telemetry_client, mock_requests, LoadPromptEvent.name, {"uses_alias": True} 1502 ) 1503 1504 1505 def test_scorer_call_direct(mock_requests, mock_telemetry_client: TelemetryClient): 1506 @scorer 1507 def custom_scorer(outputs) -> bool: 1508 return len(outputs) > 0 1509 1510 result = custom_scorer(outputs="test output") 1511 assert result is True 1512 1513 validate_telemetry_record( 1514 mock_telemetry_client, 1515 mock_requests, 1516 ScorerCallEvent.name, 1517 { 1518 "scorer_class": "UserDefinedScorer", 1519 "scorer_kind": "decorator", 1520 "scope": "trace", 1521 "callsite": "direct_scorer_call", 1522 "has_feedback_error": False, 1523 }, 1524 ) 1525 1526 safety_scorer = Safety() 1527 1528 mock_feedback = Feedback( 1529 name="test_feedback", 1530 value="yes", 1531 rationale="Test rationale", 1532 ) 1533 1534 with mock.patch( 1535 "mlflow.genai.judges.builtin.invoke_judge_model", 1536 return_value=mock_feedback, 1537 ): 1538 safety_scorer(outputs="test output") 1539 1540 validate_telemetry_record( 1541 mock_telemetry_client, 1542 mock_requests, 1543 ScorerCallEvent.name, 1544 { 1545 "scorer_class": "Safety", 1546 "scorer_kind": "builtin", 1547 "scope": "trace", 1548 "callsite": "direct_scorer_call", 1549 "has_feedback_error": False, 1550 }, 1551 ) 1552 1553 mock_requests.clear() 1554 1555 guidelines_scorer = Guidelines(guidelines="The response must be in English") 1556 with mock.patch( 1557 "mlflow.genai.judges.builtin.invoke_judge_model", 1558 return_value=mock_feedback, 1559 ): 1560 guidelines_scorer( 1561 inputs={"question": "What is MLflow?"}, outputs="MLflow is an ML platform" 1562 ) 1563 1564 validate_telemetry_record( 1565 mock_telemetry_client, 1566 mock_requests, 1567 ScorerCallEvent.name, 1568 { 1569 "scorer_class": "Guidelines", 1570 "scorer_kind": "guidelines", 1571 "scope": "trace", 1572 "callsite": "direct_scorer_call", 1573 "has_feedback_error": False, 1574 }, 1575 ) 1576 1577 mock_requests.clear() 1578 1579 class CustomClassScorer(Scorer): 1580 name: str = "custom_class" 1581 1582 def __call__(self, *, outputs) -> bool: 1583 return len(outputs) > 0 1584 1585 custom_class_scorer = CustomClassScorer() 1586 result = custom_class_scorer(outputs="test output") 1587 assert result is True 1588 1589 validate_telemetry_record( 1590 mock_telemetry_client, 1591 mock_requests, 1592 ScorerCallEvent.name, 1593 { 1594 "scorer_class": "UserDefinedScorer", 1595 "scorer_kind": "class", 1596 "scope": "trace", 1597 "callsite": "direct_scorer_call", 1598 "has_feedback_error": False, 1599 }, 1600 ) 1601 1602 1603 def test_scorer_call_from_genai_evaluate(mock_requests, mock_telemetry_client: TelemetryClient): 1604 @scorer 1605 def simple_length_checker(outputs) -> bool: 1606 return len(outputs) > 0 1607 1608 session_judge = make_judge( 1609 name="conversation_quality", 1610 instructions="Evaluate if the {{ conversation }} is engaging and coherent", 1611 model="openai:/gpt-4", 1612 ) 1613 1614 # Create traces with session metadata for session-level scorer testing 1615 @mlflow.trace(span_type=mlflow.entities.SpanType.CHAT_MODEL) 1616 def model(question, session_id): 1617 mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id}) 1618 return f"Answer to: {question}" 1619 1620 model("What is MLflow?", session_id="test_session") 1621 trace_1 = mlflow.get_trace(mlflow.get_last_active_trace_id()) 1622 1623 model("How does MLflow work?", session_id="test_session") 1624 trace_2 = mlflow.get_trace(mlflow.get_last_active_trace_id()) 1625 1626 test_data = pd.DataFrame([ 1627 { 1628 "trace": trace_1, 1629 }, 1630 { 1631 "trace": trace_2, 1632 }, 1633 ]) 1634 1635 mock_feedback = Feedback( 1636 name="test_feedback", 1637 value="yes", 1638 rationale="Test", 1639 ) 1640 1641 with mock.patch( 1642 "mlflow.genai.judges.instructions_judge.invoke_judge_model", 1643 return_value=mock_feedback, 1644 ): 1645 mlflow.genai.evaluate(data=test_data, scorers=[simple_length_checker, session_judge]) 1646 1647 mock_telemetry_client.flush() 1648 1649 scorer_call_events = [ 1650 record for record in mock_requests if record["data"]["event_name"] == ScorerCallEvent.name 1651 ] 1652 1653 # Should have 3 events: 2 response-level calls (one per trace) 1654 # + 1 session-level call (one per session) 1655 assert len(scorer_call_events) == 3 1656 1657 event_params = [json.loads(event["data"]["params"]) for event in scorer_call_events] 1658 1659 # Validate response-level scorer was called twice (once per trace) 1660 response_level_events = [ 1661 params 1662 for params in event_params 1663 if params["scorer_class"] == "UserDefinedScorer" 1664 and params["scorer_kind"] == "decorator" 1665 and params["scope"] == "trace" 1666 and params["callsite"] == "genai_evaluate" 1667 and params["has_feedback_error"] is False 1668 ] 1669 assert len(response_level_events) == 2 1670 1671 # Validate session-level scorer was called once (once per session) 1672 session_level_events = [ 1673 params 1674 for params in event_params 1675 if params["scorer_class"] == "UserDefinedScorer" 1676 and params["scorer_kind"] == "instructions" 1677 and params["scope"] == "session" 1678 and params["callsite"] == "genai_evaluate" 1679 and params["has_feedback_error"] is False 1680 ] 1681 assert len(session_level_events) == 1 1682 1683 mock_requests.clear() 1684 1685 1686 @pytest.mark.parametrize( 1687 ("job_name", "expected_callsite"), 1688 [ 1689 ("run_online_trace_scorer", "online_scoring"), 1690 ("run_online_session_scorer", "online_scoring"), 1691 # Counterexample: non-online-scoring job should be treated as direct call 1692 ("invoke_scorer", "direct_scorer_call"), 1693 ], 1694 ) 1695 def test_scorer_call_online_scoring_callsite( 1696 mock_requests, mock_telemetry_client: TelemetryClient, monkeypatch, job_name, expected_callsite 1697 ): 1698 # Import here to avoid circular imports 1699 from mlflow.server.jobs.utils import MLFLOW_SERVER_JOB_NAME_ENV_VAR 1700 1701 monkeypatch.setenv(MLFLOW_SERVER_JOB_NAME_ENV_VAR, job_name) 1702 1703 @scorer 1704 def custom_scorer(outputs: str) -> bool: 1705 return True 1706 1707 custom_scorer(outputs="test output") 1708 1709 validate_telemetry_record( 1710 mock_telemetry_client, 1711 mock_requests, 1712 ScorerCallEvent.name, 1713 { 1714 "scorer_class": "UserDefinedScorer", 1715 "scorer_kind": "decorator", 1716 "scope": "trace", 1717 "callsite": expected_callsite, 1718 "has_feedback_error": False, 1719 }, 1720 ) 1721 1722 1723 def test_scorer_call_tracks_feedback_errors(mock_requests, mock_telemetry_client: TelemetryClient): 1724 error_judge = make_judge( 1725 name="quality_judge", 1726 instructions="Evaluate if {{ outputs }} is high quality", 1727 model="openai:/gpt-4", 1728 ) 1729 1730 error_feedback = Feedback( 1731 name="quality_judge", 1732 error="Model invocation failed", 1733 source=AssessmentSource( 1734 source_type=AssessmentSourceType.LLM_JUDGE, source_id="openai:/gpt-4" 1735 ), 1736 ) 1737 with mock.patch( 1738 "mlflow.genai.judges.instructions_judge.invoke_judge_model", 1739 return_value=error_feedback, 1740 ): 1741 result = error_judge(outputs="test output") 1742 assert result.error is not None 1743 1744 validate_telemetry_record( 1745 mock_telemetry_client, 1746 mock_requests, 1747 ScorerCallEvent.name, 1748 { 1749 "scorer_class": "UserDefinedScorer", 1750 "scorer_kind": "instructions", 1751 "scope": "trace", 1752 "callsite": "direct_scorer_call", 1753 "has_feedback_error": True, 1754 }, 1755 ) 1756 1757 mock_requests.clear() 1758 1759 # Test Scorer returns list of Feedback with mixed errors 1760 @scorer 1761 def multi_feedback_scorer(outputs) -> list[Feedback]: 1762 return [ 1763 Feedback(name="feedback1", value=1.0), 1764 Feedback(name="feedback2", error=ValueError("Error in feedback 2")), 1765 Feedback(name="feedback3", value=0.5), 1766 ] 1767 1768 multi_feedback_scorer(outputs="test") 1769 validate_telemetry_record( 1770 mock_telemetry_client, 1771 mock_requests, 1772 ScorerCallEvent.name, 1773 { 1774 "scorer_class": "UserDefinedScorer", 1775 "scorer_kind": "decorator", 1776 "scope": "trace", 1777 "callsite": "direct_scorer_call", 1778 "has_feedback_error": True, 1779 }, 1780 ) 1781 1782 mock_requests.clear() 1783 1784 # Test Scorer returns primitive type (no Feedback error possible) 1785 @scorer 1786 def primitive_scorer(outputs) -> bool: 1787 return True 1788 1789 primitive_scorer(outputs="test") 1790 validate_telemetry_record( 1791 mock_telemetry_client, 1792 mock_requests, 1793 ScorerCallEvent.name, 1794 { 1795 "scorer_class": "UserDefinedScorer", 1796 "scorer_kind": "decorator", 1797 "scope": "trace", 1798 "callsite": "direct_scorer_call", 1799 "has_feedback_error": False, 1800 }, 1801 ) 1802 1803 1804 def test_scorer_call_wrapped_builtin_scorer_direct( 1805 mock_requests, mock_telemetry_client: TelemetryClient 1806 ): 1807 completeness_scorer = Completeness() 1808 1809 mock_feedback = Feedback( 1810 name="completeness", 1811 value="yes", 1812 rationale="Test rationale", 1813 ) 1814 1815 with mock.patch( 1816 "mlflow.genai.judges.instructions_judge.invoke_judge_model", 1817 return_value=mock_feedback, 1818 ): 1819 completeness_scorer(inputs={"question": "What is MLflow?"}, outputs="MLflow is a platform") 1820 1821 mock_telemetry_client.flush() 1822 1823 # Verify exactly 1 scorer_call event was created 1824 # (only top-level Completeness, not nested InstructionsJudge) 1825 scorer_call_events = [ 1826 record for record in mock_requests if record["data"]["event_name"] == ScorerCallEvent.name 1827 ] 1828 assert len(scorer_call_events) == 1, ( 1829 f"Expected 1 scorer call event for Completeness scorer (nested calls should be skipped), " 1830 f"got {len(scorer_call_events)}" 1831 ) 1832 1833 validate_telemetry_record( 1834 mock_telemetry_client, 1835 mock_requests, 1836 ScorerCallEvent.name, 1837 { 1838 "scorer_class": "Completeness", 1839 "scorer_kind": "builtin", 1840 "scope": "trace", 1841 "callsite": "direct_scorer_call", 1842 "has_feedback_error": False, 1843 }, 1844 ) 1845 1846 1847 def test_scorer_call_wrapped_builtin_scorer_from_genai_evaluate( 1848 mock_requests, mock_telemetry_client: TelemetryClient 1849 ): 1850 user_frustration_scorer = UserFrustration() 1851 1852 @mlflow.trace(span_type=mlflow.entities.SpanType.CHAT_MODEL) 1853 def model(question, session_id): 1854 mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id}) 1855 return f"Answer to: {question}" 1856 1857 model("What is MLflow?", session_id="test_session") 1858 trace_1 = mlflow.get_trace(mlflow.get_last_active_trace_id()) 1859 1860 model("How does MLflow work?", session_id="test_session") 1861 trace_2 = mlflow.get_trace(mlflow.get_last_active_trace_id()) 1862 1863 test_data = pd.DataFrame([ 1864 {"trace": trace_1}, 1865 {"trace": trace_2}, 1866 ]) 1867 1868 mock_feedback = Feedback( 1869 name="user_frustration", 1870 value="no", 1871 rationale="Test rationale", 1872 ) 1873 1874 with mock.patch( 1875 "mlflow.genai.judges.instructions_judge.invoke_judge_model", 1876 return_value=mock_feedback, 1877 ): 1878 mlflow.genai.evaluate(data=test_data, scorers=[user_frustration_scorer]) 1879 1880 mock_telemetry_client.flush() 1881 1882 # Verify exactly 1 scorer_call event was created for the session-level scorer 1883 # (one call at the session level and no nested InstructionsJudge event) 1884 scorer_call_events = [ 1885 record for record in mock_requests if record["data"]["event_name"] == ScorerCallEvent.name 1886 ] 1887 assert len(scorer_call_events) == 1, ( 1888 f"Expected 1 scorer call event for UserFrustration scorer " 1889 f"(nested calls should be skipped), got {len(scorer_call_events)}" 1890 ) 1891 1892 validate_telemetry_record( 1893 mock_telemetry_client, 1894 mock_requests, 1895 ScorerCallEvent.name, 1896 { 1897 "scorer_class": "UserFrustration", 1898 "scorer_kind": "builtin", 1899 "scope": "session", 1900 "callsite": "genai_evaluate", 1901 "has_feedback_error": False, 1902 }, 1903 ) 1904 1905 1906 def test_gateway_crud_telemetry(mock_requests, mock_telemetry_client: TelemetryClient, tmp_path): 1907 db_path = tmp_path / "mlflow.db" 1908 store = SqlAlchemyStore(f"sqlite:///{db_path}", tmp_path.as_posix()) 1909 1910 secret = store.create_gateway_secret( 1911 secret_name="test-secret", 1912 secret_value={"api_key": "test-api-key"}, 1913 provider="openai", 1914 created_by="test-user", 1915 ) 1916 1917 model_def = store.create_gateway_model_definition( 1918 name="test-model", 1919 provider="openai", 1920 model_name="gpt-4", 1921 secret_id=secret.secret_id, 1922 created_by="test-user", 1923 ) 1924 validate_telemetry_record( 1925 mock_telemetry_client, 1926 mock_requests, 1927 GatewayCreateModelDefinitionEvent.name, 1928 {"model_name": "gpt-4", "provider": "openai"}, 1929 ) 1930 1931 model_config = GatewayEndpointModelConfig( 1932 model_definition_id=model_def.model_definition_id, 1933 linkage_type=GatewayModelLinkageType.PRIMARY, 1934 weight=100, 1935 ) 1936 endpoint = store.create_gateway_endpoint( 1937 name="test-endpoint", 1938 model_configs=[model_config], 1939 created_by="test-user", 1940 ) 1941 validate_telemetry_record( 1942 mock_telemetry_client, 1943 mock_requests, 1944 GatewayCreateEndpointEvent.name, 1945 { 1946 "has_fallback_config": False, 1947 "routing_strategy": None, 1948 "num_model_configs": 1, 1949 "usage_tracking": True, 1950 }, 1951 ) 1952 1953 store.get_gateway_endpoint(endpoint_id=endpoint.endpoint_id) 1954 validate_telemetry_record( 1955 mock_telemetry_client, 1956 mock_requests, 1957 GatewayGetEndpointEvent.name, 1958 ) 1959 1960 store.list_gateway_endpoints() 1961 validate_telemetry_record( 1962 mock_telemetry_client, 1963 mock_requests, 1964 GatewayListEndpointsEvent.name, 1965 {"filter_by_provider": False}, 1966 ) 1967 1968 store.list_gateway_endpoints(provider="openai") 1969 validate_telemetry_record( 1970 mock_telemetry_client, 1971 mock_requests, 1972 GatewayListEndpointsEvent.name, 1973 {"filter_by_provider": True}, 1974 ) 1975 1976 store.update_gateway_endpoint( 1977 endpoint_id=endpoint.endpoint_id, 1978 name="updated-endpoint", 1979 ) 1980 validate_telemetry_record( 1981 mock_telemetry_client, 1982 mock_requests, 1983 GatewayUpdateEndpointEvent.name, 1984 { 1985 "has_fallback_config": False, 1986 "routing_strategy": None, 1987 "num_model_configs": None, 1988 "usage_tracking": None, 1989 }, 1990 ) 1991 1992 store.delete_gateway_endpoint(endpoint_id=endpoint.endpoint_id) 1993 validate_telemetry_record( 1994 mock_telemetry_client, 1995 mock_requests, 1996 GatewayDeleteEndpointEvent.name, 1997 ) 1998 1999 2000 def test_gateway_secret_crud_telemetry( 2001 mock_requests, mock_telemetry_client: TelemetryClient, tmp_path 2002 ): 2003 db_path = tmp_path / "mlflow.db" 2004 store = SqlAlchemyStore(f"sqlite:///{db_path}", tmp_path.as_posix()) 2005 2006 secret = store.create_gateway_secret( 2007 secret_name="test-secret", 2008 secret_value={"api_key": "test-api-key"}, 2009 provider="openai", 2010 created_by="test-user", 2011 ) 2012 validate_telemetry_record( 2013 mock_telemetry_client, 2014 mock_requests, 2015 GatewayCreateSecretEvent.name, 2016 {"provider": "openai"}, 2017 ) 2018 2019 secret2 = store.create_gateway_secret( 2020 secret_name="test-secret-2", 2021 secret_value={"api_key": "test-api-key-2"}, 2022 created_by="test-user", 2023 ) 2024 validate_telemetry_record( 2025 mock_telemetry_client, 2026 mock_requests, 2027 GatewayCreateSecretEvent.name, 2028 {"provider": None}, 2029 ) 2030 2031 store.list_secret_infos() 2032 validate_telemetry_record( 2033 mock_telemetry_client, 2034 mock_requests, 2035 GatewayListSecretsEvent.name, 2036 {"filter_by_provider": False}, 2037 ) 2038 2039 store.list_secret_infos(provider="openai") 2040 validate_telemetry_record( 2041 mock_telemetry_client, 2042 mock_requests, 2043 GatewayListSecretsEvent.name, 2044 {"filter_by_provider": True}, 2045 ) 2046 2047 store.update_gateway_secret( 2048 secret_id=secret.secret_id, 2049 secret_value={"api_key": "updated-api-key"}, 2050 updated_by="test-user", 2051 ) 2052 validate_telemetry_record( 2053 mock_telemetry_client, 2054 mock_requests, 2055 GatewayUpdateSecretEvent.name, 2056 ) 2057 2058 store.delete_gateway_secret(secret_id=secret.secret_id) 2059 validate_telemetry_record( 2060 mock_telemetry_client, 2061 mock_requests, 2062 GatewayDeleteSecretEvent.name, 2063 ) 2064 2065 store.delete_gateway_secret(secret_id=secret2.secret_id) 2066 2067 2068 def test_gateway_budget_policy_crud_telemetry( 2069 mock_requests, mock_telemetry_client: TelemetryClient, tmp_path 2070 ): 2071 db_path = tmp_path / "mlflow.db" 2072 store = SqlAlchemyStore(f"sqlite:///{db_path}", tmp_path.as_posix()) 2073 2074 policy = store.create_budget_policy( 2075 budget_unit=BudgetUnit.USD, 2076 budget_amount=100.0, 2077 duration=BudgetDuration(unit=BudgetDurationUnit.DAYS, value=30), 2078 target_scope=BudgetTargetScope.GLOBAL, 2079 budget_action=BudgetAction.ALERT, 2080 created_by="test-user", 2081 ) 2082 validate_telemetry_record( 2083 mock_telemetry_client, 2084 mock_requests, 2085 GatewayCreateBudgetPolicyEvent.name, 2086 { 2087 "budget_unit": "USD", 2088 "duration_unit": "DAYS", 2089 "target_scope": "GLOBAL", 2090 "budget_action": "ALERT", 2091 }, 2092 ) 2093 2094 store.list_budget_policies() 2095 validate_telemetry_record( 2096 mock_telemetry_client, 2097 mock_requests, 2098 GatewayListBudgetPoliciesEvent.name, 2099 ) 2100 2101 store.update_budget_policy( 2102 budget_policy_id=policy.budget_policy_id, 2103 budget_amount=200.0, 2104 ) 2105 validate_telemetry_record( 2106 mock_telemetry_client, 2107 mock_requests, 2108 GatewayUpdateBudgetPolicyEvent.name, 2109 ) 2110 2111 store.delete_budget_policy(budget_policy_id=policy.budget_policy_id) 2112 validate_telemetry_record( 2113 mock_telemetry_client, 2114 mock_requests, 2115 GatewayDeleteBudgetPolicyEvent.name, 2116 ) 2117 2118 2119 def test_gateway_guardrail_crud_telemetry( 2120 mock_requests, mock_telemetry_client: TelemetryClient, tmp_path 2121 ): 2122 db_path = tmp_path / "mlflow.db" 2123 store = SqlAlchemyStore(f"sqlite:///{db_path}", tmp_path.as_posix()) 2124 2125 secret = store.create_gateway_secret( 2126 secret_name="test-secret", 2127 secret_value={"api_key": "test-api-key"}, 2128 provider="openai", 2129 created_by="test-user", 2130 ) 2131 model_def = store.create_gateway_model_definition( 2132 name="test-model", 2133 provider="openai", 2134 model_name="gpt-4", 2135 secret_id=secret.secret_id, 2136 created_by="test-user", 2137 ) 2138 endpoint = store.create_gateway_endpoint( 2139 name="test-endpoint", 2140 model_configs=[ 2141 GatewayEndpointModelConfig( 2142 model_definition_id=model_def.model_definition_id, 2143 linkage_type=GatewayModelLinkageType.PRIMARY, 2144 weight=100, 2145 ) 2146 ], 2147 created_by="test-user", 2148 usage_tracking=False, 2149 ) 2150 scorer_experiment_id = store.create_experiment("guardrail-scorer-exp") 2151 serialized_scorer = json.dumps({ 2152 "instructions_judge_pydantic_data": { 2153 "model": "gateway:/test-endpoint", 2154 "instructions": "Is this input safe?", 2155 } 2156 }) 2157 scorer = store.register_scorer( 2158 experiment_id=scorer_experiment_id, 2159 name="safety-judge", 2160 serialized_scorer=serialized_scorer, 2161 ) 2162 guardrail = store.create_gateway_guardrail( 2163 name="guardrail-1", 2164 scorer_id=scorer.scorer_id, 2165 scorer_version=scorer.scorer_version, 2166 stage=GuardrailStage.BEFORE, 2167 action=GuardrailAction.VALIDATION, 2168 created_by="test-user", 2169 ) 2170 validate_telemetry_record( 2171 mock_telemetry_client, 2172 mock_requests, 2173 GatewayCreateGuardrailEvent.name, 2174 { 2175 "stage": "BEFORE", 2176 "action": "VALIDATION", 2177 }, 2178 ) 2179 2180 # Guardrail update telemetry is emitted by endpoint guardrail config updates. 2181 store.add_guardrail_to_endpoint(endpoint.endpoint_id, guardrail.guardrail_id, execution_order=1) 2182 store.update_endpoint_guardrail_config( 2183 endpoint_id=endpoint.endpoint_id, 2184 guardrail_id=guardrail.guardrail_id, 2185 execution_order=2, 2186 ) 2187 validate_telemetry_record( 2188 mock_telemetry_client, 2189 mock_requests, 2190 GatewayUpdateGuardrailEvent.name, 2191 {"stage": None, "action": None}, 2192 ) 2193 2194 store.delete_gateway_guardrail(guardrail.guardrail_id) 2195 validate_telemetry_record( 2196 mock_telemetry_client, 2197 mock_requests, 2198 GatewayDeleteGuardrailEvent.name, 2199 ) 2200 2201 2202 @pytest.mark.asyncio 2203 async def test_gateway_invocation_telemetry( 2204 mock_requests, mock_telemetry_client: TelemetryClient, tmp_path 2205 ): 2206 db_path = tmp_path / "mlflow.db" 2207 store = SqlAlchemyStore(f"sqlite:///{db_path}", tmp_path.as_posix()) 2208 2209 secret = store.create_gateway_secret( 2210 secret_name="test-secret", 2211 secret_value={"api_key": "test-api-key"}, 2212 provider="openai", 2213 created_by="test-user", 2214 ) 2215 mock_telemetry_client.flush() 2216 mock_requests.clear() 2217 2218 model_def = store.create_gateway_model_definition( 2219 name="test-model", 2220 provider="openai", 2221 model_name="gpt-4", 2222 secret_id=secret.secret_id, 2223 created_by="test-user", 2224 ) 2225 endpoint = store.create_gateway_endpoint( 2226 name="test-endpoint", 2227 model_configs=[ 2228 GatewayEndpointModelConfig( 2229 model_definition_id=model_def.model_definition_id, 2230 linkage_type=GatewayModelLinkageType.PRIMARY, 2231 weight=100, 2232 ) 2233 ], 2234 created_by="test-user", 2235 ) 2236 mock_telemetry_client.flush() 2237 mock_requests.clear() 2238 2239 mock_response = chat.ResponsePayload( 2240 id="test-id", 2241 object="chat.completion", 2242 created=1234567890, 2243 model="gpt-4", 2244 choices=[ 2245 chat.Choice( 2246 index=0, 2247 message=chat.ResponseMessage(role="assistant", content="Hello!"), 2248 finish_reason="stop", 2249 ) 2250 ], 2251 usage=chat.ChatUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), 2252 ) 2253 2254 mock_model = GatewayModelConfig( 2255 model_definition_id="test-model-def", 2256 provider="openai", 2257 model_name="gpt-4", 2258 secret_value={"api_key": "test"}, 2259 linkage_type=GatewayModelLinkageType.PRIMARY, 2260 ) 2261 2262 # Test invocations endpoint (chat) 2263 mock_request = MagicMock(spec=Request) 2264 mock_request.headers = {} 2265 mock_request.json = AsyncMock( 2266 return_value={ 2267 "messages": [{"role": "user", "content": "Hi"}], 2268 "temperature": 0.7, 2269 "stream": False, 2270 } 2271 ) 2272 2273 with ( 2274 patch("mlflow.server.gateway_api._get_store", return_value=store), 2275 patch( 2276 "mlflow.server.gateway_api._create_provider_from_endpoint_name" 2277 ) as mock_create_provider, 2278 ): 2279 mock_provider = MagicMock() 2280 mock_provider.chat = AsyncMock(return_value=mock_response) 2281 mock_endpoint_config = GatewayEndpointConfig( 2282 endpoint_id=endpoint.endpoint_id, 2283 endpoint_name=endpoint.name, 2284 models=[mock_model], 2285 ) 2286 mock_create_provider.return_value = (mock_provider, mock_endpoint_config) 2287 2288 await invocations(endpoint.name, mock_request) 2289 2290 data = validate_telemetry_record( 2291 mock_telemetry_client, 2292 mock_requests, 2293 GatewayInvocationEvent.name, 2294 check_params=False, 2295 ) 2296 params = json.loads(data["params"]) 2297 assert params["is_streaming"] is False 2298 assert params["invocation_type"] == "mlflow_invocations" 2299 assert params["has_traceparent"] is False 2300 assert params["auth_enabled"] is False 2301 assert params["endpoint_id"] == endpoint.endpoint_id 2302 assert params["provider"] == "openai" 2303 # Non-streaming includes timing fields 2304 assert "provider_duration_ms" in params 2305 assert "gateway_overhead_ms" in params 2306 2307 # Test chat_completions endpoint 2308 mock_request = MagicMock(spec=Request) 2309 mock_request.headers = {} 2310 mock_request.json = AsyncMock( 2311 return_value={ 2312 "model": endpoint.name, 2313 "messages": [{"role": "user", "content": "Hi"}], 2314 "temperature": 0.7, 2315 "stream": False, 2316 } 2317 ) 2318 2319 with ( 2320 patch("mlflow.server.gateway_api._get_store", return_value=store), 2321 patch( 2322 "mlflow.server.gateway_api._create_provider_from_endpoint_name" 2323 ) as mock_create_provider, 2324 ): 2325 mock_provider = MagicMock() 2326 mock_provider.chat = AsyncMock(return_value=mock_response) 2327 mock_endpoint_config = GatewayEndpointConfig( 2328 endpoint_id=endpoint.endpoint_id, 2329 endpoint_name=endpoint.name, 2330 models=[mock_model], 2331 ) 2332 mock_create_provider.return_value = (mock_provider, mock_endpoint_config) 2333 2334 await chat_completions(mock_request) 2335 2336 data = validate_telemetry_record( 2337 mock_telemetry_client, 2338 mock_requests, 2339 GatewayInvocationEvent.name, 2340 check_params=False, 2341 ) 2342 params = json.loads(data["params"]) 2343 assert params["is_streaming"] is False 2344 assert params["invocation_type"] == "mlflow_chat_completions" 2345 assert params["endpoint_id"] == endpoint.endpoint_id 2346 assert params["provider"] == "openai" 2347 2348 # Test streaming invocation — timing fields should be absent 2349 mock_request = MagicMock(spec=Request) 2350 mock_request.headers = {} 2351 mock_request.json = AsyncMock( 2352 return_value={ 2353 "model": endpoint.name, 2354 "messages": [{"role": "user", "content": "Hi"}], 2355 "stream": True, 2356 } 2357 ) 2358 2359 async def mock_stream(): 2360 yield chat.StreamResponsePayload( 2361 id="test-id", 2362 object="chat.completion.chunk", 2363 created=1234567890, 2364 model="gpt-4", 2365 choices=[ 2366 chat.StreamChoice( 2367 index=0, 2368 delta=chat.StreamDelta(role="assistant", content="Hello"), 2369 finish_reason=None, 2370 ) 2371 ], 2372 ) 2373 2374 with ( 2375 patch("mlflow.server.gateway_api._get_store", return_value=store), 2376 patch( 2377 "mlflow.server.gateway_api._create_provider_from_endpoint_name" 2378 ) as mock_create_provider, 2379 ): 2380 mock_provider = MagicMock() 2381 mock_provider.chat_stream = MagicMock(return_value=mock_stream()) 2382 mock_endpoint_config = GatewayEndpointConfig( 2383 endpoint_id=endpoint.endpoint_id, 2384 endpoint_name=endpoint.name, 2385 models=[mock_model], 2386 ) 2387 mock_create_provider.return_value = (mock_provider, mock_endpoint_config) 2388 2389 await chat_completions(mock_request) 2390 2391 data = validate_telemetry_record( 2392 mock_telemetry_client, 2393 mock_requests, 2394 GatewayInvocationEvent.name, 2395 check_params=False, 2396 ) 2397 params = json.loads(data["params"]) 2398 assert params["is_streaming"] is True 2399 assert params["invocation_type"] == "mlflow_chat_completions" 2400 # Streaming responses should NOT include timing fields 2401 assert "provider_duration_ms" not in params 2402 assert "gateway_overhead_ms" not in params 2403 2404 # Test that caller header and traceparent are included in telemetry when present 2405 mock_request = MagicMock(spec=Request) 2406 mock_request.json = AsyncMock( 2407 return_value={ 2408 "model": endpoint.name, 2409 "messages": [{"role": "user", "content": "Hi"}], 2410 "stream": False, 2411 } 2412 ) 2413 mock_request.headers = {MLFLOW_GATEWAY_CALLER_HEADER: "judge", "traceparent": "00-abc-def-01"} 2414 2415 mock_auth_module = MagicMock() 2416 mock_auth_module.is_auth_enabled = MagicMock(return_value=True) 2417 2418 with ( 2419 patch("mlflow.server.gateway_api._get_store", return_value=store), 2420 patch( 2421 "mlflow.server.gateway_api._create_provider_from_endpoint_name" 2422 ) as mock_create_provider, 2423 patch.dict("sys.modules", {"mlflow.server.auth": mock_auth_module}), 2424 ): 2425 mock_provider = MagicMock() 2426 mock_provider.chat = AsyncMock(return_value=mock_response) 2427 mock_endpoint_config = GatewayEndpointConfig( 2428 endpoint_id=endpoint.endpoint_id, 2429 endpoint_name=endpoint.name, 2430 models=[mock_model], 2431 ) 2432 mock_create_provider.return_value = (mock_provider, mock_endpoint_config) 2433 2434 await chat_completions(mock_request) 2435 2436 data = validate_telemetry_record( 2437 mock_telemetry_client, 2438 mock_requests, 2439 GatewayInvocationEvent.name, 2440 check_params=False, 2441 ) 2442 params = json.loads(data["params"]) 2443 assert params["is_streaming"] is False 2444 assert params["invocation_type"] == "mlflow_chat_completions" 2445 assert params["caller"] == "judge" 2446 assert params["has_traceparent"] is True 2447 assert params["auth_enabled"] is True 2448 2449 2450 def test_tracing_context_propagation_get_and_set_success( 2451 mock_requests, mock_telemetry_client: TelemetryClient 2452 ): 2453 with mock.patch( 2454 "mlflow.telemetry.track.get_telemetry_client", return_value=mock_telemetry_client 2455 ): 2456 with mlflow.start_span("client span"): 2457 headers = get_tracing_context_headers_for_http_request() 2458 2459 validate_telemetry_record( 2460 mock_telemetry_client, 2461 mock_requests, 2462 TracingContextPropagation.name, 2463 ) 2464 2465 with mock.patch( 2466 "mlflow.telemetry.track.get_telemetry_client", return_value=mock_telemetry_client 2467 ): 2468 with set_tracing_context_from_http_request_headers(headers): 2469 with mlflow.start_span("server span"): 2470 pass 2471 2472 validate_telemetry_record( 2473 mock_telemetry_client, 2474 mock_requests, 2475 TracingContextPropagation.name, 2476 ) 2477 2478 2479 def test_update_issue_telemetry(mock_requests, mock_telemetry_client: TelemetryClient, db_uri): 2480 store = SqlAlchemyStore(db_uri, "/tmp") 2481 2482 exp_id = store.create_experiment("test-exp") 2483 issue = store.create_issue( 2484 experiment_id=exp_id, 2485 name="Original name", 2486 description="Original description", 2487 status=IssueStatus.PENDING, 2488 ) 2489 mock_telemetry_client.flush() 2490 mock_requests.clear() 2491 2492 store.update_issue( 2493 issue_id=issue.issue_id, 2494 status=IssueStatus.RESOLVED, 2495 name="Updated name", 2496 description="Updated description", 2497 severity=IssueSeverity.HIGH, 2498 ) 2499 2500 validate_telemetry_record( 2501 mock_telemetry_client, 2502 mock_requests, 2503 UpdateIssueEvent.name, 2504 { 2505 "status": "resolved", 2506 "has_name": True, 2507 "has_description": True, 2508 "severity": "high", 2509 "source_run_id": None, 2510 }, 2511 )