/ mlflow / data / delta_dataset_source.py
delta_dataset_source.py
  1  import logging
  2  from typing import Any
  3  
  4  from mlflow.data.dataset_source import DatasetSource
  5  from mlflow.exceptions import MlflowException
  6  from mlflow.protos.databricks_managed_catalog_messages_pb2 import (
  7      GetTable,
  8      GetTableResponse,
  9  )
 10  from mlflow.protos.databricks_managed_catalog_service_pb2 import DatabricksUnityCatalogService
 11  from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
 12  from mlflow.utils._spark_utils import _get_active_spark_session
 13  from mlflow.utils._unity_catalog_utils import get_full_name_from_sc
 14  from mlflow.utils.databricks_utils import get_databricks_host_creds
 15  from mlflow.utils.proto_json_utils import message_to_json
 16  from mlflow.utils.rest_utils import (
 17      _REST_API_PATH_PREFIX,
 18      call_endpoint,
 19      extract_api_info_for_service,
 20  )
 21  from mlflow.utils.string_utils import _backtick_quote
 22  
 23  DATABRICKS_HIVE_METASTORE_NAME = "hive_metastore"
 24  # these two catalog names both points to the workspace local default HMS (hive metastore).
 25  DATABRICKS_LOCAL_METASTORE_NAMES = [DATABRICKS_HIVE_METASTORE_NAME, "spark_catalog"]
 26  # samples catalog is managed by databricks for hosting public dataset like NYC taxi dataset.
 27  # it is neither a UC nor local metastore catalog
 28  DATABRICKS_SAMPLES_CATALOG_NAME = "samples"
 29  
 30  _logger = logging.getLogger(__name__)
 31  
 32  
 33  class DeltaDatasetSource(DatasetSource):
 34      """
 35      Represents the source of a dataset stored at in a delta table.
 36      """
 37  
 38      def __init__(
 39          self,
 40          path: str | None = None,
 41          delta_table_name: str | None = None,
 42          delta_table_version: int | None = None,
 43          delta_table_id: str | None = None,
 44      ):
 45          if (path, delta_table_name).count(None) != 1:
 46              raise MlflowException(
 47                  'Must specify exactly one of "path" or "table_name"',
 48                  INVALID_PARAMETER_VALUE,
 49              )
 50          self._path = path
 51          if delta_table_name is not None:
 52              self._delta_table_name = get_full_name_from_sc(
 53                  delta_table_name, _get_active_spark_session()
 54              )
 55          else:
 56              self._delta_table_name = delta_table_name
 57          self._delta_table_version = delta_table_version
 58          self._delta_table_id = delta_table_id
 59  
 60      @staticmethod
 61      def _get_source_type() -> str:
 62          return "delta_table"
 63  
 64      def load(self, **kwargs):
 65          """
 66          Loads the dataset source as a Delta Dataset Source.
 67  
 68          Returns:
 69              An instance of ``pyspark.sql.DataFrame``.
 70          """
 71          from pyspark.sql import SparkSession
 72  
 73          spark = SparkSession.builder.getOrCreate()
 74  
 75          spark_read_op = spark.read.format("delta")
 76          if self._delta_table_version is not None:
 77              spark_read_op = spark_read_op.option("versionAsOf", self._delta_table_version)
 78  
 79          if self._path:
 80              return spark_read_op.load(self._path)
 81          else:
 82              backticked_delta_table_name = ".".join(
 83                  map(_backtick_quote, self._delta_table_name.split("."))
 84              )
 85              return spark_read_op.table(backticked_delta_table_name)
 86  
 87      @property
 88      def path(self) -> str | None:
 89          return self._path
 90  
 91      @property
 92      def delta_table_name(self) -> str | None:
 93          return self._delta_table_name
 94  
 95      @property
 96      def delta_table_id(self) -> str | None:
 97          return self._delta_table_id
 98  
 99      @property
100      def delta_table_version(self) -> int | None:
101          return self._delta_table_version
102  
103      @staticmethod
104      def _can_resolve(raw_source: Any):
105          return False
106  
107      @classmethod
108      def _resolve(cls, raw_source: str) -> "DeltaDatasetSource":
109          raise NotImplementedError
110  
111      # check if table is in the Databricks Unity Catalog
112      def _is_databricks_uc_table(self):
113          if self._delta_table_name is not None:
114              catalog_name = self._delta_table_name.split(".", 1)[0]
115              return (
116                  catalog_name not in DATABRICKS_LOCAL_METASTORE_NAMES
117                  and catalog_name != DATABRICKS_SAMPLES_CATALOG_NAME
118              )
119          else:
120              return False
121  
122      def _lookup_table_id(self, table_name):
123          try:
124              req_body = message_to_json(GetTable(full_name_arg=table_name))
125              _METHOD_TO_INFO = extract_api_info_for_service(
126                  DatabricksUnityCatalogService, _REST_API_PATH_PREFIX
127              )
128              db_creds = get_databricks_host_creds()
129              endpoint, method = _METHOD_TO_INFO[GetTable]
130              # We need to replace the full_name_arg in the endpoint definition with
131              # the actual table name for the REST API to work.
132              final_endpoint = endpoint.replace("{full_name_arg}", table_name)
133              resp = call_endpoint(
134                  host_creds=db_creds,
135                  endpoint=final_endpoint,
136                  method=method,
137                  json_body=req_body,
138                  response_proto=GetTableResponse,
139              )
140              return resp.table_id
141          except Exception:
142              return None
143  
144      def to_dict(self) -> dict[Any, Any]:
145          info = {}
146          if self._path:
147              info["path"] = self._path
148          if self._delta_table_name:
149              info["delta_table_name"] = self._delta_table_name
150          if self._delta_table_version:
151              info["delta_table_version"] = self._delta_table_version
152          if self._is_databricks_uc_table():
153              info["is_databricks_uc_table"] = True
154              if self._delta_table_id:
155                  info["delta_table_id"] = self._delta_table_id
156              else:
157                  info["delta_table_id"] = self._lookup_table_id(self._delta_table_name)
158          return info
159  
160      @classmethod
161      def from_dict(cls, source_dict: dict[Any, Any]) -> "DeltaDatasetSource":
162          return cls(
163              path=source_dict.get("path"),
164              delta_table_name=source_dict.get("delta_table_name"),
165              delta_table_version=source_dict.get("delta_table_version"),
166              delta_table_id=source_dict.get("delta_table_id"),
167          )