test_security.py
1 import pytest 2 from fastapi import FastAPI 3 from flask import Flask 4 from starlette.testclient import TestClient 5 from werkzeug.test import Client 6 7 from mlflow.server import security 8 from mlflow.server.fastapi_security import init_fastapi_security 9 from mlflow.server.security_utils import is_allowed_host_header, is_api_endpoint 10 11 12 def test_default_allowed_hosts(): 13 hosts = security.get_allowed_hosts() 14 assert "localhost" in hosts 15 assert "127.0.0.1" in hosts 16 assert "[::1]" in hosts 17 assert "localhost:*" in hosts 18 assert "127.0.0.1:*" in hosts 19 assert "[[]::1]:*" in hosts 20 assert "192.168.*" in hosts 21 assert "10.*" in hosts 22 23 24 def test_custom_allowed_hosts(monkeypatch: pytest.MonkeyPatch): 25 monkeypatch.setenv("MLFLOW_SERVER_ALLOWED_HOSTS", "example.com,app.example.com") 26 hosts = security.get_allowed_hosts() 27 assert "example.com" in hosts 28 assert "app.example.com" in hosts 29 30 31 @pytest.mark.parametrize( 32 ("host_header", "expected_status", "expected_error"), 33 [ 34 ("localhost", 200, None), 35 ("127.0.0.1", 200, None), 36 ("evil.attacker.com", 403, b"Invalid Host header"), 37 ], 38 ) 39 def test_dns_rebinding_protection( 40 test_app, host_header, expected_status, expected_error, monkeypatch: pytest.MonkeyPatch 41 ): 42 monkeypatch.setenv("MLFLOW_SERVER_ALLOWED_HOSTS", "localhost,127.0.0.1") 43 security.init_security_middleware(test_app) 44 client = Client(test_app) 45 46 response = client.get("/test", headers={"Host": host_header}) 47 assert response.status_code == expected_status 48 if expected_error: 49 assert expected_error in response.data 50 51 52 @pytest.mark.parametrize( 53 ("method", "origin", "expected_status", "expected_cors_header"), 54 [ 55 ("POST", "http://localhost:3000", 200, "http://localhost:3000"), 56 ("POST", "http://evil.com", 403, None), 57 ("POST", None, 200, None), 58 ("GET", "http://evil.com", 200, None), 59 ], 60 ) 61 def test_cors_protection( 62 test_app, method, origin, expected_status, expected_cors_header, monkeypatch: pytest.MonkeyPatch 63 ): 64 monkeypatch.setenv( 65 "MLFLOW_SERVER_CORS_ALLOWED_ORIGINS", "http://localhost:3000,https://app.example.com" 66 ) 67 security.init_security_middleware(test_app) 68 client = Client(test_app) 69 70 headers = {"Origin": origin} if origin else {} 71 response = getattr(client, method.lower())("/api/2.0/mlflow/experiments/list", headers=headers) 72 assert response.status_code == expected_status 73 74 if expected_cors_header: 75 assert response.headers.get("Access-Control-Allow-Origin") == expected_cors_header 76 77 78 def test_insecure_cors_mode(test_app, monkeypatch: pytest.MonkeyPatch): 79 monkeypatch.setenv("MLFLOW_SERVER_CORS_ALLOWED_ORIGINS", "*") 80 security.init_security_middleware(test_app) 81 client = Client(test_app) 82 83 response = client.post( 84 "/api/2.0/mlflow/experiments/list", headers={"Origin": "http://evil.com"} 85 ) 86 assert response.status_code == 200 87 assert response.headers.get("Access-Control-Allow-Origin") == "http://evil.com" 88 89 90 @pytest.mark.parametrize( 91 ("origin", "expected_cors_header"), 92 [ 93 ("http://localhost:3000", "http://localhost:3000"), 94 ("http://evil.com", None), 95 ], 96 ) 97 def test_preflight_options_request( 98 test_app, origin, expected_cors_header, monkeypatch: pytest.MonkeyPatch 99 ): 100 monkeypatch.setenv("MLFLOW_SERVER_CORS_ALLOWED_ORIGINS", "http://localhost:3000") 101 security.init_security_middleware(test_app) 102 client = Client(test_app) 103 104 response = client.options( 105 "/api/2.0/mlflow/experiments/list", 106 headers={ 107 "Origin": origin, 108 "Access-Control-Request-Method": "POST", 109 "Access-Control-Request-Headers": "Content-Type", 110 }, 111 ) 112 assert response.status_code == 204 113 114 if expected_cors_header: 115 assert response.headers.get("Access-Control-Allow-Origin") == expected_cors_header 116 117 118 def test_security_headers(test_app): 119 security.init_security_middleware(test_app) 120 client = Client(test_app) 121 122 response = client.get("/test") 123 assert response.headers.get("X-Content-Type-Options") == "nosniff" 124 assert response.headers.get("X-Frame-Options") == "SAMEORIGIN" 125 126 127 def test_disable_security_middleware(test_app, monkeypatch: pytest.MonkeyPatch): 128 monkeypatch.setenv("MLFLOW_SERVER_DISABLE_SECURITY_MIDDLEWARE", "true") 129 security.init_security_middleware(test_app) 130 client = Client(test_app) 131 132 response = client.get("/test") 133 assert "X-Content-Type-Options" not in response.headers 134 assert "X-Frame-Options" not in response.headers 135 136 response = client.get("/test", headers={"Host": "evil.com"}) 137 assert response.status_code == 200 138 139 140 def test_x_frame_options_configuration(monkeypatch: pytest.MonkeyPatch): 141 app = Flask(__name__) 142 143 @app.route("/test") 144 def test(): 145 return "OK" 146 147 monkeypatch.setenv("MLFLOW_SERVER_X_FRAME_OPTIONS", "DENY") 148 security.init_security_middleware(app) 149 client = Client(app) 150 response = client.get("/test") 151 assert response.headers.get("X-Frame-Options") == "DENY" 152 153 app2 = Flask(__name__) 154 155 @app2.route("/test") 156 def test2(): 157 return "OK" 158 159 # Reset for the second app 160 monkeypatch.setenv("MLFLOW_SERVER_X_FRAME_OPTIONS", "NONE") 161 security.init_security_middleware(app2) 162 client = Client(app2) 163 response = client.get("/test") 164 assert "X-Frame-Options" not in response.headers 165 166 167 def test_notebook_trace_renderer_skips_x_frame_options(monkeypatch: pytest.MonkeyPatch): 168 from mlflow.tracing.constant import TRACE_RENDERER_ASSET_PATH 169 170 app = Flask(__name__) 171 172 @app.route(f"{TRACE_RENDERER_ASSET_PATH}/index.html") 173 def notebook_renderer(): 174 return "<html>trace renderer</html>" 175 176 @app.route(f"{TRACE_RENDERER_ASSET_PATH}/js/main.js") 177 def notebook_renderer_js(): 178 return "console.log('trace renderer');" 179 180 @app.route("/static-files/other-page.html") 181 def other_page(): 182 return "<html>other page</html>" 183 184 # Set X-Frame-Options to DENY to test that it's skipped for notebook renderer 185 monkeypatch.setenv("MLFLOW_SERVER_X_FRAME_OPTIONS", "DENY") 186 security.init_security_middleware(app) 187 client = Client(app) 188 189 response = client.get(f"{TRACE_RENDERER_ASSET_PATH}/index.html") 190 assert response.status_code == 200 191 assert "X-Frame-Options" not in response.headers 192 193 response = client.get(f"{TRACE_RENDERER_ASSET_PATH}/js/main.js") 194 assert response.status_code == 200 195 assert "X-Frame-Options" not in response.headers 196 197 response = client.get("/static-files/other-page.html") 198 assert response.status_code == 200 199 assert response.headers.get("X-Frame-Options") == "DENY" 200 201 202 @pytest.mark.parametrize( 203 ("allowed_hosts", "host_header", "expected_status"), 204 [ 205 ("*", "any.domain.com", 200), 206 ("*.example.com", "app.example.com", 200), 207 ("*.example.com", "sub.app.example.com", 200), 208 ("*.example.com", "evil.com", 403), 209 ], 210 ) 211 def test_wildcard_hosts( 212 test_app, allowed_hosts, host_header, expected_status, monkeypatch: pytest.MonkeyPatch 213 ): 214 monkeypatch.setenv("MLFLOW_SERVER_ALLOWED_HOSTS", allowed_hosts) 215 security.init_security_middleware(test_app) 216 client = Client(test_app) 217 218 response = client.get("/test", headers={"Host": host_header}) 219 assert response.status_code == expected_status 220 221 222 @pytest.mark.parametrize( 223 ("allowed_origins", "origin", "expected_status"), 224 [ 225 ("*", "http://any.domain.com", 200), 226 ("http://*.example.com", "http://app.example.com", 200), 227 ("http://*.example.com", "http://sub.app.example.com", 200), 228 ("http://*.example.com", "http://evil.com", 403), 229 ], 230 ) 231 def test_wildcard_origins( 232 test_app, allowed_origins, origin, expected_status, monkeypatch: pytest.MonkeyPatch 233 ): 234 monkeypatch.setenv("MLFLOW_SERVER_CORS_ALLOWED_ORIGINS", allowed_origins) 235 security.init_security_middleware(test_app) 236 client = Client(test_app) 237 238 response = client.post("/api/2.0/mlflow/experiments/list", headers={"Origin": origin}) 239 assert response.status_code == expected_status 240 241 242 @pytest.mark.parametrize( 243 ("endpoint", "host_header", "expected_status"), 244 [ 245 ("/health", "evil.com", 200), 246 ("/test", "evil.com", 403), 247 ], 248 ) 249 def test_endpoint_security_bypass( 250 test_app, endpoint, host_header, expected_status, monkeypatch: pytest.MonkeyPatch 251 ): 252 monkeypatch.setenv("MLFLOW_SERVER_ALLOWED_HOSTS", "localhost") 253 security.init_security_middleware(test_app) 254 client = Client(test_app) 255 256 response = client.get(endpoint, headers={"Host": host_header}) 257 assert response.status_code == expected_status 258 259 260 @pytest.mark.parametrize( 261 ("hostname", "expected_valid"), 262 [ 263 ("192.168.1.1", True), 264 ("10.0.0.1", True), 265 ("172.16.0.1", True), 266 ("127.0.0.1", True), 267 ("localhost", True), 268 ("[::1]", True), 269 ("192.168.1.1:8080", True), 270 ("[::1]:8080", True), 271 ("evil.com", False), 272 ], 273 ) 274 def test_host_validation(hostname, expected_valid): 275 hosts = security.get_allowed_hosts() 276 assert is_allowed_host_header(hosts, hostname) == expected_valid 277 278 279 @pytest.mark.parametrize( 280 ("env_var", "env_value", "expected_result"), 281 [ 282 ( 283 "MLFLOW_SERVER_CORS_ALLOWED_ORIGINS", 284 "http://app1.com,http://app2.com", 285 ["http://app1.com", "http://app2.com"], 286 ), 287 ("MLFLOW_SERVER_ALLOWED_HOSTS", "app1.com,app2.com:8080", ["app1.com", "app2.com:8080"]), 288 ], 289 ) 290 def test_environment_variable_configuration( 291 env_var, env_value, expected_result, monkeypatch: pytest.MonkeyPatch 292 ): 293 monkeypatch.setenv(env_var, env_value) 294 if "ORIGINS" in env_var: 295 result = security.get_allowed_origins() 296 for expected in expected_result: 297 assert expected in result 298 else: 299 result = security.get_allowed_hosts() 300 for expected in expected_result: 301 assert expected in result 302 303 304 @pytest.mark.parametrize( 305 ("path", "expected"), 306 [ 307 ("/api/2.0/mlflow/experiments/list", True), 308 ("/ajax-api/2.0/mlflow/experiments/list", True), 309 ("/ajax-api/3.0/mlflow/runs/search", True), 310 ("/api/test", False), 311 ("/test", False), 312 ("/health", False), 313 ("/static/index.html", False), 314 ], 315 ) 316 def test_is_api_endpoint(path, expected): 317 assert is_api_endpoint(path) == expected 318 319 320 @pytest.mark.parametrize( 321 ("origin", "expect_cors_header"), 322 [ 323 ("http://localhost:3000", True), 324 ("http://127.0.0.1:5000", True), 325 ("http://[::1]:8080", True), 326 ("http://evil.com", False), 327 ], 328 ) 329 def test_fastapi_cors_allows_localhost_origins(fastapi_client, origin, expect_cors_header): 330 response = fastapi_client.get( 331 "/api/2.0/mlflow/experiments/list", headers={"Host": "localhost", "Origin": origin} 332 ) 333 if expect_cors_header: 334 assert response.headers.get("access-control-allow-origin") == origin 335 else: 336 assert response.headers.get("access-control-allow-origin") is None 337 338 339 def test_fastapi_cors_allows_configured_origin(monkeypatch: pytest.MonkeyPatch): 340 monkeypatch.setenv("MLFLOW_SERVER_CORS_ALLOWED_ORIGINS", "https://trusted.com") 341 342 app = FastAPI() 343 344 @app.api_route("/api/2.0/mlflow/experiments/list", methods=["GET", "POST", "OPTIONS"]) 345 async def api_endpoint(): 346 return {"ok": True} 347 348 init_fastapi_security(app) 349 client = TestClient(app, raise_server_exceptions=False) 350 351 response = client.get( 352 "/api/2.0/mlflow/experiments/list", 353 headers={"Host": "localhost", "Origin": "https://trusted.com"}, 354 ) 355 assert response.headers.get("access-control-allow-origin") == "https://trusted.com" 356 357 response = client.get( 358 "/api/2.0/mlflow/experiments/list", 359 headers={"Host": "localhost", "Origin": "http://evil.com"}, 360 ) 361 assert response.headers.get("access-control-allow-origin") is None