test_delta_dataset_source.py
1 import json 2 from unittest import mock 3 4 import pandas as pd 5 import pytest 6 7 from mlflow.data.dataset_source_registry import get_dataset_source_from_json 8 from mlflow.data.delta_dataset_source import DeltaDatasetSource 9 from mlflow.exceptions import MlflowException 10 from mlflow.protos.databricks_managed_catalog_messages_pb2 import GetTable, GetTableResponse 11 from mlflow.utils.proto_json_utils import message_to_json 12 13 14 @pytest.fixture(scope="module") 15 def spark_session(): 16 from pyspark.sql import SparkSession 17 18 with ( 19 SparkSession.builder 20 .master("local[*]") 21 .config("spark.jars.packages", "io.delta:delta-spark_2.13:4.0.0") 22 .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") 23 .config( 24 "spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog" 25 ) 26 .getOrCreate() 27 ) as session: 28 yield session 29 30 31 def test_delta_dataset_source_from_path(spark_session, tmp_path): 32 df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"]) 33 df_spark = spark_session.createDataFrame(df) 34 path = str(tmp_path / "temp.delta") 35 df_spark.write.format("delta").mode("overwrite").save(path) 36 37 delta_datasource = DeltaDatasetSource(path=path) 38 loaded_df_spark = delta_datasource.load() 39 assert loaded_df_spark.count() == df_spark.count() 40 assert delta_datasource.to_dict()["path"] == path 41 42 reloaded_source = get_dataset_source_from_json( 43 delta_datasource.to_json(), source_type=delta_datasource._get_source_type() 44 ) 45 assert isinstance(reloaded_source, DeltaDatasetSource) 46 assert type(delta_datasource) == type(reloaded_source) 47 assert reloaded_source.to_json() == delta_datasource.to_json() 48 49 50 def test_delta_dataset_source_from_table(spark_session, tmp_path): 51 df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"]) 52 df_spark = spark_session.createDataFrame(df) 53 df_spark.write.format("delta").mode("overwrite").saveAsTable( 54 "default.temp_delta", path=tmp_path 55 ) 56 57 delta_datasource = DeltaDatasetSource(delta_table_name="temp_delta") 58 loaded_df_spark = delta_datasource.load() 59 assert loaded_df_spark.count() == df_spark.count() 60 assert delta_datasource.to_dict()["delta_table_name"] == "temp_delta" 61 62 reloaded_source = get_dataset_source_from_json( 63 delta_datasource.to_json(), source_type=delta_datasource._get_source_type() 64 ) 65 assert isinstance(reloaded_source, DeltaDatasetSource) 66 assert type(delta_datasource) == type(reloaded_source) 67 assert reloaded_source.to_json() == delta_datasource.to_json() 68 69 70 def test_delta_dataset_source_from_table_versioned(spark_session, tmp_path): 71 df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"]) 72 df_spark = spark_session.createDataFrame(df) 73 df_spark.write.format("delta").mode("overwrite").saveAsTable( 74 "default.temp_delta_versioned", path=tmp_path 75 ) 76 77 df2 = pd.DataFrame([[1, 2, 3]], columns=["a", "b", "c"]) 78 df2_spark = spark_session.createDataFrame(df2) 79 df2_spark.write.format("delta").mode("overwrite").saveAsTable( 80 "default.temp_delta_versioned", path=tmp_path 81 ) 82 83 delta_datasource = DeltaDatasetSource( 84 delta_table_name="temp_delta_versioned", delta_table_version=1 85 ) 86 loaded_df_spark = delta_datasource.load() 87 assert loaded_df_spark.count() == df2_spark.count() 88 config = delta_datasource.to_dict() 89 assert config["delta_table_name"] == "temp_delta_versioned" 90 assert config["delta_table_version"] == 1 91 92 reloaded_source = get_dataset_source_from_json( 93 delta_datasource.to_json(), source_type=delta_datasource._get_source_type() 94 ) 95 assert isinstance(reloaded_source, DeltaDatasetSource) 96 assert type(delta_datasource) == type(reloaded_source) 97 assert reloaded_source.to_json() == delta_datasource.to_json() 98 99 100 def test_delta_dataset_source_too_many_inputs(spark_session, tmp_path): 101 df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"]) 102 df_spark = spark_session.createDataFrame(df) 103 df_spark.write.format("delta").mode("overwrite").saveAsTable( 104 "default.temp_delta_too_many_inputs", path=tmp_path 105 ) 106 107 with pytest.raises(MlflowException, match='Must specify exactly one of "path" or "table_name"'): 108 DeltaDatasetSource(path=tmp_path, delta_table_name="temp_delta_too_many_inputs") 109 110 111 def test_uc_table_id_retrieval_works(spark_session, tmp_path): 112 def mock_resolve_table_name(table_name, spark): 113 if table_name == "temp_delta_versioned_with_id": 114 return "default.temp_delta_versioned_with_id" 115 return table_name 116 117 def mock_lookup_table_id(table_name): 118 if table_name == "default.temp_delta_versioned_with_id": 119 return "uc_table_id_1" 120 return None 121 122 with ( 123 mock.patch( 124 "mlflow.data.delta_dataset_source.get_full_name_from_sc", 125 side_effect=mock_resolve_table_name, 126 ), 127 mock.patch( 128 "mlflow.data.delta_dataset_source.DeltaDatasetSource._lookup_table_id", 129 side_effect=mock_lookup_table_id, 130 ), 131 mock.patch( 132 "mlflow.data.delta_dataset_source._get_active_spark_session", 133 return_value=None, 134 ), 135 mock.patch( 136 "mlflow.data.delta_dataset_source.DeltaDatasetSource._is_databricks_uc_table", 137 return_value=True, 138 ), 139 ): 140 df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"]) 141 df_spark = spark_session.createDataFrame(df) 142 df_spark.write.format("delta").mode("overwrite").saveAsTable( 143 "default.temp_delta_versioned_with_id", path=tmp_path 144 ) 145 146 df2 = pd.DataFrame([[1, 2, 3]], columns=["a", "b", "c"]) 147 df2_spark = spark_session.createDataFrame(df2) 148 df2_spark.write.format("delta").mode("overwrite").saveAsTable( 149 "default.temp_delta_versioned_with_id", path=tmp_path 150 ) 151 152 delta_datasource = DeltaDatasetSource( 153 delta_table_name="temp_delta_versioned_with_id", delta_table_version=1 154 ) 155 loaded_df_spark = delta_datasource.load() 156 assert loaded_df_spark.count() == df2_spark.count() 157 assert delta_datasource.to_json() == json.dumps({ 158 "delta_table_name": "default.temp_delta_versioned_with_id", 159 "delta_table_version": 1, 160 "is_databricks_uc_table": True, 161 "delta_table_id": "uc_table_id_1", 162 }) 163 164 165 def _args(endpoint, json_body): 166 return { 167 "host_creds": None, 168 "endpoint": f"/api/2.0/unity-catalog/tables/{endpoint}", 169 "method": "GET", 170 "json_body": json_body, 171 "response_proto": GetTableResponse, 172 } 173 174 175 @pytest.mark.parametrize( 176 ("call_endpoint_response", "expected_lookup_response", "test_table_name"), 177 [ 178 (None, None, "delta_table_1"), 179 (Exception("Exception from call_endpoint"), None, "delta_table_2"), 180 (GetTableResponse(table_id="uc_table_id_1"), "uc_table_id_1", "delta_table_3"), 181 ], 182 ) 183 def test_lookup_table_id( 184 call_endpoint_response, expected_lookup_response, test_table_name, tmp_path 185 ): 186 def mock_resolve_table_name(table_name, spark): 187 if table_name == test_table_name: 188 return f"default.{test_table_name}" 189 return table_name 190 191 def mock_call_endpoint(host_creds, endpoint, method, json_body, response_proto): 192 if isinstance(call_endpoint_response, Exception): 193 raise call_endpoint_response 194 return call_endpoint_response 195 196 with ( 197 mock.patch( 198 "mlflow.data.delta_dataset_source.get_full_name_from_sc", 199 side_effect=mock_resolve_table_name, 200 ), 201 mock.patch( 202 "mlflow.data.delta_dataset_source._get_active_spark_session", 203 return_value=None, 204 ), 205 mock.patch( 206 "mlflow.data.delta_dataset_source.get_databricks_host_creds", 207 return_value=None, 208 ), 209 mock.patch( 210 "mlflow.data.delta_dataset_source.DeltaDatasetSource._is_databricks_uc_table", 211 return_value=True, 212 ), 213 mock.patch( 214 "mlflow.data.delta_dataset_source.call_endpoint", 215 side_effect=mock_call_endpoint, 216 ) as mock_endpoint, 217 ): 218 delta_datasource = DeltaDatasetSource( 219 delta_table_name=test_table_name, delta_table_version=1 220 ) 221 assert delta_datasource._lookup_table_id(test_table_name) == expected_lookup_response 222 req_body = message_to_json(GetTable(full_name_arg=test_table_name)) 223 call_args = _args(test_table_name, req_body) 224 mock_endpoint.assert_any_call(**call_args) 225 226 227 @pytest.mark.parametrize( 228 ("table_name", "expected_result"), 229 [ 230 ("default.test", True), 231 ("hive_metastore.test", False), 232 ("spark_catalog.test", False), 233 ("samples.test", False), 234 ], 235 ) 236 def test_is_databricks_uc_table(table_name, expected_result): 237 with ( 238 mock.patch( 239 "mlflow.data.delta_dataset_source.get_full_name_from_sc", 240 return_value=table_name, 241 ), 242 mock.patch( 243 "mlflow.data.delta_dataset_source._get_active_spark_session", 244 return_value=None, 245 ), 246 ): 247 delta_datasource = DeltaDatasetSource(delta_table_name=table_name, delta_table_version=1) 248 assert delta_datasource._is_databricks_uc_table() == expected_result