delta_dataset_source.py
1 import logging 2 from typing import Any 3 4 from mlflow.data.dataset_source import DatasetSource 5 from mlflow.exceptions import MlflowException 6 from mlflow.protos.databricks_managed_catalog_messages_pb2 import ( 7 GetTable, 8 GetTableResponse, 9 ) 10 from mlflow.protos.databricks_managed_catalog_service_pb2 import DatabricksUnityCatalogService 11 from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE 12 from mlflow.utils._spark_utils import _get_active_spark_session 13 from mlflow.utils._unity_catalog_utils import get_full_name_from_sc 14 from mlflow.utils.databricks_utils import get_databricks_host_creds 15 from mlflow.utils.proto_json_utils import message_to_json 16 from mlflow.utils.rest_utils import ( 17 _REST_API_PATH_PREFIX, 18 call_endpoint, 19 extract_api_info_for_service, 20 ) 21 from mlflow.utils.string_utils import _backtick_quote 22 23 DATABRICKS_HIVE_METASTORE_NAME = "hive_metastore" 24 # these two catalog names both points to the workspace local default HMS (hive metastore). 25 DATABRICKS_LOCAL_METASTORE_NAMES = [DATABRICKS_HIVE_METASTORE_NAME, "spark_catalog"] 26 # samples catalog is managed by databricks for hosting public dataset like NYC taxi dataset. 27 # it is neither a UC nor local metastore catalog 28 DATABRICKS_SAMPLES_CATALOG_NAME = "samples" 29 30 _logger = logging.getLogger(__name__) 31 32 33 class DeltaDatasetSource(DatasetSource): 34 """ 35 Represents the source of a dataset stored at in a delta table. 36 """ 37 38 def __init__( 39 self, 40 path: str | None = None, 41 delta_table_name: str | None = None, 42 delta_table_version: int | None = None, 43 delta_table_id: str | None = None, 44 ): 45 if (path, delta_table_name).count(None) != 1: 46 raise MlflowException( 47 'Must specify exactly one of "path" or "table_name"', 48 INVALID_PARAMETER_VALUE, 49 ) 50 self._path = path 51 if delta_table_name is not None: 52 self._delta_table_name = get_full_name_from_sc( 53 delta_table_name, _get_active_spark_session() 54 ) 55 else: 56 self._delta_table_name = delta_table_name 57 self._delta_table_version = delta_table_version 58 self._delta_table_id = delta_table_id 59 60 @staticmethod 61 def _get_source_type() -> str: 62 return "delta_table" 63 64 def load(self, **kwargs): 65 """ 66 Loads the dataset source as a Delta Dataset Source. 67 68 Returns: 69 An instance of ``pyspark.sql.DataFrame``. 70 """ 71 from pyspark.sql import SparkSession 72 73 spark = SparkSession.builder.getOrCreate() 74 75 spark_read_op = spark.read.format("delta") 76 if self._delta_table_version is not None: 77 spark_read_op = spark_read_op.option("versionAsOf", self._delta_table_version) 78 79 if self._path: 80 return spark_read_op.load(self._path) 81 else: 82 backticked_delta_table_name = ".".join( 83 map(_backtick_quote, self._delta_table_name.split(".")) 84 ) 85 return spark_read_op.table(backticked_delta_table_name) 86 87 @property 88 def path(self) -> str | None: 89 return self._path 90 91 @property 92 def delta_table_name(self) -> str | None: 93 return self._delta_table_name 94 95 @property 96 def delta_table_id(self) -> str | None: 97 return self._delta_table_id 98 99 @property 100 def delta_table_version(self) -> int | None: 101 return self._delta_table_version 102 103 @staticmethod 104 def _can_resolve(raw_source: Any): 105 return False 106 107 @classmethod 108 def _resolve(cls, raw_source: str) -> "DeltaDatasetSource": 109 raise NotImplementedError 110 111 # check if table is in the Databricks Unity Catalog 112 def _is_databricks_uc_table(self): 113 if self._delta_table_name is not None: 114 catalog_name = self._delta_table_name.split(".", 1)[0] 115 return ( 116 catalog_name not in DATABRICKS_LOCAL_METASTORE_NAMES 117 and catalog_name != DATABRICKS_SAMPLES_CATALOG_NAME 118 ) 119 else: 120 return False 121 122 def _lookup_table_id(self, table_name): 123 try: 124 req_body = message_to_json(GetTable(full_name_arg=table_name)) 125 _METHOD_TO_INFO = extract_api_info_for_service( 126 DatabricksUnityCatalogService, _REST_API_PATH_PREFIX 127 ) 128 db_creds = get_databricks_host_creds() 129 endpoint, method = _METHOD_TO_INFO[GetTable] 130 # We need to replace the full_name_arg in the endpoint definition with 131 # the actual table name for the REST API to work. 132 final_endpoint = endpoint.replace("{full_name_arg}", table_name) 133 resp = call_endpoint( 134 host_creds=db_creds, 135 endpoint=final_endpoint, 136 method=method, 137 json_body=req_body, 138 response_proto=GetTableResponse, 139 ) 140 return resp.table_id 141 except Exception: 142 return None 143 144 def to_dict(self) -> dict[Any, Any]: 145 info = {} 146 if self._path: 147 info["path"] = self._path 148 if self._delta_table_name: 149 info["delta_table_name"] = self._delta_table_name 150 if self._delta_table_version: 151 info["delta_table_version"] = self._delta_table_version 152 if self._is_databricks_uc_table(): 153 info["is_databricks_uc_table"] = True 154 if self._delta_table_id: 155 info["delta_table_id"] = self._delta_table_id 156 else: 157 info["delta_table_id"] = self._lookup_table_id(self._delta_table_name) 158 return info 159 160 @classmethod 161 def from_dict(cls, source_dict: dict[Any, Any]) -> "DeltaDatasetSource": 162 return cls( 163 path=source_dict.get("path"), 164 delta_table_name=source_dict.get("delta_table_name"), 165 delta_table_version=source_dict.get("delta_table_version"), 166 delta_table_id=source_dict.get("delta_table_id"), 167 )