/ tests / telemetry / test_tracked_events.py
test_tracked_events.py
   1  import json
   2  import time
   3  from unittest import mock
   4  from unittest.mock import AsyncMock, MagicMock, patch
   5  
   6  import pandas as pd
   7  import pytest
   8  import sklearn.neighbors as knn
   9  from click.testing import CliRunner
  10  from fastapi import Request
  11  
  12  import mlflow
  13  from mlflow import MlflowClient
  14  from mlflow.entities import (
  15      EvaluationDataset,
  16      Expectation,
  17      Feedback,
  18      GatewayEndpointModelConfig,
  19      IssueSeverity,
  20      IssueStatus,
  21      Metric,
  22      Param,
  23      RunTag,
  24  )
  25  from mlflow.entities.assessment_source import AssessmentSource, AssessmentSourceType
  26  from mlflow.entities.gateway_budget_policy import (
  27      BudgetAction,
  28      BudgetDuration,
  29      BudgetDurationUnit,
  30      BudgetTargetScope,
  31      BudgetUnit,
  32  )
  33  from mlflow.entities.gateway_endpoint import GatewayModelLinkageType
  34  from mlflow.entities.gateway_guardrail import GuardrailAction, GuardrailStage
  35  from mlflow.entities.trace import Trace
  36  from mlflow.entities.webhook import WebhookAction, WebhookEntity, WebhookEvent
  37  from mlflow.gateway.cli import start
  38  from mlflow.gateway.constants import MLFLOW_GATEWAY_CALLER_HEADER
  39  from mlflow.gateway.schemas import chat
  40  from mlflow.genai.datasets import create_dataset
  41  from mlflow.genai.discovery.entities import _TriageResult
  42  from mlflow.genai.discovery.pipeline import discover_issues
  43  from mlflow.genai.judges import make_judge
  44  from mlflow.genai.judges.base import AlignmentOptimizer
  45  from mlflow.genai.scorers import scorer
  46  from mlflow.genai.scorers.base import Scorer
  47  from mlflow.genai.scorers.builtin_scorers import (
  48      Completeness,
  49      Guidelines,
  50      RelevanceToQuery,
  51      Safety,
  52      UserFrustration,
  53  )
  54  from mlflow.genai.simulators import ConversationSimulator
  55  from mlflow.pyfunc.model import (
  56      ResponsesAgent,
  57      ResponsesAgentRequest,
  58      ResponsesAgentResponse,
  59  )
  60  from mlflow.server.gateway_api import chat_completions, invocations
  61  from mlflow.store.tracking.gateway.entities import GatewayEndpointConfig, GatewayModelConfig
  62  from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore
  63  from mlflow.telemetry.client import TelemetryClient
  64  from mlflow.telemetry.events import (
  65      AiCommandRunEvent,
  66      AlignJudgeEvent,
  67      AutologgingEvent,
  68      CreateDatasetEvent,
  69      CreateExperimentEvent,
  70      CreateLoggedModelEvent,
  71      CreateModelVersionEvent,
  72      CreatePromptEvent,
  73      CreateRegisteredModelEvent,
  74      CreateRunEvent,
  75      CreateWebhookEvent,
  76      DiscoverIssuesEvent,
  77      EvaluateEvent,
  78      GatewayCreateBudgetPolicyEvent,
  79      GatewayCreateEndpointEvent,
  80      GatewayCreateGuardrailEvent,
  81      GatewayCreateModelDefinitionEvent,
  82      GatewayCreateSecretEvent,
  83      GatewayDeleteBudgetPolicyEvent,
  84      GatewayDeleteEndpointEvent,
  85      GatewayDeleteGuardrailEvent,
  86      GatewayDeleteSecretEvent,
  87      GatewayGetEndpointEvent,
  88      GatewayInvocationEvent,
  89      GatewayListBudgetPoliciesEvent,
  90      GatewayListEndpointsEvent,
  91      GatewayListSecretsEvent,
  92      GatewayStartEvent,
  93      GatewayUpdateBudgetPolicyEvent,
  94      GatewayUpdateEndpointEvent,
  95      GatewayUpdateGuardrailEvent,
  96      GatewayUpdateSecretEvent,
  97      GenAIEvaluateEvent,
  98      GetLoggedModelEvent,
  99      GitModelVersioningEvent,
 100      InvokeCustomJudgeModelEvent,
 101      LoadPromptEvent,
 102      LogAssessmentEvent,
 103      LogBatchEvent,
 104      LogDatasetEvent,
 105      LogMetricEvent,
 106      LogParamEvent,
 107      MakeJudgeEvent,
 108      McpRunEvent,
 109      MergeRecordsEvent,
 110      PromptOptimizationEvent,
 111      ScorerCallEvent,
 112      SimulateConversationEvent,
 113      StartTraceEvent,
 114      TracingContextPropagation,
 115      TrackingServerStartEvent,
 116      UpdateIssueEvent,
 117  )
 118  from mlflow.tracing.distributed import (
 119      get_tracing_context_headers_for_http_request,
 120      set_tracing_context_from_http_request_headers,
 121  )
 122  from mlflow.tracking.fluent import (
 123      _create_dataset_input,
 124      _create_logged_model,
 125      _initialize_logged_model,
 126  )
 127  from mlflow.utils.os import is_windows
 128  
 129  from tests.telemetry.helper_functions import validate_telemetry_record
 130  
 131  
 132  class TestModel(mlflow.pyfunc.PythonModel):
 133      def predict(self, model_input: list[str]) -> str:
 134          return "test"
 135  
 136  
 137  @pytest.fixture
 138  def mlflow_client():
 139      return MlflowClient()
 140  
 141  
 142  @pytest.fixture(autouse=True)
 143  def mock_get_telemetry_client(mock_telemetry_client: TelemetryClient):
 144      with mock.patch(
 145          "mlflow.telemetry.track.get_telemetry_client",
 146          return_value=mock_telemetry_client,
 147      ):
 148          yield
 149  
 150  
 151  def test_create_logged_model(mock_requests, mock_telemetry_client: TelemetryClient):
 152      event_name = CreateLoggedModelEvent.name
 153      mlflow.create_external_model(name="model")
 154      validate_telemetry_record(
 155          mock_telemetry_client, mock_requests, event_name, {"flavor": "external"}
 156      )
 157  
 158      mlflow.initialize_logged_model(name="model", tags={"key": "value"})
 159      validate_telemetry_record(
 160          mock_telemetry_client, mock_requests, event_name, {"flavor": "initialize"}
 161      )
 162  
 163      _initialize_logged_model(name="model", flavor="keras")
 164      validate_telemetry_record(mock_telemetry_client, mock_requests, event_name, {"flavor": "keras"})
 165  
 166      mlflow.pyfunc.log_model(
 167          name="model",
 168          python_model=TestModel(),
 169      )
 170      validate_telemetry_record(
 171          mock_telemetry_client,
 172          mock_requests,
 173          event_name,
 174          {"flavor": "pyfunc.CustomPythonModel"},
 175      )
 176  
 177      mlflow.sklearn.log_model(
 178          knn.KNeighborsClassifier(),
 179          name="model",
 180      )
 181      validate_telemetry_record(
 182          mock_telemetry_client,
 183          mock_requests,
 184          event_name,
 185          {"flavor": "sklearn", "serialization_format": "cloudpickle"},
 186      )
 187  
 188      mlflow.sklearn.log_model(
 189          knn.KNeighborsClassifier(),
 190          name="model",
 191          serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE,
 192      )
 193      validate_telemetry_record(
 194          mock_telemetry_client,
 195          mock_requests,
 196          event_name,
 197          {"flavor": "sklearn", "serialization_format": "pickle"},
 198      )
 199  
 200      class SimpleResponsesAgent(ResponsesAgent):
 201          def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
 202              mock_response = {
 203                  "output": [
 204                      {
 205                          "type": "message",
 206                          "id": "1234",
 207                          "status": "completed",
 208                          "role": "assistant",
 209                          "content": [
 210                              {
 211                                  "type": "output_text",
 212                                  "text": request.input[0].content,
 213                              }
 214                          ],
 215                      }
 216                  ],
 217              }
 218              return ResponsesAgentResponse(**mock_response)
 219  
 220      mlflow.pyfunc.log_model(
 221          name="model",
 222          python_model=SimpleResponsesAgent(),
 223      )
 224      validate_telemetry_record(
 225          mock_telemetry_client,
 226          mock_requests,
 227          event_name,
 228          {"flavor": "pyfunc.ResponsesAgent"},
 229      )
 230  
 231      _create_logged_model(name="model", flavor="pyfunc", uses_uv=True)
 232      validate_telemetry_record(
 233          mock_telemetry_client,
 234          mock_requests,
 235          event_name,
 236          {"flavor": "pyfunc", "uses_uv": True},
 237      )
 238  
 239      _create_logged_model(name="model", flavor="pyfunc", uses_uv=False)
 240      validate_telemetry_record(
 241          mock_telemetry_client,
 242          mock_requests,
 243          event_name,
 244          {"flavor": "pyfunc"},
 245      )
 246  
 247  
 248  def test_create_experiment(mock_requests, mlflow_client, mock_telemetry_client: TelemetryClient):
 249      event_name = CreateExperimentEvent.name
 250      exp_id = mlflow.create_experiment(name="test_experiment")
 251      validate_telemetry_record(
 252          mock_telemetry_client, mock_requests, event_name, {"experiment_id": exp_id}
 253      )
 254  
 255      exp_id = mlflow_client.create_experiment(name="test_experiment1")
 256      validate_telemetry_record(
 257          mock_telemetry_client, mock_requests, event_name, {"experiment_id": exp_id}
 258      )
 259  
 260  
 261  def test_create_run(mock_requests, mlflow_client, mock_telemetry_client: TelemetryClient):
 262      event_name = CreateRunEvent.name
 263      exp_id = mlflow.create_experiment(name="test_experiment")
 264      with mlflow.start_run(experiment_id=exp_id):
 265          record = validate_telemetry_record(
 266              mock_telemetry_client, mock_requests, event_name, check_params=False
 267          )
 268          assert json.loads(record["params"])["experiment_id"] == exp_id
 269  
 270      mlflow_client.create_run(experiment_id=exp_id)
 271      validate_telemetry_record(mock_telemetry_client, mock_requests, event_name, check_params=False)
 272  
 273      exp_id = mlflow.create_experiment(name="test_experiment2")
 274      mlflow.set_experiment(experiment_id=exp_id)
 275      with mlflow.start_run():
 276          record = validate_telemetry_record(
 277              mock_telemetry_client, mock_requests, event_name, check_params=False
 278          )
 279          params = json.loads(record["params"])
 280          assert params["mlflow_experiment_id"] == exp_id
 281  
 282  
 283  def test_create_run_with_imports(mock_requests, mock_telemetry_client: TelemetryClient):
 284      event_name = CreateRunEvent.name
 285      import pyspark.ml  # noqa: F401
 286  
 287      with mlflow.start_run():
 288          data = validate_telemetry_record(
 289              mock_telemetry_client, mock_requests, event_name, check_params=False
 290          )
 291          assert "pyspark.ml" in json.loads(data["params"])["imports"]
 292  
 293  
 294  def test_create_registered_model(
 295      mock_requests, mlflow_client, mock_telemetry_client: TelemetryClient
 296  ):
 297      event_name = CreateRegisteredModelEvent.name
 298      mlflow_client.create_registered_model(name="test_model1")
 299      validate_telemetry_record(
 300          mock_telemetry_client,
 301          mock_requests,
 302          event_name,
 303          {"is_prompt": False},
 304      )
 305  
 306      mlflow.pyfunc.log_model(
 307          name="model",
 308          python_model=TestModel(),
 309          registered_model_name="test_model",
 310      )
 311      validate_telemetry_record(
 312          mock_telemetry_client,
 313          mock_requests,
 314          event_name,
 315          {"is_prompt": False},
 316      )
 317  
 318  
 319  def test_create_model_version(mock_requests, mlflow_client, mock_telemetry_client: TelemetryClient):
 320      event_name = CreateModelVersionEvent.name
 321      mlflow_client.create_registered_model(name="test_model")
 322      mlflow_client.create_model_version(
 323          name="test_model", source="test_source", run_id="test_run_id"
 324      )
 325      validate_telemetry_record(
 326          mock_telemetry_client,
 327          mock_requests,
 328          event_name,
 329          {"is_prompt": False},
 330      )
 331  
 332      mlflow.pyfunc.log_model(
 333          name="model",
 334          python_model=TestModel(),
 335          registered_model_name="test_model",
 336      )
 337      validate_telemetry_record(
 338          mock_telemetry_client,
 339          mock_requests,
 340          event_name,
 341          {"is_prompt": False},
 342      )
 343  
 344      mlflow.genai.register_prompt(
 345          name="ai_assistant_prompt",
 346          template="Respond to the user's message as a {{style}} AI. {{greeting}}",
 347          commit_message="Initial version of AI assistant",
 348      )
 349      validate_telemetry_record(
 350          mock_telemetry_client,
 351          mock_requests,
 352          event_name,
 353          {"is_prompt": True},
 354      )
 355  
 356  
 357  def test_start_trace(mock_requests, mlflow_client, mock_telemetry_client: TelemetryClient):
 358      event_name = StartTraceEvent.name
 359      with mlflow.start_span(name="test_span"):
 360          pass
 361      validate_telemetry_record(mock_telemetry_client, mock_requests, event_name, check_params=False)
 362  
 363      @mlflow.trace
 364      def test_func():
 365          pass
 366  
 367      test_func()
 368      validate_telemetry_record(mock_telemetry_client, mock_requests, event_name, check_params=False)
 369  
 370      trace_id = mlflow_client.start_trace(name="test_trace").trace_id
 371      mlflow_client.end_trace(trace_id=trace_id)
 372      validate_telemetry_record(mock_telemetry_client, mock_requests, event_name, check_params=False)
 373  
 374      import openai  # noqa: F401
 375  
 376      test_func()
 377      data = validate_telemetry_record(
 378          mock_telemetry_client, mock_requests, event_name, check_params=False
 379      )
 380      params = json.loads(data["params"])
 381      assert "openai" in params["imports"]
 382      assert params["format"] == "native"
 383  
 384  
 385  def test_start_trace_genai_semconv(
 386      mock_requests, monkeypatch, mock_telemetry_client: TelemetryClient
 387  ):
 388      monkeypatch.setenv("MLFLOW_ENABLE_OTEL_GENAI_SEMCONV", "true")
 389      event_name = StartTraceEvent.name
 390  
 391      @mlflow.trace
 392      def test_func():
 393          pass
 394  
 395      test_func()
 396      data = validate_telemetry_record(
 397          mock_telemetry_client, mock_requests, event_name, check_params=False
 398      )
 399      assert json.loads(data["params"])["format"] == "genai_semconv"
 400  
 401  
 402  def test_create_prompt(mock_requests, mlflow_client, mock_telemetry_client: TelemetryClient):
 403      mlflow_client.create_prompt(name="test_prompt")
 404      validate_telemetry_record(mock_telemetry_client, mock_requests, CreatePromptEvent.name)
 405  
 406      # OSS prompt registry uses create_registered_model with a special tag
 407      mlflow.genai.register_prompt(
 408          name="greeting_prompt",
 409          template="Respond to the user's message as a {{style}} AI. {{greeting}}",
 410      )
 411      expected_params = {"is_prompt": True}
 412      validate_telemetry_record(
 413          mock_telemetry_client,
 414          mock_requests,
 415          CreateRegisteredModelEvent.name,
 416          expected_params,
 417      )
 418  
 419  
 420  def test_log_assessment(mock_requests, mock_telemetry_client: TelemetryClient):
 421      with mlflow.start_span(name="test_span") as span:
 422          feedback = Feedback(
 423              name="faithfulness",
 424              value=0.9,
 425              rationale="The model is faithful to the input.",
 426              metadata={"model": "gpt-4o-mini"},
 427          )
 428  
 429          mlflow.log_assessment(trace_id=span.trace_id, assessment=feedback)
 430      validate_telemetry_record(
 431          mock_telemetry_client,
 432          mock_requests,
 433          LogAssessmentEvent.name,
 434          {"type": "feedback", "source_type": "CODE"},
 435      )
 436      mlflow.log_feedback(trace_id=span.trace_id, value=0.9, name="faithfulness")
 437      validate_telemetry_record(
 438          mock_telemetry_client,
 439          mock_requests,
 440          LogAssessmentEvent.name,
 441          {"type": "feedback", "source_type": "CODE"},
 442      )
 443  
 444      with mlflow.start_span(name="test_span2") as span:
 445          expectation = Expectation(
 446              name="expected_answer",
 447              value="MLflow",
 448          )
 449  
 450          mlflow.log_assessment(trace_id=span.trace_id, assessment=expectation)
 451      validate_telemetry_record(
 452          mock_telemetry_client,
 453          mock_requests,
 454          LogAssessmentEvent.name,
 455          {"type": "expectation", "source_type": "HUMAN"},
 456      )
 457      mlflow.log_expectation(trace_id=span.trace_id, value="MLflow", name="expected_answer")
 458      validate_telemetry_record(
 459          mock_telemetry_client,
 460          mock_requests,
 461          LogAssessmentEvent.name,
 462          {"type": "expectation", "source_type": "HUMAN"},
 463      )
 464  
 465  
 466  def test_evaluate(mock_requests, mock_telemetry_client: TelemetryClient):
 467      mlflow.models.evaluate(
 468          data=pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}),
 469          model=lambda x: x["x"] * 2,
 470          extra_metrics=[mlflow.metrics.latency()],
 471      )
 472      validate_telemetry_record(mock_telemetry_client, mock_requests, EvaluateEvent.name)
 473  
 474  
 475  def test_create_webhook(mock_requests, mock_telemetry_client: TelemetryClient):
 476      client = MlflowClient()
 477      client.create_webhook(
 478          name="test_webhook",
 479          url="https://example.com/webhook",
 480          events=[WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)],
 481      )
 482      expected_params = {"events": ["model_version.created"]}
 483      validate_telemetry_record(
 484          mock_telemetry_client, mock_requests, CreateWebhookEvent.name, expected_params
 485      )
 486  
 487  
 488  def test_genai_evaluate(mock_requests, mock_telemetry_client: TelemetryClient):
 489      @mlflow.genai.scorer
 490      def decorator_scorer():
 491          return 1.0
 492  
 493      instructions_judge = make_judge(
 494          name="quality_judge",
 495          instructions="Evaluate if {{ outputs }} is high quality",
 496          model="openai:/gpt-4",
 497      )
 498  
 499      session_level_instruction_judge = make_judge(
 500          name="conversation_quality",
 501          instructions="Evaluate if the {{ conversation }} is engaging and coherent",
 502          model="openai:/gpt-4",
 503      )
 504  
 505      guidelines_scorer = Guidelines(
 506          name="politeness",
 507          guidelines=["Be polite", "Be respectful"],
 508      )
 509  
 510      builtin_scorer = RelevanceToQuery(name="relevance_check")
 511  
 512      session_level_builtin_scorer = UserFrustration(name="frustration_check")
 513  
 514      data = [
 515          {
 516              "inputs": {"model_input": ["What is MLflow?"]},
 517              "outputs": "MLflow is an open source platform.",
 518          }
 519      ]
 520  
 521      model = TestModel()
 522  
 523      with (
 524          mock.patch("mlflow.genai.judges.utils.invocation_utils.invoke_judge_model"),
 525          mock.patch("mlflow.genai.judges.builtin.invoke_judge_model"),
 526          mock.patch("mlflow.genai.judges.instructions_judge.invoke_judge_model"),
 527      ):
 528          # Test with all scorer kinds and scopes, without predict_fn
 529          mlflow.genai.evaluate(
 530              data=data,
 531              scorers=[
 532                  decorator_scorer,
 533                  instructions_judge,
 534                  session_level_instruction_judge,
 535                  guidelines_scorer,
 536                  builtin_scorer,
 537                  session_level_builtin_scorer,
 538              ],
 539          )
 540  
 541          expected_params = {
 542              "predict_fn_provided": False,
 543              "scorer_info": [
 544                  {
 545                      "class": "UserDefinedScorer",
 546                      "kind": "decorator",
 547                      "scope": "trace",
 548                  },
 549                  {
 550                      "class": "UserDefinedScorer",
 551                      "kind": "instructions",
 552                      "scope": "trace",
 553                  },
 554                  {
 555                      "class": "UserDefinedScorer",
 556                      "kind": "instructions",
 557                      "scope": "session",
 558                  },
 559                  {"class": "Guidelines", "kind": "guidelines", "scope": "trace"},
 560                  {"class": "RelevanceToQuery", "kind": "builtin", "scope": "trace"},
 561                  {"class": "UserFrustration", "kind": "builtin", "scope": "session"},
 562              ],
 563              "eval_data_type": "list[dict]",
 564              "eval_data_size": 1,
 565              "eval_data_provided_fields": ["inputs", "outputs"],
 566          }
 567          validate_telemetry_record(
 568              mock_telemetry_client,
 569              mock_requests,
 570              GenAIEvaluateEvent.name,
 571              expected_params,
 572          )
 573  
 574          # Test with predict_fn
 575          mlflow.genai.evaluate(
 576              data=data,
 577              scorers=[builtin_scorer, guidelines_scorer],
 578              predict_fn=model.predict,
 579          )
 580          expected_params = {
 581              "predict_fn_provided": True,
 582              "scorer_info": [
 583                  {"class": "RelevanceToQuery", "kind": "builtin", "scope": "trace"},
 584                  {"class": "Guidelines", "kind": "guidelines", "scope": "trace"},
 585              ],
 586              "eval_data_type": "list[dict]",
 587              "eval_data_size": 1,
 588              "eval_data_provided_fields": ["inputs", "outputs"],
 589          }
 590          validate_telemetry_record(
 591              mock_telemetry_client,
 592              mock_requests,
 593              GenAIEvaluateEvent.name,
 594              expected_params,
 595          )
 596  
 597  
 598  def test_genai_evaluate_telemetry_data_fields(
 599      mock_requests, mock_telemetry_client: TelemetryClient
 600  ):
 601      @mlflow.genai.scorer
 602      def sample_scorer():
 603          return 1.0
 604  
 605      with mock.patch("mlflow.genai.judges.utils.invocation_utils.invoke_judge_model"):
 606          # Test with list of dicts
 607          data_list = [
 608              {
 609                  "inputs": {"question": "Q1"},
 610                  "outputs": "A1",
 611                  "expectations": {"answer": "Expected1"},
 612              },
 613              {
 614                  "inputs": {"question": "Q2"},
 615                  "outputs": "A2",
 616                  "expectations": {"answer": "Expected2"},
 617              },
 618          ]
 619          mlflow.genai.evaluate(data=data_list, scorers=[sample_scorer])
 620          expected_params = {
 621              "predict_fn_provided": False,
 622              "scorer_info": [
 623                  {
 624                      "class": "UserDefinedScorer",
 625                      "kind": "decorator",
 626                      "scope": "trace",
 627                  },
 628              ],
 629              "eval_data_type": "list[dict]",
 630              "eval_data_size": 2,
 631              "eval_data_provided_fields": ["expectations", "inputs", "outputs"],
 632          }
 633          validate_telemetry_record(
 634              mock_telemetry_client,
 635              mock_requests,
 636              GenAIEvaluateEvent.name,
 637              expected_params,
 638          )
 639  
 640          # Test with pandas DataFrame
 641          df_data = pd.DataFrame([
 642              {"inputs": {"question": "Q1"}, "outputs": "A1"},
 643              {"inputs": {"question": "Q2"}, "outputs": "A2"},
 644              {"inputs": {"question": "Q3"}, "outputs": "A3"},
 645          ])
 646          mlflow.genai.evaluate(data=df_data, scorers=[sample_scorer])
 647          expected_params = {
 648              "predict_fn_provided": False,
 649              "scorer_info": [
 650                  {
 651                      "class": "UserDefinedScorer",
 652                      "kind": "decorator",
 653                      "scope": "trace",
 654                  },
 655              ],
 656              "eval_data_type": "pd.DataFrame",
 657              "eval_data_size": 3,
 658              "eval_data_provided_fields": ["inputs", "outputs"],
 659          }
 660          validate_telemetry_record(
 661              mock_telemetry_client,
 662              mock_requests,
 663              GenAIEvaluateEvent.name,
 664              expected_params,
 665          )
 666  
 667          # Test with list of Traces
 668          trace_ids = []
 669          for i in range(2):
 670              with mlflow.start_span(name=f"test_span_{i}") as span:
 671                  span.set_inputs({"question": f"Q{i}"})
 672                  span.set_outputs({"answer": f"A{i}"})
 673                  trace_ids.append(span.trace_id)
 674  
 675          traces = [mlflow.get_trace(trace_id) for trace_id in trace_ids]
 676          mlflow.genai.evaluate(data=traces, scorers=[sample_scorer])
 677          expected_params = {
 678              "predict_fn_provided": False,
 679              "scorer_info": [
 680                  {
 681                      "class": "UserDefinedScorer",
 682                      "kind": "decorator",
 683                      "scope": "trace",
 684                  },
 685              ],
 686              "eval_data_type": "list[Trace]",
 687              "eval_data_size": 2,
 688              "eval_data_provided_fields": ["inputs", "outputs", "trace"],
 689          }
 690          validate_telemetry_record(
 691              mock_telemetry_client,
 692              mock_requests,
 693              GenAIEvaluateEvent.name,
 694              expected_params,
 695          )
 696  
 697          # Test with EvaluationDataset
 698          from mlflow.genai.datasets import create_dataset
 699  
 700          dataset = create_dataset("test_dataset")
 701          dataset_data = [
 702              {
 703                  "inputs": {"question": "Q1"},
 704                  "outputs": "A1",
 705                  "expectations": {"answer": "Expected1"},
 706              },
 707              {
 708                  "inputs": {"question": "Q2"},
 709                  "outputs": "A2",
 710                  "expectations": {"answer": "Expected2"},
 711              },
 712          ]
 713          dataset.merge_records(dataset_data)
 714          mlflow.genai.evaluate(data=dataset, scorers=[sample_scorer])
 715          expected_params = {
 716              "predict_fn_provided": False,
 717              "scorer_info": [
 718                  {
 719                      "class": "UserDefinedScorer",
 720                      "kind": "decorator",
 721                      "scope": "trace",
 722                  },
 723              ],
 724              "eval_data_type": "EvaluationDataset",
 725              "eval_data_size": 2,
 726              "eval_data_provided_fields": ["expectations", "inputs", "outputs"],
 727          }
 728          validate_telemetry_record(
 729              mock_telemetry_client,
 730              mock_requests,
 731              GenAIEvaluateEvent.name,
 732              expected_params,
 733          )
 734  
 735  
 736  def test_simulate_conversation(mock_requests, mock_telemetry_client: TelemetryClient):
 737      simulator = ConversationSimulator(
 738          test_cases=[
 739              {"goal": "Learn about MLflow"},
 740              {"goal": "Debug an issue"},
 741          ],
 742          max_turns=2,
 743      )
 744  
 745      def mock_predict_fn(input, **kwargs):
 746          return {"role": "assistant", "content": "Mock response"}
 747  
 748      mock_trace = mock.Mock()
 749      with (
 750          mock.patch(
 751              "mlflow.genai.simulators.simulator.invoke_model_without_tracing",
 752              return_value="Mock user message",
 753          ),
 754          mock.patch(
 755              "mlflow.genai.simulators.simulator.ConversationSimulator._check_goal_achieved",
 756              return_value=False,
 757          ),
 758          mock.patch(
 759              "mlflow.genai.simulators.simulator.mlflow.get_trace",
 760              return_value=mock_trace,
 761          ),
 762      ):
 763          result = simulator.simulate(predict_fn=mock_predict_fn)
 764  
 765      assert len(result) == 2
 766  
 767      validate_telemetry_record(
 768          mock_telemetry_client,
 769          mock_requests,
 770          SimulateConversationEvent.name,
 771          {
 772              "callsite": "conversation_simulator",
 773              "simulated_conversation_info": [
 774                  {"turn_count": len(result[0])},
 775                  {"turn_count": len(result[1])},
 776              ],
 777          },
 778      )
 779  
 780  
 781  def test_simulate_conversation_from_genai_evaluate(
 782      mock_requests, mock_telemetry_client: TelemetryClient
 783  ):
 784      simulator = ConversationSimulator(
 785          test_cases=[
 786              {"goal": "Learn about MLflow"},
 787          ],
 788          max_turns=1,
 789      )
 790  
 791      def mock_predict_fn(input, **kwargs):
 792          return {"role": "assistant", "content": "Mock response"}
 793  
 794      @scorer
 795      def simple_scorer(outputs) -> bool:
 796          return len(outputs) > 0
 797  
 798      with (
 799          mock.patch(
 800              "mlflow.genai.simulators.simulator.invoke_model_without_tracing",
 801              return_value="Mock user message",
 802          ),
 803          mock.patch(
 804              "mlflow.genai.simulators.simulator.ConversationSimulator._check_goal_achieved",
 805              return_value=True,
 806          ),
 807      ):
 808          mlflow.genai.evaluate(data=simulator, predict_fn=mock_predict_fn, scorers=[simple_scorer])
 809  
 810      mock_telemetry_client.flush()
 811  
 812      simulate_events = [
 813          record
 814          for record in mock_requests
 815          if record["data"]["event_name"] == SimulateConversationEvent.name
 816      ]
 817      assert len(simulate_events) == 1
 818  
 819      event_params = json.loads(simulate_events[0]["data"]["params"])
 820      assert event_params == {
 821          "callsite": "genai_evaluate",
 822          "simulated_conversation_info": [{"turn_count": 1}],
 823      }
 824  
 825  
 826  def test_prompt_optimization(mock_requests, mock_telemetry_client: TelemetryClient):
 827      from mlflow.genai.optimize import optimize_prompts
 828      from mlflow.genai.optimize.optimizers import BasePromptOptimizer
 829      from mlflow.genai.optimize.types import PromptOptimizerOutput
 830  
 831      class MockAdapter(BasePromptOptimizer):
 832          def __init__(self):
 833              self.model_name = "openai:/gpt-4o-mini"
 834  
 835          def optimize(self, eval_fn, train_data, target_prompts, enable_tracking):
 836              return PromptOptimizerOutput(optimized_prompts=target_prompts)
 837  
 838      sample_prompt = mlflow.genai.register_prompt(
 839          name="test_prompt_for_adaptation",
 840          template="Translate {{input_text}} to {{language}}",
 841      )
 842  
 843      sample_data = [
 844          {"inputs": {"input_text": "Hello", "language": "Spanish"}, "outputs": "Hola"},
 845          {"inputs": {"input_text": "World", "language": "French"}, "outputs": "Monde"},
 846      ]
 847  
 848      @mlflow.genai.scorers.scorer
 849      def exact_match_scorer(outputs, expectations):
 850          return 1.0 if outputs == expectations["expected_response"] else 0.0
 851  
 852      def predict_fn(input_text, language):
 853          mlflow.genai.load_prompt(f"prompts:/{sample_prompt.name}/{sample_prompt.version}")
 854          return "translated"
 855  
 856      optimize_prompts(
 857          predict_fn=predict_fn,
 858          train_data=sample_data,
 859          prompt_uris=[f"prompts:/{sample_prompt.name}/{sample_prompt.version}"],
 860          optimizer=MockAdapter(),
 861          scorers=[exact_match_scorer],
 862      )
 863      validate_telemetry_record(
 864          mock_telemetry_client,
 865          mock_requests,
 866          PromptOptimizationEvent.name,
 867          {
 868              "optimizer_type": "MockAdapter",
 869              "prompt_count": 1,
 870              "scorer_count": 1,
 871              "custom_aggregation": False,
 872          },
 873      )
 874  
 875  
 876  def test_create_dataset(mock_requests, mock_telemetry_client: TelemetryClient):
 877      with mock.patch("mlflow.tracking._tracking_service.utils._get_store") as mock_store:
 878          mock_store_instance = mock.MagicMock()
 879          mock_store.return_value = mock_store_instance
 880          mock_store_instance.create_dataset.return_value = mock.MagicMock(
 881              dataset_id="test-dataset-id", name="test_dataset", tags={"test": "value"}
 882          )
 883  
 884          create_dataset(name="test_dataset", tags={"test": "value"})
 885          validate_telemetry_record(mock_telemetry_client, mock_requests, CreateDatasetEvent.name)
 886  
 887  
 888  def test_merge_records(mock_requests, mock_telemetry_client: TelemetryClient):
 889      with mock.patch("mlflow.tracking._tracking_service.utils._get_store") as mock_store:
 890          mock_store_instance = mock.MagicMock()
 891          mock_store.return_value = mock_store_instance
 892          mock_store_instance.get_dataset.return_value = mock.MagicMock(dataset_id="test-id")
 893          mock_store_instance.upsert_dataset_records.return_value = {
 894              "inserted": 2,
 895              "updated": 0,
 896          }
 897  
 898          evaluation_dataset = EvaluationDataset(
 899              dataset_id="test-id",
 900              name="test",
 901              digest="digest",
 902              created_time=123,
 903              last_update_time=456,
 904          )
 905  
 906          records = [
 907              {"inputs": {"q": "Q1"}, "expectations": {"a": "A1"}},
 908              {"inputs": {"q": "Q2"}, "expectations": {"a": "A2"}},
 909          ]
 910          evaluation_dataset.merge_records(records)
 911  
 912          expected_params = {
 913              "record_count": 2,
 914              "input_type": "list[dict]",
 915              "dataset_type": "trace",
 916          }
 917          validate_telemetry_record(
 918              mock_telemetry_client,
 919              mock_requests,
 920              MergeRecordsEvent.name,
 921              expected_params,
 922          )
 923  
 924  
 925  def test_log_dataset(mock_requests, mock_telemetry_client: TelemetryClient):
 926      with mlflow.start_run() as run:
 927          dataset = mlflow.data.from_pandas(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}))
 928          mlflow.log_input(dataset)
 929          validate_telemetry_record(mock_telemetry_client, mock_requests, LogDatasetEvent.name)
 930  
 931          mlflow.log_inputs(datasets=[dataset], contexts=["training"], tags_list=[None])
 932          validate_telemetry_record(mock_telemetry_client, mock_requests, LogDatasetEvent.name)
 933  
 934          client = MlflowClient()
 935          client.log_inputs(run_id=run.info.run_id, datasets=[_create_dataset_input(dataset)])
 936          validate_telemetry_record(mock_telemetry_client, mock_requests, LogDatasetEvent.name)
 937  
 938  
 939  def test_log_metric(mock_requests, mock_telemetry_client: TelemetryClient):
 940      with mlflow.start_run():
 941          mlflow.log_metric("test_metric", 1.0)
 942          validate_telemetry_record(
 943              mock_telemetry_client,
 944              mock_requests,
 945              LogMetricEvent.name,
 946              {"synchronous": True},
 947          )
 948  
 949          mlflow.log_metric("test_metric", 1.0, synchronous=False)
 950          validate_telemetry_record(
 951              mock_telemetry_client,
 952              mock_requests,
 953              LogMetricEvent.name,
 954              {"synchronous": False},
 955          )
 956  
 957          client = MlflowClient()
 958          client.log_metric(
 959              run_id=mlflow.active_run().info.run_id,
 960              key="test_metric",
 961              value=1.0,
 962              timestamp=int(time.time()),
 963              step=0,
 964          )
 965          validate_telemetry_record(
 966              mock_telemetry_client,
 967              mock_requests,
 968              LogMetricEvent.name,
 969              {"synchronous": True},
 970          )
 971  
 972          client.log_metric(
 973              run_id=mlflow.active_run().info.run_id,
 974              key="test_metric",
 975              value=1.0,
 976              timestamp=int(time.time()),
 977              step=0,
 978              synchronous=False,
 979          )
 980          validate_telemetry_record(
 981              mock_telemetry_client,
 982              mock_requests,
 983              LogMetricEvent.name,
 984              {"synchronous": False},
 985          )
 986  
 987  
 988  def test_log_param(mock_requests, mock_telemetry_client: TelemetryClient):
 989      with mlflow.start_run():
 990          mlflow.log_param("test_param", "test_value")
 991          validate_telemetry_record(
 992              mock_telemetry_client,
 993              mock_requests,
 994              LogParamEvent.name,
 995              {"synchronous": True},
 996          )
 997  
 998          mlflow.log_param("test_param", "test_value", synchronous=False)
 999          validate_telemetry_record(
1000              mock_telemetry_client,
1001              mock_requests,
1002              LogParamEvent.name,
1003              {"synchronous": False},
1004          )
1005  
1006          client = mlflow.MlflowClient()
1007          client.log_param(
1008              run_id=mlflow.active_run().info.run_id,
1009              key="test_param",
1010              value="test_value",
1011          )
1012          validate_telemetry_record(
1013              mock_telemetry_client,
1014              mock_requests,
1015              LogParamEvent.name,
1016              {"synchronous": True},
1017          )
1018  
1019  
1020  def test_log_batch(mock_requests, mock_telemetry_client: TelemetryClient):
1021      with mlflow.start_run():
1022          mlflow.log_params(params={"test_param": "test_value"})
1023          validate_telemetry_record(
1024              mock_telemetry_client,
1025              mock_requests,
1026              LogBatchEvent.name,
1027              {"metrics": False, "params": True, "tags": False, "synchronous": True},
1028          )
1029  
1030          mlflow.log_params(params={"test_param": "test_value"}, synchronous=False)
1031          validate_telemetry_record(
1032              mock_telemetry_client,
1033              mock_requests,
1034              LogBatchEvent.name,
1035              {"metrics": False, "params": True, "tags": False, "synchronous": False},
1036          )
1037  
1038          mlflow.log_metrics(metrics={"test_metric": 1.0})
1039          validate_telemetry_record(
1040              mock_telemetry_client,
1041              mock_requests,
1042              LogBatchEvent.name,
1043              {"metrics": True, "params": False, "tags": False, "synchronous": True},
1044          )
1045  
1046          mlflow.log_metrics(metrics={"test_metric": 1.0}, synchronous=False)
1047          validate_telemetry_record(
1048              mock_telemetry_client,
1049              mock_requests,
1050              LogBatchEvent.name,
1051              {"metrics": True, "params": False, "tags": False, "synchronous": False},
1052          )
1053  
1054          mlflow.set_tags(tags={"test_tag": "test_value"})
1055          validate_telemetry_record(
1056              mock_telemetry_client,
1057              mock_requests,
1058              LogBatchEvent.name,
1059              {"metrics": False, "params": False, "tags": True, "synchronous": True},
1060          )
1061  
1062          mlflow.set_tags(tags={"test_tag": "test_value"}, synchronous=False)
1063  
1064          validate_telemetry_record(
1065              mock_telemetry_client,
1066              mock_requests,
1067              LogBatchEvent.name,
1068              {"metrics": False, "params": False, "tags": True, "synchronous": False},
1069          )
1070  
1071          client = mlflow.MlflowClient()
1072          client.log_batch(
1073              run_id=mlflow.active_run().info.run_id,
1074              metrics=[Metric(key="test_metric", value=1.0, timestamp=int(time.time()), step=0)],
1075              params=[Param(key="test_param", value="test_value")],
1076              tags=[RunTag(key="test_tag", value="test_value")],
1077          )
1078          validate_telemetry_record(
1079              mock_telemetry_client,
1080              mock_requests,
1081              LogBatchEvent.name,
1082              {"metrics": True, "params": True, "tags": True, "synchronous": True},
1083          )
1084  
1085  
1086  def test_get_logged_model(mock_requests, mock_telemetry_client: TelemetryClient, tmp_path):
1087      model_info = mlflow.sklearn.log_model(
1088          knn.KNeighborsClassifier(),
1089          name="model",
1090      )
1091      mock_telemetry_client.flush()
1092  
1093      mlflow.sklearn.load_model(model_info.model_uri)
1094      data = validate_telemetry_record(
1095          mock_telemetry_client,
1096          mock_requests,
1097          GetLoggedModelEvent.name,
1098          check_params=False,
1099      )
1100      assert "sklearn" in json.loads(data["params"])["imports"]
1101  
1102      mlflow.pyfunc.load_model(model_info.model_uri)
1103      data = validate_telemetry_record(
1104          mock_telemetry_client,
1105          mock_requests,
1106          GetLoggedModelEvent.name,
1107          check_params=False,
1108      )
1109  
1110      model_def = """
1111  import mlflow
1112  from mlflow.models import set_model
1113  
1114  class TestModel(mlflow.pyfunc.PythonModel):
1115      def predict(self, context, model_input: list[str], params=None) -> list[str]:
1116          return model_input
1117  
1118  set_model(TestModel())
1119  """
1120      model_path = tmp_path / "model.py"
1121      model_path.write_text(model_def)
1122      model_info = mlflow.pyfunc.log_model(
1123          name="model",
1124          python_model=model_path,
1125      )
1126      mock_telemetry_client.flush()
1127  
1128      mlflow.pyfunc.load_model(model_info.model_uri)
1129      data = validate_telemetry_record(
1130          mock_telemetry_client,
1131          mock_requests,
1132          GetLoggedModelEvent.name,
1133          check_params=False,
1134      )
1135  
1136      # test load model after registry
1137      mlflow.register_model(model_info.model_uri, name="test")
1138      mock_telemetry_client.flush()
1139  
1140      mlflow.pyfunc.load_model("models:/test/1")
1141      data = validate_telemetry_record(
1142          mock_telemetry_client,
1143          mock_requests,
1144          GetLoggedModelEvent.name,
1145          check_params=False,
1146      )
1147  
1148  
1149  def test_mcp_run(mock_requests, mock_telemetry_client: TelemetryClient):
1150      from mlflow.mcp.cli import run
1151  
1152      runner = CliRunner(catch_exceptions=False)
1153      with mock.patch("mlflow.mcp.cli.run_server") as mock_run_server:
1154          runner.invoke(run)
1155  
1156      mock_run_server.assert_called_once()
1157      mock_telemetry_client.flush()
1158      validate_telemetry_record(mock_telemetry_client, mock_requests, McpRunEvent.name)
1159  
1160  
1161  @pytest.mark.skipif(is_windows(), reason="Windows does not support gateway start")
1162  def test_gateway_start(tmp_path, mock_requests, mock_telemetry_client: TelemetryClient):
1163      config = tmp_path.joinpath("config.yml")
1164      config.write_text(
1165          """
1166  endpoints:
1167    - name: test-endpoint
1168      endpoint_type: llm/v1/completions
1169      model:
1170        provider: openai
1171        name: gpt-3.5-turbo
1172        config:
1173          openai_api_key: test-key
1174  """
1175      )
1176  
1177      def assert_event_recorded_before_run_app(**kwargs):
1178          mock_telemetry_client.flush()
1179          validate_telemetry_record(mock_telemetry_client, mock_requests, GatewayStartEvent.name)
1180  
1181      runner = CliRunner(catch_exceptions=False)
1182      with mock.patch("mlflow.gateway.cli.run_app", side_effect=assert_event_recorded_before_run_app):
1183          runner.invoke(start, ["--config-path", str(config)])
1184  
1185  
1186  @pytest.mark.parametrize(
1187      ("cli_args", "expected_params"),
1188      [
1189          (
1190              ["--backend-store-uri", "sqlite:///test.db"],
1191              {
1192                  "auth_enabled": False,
1193                  "app_name": None,
1194                  "backend_store_type": "sqlite",
1195                  "serve_artifacts": True,
1196                  "artifacts_only": False,
1197                  "expose_prometheus": False,
1198                  "enable_workspaces": False,
1199                  "workers": None,
1200                  "dev": False,
1201              },
1202          ),
1203          (
1204              ["--backend-store-uri", "sqlite:///test.db", "--app-name", "basic-auth"],
1205              {
1206                  "auth_enabled": True,
1207                  "app_name": "basic-auth",
1208                  "backend_store_type": "sqlite",
1209                  "serve_artifacts": True,
1210                  "artifacts_only": False,
1211                  "expose_prometheus": False,
1212                  "enable_workspaces": False,
1213                  "workers": None,
1214                  "dev": False,
1215              },
1216          ),
1217          (
1218              [
1219                  "--backend-store-uri",
1220                  "sqlite:///test.db",
1221                  "--no-serve-artifacts",
1222                  "--expose-prometheus",
1223                  "/tmp/metrics",
1224                  "--enable-workspaces",
1225              ],
1226              {
1227                  "auth_enabled": False,
1228                  "app_name": None,
1229                  "backend_store_type": "sqlite",
1230                  "serve_artifacts": False,
1231                  "artifacts_only": False,
1232                  "expose_prometheus": True,
1233                  "enable_workspaces": True,
1234                  "workers": None,
1235                  "dev": False,
1236              },
1237          ),
1238      ],
1239  )
1240  def test_tracking_server_start(
1241      tmp_path,
1242      mock_requests,
1243      mock_telemetry_client: TelemetryClient,
1244      monkeypatch,
1245      cli_args,
1246      expected_params,
1247  ):
1248  
1249      from mlflow.cli import server
1250  
1251      # Isolate env vars that server() mutates so they don't leak into other tests
1252      for key in (
1253          "MLFLOW_ENABLE_WORKSPACES",
1254          "MLFLOW_WORKSPACE_STORE_URI",
1255          "MLFLOW_SERVER_DISABLE_SECURITY_MIDDLEWARE",
1256          "MLFLOW_SERVER_ALLOWED_HOSTS",
1257          "MLFLOW_SERVER_CORS_ALLOWED_ORIGINS",
1258          "MLFLOW_SERVER_X_FRAME_OPTIONS",
1259      ):
1260          monkeypatch.delenv(key, raising=False)
1261  
1262      def assert_event_recorded_before_run_server(**kwargs):
1263          mock_telemetry_client.flush()
1264          validate_telemetry_record(
1265              mock_telemetry_client,
1266              mock_requests,
1267              TrackingServerStartEvent.name,
1268              expected_params,
1269          )
1270  
1271      runner = CliRunner(catch_exceptions=False)
1272      with (
1273          mock.patch(
1274              "mlflow.server._run_server", side_effect=assert_event_recorded_before_run_server
1275          ),
1276          mock.patch("mlflow.server.handlers.initialize_backend_stores"),
1277      ):
1278          runner.invoke(server, cli_args)
1279  
1280  
1281  def test_ai_command_run(mock_requests, mock_telemetry_client: TelemetryClient):
1282      from mlflow.ai_commands import commands
1283  
1284      runner = CliRunner(catch_exceptions=False)
1285      # Test CLI context
1286      with mock.patch("mlflow.ai_commands.get_command", return_value="---\ntest\n---\nTest command"):
1287          result = runner.invoke(commands, ["run", "test_command"])
1288          assert result.exit_code == 0
1289  
1290      mock_telemetry_client.flush()
1291      validate_telemetry_record(
1292          mock_telemetry_client,
1293          mock_requests,
1294          AiCommandRunEvent.name,
1295          {"command_key": "test_command", "context": "cli"},
1296      )
1297  
1298  
1299  def test_git_model_versioning(mock_requests, mock_telemetry_client):
1300      from mlflow.genai import enable_git_model_versioning
1301  
1302      with enable_git_model_versioning():
1303          pass
1304  
1305      mock_telemetry_client.flush()
1306      validate_telemetry_record(mock_telemetry_client, mock_requests, GitModelVersioningEvent.name)
1307  
1308  
1309  @pytest.mark.parametrize(
1310      ("model_uri", "expected_provider"),
1311      [
1312          ("databricks:/llama-3.1-70b", "databricks"),
1313          ("openai:/gpt-4o-mini", "openai"),
1314          ("endpoints:/my-endpoint", "endpoints"),
1315          ("anthropic:/claude-3-opus", "anthropic"),
1316      ],
1317  )
1318  def test_invoke_custom_judge_model(
1319      mock_requests,
1320      mock_telemetry_client: TelemetryClient,
1321      model_uri,
1322      expected_provider,
1323  ):
1324      from mlflow.genai.judges.utils import invoke_judge_model
1325  
1326      mock_response = json.dumps({"result": 0.8, "rationale": "Test rationale"})
1327  
1328      with mock.patch(
1329          "mlflow.genai.judges.adapters.gateway_adapter._invoke_via_gateway",
1330          return_value=mock_response,
1331      ):
1332          invoke_judge_model(
1333              model_uri=model_uri,
1334              prompt="Test prompt",
1335              assessment_name="test_assessment",
1336          )
1337  
1338          expected_params = {"model_provider": expected_provider}
1339          validate_telemetry_record(
1340              mock_telemetry_client,
1341              mock_requests,
1342              InvokeCustomJudgeModelEvent.name,
1343              expected_params,
1344          )
1345  
1346  
1347  def test_make_judge(mock_requests, mock_telemetry_client: TelemetryClient):
1348      make_judge(
1349          name="test_judge",
1350          instructions="Evaluate the {{ inputs }} and {{ outputs }}",
1351          model="openai:/gpt-4",
1352          feedback_value_type=str,
1353      )
1354      expected_params = {"model_provider": "openai"}
1355      validate_telemetry_record(
1356          mock_telemetry_client, mock_requests, MakeJudgeEvent.name, expected_params
1357      )
1358  
1359      make_judge(
1360          name="test_judge",
1361          instructions="Evaluate the {{ inputs }} and {{ outputs }}",
1362          feedback_value_type=str,
1363      )
1364      expected_params = {"model_provider": None}
1365      validate_telemetry_record(
1366          mock_telemetry_client, mock_requests, MakeJudgeEvent.name, expected_params
1367      )
1368  
1369  
1370  def test_align_judge(mock_requests, mock_telemetry_client: TelemetryClient):
1371      judge = make_judge(
1372          name="test_judge",
1373          instructions="Evaluate the {{ inputs }} and {{ outputs }}",
1374          model="openai:/gpt-4",
1375          feedback_value_type=str,
1376      )
1377  
1378      traces = [
1379          mock.MagicMock(spec=Trace),
1380          mock.MagicMock(spec=Trace),
1381      ]
1382  
1383      class MockOptimizer(AlignmentOptimizer):
1384          def align(self, judge, traces):
1385              return judge
1386  
1387      custom_optimizer = MockOptimizer()
1388      judge.align(traces, optimizer=custom_optimizer)
1389  
1390      expected_params = {"trace_count": 2, "optimizer_type": "MockOptimizer"}
1391      validate_telemetry_record(
1392          mock_telemetry_client, mock_requests, AlignJudgeEvent.name, expected_params
1393      )
1394  
1395  
1396  def test_discover_issues(mock_requests, mock_telemetry_client: TelemetryClient):
1397      traces = [
1398          mock.MagicMock(spec=Trace),
1399          mock.MagicMock(spec=Trace),
1400          mock.MagicMock(spec=Trace),
1401      ]
1402  
1403      mock_triage_run_id = "abc123"
1404      mock_eval_result = mock.MagicMock()
1405      mock_eval_result.run_id = mock_triage_run_id
1406  
1407      with (
1408          patch("mlflow.genai.discovery.pipeline.get_session_id", return_value=None),
1409          patch("mlflow.genai.discovery.pipeline.verify_scorer"),
1410          patch(
1411              "mlflow.genai.discovery.pipeline.mlflow.genai.evaluate",
1412              return_value=mock_eval_result,
1413          ),
1414          patch(
1415              "mlflow.genai.discovery.pipeline.extract_failing_traces",
1416              return_value=_TriageResult([], {}, {}),
1417          ),
1418          patch("mlflow.genai.discovery.pipeline.mlflow.MlflowClient"),
1419          patch("mlflow.genai.discovery.pipeline.mlflow.set_experiment"),
1420      ):
1421          discover_issues(
1422              traces=traces,
1423              model="openai:/gpt-4",
1424              categories=["hallucination", "accuracy"],
1425          )
1426  
1427      expected_params = {
1428          "model": "openai:/gpt-4",
1429          "trace_count": 3,
1430          "categories": ["hallucination", "accuracy"],
1431          "source_run_id": None,
1432          "issue_count": 0,
1433          "total_traces_analyzed": 3,
1434          "total_cost_usd": None,
1435          "triage_run_id": mock_triage_run_id,
1436      }
1437      validate_telemetry_record(
1438          mock_telemetry_client, mock_requests, DiscoverIssuesEvent.name, expected_params
1439      )
1440  
1441  
1442  def test_autologging(mock_requests, mock_telemetry_client: TelemetryClient):
1443      try:
1444          mlflow.openai.autolog()
1445  
1446          mlflow.autolog()
1447          mock_telemetry_client.flush()
1448          data = [record["data"] for record in mock_requests]
1449          params = [event["params"] for event in data if event["event_name"] == AutologgingEvent.name]
1450          assert (
1451              json.dumps({
1452                  "flavor": mlflow.openai.FLAVOR_NAME,
1453                  "log_traces": True,
1454                  "disable": False,
1455              })
1456              in params
1457          )
1458          assert json.dumps({"flavor": "all", "log_traces": True, "disable": False}) in params
1459      finally:
1460          mlflow.autolog(disable=True)
1461  
1462  
1463  def test_load_prompt(mock_requests, mock_telemetry_client: TelemetryClient):
1464      # Register a prompt first
1465      prompt = mlflow.genai.register_prompt(
1466          name="test_prompt",
1467          template="Hello {{name}}",
1468      )
1469      mock_telemetry_client.flush()
1470  
1471      # Set an alias for testing
1472      mlflow.genai.set_prompt_alias(name="test_prompt", version=prompt.version, alias="production")
1473  
1474      # Test load_prompt with version (no alias)
1475      mlflow.genai.load_prompt(name_or_uri="test_prompt", version=prompt.version)
1476      validate_telemetry_record(
1477          mock_telemetry_client,
1478          mock_requests,
1479          LoadPromptEvent.name,
1480          {"uses_alias": False},
1481      )
1482  
1483      # Test load_prompt with URI and version (no alias)
1484      mlflow.genai.load_prompt(name_or_uri=f"prompts:/test_prompt/{prompt.version}")
1485      validate_telemetry_record(
1486          mock_telemetry_client,
1487          mock_requests,
1488          LoadPromptEvent.name,
1489          {"uses_alias": False},
1490      )
1491  
1492      # Test load_prompt with alias
1493      mlflow.genai.load_prompt(name_or_uri="prompts:/test_prompt@production")
1494      validate_telemetry_record(
1495          mock_telemetry_client, mock_requests, LoadPromptEvent.name, {"uses_alias": True}
1496      )
1497  
1498      # Test load_prompt with @latest (special alias)
1499      mlflow.genai.load_prompt(name_or_uri="prompts:/test_prompt@latest")
1500      validate_telemetry_record(
1501          mock_telemetry_client, mock_requests, LoadPromptEvent.name, {"uses_alias": True}
1502      )
1503  
1504  
1505  def test_scorer_call_direct(mock_requests, mock_telemetry_client: TelemetryClient):
1506      @scorer
1507      def custom_scorer(outputs) -> bool:
1508          return len(outputs) > 0
1509  
1510      result = custom_scorer(outputs="test output")
1511      assert result is True
1512  
1513      validate_telemetry_record(
1514          mock_telemetry_client,
1515          mock_requests,
1516          ScorerCallEvent.name,
1517          {
1518              "scorer_class": "UserDefinedScorer",
1519              "scorer_kind": "decorator",
1520              "scope": "trace",
1521              "callsite": "direct_scorer_call",
1522              "has_feedback_error": False,
1523          },
1524      )
1525  
1526      safety_scorer = Safety()
1527  
1528      mock_feedback = Feedback(
1529          name="test_feedback",
1530          value="yes",
1531          rationale="Test rationale",
1532      )
1533  
1534      with mock.patch(
1535          "mlflow.genai.judges.builtin.invoke_judge_model",
1536          return_value=mock_feedback,
1537      ):
1538          safety_scorer(outputs="test output")
1539  
1540      validate_telemetry_record(
1541          mock_telemetry_client,
1542          mock_requests,
1543          ScorerCallEvent.name,
1544          {
1545              "scorer_class": "Safety",
1546              "scorer_kind": "builtin",
1547              "scope": "trace",
1548              "callsite": "direct_scorer_call",
1549              "has_feedback_error": False,
1550          },
1551      )
1552  
1553      mock_requests.clear()
1554  
1555      guidelines_scorer = Guidelines(guidelines="The response must be in English")
1556      with mock.patch(
1557          "mlflow.genai.judges.builtin.invoke_judge_model",
1558          return_value=mock_feedback,
1559      ):
1560          guidelines_scorer(
1561              inputs={"question": "What is MLflow?"}, outputs="MLflow is an ML platform"
1562          )
1563  
1564      validate_telemetry_record(
1565          mock_telemetry_client,
1566          mock_requests,
1567          ScorerCallEvent.name,
1568          {
1569              "scorer_class": "Guidelines",
1570              "scorer_kind": "guidelines",
1571              "scope": "trace",
1572              "callsite": "direct_scorer_call",
1573              "has_feedback_error": False,
1574          },
1575      )
1576  
1577      mock_requests.clear()
1578  
1579      class CustomClassScorer(Scorer):
1580          name: str = "custom_class"
1581  
1582          def __call__(self, *, outputs) -> bool:
1583              return len(outputs) > 0
1584  
1585      custom_class_scorer = CustomClassScorer()
1586      result = custom_class_scorer(outputs="test output")
1587      assert result is True
1588  
1589      validate_telemetry_record(
1590          mock_telemetry_client,
1591          mock_requests,
1592          ScorerCallEvent.name,
1593          {
1594              "scorer_class": "UserDefinedScorer",
1595              "scorer_kind": "class",
1596              "scope": "trace",
1597              "callsite": "direct_scorer_call",
1598              "has_feedback_error": False,
1599          },
1600      )
1601  
1602  
1603  def test_scorer_call_from_genai_evaluate(mock_requests, mock_telemetry_client: TelemetryClient):
1604      @scorer
1605      def simple_length_checker(outputs) -> bool:
1606          return len(outputs) > 0
1607  
1608      session_judge = make_judge(
1609          name="conversation_quality",
1610          instructions="Evaluate if the {{ conversation }} is engaging and coherent",
1611          model="openai:/gpt-4",
1612      )
1613  
1614      # Create traces with session metadata for session-level scorer testing
1615      @mlflow.trace(span_type=mlflow.entities.SpanType.CHAT_MODEL)
1616      def model(question, session_id):
1617          mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id})
1618          return f"Answer to: {question}"
1619  
1620      model("What is MLflow?", session_id="test_session")
1621      trace_1 = mlflow.get_trace(mlflow.get_last_active_trace_id())
1622  
1623      model("How does MLflow work?", session_id="test_session")
1624      trace_2 = mlflow.get_trace(mlflow.get_last_active_trace_id())
1625  
1626      test_data = pd.DataFrame([
1627          {
1628              "trace": trace_1,
1629          },
1630          {
1631              "trace": trace_2,
1632          },
1633      ])
1634  
1635      mock_feedback = Feedback(
1636          name="test_feedback",
1637          value="yes",
1638          rationale="Test",
1639      )
1640  
1641      with mock.patch(
1642          "mlflow.genai.judges.instructions_judge.invoke_judge_model",
1643          return_value=mock_feedback,
1644      ):
1645          mlflow.genai.evaluate(data=test_data, scorers=[simple_length_checker, session_judge])
1646  
1647      mock_telemetry_client.flush()
1648  
1649      scorer_call_events = [
1650          record for record in mock_requests if record["data"]["event_name"] == ScorerCallEvent.name
1651      ]
1652  
1653      # Should have 3 events: 2 response-level calls (one per trace)
1654      # + 1 session-level call (one per session)
1655      assert len(scorer_call_events) == 3
1656  
1657      event_params = [json.loads(event["data"]["params"]) for event in scorer_call_events]
1658  
1659      # Validate response-level scorer was called twice (once per trace)
1660      response_level_events = [
1661          params
1662          for params in event_params
1663          if params["scorer_class"] == "UserDefinedScorer"
1664          and params["scorer_kind"] == "decorator"
1665          and params["scope"] == "trace"
1666          and params["callsite"] == "genai_evaluate"
1667          and params["has_feedback_error"] is False
1668      ]
1669      assert len(response_level_events) == 2
1670  
1671      # Validate session-level scorer was called once (once per session)
1672      session_level_events = [
1673          params
1674          for params in event_params
1675          if params["scorer_class"] == "UserDefinedScorer"
1676          and params["scorer_kind"] == "instructions"
1677          and params["scope"] == "session"
1678          and params["callsite"] == "genai_evaluate"
1679          and params["has_feedback_error"] is False
1680      ]
1681      assert len(session_level_events) == 1
1682  
1683      mock_requests.clear()
1684  
1685  
1686  @pytest.mark.parametrize(
1687      ("job_name", "expected_callsite"),
1688      [
1689          ("run_online_trace_scorer", "online_scoring"),
1690          ("run_online_session_scorer", "online_scoring"),
1691          # Counterexample: non-online-scoring job should be treated as direct call
1692          ("invoke_scorer", "direct_scorer_call"),
1693      ],
1694  )
1695  def test_scorer_call_online_scoring_callsite(
1696      mock_requests, mock_telemetry_client: TelemetryClient, monkeypatch, job_name, expected_callsite
1697  ):
1698      # Import here to avoid circular imports
1699      from mlflow.server.jobs.utils import MLFLOW_SERVER_JOB_NAME_ENV_VAR
1700  
1701      monkeypatch.setenv(MLFLOW_SERVER_JOB_NAME_ENV_VAR, job_name)
1702  
1703      @scorer
1704      def custom_scorer(outputs: str) -> bool:
1705          return True
1706  
1707      custom_scorer(outputs="test output")
1708  
1709      validate_telemetry_record(
1710          mock_telemetry_client,
1711          mock_requests,
1712          ScorerCallEvent.name,
1713          {
1714              "scorer_class": "UserDefinedScorer",
1715              "scorer_kind": "decorator",
1716              "scope": "trace",
1717              "callsite": expected_callsite,
1718              "has_feedback_error": False,
1719          },
1720      )
1721  
1722  
1723  def test_scorer_call_tracks_feedback_errors(mock_requests, mock_telemetry_client: TelemetryClient):
1724      error_judge = make_judge(
1725          name="quality_judge",
1726          instructions="Evaluate if {{ outputs }} is high quality",
1727          model="openai:/gpt-4",
1728      )
1729  
1730      error_feedback = Feedback(
1731          name="quality_judge",
1732          error="Model invocation failed",
1733          source=AssessmentSource(
1734              source_type=AssessmentSourceType.LLM_JUDGE, source_id="openai:/gpt-4"
1735          ),
1736      )
1737      with mock.patch(
1738          "mlflow.genai.judges.instructions_judge.invoke_judge_model",
1739          return_value=error_feedback,
1740      ):
1741          result = error_judge(outputs="test output")
1742          assert result.error is not None
1743  
1744      validate_telemetry_record(
1745          mock_telemetry_client,
1746          mock_requests,
1747          ScorerCallEvent.name,
1748          {
1749              "scorer_class": "UserDefinedScorer",
1750              "scorer_kind": "instructions",
1751              "scope": "trace",
1752              "callsite": "direct_scorer_call",
1753              "has_feedback_error": True,
1754          },
1755      )
1756  
1757      mock_requests.clear()
1758  
1759      # Test Scorer returns list of Feedback with mixed errors
1760      @scorer
1761      def multi_feedback_scorer(outputs) -> list[Feedback]:
1762          return [
1763              Feedback(name="feedback1", value=1.0),
1764              Feedback(name="feedback2", error=ValueError("Error in feedback 2")),
1765              Feedback(name="feedback3", value=0.5),
1766          ]
1767  
1768      multi_feedback_scorer(outputs="test")
1769      validate_telemetry_record(
1770          mock_telemetry_client,
1771          mock_requests,
1772          ScorerCallEvent.name,
1773          {
1774              "scorer_class": "UserDefinedScorer",
1775              "scorer_kind": "decorator",
1776              "scope": "trace",
1777              "callsite": "direct_scorer_call",
1778              "has_feedback_error": True,
1779          },
1780      )
1781  
1782      mock_requests.clear()
1783  
1784      # Test Scorer returns primitive type (no Feedback error possible)
1785      @scorer
1786      def primitive_scorer(outputs) -> bool:
1787          return True
1788  
1789      primitive_scorer(outputs="test")
1790      validate_telemetry_record(
1791          mock_telemetry_client,
1792          mock_requests,
1793          ScorerCallEvent.name,
1794          {
1795              "scorer_class": "UserDefinedScorer",
1796              "scorer_kind": "decorator",
1797              "scope": "trace",
1798              "callsite": "direct_scorer_call",
1799              "has_feedback_error": False,
1800          },
1801      )
1802  
1803  
1804  def test_scorer_call_wrapped_builtin_scorer_direct(
1805      mock_requests, mock_telemetry_client: TelemetryClient
1806  ):
1807      completeness_scorer = Completeness()
1808  
1809      mock_feedback = Feedback(
1810          name="completeness",
1811          value="yes",
1812          rationale="Test rationale",
1813      )
1814  
1815      with mock.patch(
1816          "mlflow.genai.judges.instructions_judge.invoke_judge_model",
1817          return_value=mock_feedback,
1818      ):
1819          completeness_scorer(inputs={"question": "What is MLflow?"}, outputs="MLflow is a platform")
1820  
1821      mock_telemetry_client.flush()
1822  
1823      # Verify exactly 1 scorer_call event was created
1824      # (only top-level Completeness, not nested InstructionsJudge)
1825      scorer_call_events = [
1826          record for record in mock_requests if record["data"]["event_name"] == ScorerCallEvent.name
1827      ]
1828      assert len(scorer_call_events) == 1, (
1829          f"Expected 1 scorer call event for Completeness scorer (nested calls should be skipped), "
1830          f"got {len(scorer_call_events)}"
1831      )
1832  
1833      validate_telemetry_record(
1834          mock_telemetry_client,
1835          mock_requests,
1836          ScorerCallEvent.name,
1837          {
1838              "scorer_class": "Completeness",
1839              "scorer_kind": "builtin",
1840              "scope": "trace",
1841              "callsite": "direct_scorer_call",
1842              "has_feedback_error": False,
1843          },
1844      )
1845  
1846  
1847  def test_scorer_call_wrapped_builtin_scorer_from_genai_evaluate(
1848      mock_requests, mock_telemetry_client: TelemetryClient
1849  ):
1850      user_frustration_scorer = UserFrustration()
1851  
1852      @mlflow.trace(span_type=mlflow.entities.SpanType.CHAT_MODEL)
1853      def model(question, session_id):
1854          mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id})
1855          return f"Answer to: {question}"
1856  
1857      model("What is MLflow?", session_id="test_session")
1858      trace_1 = mlflow.get_trace(mlflow.get_last_active_trace_id())
1859  
1860      model("How does MLflow work?", session_id="test_session")
1861      trace_2 = mlflow.get_trace(mlflow.get_last_active_trace_id())
1862  
1863      test_data = pd.DataFrame([
1864          {"trace": trace_1},
1865          {"trace": trace_2},
1866      ])
1867  
1868      mock_feedback = Feedback(
1869          name="user_frustration",
1870          value="no",
1871          rationale="Test rationale",
1872      )
1873  
1874      with mock.patch(
1875          "mlflow.genai.judges.instructions_judge.invoke_judge_model",
1876          return_value=mock_feedback,
1877      ):
1878          mlflow.genai.evaluate(data=test_data, scorers=[user_frustration_scorer])
1879  
1880      mock_telemetry_client.flush()
1881  
1882      # Verify exactly 1 scorer_call event was created for the session-level scorer
1883      # (one call at the session level and no nested InstructionsJudge event)
1884      scorer_call_events = [
1885          record for record in mock_requests if record["data"]["event_name"] == ScorerCallEvent.name
1886      ]
1887      assert len(scorer_call_events) == 1, (
1888          f"Expected 1 scorer call event for UserFrustration scorer "
1889          f"(nested calls should be skipped), got {len(scorer_call_events)}"
1890      )
1891  
1892      validate_telemetry_record(
1893          mock_telemetry_client,
1894          mock_requests,
1895          ScorerCallEvent.name,
1896          {
1897              "scorer_class": "UserFrustration",
1898              "scorer_kind": "builtin",
1899              "scope": "session",
1900              "callsite": "genai_evaluate",
1901              "has_feedback_error": False,
1902          },
1903      )
1904  
1905  
1906  def test_gateway_crud_telemetry(mock_requests, mock_telemetry_client: TelemetryClient, tmp_path):
1907      db_path = tmp_path / "mlflow.db"
1908      store = SqlAlchemyStore(f"sqlite:///{db_path}", tmp_path.as_posix())
1909  
1910      secret = store.create_gateway_secret(
1911          secret_name="test-secret",
1912          secret_value={"api_key": "test-api-key"},
1913          provider="openai",
1914          created_by="test-user",
1915      )
1916  
1917      model_def = store.create_gateway_model_definition(
1918          name="test-model",
1919          provider="openai",
1920          model_name="gpt-4",
1921          secret_id=secret.secret_id,
1922          created_by="test-user",
1923      )
1924      validate_telemetry_record(
1925          mock_telemetry_client,
1926          mock_requests,
1927          GatewayCreateModelDefinitionEvent.name,
1928          {"model_name": "gpt-4", "provider": "openai"},
1929      )
1930  
1931      model_config = GatewayEndpointModelConfig(
1932          model_definition_id=model_def.model_definition_id,
1933          linkage_type=GatewayModelLinkageType.PRIMARY,
1934          weight=100,
1935      )
1936      endpoint = store.create_gateway_endpoint(
1937          name="test-endpoint",
1938          model_configs=[model_config],
1939          created_by="test-user",
1940      )
1941      validate_telemetry_record(
1942          mock_telemetry_client,
1943          mock_requests,
1944          GatewayCreateEndpointEvent.name,
1945          {
1946              "has_fallback_config": False,
1947              "routing_strategy": None,
1948              "num_model_configs": 1,
1949              "usage_tracking": True,
1950          },
1951      )
1952  
1953      store.get_gateway_endpoint(endpoint_id=endpoint.endpoint_id)
1954      validate_telemetry_record(
1955          mock_telemetry_client,
1956          mock_requests,
1957          GatewayGetEndpointEvent.name,
1958      )
1959  
1960      store.list_gateway_endpoints()
1961      validate_telemetry_record(
1962          mock_telemetry_client,
1963          mock_requests,
1964          GatewayListEndpointsEvent.name,
1965          {"filter_by_provider": False},
1966      )
1967  
1968      store.list_gateway_endpoints(provider="openai")
1969      validate_telemetry_record(
1970          mock_telemetry_client,
1971          mock_requests,
1972          GatewayListEndpointsEvent.name,
1973          {"filter_by_provider": True},
1974      )
1975  
1976      store.update_gateway_endpoint(
1977          endpoint_id=endpoint.endpoint_id,
1978          name="updated-endpoint",
1979      )
1980      validate_telemetry_record(
1981          mock_telemetry_client,
1982          mock_requests,
1983          GatewayUpdateEndpointEvent.name,
1984          {
1985              "has_fallback_config": False,
1986              "routing_strategy": None,
1987              "num_model_configs": None,
1988              "usage_tracking": None,
1989          },
1990      )
1991  
1992      store.delete_gateway_endpoint(endpoint_id=endpoint.endpoint_id)
1993      validate_telemetry_record(
1994          mock_telemetry_client,
1995          mock_requests,
1996          GatewayDeleteEndpointEvent.name,
1997      )
1998  
1999  
2000  def test_gateway_secret_crud_telemetry(
2001      mock_requests, mock_telemetry_client: TelemetryClient, tmp_path
2002  ):
2003      db_path = tmp_path / "mlflow.db"
2004      store = SqlAlchemyStore(f"sqlite:///{db_path}", tmp_path.as_posix())
2005  
2006      secret = store.create_gateway_secret(
2007          secret_name="test-secret",
2008          secret_value={"api_key": "test-api-key"},
2009          provider="openai",
2010          created_by="test-user",
2011      )
2012      validate_telemetry_record(
2013          mock_telemetry_client,
2014          mock_requests,
2015          GatewayCreateSecretEvent.name,
2016          {"provider": "openai"},
2017      )
2018  
2019      secret2 = store.create_gateway_secret(
2020          secret_name="test-secret-2",
2021          secret_value={"api_key": "test-api-key-2"},
2022          created_by="test-user",
2023      )
2024      validate_telemetry_record(
2025          mock_telemetry_client,
2026          mock_requests,
2027          GatewayCreateSecretEvent.name,
2028          {"provider": None},
2029      )
2030  
2031      store.list_secret_infos()
2032      validate_telemetry_record(
2033          mock_telemetry_client,
2034          mock_requests,
2035          GatewayListSecretsEvent.name,
2036          {"filter_by_provider": False},
2037      )
2038  
2039      store.list_secret_infos(provider="openai")
2040      validate_telemetry_record(
2041          mock_telemetry_client,
2042          mock_requests,
2043          GatewayListSecretsEvent.name,
2044          {"filter_by_provider": True},
2045      )
2046  
2047      store.update_gateway_secret(
2048          secret_id=secret.secret_id,
2049          secret_value={"api_key": "updated-api-key"},
2050          updated_by="test-user",
2051      )
2052      validate_telemetry_record(
2053          mock_telemetry_client,
2054          mock_requests,
2055          GatewayUpdateSecretEvent.name,
2056      )
2057  
2058      store.delete_gateway_secret(secret_id=secret.secret_id)
2059      validate_telemetry_record(
2060          mock_telemetry_client,
2061          mock_requests,
2062          GatewayDeleteSecretEvent.name,
2063      )
2064  
2065      store.delete_gateway_secret(secret_id=secret2.secret_id)
2066  
2067  
2068  def test_gateway_budget_policy_crud_telemetry(
2069      mock_requests, mock_telemetry_client: TelemetryClient, tmp_path
2070  ):
2071      db_path = tmp_path / "mlflow.db"
2072      store = SqlAlchemyStore(f"sqlite:///{db_path}", tmp_path.as_posix())
2073  
2074      policy = store.create_budget_policy(
2075          budget_unit=BudgetUnit.USD,
2076          budget_amount=100.0,
2077          duration=BudgetDuration(unit=BudgetDurationUnit.DAYS, value=30),
2078          target_scope=BudgetTargetScope.GLOBAL,
2079          budget_action=BudgetAction.ALERT,
2080          created_by="test-user",
2081      )
2082      validate_telemetry_record(
2083          mock_telemetry_client,
2084          mock_requests,
2085          GatewayCreateBudgetPolicyEvent.name,
2086          {
2087              "budget_unit": "USD",
2088              "duration_unit": "DAYS",
2089              "target_scope": "GLOBAL",
2090              "budget_action": "ALERT",
2091          },
2092      )
2093  
2094      store.list_budget_policies()
2095      validate_telemetry_record(
2096          mock_telemetry_client,
2097          mock_requests,
2098          GatewayListBudgetPoliciesEvent.name,
2099      )
2100  
2101      store.update_budget_policy(
2102          budget_policy_id=policy.budget_policy_id,
2103          budget_amount=200.0,
2104      )
2105      validate_telemetry_record(
2106          mock_telemetry_client,
2107          mock_requests,
2108          GatewayUpdateBudgetPolicyEvent.name,
2109      )
2110  
2111      store.delete_budget_policy(budget_policy_id=policy.budget_policy_id)
2112      validate_telemetry_record(
2113          mock_telemetry_client,
2114          mock_requests,
2115          GatewayDeleteBudgetPolicyEvent.name,
2116      )
2117  
2118  
2119  def test_gateway_guardrail_crud_telemetry(
2120      mock_requests, mock_telemetry_client: TelemetryClient, tmp_path
2121  ):
2122      db_path = tmp_path / "mlflow.db"
2123      store = SqlAlchemyStore(f"sqlite:///{db_path}", tmp_path.as_posix())
2124  
2125      secret = store.create_gateway_secret(
2126          secret_name="test-secret",
2127          secret_value={"api_key": "test-api-key"},
2128          provider="openai",
2129          created_by="test-user",
2130      )
2131      model_def = store.create_gateway_model_definition(
2132          name="test-model",
2133          provider="openai",
2134          model_name="gpt-4",
2135          secret_id=secret.secret_id,
2136          created_by="test-user",
2137      )
2138      endpoint = store.create_gateway_endpoint(
2139          name="test-endpoint",
2140          model_configs=[
2141              GatewayEndpointModelConfig(
2142                  model_definition_id=model_def.model_definition_id,
2143                  linkage_type=GatewayModelLinkageType.PRIMARY,
2144                  weight=100,
2145              )
2146          ],
2147          created_by="test-user",
2148          usage_tracking=False,
2149      )
2150      scorer_experiment_id = store.create_experiment("guardrail-scorer-exp")
2151      serialized_scorer = json.dumps({
2152          "instructions_judge_pydantic_data": {
2153              "model": "gateway:/test-endpoint",
2154              "instructions": "Is this input safe?",
2155          }
2156      })
2157      scorer = store.register_scorer(
2158          experiment_id=scorer_experiment_id,
2159          name="safety-judge",
2160          serialized_scorer=serialized_scorer,
2161      )
2162      guardrail = store.create_gateway_guardrail(
2163          name="guardrail-1",
2164          scorer_id=scorer.scorer_id,
2165          scorer_version=scorer.scorer_version,
2166          stage=GuardrailStage.BEFORE,
2167          action=GuardrailAction.VALIDATION,
2168          created_by="test-user",
2169      )
2170      validate_telemetry_record(
2171          mock_telemetry_client,
2172          mock_requests,
2173          GatewayCreateGuardrailEvent.name,
2174          {
2175              "stage": "BEFORE",
2176              "action": "VALIDATION",
2177          },
2178      )
2179  
2180      # Guardrail update telemetry is emitted by endpoint guardrail config updates.
2181      store.add_guardrail_to_endpoint(endpoint.endpoint_id, guardrail.guardrail_id, execution_order=1)
2182      store.update_endpoint_guardrail_config(
2183          endpoint_id=endpoint.endpoint_id,
2184          guardrail_id=guardrail.guardrail_id,
2185          execution_order=2,
2186      )
2187      validate_telemetry_record(
2188          mock_telemetry_client,
2189          mock_requests,
2190          GatewayUpdateGuardrailEvent.name,
2191          {"stage": None, "action": None},
2192      )
2193  
2194      store.delete_gateway_guardrail(guardrail.guardrail_id)
2195      validate_telemetry_record(
2196          mock_telemetry_client,
2197          mock_requests,
2198          GatewayDeleteGuardrailEvent.name,
2199      )
2200  
2201  
2202  @pytest.mark.asyncio
2203  async def test_gateway_invocation_telemetry(
2204      mock_requests, mock_telemetry_client: TelemetryClient, tmp_path
2205  ):
2206      db_path = tmp_path / "mlflow.db"
2207      store = SqlAlchemyStore(f"sqlite:///{db_path}", tmp_path.as_posix())
2208  
2209      secret = store.create_gateway_secret(
2210          secret_name="test-secret",
2211          secret_value={"api_key": "test-api-key"},
2212          provider="openai",
2213          created_by="test-user",
2214      )
2215      mock_telemetry_client.flush()
2216      mock_requests.clear()
2217  
2218      model_def = store.create_gateway_model_definition(
2219          name="test-model",
2220          provider="openai",
2221          model_name="gpt-4",
2222          secret_id=secret.secret_id,
2223          created_by="test-user",
2224      )
2225      endpoint = store.create_gateway_endpoint(
2226          name="test-endpoint",
2227          model_configs=[
2228              GatewayEndpointModelConfig(
2229                  model_definition_id=model_def.model_definition_id,
2230                  linkage_type=GatewayModelLinkageType.PRIMARY,
2231                  weight=100,
2232              )
2233          ],
2234          created_by="test-user",
2235      )
2236      mock_telemetry_client.flush()
2237      mock_requests.clear()
2238  
2239      mock_response = chat.ResponsePayload(
2240          id="test-id",
2241          object="chat.completion",
2242          created=1234567890,
2243          model="gpt-4",
2244          choices=[
2245              chat.Choice(
2246                  index=0,
2247                  message=chat.ResponseMessage(role="assistant", content="Hello!"),
2248                  finish_reason="stop",
2249              )
2250          ],
2251          usage=chat.ChatUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
2252      )
2253  
2254      mock_model = GatewayModelConfig(
2255          model_definition_id="test-model-def",
2256          provider="openai",
2257          model_name="gpt-4",
2258          secret_value={"api_key": "test"},
2259          linkage_type=GatewayModelLinkageType.PRIMARY,
2260      )
2261  
2262      # Test invocations endpoint (chat)
2263      mock_request = MagicMock(spec=Request)
2264      mock_request.headers = {}
2265      mock_request.json = AsyncMock(
2266          return_value={
2267              "messages": [{"role": "user", "content": "Hi"}],
2268              "temperature": 0.7,
2269              "stream": False,
2270          }
2271      )
2272  
2273      with (
2274          patch("mlflow.server.gateway_api._get_store", return_value=store),
2275          patch(
2276              "mlflow.server.gateway_api._create_provider_from_endpoint_name"
2277          ) as mock_create_provider,
2278      ):
2279          mock_provider = MagicMock()
2280          mock_provider.chat = AsyncMock(return_value=mock_response)
2281          mock_endpoint_config = GatewayEndpointConfig(
2282              endpoint_id=endpoint.endpoint_id,
2283              endpoint_name=endpoint.name,
2284              models=[mock_model],
2285          )
2286          mock_create_provider.return_value = (mock_provider, mock_endpoint_config)
2287  
2288          await invocations(endpoint.name, mock_request)
2289  
2290      data = validate_telemetry_record(
2291          mock_telemetry_client,
2292          mock_requests,
2293          GatewayInvocationEvent.name,
2294          check_params=False,
2295      )
2296      params = json.loads(data["params"])
2297      assert params["is_streaming"] is False
2298      assert params["invocation_type"] == "mlflow_invocations"
2299      assert params["has_traceparent"] is False
2300      assert params["auth_enabled"] is False
2301      assert params["endpoint_id"] == endpoint.endpoint_id
2302      assert params["provider"] == "openai"
2303      # Non-streaming includes timing fields
2304      assert "provider_duration_ms" in params
2305      assert "gateway_overhead_ms" in params
2306  
2307      # Test chat_completions endpoint
2308      mock_request = MagicMock(spec=Request)
2309      mock_request.headers = {}
2310      mock_request.json = AsyncMock(
2311          return_value={
2312              "model": endpoint.name,
2313              "messages": [{"role": "user", "content": "Hi"}],
2314              "temperature": 0.7,
2315              "stream": False,
2316          }
2317      )
2318  
2319      with (
2320          patch("mlflow.server.gateway_api._get_store", return_value=store),
2321          patch(
2322              "mlflow.server.gateway_api._create_provider_from_endpoint_name"
2323          ) as mock_create_provider,
2324      ):
2325          mock_provider = MagicMock()
2326          mock_provider.chat = AsyncMock(return_value=mock_response)
2327          mock_endpoint_config = GatewayEndpointConfig(
2328              endpoint_id=endpoint.endpoint_id,
2329              endpoint_name=endpoint.name,
2330              models=[mock_model],
2331          )
2332          mock_create_provider.return_value = (mock_provider, mock_endpoint_config)
2333  
2334          await chat_completions(mock_request)
2335  
2336      data = validate_telemetry_record(
2337          mock_telemetry_client,
2338          mock_requests,
2339          GatewayInvocationEvent.name,
2340          check_params=False,
2341      )
2342      params = json.loads(data["params"])
2343      assert params["is_streaming"] is False
2344      assert params["invocation_type"] == "mlflow_chat_completions"
2345      assert params["endpoint_id"] == endpoint.endpoint_id
2346      assert params["provider"] == "openai"
2347  
2348      # Test streaming invocation — timing fields should be absent
2349      mock_request = MagicMock(spec=Request)
2350      mock_request.headers = {}
2351      mock_request.json = AsyncMock(
2352          return_value={
2353              "model": endpoint.name,
2354              "messages": [{"role": "user", "content": "Hi"}],
2355              "stream": True,
2356          }
2357      )
2358  
2359      async def mock_stream():
2360          yield chat.StreamResponsePayload(
2361              id="test-id",
2362              object="chat.completion.chunk",
2363              created=1234567890,
2364              model="gpt-4",
2365              choices=[
2366                  chat.StreamChoice(
2367                      index=0,
2368                      delta=chat.StreamDelta(role="assistant", content="Hello"),
2369                      finish_reason=None,
2370                  )
2371              ],
2372          )
2373  
2374      with (
2375          patch("mlflow.server.gateway_api._get_store", return_value=store),
2376          patch(
2377              "mlflow.server.gateway_api._create_provider_from_endpoint_name"
2378          ) as mock_create_provider,
2379      ):
2380          mock_provider = MagicMock()
2381          mock_provider.chat_stream = MagicMock(return_value=mock_stream())
2382          mock_endpoint_config = GatewayEndpointConfig(
2383              endpoint_id=endpoint.endpoint_id,
2384              endpoint_name=endpoint.name,
2385              models=[mock_model],
2386          )
2387          mock_create_provider.return_value = (mock_provider, mock_endpoint_config)
2388  
2389          await chat_completions(mock_request)
2390  
2391      data = validate_telemetry_record(
2392          mock_telemetry_client,
2393          mock_requests,
2394          GatewayInvocationEvent.name,
2395          check_params=False,
2396      )
2397      params = json.loads(data["params"])
2398      assert params["is_streaming"] is True
2399      assert params["invocation_type"] == "mlflow_chat_completions"
2400      # Streaming responses should NOT include timing fields
2401      assert "provider_duration_ms" not in params
2402      assert "gateway_overhead_ms" not in params
2403  
2404      # Test that caller header and traceparent are included in telemetry when present
2405      mock_request = MagicMock(spec=Request)
2406      mock_request.json = AsyncMock(
2407          return_value={
2408              "model": endpoint.name,
2409              "messages": [{"role": "user", "content": "Hi"}],
2410              "stream": False,
2411          }
2412      )
2413      mock_request.headers = {MLFLOW_GATEWAY_CALLER_HEADER: "judge", "traceparent": "00-abc-def-01"}
2414  
2415      mock_auth_module = MagicMock()
2416      mock_auth_module.is_auth_enabled = MagicMock(return_value=True)
2417  
2418      with (
2419          patch("mlflow.server.gateway_api._get_store", return_value=store),
2420          patch(
2421              "mlflow.server.gateway_api._create_provider_from_endpoint_name"
2422          ) as mock_create_provider,
2423          patch.dict("sys.modules", {"mlflow.server.auth": mock_auth_module}),
2424      ):
2425          mock_provider = MagicMock()
2426          mock_provider.chat = AsyncMock(return_value=mock_response)
2427          mock_endpoint_config = GatewayEndpointConfig(
2428              endpoint_id=endpoint.endpoint_id,
2429              endpoint_name=endpoint.name,
2430              models=[mock_model],
2431          )
2432          mock_create_provider.return_value = (mock_provider, mock_endpoint_config)
2433  
2434          await chat_completions(mock_request)
2435  
2436      data = validate_telemetry_record(
2437          mock_telemetry_client,
2438          mock_requests,
2439          GatewayInvocationEvent.name,
2440          check_params=False,
2441      )
2442      params = json.loads(data["params"])
2443      assert params["is_streaming"] is False
2444      assert params["invocation_type"] == "mlflow_chat_completions"
2445      assert params["caller"] == "judge"
2446      assert params["has_traceparent"] is True
2447      assert params["auth_enabled"] is True
2448  
2449  
2450  def test_tracing_context_propagation_get_and_set_success(
2451      mock_requests, mock_telemetry_client: TelemetryClient
2452  ):
2453      with mock.patch(
2454          "mlflow.telemetry.track.get_telemetry_client", return_value=mock_telemetry_client
2455      ):
2456          with mlflow.start_span("client span"):
2457              headers = get_tracing_context_headers_for_http_request()
2458  
2459      validate_telemetry_record(
2460          mock_telemetry_client,
2461          mock_requests,
2462          TracingContextPropagation.name,
2463      )
2464  
2465      with mock.patch(
2466          "mlflow.telemetry.track.get_telemetry_client", return_value=mock_telemetry_client
2467      ):
2468          with set_tracing_context_from_http_request_headers(headers):
2469              with mlflow.start_span("server span"):
2470                  pass
2471  
2472      validate_telemetry_record(
2473          mock_telemetry_client,
2474          mock_requests,
2475          TracingContextPropagation.name,
2476      )
2477  
2478  
2479  def test_update_issue_telemetry(mock_requests, mock_telemetry_client: TelemetryClient, db_uri):
2480      store = SqlAlchemyStore(db_uri, "/tmp")
2481  
2482      exp_id = store.create_experiment("test-exp")
2483      issue = store.create_issue(
2484          experiment_id=exp_id,
2485          name="Original name",
2486          description="Original description",
2487          status=IssueStatus.PENDING,
2488      )
2489      mock_telemetry_client.flush()
2490      mock_requests.clear()
2491  
2492      store.update_issue(
2493          issue_id=issue.issue_id,
2494          status=IssueStatus.RESOLVED,
2495          name="Updated name",
2496          description="Updated description",
2497          severity=IssueSeverity.HIGH,
2498      )
2499  
2500      validate_telemetry_record(
2501          mock_telemetry_client,
2502          mock_requests,
2503          UpdateIssueEvent.name,
2504          {
2505              "status": "resolved",
2506              "has_name": True,
2507              "has_description": True,
2508              "severity": "high",
2509              "source_run_id": None,
2510          },
2511      )