/ ragaai_catalyst / guardrails_manager.py
guardrails_manager.py
1 import requests 2 import json 3 import os 4 import logging 5 logger = logging.getLogger(__name__) 6 from .utils import response_checker 7 from .ragaai_catalyst import RagaAICatalyst 8 9 10 class GuardrailsManager: 11 def __init__(self, project_name): 12 """ 13 Initialize the GuardrailsManager with the given project name. 14 15 :param project_name: The name of the project to manage guardrails for. 16 """ 17 self.project_name = project_name 18 self.timeout = 10 19 self.num_projects = 99999 20 self.deployment_name = "NA" 21 self.deployment_id = "NA" 22 self.base_url = f"{RagaAICatalyst.BASE_URL}" 23 list_projects, project_name_with_id = self._get_project_list() 24 if project_name not in list_projects: 25 raise ValueError(f"Project '{self.project_name}' does not exists") 26 27 self.project_id = [_["id"] for _ in project_name_with_id if _["name"]==self.project_name][0] 28 29 30 def _get_project_list(self): 31 """ 32 Retrieve the list of projects and their IDs from the API. 33 34 :return: A tuple containing a list of project names and a list of dictionaries with project IDs and names. 35 """ 36 headers = {'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}'} 37 response = requests.request("GET", f"{self.base_url}/v2/llm/projects?size={self.num_projects}", headers=headers, timeout=self.timeout) 38 project_content = response.json()["data"]["content"] 39 list_project = [_["name"] for _ in project_content] 40 project_name_with_id = [{"id": _["id"], "name": _["name"]} for _ in project_content] 41 return list_project, project_name_with_id 42 43 44 def list_deployment_ids(self): 45 """ 46 List all deployment IDs and their names for the current project. 47 48 :return: A list of dictionaries containing deployment IDs and names. 49 """ 50 payload = {} 51 headers = { 52 'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}', 53 'X-Project-Id': str(self.project_id) 54 } 55 response = requests.request("GET", f"{self.base_url}/guardrail/deployment?size={self.num_projects}&page=0&sort=lastUsedAt,desc", headers=headers, data=payload, timeout=self.timeout) 56 deployment_ids_content = response.json()["data"]["content"] 57 deployment_ids_content = [{"id": _["id"], "name": _["name"]} for _ in deployment_ids_content] 58 return deployment_ids_content 59 60 61 def get_deployment(self, deployment_id): 62 """ 63 Get details of a specific deployment ID, including its name and guardrails. 64 65 :param deployment_id: The ID of the deployment to retrieve details for. 66 :return: A dictionary containing the deployment name and a list of guardrails. 67 """ 68 payload = {} 69 headers = { 70 'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}', 71 'X-Project-Id': str(self.project_id) 72 } 73 response = requests.request("GET", f"{self.base_url}/guardrail/deployment/{deployment_id}", headers=headers, data=payload, timeout=self.timeout) 74 if response.json()['success']: 75 return response.json() 76 else: 77 print('Error in retrieving deployment details:',response.json()['message']) 78 return None 79 80 81 def list_guardrails(self): 82 """ 83 List all available guardrails for the current project. 84 85 :return: A list of guardrail names. 86 """ 87 payload = {} 88 headers = { 89 'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}', 90 'X-Project-Id': str(self.project_id) 91 } 92 response = requests.request("GET", f"{self.base_url}/v1/llm/llm-metrics?category=Guardrail", headers=headers, data=payload, timeout=self.timeout) 93 list_guardrails_content = response.json()["data"]["metrics"] 94 list_guardrails = [_["name"] for _ in list_guardrails_content] 95 return list_guardrails 96 97 98 def list_fail_condition(self): 99 """ 100 List all fail conditions for the current project's deployments. 101 102 :return: A list of fail conditions. 103 """ 104 payload = {} 105 headers = { 106 'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}', 107 'X-Project-Id': str(self.project_id) 108 } 109 response = requests.request("GET", f"{self.base_url}/guardrail/deployment/configurations", headers=headers, data=payload, timeout=self.timeout) 110 return response.json()["data"] 111 112 113 def list_datasets(self): 114 """ 115 Retrieves a list of datasets for a given project. 116 117 Returns: 118 list: A list of dataset names. 119 120 Raises: 121 None. 122 """ 123 124 def make_request(): 125 headers = { 126 'Content-Type': 'application/json', 127 "Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}", 128 "X-Project-Id": str(self.project_id), 129 } 130 json_data = {"size": 12, "page": "0", "projectId": str(self.project_id), "search": ""} 131 try: 132 response = requests.post( 133 f"{self.base_url}/v2/llm/dataset", 134 headers=headers, 135 json=json_data, 136 timeout=30, 137 ) 138 response.raise_for_status() 139 return response 140 except requests.exceptions.RequestException as e: 141 logger.error(f"Failed to list datasets: {e}") 142 raise 143 144 try: 145 response = make_request() 146 response_checker(response, "Dataset.list_datasets") 147 if response.status_code == 401: 148 response = make_request() # Retry the request 149 if response.status_code != 200: 150 return { 151 "status_code": response.status_code, 152 "message": response.json(), 153 } 154 datasets = response.json()["data"]["content"] 155 dataset_list = [dataset["name"] for dataset in datasets] 156 return dataset_list 157 except Exception as e: 158 logger.error(f"Error in list_datasets: {e}") 159 raise 160 161 162 def create_deployment(self, deployment_name, deployment_dataset_name): 163 """ 164 Create a new deployment ID with the given name. 165 166 :param deployment_name: The name of the new deployment. 167 :param deployment_dataset_name: The name of the tracking dataset. 168 :raises ValueError: If a deployment with the given name already exists. 169 """ 170 self.deployment_name = deployment_name 171 self.deployment_dataset_name = deployment_dataset_name 172 list_deployment_ids = self.list_deployment_ids() 173 list_deployment_names = [_["name"] for _ in list_deployment_ids] 174 if deployment_name in list_deployment_names: 175 raise ValueError(f"Deployment with '{deployment_name}' already exists, choose a unique name") 176 177 # Check if dataset name exists 178 list_datasets = self.list_datasets() 179 # Assuming this method exists to get list of datasets 180 is_new_dataset = deployment_dataset_name not in list_datasets 181 182 payload = json.dumps({ 183 "name": str(deployment_name), 184 "trackingDataset": { 185 "addNew": is_new_dataset, 186 "name": str(deployment_dataset_name) 187 } 188 }) 189 headers = { 190 'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}', 191 'Content-Type': 'application/json', 192 'X-Project-Id': str(self.project_id) 193 } 194 response = requests.request("POST", f"{self.base_url}/guardrail/deployment", headers=headers, data=payload, timeout=self.timeout) 195 if response.status_code == 409: 196 raise ValueError(f"Data with '{deployment_name}' already exists, choose a unique name") 197 if response.json()["success"]: 198 print(response.json()["message"]) 199 deployment_ids = self.list_deployment_ids() 200 self.deployment_id = [_["id"] for _ in deployment_ids if _["name"]==self.deployment_name][0] 201 return self.deployment_id 202 else: 203 print(response) 204 205 206 def add_guardrails(self, deployment_id, guardrails, guardrails_config={}): 207 """ 208 Add guardrails to the current deployment. 209 210 :param guardrails: A list of guardrails to add. 211 :param guardrails_config: Configuration settings for the guardrails. 212 :raises ValueError: If a guardrail name or type is invalid. 213 """ 214 # Checking if guardrails names given already exist or not 215 self.deployment_id = deployment_id 216 deployment_details = self.get_deployment(self.deployment_id) 217 if not deployment_details: 218 return None 219 deployment_id_name = deployment_details["data"]["name"] 220 deployment_id_guardrails = deployment_details["data"]["guardrailsResponse"] 221 guardrails_type_name_exists = [{_['metricSpec']["name"]:_['metricSpec']["displayName"]} for _ in deployment_id_guardrails] 222 guardrails_type_name_exists = [list(d.values())[0] for d in guardrails_type_name_exists] 223 user_guardrails_name_list = [_["name"] for _ in guardrails] 224 for g_name in user_guardrails_name_list: 225 if g_name in guardrails_type_name_exists: 226 raise ValueError(f"Guardrail with '{g_name} already exists, choose a unique name'") 227 # Checking if guardrails type is correct or not 228 available_guardrails_list = self.list_guardrails() 229 user_guardrails_type_list = [_["name"] for _ in guardrails] 230 for g_type in user_guardrails_type_list: 231 if g_type not in available_guardrails_list: 232 raise ValueError(f"Guardrail type '{g_type} does not exists, choose a correct type'") 233 234 payload = self._get_guardrail_config_payload(guardrails_config) 235 payload["guardrails"] = self._get_guardrail_list_payload(guardrails) 236 payload = json.dumps(payload) 237 headers = { 238 'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}', 239 'Content-Type': 'application/json', 240 'X-Project-Id': str(self.project_id) 241 } 242 response = requests.request("POST", f"{self.base_url}/guardrail/deployment/{str(self.deployment_id)}/configure", headers=headers, data=payload) 243 if response.json()["success"]: 244 print(response.json()["message"]) 245 else: 246 print('Error updating guardrail ',response.json()['message']) 247 248 def _get_guardrail_config_payload(self, guardrails_config): 249 """ 250 Construct the payload for guardrail configuration. 251 252 :param guardrails_config: Configuration settings for the guardrails. 253 :return: A dictionary representing the guardrail configuration payload. 254 """ 255 data = { 256 "isActive": guardrails_config.get("isActive",False), 257 "guardrailFailConditions": guardrails_config.get("guardrailFailConditions",["FAIL"]), 258 "deploymentFailCondition": guardrails_config.get("deploymentFailCondition","ONE_FAIL"), 259 "failAction": { 260 "action": "ALTERNATE_RESPONSE", 261 "args": f'{{\"alternateResponse\": \"{guardrails_config.get("alternateResponse","This is the Alternate Response")}\"}}' 262 }, 263 "guardrails" : [] 264 } 265 return data 266 267 def _get_guardrail_list_payload(self, guardrails): 268 """ 269 Construct the payload for a list of guardrails. 270 271 :param guardrails: A list of guardrails to include in the payload. 272 :return: A list of dictionaries representing each guardrail's data. 273 """ 274 guardrails_list_payload = [] 275 for guardrail in guardrails: 276 guardrails_list_payload.append(self._get_one_guardrail_data(guardrail)) 277 return guardrails_list_payload 278 279 def _get_one_guardrail_data(self, guardrail): 280 """ 281 Construct the data for a single guardrail. 282 283 :param guardrail: A dictionary containing the guardrail's attributes. 284 :return: A dictionary representing the guardrail's data. 285 """ 286 if 'config' in guardrail: 287 if 'mappings' in guardrail.get('config'): 288 for mapping in guardrail.get('config',{}).get('mappings',{}): 289 if mapping['schemaName'] not in ['Text','Prompt','Context','Response']: 290 raise(ValueError('Invalid schemaName in guardrail mapping schema')) 291 if mapping['variableName'] not in ['Instruction','Prompt','Context','Response']: 292 raise(ValueError('Invalid variableName in guardrail mapping schema')) 293 if 'model' in guardrail.get('config'): 294 if guardrail.get('config',{}).get('model','') not in ['gpt-4o-mini','gpt-4o','gpt-4-turbo']: 295 raise(ValueError('Invalid model name in guardrail model schema')) 296 if 'params' not in guardrail.get('config'): 297 guardrail['config']['params'] = { 298 "isActive": {"value": False}, 299 "isHighRisk": {"value": False}, 300 "threshold": {"lt": 1} 301 } 302 303 304 data = { 305 "displayName": guardrail["displayName"], 306 "name": guardrail["name"], 307 "config": guardrail.get("config", {}) 308 } 309 ''' 310 if "lte" in guardrail["threshold"]: 311 data["threshold"]["lte"] = guardrail["threshold"]["lte"] 312 elif "gte" in guardrail["threshold"]: 313 data["threshold"]["gte"] = guardrail["threshold"]["gte"] 314 elif "eq" in guardrail["threshold"]: 315 data["threshold"]["eq"] = guardrail["threshold"]["eq"] 316 else: 317 data["threshold"]["gte"] = 0.0''' 318 return data 319 320 321 def _run(self, **kwargs): 322 """ 323 Execute the guardrail checks with the provided variables. 324 """