/ tests / server / test_workspace_middleware.py
test_workspace_middleware.py
  1  from __future__ import annotations
  2  
  3  import pytest
  4  import werkzeug
  5  from fastapi import FastAPI
  6  from fastapi.testclient import TestClient
  7  from flask import Flask
  8  
  9  from mlflow.entities import Workspace
 10  from mlflow.environment_variables import MLFLOW_ENABLE_WORKSPACES
 11  from mlflow.exceptions import MlflowException
 12  from mlflow.server import app as flask_app
 13  from mlflow.server.fastapi_app import add_fastapi_workspace_middleware
 14  from mlflow.server.job_api import job_api_router
 15  from mlflow.server.workspace_helpers import (
 16      WORKSPACE_HEADER_NAME,
 17      workspace_before_request_handler,
 18      workspace_teardown_request_handler,
 19  )
 20  from mlflow.utils import workspace_context
 21  
 22  
 23  @pytest.fixture
 24  def flask_workspace_app(monkeypatch):
 25      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
 26      if not hasattr(werkzeug, "__version__"):
 27          werkzeug.__version__ = "tests"
 28  
 29      app = Flask(__name__)
 30      app.before_request(workspace_before_request_handler)
 31      app.teardown_request(workspace_teardown_request_handler)
 32  
 33      @app.route("/ping")
 34      def _ping():
 35          return workspace_context.get_request_workspace() or "none"
 36  
 37      return app
 38  
 39  
 40  def test_flask_workspace_middleware_sets_context(flask_workspace_app, monkeypatch):
 41      class DummyWorkspaceStore:
 42          def get_workspace(self, name):
 43              return Workspace(name=name)
 44  
 45      store = DummyWorkspaceStore()
 46      monkeypatch.setattr(
 47          "mlflow.server.workspace_helpers._get_workspace_store",
 48          lambda workspace_uri=None, tracking_uri=None: store,
 49      )
 50  
 51      client = flask_workspace_app.test_client()
 52      resp = client.get("/ping", headers={WORKSPACE_HEADER_NAME: "team-a"})
 53      assert resp.data.decode() == "team-a"
 54      assert workspace_context.get_request_workspace() is None
 55  
 56  
 57  def test_flask_workspace_middleware_requires_header(flask_workspace_app, monkeypatch):
 58      class DefaultlessWorkspaceStore:
 59          def get_default_workspace(self):
 60              raise MlflowException.invalid_parameter_value("Active workspace is required.")
 61  
 62      store = DefaultlessWorkspaceStore()
 63      monkeypatch.setattr(
 64          "mlflow.server.workspace_helpers._get_workspace_store",
 65          lambda workspace_uri=None, tracking_uri=None: store,
 66      )
 67  
 68      client = flask_workspace_app.test_client()
 69      resp = client.get("/ping")
 70      assert resp.status_code == 400
 71      assert "Active workspace is required" in resp.json["message"]
 72      assert workspace_context.get_request_workspace() is None
 73  
 74  
 75  def _fastapi_workspace_app(monkeypatch):
 76      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
 77      app = FastAPI()
 78      add_fastapi_workspace_middleware(app)
 79  
 80      ping_path = f"{job_api_router.prefix}/ping"
 81  
 82      @app.get(ping_path)
 83      async def ping():
 84          return {"workspace": workspace_context.get_request_workspace()}
 85  
 86      return app, ping_path
 87  
 88  
 89  def test_fastapi_workspace_middleware_sets_context(monkeypatch):
 90      app, ping_path = _fastapi_workspace_app(monkeypatch)
 91      monkeypatch.setattr(
 92          "mlflow.server.fastapi_app.resolve_workspace_for_request_if_enabled",
 93          lambda _path, header: Workspace(name=header),
 94      )
 95  
 96      client = TestClient(app)
 97      resp = client.get(ping_path, headers={WORKSPACE_HEADER_NAME: "team-fast"})
 98      assert resp.status_code == 200
 99      assert resp.json() == {"workspace": "team-fast"}
