/ tests / entities / test_trace_data.py
test_trace_data.py
  1  import json
  2  from unittest import mock
  3  
  4  import pytest
  5  
  6  import mlflow
  7  from mlflow.entities import SpanType, TraceData
  8  from mlflow.entities.span_event import SpanEvent
  9  
 10  
 11  def test_json_deserialization():
 12      class TestModel:
 13          @mlflow.trace()
 14          def predict(self, x, y):
 15              z = x + y
 16  
 17              with mlflow.start_span(name="with_ok_event") as span:
 18                  span.add_event(SpanEvent(name="ok_event", attributes={"foo": "bar"}))
 19  
 20              self.always_fail()
 21              return z
 22  
 23          @mlflow.trace(span_type=SpanType.LLM, name="always_fail_name", attributes={"delta": 1})
 24          def always_fail(self):
 25              raise Exception("Error!")
 26  
 27      model = TestModel()
 28  
 29      # Verify the exception is not absorbed by the context manager
 30      with pytest.raises(Exception, match="Error!"):
 31          model.predict(2, 5)
 32  
 33      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 34      trace_data = trace.data
 35  
 36      # Compare events separately as it includes exception stacktrace which is hard to hardcode
 37      trace_data_dict = trace_data.to_dict()
 38      span_to_events = {span["name"]: span.get("events") for span in trace_data_dict["spans"]}
 39  
 40      assert trace_data_dict == {
 41          "spans": [
 42              {
 43                  "name": "predict",
 44                  "trace_id": mock.ANY,
 45                  "span_id": mock.ANY,
 46                  "parent_span_id": None,
 47                  "start_time_unix_nano": trace.data.spans[0].start_time_ns,
 48                  "end_time_unix_nano": trace.data.spans[0].end_time_ns,
 49                  "status": {
 50                      "code": "STATUS_CODE_ERROR",
 51                      "message": "Exception: Error!",
 52                  },
 53                  "attributes": {
 54                      "mlflow.traceRequestId": json.dumps(trace.info.trace_id),
 55                      "mlflow.spanType": '"UNKNOWN"',
 56                      "mlflow.spanFunctionName": '"predict"',
 57                      "mlflow.spanInputs": '{"x": 2, "y": 5}',
 58                  },
 59                  "events": [
 60                      {
 61                          "name": "exception",
 62                          "time_unix_nano": trace.data.spans[0].events[0].timestamp,
 63                          "attributes": {
 64                              "exception.message": "Error!",
 65                              "exception.type": "Exception",
 66                              "exception.stacktrace": mock.ANY,
 67                          },
 68                      }
 69                  ],
 70              },
 71              {
 72                  "name": "with_ok_event",
 73                  "trace_id": mock.ANY,
 74                  "span_id": mock.ANY,
 75                  "parent_span_id": mock.ANY,
 76                  "start_time_unix_nano": trace.data.spans[1].start_time_ns,
 77                  "end_time_unix_nano": trace.data.spans[1].end_time_ns,
 78                  "status": {
 79                      "code": "STATUS_CODE_OK",
 80                      "message": "",
 81                  },
 82                  "attributes": {
 83                      "mlflow.traceRequestId": json.dumps(trace.info.trace_id),
 84                      "mlflow.spanType": '"UNKNOWN"',
 85                  },
 86                  "events": [
 87                      {
 88                          "name": "ok_event",
 89                          "time_unix_nano": trace.data.spans[1].events[0].timestamp,
 90                          "attributes": {"foo": "bar"},
 91                      }
 92                  ],
 93              },
 94              {
 95                  "name": "always_fail_name",
 96                  "trace_id": mock.ANY,
 97                  "span_id": mock.ANY,
 98                  "parent_span_id": mock.ANY,
 99                  "start_time_unix_nano": trace.data.spans[2].start_time_ns,
100                  "end_time_unix_nano": trace.data.spans[2].end_time_ns,
101                  "status": {
102                      "code": "STATUS_CODE_ERROR",
103                      "message": "Exception: Error!",
104                  },
105                  "attributes": {
106                      "delta": "1",
107                      "mlflow.traceRequestId": json.dumps(trace.info.trace_id),
108                      "mlflow.spanType": '"LLM"',
109                      "mlflow.spanFunctionName": '"always_fail"',
110                      "mlflow.spanInputs": "{}",
111                  },
112                  "events": [
113                      {
114                          "name": "exception",
115                          "time_unix_nano": trace.data.spans[2].events[0].timestamp,
116                          "attributes": {
117                              "exception.message": "Error!",
118                              "exception.type": "Exception",
119                              "exception.stacktrace": mock.ANY,
120                          },
121                      }
122                  ],
123              },
124          ],
125      }
126  
127      ok_events = span_to_events["with_ok_event"]
128      assert len(ok_events) == 1
129      assert ok_events[0]["name"] == "ok_event"
130      assert ok_events[0]["attributes"] == {"foo": "bar"}
131  
132      error_events = span_to_events["always_fail_name"]
133      assert len(error_events) == 1
134      assert error_events[0]["name"] == "exception"
135      assert error_events[0]["attributes"]["exception.message"] == "Error!"
136      assert error_events[0]["attributes"]["exception.type"] == "Exception"
137      assert error_events[0]["attributes"]["exception.stacktrace"] is not None
138  
139      parent_events = span_to_events["predict"]
140      assert len(parent_events) == 1
141      assert parent_events[0]["name"] == "exception"
142      assert parent_events[0]["attributes"]["exception.message"] == "Error!"
143      assert parent_events[0]["attributes"]["exception.type"] == "Exception"
144      # Parent span includes exception event bubbled up from the child span, hence the
145      # stack trace includes the function call
146      assert "self.always_fail()" in parent_events[0]["attributes"]["exception.stacktrace"]
147  
148      # Convert back from dict to TraceData and compare
149      trace_data_from_dict = TraceData.from_dict(trace_data_dict)
150      assert trace_data.to_dict() == trace_data_from_dict.to_dict()
151  
152  
153  def test_intermediate_outputs_from_attribute():
154      intermediate_outputs = {
155          "retrieved_documents": ["document 1", "document 2"],
156          "generative_prompt": "prompt",
157      }
158  
159      def run():
160          with mlflow.start_span(name="run") as span:
161              span.set_attribute("mlflow.trace.intermediate_outputs", intermediate_outputs)
162  
163      run()
164      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
165  
166      assert trace.data.intermediate_outputs == intermediate_outputs
167  
168  
169  def test_intermediate_outputs_from_spans():
170      @mlflow.trace()
171      def retrieved_documents():
172          return ["document 1", "document 2"]
173  
174      @mlflow.trace()
175      def llm(i):
176          return f"Hi, this is LLM {i}"
177  
178      @mlflow.trace()
179      def predict():
180          retrieved_documents()
181          llm(1)
182          llm(2)
183  
184      predict()
185      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
186  
187      assert trace.data.intermediate_outputs == {
188          "retrieved_documents": ["document 1", "document 2"],
189          "llm_1": "Hi, this is LLM 1",
190          "llm_2": "Hi, this is LLM 2",
191      }
192  
193  
194  def test_intermediate_outputs_no_value():
195      def run():
196          with mlflow.start_span(name="run") as span:
197              span.set_outputs(1)
198  
199      run()
200      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
201  
202      assert trace.data.intermediate_outputs is None
203  
204  
205  def test_to_dict():
206      with mlflow.start_span():
207          pass
208      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
209      trace_dict = trace.data.to_dict()
210      assert len(trace_dict["spans"]) == 1
211      # Ensure the legacy properties are not present
212      assert "request" not in trace_dict
213      assert "response" not in trace_dict
214  
215  
216  def test_request_and_response_are_still_available():
217      with mlflow.start_span() as s:
218          s.set_inputs("foo")
219          s.set_outputs("bar")
220  
221      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
222      trace_data = trace.data
223      assert trace_data.request == '"foo"'
224      assert trace_data.response == '"bar"'
225  
226      with mlflow.start_span():
227          pass
228  
229      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
230      trace_data = trace.data
231      assert trace_data.request is None
232      assert trace_data.response is None