conftest.py
1 from unittest.mock import Mock, patch 2 3 import pytest 4 5 import mlflow.telemetry.utils 6 from mlflow.telemetry.client import ( 7 TelemetryClient, 8 _fetch_server_info, 9 _set_telemetry_client, 10 get_telemetry_client, 11 ) 12 from mlflow.version import VERSION 13 14 15 @pytest.fixture(autouse=True) 16 def clear_server_store_type_cache(): 17 _fetch_server_info.cache_clear() 18 yield 19 _fetch_server_info.cache_clear() 20 21 22 @pytest.fixture(autouse=True) 23 def terminate_telemetry_client(): 24 yield 25 if client := get_telemetry_client(): 26 client._clean_up() 27 # set to None to avoid side effect in other tests 28 _set_telemetry_client(None) 29 30 31 @pytest.fixture 32 def mock_requests(): 33 """Fixture to mock requests.post and capture telemetry records.""" 34 captured_records = [] 35 36 url_status_code_map = { 37 "http://127.0.0.1:9999/nonexistent": 404, 38 "http://127.0.0.1:9999/unauthorized": 401, 39 "http://127.0.0.1:9999/forbidden": 403, 40 "http://127.0.0.1:9999/bad_request": 400, 41 } 42 43 def mock_post(url, json=None, **kwargs): 44 if url in url_status_code_map: 45 mock_response = Mock() 46 mock_response.status_code = url_status_code_map[url] 47 return mock_response 48 if url == "http://localhost:9999": 49 if json and "records" in json: 50 captured_records.extend(json["records"]) 51 mock_response = Mock() 52 mock_response.status_code = 200 53 mock_response.json.return_value = { 54 "status": "success", 55 "count": len(json.get("records", [])) if json else 0, 56 } 57 return mock_response 58 return Mock(status_code=404) 59 60 with patch("requests.post", side_effect=mock_post): 61 yield captured_records 62 63 64 @pytest.fixture(autouse=True) 65 def mock_requests_get(request): 66 if request.node.get_closest_marker("no_mock_requests_get"): 67 yield 68 return 69 70 with patch("mlflow.telemetry.client.requests.get") as mock_get: 71 mock_get.return_value = Mock( 72 status_code=200, 73 json=Mock( 74 return_value={ 75 "mlflow_version": VERSION, 76 "disable_telemetry": False, 77 "ingestion_url": "http://localhost:9999", 78 "rollout_percentage": 100, 79 "disable_events": [], 80 "disable_sdks": [], 81 } 82 ), 83 ) 84 yield 85 86 87 @pytest.fixture 88 def mock_telemetry_client(mock_requests_get, mock_requests): 89 with TelemetryClient() as client: 90 client.activate() 91 # ensure config is fetched before the test 92 client._config_thread.join(timeout=1) 93 yield client 94 95 96 @pytest.fixture(autouse=True) 97 def is_mlflow_testing(monkeypatch): 98 # enable telemetry by default when running tests in local with dev version 99 monkeypatch.setattr(mlflow.telemetry.utils, "_IS_MLFLOW_TESTING_TELEMETRY", True) 100 101 102 @pytest.fixture 103 def bypass_env_check(monkeypatch): 104 monkeypatch.setattr(mlflow.telemetry.utils, "_IS_MLFLOW_TESTING_TELEMETRY", False) 105 monkeypatch.setattr(mlflow.telemetry.utils, "_IS_IN_CI_ENV_OR_TESTING", False) 106 monkeypatch.setattr(mlflow.telemetry.utils, "_IS_MLFLOW_DEV_VERSION", False)