/ tests / server / test_security.py
test_security.py
  1  import pytest
  2  from fastapi import FastAPI
  3  from flask import Flask
  4  from starlette.testclient import TestClient
  5  from werkzeug.test import Client
  6  
  7  from mlflow.server import security
  8  from mlflow.server.fastapi_security import init_fastapi_security
  9  from mlflow.server.security_utils import is_allowed_host_header, is_api_endpoint
 10  
 11  
 12  def test_default_allowed_hosts():
 13      hosts = security.get_allowed_hosts()
 14      assert "localhost" in hosts
 15      assert "127.0.0.1" in hosts
 16      assert "[::1]" in hosts
 17      assert "localhost:*" in hosts
 18      assert "127.0.0.1:*" in hosts
 19      assert "[[]::1]:*" in hosts
 20      assert "192.168.*" in hosts
 21      assert "10.*" in hosts
 22  
 23  
 24  def test_custom_allowed_hosts(monkeypatch: pytest.MonkeyPatch):
 25      monkeypatch.setenv("MLFLOW_SERVER_ALLOWED_HOSTS", "example.com,app.example.com")
 26      hosts = security.get_allowed_hosts()
 27      assert "example.com" in hosts
 28      assert "app.example.com" in hosts
 29  
 30  
 31  @pytest.mark.parametrize(
 32      ("host_header", "expected_status", "expected_error"),
 33      [
 34          ("localhost", 200, None),
 35          ("127.0.0.1", 200, None),
 36          ("evil.attacker.com", 403, b"Invalid Host header"),
 37      ],
 38  )
 39  def test_dns_rebinding_protection(
 40      test_app, host_header, expected_status, expected_error, monkeypatch: pytest.MonkeyPatch
 41  ):
 42      monkeypatch.setenv("MLFLOW_SERVER_ALLOWED_HOSTS", "localhost,127.0.0.1")
 43      security.init_security_middleware(test_app)
 44      client = Client(test_app)
 45  
 46      response = client.get("/test", headers={"Host": host_header})
 47      assert response.status_code == expected_status
 48      if expected_error:
 49          assert expected_error in response.data
 50  
 51  
 52  @pytest.mark.parametrize(
 53      ("method", "origin", "expected_status", "expected_cors_header"),
 54      [
 55          ("POST", "http://localhost:3000", 200, "http://localhost:3000"),
 56          ("POST", "http://evil.com", 403, None),
 57          ("POST", None, 200, None),
 58          ("GET", "http://evil.com", 200, None),
 59      ],
 60  )
 61  def test_cors_protection(
 62      test_app, method, origin, expected_status, expected_cors_header, monkeypatch: pytest.MonkeyPatch
 63  ):
 64      monkeypatch.setenv(
 65          "MLFLOW_SERVER_CORS_ALLOWED_ORIGINS", "http://localhost:3000,https://app.example.com"
 66      )
 67      security.init_security_middleware(test_app)
 68      client = Client(test_app)
 69  
 70      headers = {"Origin": origin} if origin else {}
 71      response = getattr(client, method.lower())("/api/2.0/mlflow/experiments/list", headers=headers)
 72      assert response.status_code == expected_status
 73  
 74      if expected_cors_header:
 75          assert response.headers.get("Access-Control-Allow-Origin") == expected_cors_header
 76  
 77  
 78  def test_insecure_cors_mode(test_app, monkeypatch: pytest.MonkeyPatch):
 79      monkeypatch.setenv("MLFLOW_SERVER_CORS_ALLOWED_ORIGINS", "*")
 80      security.init_security_middleware(test_app)
 81      client = Client(test_app)
 82  
 83      response = client.post(
 84          "/api/2.0/mlflow/experiments/list", headers={"Origin": "http://evil.com"}
 85      )
 86      assert response.status_code == 200
 87      assert response.headers.get("Access-Control-Allow-Origin") == "http://evil.com"
 88  
 89  
 90  @pytest.mark.parametrize(
 91      ("origin", "expected_cors_header"),
 92      [
 93          ("http://localhost:3000", "http://localhost:3000"),
 94          ("http://evil.com", None),
 95      ],
 96  )
 97  def test_preflight_options_request(
 98      test_app, origin, expected_cors_header, monkeypatch: pytest.MonkeyPatch
 99  ):
