test_trace_data.py
1 import json 2 from unittest import mock 3 4 import pytest 5 6 import mlflow 7 from mlflow.entities import SpanType, TraceData 8 from mlflow.entities.span_event import SpanEvent 9 10 11 def test_json_deserialization(): 12 class TestModel: 13 @mlflow.trace() 14 def predict(self, x, y): 15 z = x + y 16 17 with mlflow.start_span(name="with_ok_event") as span: 18 span.add_event(SpanEvent(name="ok_event", attributes={"foo": "bar"})) 19 20 self.always_fail() 21 return z 22 23 @mlflow.trace(span_type=SpanType.LLM, name="always_fail_name", attributes={"delta": 1}) 24 def always_fail(self): 25 raise Exception("Error!") 26 27 model = TestModel() 28 29 # Verify the exception is not absorbed by the context manager 30 with pytest.raises(Exception, match="Error!"): 31 model.predict(2, 5) 32 33 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 34 trace_data = trace.data 35 36 # Compare events separately as it includes exception stacktrace which is hard to hardcode 37 trace_data_dict = trace_data.to_dict() 38 span_to_events = {span["name"]: span.get("events") for span in trace_data_dict["spans"]} 39 40 assert trace_data_dict == { 41 "spans": [ 42 { 43 "name": "predict", 44 "trace_id": mock.ANY, 45 "span_id": mock.ANY, 46 "parent_span_id": None, 47 "start_time_unix_nano": trace.data.spans[0].start_time_ns, 48 "end_time_unix_nano": trace.data.spans[0].end_time_ns, 49 "status": { 50 "code": "STATUS_CODE_ERROR", 51 "message": "Exception: Error!", 52 }, 53 "attributes": { 54 "mlflow.traceRequestId": json.dumps(trace.info.trace_id), 55 "mlflow.spanType": '"UNKNOWN"', 56 "mlflow.spanFunctionName": '"predict"', 57 "mlflow.spanInputs": '{"x": 2, "y": 5}', 58 }, 59 "events": [ 60 { 61 "name": "exception", 62 "time_unix_nano": trace.data.spans[0].events[0].timestamp, 63 "attributes": { 64 "exception.message": "Error!", 65 "exception.type": "Exception", 66 "exception.stacktrace": mock.ANY, 67 }, 68 } 69 ], 70 }, 71 { 72 "name": "with_ok_event", 73 "trace_id": mock.ANY, 74 "span_id": mock.ANY, 75 "parent_span_id": mock.ANY, 76 "start_time_unix_nano": trace.data.spans[1].start_time_ns, 77 "end_time_unix_nano": trace.data.spans[1].end_time_ns, 78 "status": { 79 "code": "STATUS_CODE_OK", 80 "message": "", 81 }, 82 "attributes": { 83 "mlflow.traceRequestId": json.dumps(trace.info.trace_id), 84 "mlflow.spanType": '"UNKNOWN"', 85 }, 86 "events": [ 87 { 88 "name": "ok_event", 89 "time_unix_nano": trace.data.spans[1].events[0].timestamp, 90 "attributes": {"foo": "bar"}, 91 } 92 ], 93 }, 94 { 95 "name": "always_fail_name", 96 "trace_id": mock.ANY, 97 "span_id": mock.ANY, 98 "parent_span_id": mock.ANY, 99 "start_time_unix_nano": trace.data.spans[2].start_time_ns, 100 "end_time_unix_nano": trace.data.spans[2].end_time_ns, 101 "status": { 102 "code": "STATUS_CODE_ERROR", 103 "message": "Exception: Error!", 104 }, 105 "attributes": { 106 "delta": "1", 107 "mlflow.traceRequestId": json.dumps(trace.info.trace_id), 108 "mlflow.spanType": '"LLM"', 109 "mlflow.spanFunctionName": '"always_fail"', 110 "mlflow.spanInputs": "{}", 111 }, 112 "events": [ 113 { 114 "name": "exception", 115 "time_unix_nano": trace.data.spans[2].events[0].timestamp, 116 "attributes": { 117 "exception.message": "Error!", 118 "exception.type": "Exception", 119 "exception.stacktrace": mock.ANY, 120 }, 121 } 122 ], 123 }, 124 ], 125 } 126 127 ok_events = span_to_events["with_ok_event"] 128 assert len(ok_events) == 1 129 assert ok_events[0]["name"] == "ok_event" 130 assert ok_events[0]["attributes"] == {"foo": "bar"} 131 132 error_events = span_to_events["always_fail_name"] 133 assert len(error_events) == 1 134 assert error_events[0]["name"] == "exception" 135 assert error_events[0]["attributes"]["exception.message"] == "Error!" 136 assert error_events[0]["attributes"]["exception.type"] == "Exception" 137 assert error_events[0]["attributes"]["exception.stacktrace"] is not None 138 139 parent_events = span_to_events["predict"] 140 assert len(parent_events) == 1 141 assert parent_events[0]["name"] == "exception" 142 assert parent_events[0]["attributes"]["exception.message"] == "Error!" 143 assert parent_events[0]["attributes"]["exception.type"] == "Exception" 144 # Parent span includes exception event bubbled up from the child span, hence the 145 # stack trace includes the function call 146 assert "self.always_fail()" in parent_events[0]["attributes"]["exception.stacktrace"] 147 148 # Convert back from dict to TraceData and compare 149 trace_data_from_dict = TraceData.from_dict(trace_data_dict) 150 assert trace_data.to_dict() == trace_data_from_dict.to_dict() 151 152 153 def test_intermediate_outputs_from_attribute(): 154 intermediate_outputs = { 155 "retrieved_documents": ["document 1", "document 2"], 156 "generative_prompt": "prompt", 157 } 158 159 def run(): 160 with mlflow.start_span(name="run") as span: 161 span.set_attribute("mlflow.trace.intermediate_outputs", intermediate_outputs) 162 163 run() 164 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 165 166 assert trace.data.intermediate_outputs == intermediate_outputs 167 168 169 def test_intermediate_outputs_from_spans(): 170 @mlflow.trace() 171 def retrieved_documents(): 172 return ["document 1", "document 2"] 173 174 @mlflow.trace() 175 def llm(i): 176 return f"Hi, this is LLM {i}" 177 178 @mlflow.trace() 179 def predict(): 180 retrieved_documents() 181 llm(1) 182 llm(2) 183 184 predict() 185 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 186 187 assert trace.data.intermediate_outputs == { 188 "retrieved_documents": ["document 1", "document 2"], 189 "llm_1": "Hi, this is LLM 1", 190 "llm_2": "Hi, this is LLM 2", 191 } 192 193 194 def test_intermediate_outputs_no_value(): 195 def run(): 196 with mlflow.start_span(name="run") as span: 197 span.set_outputs(1) 198 199 run() 200 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 201 202 assert trace.data.intermediate_outputs is None 203 204 205 def test_to_dict(): 206 with mlflow.start_span(): 207 pass 208 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 209 trace_dict = trace.data.to_dict() 210 assert len(trace_dict["spans"]) == 1 211 # Ensure the legacy properties are not present 212 assert "request" not in trace_dict 213 assert "response" not in trace_dict 214 215 216 def test_request_and_response_are_still_available(): 217 with mlflow.start_span() as s: 218 s.set_inputs("foo") 219 s.set_outputs("bar") 220 221 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 222 trace_data = trace.data 223 assert trace_data.request == '"foo"' 224 assert trace_data.response == '"bar"' 225 226 with mlflow.start_span(): 227 pass 228 229 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 230 trace_data = trace.data 231 assert trace_data.request is None 232 assert trace_data.response is None