test_workspace_middleware.py
1 from __future__ import annotations 2 3 import pytest 4 import werkzeug 5 from fastapi import FastAPI 6 from fastapi.testclient import TestClient 7 from flask import Flask 8 9 from mlflow.entities import Workspace 10 from mlflow.environment_variables import MLFLOW_ENABLE_WORKSPACES 11 from mlflow.exceptions import MlflowException 12 from mlflow.server import app as flask_app 13 from mlflow.server.fastapi_app import add_fastapi_workspace_middleware 14 from mlflow.server.job_api import job_api_router 15 from mlflow.server.workspace_helpers import ( 16 WORKSPACE_HEADER_NAME, 17 workspace_before_request_handler, 18 workspace_teardown_request_handler, 19 ) 20 from mlflow.utils import workspace_context 21 22 23 @pytest.fixture 24 def flask_workspace_app(monkeypatch): 25 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 26 if not hasattr(werkzeug, "__version__"): 27 werkzeug.__version__ = "tests" 28 29 app = Flask(__name__) 30 app.before_request(workspace_before_request_handler) 31 app.teardown_request(workspace_teardown_request_handler) 32 33 @app.route("/ping") 34 def _ping(): 35 return workspace_context.get_request_workspace() or "none" 36 37 return app 38 39 40 def test_flask_workspace_middleware_sets_context(flask_workspace_app, monkeypatch): 41 class DummyWorkspaceStore: 42 def get_workspace(self, name): 43 return Workspace(name=name) 44 45 store = DummyWorkspaceStore() 46 monkeypatch.setattr( 47 "mlflow.server.workspace_helpers._get_workspace_store", 48 lambda workspace_uri=None, tracking_uri=None: store, 49 ) 50 51 client = flask_workspace_app.test_client() 52 resp = client.get("/ping", headers={WORKSPACE_HEADER_NAME: "team-a"}) 53 assert resp.data.decode() == "team-a" 54 assert workspace_context.get_request_workspace() is None 55 56 57 def test_flask_workspace_middleware_requires_header(flask_workspace_app, monkeypatch): 58 class DefaultlessWorkspaceStore: 59 def get_default_workspace(self): 60 raise MlflowException.invalid_parameter_value("Active workspace is required.") 61 62 store = DefaultlessWorkspaceStore() 63 monkeypatch.setattr( 64 "mlflow.server.workspace_helpers._get_workspace_store", 65 lambda workspace_uri=None, tracking_uri=None: store, 66 ) 67 68 client = flask_workspace_app.test_client() 69 resp = client.get("/ping") 70 assert resp.status_code == 400 71 assert "Active workspace is required" in resp.json["message"] 72 assert workspace_context.get_request_workspace() is None 73 74 75 def _fastapi_workspace_app(monkeypatch): 76 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 77 app = FastAPI() 78 add_fastapi_workspace_middleware(app) 79 80 ping_path = f"{job_api_router.prefix}/ping" 81 82 @app.get(ping_path) 83 async def ping(): 84 return {"workspace": workspace_context.get_request_workspace()} 85 86 return app, ping_path 87 88 89 def test_fastapi_workspace_middleware_sets_context(monkeypatch): 90 app, ping_path = _fastapi_workspace_app(monkeypatch) 91 monkeypatch.setattr( 92 "mlflow.server.fastapi_app.resolve_workspace_for_request_if_enabled", 93 lambda _path, header: Workspace(name=header), 94 ) 95 96 client = TestClient(app) 97 resp = client.get(ping_path, headers={WORKSPACE_HEADER_NAME: "team-fast"}) 98 assert resp.status_code == 200 99 assert resp.json() == {"workspace": "team-fast"} 100 assert workspace_context.get_request_workspace() is None 101 102 103 def test_fastapi_workspace_middleware_handles_missing_header(monkeypatch): 104 app, ping_path = _fastapi_workspace_app(monkeypatch) 105 monkeypatch.setattr( 106 "mlflow.server.fastapi_app.resolve_workspace_for_request_if_enabled", 107 lambda _path, _header: None, 108 ) 109 110 client = TestClient(app) 111 resp = client.get(ping_path) 112 assert resp.status_code == 200 113 assert resp.json() == {"workspace": None} 114 assert workspace_context.get_request_workspace() is None 115 116 117 def test_server_info_workspaces_enabled(monkeypatch): 118 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 119 client = flask_app.test_client() 120 resp = client.get("/api/3.0/mlflow/server-info") 121 assert resp.status_code == 200 122 data = resp.get_json() 123 assert data["workspaces_enabled"] is True 124 125 # Disable workspaces and ensure the endpoint reflects the change. 126 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false") 127 resp = client.get("/api/3.0/mlflow/server-info") 128 assert resp.status_code == 200 129 data = resp.get_json() 130 assert data["workspaces_enabled"] is False 131 132 133 def test_server_info_skips_workspace_resolution(monkeypatch): 134 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 135 136 def _raise_if_called(_header_workspace): 137 raise AssertionError("workspace resolution should not run for server-info") 138 139 monkeypatch.setattr( 140 "mlflow.server.workspace_helpers.resolve_workspace_from_header", _raise_if_called 141 ) 142 143 client = flask_app.test_client() 144 resp = client.get("/api/3.0/mlflow/server-info", headers={WORKSPACE_HEADER_NAME: "missing"}) 145 assert resp.status_code == 200 146 data = resp.get_json() 147 assert data["workspaces_enabled"] is True 148 149 150 def test_server_info_with_workspace_header_when_workspaces_disabled(monkeypatch): 151 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false") 152 client = flask_app.test_client() 153 resp = client.get( 154 "/api/3.0/mlflow/server-info", headers={WORKSPACE_HEADER_NAME: "some-workspace"} 155 ) 156 assert resp.status_code == 200 157 data = resp.get_json() 158 assert data["workspaces_enabled"] is False 159 160 161 def test_fastapi_wsgi_flask_workspace_propagation(monkeypatch): 162 from fastapi.middleware.wsgi import WSGIMiddleware 163 164 monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true") 165 if not hasattr(werkzeug, "__version__"): 166 werkzeug.__version__ = "tests" 167 168 flask_resolution_count = [] 169 170 original_resolve = workspace_context.is_request_workspace_resolved 171 172 def tracking_is_resolved(): 173 result = original_resolve() 174 flask_resolution_count.append(result) 175 return result 176 177 monkeypatch.setattr( 178 "mlflow.server.workspace_helpers.workspace_context.is_request_workspace_resolved", 179 tracking_is_resolved, 180 ) 181 182 test_flask_app = Flask(__name__) 183 test_flask_app.before_request(workspace_before_request_handler) 184 test_flask_app.teardown_request(workspace_teardown_request_handler) 185 186 @test_flask_app.route("/flask-ping") 187 def _flask_ping(): 188 return workspace_context.get_request_workspace() or "none" 189 190 fastapi_app = FastAPI() 191 add_fastapi_workspace_middleware(fastapi_app) 192 fastapi_app.mount("/", WSGIMiddleware(test_flask_app)) 193 194 monkeypatch.setattr( 195 "mlflow.server.fastapi_app.resolve_workspace_for_request_if_enabled", 196 lambda _path, header: Workspace(name=header) if header else None, 197 ) 198 199 client = TestClient(fastapi_app) 200 resp = client.get("/flask-ping", headers={WORKSPACE_HEADER_NAME: "team-wsgi"}) 201 202 assert resp.status_code == 200 203 assert resp.text == "team-wsgi" 204 assert len(flask_resolution_count) == 1 205 assert flask_resolution_count[0] is True