100      monkeypatch.setenv("MLFLOW_SERVER_CORS_ALLOWED_ORIGINS", "http://localhost:3000")
101      security.init_security_middleware(test_app)
102      client = Client(test_app)
103  
104      response = client.options(
105          "/api/2.0/mlflow/experiments/list",
106          headers={
107              "Origin": origin,
108              "Access-Control-Request-Method": "POST",
109              "Access-Control-Request-Headers": "Content-Type",
110          },
111      )
112      assert response.status_code == 204
113  
114      if expected_cors_header:
115          assert response.headers.get("Access-Control-Allow-Origin") == expected_cors_header
116  
117  
118  def test_security_headers(test_app):
119      security.init_security_middleware(test_app)
120      client = Client(test_app)
121  
122      response = client.get("/test")
123      assert response.headers.get("X-Content-Type-Options") == "nosniff"
124      assert response.headers.get("X-Frame-Options") == "SAMEORIGIN"
125  
126  
127  def test_disable_security_middleware(test_app, monkeypatch: pytest.MonkeyPatch):
128      monkeypatch.setenv("MLFLOW_SERVER_DISABLE_SECURITY_MIDDLEWARE", "true")
129      security.init_security_middleware(test_app)
130      client = Client(test_app)
131  
132      response = client.get("/test")
133      assert "X-Content-Type-Options" not in response.headers
134      assert "X-Frame-Options" not in response.headers
135  
136      response = client.get("/test", headers={"Host": "evil.com"})
137      assert response.status_code == 200
138  
139  
140  def test_x_frame_options_configuration(monkeypatch: pytest.MonkeyPatch):
141      app = Flask(__name__)
142  
143      @app.route("/test")
144      def test():
145          return "OK"
146  
147      monkeypatch.setenv("MLFLOW_SERVER_X_FRAME_OPTIONS", "DENY")
148      security.init_security_middleware(app)
149      client = Client(app)
150      response = client.get("/test")
151      assert response.headers.get("X-Frame-Options") == "DENY"
152  
153      app2 = Flask(__name__)
154  
155      @app2.route("/test")
156      def test2():
157          return "OK"
158  
159      # Reset for the second app
160      monkeypatch.setenv("MLFLOW_SERVER_X_FRAME_OPTIONS", "NONE")
161      security.init_security_middleware(app2)
162      client = Client(app2)
163      response = client.get("/test")
164      assert "X-Frame-Options" not in response.headers
165  
166  
167  def test_notebook_trace_renderer_skips_x_frame_options(monkeypatch: pytest.MonkeyPatch):
168      from mlflow.tracing.constant import TRACE_RENDERER_ASSET_PATH
169  
170      app = Flask(__name__)
171  
172      @app.route(f"{TRACE_RENDERER_ASSET_PATH}/index.html")
173      def notebook_renderer():
174          return "<html>trace renderer</html>"
175  
176      @app.route(f"{TRACE_RENDERER_ASSET_PATH}/js/main.js")
177      def notebook_renderer_js():
178          return "console.log('trace renderer');"
179  
180      @app.route("/static-files/other-page.html")
181      def other_page():
182          return "<html>other page</html>"
183  
184      # Set X-Frame-Options to DENY to test that it's skipped for notebook renderer
185      monkeypatch.setenv("MLFLOW_SERVER_X_FRAME_OPTIONS", "DENY")
186      security.init_security_middleware(app)
187      client = Client(app)
188  
189      response = client.get(f"{TRACE_RENDERER_ASSET_PATH}/index.html")
190      assert response.status_code == 200
191      assert "X-Frame-Options" not in response.headers
192  
193      response = client.get(f"{TRACE_RENDERER_ASSET_PATH}/js/main.js")
194      assert response.status_code == 200
195      assert "X-Frame-Options" not in response.headers
196  
197      response = client.get("/static-files/other-page.html")
198      assert response.status_code == 200
199      assert response.headers.get("X-Frame-Options") == "DENY"
200  
201  
202  @pytest.mark.parametrize(
203      ("allowed_hosts", "host_header", "expected_status"),
204      [
205          ("*", "any.domain.com", 200),
206          ("*.example.com", "app.example.com", 200),
207          ("*.example.com", "sub.app.example.com", 200),
208          ("*.example.com", "evil.com", 403),
209      ],
210  )
211  def test_wildcard_hosts(
212      test_app, allowed_hosts, host_header, expected_status, monkeypatch: pytest.MonkeyPatch
213  ):
214      monkeypatch.setenv("MLFLOW_SERVER_ALLOWED_HOSTS", allowed_hosts)
215      security.init_security_middleware(test_app)
216      client = Client(test_app)
217  
218      response = client.get("/test", headers={"Host": host_header})
219      assert response.status_code == expected_status
220  
221  
222  @pytest.mark.parametrize(
223      ("allowed_origins", "origin", "expected_status"),
224      [
225          ("*", "http://any.domain.com", 200),
226          ("http://*.example.com", "http://app.example.com", 200),
227          ("http://*.example.com", "http://sub.app.example.com", 200),
228          ("http://*.example.com", "http://evil.com", 403),
229      ],
230  )
231  def test_wildcard_origins(
232      test_app, allowed_origins, origin, expected_status, monkeypatch: pytest.MonkeyPatch
233  ):
234      monkeypatch.setenv("MLFLOW_SERVER_CORS_ALLOWED_ORIGINS", allowed_origins)
235      security.init_security_middleware(test_app)
236      client = Client(test_app)
237  
238      response = client.post("/api/2.0/mlflow/experiments/list", headers={"Origin": origin})
239      assert response.status_code == expected_status
240  
241  
242  @pytest.mark.parametrize(
243      ("endpoint", "host_header", "expected_status"),
244      [
245          ("/health", "evil.com", 200),
246          ("/test", "evil.com", 403),
247      ],
248  )
249  def test_endpoint_security_bypass(
250      test_app, endpoint, host_header, expected_status, monkeypatch: pytest.MonkeyPatch
251  ):
252      monkeypatch.setenv("MLFLOW_SERVER_ALLOWED_HOSTS", "localhost")
253      security.init_security_middleware(test_app)
254      client = Client(test_app)
255  
256      response = client.get(endpoint, headers={"Host": host_header})
257      assert response.status_code == expected_status
258  
259  
260  @pytest.mark.parametrize(
261      ("hostname", "expected_valid"),
262      [
263          ("192.168.1.1", True),
264          ("10.0.0.1", True),
265          ("172.16.0.1", True),
266          ("127.0.0.1", True),
267          ("localhost", True),
268          ("[::1]", True),
269          ("192.168.1.1:8080", True),
270          ("[::1]:8080", True),
271          ("evil.com", False),
272      ],
273  )
274  def test_host_validation(hostname, expected_valid):
275      hosts = security.get_allowed_hosts()
276      assert is_allowed_host_header(hosts, hostname) == expected_valid
277  
278  
279  @pytest.mark.parametrize(
280      ("env_var", "env_value", "expected_result"),
281      [
282          (
283              "MLFLOW_SERVER_CORS_ALLOWED_ORIGINS",
284              "http://app1.com,http://app2.com",
285              ["http://app1.com", "http://app2.com"],
286          ),
287          ("MLFLOW_SERVER_ALLOWED_HOSTS", "app1.com,app2.com:8080", ["app1.com", "app2.com:8080"]),
288      ],
289  )
290  def test_environment_variable_configuration(
291      env_var, env_value, expected_result, monkeypatch: pytest.MonkeyPatch
292  ):
293      monkeypatch.setenv(env_var, env_value)
294      if "ORIGINS" in env_var:
295          result = security.get_allowed_origins()
296          for expected in expected_result:
297              assert expected in result
298      else:
299          result = security.get_allowed_hosts()
300          for expected in expected_result:
301              assert expected in result
302  
303  
304  @pytest.mark.parametrize(
305      ("path", "expected"),
306      [
307          ("/api/2.0/mlflow/experiments/list", True),
308          ("/ajax-api/2.0/mlflow/experiments/list", True),
309          ("/ajax-api/3.0/mlflow/runs/search", True),
310          ("/api/test", False),
311          ("/test", False),
312          ("/health", False),
313          ("/static/index.html", False),
314      ],
315  )
316  def test_is_api_endpoint(path, expected):
317      assert is_api_endpoint(path) == expected
318  
319  
320  @pytest.mark.parametrize(
321      ("origin", "expect_cors_header"),
322      [
323          ("http://localhost:3000", True),
324          ("http://127.0.0.1:5000", True),
325          ("http://[::1]:8080", True),
326          ("http://evil.com", False),
327      ],
328  )
329  def test_fastapi_cors_allows_localhost_origins(fastapi_client, origin, expect_cors_header):
330      response = fastapi_client.get(
331          "/api/2.0/mlflow/experiments/list", headers={"Host": "localhost", "Origin": origin}
332      )
333      if expect_cors_header:
334          assert response.headers.get("access-control-allow-origin") == origin
335      else:
336          assert response.headers.get("access-control-allow-origin") is None
337  
338  
339  def test_fastapi_cors_allows_configured_origin(monkeypatch: pytest.MonkeyPatch):
340      monkeypatch.setenv("MLFLOW_SERVER_CORS_ALLOWED_ORIGINS", "https://trusted.com")
341  
342      app = FastAPI()
343  
344      @app.api_route("/api/2.0/mlflow/experiments/list", methods=["GET", "POST", "OPTIONS"])
345      async def api_endpoint():
346          return {"ok": True}
347  
348      init_fastapi_security(app)
349      client = TestClient(app, raise_server_exceptions=False)
350  
351      response = client.get(
352          "/api/2.0/mlflow/experiments/list",
353          headers={"Host": "localhost", "Origin": "https://trusted.com"},
354      )
355      assert response.headers.get("access-control-allow-origin") == "https://trusted.com"
356  
357      response = client.get(
358          "/api/2.0/mlflow/experiments/list",
359          headers={"Host": "localhost", "Origin": "http://evil.com"},
360      )
361      assert response.headers.get("access-control-allow-origin") is None