/ tests / utils / test_jsonpath_utils.py
test_jsonpath_utils.py
  1  import pytest
  2  
  3  from mlflow.utils.jsonpath_utils import (
  4      filter_json_by_fields,
  5      jsonpath_extract_values,
  6      split_path_respecting_backticks,
  7      validate_field_paths,
  8  )
  9  
 10  
 11  def test_jsonpath_extract_values_simple():
 12      data = {"info": {"trace_id": "tr-123", "state": "OK"}}
 13      values = jsonpath_extract_values(data, "info.trace_id")
 14      assert values == ["tr-123"]
 15  
 16  
 17  def test_jsonpath_extract_values_nested():
 18      data = {"info": {"metadata": {"user": "test@example.com"}}}
 19      values = jsonpath_extract_values(data, "info.metadata.user")
 20      assert values == ["test@example.com"]
 21  
 22  
 23  def test_jsonpath_extract_values_wildcard_array():
 24      data = {"info": {"assessments": [{"feedback": {"value": 0.8}}, {"feedback": {"value": 0.9}}]}}
 25      values = jsonpath_extract_values(data, "info.assessments.*.feedback.value")
 26      assert values == [0.8, 0.9]
 27  
 28  
 29  def test_jsonpath_extract_values_wildcard_dict():
 30      data = {"data": {"spans": {"span1": {"name": "first"}, "span2": {"name": "second"}}}}
 31      values = jsonpath_extract_values(data, "data.spans.*.name")
 32      assert set(values) == {"first", "second"}  # Order may vary with dict
 33  
 34  
 35  def test_jsonpath_extract_values_missing_field():
 36      data = {"info": {"trace_id": "tr-123"}}
 37      values = jsonpath_extract_values(data, "info.nonexistent")
 38      assert values == []
 39  
 40  
 41  def test_jsonpath_extract_values_partial_path_missing():
 42      data = {"info": {"trace_id": "tr-123"}}
 43      values = jsonpath_extract_values(data, "info.metadata.user")
 44      assert values == []
 45  
 46  
 47  @pytest.mark.parametrize(
 48      ("input_string", "expected"),
 49      [
 50          ("info.trace_id", ["info", "trace_id"]),
 51          ("info.tags.`mlflow.traceName`", ["info", "tags", "mlflow.traceName"]),
 52          ("`field.one`.middle.`field.two`", ["field.one", "middle", "field.two"]),
 53          ("`mlflow.traceName`.value", ["mlflow.traceName", "value"]),
 54          ("info.`mlflow.traceName`", ["info", "mlflow.traceName"]),
 55      ],
 56  )
 57  def test_split_path_respecting_backticks(input_string, expected):
 58      assert split_path_respecting_backticks(input_string) == expected
 59  
 60  
 61  def test_jsonpath_extract_values_with_backticks():
 62      # Field name with dot
 63      data = {"tags": {"mlflow.traceName": "test_trace"}}
 64      values = jsonpath_extract_values(data, "tags.`mlflow.traceName`")
 65      assert values == ["test_trace"]
 66  
 67      # Nested structure with dotted field names
 68      data = {"info": {"tags": {"mlflow.traceName": "my_trace", "user.id": "user123"}}}
 69      assert jsonpath_extract_values(data, "info.tags.`mlflow.traceName`") == ["my_trace"]
 70      assert jsonpath_extract_values(data, "info.tags.`user.id`") == ["user123"]
 71  
 72      # Mixed regular and backticked fields
 73      data = {"metadata": {"mlflow.source.type": "NOTEBOOK", "regular_field": "value"}}
 74      assert jsonpath_extract_values(data, "metadata.`mlflow.source.type`") == ["NOTEBOOK"]
 75      assert jsonpath_extract_values(data, "metadata.regular_field") == ["value"]
 76  
 77  
 78  def test_jsonpath_extract_values_empty_array():
 79      data = {"info": {"assessments": []}}
 80      values = jsonpath_extract_values(data, "info.assessments.*.feedback.value")
 81      assert values == []
 82  
 83  
 84  def test_jsonpath_extract_values_mixed_types():
 85      data = {
 86          "data": {
 87              "spans": [
 88                  {"attributes": {"key1": "value1"}},
 89                  {"attributes": {"key1": 42}},
 90                  {"attributes": {"key1": True}},
 91              ]
 92          }
 93      }
 94      values = jsonpath_extract_values(data, "data.spans.*.attributes.key1")
 95      assert values == ["value1", 42, True]
 96  
 97  
 98  def test_filter_json_by_fields_single_field():
 99      data = {"info": {"trace_id": "tr-123", "state": "OK"}, "data": {"spans": []}}
