/ ragaai_catalyst / guardrails_manager.py
guardrails_manager.py
  1  import requests
  2  import json
  3  import os
  4  import logging
  5  logger = logging.getLogger(__name__)
  6  from .utils import response_checker
  7  from .ragaai_catalyst import RagaAICatalyst
  8  
  9  
 10  class GuardrailsManager:
 11      def __init__(self, project_name):
 12          """
 13          Initialize the GuardrailsManager with the given project name.
 14          
 15          :param project_name: The name of the project to manage guardrails for.
 16          """
 17          self.project_name = project_name
 18          self.timeout = 10
 19          self.num_projects = 99999
 20          self.deployment_name = "NA"
 21          self.deployment_id = "NA"
 22          self.base_url = f"{RagaAICatalyst.BASE_URL}"
 23          list_projects, project_name_with_id = self._get_project_list()
 24          if project_name not in list_projects:
 25              raise ValueError(f"Project '{self.project_name}' does not exists")
 26          
 27          self.project_id = [_["id"] for _ in project_name_with_id if _["name"]==self.project_name][0]
 28  
 29  
 30      def _get_project_list(self):
 31          """
 32          Retrieve the list of projects and their IDs from the API.
 33          
 34          :return: A tuple containing a list of project names and a list of dictionaries with project IDs and names.
 35          """
 36          headers = {'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}'}
 37          response = requests.request("GET", f"{self.base_url}/v2/llm/projects?size={self.num_projects}", headers=headers, timeout=self.timeout)
 38          project_content = response.json()["data"]["content"]
 39          list_project = [_["name"] for _ in project_content]
 40          project_name_with_id = [{"id": _["id"], "name": _["name"]} for _ in project_content]
 41          return list_project, project_name_with_id
 42  
 43  
 44      def list_deployment_ids(self):
 45          """
 46          List all deployment IDs and their names for the current project.
 47          
 48          :return: A list of dictionaries containing deployment IDs and names.
 49          """
 50          payload = {}
 51          headers = {
 52                  'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
 53                  'X-Project-Id': str(self.project_id)
 54                  }
 55          response = requests.request("GET", f"{self.base_url}/guardrail/deployment?size={self.num_projects}&page=0&sort=lastUsedAt,desc", headers=headers, data=payload, timeout=self.timeout)
 56          deployment_ids_content = response.json()["data"]["content"]
 57          deployment_ids_content = [{"id": _["id"], "name": _["name"]} for _ in deployment_ids_content]
 58          return deployment_ids_content
 59  
 60  
 61      def get_deployment(self, deployment_id):
 62          """
 63          Get details of a specific deployment ID, including its name and guardrails.
 64          
 65          :param deployment_id: The ID of the deployment to retrieve details for.
 66          :return: A dictionary containing the deployment name and a list of guardrails.
 67          """
 68          payload = {}
 69          headers = {
 70                  'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
 71                  'X-Project-Id': str(self.project_id)
 72                  }
 73          response = requests.request("GET", f"{self.base_url}/guardrail/deployment/{deployment_id}", headers=headers, data=payload, timeout=self.timeout)
 74          if response.json()['success']:
 75              return response.json()
 76          else:
 77              print('Error in retrieving deployment details:',response.json()['message'])
 78              return None
 79  
 80  
 81      def list_guardrails(self):
 82          """
 83          List all available guardrails for the current project.
 84          
 85          :return: A list of guardrail names.
 86          """
 87          payload = {}
 88          headers = {
 89                  'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
 90                  'X-Project-Id': str(self.project_id)
 91                  }
 92          response = requests.request("GET", f"{self.base_url}/v1/llm/llm-metrics?category=Guardrail", headers=headers, data=payload, timeout=self.timeout)
 93          list_guardrails_content = response.json()["data"]["metrics"]
 94          list_guardrails = [_["name"] for _ in list_guardrails_content]
 95          return list_guardrails
 96  
 97  
 98      def list_fail_condition(self):
 99          """
100          List all fail conditions for the current project's deployments.
101          
102          :return: A list of fail conditions.
103          """
104          payload = {}
105          headers = {
106                  'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
107                  'X-Project-Id': str(self.project_id)
108                  }
109          response = requests.request("GET", f"{self.base_url}/guardrail/deployment/configurations", headers=headers, data=payload, timeout=self.timeout)
110          return response.json()["data"]
111  
112      
113      def list_datasets(self):
114          """
115          Retrieves a list of datasets for a given project.
116  
117          Returns:
118              list: A list of dataset names.
119  
120          Raises:
121              None.
122          """
123  
124          def make_request():
125              headers = {
126                  'Content-Type': 'application/json',
127                  "Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
128                  "X-Project-Id": str(self.project_id),
129              }
130              json_data = {"size": 12, "page": "0", "projectId": str(self.project_id), "search": ""}
131              try:
132                  response = requests.post(
133                      f"{self.base_url}/v2/llm/dataset",
134                      headers=headers,
135                      json=json_data,
136                      timeout=30,
137                  )
138                  response.raise_for_status()
139                  return response
140              except requests.exceptions.RequestException as e:
141                  logger.error(f"Failed to list datasets: {e}")
142                  raise
143  
144          try:
145              response = make_request()
146              response_checker(response, "Dataset.list_datasets")
147              if response.status_code == 401:
148                  response = make_request()  # Retry the request
149              if response.status_code != 200:
150                  return {
151                      "status_code": response.status_code,
152                      "message": response.json(),
153                  }
154              datasets = response.json()["data"]["content"]
155              dataset_list = [dataset["name"] for dataset in datasets]
156              return dataset_list
157          except Exception as e:
158              logger.error(f"Error in list_datasets: {e}")
159              raise
160  
161  
162      def create_deployment(self, deployment_name, deployment_dataset_name):
163          """
164          Create a new deployment ID with the given name.
165          
166          :param deployment_name: The name of the new deployment.
167          :param deployment_dataset_name: The name of the tracking dataset.
168          :raises ValueError: If a deployment with the given name already exists.
169          """
170          self.deployment_name = deployment_name
171          self.deployment_dataset_name = deployment_dataset_name
172          list_deployment_ids = self.list_deployment_ids()
173          list_deployment_names = [_["name"] for _ in list_deployment_ids]
174          if deployment_name in list_deployment_names:
175              raise ValueError(f"Deployment with '{deployment_name}' already exists, choose a unique name")
176          
177          # Check if dataset name exists
178          list_datasets = self.list_datasets()
179          # Assuming this method exists to get list of datasets
180          is_new_dataset = deployment_dataset_name not in list_datasets
181          
182          payload = json.dumps({
183              "name": str(deployment_name),
184              "trackingDataset": {
185                  "addNew": is_new_dataset,
186                  "name": str(deployment_dataset_name)
187              }
188          })
189          headers = {
190                  'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
191                  'Content-Type': 'application/json',
192                  'X-Project-Id': str(self.project_id)
193                  }
194          response = requests.request("POST", f"{self.base_url}/guardrail/deployment", headers=headers, data=payload, timeout=self.timeout)
195          if response.status_code == 409:
196              raise ValueError(f"Data with '{deployment_name}' already exists, choose a unique name")
197          if response.json()["success"]:
198              print(response.json()["message"])
199              deployment_ids = self.list_deployment_ids()
200              self.deployment_id = [_["id"] for _ in deployment_ids if _["name"]==self.deployment_name][0]
201              return self.deployment_id
202          else:
203              print(response)
204              
205  
206      def add_guardrails(self, deployment_id, guardrails, guardrails_config={}):
207          """
208          Add guardrails to the current deployment.
209          
210          :param guardrails: A list of guardrails to add.
211          :param guardrails_config: Configuration settings for the guardrails.
212          :raises ValueError: If a guardrail name or type is invalid.
213          """
214          # Checking if guardrails names given already exist or not
215          self.deployment_id = deployment_id
216          deployment_details = self.get_deployment(self.deployment_id)
217          if not deployment_details:
218              return None
219          deployment_id_name = deployment_details["data"]["name"]
220          deployment_id_guardrails = deployment_details["data"]["guardrailsResponse"]
221          guardrails_type_name_exists = [{_['metricSpec']["name"]:_['metricSpec']["displayName"]} for _ in deployment_id_guardrails]
222          guardrails_type_name_exists = [list(d.values())[0] for d in guardrails_type_name_exists]
223          user_guardrails_name_list = [_["name"] for _ in guardrails]
224          for g_name in user_guardrails_name_list:
225              if g_name in guardrails_type_name_exists:
226                  raise ValueError(f"Guardrail with '{g_name} already exists, choose a unique name'")
227          # Checking if guardrails type is correct or not
228          available_guardrails_list = self.list_guardrails()
229          user_guardrails_type_list = [_["name"] for _ in guardrails]
230          for g_type in user_guardrails_type_list:
231              if g_type not in available_guardrails_list:
232                  raise ValueError(f"Guardrail type '{g_type} does not exists, choose a correct type'")
233  
234          payload = self._get_guardrail_config_payload(guardrails_config)
235          payload["guardrails"] = self._get_guardrail_list_payload(guardrails)
236          payload = json.dumps(payload)
237          headers = {
238                  'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
239                  'Content-Type': 'application/json',
240                  'X-Project-Id': str(self.project_id)
241                  }
242          response = requests.request("POST", f"{self.base_url}/guardrail/deployment/{str(self.deployment_id)}/configure", headers=headers, data=payload)
243          if response.json()["success"]:
244              print(response.json()["message"])
245          else:
246              print('Error updating guardrail ',response.json()['message'])
247  
248      def _get_guardrail_config_payload(self, guardrails_config):
249          """
250          Construct the payload for guardrail configuration.
251          
252          :param guardrails_config: Configuration settings for the guardrails.
253          :return: A dictionary representing the guardrail configuration payload.
254          """
255          data = {
256              "isActive": guardrails_config.get("isActive",False),
257              "guardrailFailConditions": guardrails_config.get("guardrailFailConditions",["FAIL"]),
258              "deploymentFailCondition": guardrails_config.get("deploymentFailCondition","ONE_FAIL"),
259              "failAction": {
260                  "action": "ALTERNATE_RESPONSE",
261                  "args": f'{{\"alternateResponse\": \"{guardrails_config.get("alternateResponse","This is the Alternate Response")}\"}}'
262                  },
263              "guardrails" : []
264              }
265          return data
266  
267      def _get_guardrail_list_payload(self, guardrails):
268          """
269          Construct the payload for a list of guardrails.
270          
271          :param guardrails: A list of guardrails to include in the payload.
272          :return: A list of dictionaries representing each guardrail's data.
273          """
274          guardrails_list_payload = []
275          for guardrail in guardrails:
276              guardrails_list_payload.append(self._get_one_guardrail_data(guardrail))
277          return guardrails_list_payload
278  
279      def _get_one_guardrail_data(self, guardrail):
280          """
281          Construct the data for a single guardrail.
282          
283          :param guardrail: A dictionary containing the guardrail's attributes.
284          :return: A dictionary representing the guardrail's data.
285          """
286          if 'config' in guardrail:
287              if 'mappings' in guardrail.get('config'):
288                  for mapping in guardrail.get('config',{}).get('mappings',{}):
289                      if mapping['schemaName'] not in ['Text','Prompt','Context','Response']:
290                          raise(ValueError('Invalid schemaName in guardrail mapping schema'))
291                      if mapping['variableName'] not in ['Instruction','Prompt','Context','Response']:
292                          raise(ValueError('Invalid variableName in guardrail mapping schema'))
293              if 'model' in guardrail.get('config'):
294                  if guardrail.get('config',{}).get('model','') not in ['gpt-4o-mini','gpt-4o','gpt-4-turbo']:
295                      raise(ValueError('Invalid model name in guardrail model schema'))
296              if 'params' not in guardrail.get('config'):
297                  guardrail['config']['params'] = {
298                      "isActive": {"value": False},
299                      "isHighRisk": {"value": False},
300                      "threshold": {"lt": 1}
301                  }
302  
303  
304          data = {
305              "displayName": guardrail["displayName"],
306              "name": guardrail["name"],
307              "config": guardrail.get("config", {})
308          }
309          '''
310          if "lte" in guardrail["threshold"]:
311              data["threshold"]["lte"] = guardrail["threshold"]["lte"]
312          elif "gte" in guardrail["threshold"]:
313              data["threshold"]["gte"] = guardrail["threshold"]["gte"]
314          elif "eq" in guardrail["threshold"]:
315              data["threshold"]["eq"] = guardrail["threshold"]["eq"]
316          else:
317              data["threshold"]["gte"] = 0.0'''
318          return data
319  
320  
321      def _run(self, **kwargs):
322          """
323          Execute the guardrail checks with the provided variables.
324          """