/ tests / data / test_delta_dataset_source.py
test_delta_dataset_source.py
  1  import json
  2  from unittest import mock
  3  
  4  import pandas as pd
  5  import pytest
  6  
  7  from mlflow.data.dataset_source_registry import get_dataset_source_from_json
  8  from mlflow.data.delta_dataset_source import DeltaDatasetSource
  9  from mlflow.exceptions import MlflowException
 10  from mlflow.protos.databricks_managed_catalog_messages_pb2 import GetTable, GetTableResponse
 11  from mlflow.utils.proto_json_utils import message_to_json
 12  
 13  
 14  @pytest.fixture(scope="module")
 15  def spark_session():
 16      from pyspark.sql import SparkSession
 17  
 18      with (
 19          SparkSession.builder
 20          .master("local[*]")
 21          .config("spark.jars.packages", "io.delta:delta-spark_2.13:4.0.0")
 22          .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
 23          .config(
 24              "spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog"
 25          )
 26          .getOrCreate()
 27      ) as session:
 28          yield session
 29  
 30  
 31  def test_delta_dataset_source_from_path(spark_session, tmp_path):
 32      df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"])
 33      df_spark = spark_session.createDataFrame(df)
 34      path = str(tmp_path / "temp.delta")
 35      df_spark.write.format("delta").mode("overwrite").save(path)
 36  
 37      delta_datasource = DeltaDatasetSource(path=path)
 38      loaded_df_spark = delta_datasource.load()
 39      assert loaded_df_spark.count() == df_spark.count()
 40      assert delta_datasource.to_dict()["path"] == path
 41  
 42      reloaded_source = get_dataset_source_from_json(
 43          delta_datasource.to_json(), source_type=delta_datasource._get_source_type()
 44      )
 45      assert isinstance(reloaded_source, DeltaDatasetSource)
 46      assert type(delta_datasource) == type(reloaded_source)
 47      assert reloaded_source.to_json() == delta_datasource.to_json()
 48  
 49  
 50  def test_delta_dataset_source_from_table(spark_session, tmp_path):
 51      df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"])
 52      df_spark = spark_session.createDataFrame(df)
 53      df_spark.write.format("delta").mode("overwrite").saveAsTable(
 54          "default.temp_delta", path=tmp_path
 55      )
 56  
 57      delta_datasource = DeltaDatasetSource(delta_table_name="temp_delta")
 58      loaded_df_spark = delta_datasource.load()
 59      assert loaded_df_spark.count() == df_spark.count()
 60      assert delta_datasource.to_dict()["delta_table_name"] == "temp_delta"
 61  
 62      reloaded_source = get_dataset_source_from_json(
 63          delta_datasource.to_json(), source_type=delta_datasource._get_source_type()
 64      )
 65      assert isinstance(reloaded_source, DeltaDatasetSource)
 66      assert type(delta_datasource) == type(reloaded_source)
 67      assert reloaded_source.to_json() == delta_datasource.to_json()
 68  
 69  
 70  def test_delta_dataset_source_from_table_versioned(spark_session, tmp_path):
 71      df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"])
 72      df_spark = spark_session.createDataFrame(df)
 73      df_spark.write.format("delta").mode("overwrite").saveAsTable(
 74          "default.temp_delta_versioned", path=tmp_path
 75      )
 76  
 77      df2 = pd.DataFrame([[1, 2, 3]], columns=["a", "b", "c"])
 78      df2_spark = spark_session.createDataFrame(df2)
 79      df2_spark.write.format("delta").mode("overwrite").saveAsTable(
 80          "default.temp_delta_versioned", path=tmp_path
 81      )
 82  
 83      delta_datasource = DeltaDatasetSource(
 84          delta_table_name="temp_delta_versioned", delta_table_version=1
 85      )
 86      loaded_df_spark = delta_datasource.load()
 87      assert loaded_df_spark.count() == df2_spark.count()
 88      config = delta_datasource.to_dict()
 89      assert config["delta_table_name"] == "temp_delta_versioned"
 90      assert config["delta_table_version"] == 1
 91  
 92      reloaded_source = get_dataset_source_from_json(
 93          delta_datasource.to_json(), source_type=delta_datasource._get_source_type()
 94      )
 95      assert isinstance(reloaded_source, DeltaDatasetSource)
 96      assert type(delta_datasource) == type(reloaded_source)
 97      assert reloaded_source.to_json() == delta_datasource.to_json()
 98  
 99  
