logged_model_output.py
1 from mlflow.entities._mlflow_object import _MlflowObject 2 from mlflow.protos.service_pb2 import ModelOutput 3 4 5 class LoggedModelOutput(_MlflowObject): 6 """ModelOutput object associated with a Run.""" 7 8 def __init__(self, model_id: str, step: int) -> None: 9 self._model_id = model_id 10 self._step = step 11 12 def __eq__(self, other: _MlflowObject) -> bool: 13 if type(other) is type(self): 14 return self.__dict__ == other.__dict__ 15 return False 16 17 @property 18 def model_id(self) -> str: 19 """Model ID""" 20 return self._model_id 21 22 @property 23 def step(self) -> str: 24 """Step at which the model was logged""" 25 return self._step 26 27 def to_proto(self): 28 return ModelOutput(model_id=self.model_id, step=self.step) 29 30 def to_dictionary(self) -> dict[str, str | int]: 31 return {"model_id": self.model_id, "step": self.step} 32 33 @classmethod 34 def from_proto(cls, proto): 35 return cls(proto.model_id, proto.step)