/ mlflow / entities / dataset_input.py
dataset_input.py
 1  from mlflow.entities._mlflow_object import _MlflowObject
 2  from mlflow.entities.dataset import Dataset
 3  from mlflow.entities.input_tag import InputTag
 4  from mlflow.protos.service_pb2 import DatasetInput as ProtoDatasetInput
 5  
 6  
 7  class DatasetInput(_MlflowObject):
 8      """DatasetInput object associated with an experiment."""
 9  
10      def __init__(self, dataset: Dataset, tags: list[InputTag] | None = None) -> None:
11          self._dataset = dataset
12          self._tags = tags or []
13  
14      def __eq__(self, other: _MlflowObject) -> bool:
15          if type(other) is type(self):
16              return self.__dict__ == other.__dict__
17          return False
18  
19      def _add_tag(self, tag: InputTag) -> None:
20          self._tags.append(tag)
21  
22      @property
23      def tags(self) -> list[InputTag]:
24          """Array of input tags."""
25          return self._tags
26  
27      @property
28      def dataset(self) -> Dataset:
29          """Dataset."""
30          return self._dataset
31  
32      def to_proto(self):
33          dataset_input = ProtoDatasetInput()
34          dataset_input.tags.extend([tag.to_proto() for tag in self.tags])
35          dataset_input.dataset.MergeFrom(self.dataset.to_proto())
36          return dataset_input
37  
38      @classmethod
39      def from_proto(cls, proto):
40          dataset_input = cls(Dataset.from_proto(proto.dataset))
41          for input_tag in proto.tags:
42              dataset_input._add_tag(InputTag.from_proto(input_tag))
43          return dataset_input
44  
45      def to_dictionary(self):
46          return {
47              "dataset": self.dataset.to_dictionary(),
48              "tags": {tag.key: tag.value for tag in self.tags},
49          }