/ tests / server / test_security_integration.py
test_security_integration.py
  1  import json
  2  
  3  import pytest
  4  from werkzeug.test import Client
  5  
  6  
  7  @pytest.mark.parametrize(
  8      ("host", "origin", "expected_status", "should_block"),
  9      [
 10          ("evil.attacker.com:5000", "http://evil.attacker.com:5000", 403, True),
 11          ("localhost:5000", None, None, False),
 12      ],
 13  )
 14  def test_dns_rebinding_and_cors_protection(
 15      mlflow_app_client, host, origin, expected_status, should_block
 16  ):
 17      headers = {"Host": host, "Content-Type": "application/json"}
 18      if origin:
 19          headers["Origin"] = origin
 20  
 21      response = mlflow_app_client.post(
 22          "/api/2.0/mlflow/experiments/search",
 23          headers=headers,
 24          data=json.dumps({"order_by": ["creation_time DESC", "name ASC"], "max_results": 50}),
 25      )
 26  
 27      if should_block:
 28          assert response.status_code == expected_status
 29          assert (
 30              b"Invalid Host header" in response.data
 31              or b"Cross-origin request blocked" in response.data
 32          )
 33      else:
 34          assert response.status_code != 403
 35  
 36  
 37  @pytest.mark.parametrize(
 38      ("origin", "endpoint", "expected_blocked"),
 39      [
 40          ("http://malicious-site.com", "/api/2.0/mlflow/experiments/create", True),
 41          ("http://localhost:3000", "/api/2.0/mlflow/experiments/search", False),
 42      ],
 43  )
 44  def test_cors_for_state_changing_requests(mlflow_app_client, origin, endpoint, expected_blocked):
 45      response = mlflow_app_client.post(
 46          endpoint,
 47          headers={"Origin": origin, "Content-Type": "application/json"},
 48          data=json.dumps({"name": "test-experiment"} if "create" in endpoint else {}),
 49      )
 50  
 51      if expected_blocked:
 52          assert response.status_code == 403
 53          assert b"Cross-origin request blocked" in response.data
 54      else:
 55          assert response.status_code != 403
 56  
 57  
 58  def test_cors_with_configured_origins(monkeypatch: pytest.MonkeyPatch):
 59      monkeypatch.setenv("MLFLOW_SERVER_CORS_ALLOWED_ORIGINS", "https://trusted-app.com")
 60  
 61      from flask import Flask
 62  
 63      from mlflow.server import handlers, security
 64  
 65      app = Flask(__name__)
 66      for http_path, handler, methods in handlers.get_endpoints():
 67          app.add_url_rule(http_path, handler.__name__, handler, methods=methods)
 68  
 69      security.init_security_middleware(app)
 70      client = Client(app)
 71  
 72      test_cases = [
 73          ("https://trusted-app.com", False),
 74          ("http://evil.com", True),
 75      ]
 76  
 77      for origin, should_block in test_cases:
 78          response = client.post(
 79              "/api/2.0/mlflow/experiments/search",
 80              headers={"Origin": origin, "Content-Type": "application/json"},
 81              data=json.dumps({}),
 82          )
 83  
 84          if should_block:
 85              assert response.status_code == 403
 86          else:
 87              assert response.status_code != 403
 88  
 89  
 90  def test_security_headers_on_responses(mlflow_app_client):
 91      response = mlflow_app_client.get("/health")
 92      assert response.headers.get("X-Content-Type-Options") == "nosniff"
 93      assert response.headers.get("X-Frame-Options") == "SAMEORIGIN"
 94  
 95  
 96  @pytest.mark.parametrize(
 97      ("origin", "expected_status", "should_have_cors"),
 98      [
 99          ("http://localhost:3000", 204, True),
100          ("http://evil.com", None, False),
101      ],
102  )
103  def test_preflight_options_requests(mlflow_app_client, origin, expected_status, should_have_cors):
104      response = mlflow_app_client.options(
105          "/api/2.0/mlflow/experiments/search",
106          headers={
107              "Origin": origin,
108              "Access-Control-Request-Method": "POST",
109              "Access-Control-Request-Headers": "Content-Type",
110          },
111      )
112  
113      if expected_status:
114          assert response.status_code == expected_status
115  
116      if should_have_cors:
117          assert response.headers.get("Access-Control-Allow-Origin") == origin
118          assert "POST" in response.headers.get("Access-Control-Allow-Methods", "")
119      else:
120          assert (
121              "Access-Control-Allow-Origin" not in response.headers
122              or response.headers.get("Access-Control-Allow-Origin") != origin
123          )