conftest.py
1 import functools 2 import os 3 from unittest import mock 4 5 import pytest 6 7 import mlflow 8 import mlflow.telemetry.utils 9 from mlflow.entities.assessment import Expectation 10 from mlflow.entities.document import Document 11 from mlflow.entities.span import SpanType 12 from mlflow.genai.scorers.validation import IS_DBX_AGENTS_INSTALLED 13 14 # Import telemetry test fixtures from tests/telemetry/conftest.py 15 # This allows genai tests to use the same telemetry testing infrastructure 16 from tests.telemetry.conftest import ( # noqa: F401 17 mock_requests, 18 mock_requests_get, 19 mock_telemetry_client, 20 terminate_telemetry_client, 21 ) 22 23 24 @pytest.fixture 25 def enable_telemetry_in_tests(monkeypatch): 26 """ 27 Enable telemetry for tests that need to verify telemetry tracking. 28 Use this fixture explicitly in tests that validate telemetry behavior. 29 """ 30 monkeypatch.setattr(mlflow.telemetry.utils, "_IS_MLFLOW_TESTING_TELEMETRY", True) 31 32 33 @pytest.fixture(autouse=True) 34 def mock_init_auth(): 35 def mocked_init_auth(config_instance): 36 config_instance.host = "https://databricks.com/" 37 config_instance._header_factory = lambda: {} 38 39 with mock.patch("databricks.sdk.config.Config.init_auth", new=mocked_init_auth): 40 yield 41 42 43 @pytest.fixture(params=[True, False], ids=["databricks", "oss"]) 44 def is_in_databricks(request): 45 if request.param and not IS_DBX_AGENTS_INSTALLED: 46 pytest.skip("Skipping Databricks test because `databricks-agents` is not installed.") 47 48 # In CI, we run test twice, once without `databricks-agents` and once with. 49 # To be effective, we skip OSS test when running with `databricks-agents`. 50 if "GITHUB_ACTIONS" in os.environ: 51 if not request.param and IS_DBX_AGENTS_INSTALLED: 52 pytest.skip("Skipping OSS test in CI because `databricks-agents` is installed.") 53 54 with ( 55 mock.patch("mlflow.genai.judges.utils.is_databricks_uri", return_value=request.param), 56 mock.patch( 57 "mlflow.utils.databricks_utils.is_databricks_default_tracking_uri", 58 return_value=request.param, 59 ), 60 ): 61 yield request.param 62 63 64 def databricks_only(func): 65 """Decorator that skips test if not in Databricks environment""" 66 67 @functools.wraps(func) 68 def wrapper(*args, **kwargs): 69 if not IS_DBX_AGENTS_INSTALLED: 70 pytest.skip("Skipping Databricks only test.") 71 72 with mock.patch("mlflow.get_tracking_uri", return_value="databricks"): 73 return func(*args, **kwargs) 74 75 return wrapper 76 77 78 @pytest.fixture 79 def sample_rag_trace(): 80 @mlflow.trace(name="rag", span_type=SpanType.AGENT) 81 def _predict(question): 82 # Two retrievers calls 83 _retrieve_1(question) 84 _retrieve_2(question) 85 return "answer" 86 87 @mlflow.trace(span_type=SpanType.RETRIEVER) 88 def _retrieve_1(question): 89 return [ 90 Document( 91 page_content="content_1", 92 metadata={"doc_uri": "url_1"}, 93 ), 94 Document( 95 page_content="content_2", 96 metadata={"doc_uri": "url_2"}, 97 ), 98 ] 99 100 @mlflow.trace(span_type=SpanType.RETRIEVER) 101 def _retrieve_2(question): 102 return [Document(page_content="content_3")] 103 104 _predict("query") 105 106 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 107 108 # Add expectations. Directly append to the trace info because OSS backend doesn't 109 # support assessment logging yet. 110 trace.info.assessments = [ 111 Expectation(name="expected_response", value="expected answer"), 112 Expectation(name="expected_facts", value=["fact1", "fact2"]), 113 Expectation(name="guidelines", value=["write in english"]), 114 ] 115 return trace