100      assert workspace_context.get_request_workspace() is None
101  
102  
103  def test_fastapi_workspace_middleware_handles_missing_header(monkeypatch):
104      app, ping_path = _fastapi_workspace_app(monkeypatch)
105      monkeypatch.setattr(
106          "mlflow.server.fastapi_app.resolve_workspace_for_request_if_enabled",
107          lambda _path, _header: None,
108      )
109  
110      client = TestClient(app)
111      resp = client.get(ping_path)
112      assert resp.status_code == 200
113      assert resp.json() == {"workspace": None}
114      assert workspace_context.get_request_workspace() is None
115  
116  
117  def test_server_info_workspaces_enabled(monkeypatch):
118      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
119      client = flask_app.test_client()
120      resp = client.get("/api/3.0/mlflow/server-info")
121      assert resp.status_code == 200
122      data = resp.get_json()
123      assert data["workspaces_enabled"] is True
124  
125      # Disable workspaces and ensure the endpoint reflects the change.
126      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false")
127      resp = client.get("/api/3.0/mlflow/server-info")
128      assert resp.status_code == 200
129      data = resp.get_json()
130      assert data["workspaces_enabled"] is False
131  
132  
133  def test_server_info_skips_workspace_resolution(monkeypatch):
134      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
135  
136      def _raise_if_called(_header_workspace):
137          raise AssertionError("workspace resolution should not run for server-info")
138  
139      monkeypatch.setattr(
140          "mlflow.server.workspace_helpers.resolve_workspace_from_header", _raise_if_called
141      )
142  
143      client = flask_app.test_client()
144      resp = client.get("/api/3.0/mlflow/server-info", headers={WORKSPACE_HEADER_NAME: "missing"})
145      assert resp.status_code == 200
146      data = resp.get_json()
147      assert data["workspaces_enabled"] is True
148  
149  
150  def test_server_info_with_workspace_header_when_workspaces_disabled(monkeypatch):
151      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false")
152      client = flask_app.test_client()
153      resp = client.get(
154          "/api/3.0/mlflow/server-info", headers={WORKSPACE_HEADER_NAME: "some-workspace"}
155      )
156      assert resp.status_code == 200
157      data = resp.get_json()
158      assert data["workspaces_enabled"] is False
159  
160  
161  def test_fastapi_wsgi_flask_workspace_propagation(monkeypatch):
162      from fastapi.middleware.wsgi import WSGIMiddleware
163  
164      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
165      if not hasattr(werkzeug, "__version__"):
166          werkzeug.__version__ = "tests"
167  
168      flask_resolution_count = []
169  
170      original_resolve = workspace_context.is_request_workspace_resolved
171  
172      def tracking_is_resolved():
173          result = original_resolve()
174          flask_resolution_count.append(result)
175          return result
176  
177      monkeypatch.setattr(
178          "mlflow.server.workspace_helpers.workspace_context.is_request_workspace_resolved",
179          tracking_is_resolved,
180      )
181  
182      test_flask_app = Flask(__name__)
183      test_flask_app.before_request(workspace_before_request_handler)
184      test_flask_app.teardown_request(workspace_teardown_request_handler)
185  
186      @test_flask_app.route("/flask-ping")
187      def _flask_ping():
188          return workspace_context.get_request_workspace() or "none"
189  
190      fastapi_app = FastAPI()
191      add_fastapi_workspace_middleware(fastapi_app)
192      fastapi_app.mount("/", WSGIMiddleware(test_flask_app))
193  
194      monkeypatch.setattr(
195          "mlflow.server.fastapi_app.resolve_workspace_for_request_if_enabled",
196          lambda _path, header: Workspace(name=header) if header else None,
197      )
198  
199      client = TestClient(fastapi_app)
200      resp = client.get("/flask-ping", headers={WORKSPACE_HEADER_NAME: "team-wsgi"})
201  
202      assert resp.status_code == 200
203      assert resp.text == "team-wsgi"
204      assert len(flask_resolution_count) == 1
205      assert flask_resolution_count[0] is True