dataset_input.py
1 from mlflow.entities._mlflow_object import _MlflowObject 2 from mlflow.entities.dataset import Dataset 3 from mlflow.entities.input_tag import InputTag 4 from mlflow.protos.service_pb2 import DatasetInput as ProtoDatasetInput 5 6 7 class DatasetInput(_MlflowObject): 8 """DatasetInput object associated with an experiment.""" 9 10 def __init__(self, dataset: Dataset, tags: list[InputTag] | None = None) -> None: 11 self._dataset = dataset 12 self._tags = tags or [] 13 14 def __eq__(self, other: _MlflowObject) -> bool: 15 if type(other) is type(self): 16 return self.__dict__ == other.__dict__ 17 return False 18 19 def _add_tag(self, tag: InputTag) -> None: 20 self._tags.append(tag) 21 22 @property 23 def tags(self) -> list[InputTag]: 24 """Array of input tags.""" 25 return self._tags 26 27 @property 28 def dataset(self) -> Dataset: 29 """Dataset.""" 30 return self._dataset 31 32 def to_proto(self): 33 dataset_input = ProtoDatasetInput() 34 dataset_input.tags.extend([tag.to_proto() for tag in self.tags]) 35 dataset_input.dataset.MergeFrom(self.dataset.to_proto()) 36 return dataset_input 37 38 @classmethod 39 def from_proto(cls, proto): 40 dataset_input = cls(Dataset.from_proto(proto.dataset)) 41 for input_tag in proto.tags: 42 dataset_input._add_tag(InputTag.from_proto(input_tag)) 43 return dataset_input 44 45 def to_dictionary(self): 46 return { 47 "dataset": self.dataset.to_dictionary(), 48 "tags": {tag.key: tag.value for tag in self.tags}, 49 }