100      filtered = filter_json_by_fields(data, ["info.trace_id"])
101      expected = {"info": {"trace_id": "tr-123"}}
102      assert filtered == expected
103  
104  
105  def test_filter_json_by_fields_multiple_fields():
106      data = {
107          "info": {"trace_id": "tr-123", "state": "OK", "unused": "value"},
108          "data": {"spans": [], "metadata": {}},
109      }
110      filtered = filter_json_by_fields(data, ["info.trace_id", "info.state"])
111      expected = {"info": {"trace_id": "tr-123", "state": "OK"}}
112      assert filtered == expected
113  
114  
115  def test_filter_json_by_fields_wildcards():
116      data = {
117          "info": {
118              "assessments": [
119                  {"feedback": {"value": 0.8}, "unused": "data"},
120                  {"feedback": {"value": 0.9}, "unused": "data"},
121              ]
122          }
123      }
124      filtered = filter_json_by_fields(data, ["info.assessments.*.feedback.value"])
125      expected = {
126          "info": {"assessments": [{"feedback": {"value": 0.8}}, {"feedback": {"value": 0.9}}]}
127      }
128      assert filtered == expected
129  
130  
131  def test_filter_json_by_fields_nested_arrays():
132      data = {
133          "data": {
134              "spans": [
135                  {
136                      "name": "span1",
137                      "events": [
138                          {"name": "event1", "data": "d1"},
139                          {"name": "event2", "data": "d2"},
140                      ],
141                      "unused": "value",
142                  }
143              ]
144          }
145      }
146      filtered = filter_json_by_fields(data, ["data.spans.*.events.*.name"])
147      expected = {"data": {"spans": [{"events": [{"name": "event1"}, {"name": "event2"}]}]}}
148      assert filtered == expected
149  
150  
151  def test_filter_json_by_fields_missing_paths():
152      data = {"info": {"trace_id": "tr-123"}}
153      filtered = filter_json_by_fields(data, ["info.nonexistent", "missing.path"])
154      assert filtered == {}
155  
156  
157  def test_filter_json_by_fields_partial_matches():
158      data = {"info": {"trace_id": "tr-123", "state": "OK"}}
159      filtered = filter_json_by_fields(data, ["info.trace_id", "info.nonexistent"])
160      expected = {"info": {"trace_id": "tr-123"}}
161      assert filtered == expected
162  
163  
164  def test_validate_field_paths_valid():
165      data = {"info": {"trace_id": "tr-123", "assessments": [{"feedback": {"value": 0.8}}]}}
166      # Should not raise any exception
167      validate_field_paths(["info.trace_id", "info.assessments.*.feedback.value"], data)
168  
169  
170  def test_validate_field_paths_invalid():
171      data = {"info": {"trace_id": "tr-123"}}
172  
173      with pytest.raises(ValueError, match="Invalid field path") as exc_info:
174          validate_field_paths(["info.nonexistent"], data)
175  
176      assert "Invalid field path" in str(exc_info.value)
177      assert "info.nonexistent" in str(exc_info.value)
178  
179  
180  def test_validate_field_paths_multiple_invalid():
181      data = {"info": {"trace_id": "tr-123"}}
182  
183      with pytest.raises(ValueError, match="Invalid field path") as exc_info:
184          validate_field_paths(["info.missing", "other.invalid"], data)
185  
186      error_msg = str(exc_info.value)
187      assert "Invalid field path" in error_msg
188      # Should mention both invalid paths
189      assert "info.missing" in error_msg or "other.invalid" in error_msg
190  
191  
192  def test_validate_field_paths_suggestions():
193      data = {"info": {"trace_id": "tr-123", "assessments": [], "metadata": {}}}
194  
195      with pytest.raises(ValueError, match="Invalid field path") as exc_info:
196          validate_field_paths(["info.traces"], data)  # Close to "trace_id"
197  
198      error_msg = str(exc_info.value)
199      assert "Available fields" in error_msg
200      assert "info.trace_id" in error_msg
201  
202  
203  def test_complex_trace_structure():
204      trace_data = {
205          "info": {
206              "trace_id": "tr-abc123def",
207              "state": "OK",
208              "execution_duration": 1500,
209              "assessments": [
210                  {
211                      "assessment_id": "a-123",
212                      "feedback": {"value": 0.85},
213                      "source": {"source_type": "HUMAN", "source_id": "user@example.com"},
214                  }
215              ],
216              "tags": {"environment": "production", "mlflow.traceName": "test_trace"},
217          },
218          "data": {
219              "spans": [
220                  {
221                      "span_id": "span-1",
222                      "name": "root_span",
223                      "attributes": {"mlflow.spanType": "AGENT"},
224                      "events": [{"name": "start", "attributes": {"key": "value"}}],
225                  }
226              ]
227          },
228      }
229  
230      # Test various field extractions
231      assert jsonpath_extract_values(trace_data, "info.trace_id") == ["tr-abc123def"]
232      assert jsonpath_extract_values(trace_data, "info.assessments.*.feedback.value") == [0.85]
233      assert jsonpath_extract_values(trace_data, "data.spans.*.name") == ["root_span"]
234      assert jsonpath_extract_values(trace_data, "data.spans.*.events.*.name") == ["start"]
235  
236      # Test filtering preserves structure
237      filtered = filter_json_by_fields(
238          trace_data, ["info.trace_id", "info.assessments.*.feedback.value", "data.spans.*.name"]
239      )
240  
241      assert "info" in filtered
242      assert filtered["info"]["trace_id"] == "tr-abc123def"
243      assert len(filtered["info"]["assessments"]) == 1
244      assert filtered["info"]["assessments"][0]["feedback"]["value"] == 0.85
245      assert "data" in filtered
246      assert len(filtered["data"]["spans"]) == 1
247      assert filtered["data"]["spans"][0]["name"] == "root_span"
248      # Should not contain other fields
249      assert "source" not in filtered["info"]["assessments"][0]
250      assert "attributes" not in filtered["data"]["spans"][0]