/ tests / server / test_prometheus_exporter.py
test_prometheus_exporter.py
 1  import pytest
 2  
 3  from mlflow.server.prometheus_exporter import activate_prometheus_exporter
 4  
 5  
 6  @pytest.fixture(autouse=True)
 7  def mock_settings_env_vars(tmp_path, monkeypatch):
 8      monkeypatch.setenv("PROMETHEUS_MULTIPROC_DIR", str(tmp_path))
 9  
10  
11  @pytest.fixture
12  def app():
13      from mlflow.server import app
14  
15      with app.app_context():
16          yield app
17  
18  
19  @pytest.fixture
20  def test_client(app):
21      with app.test_client() as c:
22          yield c
23  
24  
25  def test_metrics(app, test_client):
26      metrics = activate_prometheus_exporter(app)
27  
28      # test metrics for successful responses
29      success_labels = {"method": "GET", "status": "200"}
30      assert (
31          metrics.registry.get_sample_value("mlflow_http_request_total", labels=success_labels)
32          is None
33      )
34      resp = test_client.get("/")
35      assert resp.status_code == 200
36      assert (
37          metrics.registry.get_sample_value("mlflow_http_request_total", labels=success_labels) == 1
38      )
39  
40      # calling the metrics endpoint should not increment the counter
41      resp = test_client.get("/metrics")
42      assert resp.status_code == 200
43      assert (
44          metrics.registry.get_sample_value("mlflow_http_request_total", labels=success_labels) == 1
45      )
46  
47      # calling the health endpoint should not increment the counter
48      resp = test_client.get("/health")
49      assert resp.status_code == 200
50      assert (
51          metrics.registry.get_sample_value("mlflow_http_request_total", labels=success_labels) == 1
52      )
53  
54      # calling the version endpoint should not increment the counter
55      resp = test_client.get("/version")
56      assert resp.status_code == 200
57      assert (
58          metrics.registry.get_sample_value("mlflow_http_request_total", labels=success_labels) == 1
59      )
60  
61      # test metrics for failed responses
62      failure_labels = {"method": "GET", "status": "404"}
63      assert (
64          metrics.registry.get_sample_value("mlflow_http_request_total", labels=failure_labels)
65          is None
66      )
67      resp = test_client.get("/non-existent-endpoint")
68      assert resp.status_code == 404
69      assert (
70          metrics.registry.get_sample_value("mlflow_http_request_total", labels=failure_labels) == 1
71      )