schema.py
1 from typing import Any 2 3 from mlflow.exceptions import MlflowException 4 from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE 5 from mlflow.types.schema import Schema 6 7 8 class TensorDatasetSchema: 9 """ 10 Represents the schema of a dataset with tensor features and targets. 11 """ 12 13 def __init__(self, features: Schema, targets: Schema = None): 14 if not isinstance(features, Schema): 15 raise MlflowException( 16 f"features must be mlflow.types.Schema, got '{type(features)}'", 17 INVALID_PARAMETER_VALUE, 18 ) 19 if targets is not None and not isinstance(targets, Schema): 20 raise MlflowException( 21 f"targets must be either None or mlflow.types.Schema, got '{type(features)}'", 22 INVALID_PARAMETER_VALUE, 23 ) 24 self.features = features 25 self.targets = targets 26 27 def to_dict(self) -> dict[str, Any]: 28 """Serialize into a 'jsonable' dictionary. 29 30 Returns: 31 dictionary representation of the schema's features and targets (if defined). 32 33 """ 34 35 return { 36 "mlflow_tensorspec": { 37 "features": self.features.to_json(), 38 "targets": self.targets.to_json() if self.targets is not None else None, 39 }, 40 } 41 42 @classmethod 43 def from_dict(cls, schema_dict: dict[str, Any]): 44 """Deserialize from dictionary representation. 45 46 Args: 47 schema_dict: Dictionary representation of model signature. Expected dictionary format: 48 `{'features': <json string>, 'targets': <json string>" }` 49 50 Returns: 51 TensorDatasetSchema populated with the data from the dictionary. 52 53 """ 54 if "mlflow_tensorspec" not in schema_dict: 55 raise MlflowException( 56 "TensorDatasetSchema dictionary is missing expected key 'mlflow_tensorspec'", 57 INVALID_PARAMETER_VALUE, 58 ) 59 60 schema_dict = schema_dict["mlflow_tensorspec"] 61 features = Schema.from_json(schema_dict["features"]) 62 if "targets" in schema_dict and schema_dict["targets"] is not None: 63 targets = Schema.from_json(schema_dict["targets"]) 64 return cls(features, targets) 65 else: 66 return cls(features) 67 68 def __eq__(self, other) -> bool: 69 return ( 70 isinstance(other, TensorDatasetSchema) 71 and self.features == other.features 72 and self.targets == other.targets 73 ) 74 75 def __repr__(self) -> str: 76 return f"features:\n {self.features!r}\ntargets:\n {self.targets!r}\n"