test_pyfunc_schema_enforcement_pyspark.py
1 from datetime import datetime 2 3 import pytest 4 from pyspark.sql import Row, SparkSession 5 from pyspark.sql.types import ( 6 ArrayType, 7 BinaryType, 8 BooleanType, 9 DateType, 10 DoubleType, 11 FloatType, 12 IntegerType, 13 LongType, 14 ShortType, 15 StringType, 16 StructField, 17 StructType, 18 TimestampType, 19 ) 20 from pyspark.testing import assertDataFrameEqual 21 22 from mlflow.exceptions import MlflowException 23 from mlflow.models.utils import _enforce_schema 24 from mlflow.types import ColSpec, DataType, Schema 25 from mlflow.types.schema import Array, Object, Property 26 27 28 @pytest.fixture(scope="module") 29 def spark(): 30 with SparkSession.builder.getOrCreate() as spark: 31 yield spark 32 33 34 def test_enforce_schema_spark_dataframe(spark): 35 spark_df_schema = StructType([ 36 StructField("smallint", ShortType(), True), 37 StructField("int", IntegerType(), True), 38 StructField("bigint", LongType(), True), 39 StructField("float", FloatType(), True), 40 StructField("double", DoubleType(), True), 41 StructField("boolean", BooleanType(), True), 42 StructField("date", DateType(), True), 43 StructField("timestamp", TimestampType(), True), 44 StructField("string", StringType(), True), 45 StructField("binary", BinaryType(), True), 46 ]) 47 48 data = [ 49 ( 50 1, # smallint 51 2, # int 52 1234567890123456789, # bigint 53 1.23, # float 54 3.456789, # double 55 True, # boolean 56 datetime(2020, 1, 1), # date 57 datetime.now(), # timestamp 58 "example string", # string 59 bytearray("example binary", "utf-8"), # binary 60 ) 61 ] 62 63 input_schema = Schema([ 64 ColSpec(DataType.integer, "smallint"), 65 ColSpec(DataType.integer, "int"), 66 ColSpec(DataType.long, "bigint"), 67 ColSpec(DataType.float, "float"), 68 ColSpec(DataType.double, "double"), 69 ColSpec(DataType.boolean, "boolean"), 70 ColSpec(DataType.datetime, "date"), 71 ColSpec(DataType.datetime, "timestamp"), 72 ColSpec(DataType.string, "string"), 73 ColSpec(DataType.binary, "binary"), 74 ]) 75 76 input_df = spark.createDataFrame(data, spark_df_schema) 77 result = _enforce_schema(input_df, input_schema) 78 assertDataFrameEqual(input_df, result) 79 80 81 @pytest.mark.parametrize( 82 ("spark_df_schema", "data", "input_schema"), 83 [ 84 ( 85 StructType([StructField("query", ArrayType(StringType()), True)]), 86 [(["sentence_1", "sentence_2"],)], 87 Schema([ColSpec(Array(DataType.string), name="query")]), 88 ), 89 ( 90 StructType([ 91 StructField( 92 "teststruct", 93 StructType([ 94 StructField("smallint", ShortType(), True), 95 StructField("int", IntegerType(), True), 96 StructField("bigint", LongType(), True), 97 StructField("float", FloatType(), True), 98 StructField("double", DoubleType(), True), 99 StructField("boolean", BooleanType(), True), 100 StructField("date", DateType(), True), 101 StructField("timestamp", TimestampType(), True), 102 StructField("string", StringType(), True), 103 StructField("binary", BinaryType(), True), 104 ]), 105 True, 106 ) 107 ]), 108 [ 109 Row( 110 teststruct=Row( 111 smallint=100, 112 int=1000, 113 bigint=10000000000, 114 float=10.5, 115 double=20.5, 116 boolean=True, 117 date=datetime(2020, 1, 1), 118 timestamp=datetime.now(), 119 string="example", 120 binary=b"binary_data", 121 ) 122 ), 123 Row( 124 teststruct=Row( 125 smallint=200, 126 int=2000, 127 bigint=20000000000, 128 float=20.5, 129 double=30.5, 130 boolean=False, 131 date=datetime(2020, 1, 1), 132 timestamp=datetime.now(), 133 string="sample", 134 binary=b"sample_data", 135 ) 136 ), 137 Row( 138 teststruct=Row( 139 smallint=300, 140 int=3000, 141 bigint=30000000000, 142 float=30.5, 143 double=40.5, 144 boolean=True, 145 date=datetime(2020, 1, 1), 146 timestamp=datetime.now(), 147 string="data", 148 binary=b"data_binary", 149 ) 150 ), 151 ], 152 Schema([ 153 ColSpec( 154 Object([ 155 Property("smallint", DataType.integer), 156 Property("int", DataType.integer), 157 Property("bigint", DataType.long), 158 Property("float", DataType.float), 159 Property("double", DataType.double), 160 Property("boolean", DataType.boolean), 161 Property("date", DataType.datetime), 162 Property("timestamp", DataType.datetime), 163 Property("string", DataType.string), 164 Property("binary", DataType.binary), 165 ]), 166 "teststruct", 167 ) 168 ]), 169 ), 170 ( 171 StructType([ 172 StructField( 173 "array", 174 ArrayType( 175 StructType([ 176 StructField("name", StringType(), True), 177 StructField("age", DoubleType(), True), 178 ]) 179 ), 180 True, 181 ) 182 ]), 183 [ 184 ( 185 [ 186 Row(name="Alice", age=30.0), 187 Row(name="Bob", age=25.0), 188 Row(name="Catherine", age=35.0), 189 ], 190 ) 191 ], 192 Schema([ 193 ColSpec( 194 Array( 195 Object([ 196 Property("name", DataType.string), 197 Property("age", DataType.double), 198 ]) 199 ), 200 name="array", 201 ), 202 ]), 203 ), 204 ( 205 StructType([StructField("nested_list", ArrayType(ArrayType(IntegerType())), True)]), 206 [ 207 ([[1, 2, 3], [4, 5, 6], [7, 8, 9]],), 208 ([[10, 11], [12, 13, 14]],), 209 ], 210 Schema([ColSpec(Array(Array(DataType.integer)), name="nested_list")]), 211 ), 212 ], 213 ) 214 def test_enforce_schema_spark_dataframe_complex(spark_df_schema, data, input_schema, spark): 215 input_df = spark.createDataFrame(data, spark_df_schema) 216 result = _enforce_schema(input_df, input_schema) 217 assertDataFrameEqual(input_df, result) 218 219 220 def test_enforce_schema_spark_dataframe_missing_col(spark): 221 spark_df_schema = StructType([ 222 StructField("smallint", ShortType(), True), 223 StructField("int", IntegerType(), True), 224 ]) 225 226 data = [ 227 ( 228 1, # smallint 229 2, # int 230 ) 231 ] 232 233 input_schema = Schema([ 234 ColSpec(DataType.integer, "smallint"), 235 ColSpec(DataType.integer, "int"), 236 ColSpec(DataType.long, "bigint"), 237 ]) 238 239 df = spark.createDataFrame(data, spark_df_schema) 240 with pytest.raises(MlflowException, match="Model is missing inputs"): 241 _enforce_schema(df, input_schema) 242 243 244 def test_enforce_schema_spark_dataframe_incompatible_type(spark): 245 spark_df_schema = StructType([ 246 StructField("a", ShortType(), True), 247 StructField("b", DoubleType(), True), 248 ]) 249 250 data = [ 251 ( 252 1, # a 253 2.3, # b 254 ) 255 ] 256 257 input_schema = Schema([ 258 ColSpec(DataType.integer, "a"), 259 ColSpec(DataType.integer, "b"), 260 ]) 261 262 df = spark.createDataFrame(data, spark_df_schema) 263 with pytest.raises(MlflowException, match="Incompatible input types"): 264 _enforce_schema(df, input_schema) 265 266 267 def test_enforce_schema_spark_dataframe_incompatible_type_complex(spark): 268 spark_df_schema = StructType([ 269 StructField( 270 "teststruct", 271 StructType([ 272 StructField("int", IntegerType(), True), 273 StructField("double", DoubleType(), True), 274 ]), 275 ) 276 ]) 277 278 data = [ 279 Row( 280 teststruct=Row( 281 int=1000, 282 double=20.5, 283 ) 284 ) 285 ] 286 287 input_schema = Schema([ 288 ColSpec( 289 Object([ 290 Property("int", DataType.integer), 291 Property("double", DataType.string), 292 ]) 293 ) 294 ]) 295 296 df = spark.createDataFrame(data, spark_df_schema) 297 with pytest.raises(MlflowException, match="Failed to enforce schema"): 298 _enforce_schema(df, input_schema) 299 300 301 def test_enforce_schema_spark_dataframe_extra_col(spark): 302 spark_df_schema = StructType([ 303 StructField("a", ShortType(), True), 304 StructField("b", DoubleType(), True), 305 ]) 306 307 data = [ 308 ( 309 1, # a 310 2.3, # b 311 ) 312 ] 313 314 input_schema = Schema([ColSpec(DataType.integer, "a")]) 315 316 df = spark.createDataFrame(data, spark_df_schema) 317 result = _enforce_schema(df, input_schema) 318 expected_result = df.drop("b") 319 assertDataFrameEqual(result, expected_result) 320 321 322 def test_enforce_schema_spark_dataframe_no_schema(spark): 323 data = [ 324 ( 325 1, # a 326 2.3, # b 327 ) 328 ] 329 330 input_schema = Schema([ 331 ColSpec(DataType.integer, "a"), 332 ColSpec(DataType.double, "b"), 333 ]) 334 335 df = spark.createDataFrame(data, ["a", "b"]) 336 with pytest.raises(MlflowException, match="Incompatible input types"): 337 _enforce_schema(df, input_schema)