all_llm_provider.py
1 import sys 2 import os 3 sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))) 4 5 from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI 6 import vertexai 7 from vertexai.generative_models import GenerativeModel, GenerationConfig 8 import google.generativeai as genai 9 from litellm import completion, acompletion 10 import litellm 11 import argparse 12 import anthropic 13 import asyncio 14 from anthropic import Anthropic, AsyncAnthropic 15 from langchain_google_genai import ChatGoogleGenerativeAI 16 from langchain_google_vertexai import ChatVertexAI 17 from langchain_core.messages import SystemMessage, HumanMessage, AIMessage 18 from groq import Groq, AsyncGroq 19 20 from ragaai_catalyst import trace_llm 21 from config import initialize_tracing 22 tracer = initialize_tracing() 23 24 from dotenv import load_dotenv 25 load_dotenv() 26 27 # Azure OpenAI setup 28 azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") 29 azure_api_key = os.getenv("AZURE_OPENAI_API_KEY") 30 azure_api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview") 31 32 # Google AI setup 33 genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) 34 35 # Vertex AI setup 36 vertexai.init(project="gen-lang-client-0655603261", location="us-central1") 37 38 async def get_llm_response( 39 prompt, 40 model, 41 provider, 42 temperature, 43 max_tokens, 44 async_llm=False, 45 ): 46 """ 47 Main interface for getting responses from various LLM providers 48 """ 49 if 'azure' in provider.lower(): 50 if async_llm: 51 async_azure_openai_client = AsyncAzureOpenAI(azure_endpoint=azure_endpoint, api_key=azure_api_key, api_version=azure_api_version) 52 return await _get_async_azure_openai_response(async_azure_openai_client, prompt, model, temperature, max_tokens) 53 else: 54 azure_openai_client = AzureOpenAI(azure_endpoint=azure_endpoint, api_key=azure_api_key, api_version=azure_api_version) 55 return _get_azure_openai_response(azure_openai_client, prompt, model, temperature, max_tokens) 56 elif 'openai_beta' in provider.lower(): 57 openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) 58 return _get_openai_beta_response(openai_client, prompt, model, temperature, max_tokens) 59 elif 'openai' in provider.lower(): 60 if async_llm: 61 async_openai_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) 62 return await _get_async_openai_response(async_openai_client, prompt, model, temperature, max_tokens) 63 else: 64 openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) 65 return _get_openai_response(openai_client, prompt, model, temperature, max_tokens) 66 elif 'chat_google' in provider.lower(): 67 if async_llm: 68 return await _get_async_chat_google_generativeai_response(prompt, model, temperature, max_tokens) 69 else: 70 return _get_chat_google_generativeai_response(prompt, model, temperature, max_tokens) 71 elif 'google' in provider.lower(): 72 if async_llm: 73 return await _get_async_google_generativeai_response(prompt, model, temperature, max_tokens) 74 else: 75 return _get_google_generativeai_response(prompt, model, temperature, max_tokens) 76 elif 'chat_vertexai' in provider.lower(): 77 if async_llm: 78 return await _get_async_chat_vertexai_response(prompt, model, temperature, max_tokens) 79 else: 80 return _get_chat_vertexai_response(prompt, model, temperature, max_tokens) 81 elif 'vertexai' in provider.lower(): 82 if async_llm: 83 return await _get_async_vertexai_response(prompt, model, temperature, max_tokens) 84 else: 85 return _get_vertexai_response(prompt, model, temperature, max_tokens) 86 elif 'anthropic' in provider.lower(): 87 if async_llm: 88 async_anthropic_client = AsyncAnthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) 89 return await _get_async_anthropic_response(async_anthropic_client, prompt, model, temperature, max_tokens) 90 else: 91 anthropic_client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) 92 return _get_anthropic_response(anthropic_client, prompt, model, temperature, max_tokens) 93 elif 'groq' in provider.lower(): 94 if async_llm: 95 async_groq_client = AsyncGroq(api_key=os.getenv("GROQ_API_KEY")) 96 return await _get_async_groq_response(async_groq_client, prompt, model, temperature, max_tokens) 97 else: 98 groq_client = Groq(api_key=os.getenv("GROQ_API_KEY")) 99 return _get_groq_response(groq_client, prompt, model, temperature, max_tokens) 100 elif 'litellm' in provider.lower(): 101 if async_llm: 102 return await _get_async_litellm_response(prompt, model, temperature, max_tokens) 103 else: 104 return _get_litellm_response(prompt, model, temperature, max_tokens) 105 106 107 @trace_llm(name="_get_openai_response") 108 def _get_openai_response( 109 openai_client, 110 prompt, 111 model, 112 temperature, 113 max_tokens, 114 ): 115 """ 116 Get response from OpenAI API 117 """ 118 try: 119 response = openai_client.chat.completions.create( 120 model=model, 121 messages=[{"role": "user", "content": prompt}], 122 temperature=temperature, 123 max_tokens=max_tokens 124 ) 125 return response.choices[0].message.content 126 except Exception as e: 127 print(f"Error with OpenAI API: {str(e)}") 128 return None 129 130 @trace_llm(name="_get_async_openai_response") 131 async def _get_async_openai_response( 132 async_openai_client, 133 prompt, 134 model, 135 temperature, 136 max_tokens, 137 ): 138 """ 139 Get async response from OpenAI API 140 """ 141 try: 142 response = await async_openai_client.chat.completions.create( 143 model=model, 144 messages=[{"role": "user", "content": prompt}], 145 temperature=temperature, 146 max_tokens=max_tokens 147 ) 148 return response.choices[0].message.content 149 except Exception as e: 150 print(f"Error with async OpenAI API: {str(e)}") 151 return None 152 153 @trace_llm(name="_get_openai_beta_response") 154 def _get_openai_beta_response( 155 openai_client, 156 prompt, 157 model, 158 temperature, 159 max_tokens 160 ): 161 assistant = openai_client.beta.assistants.create(model=model) 162 thread = openai_client.beta.threads.create() 163 message = openai_client.beta.threads.messages.create( 164 thread_id=thread.id, 165 role="user", 166 content=prompt 167 ) 168 run = openai_client.beta.threads.runs.create_and_poll( 169 thread_id=thread.id, 170 assistant_id=assistant.id, 171 temperature=temperature, 172 max_completion_tokens=max_tokens 173 ) 174 if run.status == 'completed': 175 messages = openai_client.beta.threads.messages.list(thread_id=thread.id) 176 return messages.data[0].content[0].text.value 177 178 @trace_llm(name="_get_azure_openai_response") 179 def _get_azure_openai_response( 180 azure_openai_client, 181 prompt, 182 model, 183 temperature, 184 max_tokens 185 ): 186 """ 187 Get response from Azure OpenAI API 188 """ 189 try: 190 response = azure_openai_client.chat.completions.create( 191 model=model, 192 messages=[{"role": "user", "content": prompt}], 193 temperature=temperature, 194 max_tokens=max_tokens 195 ) 196 return response.choices[0].message.content 197 except Exception as e: 198 print(f"Error with Azure OpenAI API: {str(e)}") 199 return None 200 201 @trace_llm(name="_get_async_azure_openai_response") 202 async def _get_async_azure_openai_response( 203 async_azure_openai_client, 204 prompt, 205 model, 206 temperature, 207 max_tokens 208 ): 209 """ 210 Get async response from Azure OpenAI API 211 """ 212 try: 213 response = await async_azure_openai_client.chat.completions.create( 214 model=model, 215 messages=[{"role": "user", "content": prompt}], 216 temperature=temperature, 217 max_tokens=max_tokens 218 ) 219 return response.choices[0].message.content 220 except Exception as e: 221 print(f"Error with async Azure OpenAI API: {str(e)}") 222 return None 223 224 @trace_llm(name="_get_litellm_response") 225 def _get_litellm_response( 226 prompt, 227 model, 228 temperature, 229 max_tokens 230 ): 231 """ 232 Get response using LiteLLM 233 """ 234 try: 235 response = completion( 236 model=model, 237 messages=[{"role": "user", "content": prompt}], 238 temperature=temperature, 239 max_tokens=max_tokens 240 ) 241 return response.choices[0].message.content 242 except Exception as e: 243 print(f"Error with LiteLLM: {str(e)}") 244 return None 245 246 @trace_llm(name="_get_async_litellm_response") 247 async def _get_async_litellm_response( 248 prompt, 249 model, 250 temperature, 251 max_tokens 252 ): 253 """ 254 Get async response using LiteLLM 255 """ 256 try: 257 response = await acompletion( 258 model=model, 259 messages=[{"role": "user", "content": prompt}], 260 temperature=temperature, 261 max_tokens=max_tokens 262 ) 263 return response.choices[0].message.content 264 except Exception as e: 265 print(f"Error with async LiteLLM: {str(e)}") 266 return None 267 268 @trace_llm(name="_get_vertexai_response") 269 def _get_vertexai_response( 270 prompt, 271 model, 272 temperature, 273 max_tokens 274 ): 275 """ 276 Get response from VertexAI 277 """ 278 try: 279 # vertexai.init(project="gen-lang-client-0655603261", location="us-central1") 280 model = GenerativeModel( 281 model_name=model 282 ) 283 response = model.generate_content( 284 prompt, 285 generation_config=GenerationConfig( 286 temperature=temperature, 287 max_output_tokens=max_tokens 288 ) 289 ) 290 return response.text 291 except Exception as e: 292 print(f"Error with VertexAI: {str(e)}") 293 return None 294 295 @trace_llm(name="_get_async_vertexai_response") 296 async def _get_async_vertexai_response( 297 prompt, 298 model, 299 temperature, 300 max_tokens 301 ): 302 """ 303 Get async response from VertexAI 304 """ 305 try: 306 model = GenerativeModel( 307 model_name=model 308 ) 309 response = await model.generate_content_async( 310 prompt, 311 generation_config=GenerationConfig( 312 temperature=temperature, 313 max_output_tokens=max_tokens 314 ) 315 ) 316 return response.text 317 except Exception as e: 318 print(f"Error with async VertexAI: {str(e)}") 319 return None 320 321 @trace_llm(name="_get_google_generativeai_response") 322 def _get_google_generativeai_response( 323 prompt, 324 model, 325 temperature, 326 max_tokens 327 ): 328 """ 329 Get response from Google GenerativeAI 330 """ 331 try: 332 model = genai.GenerativeModel(model) 333 response = model.generate_content( 334 prompt, 335 generation_config=genai.GenerationConfig( 336 temperature=temperature, 337 max_output_tokens=max_tokens 338 ) 339 ) 340 return response.text 341 except Exception as e: 342 print(f"Error with Google GenerativeAI: {str(e)}") 343 return None 344 345 @trace_llm(name="_get_async_google_generativeai_response") 346 async def _get_async_google_generativeai_response( 347 prompt, 348 model, 349 temperature, 350 max_tokens 351 ): 352 """ 353 Get async response from Google GenerativeAI 354 """ 355 try: 356 model = genai.GenerativeModel(model) 357 response = await model.generate_content_async( 358 prompt, 359 generation_config=genai.GenerationConfig( 360 temperature=temperature, 361 max_output_tokens=max_tokens 362 ) 363 ) 364 return response.text 365 except Exception as e: 366 print(f"Error with async Google GenerativeAI: {str(e)}") 367 return None 368 369 @trace_llm(name="_get_anthropic_response") 370 def _get_anthropic_response( 371 anthropic_client, 372 prompt, 373 model, 374 temperature, 375 max_tokens, 376 ): 377 try: 378 response = anthropic_client.messages.create( 379 model=model, 380 messages=[{"role": "user", "content": prompt}], 381 temperature=temperature, 382 max_tokens=max_tokens 383 ) 384 return response.content[0].text 385 except Exception as e: 386 print(f"Error with Anthropic: {str(e)}") 387 return None 388 389 @trace_llm(name="_get_async_anthropic_response") 390 async def _get_async_anthropic_response( 391 async_anthropic_client, 392 prompt, 393 model, 394 temperature, 395 max_tokens, 396 ): 397 try: 398 response = await async_anthropic_client.messages.create( 399 model=model, 400 messages=[{"role": "user", "content": prompt}], 401 temperature=temperature, 402 max_tokens=max_tokens 403 ) 404 return response.content[0].text 405 except Exception as e: 406 print(f"Error with async Anthropic: {str(e)}") 407 return None 408 409 @trace_llm(name="_get_chat_google_generativeai_response") 410 def _get_chat_google_generativeai_response( 411 prompt, 412 model, 413 temperature, 414 max_tokens 415 ): 416 try: 417 model = ChatGoogleGenerativeAI(model=model) 418 response = model._generate( 419 [HumanMessage(content=prompt)], 420 generation_config=dict( 421 temperature=temperature, 422 max_output_tokens=max_tokens 423 ) 424 ) 425 return response.generations[0].text 426 except Exception as e: 427 print(f"Error with Google GenerativeAI: {str(e)}") 428 return None 429 430 @trace_llm(name="_get_async_chat_google_generativeai_response") 431 async def _get_async_chat_google_generativeai_response( 432 prompt, 433 model, 434 temperature, 435 max_tokens 436 ): 437 try: 438 model = ChatGoogleGenerativeAI(model=model) 439 response = await model._agenerate( 440 [HumanMessage(content=prompt)], 441 generation_config=dict( 442 temperature=temperature, 443 max_output_tokens=max_tokens 444 ) 445 ) 446 return response.generations[0].text 447 except Exception as e: 448 print(f"Error with async Google GenerativeAI: {str(e)}") 449 return None 450 451 @trace_llm(name="_get_chat_vertexai_response") 452 def _get_chat_vertexai_response( 453 prompt, 454 model, 455 temperature, 456 max_tokens 457 ): 458 try: 459 model = ChatVertexAI( 460 model=model, 461 google_api_key=os.getenv("GOOGLE_API_KEY") 462 ) 463 response = model._generate( 464 [HumanMessage(content=prompt)], 465 generation_config=dict( 466 temperature=temperature, 467 max_output_tokens=max_tokens 468 ) 469 ) 470 return response.generations[0].text 471 except Exception as e: 472 print(f"Error with VertexAI: {str(e)}") 473 return None 474 475 @trace_llm(name="_get_async_chat_vertexai_response") 476 async def _get_async_chat_vertexai_response( 477 prompt, 478 model, 479 temperature, 480 max_tokens 481 ): 482 try: 483 model = ChatVertexAI( 484 model=model, 485 google_api_key=os.getenv("GOOGLE_API_KEY") 486 ) 487 response = await model._agenerate( 488 [HumanMessage(content=prompt)], 489 generation_config=dict( 490 temperature=temperature, 491 max_output_tokens=max_tokens 492 ) 493 ) 494 return response.generations[0].text 495 except Exception as e: 496 print(f"Error with async VertexAI: {str(e)}") 497 return None 498 499 @trace_llm(name="_get_groq_response") 500 def _get_groq_response( 501 groq_client, 502 prompt, 503 model, 504 temperature, 505 max_tokens 506 ): 507 try: 508 response = groq_client.chat.completions.create( 509 model=model, 510 messages=[{"role": "user", "content": prompt}], 511 temperature=temperature, 512 max_tokens=max_tokens 513 ) 514 return response.choices[0].message.content 515 except Exception as e: 516 print(f"Error with Groq: {str(e)}") 517 return None 518 519 @trace_llm(name="_get_async_groq_response") 520 async def _get_async_groq_response( 521 async_groq_client, 522 prompt, 523 model, 524 temperature, 525 max_tokens 526 ): 527 try: 528 response = await async_groq_client.chat.completions.create( 529 model=model, 530 messages=[{"role": "user", "content": prompt}], 531 temperature=temperature, 532 max_tokens=max_tokens 533 ) 534 return response.choices[0].message.content 535 except Exception as e: 536 print(f"Error with async Groq: {str(e)}") 537 return None 538 539 540 if __name__ == "__main__": 541 # Parse command-line arguments 542 parser = argparse.ArgumentParser(description="Run the LLM provider test with different LLM models.") 543 parser.add_argument("--model", type=str, default="gpt-4o-mini", help="The model to use (e.g., gpt-4o-mini).") 544 parser.add_argument("--provider", type=str, default="openai", help="The LLM provider (e.g., openai, azure, google).") 545 parser.add_argument("--async_llm", type=bool, default=False, help="Whether to use async LLM calls.") 546 args = parser.parse_args() 547 548 549 with tracer: 550 response = asyncio.run(get_llm_response( 551 prompt="Hello, how are you? Explain in one sentence.", 552 model=args.model, 553 provider=args.provider, 554 temperature=0.7, 555 max_tokens=100, 556 async_llm=args.async_llm 557 )) 558 559 560 561 562