/ tests / models / test_signature.py
test_signature.py
  1  import json
  2  from dataclasses import asdict, dataclass
  3  
  4  import numpy as np
  5  import pandas as pd
  6  import pydantic
  7  import pyspark
  8  import pytest
  9  from sklearn.ensemble import RandomForestRegressor
 10  
 11  import mlflow
 12  from mlflow.exceptions import MlflowException
 13  from mlflow.models import Model, ModelSignature, infer_signature, rag_signatures, set_signature
 14  from mlflow.models.model import get_model_info
 15  from mlflow.types import DataType
 16  from mlflow.types.schema import (
 17      Array,
 18      ColSpec,
 19      ParamSchema,
 20      ParamSpec,
 21      Schema,
 22      TensorSpec,
 23      convert_dataclass_to_schema,
 24  )
 25  from mlflow.types.utils import InvalidDataForSignatureInferenceError
 26  
 27  
 28  def test_model_signature_with_colspec():
 29      signature1 = ModelSignature(
 30          inputs=Schema([ColSpec(DataType.boolean), ColSpec(DataType.binary)]),
 31          outputs=Schema([
 32              ColSpec(name=None, type=DataType.double),
 33              ColSpec(name=None, type=DataType.double),
 34          ]),
 35      )
 36      signature2 = ModelSignature(
 37          inputs=Schema([ColSpec(DataType.boolean), ColSpec(DataType.binary)]),
 38          outputs=Schema([
 39              ColSpec(name=None, type=DataType.double),
 40              ColSpec(name=None, type=DataType.double),
 41          ]),
 42      )
 43      assert signature1 == signature2
 44      signature3 = ModelSignature(
 45          inputs=Schema([ColSpec(DataType.boolean), ColSpec(DataType.binary)]),
 46          outputs=Schema([
 47              ColSpec(name=None, type=DataType.float),
 48              ColSpec(name=None, type=DataType.double),
 49          ]),
 50      )
 51      assert signature3 != signature1
 52      as_json = json.dumps(signature1.to_dict())
 53      signature4 = ModelSignature.from_dict(json.loads(as_json))
 54      assert signature1 == signature4
 55      signature5 = ModelSignature(
 56          inputs=Schema([ColSpec(DataType.boolean), ColSpec(DataType.binary)]), outputs=None
 57      )
 58      as_json = json.dumps(signature5.to_dict())
 59      signature6 = ModelSignature.from_dict(json.loads(as_json))
 60      assert signature5 == signature6
 61  
 62  
 63  def test_model_signature_with_tensorspec():
 64      signature1 = ModelSignature(
 65          inputs=Schema([TensorSpec(np.dtype("float"), (-1, 28, 28))]),
 66          outputs=Schema([TensorSpec(np.dtype("float"), (-1, 10))]),
 67      )
 68      signature2 = ModelSignature(
 69          inputs=Schema([TensorSpec(np.dtype("float"), (-1, 28, 28))]),
 70          outputs=Schema([TensorSpec(np.dtype("float"), (-1, 10))]),
 71      )
 72      # Single type mismatch
 73      assert signature1 == signature2
 74      signature3 = ModelSignature(
 75          inputs=Schema([TensorSpec(np.dtype("float"), (-1, 28, 28))]),
 76          outputs=Schema([TensorSpec(np.dtype("int"), (-1, 10))]),
 77      )
 78      assert signature3 != signature1
 79      # Name mismatch
 80      signature4 = ModelSignature(
 81          inputs=Schema([TensorSpec(np.dtype("float"), (-1, 28, 28))]),
 82          outputs=Schema([TensorSpec(np.dtype("float"), (-1, 10), "mismatch")]),
 83      )
 84      assert signature3 != signature4
 85      as_json = json.dumps(signature1.to_dict())
 86      signature5 = ModelSignature.from_dict(json.loads(as_json))
 87      assert signature1 == signature5
 88  
 89      # Test with name
 90      signature6 = ModelSignature(
 91          inputs=Schema([
 92              TensorSpec(np.dtype("float"), (-1, 28, 28), name="image"),
 93              TensorSpec(np.dtype("int"), (-1, 10), name="metadata"),
 94          ]),
 95          outputs=Schema([TensorSpec(np.dtype("float"), (-1, 10), name="outputs")]),
 96      )
 97      signature7 = ModelSignature(
 98          inputs=Schema([
 99              TensorSpec(np.dtype("float"), (-1, 28, 28), name="image"),
100              TensorSpec(np.dtype("int"), (-1, 10), name="metadata"),
101          ]),
102          outputs=Schema([TensorSpec(np.dtype("float"), (-1, 10), name="outputs")]),
103      )
104      assert signature6 == signature7
105      assert signature1 != signature6
106  
107      # Test w/o output
108      signature8 = ModelSignature(
109          inputs=Schema([TensorSpec(np.dtype("float"), (-1, 28, 28))]), outputs=None
110      )
111      as_json = json.dumps(signature8.to_dict())
112      signature9 = ModelSignature.from_dict(json.loads(as_json))
113      assert signature8 == signature9
114  
115  
116  def test_model_signature_with_colspec_and_tensorspec():
117      signature1 = ModelSignature(inputs=Schema([ColSpec(DataType.double)]))
118      signature2 = ModelSignature(inputs=Schema([TensorSpec(np.dtype("float"), (-1, 28, 28))]))
119      assert signature1 != signature2
120      assert signature2 != signature1
121  
122      signature3 = ModelSignature(
123          inputs=Schema([ColSpec(DataType.double)]),
124          outputs=Schema([TensorSpec(np.dtype("float"), (-1, 28, 28))]),
125      )
126      signature4 = ModelSignature(
127          inputs=Schema([ColSpec(DataType.double)]),
128          outputs=Schema([ColSpec(DataType.double)]),
129      )
130      assert signature3 != signature4
131      assert signature4 != signature3
132  
133  
134  def test_signature_inference_infers_input_and_output_as_expected():
135      sig0 = infer_signature(np.array([1]))
136      assert sig0.inputs is not None
137      assert sig0.outputs is None
138      sig1 = infer_signature(np.array([1]), np.array([1]))
139      assert sig1.inputs == sig0.inputs
140      assert sig1.outputs == sig0.inputs
141  
142  
143  def test_infer_signature_on_nested_array():
144      signature = infer_signature(
145          model_input=[{"queries": [["a", "b", "c"], ["d", "e"], []]}],
146          model_output=[{"answers": [["f", "g"], ["h"]]}],
147      )
148      assert signature.inputs == Schema([ColSpec(Array(Array(DataType.string)), name="queries")])
149      assert signature.outputs == Schema([ColSpec(Array(Array(DataType.string)), name="answers")])
150  
151      signature = infer_signature(
152          model_input=[
153              {
154                  "inputs": [
155                      np.array([["a", "b"], ["c", "d"]]),
156                      np.array([["e", "f"], ["g", "h"]]),
157                  ]
158              }
159          ],
160          model_output=[{"outputs": [np.int32(5), np.int32(6)]}],
161      )
162      assert signature.inputs == Schema([
163          ColSpec(Array(Array(Array(DataType.string))), name="inputs")
164      ])
165      assert signature.outputs == Schema([ColSpec(Array(DataType.integer), name="outputs")])
166  
167  
168  def test_infer_signature_on_list_of_dictionaries():
169      signature = infer_signature(
170          model_input=[{"query": "test query"}],
171          model_output=[
172              {
173                  "output": "Output from the LLM",
174                  "candidate_ids": ["412", "1233"],
175                  "candidate_sources": ["file1.md", "file201.md"],
176              }
177          ],
178      )
179      assert signature.inputs == Schema([ColSpec(DataType.string, name="query")])
180      assert signature.outputs == Schema([
181          ColSpec(DataType.string, name="output"),
182          ColSpec(Array(DataType.string), name="candidate_ids"),
183          ColSpec(Array(DataType.string), name="candidate_sources"),
184      ])
185  
186  
187  def test_signature_inference_infers_datime_types_as_expected():
188      col_name = "datetime_col"
189      test_datetime = np.datetime64("2021-01-01")
190      test_series = pd.Series(pd.to_datetime([test_datetime]))
191      test_df = test_series.to_frame(col_name)
192  
193      signature = infer_signature(test_series)
194      assert signature.inputs == Schema([ColSpec(DataType.datetime)])
195  
196      signature = infer_signature(test_df)
197      assert signature.inputs == Schema([ColSpec(DataType.datetime, name=col_name)])
198  
199      with pyspark.sql.SparkSession.builder.getOrCreate() as spark:
200          spark_df = spark.range(1).selectExpr(
201              "current_timestamp() as timestamp", "current_date() as date"
202          )
203          signature = infer_signature(spark_df)
204          assert signature.inputs == Schema([
205              ColSpec(DataType.datetime, name="timestamp"),
206              ColSpec(DataType.datetime, name="date"),
207          ])
208  
209  
210  def test_set_signature_to_logged_model():
211      artifact_path = "regr-model"
212      with mlflow.start_run():
213          model_info = mlflow.sklearn.log_model(RandomForestRegressor(), name=artifact_path)
214      signature = infer_signature(np.array([1]))
215      set_signature(model_info.model_uri, signature)
216      model_info = get_model_info(model_info.model_uri)
217      assert model_info.signature == signature
218  
219  
220  def test_set_signature_to_saved_model(tmp_path):
221      model_path = str(tmp_path)
222      mlflow.sklearn.save_model(
223          RandomForestRegressor(),
224          model_path,
225          serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE,
226      )
227      signature = infer_signature(np.array([1]))
228      set_signature(model_path, signature)
229      assert Model.load(model_path).signature == signature
230  
231  
232  def test_set_signature_overwrite():
233      artifact_path = "regr-model"
234      with mlflow.start_run():
235          model_info = mlflow.sklearn.log_model(
236              RandomForestRegressor(),
237              name=artifact_path,
238              signature=infer_signature(np.array([1])),
239          )
240      new_signature = infer_signature(np.array([1]), np.array([1]))
241      set_signature(model_info.model_uri, new_signature)
242      model_info = get_model_info(model_info.model_uri)
243      assert model_info.signature == new_signature
244  
245  
246  def test_cannot_set_signature_on_models_scheme_uris():
247      signature = infer_signature(np.array([1]))
248      with pytest.raises(
249          MlflowException,
250          match="Model URIs with the `models:/<name>/<version>` scheme are not supported.",
251      ):
252          set_signature("models:/dummy_model@champion", signature)
253  
254  
255  def test_signature_construction():
256      signature = ModelSignature(inputs=Schema([ColSpec(DataType.binary)]))
257      assert signature.to_dict() == {
258          "inputs": '[{"type": "binary", "required": true}]',
259          "outputs": None,
260          "params": None,
261      }
262  
263      signature = ModelSignature(outputs=Schema([ColSpec(DataType.double)]))
264      assert signature.to_dict() == {
265          "inputs": None,
266          "outputs": '[{"type": "double", "required": true}]',
267          "params": None,
268      }
269  
270      signature = ModelSignature(params=ParamSchema([ParamSpec("param1", DataType.string, "test")]))
271      assert signature.to_dict() == {
272          "inputs": None,
273          "outputs": None,
274          "params": '[{"name": "param1", "default": "test", "shape": null, "type": "string"}]',
275      }
276  
277  
278  def test_signature_with_errors():
279      with pytest.raises(
280          TypeError,
281          match=r"inputs must be either None, mlflow.models.signature.Schema, or a dataclass",
282      ):
283          ModelSignature(inputs=1)
284  
285      with pytest.raises(
286          ValueError, match=r"At least one of inputs, outputs or params must be provided"
287      ):
288          ModelSignature()
289  
290  
291  def test_signature_for_rag():
292      signature = ModelSignature(
293          inputs=rag_signatures.ChatCompletionRequest(),
294          outputs=rag_signatures.ChatCompletionResponse(),
295      )
296      signature_dict = signature.to_dict()
297      assert signature_dict == {
298          "inputs": (
299              '[{"type": "array", "items": {"type": "object", "properties": '
300              '{"content": {"type": "string", "required": true}, '
301              '"role": {"type": "string", "required": true}}}, '
302              '"name": "messages", "required": true}]'
303          ),
304          "outputs": (
305              '[{"type": "array", "items": {"type": "object", "properties": '
306              '{"finish_reason": {"type": "string", "required": true}, '
307              '"index": {"type": "long", "required": true}, '
308              '"message": {"type": "object", "properties": '
309              '{"content": {"type": "string", "required": true}, '
310              '"role": {"type": "string", "required": true}}, '
311              '"required": true}}}, "name": "choices", "required": true}, '
312              '{"type": "string", "name": "object", "required": true}]'
313          ),
314          "params": None,
315      }
316  
317  
318  def test_infer_signature_and_convert_dataclass_to_schema_for_rag():
319      inferred_signature = infer_signature(
320          asdict(rag_signatures.ChatCompletionRequest()),
321          asdict(rag_signatures.ChatCompletionResponse()),
322      )
323      input_schema = convert_dataclass_to_schema(rag_signatures.ChatCompletionRequest())
324      output_schema = convert_dataclass_to_schema(rag_signatures.ChatCompletionResponse())
325      assert inferred_signature.inputs == input_schema
326      assert inferred_signature.outputs == output_schema
327  
328  
329  def test_infer_signature_with_dataclass():
330      inferred_signature = infer_signature(
331          rag_signatures.ChatCompletionRequest(),
332          rag_signatures.ChatCompletionResponse(),
333      )
334      input_schema = convert_dataclass_to_schema(rag_signatures.ChatCompletionRequest())
335      output_schema = convert_dataclass_to_schema(rag_signatures.ChatCompletionResponse())
336      assert inferred_signature.inputs == input_schema
337      assert inferred_signature.outputs == output_schema
338  
339  
340  @dataclass
341  class CustomInput:
342      id: int = 0
343  
344  
345  @dataclass
346  class CustomOutput:
347      id: int = 0
348  
349  
350  @dataclass
351  class FlexibleChatCompletionRequest(rag_signatures.ChatCompletionRequest):
352      custom_input: CustomInput | None = None
353  
354  
355  @dataclass
356  class FlexibleChatCompletionResponse(rag_signatures.ChatCompletionResponse):
357      custom_output: CustomOutput | None = None
358  
359  
360  def test_infer_signature_with_optional_and_child_dataclass():
361      inferred_signature = infer_signature(
362          FlexibleChatCompletionRequest(),
363          FlexibleChatCompletionResponse(),
364      )
365      custom_input_schema = next(
366          schema for schema in inferred_signature.inputs.to_dict() if schema["name"] == "custom_input"
367      )
368      assert custom_input_schema["required"] is False
369      assert "id" in custom_input_schema["properties"]
370      assert any(
371          schema for schema in inferred_signature.inputs.to_dict() if schema["name"] == "messages"
372      )
373  
374  
375  def test_infer_signature_for_pydantic_objects_error():
376      class Message(pydantic.BaseModel):
377          content: str
378          role: str
379  
380      m = Message(content="test", role="user")
381      with pytest.raises(
382          InvalidDataForSignatureInferenceError,
383          match=r"MLflow does not support inferring model signature from "
384          r"input example with Pydantic objects",
385      ):
386          infer_signature([m])