/ tests / utils / test_proto_json_utils.py
test_proto_json_utils.py
  1  import base64
  2  import datetime
  3  import json
  4  
  5  import numpy as np
  6  import pandas as pd
  7  import pytest
  8  from google.protobuf.text_format import Parse as ParseTextIntoProto
  9  
 10  from mlflow.entities import Experiment, Metric
 11  from mlflow.entities.model_registry import ModelVersion, RegisteredModel
 12  from mlflow.exceptions import MlflowException
 13  from mlflow.protos.model_registry_pb2 import RegisteredModel as ProtoRegisteredModel
 14  from mlflow.protos.service_pb2 import Experiment as ProtoExperiment
 15  from mlflow.protos.service_pb2 import Metric as ProtoMetric
 16  from mlflow.types import ColSpec, DataType, Schema, TensorSpec
 17  from mlflow.types.schema import Array, Map, Object, Property
 18  from mlflow.types.utils import _infer_schema
 19  from mlflow.utils.proto_json_utils import (
 20      MlflowFailedTypeConversion,
 21      _CustomJsonEncoder,
 22      cast_df_types_according_to_schema,
 23      dataframe_from_parsed_json,
 24      dataframe_from_raw_json,
 25      message_to_json,
 26      parse_dict,
 27      parse_tf_serving_input,
 28  )
 29  
 30  from tests.protos.test_message_pb2 import SampleMessage
 31  
 32  
 33  def test_message_to_json():
 34      json_out = message_to_json(Experiment("123", "name", "arty", "active").to_proto())
 35      assert json.loads(json_out) == {
 36          "experiment_id": "123",
 37          "name": "name",
 38          "artifact_location": "arty",
 39          "lifecycle_stage": "active",
 40      }
 41  
 42      original_proto_message = RegisteredModel(
 43          name="model_1",
 44          creation_timestamp=111,
 45          last_updated_timestamp=222,
 46          description="Test model",
 47          latest_versions=[
 48              ModelVersion(
 49                  name="mv-1",
 50                  version="1",
 51                  creation_timestamp=333,
 52                  last_updated_timestamp=444,
 53                  description="v 1",
 54                  user_id="u1",
 55                  current_stage="Production",
 56                  source="A/B",
 57                  run_id="9245c6ce1e2d475b82af84b0d36b52f4",
 58                  status="READY",
 59                  status_message=None,
 60              ),
 61              ModelVersion(
 62                  name="mv-2",
 63                  version="2",
 64                  creation_timestamp=555,
 65                  last_updated_timestamp=666,
 66                  description="v 2",
 67                  user_id="u2",
 68                  current_stage="Staging",
 69                  source="A/C",
 70                  run_id="123",
 71                  status="READY",
 72                  status_message=None,
 73              ),
 74          ],
 75      ).to_proto()
 76      json_out = message_to_json(original_proto_message)
 77      json_dict = json.loads(json_out)
 78      assert json_dict == {
 79          "name": "model_1",
 80          "creation_timestamp": 111,
 81          "last_updated_timestamp": 222,
 82          "description": "Test model",
 83          "latest_versions": [
 84              {
 85                  "name": "mv-1",
 86                  "version": "1",
 87                  "creation_timestamp": 333,
 88                  "last_updated_timestamp": 444,
 89                  "current_stage": "Production",
 90                  "description": "v 1",
 91                  "user_id": "u1",
 92                  "source": "A/B",
 93                  "run_id": "9245c6ce1e2d475b82af84b0d36b52f4",
 94                  "status": "READY",
 95              },
 96              {
 97                  "name": "mv-2",
 98                  "version": "2",
 99                  "creation_timestamp": 555,
100                  "last_updated_timestamp": 666,
101                  "current_stage": "Staging",
102                  "description": "v 2",
103                  "user_id": "u2",
104                  "source": "A/C",
105                  "run_id": "123",
106                  "status": "READY",
107              },
108          ],
109      }
110      new_proto_message = ProtoRegisteredModel()
111      parse_dict(json_dict, new_proto_message)
112      assert original_proto_message == new_proto_message
113  
114      test_message = ParseTextIntoProto(
115          """
116          field_int32: 11
117          field_int64: 12
118          field_uint32: 13
119          field_uint64: 14
120          field_sint32: 15
121          field_sint64: 16
122          field_fixed32: 17
123          field_fixed64: 18
124          field_sfixed32: 19
125          field_sfixed64: 20
126          field_bool: true
127          field_string: "Im a string"
128          field_with_default1: 111
129          field_repeated_int64: [1, 2, 3]
130          field_enum: ENUM_VALUE1
131          field_inner_message {
132              field_inner_int64: 101
133              field_inner_repeated_int64: [102, 103]
134          }
135          field_inner_message {
136              field_inner_int64: 104
137              field_inner_repeated_int64: [105, 106]
138          }
139          oneof1: 207
140          [mlflow.ExtensionMessage.field_extended_int64]: 100
141          field_map1: [{key: 51 value: "52"}, {key: 53 value: "54"}]
142          field_map2: [{key: "61" value: 62}, {key: "63" value: 64}]
143          field_map3: [{key: 561 value: 562}, {key: 563 value: 564}]
144          field_map4: [{key: 71
145                        value: {field_inner_int64: 72
146                                field_inner_repeated_int64: [81, 82]
147                                field_inner_string: "str1"}},
148                       {key: 73
149                        value: {field_inner_int64: 74
150                                field_inner_repeated_int64: 83
151                                field_inner_string: "str2"}}]
152      """,
153          SampleMessage(),
154      )
155      json_out = message_to_json(test_message)
156      json_dict = json.loads(json_out)
157      assert json_dict == {
158          "field_int32": 11,
159          "field_int64": 12,
160          "field_uint32": 13,
161          "field_uint64": 14,
162          "field_sint32": 15,
163          "field_sint64": 16,
164          "field_fixed32": 17,
165          "field_fixed64": 18,
166          "field_sfixed32": 19,
167          "field_sfixed64": 20,
168          "field_bool": True,
169          "field_string": "Im a string",
170          "field_with_default1": 111,
171          "field_repeated_int64": [1, 2, 3],
172          "field_enum": "ENUM_VALUE1",
173          "field_inner_message": [
174              {"field_inner_int64": 101, "field_inner_repeated_int64": [102, 103]},
175              {"field_inner_int64": 104, "field_inner_repeated_int64": [105, 106]},
176          ],
177          "oneof1": 207,
178          # JSON doesn't support non-string keys, so the int keys will be converted to strings.
179          "field_map1": {"51": "52", "53": "54"},
180          "field_map2": {"63": 64, "61": 62},
181          "field_map3": {"561": 562, "563": 564},
182          "field_map4": {
183              "73": {
184                  "field_inner_int64": 74,
185                  "field_inner_repeated_int64": [83],
186                  "field_inner_string": "str2",
187              },
188              "71": {
189                  "field_inner_int64": 72,
190                  "field_inner_repeated_int64": [81, 82],
191                  "field_inner_string": "str1",
192              },
193          },
194          "[mlflow.ExtensionMessage.field_extended_int64]": "100",
195      }
196      new_test_message = SampleMessage()
197      parse_dict(json_dict, new_test_message)
198      assert new_test_message == test_message
199  
200  
201  def test_parse_dict():
202      in_json = {"experiment_id": "123", "name": "name", "unknown": "field"}
203      message = ProtoExperiment()
204      parse_dict(in_json, message)
205      experiment = Experiment.from_proto(message)
206      assert experiment.experiment_id == "123"
207      assert experiment.name == "name"
208      assert experiment.artifact_location == ""
209  
210  
211  def test_parse_dict_int_as_string_backcompat():
212      in_json = {"timestamp": "123"}
213      message = ProtoMetric()
214      parse_dict(in_json, message)
215      experiment = Metric.from_proto(message)
216      assert experiment.timestamp == 123
217  
218  
219  def assert_result(result, expected_result):
220      assert result.keys() == expected_result.keys()
221      for key in result:
222          assert (result[key] == expected_result[key]).all()
223          assert result[key].dtype == expected_result[key].dtype
224  
225  
226  def test_parse_tf_serving_dictionary():
227      # instances are correctly aggregated to dict of input name -> tensor
228      tfserving_input = {
229          "instances": [
230              {"a": "s1", "b": 1.1, "c": [1, 2, 3]},
231              {"a": "s2", "b": 2.2, "c": [4, 5, 6]},
232              {"a": "s3", "b": 3.3, "c": [7, 8, 9]},
233          ]
234      }
235      # Without Schema
236      result = parse_tf_serving_input(tfserving_input)
237      expected_result_no_schema = {
238          "a": np.array(["s1", "s2", "s3"]),
239          "b": np.array([1.1, 2.2, 3.3]),
240          "c": np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
241      }
242      assert_result(result, expected_result_no_schema)
243  
244      # With schema
245      schema = Schema([
246          TensorSpec(np.dtype("str"), [-1], "a"),
247          TensorSpec(np.dtype("float32"), [-1], "b"),
248          TensorSpec(np.dtype("int32"), [-1], "c"),
249      ])
250      df_schema = Schema([ColSpec("string", "a"), ColSpec("float", "b"), ColSpec("integer", "c")])
251      result = parse_tf_serving_input(tfserving_input, schema)
252      expected_result_schema = {
253          "a": np.array(["s1", "s2", "s3"], dtype=np.dtype("str")),
254          "b": np.array([1.1, 2.2, 3.3], dtype="float32"),
255          "c": np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype="int32"),
256      }
257      assert_result(result, expected_result_schema)
258      # With df Schema
259      result = parse_tf_serving_input(tfserving_input, df_schema)
260      assert_result(result, expected_result_schema)
261      # With df Schema containing array
262      new_schema = _infer_schema(tfserving_input["instances"])
263      result = parse_tf_serving_input(tfserving_input, new_schema)
264      expected_result = {
265          "a": np.array(["s1", "s2", "s3"]),
266          "b": np.array([1.1, 2.2, 3.3], dtype="float64"),
267          "c": np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype="int64"),
268      }
269      assert_result(result, expected_result)
270  
271      # input provided as a dict
272      tfserving_input = {
273          "inputs": {
274              "a": ["s1", "s2", "s3"],
275              "b": [1.1, 2.2, 3.3],
276              "c": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
277          }
278      }
279      # Without Schema
280      result = parse_tf_serving_input(tfserving_input)
281      assert_result(result, expected_result_no_schema)
282  
283      # With Schema
284      result = parse_tf_serving_input(tfserving_input, schema)
285      assert_result(result, expected_result_schema)
286  
287      # With df Schema
288      result = parse_tf_serving_input(tfserving_input, df_schema)
289      assert_result(result, expected_result_schema)
290  
291  
292  def test_parse_tf_serving_arbitrary_input_dictionary():
293      # input provided as a columnar dict with an arbitrary shape for each input, specifically a
294      # different 0th dimension.
295      tfserving_input_arbitrary = {
296          "inputs": {
297              "a": [["s1", "s2", "s3"], ["s4", "s5", "s6"]],  # [2, 3]
298              "b": [1.1, 2.2, 3.3],  # [3,  ]
299              "c": [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]],  # [4, 3]
300          }
301      }
302  
303      schema = Schema([
304          TensorSpec(np.dtype("str"), [-1, 3], "a"),
305          TensorSpec(np.dtype("float32"), [-1], "b"),
306          TensorSpec(np.dtype("int32"), [-1, 4], "c"),
307      ])
308      df_schema = Schema([ColSpec("string", "a"), ColSpec("float", "b"), ColSpec("integer", "c")])
309  
310      expected_result_no_schema_arbitrary = {
311          "a": np.array([["s1", "s2", "s3"], ["s4", "s5", "s6"]]),
312          "b": np.array([1.1, 2.2, 3.3]),
313          "c": np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]]),
314      }
315      expected_result_schema_arbitrary = {
316          "a": np.array([["s1", "s2", "s3"], ["s4", "s5", "s6"]], dtype=np.dtype("str")),
317          "b": np.array([1.1, 2.2, 3.3], dtype="float32"),
318          "c": np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]], dtype="int32"),
319      }
320  
321      # Without Schema
322      result = parse_tf_serving_input(tfserving_input_arbitrary)
323      assert_result(result, expected_result_no_schema_arbitrary)
324  
325      # With Schema
326      result = parse_tf_serving_input(tfserving_input_arbitrary, schema)
327      assert_result(result, expected_result_schema_arbitrary)
328  
329      # With df Schema
330      result = parse_tf_serving_input(tfserving_input_arbitrary, df_schema)
331      assert_result(result, expected_result_schema_arbitrary)
332  
333  
334  def test_parse_tf_serving_single_array():
335      def assert_result(result, expected_result):
336          assert (result == expected_result).all()
337  
338      # values for each column are properly converted to a tensor
339      arr = [
340          [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
341          [[3, 2, 1], [6, 5, 4], [9, 8, 7]],
342      ]
343      tfserving_instances = {"instances": arr}
344      tfserving_inputs = {"inputs": arr}
345  
346      # Without schema
347      instance_result = parse_tf_serving_input(tfserving_instances)
348      assert instance_result.shape == (2, 3, 3)
349      assert_result(instance_result, np.array(arr, dtype="int64"))
350  
351      input_result = parse_tf_serving_input(tfserving_inputs)
352      assert input_result.shape == (2, 3, 3)
353      assert_result(input_result, np.array(arr, dtype="int64"))
354  
355      # Unnamed schema
356      schema = Schema([TensorSpec(np.dtype("float32"), [-1])])
357      instance_result = parse_tf_serving_input(tfserving_instances, schema)
358      assert_result(instance_result, np.array(arr, dtype="float32"))
359  
360      input_result = parse_tf_serving_input(tfserving_inputs, schema)
361      assert_result(input_result, np.array(arr, dtype="float32"))
362  
363      # named schema
364      schema = Schema([TensorSpec(np.dtype("float32"), [-1], "a")])
365      instance_result = parse_tf_serving_input(tfserving_instances, schema)
366      assert isinstance(instance_result, dict)
367      assert len(instance_result.keys()) == 1
368      assert "a" in instance_result
369      assert_result(instance_result["a"], np.array(arr, dtype="float32"))
370  
371      input_result = parse_tf_serving_input(tfserving_inputs, schema)
372      assert isinstance(input_result, dict)
373      assert len(input_result.keys()) == 1
374      assert "a" in input_result
375      assert_result(input_result["a"], np.array(arr, dtype="float32"))
376  
377  
378  def test_parse_tf_serving_raises_expected_errors():
379      # input is bad if a column value is missing for a row/instance
380      tfserving_instances = {
381          "instances": [
382              {"a": "s1", "b": 1},
383              {"a": "s2", "b": 2, "c": [4, 5, 6]},
384              {"a": "s3", "b": 3, "c": [7, 8, 9]},
385          ]
386      }
387      with pytest.raises(
388          MlflowException, match="The length of values for each input/column name are not the same"
389      ):
390          parse_tf_serving_input(tfserving_instances)
391  
392      # cannot specify both instance and inputs
393      tfserving_input = {
394          "instances": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
395          "inputs": {"a": ["s1", "s2", "s3"], "b": [1, 2, 3], "c": [[1, 2, 3], [4, 5, 6], [7, 8, 9]]},
396      }
397      match = 'Invalid input. One of "instances" and "inputs" must be specified'
398      with pytest.raises(MlflowException, match=match):
399          parse_tf_serving_input(tfserving_input)
400  
401      # cannot specify signature name
402      tfserving_input = {
403          "signature_name": "hello",
404          "inputs": {"a": ["s1", "s2", "s3"], "b": [1, 2, 3], "c": [[1, 2, 3], [4, 5, 6], [7, 8, 9]]},
405      }
406      match = '"signature_name" parameter is currently not supported'
407      with pytest.raises(MlflowException, match=match):
408          parse_tf_serving_input(tfserving_input)
409  
410  
411  def test_dataframe_from_json():
412      source = pd.DataFrame(
413          {
414              "boolean": [True, False, True],
415              "string": ["a", "b", "c"],
416              "float": np.array([1.2, 2.3, 3.4], dtype=np.float32),
417              "double": np.array([1.2, 2.3, 3.4], dtype=np.float64),
418              "integer": np.array([3, 4, 5], dtype=np.int32),
419              "long": np.array([3, 4, 5], dtype=np.int64),
420              "binary": [bytes([1, 2, 3]), bytes([4, 5]), bytes([6])],
421              "date_string": ["2018-02-03", "1996-03-02", "2021-03-05"],
422          },
423          columns=[
424              "boolean",
425              "string",
426              "float",
427              "double",
428              "integer",
429              "long",
430              "binary",
431              "date_string",
432          ],
433      )
434  
435      jsonable_df = pd.DataFrame(source, copy=True)
436      jsonable_df["binary"] = jsonable_df["binary"].map(base64.b64encode)
437      schema = Schema([
438          ColSpec("boolean", "boolean"),
439          ColSpec("string", "string"),
440          ColSpec("float", "float"),
441          ColSpec("double", "double"),
442          ColSpec("integer", "integer"),
443          ColSpec("long", "long"),
444          ColSpec("binary", "binary"),
445          ColSpec("string", "date_string"),
446      ])
447      parsed = dataframe_from_raw_json(
448          jsonable_df.to_json(orient="split"), pandas_orient="split", schema=schema
449      )
450      pd.testing.assert_frame_equal(parsed, source)
451      parsed = dataframe_from_raw_json(
452          jsonable_df.to_json(orient="records"), pandas_orient="records", schema=schema
453      )
454      pd.testing.assert_frame_equal(parsed, source)
455      # try parsing with tensor schema
456      tensor_schema = Schema([
457          TensorSpec(np.dtype("bool"), [-1], "boolean"),
458          TensorSpec(np.dtype("str"), [-1], "string"),
459          TensorSpec(np.dtype("float32"), [-1], "float"),
460          TensorSpec(np.dtype("float64"), [-1], "double"),
461          TensorSpec(np.dtype("int32"), [-1], "integer"),
462          TensorSpec(np.dtype("int64"), [-1], "long"),
463          TensorSpec(np.dtype(bytes), [-1], "binary"),
464      ])
465      parsed = dataframe_from_raw_json(
466          jsonable_df.to_json(orient="split"), pandas_orient="split", schema=tensor_schema
467      )
468  
469      # NB: tensor schema does not automatically decode base64 encoded bytes.
470      pd.testing.assert_frame_equal(parsed, jsonable_df)
471      parsed = dataframe_from_raw_json(
472          jsonable_df.to_json(orient="records"), pandas_orient="records", schema=tensor_schema
473      )
474  
475      # NB: tensor schema does not automatically decode base64 encoded bytes.
476      pd.testing.assert_frame_equal(parsed, jsonable_df)
477  
478      # Test parse with TensorSchema with a single tensor
479      tensor_schema = Schema([TensorSpec(np.dtype("float32"), [-1, 3])])
480      source = pd.DataFrame(
481          {
482              "a": np.array([1, 2, 3], dtype=np.float32),
483              "b": np.array([4.1, 5.2, 6.3], dtype=np.float32),
484              "c": np.array([7, 8, 9], dtype=np.float32),
485          },
486          columns=["a", "b", "c"],
487      )
488      pd.testing.assert_frame_equal(
489          source,
490          dataframe_from_raw_json(
491              source.to_json(orient="split"), pandas_orient="split", schema=tensor_schema
492          ),
493      )
494      pd.testing.assert_frame_equal(
495          source,
496          dataframe_from_raw_json(
497              source.to_json(orient="records"), pandas_orient="records", schema=tensor_schema
498          ),
499      )
500  
501      schema = Schema([ColSpec("datetime", "datetime")])
502      parsed = dataframe_from_raw_json(
503          """
504  [
505      {"datetime": "2022-01-01T00:00:00"},
506      {"datetime": "2022-01-02T03:04:05"}
507  ]
508      """,
509          pandas_orient="records",
510          schema=schema,
511      )
512      expected = pd.DataFrame(
513          {
514              "datetime": pd.to_datetime([
515                  "2022-01-01T00:00:00",
516                  "2022-01-02T03:04:05",
517              ])
518          },
519      )
520      pd.testing.assert_frame_equal(parsed, expected)
521  
522  
523  @pytest.mark.parametrize(
524      ("dt", "expected"),
525      [
526          (datetime.datetime(2022, 1, 1), '"2022-01-01T00:00:00"'),
527          (datetime.datetime(2022, 1, 2, 3, 4, 5), '"2022-01-02T03:04:05"'),
528          (datetime.date(2022, 1, 1), '"2022-01-01"'),
529          (datetime.time(0, 0, 0), '"00:00:00"'),
530          (pd.Timestamp(2022, 1, 1), '"2022-01-01T00:00:00"'),
531      ],
532  )
533  def test_datetime_encoder(dt, expected):
534      assert json.dumps(dt, cls=_CustomJsonEncoder) == expected
535  
536  
537  @pytest.mark.parametrize(
538      ("dataframe", "schema", "expected"),
539      [
540          (
541              pd.DataFrame(columns=["foo"], data=[1, 2, 3]),
542              Schema([TensorSpec(np.dtype("float64"), [-1], "foo")]),
543              np.dtype("float64"),
544          ),
545          (
546              pd.DataFrame(columns=["foo"], data=[[[1, 2, 3]], [[4, 5, 6]]]),
547              Schema([TensorSpec(np.dtype("float64"), [-1, 1], "foo")]),
548              np.dtype("object"),
549          ),
550          (
551              pd.DataFrame(index=[1, 2, 3], columns=["foo"], data=[1, 2, 3]),
552              Schema([TensorSpec(np.dtype("float64"), [-1], "foo")]),
553              np.dtype("float64"),
554          ),
555          (
556              pd.DataFrame(columns=["foo"], data=[1, 2, 3]),
557              Schema([ColSpec("double", "foo")]),
558              np.dtype("float64"),
559          ),
560      ],
561  )
562  def test_cast_df_types_according_to_schema_success(dataframe, schema, expected):
563      casted_pdf = cast_df_types_according_to_schema(dataframe, schema)
564      assert casted_pdf["foo"].dtype == expected
565  
566  
567  @pytest.mark.parametrize(
568      ("dataframe", "schema", "error_message"),
569      [
570          (
571              pd.DataFrame(columns=["foo"], data=[1, 2, 3]),
572              Schema([ColSpec("binary", "foo")]),
573              r"TypeError\('encoding without a string argument'\)",
574          ),
575          (
576              pd.DataFrame(columns=["foo"], data=["a", "b", "c"]),
577              Schema([ColSpec("double", "foo")]),
578              r'ValueError\("could not convert string to float: \'a\'"\)',
579          ),
580      ],
581  )
582  def test_cast_df_types_according_to_schema_error_message(dataframe, schema, error_message):
583      with pytest.raises(MlflowFailedTypeConversion, match=error_message):
584          cast_df_types_according_to_schema(dataframe, schema)
585  
586  
587  @pytest.mark.parametrize(
588      ("data", "schema", "instances_data"),
589      [
590          ({"query": "sentence"}, Schema([ColSpec(DataType.string, name="query")]), None),
591          (
592              {"query": ["sentence_1", "sentence_2"]},
593              Schema([ColSpec(Array(DataType.string), name="query")]),
594              None,
595          ),
596          (
597              {"query": ["sentence_1", "sentence_2"], "table": "some_table"},
598              Schema([
599                  ColSpec(Array(DataType.string), name="query"),
600                  ColSpec(DataType.string, name="table"),
601              ]),
602              None,
603          ),
604          (
605              {"query": [{"name": "value", "age": 10}, {"name": "value"}], "table": ["some_table"]},
606              Schema([
607                  ColSpec(
608                      Array(
609                          Object([
610                              Property("name", DataType.string),
611                              Property("age", DataType.long, required=False),
612                          ])
613                      ),
614                      name="query",
615                  ),
616                  ColSpec(Array(DataType.string), name="table"),
617              ]),
618              None,
619          ),
620          (
621              [{"query": "sentence"}, {"query": "sentence"}],
622              Schema([ColSpec(DataType.string, name="query")]),
623              {"query": ["sentence", "sentence"]},
624          ),
625          (
626              [
627                  {"query": ["sentence_1", "sentence_2"], "table": "some_table"},
628                  {"query": ["sentence_1", "sentence_2"]},
629              ],
630              Schema([
631                  ColSpec(Array(DataType.string), name="query"),
632                  ColSpec(DataType.string, name="table", required=False),
633              ]),
634              {
635                  "query": [["sentence_1", "sentence_2"], ["sentence_1", "sentence_2"]],
636                  "table": ["some_table"],
637              },
638          ),
639          (
640              [
641                  {"query": {"a": "sentence_1", "b": "sentence_2"}, "table": "some_table"},
642                  {"query": {"a": "sentence_1"}, "table": "some_table"},
643              ],
644              Schema([
645                  ColSpec(
646                      Object([
647                          Property("a", DataType.string),
648                          Property("b", DataType.string, required=False),
649                      ]),
650                      name="query",
651                  ),
652                  ColSpec(DataType.string, name="table"),
653              ]),
654              {
655                  "query": [{"a": "sentence_1", "b": "sentence_2"}, {"a": "sentence_1"}],
656                  "table": ["some_table", "some_table"],
657              },
658          ),
659          (
660              {
661                  "query": [{"name": "value", "age": "10"}, {"name": "value"}],
662                  "table": {"k": "some_table"},
663                  "data": {"k1": ["a", "b"], "k2": ["c"]},
664              },
665              Schema([
666                  ColSpec(
667                      Array(Map(value_type=DataType.string)),
668                      name="query",
669                  ),
670                  ColSpec(Map(value_type=DataType.string), name="table"),
671                  ColSpec(Map(value_type=Array(DataType.string)), name="data"),
672              ]),
673              None,
674          ),
675      ],
676  )
677  def test_parse_tf_serving_input_for_dictionaries_and_lists_and_maps(data, schema, instances_data):
678      np.testing.assert_equal(parse_tf_serving_input({"inputs": data}, schema), data)
679      if instances_data is None:
680          np.testing.assert_equal(parse_tf_serving_input({"instances": data}, schema), data)
681      else:
682          np.testing.assert_equal(parse_tf_serving_input({"instances": data}, schema), instances_data)
683      df = pd.DataFrame(data) if isinstance(data, list) else pd.DataFrame([data])
684      df_split = df.to_dict(orient="split")
685      pd.testing.assert_frame_equal(dataframe_from_parsed_json(df_split, "split", schema), df)
686      df_records = df.to_dict(orient="records")
687      pd.testing.assert_frame_equal(dataframe_from_parsed_json(df_records, "records", schema), df)