/ tests / examples / all_llm_provider / all_llm_provider.py
all_llm_provider.py
  1  import sys
  2  import os
  3  sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..')))
  4  
  5  from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI
  6  import vertexai
  7  from vertexai.generative_models import GenerativeModel, GenerationConfig
  8  import google.generativeai as genai
  9  from litellm import completion, acompletion
 10  import litellm
 11  import argparse
 12  import anthropic
 13  import asyncio
 14  from anthropic import Anthropic, AsyncAnthropic
 15  from langchain_google_genai import ChatGoogleGenerativeAI
 16  from langchain_google_vertexai import ChatVertexAI
 17  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
 18  from groq import Groq, AsyncGroq
 19  
 20  from ragaai_catalyst import trace_llm
 21  from config import initialize_tracing
 22  tracer = initialize_tracing()
 23  
 24  from dotenv import load_dotenv
 25  load_dotenv()
 26  
 27  # Azure OpenAI setup
 28  azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
 29  azure_api_key = os.getenv("AZURE_OPENAI_API_KEY")
 30  azure_api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview")
 31  
 32  # Google AI setup
 33  genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
 34  
 35  # Vertex AI setup
 36  vertexai.init(project="gen-lang-client-0655603261", location="us-central1")
 37  
 38  async def get_llm_response(
 39      prompt,
 40      model, 
 41      provider,
 42      temperature,
 43      max_tokens,
 44      async_llm=False,
 45      ):
 46      """
 47      Main interface for getting responses from various LLM providers
 48      """
 49      if 'azure' in provider.lower():
 50          if async_llm:
 51              async_azure_openai_client = AsyncAzureOpenAI(azure_endpoint=azure_endpoint, api_key=azure_api_key, api_version=azure_api_version)
 52              return await _get_async_azure_openai_response(async_azure_openai_client, prompt, model, temperature, max_tokens)
 53          else:
 54              azure_openai_client = AzureOpenAI(azure_endpoint=azure_endpoint, api_key=azure_api_key, api_version=azure_api_version)
 55              return _get_azure_openai_response(azure_openai_client, prompt, model, temperature, max_tokens)
 56      elif 'openai_beta' in provider.lower():
 57          openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 58          return _get_openai_beta_response(openai_client, prompt, model, temperature, max_tokens)
 59      elif 'openai' in provider.lower():
 60          if async_llm:
 61              async_openai_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 62              return await _get_async_openai_response(async_openai_client, prompt, model, temperature, max_tokens)
 63          else:
 64              openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 65              return _get_openai_response(openai_client, prompt, model, temperature, max_tokens)
 66      elif 'chat_google' in provider.lower():
 67          if async_llm:
 68              return await _get_async_chat_google_generativeai_response(prompt, model, temperature, max_tokens)
 69          else:
 70              return _get_chat_google_generativeai_response(prompt, model, temperature, max_tokens)
 71      elif 'google' in provider.lower():
 72          if async_llm:
 73              return await _get_async_google_generativeai_response(prompt, model, temperature, max_tokens)
 74          else:
 75              return _get_google_generativeai_response(prompt, model, temperature, max_tokens)
 76      elif 'chat_vertexai' in provider.lower():
 77          if async_llm:
 78              return await _get_async_chat_vertexai_response(prompt, model, temperature, max_tokens)
 79          else:
 80              return _get_chat_vertexai_response(prompt, model, temperature, max_tokens)
 81      elif 'vertexai' in provider.lower():
 82          if async_llm:
 83              return await _get_async_vertexai_response(prompt, model, temperature, max_tokens)
 84          else:
 85              return _get_vertexai_response(prompt, model, temperature, max_tokens)
 86      elif 'anthropic' in provider.lower():
 87          if async_llm:
 88              async_anthropic_client = AsyncAnthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
 89              return await _get_async_anthropic_response(async_anthropic_client, prompt, model, temperature, max_tokens)
 90          else:
 91              anthropic_client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
 92              return _get_anthropic_response(anthropic_client, prompt, model, temperature, max_tokens)
 93      elif 'groq' in provider.lower():
 94          if async_llm:
 95              async_groq_client = AsyncGroq(api_key=os.getenv("GROQ_API_KEY"))
 96              return await _get_async_groq_response(async_groq_client, prompt, model, temperature, max_tokens)
 97          else:
 98              groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
 99              return _get_groq_response(groq_client, prompt, model, temperature, max_tokens)
