test_security_integration.py
1 import json 2 3 import pytest 4 from werkzeug.test import Client 5 6 7 @pytest.mark.parametrize( 8 ("host", "origin", "expected_status", "should_block"), 9 [ 10 ("evil.attacker.com:5000", "http://evil.attacker.com:5000", 403, True), 11 ("localhost:5000", None, None, False), 12 ], 13 ) 14 def test_dns_rebinding_and_cors_protection( 15 mlflow_app_client, host, origin, expected_status, should_block 16 ): 17 headers = {"Host": host, "Content-Type": "application/json"} 18 if origin: 19 headers["Origin"] = origin 20 21 response = mlflow_app_client.post( 22 "/api/2.0/mlflow/experiments/search", 23 headers=headers, 24 data=json.dumps({"order_by": ["creation_time DESC", "name ASC"], "max_results": 50}), 25 ) 26 27 if should_block: 28 assert response.status_code == expected_status 29 assert ( 30 b"Invalid Host header" in response.data 31 or b"Cross-origin request blocked" in response.data 32 ) 33 else: 34 assert response.status_code != 403 35 36 37 @pytest.mark.parametrize( 38 ("origin", "endpoint", "expected_blocked"), 39 [ 40 ("http://malicious-site.com", "/api/2.0/mlflow/experiments/create", True), 41 ("http://localhost:3000", "/api/2.0/mlflow/experiments/search", False), 42 ], 43 ) 44 def test_cors_for_state_changing_requests(mlflow_app_client, origin, endpoint, expected_blocked): 45 response = mlflow_app_client.post( 46 endpoint, 47 headers={"Origin": origin, "Content-Type": "application/json"}, 48 data=json.dumps({"name": "test-experiment"} if "create" in endpoint else {}), 49 ) 50 51 if expected_blocked: 52 assert response.status_code == 403 53 assert b"Cross-origin request blocked" in response.data 54 else: 55 assert response.status_code != 403 56 57 58 def test_cors_with_configured_origins(monkeypatch: pytest.MonkeyPatch): 59 monkeypatch.setenv("MLFLOW_SERVER_CORS_ALLOWED_ORIGINS", "https://trusted-app.com") 60 61 from flask import Flask 62 63 from mlflow.server import handlers, security 64 65 app = Flask(__name__) 66 for http_path, handler, methods in handlers.get_endpoints(): 67 app.add_url_rule(http_path, handler.__name__, handler, methods=methods) 68 69 security.init_security_middleware(app) 70 client = Client(app) 71 72 test_cases = [ 73 ("https://trusted-app.com", False), 74 ("http://evil.com", True), 75 ] 76 77 for origin, should_block in test_cases: 78 response = client.post( 79 "/api/2.0/mlflow/experiments/search", 80 headers={"Origin": origin, "Content-Type": "application/json"}, 81 data=json.dumps({}), 82 ) 83 84 if should_block: 85 assert response.status_code == 403 86 else: 87 assert response.status_code != 403 88 89 90 def test_security_headers_on_responses(mlflow_app_client): 91 response = mlflow_app_client.get("/health") 92 assert response.headers.get("X-Content-Type-Options") == "nosniff" 93 assert response.headers.get("X-Frame-Options") == "SAMEORIGIN" 94 95 96 @pytest.mark.parametrize( 97 ("origin", "expected_status", "should_have_cors"), 98 [ 99 ("http://localhost:3000", 204, True), 100 ("http://evil.com", None, False), 101 ], 102 ) 103 def test_preflight_options_requests(mlflow_app_client, origin, expected_status, should_have_cors): 104 response = mlflow_app_client.options( 105 "/api/2.0/mlflow/experiments/search", 106 headers={ 107 "Origin": origin, 108 "Access-Control-Request-Method": "POST", 109 "Access-Control-Request-Headers": "Content-Type", 110 }, 111 ) 112 113 if expected_status: 114 assert response.status_code == expected_status 115 116 if should_have_cors: 117 assert response.headers.get("Access-Control-Allow-Origin") == origin 118 assert "POST" in response.headers.get("Access-Control-Allow-Methods", "") 119 else: 120 assert ( 121 "Access-Control-Allow-Origin" not in response.headers 122 or response.headers.get("Access-Control-Allow-Origin") != origin 123 )