/ ragaai_catalyst / proxy_call.py
proxy_call.py
1 import requests 2 import json 3 import subprocess 4 import logging 5 import traceback 6 7 logger = logging.getLogger(__name__) 8 9 def api_completion(model,messages, api_base='http://127.0.0.1:8000', 10 api_key='',model_config=dict()): 11 whoami = get_username() 12 all_response = list() 13 job_id = model_config.get('job_id',-1) 14 converted_message = convert_input(messages,model,model_config) 15 payload = json.dumps(converted_message) 16 response = payload 17 headers = { 18 'Content-Type': 'application/json', 19 'Wd-PCA-Feature-Key':f'your_feature_key, $(whoami)' 20 } 21 try: 22 response = requests.request("POST", api_base, headers=headers, data=payload, verify=False) 23 if model_config.get('log_level','')=='debug': 24 logger.info(f'Model response Job ID {job_id} {response.text}') 25 if response.status_code!=200: 26 # logger.error(f'Error in model response Job ID {job_id}:',str(response.text)) 27 raise ValueError(str(response.text)) 28 except Exception as e: 29 logger.error(f'Error in calling api Job ID {job_id}:',str(e)) 30 raise ValueError(str(e)) 31 try: 32 response = response.json() 33 if 'error' in response: 34 logger.error(f'Invalid response from API Job ID {job_id}:'+str(response)) 35 raise ValueError(str(response.get('error'))) 36 all_response.append(convert_output(response,job_id)) 37 except ValueError as e1: 38 logger.error(f'Invalid json response from API Job ID {job_id}:'+response) 39 raise ValueError(str(e1)) 40 except Exception as e1: 41 if model_config.get('log_level','')=='debug': 42 logger.info(f"Error trace Job ID: {job_id} {traceback.print_exc()}") 43 logger.error(f"Exception in parsing model response Job ID:{job_id} {str(e1)}") 44 logger.error(f"Model response Job ID: {job_id} {response.text}") 45 all_response.append(None) 46 return all_response 47 48 def get_username(): 49 result = subprocess.run(['whoami'], capture_output=True, text=True) 50 result = result.stdout 51 return result 52 53 def convert_output(response,job_id): 54 try: 55 if response.get('prediction',{}).get('type','')=='generic-text-generation-v1': 56 return response['prediction']['output'] 57 elif response.get('prediction',{}).get('type','')=='gcp-multimodal-v1': 58 full_response = '' 59 for chunk in response['prediction']['output']['chunks']: 60 candidate = chunk['candidates'][0] 61 if candidate['finishReason'] and candidate['finishReason'] not in ['STOP']: 62 raise ValueError(candidate['finishReason']) 63 part = candidate['content']['parts'][0] 64 full_response += part['text'] 65 return full_response 66 else: 67 raise ValueError('Invalid prediction type passed in config') 68 except ValueError as e1: 69 raise ValueError(str(e1)) 70 except Exception as e: 71 logger.warning(f'Exception in formatting model response Job ID {job_id}:'+str(e)) 72 return None 73 74 75 def convert_input(prompt,model,model_config): 76 doc_input = { 77 "target": { 78 "provider": "echo", 79 "model": "echo" 80 }, 81 "task": { 82 "type": "gcp-multimodal-v1", 83 "prediction_type": "gcp-multimodal-v1", 84 "input": { 85 "contents": [ 86 { 87 "role": "user", 88 "parts": [ 89 { 90 "text": "Give me a recipe for banana bread." 91 } 92 ] 93 } 94 ], 95 "safetySettings": 96 [ 97 {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, 98 {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, 99 {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, 100 {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, 101 ], 102 "generationConfig": { 103 "temperature": 0, 104 "maxOutputTokens": 8000, 105 "topK": 40, 106 "topP": 0.95, 107 "stopSequences": [], 108 "candidateCount": 1 109 } 110 } 111 } 112 } 113 if 'provider' not in model_config: 114 doc_input['target']['provider'] = 'gcp' 115 else: 116 doc_input['target']['provider'] = model_config['provider'] 117 doc_input['task']['type'] = model_config.get('task_type','gcp-multimodal-v1') 118 doc_input['task']['prediction_type'] = model_config.get('prediction_type','generic-text-generation-v1') 119 if 'safetySettings' in model_config: 120 doc_input['task']['input']['safetySettings'] = model_config.get('safetySettings') 121 if 'generationConfig' in model_config: 122 doc_input['task']['input']['generationConfig'] = model_config.get('generationConfig') 123 doc_input['target']['model'] = model 124 if model_config.get('log_level','')=='debug': 125 logger.info(f"Using model configs Job ID {model_config.get('job_id',-1)}{doc_input}") 126 doc_input['task']['input']['contents'][0]['parts'] = [{"text":prompt[0]['content']}] 127 return doc_input 128 129 130 131 if __name__=='__main__': 132 message_list = ["Hi How are you","I am good","How are you"] 133 response = batch_completion('gemini/gemini-1.5-flash',message_list,0,1,100,api_base='http://127.0.0.1:5000') 134 print(response)