/ mlflow / data / schema.py
schema.py
 1  from typing import Any
 2  
 3  from mlflow.exceptions import MlflowException
 4  from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
 5  from mlflow.types.schema import Schema
 6  
 7  
 8  class TensorDatasetSchema:
 9      """
10      Represents the schema of a dataset with tensor features and targets.
11      """
12  
13      def __init__(self, features: Schema, targets: Schema = None):
14          if not isinstance(features, Schema):
15              raise MlflowException(
16                  f"features must be mlflow.types.Schema, got '{type(features)}'",
17                  INVALID_PARAMETER_VALUE,
18              )
19          if targets is not None and not isinstance(targets, Schema):
20              raise MlflowException(
21                  f"targets must be either None or mlflow.types.Schema, got '{type(features)}'",
22                  INVALID_PARAMETER_VALUE,
23              )
24          self.features = features
25          self.targets = targets
26  
27      def to_dict(self) -> dict[str, Any]:
28          """Serialize into a 'jsonable' dictionary.
29  
30          Returns:
31              dictionary representation of the schema's features and targets (if defined).
32  
33          """
34  
35          return {
36              "mlflow_tensorspec": {
37                  "features": self.features.to_json(),
38                  "targets": self.targets.to_json() if self.targets is not None else None,
39              },
40          }
41  
42      @classmethod
43      def from_dict(cls, schema_dict: dict[str, Any]):
44          """Deserialize from dictionary representation.
45  
46          Args:
47              schema_dict: Dictionary representation of model signature. Expected dictionary format:
48                  `{'features': <json string>, 'targets': <json string>" }`
49  
50          Returns:
51              TensorDatasetSchema populated with the data from the dictionary.
52  
53          """
54          if "mlflow_tensorspec" not in schema_dict:
55              raise MlflowException(
56                  "TensorDatasetSchema dictionary is missing expected key 'mlflow_tensorspec'",
57                  INVALID_PARAMETER_VALUE,
58              )
59  
60          schema_dict = schema_dict["mlflow_tensorspec"]
61          features = Schema.from_json(schema_dict["features"])
62          if "targets" in schema_dict and schema_dict["targets"] is not None:
63              targets = Schema.from_json(schema_dict["targets"])
64              return cls(features, targets)
65          else:
66              return cls(features)
67  
68      def __eq__(self, other) -> bool:
69          return (
70              isinstance(other, TensorDatasetSchema)
71              and self.features == other.features
72              and self.targets == other.targets
73          )
74  
75      def __repr__(self) -> str:
76          return f"features:\n  {self.features!r}\ntargets:\n  {self.targets!r}\n"