/ tests / telemetry / conftest.py
conftest.py
  1  from unittest.mock import Mock, patch
  2  
  3  import pytest
  4  
  5  import mlflow.telemetry.utils
  6  from mlflow.telemetry.client import (
  7      TelemetryClient,
  8      _fetch_server_info,
  9      _set_telemetry_client,
 10      get_telemetry_client,
 11  )
 12  from mlflow.version import VERSION
 13  
 14  
 15  @pytest.fixture(autouse=True)
 16  def clear_server_store_type_cache():
 17      _fetch_server_info.cache_clear()
 18      yield
 19      _fetch_server_info.cache_clear()
 20  
 21  
 22  @pytest.fixture(autouse=True)
 23  def terminate_telemetry_client():
 24      yield
 25      if client := get_telemetry_client():
 26          client._clean_up()
 27          # set to None to avoid side effect in other tests
 28          _set_telemetry_client(None)
 29  
 30  
 31  @pytest.fixture
 32  def mock_requests():
 33      """Fixture to mock requests.post and capture telemetry records."""
 34      captured_records = []
 35  
 36      url_status_code_map = {
 37          "http://127.0.0.1:9999/nonexistent": 404,
 38          "http://127.0.0.1:9999/unauthorized": 401,
 39          "http://127.0.0.1:9999/forbidden": 403,
 40          "http://127.0.0.1:9999/bad_request": 400,
 41      }
 42  
 43      def mock_post(url, json=None, **kwargs):
 44          if url in url_status_code_map:
 45              mock_response = Mock()
 46              mock_response.status_code = url_status_code_map[url]
 47              return mock_response
 48          if url == "http://localhost:9999":
 49              if json and "records" in json:
 50                  captured_records.extend(json["records"])
 51              mock_response = Mock()
 52              mock_response.status_code = 200
 53              mock_response.json.return_value = {
 54                  "status": "success",
 55                  "count": len(json.get("records", [])) if json else 0,
 56              }
 57              return mock_response
 58          return Mock(status_code=404)
 59  
 60      with patch("requests.post", side_effect=mock_post):
 61          yield captured_records
 62  
 63  
 64  @pytest.fixture(autouse=True)
 65  def mock_requests_get(request):
 66      if request.node.get_closest_marker("no_mock_requests_get"):
 67          yield
 68          return
 69  
 70      with patch("mlflow.telemetry.client.requests.get") as mock_get:
 71          mock_get.return_value = Mock(
 72              status_code=200,
 73              json=Mock(
 74                  return_value={
 75                      "mlflow_version": VERSION,
 76                      "disable_telemetry": False,
 77                      "ingestion_url": "http://localhost:9999",
 78                      "rollout_percentage": 100,
 79                      "disable_events": [],
 80                      "disable_sdks": [],
 81                  }
 82              ),
 83          )
 84          yield
 85  
 86  
 87  @pytest.fixture
 88  def mock_telemetry_client(mock_requests_get, mock_requests):
 89      with TelemetryClient() as client:
 90          client.activate()
 91          # ensure config is fetched before the test
 92          client._config_thread.join(timeout=1)
 93          yield client
 94  
 95  
 96  @pytest.fixture(autouse=True)
 97  def is_mlflow_testing(monkeypatch):
 98      # enable telemetry by default when running tests in local with dev version
 99      monkeypatch.setattr(mlflow.telemetry.utils, "_IS_MLFLOW_TESTING_TELEMETRY", True)
100  
101  
102  @pytest.fixture
103  def bypass_env_check(monkeypatch):
104      monkeypatch.setattr(mlflow.telemetry.utils, "_IS_MLFLOW_TESTING_TELEMETRY", False)
105      monkeypatch.setattr(mlflow.telemetry.utils, "_IS_IN_CI_ENV_OR_TESTING", False)
106      monkeypatch.setattr(mlflow.telemetry.utils, "_IS_MLFLOW_DEV_VERSION", False)