100  def test_delta_dataset_source_too_many_inputs(spark_session, tmp_path):
101      df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"])
102      df_spark = spark_session.createDataFrame(df)
103      df_spark.write.format("delta").mode("overwrite").saveAsTable(
104          "default.temp_delta_too_many_inputs", path=tmp_path
105      )
106  
107      with pytest.raises(MlflowException, match='Must specify exactly one of "path" or "table_name"'):
108          DeltaDatasetSource(path=tmp_path, delta_table_name="temp_delta_too_many_inputs")
109  
110  
111  def test_uc_table_id_retrieval_works(spark_session, tmp_path):
112      def mock_resolve_table_name(table_name, spark):
113          if table_name == "temp_delta_versioned_with_id":
114              return "default.temp_delta_versioned_with_id"
115          return table_name
116  
117      def mock_lookup_table_id(table_name):
118          if table_name == "default.temp_delta_versioned_with_id":
119              return "uc_table_id_1"
120          return None
121  
122      with (
123          mock.patch(
124              "mlflow.data.delta_dataset_source.get_full_name_from_sc",
125              side_effect=mock_resolve_table_name,
126          ),
127          mock.patch(
128              "mlflow.data.delta_dataset_source.DeltaDatasetSource._lookup_table_id",
129              side_effect=mock_lookup_table_id,
130          ),
131          mock.patch(
132              "mlflow.data.delta_dataset_source._get_active_spark_session",
133              return_value=None,
134          ),
135          mock.patch(
136              "mlflow.data.delta_dataset_source.DeltaDatasetSource._is_databricks_uc_table",
137              return_value=True,
138          ),
139      ):
140          df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"])
141          df_spark = spark_session.createDataFrame(df)
142          df_spark.write.format("delta").mode("overwrite").saveAsTable(
143              "default.temp_delta_versioned_with_id", path=tmp_path
144          )
145  
146          df2 = pd.DataFrame([[1, 2, 3]], columns=["a", "b", "c"])
147          df2_spark = spark_session.createDataFrame(df2)
148          df2_spark.write.format("delta").mode("overwrite").saveAsTable(
149              "default.temp_delta_versioned_with_id", path=tmp_path
150          )
151  
152          delta_datasource = DeltaDatasetSource(
153              delta_table_name="temp_delta_versioned_with_id", delta_table_version=1
154          )
155          loaded_df_spark = delta_datasource.load()
156          assert loaded_df_spark.count() == df2_spark.count()
157          assert delta_datasource.to_json() == json.dumps({
158              "delta_table_name": "default.temp_delta_versioned_with_id",
159              "delta_table_version": 1,
160              "is_databricks_uc_table": True,
161              "delta_table_id": "uc_table_id_1",
162          })
163  
164  
165  def _args(endpoint, json_body):
166      return {
167          "host_creds": None,
168          "endpoint": f"/api/2.0/unity-catalog/tables/{endpoint}",
169          "method": "GET",
170          "json_body": json_body,
171          "response_proto": GetTableResponse,
172      }
173  
174  
175  @pytest.mark.parametrize(
176      ("call_endpoint_response", "expected_lookup_response", "test_table_name"),
177      [
178          (None, None, "delta_table_1"),
179          (Exception("Exception from call_endpoint"), None, "delta_table_2"),
180          (GetTableResponse(table_id="uc_table_id_1"), "uc_table_id_1", "delta_table_3"),
181      ],
182  )
183  def test_lookup_table_id(
184      call_endpoint_response, expected_lookup_response, test_table_name, tmp_path
185  ):
186      def mock_resolve_table_name(table_name, spark):
187          if table_name == test_table_name:
188              return f"default.{test_table_name}"
189          return table_name
190  
191      def mock_call_endpoint(host_creds, endpoint, method, json_body, response_proto):
192          if isinstance(call_endpoint_response, Exception):
193              raise call_endpoint_response
194          return call_endpoint_response
195  
196      with (
197          mock.patch(
198              "mlflow.data.delta_dataset_source.get_full_name_from_sc",
199              side_effect=mock_resolve_table_name,
200          ),
201          mock.patch(
202              "mlflow.data.delta_dataset_source._get_active_spark_session",
203              return_value=None,
204          ),
205          mock.patch(
206              "mlflow.data.delta_dataset_source.get_databricks_host_creds",
207              return_value=None,
208          ),
209          mock.patch(
210              "mlflow.data.delta_dataset_source.DeltaDatasetSource._is_databricks_uc_table",
211              return_value=True,
212          ),
213          mock.patch(
214              "mlflow.data.delta_dataset_source.call_endpoint",
215              side_effect=mock_call_endpoint,
216          ) as mock_endpoint,
217      ):
218          delta_datasource = DeltaDatasetSource(
219              delta_table_name=test_table_name, delta_table_version=1
220          )
221          assert delta_datasource._lookup_table_id(test_table_name) == expected_lookup_response
222          req_body = message_to_json(GetTable(full_name_arg=test_table_name))
223          call_args = _args(test_table_name, req_body)
224          mock_endpoint.assert_any_call(**call_args)
225  
226  
227  @pytest.mark.parametrize(
228      ("table_name", "expected_result"),
229      [
230          ("default.test", True),
231          ("hive_metastore.test", False),
232          ("spark_catalog.test", False),
233          ("samples.test", False),
234      ],
235  )
236  def test_is_databricks_uc_table(table_name, expected_result):
237      with (
238          mock.patch(
239              "mlflow.data.delta_dataset_source.get_full_name_from_sc",
240              return_value=table_name,
241          ),
242          mock.patch(
243              "mlflow.data.delta_dataset_source._get_active_spark_session",
244              return_value=None,
245          ),
246      ):
247          delta_datasource = DeltaDatasetSource(delta_table_name=table_name, delta_table_version=1)
248          assert delta_datasource._is_databricks_uc_table() == expected_result