gateway_guardrail.py
1 from __future__ import annotations 2 3 from dataclasses import dataclass 4 from enum import Enum 5 6 from mlflow.entities._mlflow_object import _MlflowObject 7 from mlflow.entities.scorer import ScorerVersion 8 from mlflow.protos.service_pb2 import GatewayGuardrail as ProtoGatewayGuardrail 9 from mlflow.protos.service_pb2 import GatewayGuardrailConfig as ProtoGatewayGuardrailConfig 10 from mlflow.protos.service_pb2 import GuardrailAction as ProtoGuardrailAction 11 from mlflow.protos.service_pb2 import GuardrailStage as ProtoGuardrailStage 12 from mlflow.utils.workspace_utils import resolve_entity_workspace_name 13 14 15 class GuardrailStage(str, Enum): 16 BEFORE = "BEFORE" 17 AFTER = "AFTER" 18 19 def __str__(self) -> str: 20 return self.value 21 22 @classmethod 23 def from_proto(cls, proto: ProtoGuardrailStage) -> GuardrailStage: 24 return cls(ProtoGuardrailStage.Name(proto)) 25 26 def to_proto(self) -> ProtoGuardrailStage: 27 return ProtoGuardrailStage.Value(self.value) 28 29 30 class GuardrailAction(str, Enum): 31 VALIDATION = "VALIDATION" 32 SANITIZATION = "SANITIZATION" 33 34 def __str__(self) -> str: 35 return self.value 36 37 @classmethod 38 def from_proto(cls, proto: ProtoGuardrailAction) -> GuardrailAction: 39 return cls(ProtoGuardrailAction.Name(proto)) 40 41 def to_proto(self) -> ProtoGuardrailAction: 42 return ProtoGuardrailAction.Value(self.value) 43 44 45 @dataclass 46 class GatewayGuardrail(_MlflowObject): 47 guardrail_id: str 48 name: str 49 scorer: ScorerVersion 50 stage: GuardrailStage 51 action: GuardrailAction 52 created_at: int 53 last_updated_at: int 54 action_endpoint_name: str | None = None 55 created_by: str | None = None 56 last_updated_by: str | None = None 57 workspace: str | None = None 58 59 def __post_init__(self): 60 self.workspace = resolve_entity_workspace_name(self.workspace) 61 if isinstance(self.stage, str): 62 self.stage = GuardrailStage(self.stage) 63 if isinstance(self.action, str): 64 self.action = GuardrailAction(self.action) 65 66 def to_proto(self): 67 proto = ProtoGatewayGuardrail() 68 proto.guardrail_id = self.guardrail_id 69 proto.name = self.name 70 proto.scorer.CopyFrom(self.scorer.to_proto()) 71 proto.stage = self.stage.to_proto() 72 proto.action = self.action.to_proto() 73 if self.action_endpoint_name: 74 proto.action_endpoint_id = self.action_endpoint_name 75 proto.created_by = self.created_by or "" 76 proto.created_at = self.created_at 77 proto.last_updated_by = self.last_updated_by or "" 78 proto.last_updated_at = self.last_updated_at 79 return proto 80 81 @classmethod 82 def from_proto(cls, proto): 83 return cls( 84 guardrail_id=proto.guardrail_id, 85 name=proto.name, 86 scorer=ScorerVersion.from_proto(proto.scorer), 87 stage=GuardrailStage.from_proto(proto.stage), 88 action=GuardrailAction.from_proto(proto.action), 89 action_endpoint_name=proto.action_endpoint_id or None, 90 created_by=proto.created_by or None, 91 created_at=proto.created_at, 92 last_updated_by=proto.last_updated_by or None, 93 last_updated_at=proto.last_updated_at, 94 ) 95 96 97 @dataclass 98 class GatewayGuardrailConfig(_MlflowObject): 99 """Junction between a guardrail and a gateway endpoint, with ordering.""" 100 101 endpoint_id: str 102 guardrail_id: str 103 execution_order: int | None 104 created_at: int 105 guardrail: GatewayGuardrail | None = None 106 created_by: str | None = None 107 workspace: str | None = None 108 109 def to_proto(self): 110 proto = ProtoGatewayGuardrailConfig() 111 proto.endpoint_id = self.endpoint_id 112 proto.guardrail_id = self.guardrail_id 113 if self.execution_order is not None: 114 proto.execution_order = self.execution_order 115 if self.guardrail is not None: 116 proto.guardrail.CopyFrom(self.guardrail.to_proto()) 117 proto.created_by = self.created_by or "" 118 proto.created_at = self.created_at 119 return proto 120 121 @classmethod 122 def from_proto(cls, proto): 123 guardrail = None 124 if proto.HasField("guardrail"): 125 guardrail = GatewayGuardrail.from_proto(proto.guardrail) 126 return cls( 127 endpoint_id=proto.endpoint_id, 128 guardrail_id=proto.guardrail_id, 129 execution_order=proto.execution_order if proto.HasField("execution_order") else None, 130 guardrail=guardrail, 131 created_at=proto.created_at, 132 created_by=proto.created_by or None, 133 )