100      elif 'litellm' in provider.lower():
101          if async_llm:
102              return await _get_async_litellm_response(prompt, model, temperature, max_tokens)
103          else:
104              return _get_litellm_response(prompt, model, temperature, max_tokens)
105  
106  
107  @trace_llm(name="_get_openai_response")
108  def _get_openai_response(
109      openai_client,
110      prompt,
111      model, 
112      temperature,
113      max_tokens,
114      ):
115      """
116      Get response from OpenAI API
117      """
118      try:
119          response = openai_client.chat.completions.create(
120              model=model,
121              messages=[{"role": "user", "content": prompt}],
122              temperature=temperature,
123              max_tokens=max_tokens
124          )
125          return response.choices[0].message.content
126      except Exception as e:
127          print(f"Error with OpenAI API: {str(e)}")
128          return None
129  
130  @trace_llm(name="_get_async_openai_response")
131  async def _get_async_openai_response(
132      async_openai_client,
133      prompt,
134      model, 
135      temperature,
136      max_tokens,
137      ):
138      """
139      Get async response from OpenAI API
140      """
141      try:
142          response = await async_openai_client.chat.completions.create(
143              model=model,
144              messages=[{"role": "user", "content": prompt}],
145              temperature=temperature,
146              max_tokens=max_tokens
147          )
148          return response.choices[0].message.content
149      except Exception as e:
150          print(f"Error with async OpenAI API: {str(e)}")
151          return None
152  
153  @trace_llm(name="_get_openai_beta_response")
154  def _get_openai_beta_response(
155      openai_client,
156      prompt,
157      model, 
158      temperature,
159      max_tokens
160      ):
161      assistant = openai_client.beta.assistants.create(model=model)
162      thread = openai_client.beta.threads.create()
163      message = openai_client.beta.threads.messages.create(
164          thread_id=thread.id,
165          role="user",
166          content=prompt
167      )
168      run = openai_client.beta.threads.runs.create_and_poll(
169          thread_id=thread.id,
170          assistant_id=assistant.id,
171          temperature=temperature,
172          max_completion_tokens=max_tokens
173      )
174      if run.status == 'completed':
175          messages = openai_client.beta.threads.messages.list(thread_id=thread.id)
176          return messages.data[0].content[0].text.value
177  
178  @trace_llm(name="_get_azure_openai_response")
179  def _get_azure_openai_response(
180      azure_openai_client,
181      prompt,
182      model, 
183      temperature,
184      max_tokens
185      ):
186      """
187      Get response from Azure OpenAI API
188      """
189      try:
190          response = azure_openai_client.chat.completions.create(
191              model=model,
192              messages=[{"role": "user", "content": prompt}],
193              temperature=temperature,
194              max_tokens=max_tokens
195          )
196          return response.choices[0].message.content
197      except Exception as e:
198          print(f"Error with Azure OpenAI API: {str(e)}")
199          return None
200  
201  @trace_llm(name="_get_async_azure_openai_response")
202  async def _get_async_azure_openai_response(
203      async_azure_openai_client,
204      prompt,
205      model, 
206      temperature,
207      max_tokens
208      ):
209      """
210      Get async response from Azure OpenAI API
211      """
212      try:
213          response = await async_azure_openai_client.chat.completions.create(
214              model=model,
215              messages=[{"role": "user", "content": prompt}],
216              temperature=temperature,
217              max_tokens=max_tokens
218          )
219          return response.choices[0].message.content
220      except Exception as e:
221          print(f"Error with async Azure OpenAI API: {str(e)}")
222          return None
223  
224  @trace_llm(name="_get_litellm_response")
225  def _get_litellm_response(
226      prompt,
227      model, 
228      temperature,
229      max_tokens
230      ):
231      """
232      Get response using LiteLLM
233      """
234      try:
235          response = completion(
236              model=model,
237              messages=[{"role": "user", "content": prompt}],
238              temperature=temperature,
239              max_tokens=max_tokens
240          )
241          return response.choices[0].message.content
242      except Exception as e:
243          print(f"Error with LiteLLM: {str(e)}")
244          return None
245  
246  @trace_llm(name="_get_async_litellm_response")
247  async def _get_async_litellm_response(
248      prompt,
249      model, 
250      temperature,
251      max_tokens
252      ):
253      """
254      Get async response using LiteLLM
255      """
256      try:
257          response = await acompletion(
258              model=model,
259              messages=[{"role": "user", "content": prompt}],
260              temperature=temperature,
261              max_tokens=max_tokens
262          )
263          return response.choices[0].message.content
264      except Exception as e:
265          print(f"Error with async LiteLLM: {str(e)}")
266          return None
267  
268  @trace_llm(name="_get_vertexai_response")
269  def _get_vertexai_response(
270      prompt,
271      model, 
272      temperature,
273      max_tokens
274      ):
275      """
276      Get response from VertexAI
277      """
278      try:
279          # vertexai.init(project="gen-lang-client-0655603261", location="us-central1")
280          model = GenerativeModel(
281              model_name=model
282              )
283          response = model.generate_content(
284              prompt,
285              generation_config=GenerationConfig(
286                  temperature=temperature,
287                  max_output_tokens=max_tokens
288              )
289          )
290          return response.text
291      except Exception as e:
292          print(f"Error with VertexAI: {str(e)}")
293          return None
294  
295  @trace_llm(name="_get_async_vertexai_response")
296  async def _get_async_vertexai_response(
297      prompt,
298      model, 
299      temperature,
300      max_tokens
301      ):
302      """
303      Get async response from VertexAI
304      """
305      try:
306          model = GenerativeModel(
307              model_name=model
308              )
309          response = await model.generate_content_async(
310              prompt,
311              generation_config=GenerationConfig(
312                  temperature=temperature,
313                  max_output_tokens=max_tokens
314              )
315          )
316          return response.text
317      except Exception as e:
318          print(f"Error with async VertexAI: {str(e)}")
319          return None
320  
321  @trace_llm(name="_get_google_generativeai_response")
322  def _get_google_generativeai_response(
323      prompt,
324      model, 
325      temperature,
326      max_tokens
327      ):
328      """
329      Get response from Google GenerativeAI
330      """
331      try:
332          model = genai.GenerativeModel(model)
333          response = model.generate_content(
334              prompt,
335              generation_config=genai.GenerationConfig(
336                  temperature=temperature,
337                  max_output_tokens=max_tokens
338              )
339          )
340          return response.text
341      except Exception as e:
342          print(f"Error with Google GenerativeAI: {str(e)}")
343          return None
344  
345  @trace_llm(name="_get_async_google_generativeai_response")
346  async def _get_async_google_generativeai_response(
347      prompt,
348      model, 
349      temperature,
350      max_tokens
351      ):
352      """
353      Get async response from Google GenerativeAI
354      """
355      try:
356          model = genai.GenerativeModel(model)
357          response = await model.generate_content_async(
358              prompt,
359              generation_config=genai.GenerationConfig(
360                  temperature=temperature,
361                  max_output_tokens=max_tokens
362              )
363          )
364          return response.text
365      except Exception as e:
366          print(f"Error with async Google GenerativeAI: {str(e)}")
367          return None
368  
369  @trace_llm(name="_get_anthropic_response")
370  def _get_anthropic_response(
371      anthropic_client,
372      prompt,
373      model, 
374      temperature,
375      max_tokens,
376      ):
377      try:
378          response = anthropic_client.messages.create(
379              model=model,
380              messages=[{"role": "user", "content": prompt}],
381              temperature=temperature,
382              max_tokens=max_tokens
383          )
384          return response.content[0].text
385      except Exception as e:
386          print(f"Error with Anthropic: {str(e)}")
387          return None
388  
389  @trace_llm(name="_get_async_anthropic_response")
390  async def _get_async_anthropic_response(
391      async_anthropic_client,
392      prompt,
393      model, 
394      temperature,
395      max_tokens, 
396      ):
397      try:
398          response = await async_anthropic_client.messages.create(
399              model=model,
400              messages=[{"role": "user", "content": prompt}],
401              temperature=temperature,
402              max_tokens=max_tokens
403          )
404          return response.content[0].text
405      except Exception as e:
406          print(f"Error with async Anthropic: {str(e)}")
407          return None
408  
409  @trace_llm(name="_get_chat_google_generativeai_response")
410  def _get_chat_google_generativeai_response(
411      prompt,
412      model, 
413      temperature,
414      max_tokens
415      ):
416      try:
417          model = ChatGoogleGenerativeAI(model=model)
418          response = model._generate(
419              [HumanMessage(content=prompt)],
420              generation_config=dict(
421                  temperature=temperature,
422                  max_output_tokens=max_tokens
423              )
424          )
425          return response.generations[0].text
426      except Exception as e:
427          print(f"Error with Google GenerativeAI: {str(e)}")
428          return None
429  
430  @trace_llm(name="_get_async_chat_google_generativeai_response")
431  async def _get_async_chat_google_generativeai_response(
432      prompt,
433      model, 
434      temperature,
435      max_tokens
436      ):
437      try:
438          model = ChatGoogleGenerativeAI(model=model)
439          response = await model._agenerate(
440              [HumanMessage(content=prompt)],
441              generation_config=dict(
442                  temperature=temperature,
443                  max_output_tokens=max_tokens
444              )
445          )
446          return response.generations[0].text
447      except Exception as e:
448          print(f"Error with async Google GenerativeAI: {str(e)}")
449          return None
450  
451  @trace_llm(name="_get_chat_vertexai_response")
452  def _get_chat_vertexai_response(
453      prompt,
454      model, 
455      temperature,
456      max_tokens
457      ):
458      try:
459          model = ChatVertexAI(
460              model=model, 
461              google_api_key=os.getenv("GOOGLE_API_KEY")
462              )
463          response = model._generate(
464              [HumanMessage(content=prompt)],
465              generation_config=dict(
466                  temperature=temperature,
467                  max_output_tokens=max_tokens
468              )
469          )
470          return response.generations[0].text
471      except Exception as e:
472          print(f"Error with VertexAI: {str(e)}")
473          return None
474  
475  @trace_llm(name="_get_async_chat_vertexai_response")
476  async def _get_async_chat_vertexai_response(
477      prompt,
478      model, 
479      temperature,
480      max_tokens
481      ):
482      try:
483          model = ChatVertexAI(
484              model=model, 
485              google_api_key=os.getenv("GOOGLE_API_KEY")
486              )
487          response = await model._agenerate(
488              [HumanMessage(content=prompt)],
489              generation_config=dict(
490                  temperature=temperature,
491                  max_output_tokens=max_tokens
492              )
493          )
494          return response.generations[0].text
495      except Exception as e:
496          print(f"Error with async VertexAI: {str(e)}")
497          return None
498  
499  @trace_llm(name="_get_groq_response")
500  def _get_groq_response(
501      groq_client,
502      prompt,
503      model, 
504      temperature,
505      max_tokens
506      ):
507      try:
508          response = groq_client.chat.completions.create(
509              model=model,
510              messages=[{"role": "user", "content": prompt}],
511              temperature=temperature,
512              max_tokens=max_tokens
513          )
514          return response.choices[0].message.content
515      except Exception as e:
516          print(f"Error with Groq: {str(e)}")
517          return None
518  
519  @trace_llm(name="_get_async_groq_response")
520  async def _get_async_groq_response(
521      async_groq_client,
522      prompt,
523      model, 
524      temperature,
525      max_tokens
526      ):
527      try:
528          response = await async_groq_client.chat.completions.create(
529              model=model,
530              messages=[{"role": "user", "content": prompt}],
531              temperature=temperature,
532              max_tokens=max_tokens
533          )
534          return response.choices[0].message.content
535      except Exception as e:
536          print(f"Error with async Groq: {str(e)}")
537          return None
538  
539  
540  if __name__ == "__main__":
541      # Parse command-line arguments
542      parser = argparse.ArgumentParser(description="Run the LLM provider test with different LLM models.")
543      parser.add_argument("--model", type=str, default="gpt-4o-mini", help="The model to use (e.g., gpt-4o-mini).")
544      parser.add_argument("--provider", type=str, default="openai", help="The LLM provider (e.g., openai, azure, google).")
545      parser.add_argument("--async_llm", type=bool, default=False, help="Whether to use async LLM calls.")
546      args = parser.parse_args()
547      
548  
549      with tracer:
550          response =  asyncio.run(get_llm_response(
551              prompt="Hello, how are you? Explain in one sentence.",
552              model=args.model,
553              provider=args.provider,
554              temperature=0.7,
555              max_tokens=100,
556              async_llm=args.async_llm
557          ))
558  
559  
560  
561  
562