/ tests / server / test_otel_api.py
test_otel_api.py
  1  from unittest import mock
  2  
  3  from fastapi import FastAPI
  4  from fastapi.testclient import TestClient
  5  from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest
  6  
  7  from mlflow.entities import Workspace
  8  from mlflow.environment_variables import MLFLOW_ENABLE_WORKSPACES
  9  from mlflow.server.fastapi_app import add_fastapi_workspace_middleware
 10  from mlflow.server.otel_api import otel_router
 11  from mlflow.tracing.utils.otlp import OTLP_TRACES_PATH
 12  from mlflow.utils import workspace_context
 13  from mlflow.utils.workspace_utils import WORKSPACE_HEADER_NAME
 14  
 15  
 16  def _build_otlp_payload():
 17      request = ExportTraceServiceRequest()
 18      span = request.resource_spans.add().scope_spans.add().spans.add()
 19      span.trace_id = b"\x00" * 16
 20      span.span_id = b"\x01" * 8
 21      span.name = "span"
 22      return request.SerializeToString()
 23  
 24  
 25  def _make_test_client():
 26      app = FastAPI()
 27      add_fastapi_workspace_middleware(app)
 28      app.include_router(otel_router)
 29      return TestClient(app)
 30  
 31  
 32  def test_workspace_scoped_otlp_endpoint_sets_workspace(monkeypatch):
 33      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
 34  
 35      class DummyTrackingStore:
 36          def __init__(self):
 37              self.calls = []
 38  
 39          def log_spans(self, experiment_id, spans):
 40              self.calls.append((workspace_context.get_request_workspace(), experiment_id, spans))
 41  
 42      tracking_store = DummyTrackingStore()
 43      captured = {}
 44  
 45      def fake_resolve(_path, header_workspace):
 46          captured["requested"] = header_workspace
 47          return Workspace(name=header_workspace)
 48  
 49      monkeypatch.setattr(
 50          "mlflow.server.fastapi_app.resolve_workspace_for_request_if_enabled",
 51          fake_resolve,
 52      )
 53      monkeypatch.setattr(
 54          "mlflow.server.otel_api._get_tracking_store",
 55          lambda: tracking_store,
 56      )
 57  
 58      client = _make_test_client()
 59      response = client.post(
 60          OTLP_TRACES_PATH,
 61          data=_build_otlp_payload(),
 62          headers={
 63              "Content-Type": "application/x-protobuf",
 64              "X-MLflow-Experiment-Id": "42",
 65              WORKSPACE_HEADER_NAME: "team-a",
 66          },
 67      )
 68  
 69      assert response.status_code == 200
 70      assert captured["requested"].strip() == "team-a"
 71      assert tracking_store.calls[0][0] == "team-a"
 72      # Workspace context should be cleared after the request
 73      assert workspace_context.get_request_workspace() is None
 74  
 75  
 76  def test_default_otlp_endpoint_uses_default_workspace(monkeypatch):
 77      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
 78  
 79      class DummyTrackingStore:
 80          def __init__(self):
 81              self.calls = []
 82  
 83          def log_spans(self, experiment_id, spans):
 84              self.calls.append((workspace_context.get_request_workspace(), experiment_id, spans))
 85  
 86      tracking_store = DummyTrackingStore()
 87      captured = {}
 88  
 89      def fake_resolve(_path, header_workspace):
 90          captured["requested"] = header_workspace
 91          return Workspace(name="default")
 92  
 93      monkeypatch.setattr(
 94          "mlflow.server.fastapi_app.resolve_workspace_for_request_if_enabled",
 95          fake_resolve,
 96      )
 97      monkeypatch.setattr(
 98          "mlflow.server.otel_api._get_tracking_store",
 99          lambda: tracking_store,
100      )
101  
102      client = _make_test_client()
103      response = client.post(
104          OTLP_TRACES_PATH,
105          data=_build_otlp_payload(),
106          headers={
107              "Content-Type": "application/x-protobuf",
108              "X-MLflow-Experiment-Id": "7",
109          },
110      )
111  
112      assert response.status_code == 200
113      assert captured["requested"] is None
114      assert tracking_store.calls[0][0] == "default"
115      assert workspace_context.get_request_workspace() is None
116  
117  
118  def test_otlp_endpoint_without_default_workspace_raises_error(monkeypatch):
119      from mlflow.store.workspace_aware_mixin import WorkspaceAwareMixin
120  
121      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
122  
123      class DummyWorkspaceAwareStore(WorkspaceAwareMixin):
124          """A dummy store that raises MlflowException when workspace is not set."""
125  
126          def log_spans(self, experiment_id, spans):
127              # This will raise MlflowException if workspace context is not set
128              self._get_active_workspace()
129  
130      def fake_resolve(_path, _header_workspace):
131          return None
132  
133      monkeypatch.setattr(
134          "mlflow.server.fastapi_app.resolve_workspace_for_request_if_enabled",
135          fake_resolve,
136      )
137      monkeypatch.setattr(
138          "mlflow.server.otel_api._get_tracking_store",
139          lambda: DummyWorkspaceAwareStore(),
140      )
141  
142      client = _make_test_client()
143      response = client.post(
144          OTLP_TRACES_PATH,
145          data=_build_otlp_payload(),
146          headers={
147              "Content-Type": "application/x-protobuf",
148              "X-MLflow-Experiment-Id": "42",
149          },
150      )
151  
152      assert response.status_code == 400
153      assert "Active workspace is required" in response.json()["message"]
154  
155  
156  def test_otlp_invalid_content_type(monkeypatch):
157      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false")
158  
159      monkeypatch.setattr(
160          "mlflow.server.otel_api._get_tracking_store",
161          lambda: mock.Mock(),
162      )
163  
164      client = _make_test_client()
165  
166      # Test with unsupported content type
167      response = client.post(
168          OTLP_TRACES_PATH,
169          data=_build_otlp_payload(),
170          headers={
171              "Content-Type": "text/plain",
172              "X-MLflow-Experiment-Id": "42",
173          },
174      )
175      assert response.status_code == 400
176      assert "Invalid Content-Type" in response.json()["detail"]
177  
178      # Test with missing content type
179      response = client.post(
180          OTLP_TRACES_PATH,
181          data=_build_otlp_payload(),
182          headers={
183              "X-MLflow-Experiment-Id": "42",
184          },
185      )
186      assert response.status_code == 400
187      assert "Invalid Content-Type" in response.json()["detail"]
188  
189  
190  def test_otlp_invalid_protobuf_data(monkeypatch):
191      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false")
192  
193      monkeypatch.setattr(
194          "mlflow.server.otel_api._get_tracking_store",
195          lambda: mock.Mock(),
196      )
197  
198      client = _make_test_client()
199  
200      # Test with invalid protobuf data
201      response = client.post(
202          OTLP_TRACES_PATH,
203          data=b"this is not valid protobuf data",
204          headers={
205              "Content-Type": "application/x-protobuf",
206              "X-MLflow-Experiment-Id": "42",
207          },
208      )
209      assert response.status_code == 400
210      assert "Invalid OpenTelemetry format" in response.json()["detail"]
211  
212  
213  def test_otlp_empty_resource_spans(monkeypatch):
214      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false")
215  
216      monkeypatch.setattr(
217          "mlflow.server.otel_api._get_tracking_store",
218          lambda: mock.Mock(),
219      )
220  
221      client = _make_test_client()
222  
223      # Create request with no resource spans
224      request = ExportTraceServiceRequest()
225  
226      response = client.post(
227          OTLP_TRACES_PATH,
228          data=request.SerializeToString(),
229          headers={
230              "Content-Type": "application/x-protobuf",
231              "X-MLflow-Experiment-Id": "42",
232          },
233      )
234      assert response.status_code == 400
235      assert "no spans found" in response.json()["detail"]
236  
237  
238  def test_otlp_conversion_error(monkeypatch):
239      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false")
240  
241      monkeypatch.setattr(
242          "mlflow.server.otel_api._get_tracking_store",
243          lambda: mock.Mock(),
244      )
245  
246      # Mock Span.from_otel_proto to raise exception
247      def mock_from_otel_proto(proto_span):
248          raise Exception("Cannot convert span")
249  
250      monkeypatch.setattr(
251          "mlflow.entities.span.Span.from_otel_proto",
252          mock_from_otel_proto,
253      )
254  
255      client = _make_test_client()
256  
257      response = client.post(
258          OTLP_TRACES_PATH,
259          data=_build_otlp_payload(),
260          headers={
261              "Content-Type": "application/x-protobuf",
262              "X-MLflow-Experiment-Id": "42",
263          },
264      )
265      assert response.status_code == 422
266      assert "Cannot convert OpenTelemetry span" in response.json()["detail"]