/ mlflow / entities / gateway_budget_policy.py
gateway_budget_policy.py
  1  from __future__ import annotations
  2  
  3  from dataclasses import dataclass
  4  from enum import Enum
  5  
  6  from mlflow.entities._mlflow_object import _MlflowObject
  7  from mlflow.protos.service_pb2 import BudgetAction as ProtoBudgetAction
  8  from mlflow.protos.service_pb2 import BudgetDuration as ProtoBudgetDuration
  9  from mlflow.protos.service_pb2 import BudgetDurationUnit as ProtoBudgetDurationUnit
 10  from mlflow.protos.service_pb2 import BudgetTargetScope as ProtoBudgetTargetScope
 11  from mlflow.protos.service_pb2 import BudgetUnit as ProtoBudgetUnit
 12  from mlflow.protos.service_pb2 import GatewayBudgetPolicy as ProtoGatewayBudgetPolicy
 13  from mlflow.utils.workspace_utils import resolve_entity_workspace_name
 14  
 15  
 16  class BudgetDurationUnit(str, Enum):
 17      """Duration unit for budget policy fixed windows."""
 18  
 19      MINUTES = "MINUTES"
 20      HOURS = "HOURS"
 21      DAYS = "DAYS"
 22      WEEKS = "WEEKS"
 23      MONTHS = "MONTHS"
 24  
 25      @classmethod
 26      def from_proto(cls, proto: ProtoBudgetDurationUnit) -> BudgetDurationUnit | None:
 27          try:
 28              return cls(ProtoBudgetDurationUnit.Name(proto))
 29          except ValueError:
 30              return None
 31  
 32      def to_proto(self) -> ProtoBudgetDurationUnit:
 33          return ProtoBudgetDurationUnit.Value(self.value)
 34  
 35  
 36  class BudgetTargetScope(str, Enum):
 37      """Target scope for a budget policy."""
 38  
 39      GLOBAL = "GLOBAL"
 40      WORKSPACE = "WORKSPACE"
 41  
 42      @classmethod
 43      def from_proto(cls, proto: ProtoBudgetTargetScope) -> BudgetTargetScope | None:
 44          try:
 45              return cls(ProtoBudgetTargetScope.Name(proto))
 46          except ValueError:
 47              return None
 48  
 49      def to_proto(self) -> ProtoBudgetTargetScope:
 50          return ProtoBudgetTargetScope.Value(self.value)
 51  
 52  
 53  class BudgetAction(str, Enum):
 54      """Action to take when a budget is exceeded."""
 55  
 56      ALERT = "ALERT"
 57      REJECT = "REJECT"
 58  
 59      @classmethod
 60      def from_proto(cls, proto: ProtoBudgetAction) -> BudgetAction | None:
 61          try:
 62              return cls(ProtoBudgetAction.Name(proto))
 63          except ValueError:
 64              return None
 65  
 66      def to_proto(self) -> ProtoBudgetAction:
 67          return ProtoBudgetAction.Value(self.value)
 68  
 69  
 70  class BudgetUnit(str, Enum):
 71      """Budget measurement unit."""
 72  
 73      USD = "USD"
 74  
 75      @classmethod
 76      def from_proto(cls, proto: ProtoBudgetUnit) -> BudgetUnit | None:
 77          try:
 78              return cls(ProtoBudgetUnit.Name(proto))
 79          except ValueError:
 80              return None
 81  
 82      def to_proto(self) -> ProtoBudgetUnit:
 83          return ProtoBudgetUnit.Value(self.value)
 84  
 85  
 86  @dataclass
 87  class BudgetDuration:
 88      """Fixed window duration: a (unit, value) pair defining the length of a budget window."""
 89  
 90      unit: BudgetDurationUnit
 91      value: int
 92  
 93      def __post_init__(self):
 94          if isinstance(self.unit, str):
 95              self.unit = BudgetDurationUnit(self.unit)
 96  
 97      def to_proto(self) -> ProtoBudgetDuration:
 98          proto = ProtoBudgetDuration()
 99          proto.unit = self.unit.to_proto()
100          proto.value = self.value
101          return proto
102  
103      @classmethod
104      def from_proto(cls, proto: ProtoBudgetDuration) -> BudgetDuration:
105          return cls(
106              unit=BudgetDurationUnit.from_proto(proto.unit),
107              value=proto.value,
108          )
109  
110  
111  @dataclass
112  class GatewayBudgetPolicy(_MlflowObject):
113      """
114      Represents a budget policy for the AI Gateway.
115  
116      Budget policies set limits with fixed time windows,
117      supporting global or per-workspace scoping.
118  
119      Args:
120          budget_policy_id: Unique identifier for this budget policy.
121          budget_unit: Budget measurement unit (e.g. USD).
122          budget_amount: Budget limit amount.
123          duration: Fixed time window (unit + length pair).
124          target_scope: Scope of the budget (GLOBAL or WORKSPACE).
125          budget_action: Action when budget is exceeded (ALERT, REJECT).
126          created_at: Timestamp (milliseconds) when the policy was created.
127          last_updated_at: Timestamp (milliseconds) when the policy was last updated.
128          created_by: User ID who created the policy.
129          last_updated_by: User ID who last updated the policy.
130          workspace: Workspace that owns the policy.
131      """
132  
133      budget_policy_id: str
134      budget_unit: BudgetUnit
135      budget_amount: float
136      duration: BudgetDuration
137      target_scope: BudgetTargetScope
138      budget_action: BudgetAction
139      created_at: int
140      last_updated_at: int
141      created_by: str | None = None
142      last_updated_by: str | None = None
143      workspace: str | None = None
144  
145      def __post_init__(self):
146          self.workspace = resolve_entity_workspace_name(self.workspace)
147          if isinstance(self.budget_unit, str):
148              self.budget_unit = BudgetUnit(self.budget_unit)
149          if isinstance(self.target_scope, str):
150              self.target_scope = BudgetTargetScope(self.target_scope)
151          if isinstance(self.budget_action, str):
152              self.budget_action = BudgetAction(self.budget_action)
153  
154      def to_proto(self):
155          proto = ProtoGatewayBudgetPolicy()
156          proto.budget_policy_id = self.budget_policy_id
157          proto.budget_unit = self.budget_unit.to_proto()
158          proto.budget_amount = self.budget_amount
159          proto.duration.CopyFrom(self.duration.to_proto())
160          proto.target_scope = self.target_scope.to_proto()
161          proto.budget_action = self.budget_action.to_proto()
162          proto.created_by = self.created_by or ""
163          proto.created_at = self.created_at
164          proto.last_updated_by = self.last_updated_by or ""
165          proto.last_updated_at = self.last_updated_at
166          return proto
167  
168      @classmethod
169      def from_proto(cls, proto):
170          return cls(
171              budget_policy_id=proto.budget_policy_id,
172              budget_unit=BudgetUnit.from_proto(proto.budget_unit),
173              budget_amount=proto.budget_amount,
174              duration=BudgetDuration.from_proto(proto.duration),
175              target_scope=BudgetTargetScope.from_proto(proto.target_scope),
176              budget_action=BudgetAction.from_proto(proto.budget_action),
177              created_by=proto.created_by or None,
178              created_at=proto.created_at,
179              last_updated_by=proto.last_updated_by or None,
180              last_updated_at=proto.last_updated_at,
181          )