/ mlflow / entities / gateway_guardrail.py
gateway_guardrail.py
  1  from __future__ import annotations
  2  
  3  from dataclasses import dataclass
  4  from enum import Enum
  5  
  6  from mlflow.entities._mlflow_object import _MlflowObject
  7  from mlflow.entities.scorer import ScorerVersion
  8  from mlflow.protos.service_pb2 import GatewayGuardrail as ProtoGatewayGuardrail
  9  from mlflow.protos.service_pb2 import GatewayGuardrailConfig as ProtoGatewayGuardrailConfig
 10  from mlflow.protos.service_pb2 import GuardrailAction as ProtoGuardrailAction
 11  from mlflow.protos.service_pb2 import GuardrailStage as ProtoGuardrailStage
 12  from mlflow.utils.workspace_utils import resolve_entity_workspace_name
 13  
 14  
 15  class GuardrailStage(str, Enum):
 16      BEFORE = "BEFORE"
 17      AFTER = "AFTER"
 18  
 19      def __str__(self) -> str:
 20          return self.value
 21  
 22      @classmethod
 23      def from_proto(cls, proto: ProtoGuardrailStage) -> GuardrailStage:
 24          return cls(ProtoGuardrailStage.Name(proto))
 25  
 26      def to_proto(self) -> ProtoGuardrailStage:
 27          return ProtoGuardrailStage.Value(self.value)
 28  
 29  
 30  class GuardrailAction(str, Enum):
 31      VALIDATION = "VALIDATION"
 32      SANITIZATION = "SANITIZATION"
 33  
 34      def __str__(self) -> str:
 35          return self.value
 36  
 37      @classmethod
 38      def from_proto(cls, proto: ProtoGuardrailAction) -> GuardrailAction:
 39          return cls(ProtoGuardrailAction.Name(proto))
 40  
 41      def to_proto(self) -> ProtoGuardrailAction:
 42          return ProtoGuardrailAction.Value(self.value)
 43  
 44  
 45  @dataclass
 46  class GatewayGuardrail(_MlflowObject):
 47      guardrail_id: str
 48      name: str
 49      scorer: ScorerVersion
 50      stage: GuardrailStage
 51      action: GuardrailAction
 52      created_at: int
 53      last_updated_at: int
 54      action_endpoint_name: str | None = None
 55      created_by: str | None = None
 56      last_updated_by: str | None = None
 57      workspace: str | None = None
 58  
 59      def __post_init__(self):
 60          self.workspace = resolve_entity_workspace_name(self.workspace)
 61          if isinstance(self.stage, str):
 62              self.stage = GuardrailStage(self.stage)
 63          if isinstance(self.action, str):
 64              self.action = GuardrailAction(self.action)
 65  
 66      def to_proto(self):
 67          proto = ProtoGatewayGuardrail()
 68          proto.guardrail_id = self.guardrail_id
 69          proto.name = self.name
 70          proto.scorer.CopyFrom(self.scorer.to_proto())
 71          proto.stage = self.stage.to_proto()
 72          proto.action = self.action.to_proto()
 73          if self.action_endpoint_name:
 74              proto.action_endpoint_id = self.action_endpoint_name
 75          proto.created_by = self.created_by or ""
 76          proto.created_at = self.created_at
 77          proto.last_updated_by = self.last_updated_by or ""
 78          proto.last_updated_at = self.last_updated_at
 79          return proto
 80  
 81      @classmethod
 82      def from_proto(cls, proto):
 83          return cls(
 84              guardrail_id=proto.guardrail_id,
 85              name=proto.name,
 86              scorer=ScorerVersion.from_proto(proto.scorer),
 87              stage=GuardrailStage.from_proto(proto.stage),
 88              action=GuardrailAction.from_proto(proto.action),
 89              action_endpoint_name=proto.action_endpoint_id or None,
 90              created_by=proto.created_by or None,
 91              created_at=proto.created_at,
 92              last_updated_by=proto.last_updated_by or None,
 93              last_updated_at=proto.last_updated_at,
 94          )
 95  
 96  
 97  @dataclass
 98  class GatewayGuardrailConfig(_MlflowObject):
 99      """Junction between a guardrail and a gateway endpoint, with ordering."""
100  
101      endpoint_id: str
102      guardrail_id: str
103      execution_order: int | None
104      created_at: int
105      guardrail: GatewayGuardrail | None = None
106      created_by: str | None = None
107      workspace: str | None = None
108  
109      def to_proto(self):
110          proto = ProtoGatewayGuardrailConfig()
111          proto.endpoint_id = self.endpoint_id
112          proto.guardrail_id = self.guardrail_id
113          if self.execution_order is not None:
114              proto.execution_order = self.execution_order
115          if self.guardrail is not None:
116              proto.guardrail.CopyFrom(self.guardrail.to_proto())
117          proto.created_by = self.created_by or ""
118          proto.created_at = self.created_at
119          return proto
120  
121      @classmethod
122      def from_proto(cls, proto):
123          guardrail = None
124          if proto.HasField("guardrail"):
125              guardrail = GatewayGuardrail.from_proto(proto.guardrail)
126          return cls(
127              endpoint_id=proto.endpoint_id,
128              guardrail_id=proto.guardrail_id,
129              execution_order=proto.execution_order if proto.HasField("execution_order") else None,
130              guardrail=guardrail,
131              created_at=proto.created_at,
132              created_by=proto.created_by or None,
133          )