/ ragaai_catalyst / proxy_call.py
proxy_call.py
  1  import requests
  2  import json
  3  import subprocess
  4  import logging
  5  import traceback
  6  
  7  logger = logging.getLogger(__name__)
  8  
  9  def api_completion(model,messages, api_base='http://127.0.0.1:8000',
 10                      api_key='',model_config=dict()):
 11      whoami = get_username()
 12      all_response = list()
 13      job_id = model_config.get('job_id',-1)
 14      converted_message = convert_input(messages,model,model_config)
 15      payload = json.dumps(converted_message)
 16      response = payload
 17      headers = {
 18          'Content-Type': 'application/json',
 19          'Wd-PCA-Feature-Key':f'your_feature_key, $(whoami)'
 20      }
 21      try:
 22          response = requests.request("POST", api_base, headers=headers, data=payload, verify=False)
 23          if model_config.get('log_level','')=='debug':
 24              logger.info(f'Model response Job ID {job_id} {response.text}')
 25          if response.status_code!=200:
 26              # logger.error(f'Error in model response Job ID {job_id}:',str(response.text))
 27              raise ValueError(str(response.text))
 28      except Exception as e:
 29          logger.error(f'Error in calling api Job ID {job_id}:',str(e))
 30          raise ValueError(str(e))
 31      try:
 32          response = response.json()
 33          if 'error' in response:
 34              logger.error(f'Invalid response from API Job ID {job_id}:'+str(response))
 35              raise ValueError(str(response.get('error')))
 36          all_response.append(convert_output(response,job_id))
 37      except ValueError as e1:
 38          logger.error(f'Invalid json response from API Job ID {job_id}:'+response)
 39          raise ValueError(str(e1))
 40      except Exception as e1:
 41          if model_config.get('log_level','')=='debug':
 42              logger.info(f"Error trace Job ID: {job_id} {traceback.print_exc()}")
 43          logger.error(f"Exception in parsing model response Job ID:{job_id} {str(e1)}")
 44          logger.error(f"Model response Job ID: {job_id} {response.text}")
 45          all_response.append(None)
 46      return all_response
 47  
 48  def get_username():
 49      result = subprocess.run(['whoami'], capture_output=True, text=True)
 50      result = result.stdout
 51      return result
 52  
 53  def convert_output(response,job_id):
 54      try:
 55          if response.get('prediction',{}).get('type','')=='generic-text-generation-v1':
 56              return response['prediction']['output']
 57          elif response.get('prediction',{}).get('type','')=='gcp-multimodal-v1':
 58              full_response = ''
 59              for chunk in response['prediction']['output']['chunks']:
 60                  candidate = chunk['candidates'][0]
 61                  if candidate['finishReason'] and candidate['finishReason'] not in ['STOP']:
 62                      raise ValueError(candidate['finishReason'])
 63                  part = candidate['content']['parts'][0]
 64                  full_response += part['text']
 65              return full_response
 66          else:
 67              raise ValueError('Invalid prediction type passed in config')
 68      except ValueError as e1:
 69          raise ValueError(str(e1))
 70      except Exception as e:
 71          logger.warning(f'Exception in formatting model response Job ID {job_id}:'+str(e))
 72          return None
 73  
 74  
 75  def convert_input(prompt,model,model_config):
 76      doc_input = {
 77          "target": {
 78              "provider": "echo",
 79              "model": "echo"
 80          },
 81          "task": {
 82              "type": "gcp-multimodal-v1",
 83              "prediction_type": "gcp-multimodal-v1",
 84              "input": {
 85              "contents": [
 86                  {
 87                  "role": "user",
 88                  "parts": [
 89                      {
 90                      "text": "Give me a recipe for banana bread."
 91                      }
 92                  ]
 93                  }
 94              ],
 95              "safetySettings": 
 96                  [
 97                      {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
 98                      {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
 99                      {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
100                      {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
101                  ],
102              "generationConfig": {
103                  "temperature": 0,
104                  "maxOutputTokens": 8000,
105                  "topK": 40,
106                  "topP": 0.95,
107                  "stopSequences": [],
108                  "candidateCount": 1
109              }
110              }
111          }
112      }
113      if 'provider' not in model_config:
114          doc_input['target']['provider'] = 'gcp'
115      else:
116          doc_input['target']['provider'] = model_config['provider']
117      doc_input['task']['type'] = model_config.get('task_type','gcp-multimodal-v1')
118      doc_input['task']['prediction_type'] = model_config.get('prediction_type','generic-text-generation-v1')
119      if 'safetySettings' in model_config:
120          doc_input['task']['input']['safetySettings'] = model_config.get('safetySettings')
121      if 'generationConfig' in model_config:
122          doc_input['task']['input']['generationConfig'] = model_config.get('generationConfig')
123      doc_input['target']['model'] = model
124      if model_config.get('log_level','')=='debug':
125          logger.info(f"Using model configs Job ID {model_config.get('job_id',-1)}{doc_input}")
126      doc_input['task']['input']['contents'][0]['parts'] = [{"text":prompt[0]['content']}]
127      return doc_input
128  
129  
130  
131  if __name__=='__main__':
132      message_list = ["Hi How are you","I am good","How are you"]
133      response = batch_completion('gemini/gemini-1.5-flash',message_list,0,1,100,api_base='http://127.0.0.1:5000')
134      print(response)