/ tests / pyfunc / test_pyfunc_schema_enforcement_pyspark.py
test_pyfunc_schema_enforcement_pyspark.py
  1  from datetime import datetime
  2  
  3  import pytest
  4  from pyspark.sql import Row, SparkSession
  5  from pyspark.sql.types import (
  6      ArrayType,
  7      BinaryType,
  8      BooleanType,
  9      DateType,
 10      DoubleType,
 11      FloatType,
 12      IntegerType,
 13      LongType,
 14      ShortType,
 15      StringType,
 16      StructField,
 17      StructType,
 18      TimestampType,
 19  )
 20  from pyspark.testing import assertDataFrameEqual
 21  
 22  from mlflow.exceptions import MlflowException
 23  from mlflow.models.utils import _enforce_schema
 24  from mlflow.types import ColSpec, DataType, Schema
 25  from mlflow.types.schema import Array, Object, Property
 26  
 27  
 28  @pytest.fixture(scope="module")
 29  def spark():
 30      with SparkSession.builder.getOrCreate() as spark:
 31          yield spark
 32  
 33  
 34  def test_enforce_schema_spark_dataframe(spark):
 35      spark_df_schema = StructType([
 36          StructField("smallint", ShortType(), True),
 37          StructField("int", IntegerType(), True),
 38          StructField("bigint", LongType(), True),
 39          StructField("float", FloatType(), True),
 40          StructField("double", DoubleType(), True),
 41          StructField("boolean", BooleanType(), True),
 42          StructField("date", DateType(), True),
 43          StructField("timestamp", TimestampType(), True),
 44          StructField("string", StringType(), True),
 45          StructField("binary", BinaryType(), True),
 46      ])
 47  
 48      data = [
 49          (
 50              1,  # smallint
 51              2,  # int
 52              1234567890123456789,  # bigint
 53              1.23,  # float
 54              3.456789,  # double
 55              True,  # boolean
 56              datetime(2020, 1, 1),  # date
 57              datetime.now(),  # timestamp
 58              "example string",  # string
 59              bytearray("example binary", "utf-8"),  # binary
 60          )
 61      ]
 62  
 63      input_schema = Schema([
 64          ColSpec(DataType.integer, "smallint"),
 65          ColSpec(DataType.integer, "int"),
 66          ColSpec(DataType.long, "bigint"),
 67          ColSpec(DataType.float, "float"),
 68          ColSpec(DataType.double, "double"),
 69          ColSpec(DataType.boolean, "boolean"),
 70          ColSpec(DataType.datetime, "date"),
 71          ColSpec(DataType.datetime, "timestamp"),
 72          ColSpec(DataType.string, "string"),
 73          ColSpec(DataType.binary, "binary"),
 74      ])
 75  
 76      input_df = spark.createDataFrame(data, spark_df_schema)
 77      result = _enforce_schema(input_df, input_schema)
 78      assertDataFrameEqual(input_df, result)
 79  
 80  
 81  @pytest.mark.parametrize(
 82      ("spark_df_schema", "data", "input_schema"),
 83      [
 84          (
 85              StructType([StructField("query", ArrayType(StringType()), True)]),
 86              [(["sentence_1", "sentence_2"],)],
 87              Schema([ColSpec(Array(DataType.string), name="query")]),
 88          ),
 89          (
 90              StructType([
 91                  StructField(
 92                      "teststruct",
 93                      StructType([
 94                          StructField("smallint", ShortType(), True),
 95                          StructField("int", IntegerType(), True),
 96                          StructField("bigint", LongType(), True),
 97                          StructField("float", FloatType(), True),
 98                          StructField("double", DoubleType(), True),
 99                          StructField("boolean", BooleanType(), True),
100                          StructField("date", DateType(), True),
101                          StructField("timestamp", TimestampType(), True),
102                          StructField("string", StringType(), True),
103                          StructField("binary", BinaryType(), True),
104                      ]),
105                      True,
106                  )
107              ]),
108              [
109                  Row(
110                      teststruct=Row(
111                          smallint=100,
112                          int=1000,
113                          bigint=10000000000,
114                          float=10.5,
115                          double=20.5,
116                          boolean=True,
117                          date=datetime(2020, 1, 1),
118                          timestamp=datetime.now(),
119                          string="example",
120                          binary=b"binary_data",
121                      )
122                  ),
123                  Row(
124                      teststruct=Row(
125                          smallint=200,
126                          int=2000,
127                          bigint=20000000000,
128                          float=20.5,
129                          double=30.5,
130                          boolean=False,
131                          date=datetime(2020, 1, 1),
132                          timestamp=datetime.now(),
133                          string="sample",
134                          binary=b"sample_data",
135                      )
136                  ),
137                  Row(
138                      teststruct=Row(
139                          smallint=300,
140                          int=3000,
141                          bigint=30000000000,
142                          float=30.5,
143                          double=40.5,
144                          boolean=True,
145                          date=datetime(2020, 1, 1),
146                          timestamp=datetime.now(),
147                          string="data",
148                          binary=b"data_binary",
149                      )
150                  ),
151              ],
152              Schema([
153                  ColSpec(
154                      Object([
155                          Property("smallint", DataType.integer),
156                          Property("int", DataType.integer),
157                          Property("bigint", DataType.long),
158                          Property("float", DataType.float),
159                          Property("double", DataType.double),
160                          Property("boolean", DataType.boolean),
161                          Property("date", DataType.datetime),
162                          Property("timestamp", DataType.datetime),
163                          Property("string", DataType.string),
164                          Property("binary", DataType.binary),
165                      ]),
166                      "teststruct",
167                  )
168              ]),
169          ),
170          (
171              StructType([
172                  StructField(
173                      "array",
174                      ArrayType(
175                          StructType([
176                              StructField("name", StringType(), True),
177                              StructField("age", DoubleType(), True),
178                          ])
179                      ),
180                      True,
181                  )
182              ]),
183              [
184                  (
185                      [
186                          Row(name="Alice", age=30.0),
187                          Row(name="Bob", age=25.0),
188                          Row(name="Catherine", age=35.0),
189                      ],
190                  )
191              ],
192              Schema([
193                  ColSpec(
194                      Array(
195                          Object([
196                              Property("name", DataType.string),
197                              Property("age", DataType.double),
198                          ])
199                      ),
200                      name="array",
201                  ),
202              ]),
203          ),
204          (
205              StructType([StructField("nested_list", ArrayType(ArrayType(IntegerType())), True)]),
206              [
207                  ([[1, 2, 3], [4, 5, 6], [7, 8, 9]],),
208                  ([[10, 11], [12, 13, 14]],),
209              ],
210              Schema([ColSpec(Array(Array(DataType.integer)), name="nested_list")]),
211          ),
212      ],
213  )
214  def test_enforce_schema_spark_dataframe_complex(spark_df_schema, data, input_schema, spark):
215      input_df = spark.createDataFrame(data, spark_df_schema)
216      result = _enforce_schema(input_df, input_schema)
217      assertDataFrameEqual(input_df, result)
218  
219  
220  def test_enforce_schema_spark_dataframe_missing_col(spark):
221      spark_df_schema = StructType([
222          StructField("smallint", ShortType(), True),
223          StructField("int", IntegerType(), True),
224      ])
225  
226      data = [
227          (
228              1,  # smallint
229              2,  # int
230          )
231      ]
232  
233      input_schema = Schema([
234          ColSpec(DataType.integer, "smallint"),
235          ColSpec(DataType.integer, "int"),
236          ColSpec(DataType.long, "bigint"),
237      ])
238  
239      df = spark.createDataFrame(data, spark_df_schema)
240      with pytest.raises(MlflowException, match="Model is missing inputs"):
241          _enforce_schema(df, input_schema)
242  
243  
244  def test_enforce_schema_spark_dataframe_incompatible_type(spark):
245      spark_df_schema = StructType([
246          StructField("a", ShortType(), True),
247          StructField("b", DoubleType(), True),
248      ])
249  
250      data = [
251          (
252              1,  # a
253              2.3,  # b
254          )
255      ]
256  
257      input_schema = Schema([
258          ColSpec(DataType.integer, "a"),
259          ColSpec(DataType.integer, "b"),
260      ])
261  
262      df = spark.createDataFrame(data, spark_df_schema)
263      with pytest.raises(MlflowException, match="Incompatible input types"):
264          _enforce_schema(df, input_schema)
265  
266  
267  def test_enforce_schema_spark_dataframe_incompatible_type_complex(spark):
268      spark_df_schema = StructType([
269          StructField(
270              "teststruct",
271              StructType([
272                  StructField("int", IntegerType(), True),
273                  StructField("double", DoubleType(), True),
274              ]),
275          )
276      ])
277  
278      data = [
279          Row(
280              teststruct=Row(
281                  int=1000,
282                  double=20.5,
283              )
284          )
285      ]
286  
287      input_schema = Schema([
288          ColSpec(
289              Object([
290                  Property("int", DataType.integer),
291                  Property("double", DataType.string),
292              ])
293          )
294      ])
295  
296      df = spark.createDataFrame(data, spark_df_schema)
297      with pytest.raises(MlflowException, match="Failed to enforce schema"):
298          _enforce_schema(df, input_schema)
299  
300  
301  def test_enforce_schema_spark_dataframe_extra_col(spark):
302      spark_df_schema = StructType([
303          StructField("a", ShortType(), True),
304          StructField("b", DoubleType(), True),
305      ])
306  
307      data = [
308          (
309              1,  # a
310              2.3,  # b
311          )
312      ]
313  
314      input_schema = Schema([ColSpec(DataType.integer, "a")])
315  
316      df = spark.createDataFrame(data, spark_df_schema)
317      result = _enforce_schema(df, input_schema)
318      expected_result = df.drop("b")
319      assertDataFrameEqual(result, expected_result)
320  
321  
322  def test_enforce_schema_spark_dataframe_no_schema(spark):
323      data = [
324          (
325              1,  # a
326              2.3,  # b
327          )
328      ]
329  
330      input_schema = Schema([
331          ColSpec(DataType.integer, "a"),
332          ColSpec(DataType.double, "b"),
333      ])
334  
335      df = spark.createDataFrame(data, ["a", "b"])
336      with pytest.raises(MlflowException, match="Incompatible input types"):
337          _enforce_schema(df, input_schema)