test_prometheus_exporter.py
1 import pytest 2 3 from mlflow.server.prometheus_exporter import activate_prometheus_exporter 4 5 6 @pytest.fixture(autouse=True) 7 def mock_settings_env_vars(tmp_path, monkeypatch): 8 monkeypatch.setenv("PROMETHEUS_MULTIPROC_DIR", str(tmp_path)) 9 10 11 @pytest.fixture 12 def app(): 13 from mlflow.server import app 14 15 with app.app_context(): 16 yield app 17 18 19 @pytest.fixture 20 def test_client(app): 21 with app.test_client() as c: 22 yield c 23 24 25 def test_metrics(app, test_client): 26 metrics = activate_prometheus_exporter(app) 27 28 # test metrics for successful responses 29 success_labels = {"method": "GET", "status": "200"} 30 assert ( 31 metrics.registry.get_sample_value("mlflow_http_request_total", labels=success_labels) 32 is None 33 ) 34 resp = test_client.get("/") 35 assert resp.status_code == 200 36 assert ( 37 metrics.registry.get_sample_value("mlflow_http_request_total", labels=success_labels) == 1 38 ) 39 40 # calling the metrics endpoint should not increment the counter 41 resp = test_client.get("/metrics") 42 assert resp.status_code == 200 43 assert ( 44 metrics.registry.get_sample_value("mlflow_http_request_total", labels=success_labels) == 1 45 ) 46 47 # calling the health endpoint should not increment the counter 48 resp = test_client.get("/health") 49 assert resp.status_code == 200 50 assert ( 51 metrics.registry.get_sample_value("mlflow_http_request_total", labels=success_labels) == 1 52 ) 53 54 # calling the version endpoint should not increment the counter 55 resp = test_client.get("/version") 56 assert resp.status_code == 200 57 assert ( 58 metrics.registry.get_sample_value("mlflow_http_request_total", labels=success_labels) == 1 59 ) 60 61 # test metrics for failed responses 62 failure_labels = {"method": "GET", "status": "404"} 63 assert ( 64 metrics.registry.get_sample_value("mlflow_http_request_total", labels=failure_labels) 65 is None 66 ) 67 resp = test_client.get("/non-existent-endpoint") 68 assert resp.status_code == 404 69 assert ( 70 metrics.registry.get_sample_value("mlflow_http_request_total", labels=failure_labels) == 1 71 )