/ tests / genai / conftest.py
conftest.py
  1  import functools
  2  import os
  3  from unittest import mock
  4  
  5  import pytest
  6  
  7  import mlflow
  8  import mlflow.telemetry.utils
  9  from mlflow.entities.assessment import Expectation
 10  from mlflow.entities.document import Document
 11  from mlflow.entities.span import SpanType
 12  from mlflow.genai.scorers.validation import IS_DBX_AGENTS_INSTALLED
 13  
 14  # Import telemetry test fixtures from tests/telemetry/conftest.py
 15  # This allows genai tests to use the same telemetry testing infrastructure
 16  from tests.telemetry.conftest import (  # noqa: F401
 17      mock_requests,
 18      mock_requests_get,
 19      mock_telemetry_client,
 20      terminate_telemetry_client,
 21  )
 22  
 23  
 24  @pytest.fixture
 25  def enable_telemetry_in_tests(monkeypatch):
 26      """
 27      Enable telemetry for tests that need to verify telemetry tracking.
 28      Use this fixture explicitly in tests that validate telemetry behavior.
 29      """
 30      monkeypatch.setattr(mlflow.telemetry.utils, "_IS_MLFLOW_TESTING_TELEMETRY", True)
 31  
 32  
 33  @pytest.fixture(autouse=True)
 34  def mock_init_auth():
 35      def mocked_init_auth(config_instance):
 36          config_instance.host = "https://databricks.com/"
 37          config_instance._header_factory = lambda: {}
 38  
 39      with mock.patch("databricks.sdk.config.Config.init_auth", new=mocked_init_auth):
 40          yield
 41  
 42  
 43  @pytest.fixture(params=[True, False], ids=["databricks", "oss"])
 44  def is_in_databricks(request):
 45      if request.param and not IS_DBX_AGENTS_INSTALLED:
 46          pytest.skip("Skipping Databricks test because `databricks-agents` is not installed.")
 47  
 48      # In CI, we run test twice, once without `databricks-agents` and once with.
 49      # To be effective, we skip OSS test when running with `databricks-agents`.
 50      if "GITHUB_ACTIONS" in os.environ:
 51          if not request.param and IS_DBX_AGENTS_INSTALLED:
 52              pytest.skip("Skipping OSS test in CI because `databricks-agents` is installed.")
 53  
 54      with (
 55          mock.patch("mlflow.genai.judges.utils.is_databricks_uri", return_value=request.param),
 56          mock.patch(
 57              "mlflow.utils.databricks_utils.is_databricks_default_tracking_uri",
 58              return_value=request.param,
 59          ),
 60      ):
 61          yield request.param
 62  
 63  
 64  def databricks_only(func):
 65      """Decorator that skips test if not in Databricks environment"""
 66  
 67      @functools.wraps(func)
 68      def wrapper(*args, **kwargs):
 69          if not IS_DBX_AGENTS_INSTALLED:
 70              pytest.skip("Skipping Databricks only test.")
 71  
 72          with mock.patch("mlflow.get_tracking_uri", return_value="databricks"):
 73              return func(*args, **kwargs)
 74  
 75      return wrapper
 76  
 77  
 78  @pytest.fixture
 79  def sample_rag_trace():
 80      @mlflow.trace(name="rag", span_type=SpanType.AGENT)
 81      def _predict(question):
 82          # Two retrievers calls
 83          _retrieve_1(question)
 84          _retrieve_2(question)
 85          return "answer"
 86  
 87      @mlflow.trace(span_type=SpanType.RETRIEVER)
 88      def _retrieve_1(question):
 89          return [
 90              Document(
 91                  page_content="content_1",
 92                  metadata={"doc_uri": "url_1"},
 93              ),
 94              Document(
 95                  page_content="content_2",
 96                  metadata={"doc_uri": "url_2"},
 97              ),
 98          ]
 99  
100      @mlflow.trace(span_type=SpanType.RETRIEVER)
101      def _retrieve_2(question):
102          return [Document(page_content="content_3")]
103  
104      _predict("query")
105  
106      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
107  
108      # Add expectations. Directly append to the trace info because OSS backend doesn't
109      # support assessment logging yet.
110      trace.info.assessments = [
111          Expectation(name="expected_response", value="expected answer"),
112          Expectation(name="expected_facts", value=["fact1", "fact2"]),
113          Expectation(name="guidelines", value=["write in english"]),
114      ]
115      return trace