/ ragaai_catalyst / guard_executor.py
guard_executor.py
  1  import litellm
  2  import json
  3  import requests
  4  import os
  5  from google import genai
  6  from google.genai.types import GenerateContentConfig
  7  from typing import Optional, List, Dict, Any
  8  import logging
  9  logger = logging.getLogger('LiteLLM')
 10  logger.setLevel(logging.ERROR)
 11  
 12  class GuardExecutor:
 13  
 14      def __init__(self,guard_manager,input_deployment_id = None,output_deployment_id=None,field_map={}):
 15          self.field_map = field_map
 16          self.guard_manager = guard_manager
 17          try:
 18              if input_deployment_id:
 19                  self.input_deployment_id = input_deployment_id
 20                  self.input_deployment_details = self.guard_manager.get_deployment(input_deployment_id)
 21              if output_deployment_id:
 22                  self.output_deployment_id = output_deployment_id
 23                  self.output_deployment_details = self.guard_manager.get_deployment(output_deployment_id)
 24              if input_deployment_id and output_deployment_id:
 25                  # check if 2 deployments are mapped to same dataset
 26                  if self.input_deployment_details['data']['datasetId'] != self.output_deployment_details['data']['datasetId']:
 27                      raise ValueError('Input deployment and output deployment should be mapped to same dataset')
 28              for guardrail in self.input_deployment_details['data']['guardrailsResponse']:
 29                  maps = guardrail['metricSpec']['config']['mappings']
 30                  for _map in maps:
 31                      if _map['schemaName']=='Response':
 32                          raise ValueError('Response field should be mapped only in output guardrails')
 33          except Exception as e:
 34              raise ValueError(str(e))
 35          self.base_url = guard_manager.base_url
 36          for key in field_map.keys():
 37              if key not in ['prompt','context','response','instruction']:
 38                  print('Keys in field map should be in ["prompt","context","response","instruction"]')
 39          self.current_trace_id = None
 40          self.id_2_doc = {}
 41  
 42      def execute_deployment(self, deployment_id, payload):
 43          api = self.base_url + f'/guardrail/deployment/{deployment_id}/ingest'
 44          if self.current_trace_id:
 45              payload['traceId'] = self.current_trace_id
 46          payload = json.dumps(payload)
 47          headers = {
 48              'x-project-id': str(self.guard_manager.project_id),
 49              'Content-Type': 'application/json',
 50              'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}'
 51          }
 52          try:
 53              response = requests.request("POST", api, headers=headers, data=payload,timeout=self.guard_manager.timeout)
 54          except Exception as e:
 55              print('Failed running guardrail: ',str(e))
 56              return None
 57          if response.status_code!=200:
 58              print('Error in running deployment ',response.json()['message'])
 59          if response.json()['success']:
 60              return response.json()
 61          else:
 62              print(response.json()['message'])
 63              return None
 64  
 65      def llm_executor(self,prompt,model_params,llm_caller):
 66          messages = [{
 67                      'role':'user',
 68                      'content':prompt
 69                      }]
 70          if self.current_trace_id:
 71              doc = self.id_2_doc[self.current_trace_id]
 72              messages[0]['content'] = messages[0]['content'] + '\n' + doc.get('context','')
 73          if llm_caller == 'litellm':
 74              model_params['messages'] = messages
 75              response = litellm.completion(**model_params)
 76              return response['choices'][0].message.content
 77          elif llm_caller == 'genai':
 78              genai_client = genai.Client(api_key=os.getenv('GENAI_API_KEY'))
 79              model_params['messages'] = messages
 80              response = genai_client.models.generate(**model_params)
 81              return response.text
 82          else:
 83              print(f"{llm_caller} not supported currently, use litellm as llm caller")
 84          '''
 85          elif llm_caller == 'anthropic':
 86              response = anthropic.completion(prompt=messages, **model_params)
 87              return response['completion']
 88          elif llm_caller == 'langchain':
 89              response = langchain.completion(prompt=messages, **model_params)
 90              return response['choices'][0].text
 91          elif llm_caller == 'azure_openai':
 92              response = azure_openai.completion(prompt=messages, **model_params)
 93              return response['choices'][0].text
 94          elif llm_caller == 'aws_bedrock':
 95              response = aws_bedrock.completion(prompt=messages, **model_params)
 96              return response['choices'][0].text
 97          elif llm_caller == 'meta':
 98              response = meta.completion(prompt=messages, **model_params)
 99              return response['choices'][0].text
100          elif llm_csller == 'llamaindex':
101              response = llamaindex.completion(prompt=messages, **model_params)
102              return response['choices'][0].text'''
103  
104      def set_input_params(self, prompt: None, context: None, instruction: None,  **kwargs):
105          if 'latest' not in self.id_2_doc:
106              self.id_2_doc['latest'] = {}
107          if prompt:
108              self.id_2_doc['latest']['prompt'] = prompt
109          if context:
110              self.id_2_doc['latest']['context'] = context
111          if instruction:
112              self.id_2_doc['latest']['instruction'] = instruction
113  
114      
115      def __call__(self,prompt,prompt_params,model_params,llm_caller='litellm'):
116          '''for key in self.field_map:
117              if key not in ['prompt','response']:
118                  if self.field_map[key] not in prompt_params:
119                      raise ValueError(f'{key} added as field map but not passed as prompt parameter')
120          context_var = self.field_map.get('context',None)
121          prompt = None
122          for msg in messages:
123              if 'role' in msg:
124                  if msg['role'] == 'user':
125                      prompt = msg['content']
126                      if not context_var:
127                          msg['content'] += '\n' + prompt_params[context_var]
128          doc = dict()
129          doc['prompt'] = prompt
130          doc['context'] = prompt_params[context_var]'''
131          
132          # Run the input guardrails
133          alternate_response,input_deployment_response = self.execute_input_guardrails(prompt,prompt_params)
134          if input_deployment_response and input_deployment_response['data']['status'].lower() == 'fail':
135              return alternate_response, None, input_deployment_response
136          
137          # activate only guardrails that require response
138          try:
139              llm_response = self.llm_executor(prompt,model_params,llm_caller)
140          except Exception as e:
141              print('Error in running llm:',str(e))
142              return None, None, input_deployment_response
143          if 'instruction' in self.field_map:
144              instruction = prompt_params[self.field_map['instruction']]
145          alternate_op_response,output_deployment_response = self.execute_output_guardrails(llm_response)
146          if output_deployment_response and output_deployment_response['data']['status'].lower() == 'fail':
147              return alternate_op_response,llm_response,output_deployment_response
148          else:
149              return None,llm_response,output_deployment_response
150  
151      def set_variables(self,prompt,prompt_params):
152          for key in self.field_map:
153              if key not in ['prompt', 'response']:
154                  if self.field_map[key] not in prompt_params:
155                      raise ValueError(f'{key} added as field map but not passed as prompt parameter')
156          context_var = self.field_map.get('context', None)
157          
158          doc = dict()
159          doc['prompt'] = prompt
160          doc['context'] = prompt_params[context_var]
161          if 'instruction' in self.field_map:
162              instruction = prompt_params[self.field_map['instruction']]
163              doc['instruction'] = instruction
164          return doc
165  
166      def execute_input_guardrails(self, prompt, prompt_params):
167          self.current_trace_id =None
168          doc = self.set_variables(prompt,prompt_params)
169          deployment_response = self.execute_deployment(self.input_deployment_id,doc)
170          self.current_trace_id = deployment_response['data']['results'][0]['executionId']
171          self.id_2_doc[self.current_trace_id] = doc
172          if deployment_response and deployment_response['data']['status'].lower() == 'fail':
173              return deployment_response['data']['alternateResponse'], deployment_response
174          elif deployment_response:
175              return None, deployment_response
176  
177      def execute_output_guardrails(self, llm_response: str, prompt=None, prompt_params=None) -> None:
178          if not prompt: # user has not passed input
179              if self.current_trace_id not in self.id_2_doc:
180                  raise Exception(f'No input doc found for trace_id: {self.current_trace_id}')
181              else:
182                  doc = self.id_2_doc[self.current_trace_id]
183                  doc['response'] = llm_response
184          else:
185              doc = self.set_variables(prompt,prompt_params)
186          deployment_response = self.execute_deployment(self.output_deployment_id,doc)
187          del self.id_2_doc[self.current_trace_id]
188          self.current_trace_id = None
189          if deployment_response and deployment_response['data']['status'].lower() == 'fail':
190              return deployment_response['data']['alternateResponse'], deployment_response
191          elif deployment_response:
192              return None, deployment_response
193  
194  
195          '''
196          # doc = dict()
197          # doc['response'] = llm_response
198          # if trace_id:
199          #     doc['trace_id'] = trace_id
200          trace_id = self.current_trace_id
201          if not trace_id:
202              for key in self.field_map:
203                  if key not in ['prompt', 'response']:
204                      if not prompt_params or self.field_map[key] not in prompt_params:
205                          if key not in self.id_2_doc.get('latest', {}):
206                              raise ValueError(f'{key} added as field map but not passed as prompt parameter or set in executor')
207                  elif key == 'prompt':
208                      if not messages:
209                          if key not in self.id_2_doc.get('latest', {}):
210                              raise ValueError('messages should be provided when prompt is used as field or prompt should be set in executor')
211              # raise Exception(f'\'doc_id\' not provided and there is no doc_id currently available. Either run \'execute_input_guardrails\' or pass a valid \'doc_id\'')
212              #deployment_details = self.guard_manager.get_deployment(self.output_deployment_id)
213              #deployed_guardrails = deployment_details['data']['guardrailsResponse']
214              
215              for guardrail in deployed_guardrails:
216                  metric_spec_mappings = guardrail['metricSpec']['config']['mappings']
217                  var_names = [mapping['variableNmae'].lower() for mapping in metric_spec_mappings]
218                  for var_name in var_names:
219                      if var_name not in ['prompt', 'response']:
220                          if var_name not in self.field_map:
221                              raise ValueError(f'{var_name} requrired for {guardrail} guardrail in deployment {self.deployment_id} but not added as field map')
222                          if not prompt_params or (self.field_map[var_name] not in prompt_params):
223                              if var_name not in self.id_2_doc.get('latest', {}):
224                                  raise ValueError(f'{var_name} added as field map but not passed as prompt parameter')
225                          elif var_name == 'prompt':
226                              if not messages:
227                                  if var_name not in self.id_2_doc.get('latest', {}):
228                                      raise ValueError('messages must be provided if doc_id is not provided')
229              
230              prompt = None
231              if messages:
232                  for msg in messages:
233                      if 'role' in msg:
234                          if msg['role'] == 'user':
235                              prompt = msg['content']
236              else:
237                  prompt = self.id_2_doc['latest']['prompt']
238              context_var = self.field_map.get('context', None)
239              doc = dict()
240              doc['prompt'] = prompt
241              if context_var and prompt_params and context_var in prompt_params:
242                  doc['context'] = prompt_params[self.field_map[context_var]]
243              elif context_var:
244                  doc['context'] = self.id_2_doc['latest']['context']
245              elif 'latest' in self.id_2_doc and 'context' in self.id_2_doc['latest'] and self.id_2_doc['latest']['context']:
246                  doc['context'] = self.id_2_doc['latest']['context']
247              else:
248                  doc['context'] = ''
249              if 'instruction' in self.field_map:
250                  if prompt_params and 'instruction' in prompt_params:
251                      instruction = prompt_params[self.field_map['instruction']]
252                  elif 'latest' in self.id_2_doc and 'instruction' in self.id_2_doc['latest'] and self.id_2_doc['latest']['instruction']:
253                      instruction = self.id_2_doc['latest']['instruction']
254                  else:
255                      raise ValueError('instruction added as field map but not passed as prompt parameter or set in executor')
256                  doc['instruction'] = instruction
257          elif trace_id not in self.id_2_doc:
258              raise Exception(f'trace_id {trace_id} is not valid. Please run \'execute_input_guardrails\' first')
259          else:
260              doc = self.id_2_doc[trace_id]
261          doc['response'] = llm_response
262          response = self.execute_deployment(doc)
263          if response and response['data']['status'] == 'FAIL':
264              print('Guardrail deployment run retured failed status, replacing with alternate response')
265              return response['data']['alternateResponse'], llm_response, response
266          else:
267              self.current_trace_id = None
268              return None, llm_response, response
269              '''
270  
271  
272