/ ragaai_catalyst / guard_executor.py
guard_executor.py
1 import litellm 2 import json 3 import requests 4 import os 5 from google import genai 6 from google.genai.types import GenerateContentConfig 7 from typing import Optional, List, Dict, Any 8 import logging 9 logger = logging.getLogger('LiteLLM') 10 logger.setLevel(logging.ERROR) 11 12 class GuardExecutor: 13 14 def __init__(self,guard_manager,input_deployment_id = None,output_deployment_id=None,field_map={}): 15 self.field_map = field_map 16 self.guard_manager = guard_manager 17 try: 18 if input_deployment_id: 19 self.input_deployment_id = input_deployment_id 20 self.input_deployment_details = self.guard_manager.get_deployment(input_deployment_id) 21 if output_deployment_id: 22 self.output_deployment_id = output_deployment_id 23 self.output_deployment_details = self.guard_manager.get_deployment(output_deployment_id) 24 if input_deployment_id and output_deployment_id: 25 # check if 2 deployments are mapped to same dataset 26 if self.input_deployment_details['data']['datasetId'] != self.output_deployment_details['data']['datasetId']: 27 raise ValueError('Input deployment and output deployment should be mapped to same dataset') 28 for guardrail in self.input_deployment_details['data']['guardrailsResponse']: 29 maps = guardrail['metricSpec']['config']['mappings'] 30 for _map in maps: 31 if _map['schemaName']=='Response': 32 raise ValueError('Response field should be mapped only in output guardrails') 33 except Exception as e: 34 raise ValueError(str(e)) 35 self.base_url = guard_manager.base_url 36 for key in field_map.keys(): 37 if key not in ['prompt','context','response','instruction']: 38 print('Keys in field map should be in ["prompt","context","response","instruction"]') 39 self.current_trace_id = None 40 self.id_2_doc = {} 41 42 def execute_deployment(self, deployment_id, payload): 43 api = self.base_url + f'/guardrail/deployment/{deployment_id}/ingest' 44 if self.current_trace_id: 45 payload['traceId'] = self.current_trace_id 46 payload = json.dumps(payload) 47 headers = { 48 'x-project-id': str(self.guard_manager.project_id), 49 'Content-Type': 'application/json', 50 'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}' 51 } 52 try: 53 response = requests.request("POST", api, headers=headers, data=payload,timeout=self.guard_manager.timeout) 54 except Exception as e: 55 print('Failed running guardrail: ',str(e)) 56 return None 57 if response.status_code!=200: 58 print('Error in running deployment ',response.json()['message']) 59 if response.json()['success']: 60 return response.json() 61 else: 62 print(response.json()['message']) 63 return None 64 65 def llm_executor(self,prompt,model_params,llm_caller): 66 messages = [{ 67 'role':'user', 68 'content':prompt 69 }] 70 if self.current_trace_id: 71 doc = self.id_2_doc[self.current_trace_id] 72 messages[0]['content'] = messages[0]['content'] + '\n' + doc.get('context','') 73 if llm_caller == 'litellm': 74 model_params['messages'] = messages 75 response = litellm.completion(**model_params) 76 return response['choices'][0].message.content 77 elif llm_caller == 'genai': 78 genai_client = genai.Client(api_key=os.getenv('GENAI_API_KEY')) 79 model_params['messages'] = messages 80 response = genai_client.models.generate(**model_params) 81 return response.text 82 else: 83 print(f"{llm_caller} not supported currently, use litellm as llm caller") 84 ''' 85 elif llm_caller == 'anthropic': 86 response = anthropic.completion(prompt=messages, **model_params) 87 return response['completion'] 88 elif llm_caller == 'langchain': 89 response = langchain.completion(prompt=messages, **model_params) 90 return response['choices'][0].text 91 elif llm_caller == 'azure_openai': 92 response = azure_openai.completion(prompt=messages, **model_params) 93 return response['choices'][0].text 94 elif llm_caller == 'aws_bedrock': 95 response = aws_bedrock.completion(prompt=messages, **model_params) 96 return response['choices'][0].text 97 elif llm_caller == 'meta': 98 response = meta.completion(prompt=messages, **model_params) 99 return response['choices'][0].text 100 elif llm_csller == 'llamaindex': 101 response = llamaindex.completion(prompt=messages, **model_params) 102 return response['choices'][0].text''' 103 104 def set_input_params(self, prompt: None, context: None, instruction: None, **kwargs): 105 if 'latest' not in self.id_2_doc: 106 self.id_2_doc['latest'] = {} 107 if prompt: 108 self.id_2_doc['latest']['prompt'] = prompt 109 if context: 110 self.id_2_doc['latest']['context'] = context 111 if instruction: 112 self.id_2_doc['latest']['instruction'] = instruction 113 114 115 def __call__(self,prompt,prompt_params,model_params,llm_caller='litellm'): 116 '''for key in self.field_map: 117 if key not in ['prompt','response']: 118 if self.field_map[key] not in prompt_params: 119 raise ValueError(f'{key} added as field map but not passed as prompt parameter') 120 context_var = self.field_map.get('context',None) 121 prompt = None 122 for msg in messages: 123 if 'role' in msg: 124 if msg['role'] == 'user': 125 prompt = msg['content'] 126 if not context_var: 127 msg['content'] += '\n' + prompt_params[context_var] 128 doc = dict() 129 doc['prompt'] = prompt 130 doc['context'] = prompt_params[context_var]''' 131 132 # Run the input guardrails 133 alternate_response,input_deployment_response = self.execute_input_guardrails(prompt,prompt_params) 134 if input_deployment_response and input_deployment_response['data']['status'].lower() == 'fail': 135 return alternate_response, None, input_deployment_response 136 137 # activate only guardrails that require response 138 try: 139 llm_response = self.llm_executor(prompt,model_params,llm_caller) 140 except Exception as e: 141 print('Error in running llm:',str(e)) 142 return None, None, input_deployment_response 143 if 'instruction' in self.field_map: 144 instruction = prompt_params[self.field_map['instruction']] 145 alternate_op_response,output_deployment_response = self.execute_output_guardrails(llm_response) 146 if output_deployment_response and output_deployment_response['data']['status'].lower() == 'fail': 147 return alternate_op_response,llm_response,output_deployment_response 148 else: 149 return None,llm_response,output_deployment_response 150 151 def set_variables(self,prompt,prompt_params): 152 for key in self.field_map: 153 if key not in ['prompt', 'response']: 154 if self.field_map[key] not in prompt_params: 155 raise ValueError(f'{key} added as field map but not passed as prompt parameter') 156 context_var = self.field_map.get('context', None) 157 158 doc = dict() 159 doc['prompt'] = prompt 160 doc['context'] = prompt_params[context_var] 161 if 'instruction' in self.field_map: 162 instruction = prompt_params[self.field_map['instruction']] 163 doc['instruction'] = instruction 164 return doc 165 166 def execute_input_guardrails(self, prompt, prompt_params): 167 self.current_trace_id =None 168 doc = self.set_variables(prompt,prompt_params) 169 deployment_response = self.execute_deployment(self.input_deployment_id,doc) 170 self.current_trace_id = deployment_response['data']['results'][0]['executionId'] 171 self.id_2_doc[self.current_trace_id] = doc 172 if deployment_response and deployment_response['data']['status'].lower() == 'fail': 173 return deployment_response['data']['alternateResponse'], deployment_response 174 elif deployment_response: 175 return None, deployment_response 176 177 def execute_output_guardrails(self, llm_response: str, prompt=None, prompt_params=None) -> None: 178 if not prompt: # user has not passed input 179 if self.current_trace_id not in self.id_2_doc: 180 raise Exception(f'No input doc found for trace_id: {self.current_trace_id}') 181 else: 182 doc = self.id_2_doc[self.current_trace_id] 183 doc['response'] = llm_response 184 else: 185 doc = self.set_variables(prompt,prompt_params) 186 deployment_response = self.execute_deployment(self.output_deployment_id,doc) 187 del self.id_2_doc[self.current_trace_id] 188 self.current_trace_id = None 189 if deployment_response and deployment_response['data']['status'].lower() == 'fail': 190 return deployment_response['data']['alternateResponse'], deployment_response 191 elif deployment_response: 192 return None, deployment_response 193 194 195 ''' 196 # doc = dict() 197 # doc['response'] = llm_response 198 # if trace_id: 199 # doc['trace_id'] = trace_id 200 trace_id = self.current_trace_id 201 if not trace_id: 202 for key in self.field_map: 203 if key not in ['prompt', 'response']: 204 if not prompt_params or self.field_map[key] not in prompt_params: 205 if key not in self.id_2_doc.get('latest', {}): 206 raise ValueError(f'{key} added as field map but not passed as prompt parameter or set in executor') 207 elif key == 'prompt': 208 if not messages: 209 if key not in self.id_2_doc.get('latest', {}): 210 raise ValueError('messages should be provided when prompt is used as field or prompt should be set in executor') 211 # raise Exception(f'\'doc_id\' not provided and there is no doc_id currently available. Either run \'execute_input_guardrails\' or pass a valid \'doc_id\'') 212 #deployment_details = self.guard_manager.get_deployment(self.output_deployment_id) 213 #deployed_guardrails = deployment_details['data']['guardrailsResponse'] 214 215 for guardrail in deployed_guardrails: 216 metric_spec_mappings = guardrail['metricSpec']['config']['mappings'] 217 var_names = [mapping['variableNmae'].lower() for mapping in metric_spec_mappings] 218 for var_name in var_names: 219 if var_name not in ['prompt', 'response']: 220 if var_name not in self.field_map: 221 raise ValueError(f'{var_name} requrired for {guardrail} guardrail in deployment {self.deployment_id} but not added as field map') 222 if not prompt_params or (self.field_map[var_name] not in prompt_params): 223 if var_name not in self.id_2_doc.get('latest', {}): 224 raise ValueError(f'{var_name} added as field map but not passed as prompt parameter') 225 elif var_name == 'prompt': 226 if not messages: 227 if var_name not in self.id_2_doc.get('latest', {}): 228 raise ValueError('messages must be provided if doc_id is not provided') 229 230 prompt = None 231 if messages: 232 for msg in messages: 233 if 'role' in msg: 234 if msg['role'] == 'user': 235 prompt = msg['content'] 236 else: 237 prompt = self.id_2_doc['latest']['prompt'] 238 context_var = self.field_map.get('context', None) 239 doc = dict() 240 doc['prompt'] = prompt 241 if context_var and prompt_params and context_var in prompt_params: 242 doc['context'] = prompt_params[self.field_map[context_var]] 243 elif context_var: 244 doc['context'] = self.id_2_doc['latest']['context'] 245 elif 'latest' in self.id_2_doc and 'context' in self.id_2_doc['latest'] and self.id_2_doc['latest']['context']: 246 doc['context'] = self.id_2_doc['latest']['context'] 247 else: 248 doc['context'] = '' 249 if 'instruction' in self.field_map: 250 if prompt_params and 'instruction' in prompt_params: 251 instruction = prompt_params[self.field_map['instruction']] 252 elif 'latest' in self.id_2_doc and 'instruction' in self.id_2_doc['latest'] and self.id_2_doc['latest']['instruction']: 253 instruction = self.id_2_doc['latest']['instruction'] 254 else: 255 raise ValueError('instruction added as field map but not passed as prompt parameter or set in executor') 256 doc['instruction'] = instruction 257 elif trace_id not in self.id_2_doc: 258 raise Exception(f'trace_id {trace_id} is not valid. Please run \'execute_input_guardrails\' first') 259 else: 260 doc = self.id_2_doc[trace_id] 261 doc['response'] = llm_response 262 response = self.execute_deployment(doc) 263 if response and response['data']['status'] == 'FAIL': 264 print('Guardrail deployment run retured failed status, replacing with alternate response') 265 return response['data']['alternateResponse'], llm_response, response 266 else: 267 self.current_trace_id = None 268 return None, llm_response, response 269 ''' 270 271 272