test_signature.py
1 import json 2 from dataclasses import asdict, dataclass 3 4 import numpy as np 5 import pandas as pd 6 import pydantic 7 import pyspark 8 import pytest 9 from sklearn.ensemble import RandomForestRegressor 10 11 import mlflow 12 from mlflow.exceptions import MlflowException 13 from mlflow.models import Model, ModelSignature, infer_signature, rag_signatures, set_signature 14 from mlflow.models.model import get_model_info 15 from mlflow.types import DataType 16 from mlflow.types.schema import ( 17 Array, 18 ColSpec, 19 ParamSchema, 20 ParamSpec, 21 Schema, 22 TensorSpec, 23 convert_dataclass_to_schema, 24 ) 25 from mlflow.types.utils import InvalidDataForSignatureInferenceError 26 27 28 def test_model_signature_with_colspec(): 29 signature1 = ModelSignature( 30 inputs=Schema([ColSpec(DataType.boolean), ColSpec(DataType.binary)]), 31 outputs=Schema([ 32 ColSpec(name=None, type=DataType.double), 33 ColSpec(name=None, type=DataType.double), 34 ]), 35 ) 36 signature2 = ModelSignature( 37 inputs=Schema([ColSpec(DataType.boolean), ColSpec(DataType.binary)]), 38 outputs=Schema([ 39 ColSpec(name=None, type=DataType.double), 40 ColSpec(name=None, type=DataType.double), 41 ]), 42 ) 43 assert signature1 == signature2 44 signature3 = ModelSignature( 45 inputs=Schema([ColSpec(DataType.boolean), ColSpec(DataType.binary)]), 46 outputs=Schema([ 47 ColSpec(name=None, type=DataType.float), 48 ColSpec(name=None, type=DataType.double), 49 ]), 50 ) 51 assert signature3 != signature1 52 as_json = json.dumps(signature1.to_dict()) 53 signature4 = ModelSignature.from_dict(json.loads(as_json)) 54 assert signature1 == signature4 55 signature5 = ModelSignature( 56 inputs=Schema([ColSpec(DataType.boolean), ColSpec(DataType.binary)]), outputs=None 57 ) 58 as_json = json.dumps(signature5.to_dict()) 59 signature6 = ModelSignature.from_dict(json.loads(as_json)) 60 assert signature5 == signature6 61 62 63 def test_model_signature_with_tensorspec(): 64 signature1 = ModelSignature( 65 inputs=Schema([TensorSpec(np.dtype("float"), (-1, 28, 28))]), 66 outputs=Schema([TensorSpec(np.dtype("float"), (-1, 10))]), 67 ) 68 signature2 = ModelSignature( 69 inputs=Schema([TensorSpec(np.dtype("float"), (-1, 28, 28))]), 70 outputs=Schema([TensorSpec(np.dtype("float"), (-1, 10))]), 71 ) 72 # Single type mismatch 73 assert signature1 == signature2 74 signature3 = ModelSignature( 75 inputs=Schema([TensorSpec(np.dtype("float"), (-1, 28, 28))]), 76 outputs=Schema([TensorSpec(np.dtype("int"), (-1, 10))]), 77 ) 78 assert signature3 != signature1 79 # Name mismatch 80 signature4 = ModelSignature( 81 inputs=Schema([TensorSpec(np.dtype("float"), (-1, 28, 28))]), 82 outputs=Schema([TensorSpec(np.dtype("float"), (-1, 10), "mismatch")]), 83 ) 84 assert signature3 != signature4 85 as_json = json.dumps(signature1.to_dict()) 86 signature5 = ModelSignature.from_dict(json.loads(as_json)) 87 assert signature1 == signature5 88 89 # Test with name 90 signature6 = ModelSignature( 91 inputs=Schema([ 92 TensorSpec(np.dtype("float"), (-1, 28, 28), name="image"), 93 TensorSpec(np.dtype("int"), (-1, 10), name="metadata"), 94 ]), 95 outputs=Schema([TensorSpec(np.dtype("float"), (-1, 10), name="outputs")]), 96 ) 97 signature7 = ModelSignature( 98 inputs=Schema([ 99 TensorSpec(np.dtype("float"), (-1, 28, 28), name="image"), 100 TensorSpec(np.dtype("int"), (-1, 10), name="metadata"), 101 ]), 102 outputs=Schema([TensorSpec(np.dtype("float"), (-1, 10), name="outputs")]), 103 ) 104 assert signature6 == signature7 105 assert signature1 != signature6 106 107 # Test w/o output 108 signature8 = ModelSignature( 109 inputs=Schema([TensorSpec(np.dtype("float"), (-1, 28, 28))]), outputs=None 110 ) 111 as_json = json.dumps(signature8.to_dict()) 112 signature9 = ModelSignature.from_dict(json.loads(as_json)) 113 assert signature8 == signature9 114 115 116 def test_model_signature_with_colspec_and_tensorspec(): 117 signature1 = ModelSignature(inputs=Schema([ColSpec(DataType.double)])) 118 signature2 = ModelSignature(inputs=Schema([TensorSpec(np.dtype("float"), (-1, 28, 28))])) 119 assert signature1 != signature2 120 assert signature2 != signature1 121 122 signature3 = ModelSignature( 123 inputs=Schema([ColSpec(DataType.double)]), 124 outputs=Schema([TensorSpec(np.dtype("float"), (-1, 28, 28))]), 125 ) 126 signature4 = ModelSignature( 127 inputs=Schema([ColSpec(DataType.double)]), 128 outputs=Schema([ColSpec(DataType.double)]), 129 ) 130 assert signature3 != signature4 131 assert signature4 != signature3 132 133 134 def test_signature_inference_infers_input_and_output_as_expected(): 135 sig0 = infer_signature(np.array([1])) 136 assert sig0.inputs is not None 137 assert sig0.outputs is None 138 sig1 = infer_signature(np.array([1]), np.array([1])) 139 assert sig1.inputs == sig0.inputs 140 assert sig1.outputs == sig0.inputs 141 142 143 def test_infer_signature_on_nested_array(): 144 signature = infer_signature( 145 model_input=[{"queries": [["a", "b", "c"], ["d", "e"], []]}], 146 model_output=[{"answers": [["f", "g"], ["h"]]}], 147 ) 148 assert signature.inputs == Schema([ColSpec(Array(Array(DataType.string)), name="queries")]) 149 assert signature.outputs == Schema([ColSpec(Array(Array(DataType.string)), name="answers")]) 150 151 signature = infer_signature( 152 model_input=[ 153 { 154 "inputs": [ 155 np.array([["a", "b"], ["c", "d"]]), 156 np.array([["e", "f"], ["g", "h"]]), 157 ] 158 } 159 ], 160 model_output=[{"outputs": [np.int32(5), np.int32(6)]}], 161 ) 162 assert signature.inputs == Schema([ 163 ColSpec(Array(Array(Array(DataType.string))), name="inputs") 164 ]) 165 assert signature.outputs == Schema([ColSpec(Array(DataType.integer), name="outputs")]) 166 167 168 def test_infer_signature_on_list_of_dictionaries(): 169 signature = infer_signature( 170 model_input=[{"query": "test query"}], 171 model_output=[ 172 { 173 "output": "Output from the LLM", 174 "candidate_ids": ["412", "1233"], 175 "candidate_sources": ["file1.md", "file201.md"], 176 } 177 ], 178 ) 179 assert signature.inputs == Schema([ColSpec(DataType.string, name="query")]) 180 assert signature.outputs == Schema([ 181 ColSpec(DataType.string, name="output"), 182 ColSpec(Array(DataType.string), name="candidate_ids"), 183 ColSpec(Array(DataType.string), name="candidate_sources"), 184 ]) 185 186 187 def test_signature_inference_infers_datime_types_as_expected(): 188 col_name = "datetime_col" 189 test_datetime = np.datetime64("2021-01-01") 190 test_series = pd.Series(pd.to_datetime([test_datetime])) 191 test_df = test_series.to_frame(col_name) 192 193 signature = infer_signature(test_series) 194 assert signature.inputs == Schema([ColSpec(DataType.datetime)]) 195 196 signature = infer_signature(test_df) 197 assert signature.inputs == Schema([ColSpec(DataType.datetime, name=col_name)]) 198 199 with pyspark.sql.SparkSession.builder.getOrCreate() as spark: 200 spark_df = spark.range(1).selectExpr( 201 "current_timestamp() as timestamp", "current_date() as date" 202 ) 203 signature = infer_signature(spark_df) 204 assert signature.inputs == Schema([ 205 ColSpec(DataType.datetime, name="timestamp"), 206 ColSpec(DataType.datetime, name="date"), 207 ]) 208 209 210 def test_set_signature_to_logged_model(): 211 artifact_path = "regr-model" 212 with mlflow.start_run(): 213 model_info = mlflow.sklearn.log_model(RandomForestRegressor(), name=artifact_path) 214 signature = infer_signature(np.array([1])) 215 set_signature(model_info.model_uri, signature) 216 model_info = get_model_info(model_info.model_uri) 217 assert model_info.signature == signature 218 219 220 def test_set_signature_to_saved_model(tmp_path): 221 model_path = str(tmp_path) 222 mlflow.sklearn.save_model( 223 RandomForestRegressor(), 224 model_path, 225 serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE, 226 ) 227 signature = infer_signature(np.array([1])) 228 set_signature(model_path, signature) 229 assert Model.load(model_path).signature == signature 230 231 232 def test_set_signature_overwrite(): 233 artifact_path = "regr-model" 234 with mlflow.start_run(): 235 model_info = mlflow.sklearn.log_model( 236 RandomForestRegressor(), 237 name=artifact_path, 238 signature=infer_signature(np.array([1])), 239 ) 240 new_signature = infer_signature(np.array([1]), np.array([1])) 241 set_signature(model_info.model_uri, new_signature) 242 model_info = get_model_info(model_info.model_uri) 243 assert model_info.signature == new_signature 244 245 246 def test_cannot_set_signature_on_models_scheme_uris(): 247 signature = infer_signature(np.array([1])) 248 with pytest.raises( 249 MlflowException, 250 match="Model URIs with the `models:/<name>/<version>` scheme are not supported.", 251 ): 252 set_signature("models:/dummy_model@champion", signature) 253 254 255 def test_signature_construction(): 256 signature = ModelSignature(inputs=Schema([ColSpec(DataType.binary)])) 257 assert signature.to_dict() == { 258 "inputs": '[{"type": "binary", "required": true}]', 259 "outputs": None, 260 "params": None, 261 } 262 263 signature = ModelSignature(outputs=Schema([ColSpec(DataType.double)])) 264 assert signature.to_dict() == { 265 "inputs": None, 266 "outputs": '[{"type": "double", "required": true}]', 267 "params": None, 268 } 269 270 signature = ModelSignature(params=ParamSchema([ParamSpec("param1", DataType.string, "test")])) 271 assert signature.to_dict() == { 272 "inputs": None, 273 "outputs": None, 274 "params": '[{"name": "param1", "default": "test", "shape": null, "type": "string"}]', 275 } 276 277 278 def test_signature_with_errors(): 279 with pytest.raises( 280 TypeError, 281 match=r"inputs must be either None, mlflow.models.signature.Schema, or a dataclass", 282 ): 283 ModelSignature(inputs=1) 284 285 with pytest.raises( 286 ValueError, match=r"At least one of inputs, outputs or params must be provided" 287 ): 288 ModelSignature() 289 290 291 def test_signature_for_rag(): 292 signature = ModelSignature( 293 inputs=rag_signatures.ChatCompletionRequest(), 294 outputs=rag_signatures.ChatCompletionResponse(), 295 ) 296 signature_dict = signature.to_dict() 297 assert signature_dict == { 298 "inputs": ( 299 '[{"type": "array", "items": {"type": "object", "properties": ' 300 '{"content": {"type": "string", "required": true}, ' 301 '"role": {"type": "string", "required": true}}}, ' 302 '"name": "messages", "required": true}]' 303 ), 304 "outputs": ( 305 '[{"type": "array", "items": {"type": "object", "properties": ' 306 '{"finish_reason": {"type": "string", "required": true}, ' 307 '"index": {"type": "long", "required": true}, ' 308 '"message": {"type": "object", "properties": ' 309 '{"content": {"type": "string", "required": true}, ' 310 '"role": {"type": "string", "required": true}}, ' 311 '"required": true}}}, "name": "choices", "required": true}, ' 312 '{"type": "string", "name": "object", "required": true}]' 313 ), 314 "params": None, 315 } 316 317 318 def test_infer_signature_and_convert_dataclass_to_schema_for_rag(): 319 inferred_signature = infer_signature( 320 asdict(rag_signatures.ChatCompletionRequest()), 321 asdict(rag_signatures.ChatCompletionResponse()), 322 ) 323 input_schema = convert_dataclass_to_schema(rag_signatures.ChatCompletionRequest()) 324 output_schema = convert_dataclass_to_schema(rag_signatures.ChatCompletionResponse()) 325 assert inferred_signature.inputs == input_schema 326 assert inferred_signature.outputs == output_schema 327 328 329 def test_infer_signature_with_dataclass(): 330 inferred_signature = infer_signature( 331 rag_signatures.ChatCompletionRequest(), 332 rag_signatures.ChatCompletionResponse(), 333 ) 334 input_schema = convert_dataclass_to_schema(rag_signatures.ChatCompletionRequest()) 335 output_schema = convert_dataclass_to_schema(rag_signatures.ChatCompletionResponse()) 336 assert inferred_signature.inputs == input_schema 337 assert inferred_signature.outputs == output_schema 338 339 340 @dataclass 341 class CustomInput: 342 id: int = 0 343 344 345 @dataclass 346 class CustomOutput: 347 id: int = 0 348 349 350 @dataclass 351 class FlexibleChatCompletionRequest(rag_signatures.ChatCompletionRequest): 352 custom_input: CustomInput | None = None 353 354 355 @dataclass 356 class FlexibleChatCompletionResponse(rag_signatures.ChatCompletionResponse): 357 custom_output: CustomOutput | None = None 358 359 360 def test_infer_signature_with_optional_and_child_dataclass(): 361 inferred_signature = infer_signature( 362 FlexibleChatCompletionRequest(), 363 FlexibleChatCompletionResponse(), 364 ) 365 custom_input_schema = next( 366 schema for schema in inferred_signature.inputs.to_dict() if schema["name"] == "custom_input" 367 ) 368 assert custom_input_schema["required"] is False 369 assert "id" in custom_input_schema["properties"] 370 assert any( 371 schema for schema in inferred_signature.inputs.to_dict() if schema["name"] == "messages" 372 ) 373 374 375 def test_infer_signature_for_pydantic_objects_error(): 376 class Message(pydantic.BaseModel): 377 content: str 378 role: str 379 380 m = Message(content="test", role="user") 381 with pytest.raises( 382 InvalidDataForSignatureInferenceError, 383 match=r"MLflow does not support inferring model signature from " 384 r"input example with Pydantic objects", 385 ): 386 infer_signature([m])