/ examples / all_llm_provider / all_llm_provider.py
all_llm_provider.py
  1  import sys
  2  import os
  3  sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
  4  
  5  from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI
  6  import vertexai
  7  from vertexai.generative_models import GenerativeModel, GenerationConfig
  8  import google.generativeai as genai
  9  from litellm import completion, acompletion
 10  import litellm
 11  import anthropic
 12  from anthropic import Anthropic, AsyncAnthropic
 13  from langchain_google_genai import ChatGoogleGenerativeAI
 14  from langchain_google_vertexai import ChatVertexAI
 15  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
 16  from groq import Groq, AsyncGroq
 17  
 18  from ragaai_catalyst import trace_llm
 19  
 20  from dotenv import load_dotenv
 21  load_dotenv()
 22  
 23  # Azure OpenAI setup
 24  azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
 25  azure_api_key = os.getenv("AZURE_OPENAI_API_KEY")
 26  azure_api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview")
 27  
 28  # Google AI setup
 29  genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
 30  
 31  # Vertex AI setup
 32  vertexai.init(project="gen-lang-client-0655603261", location="us-central1")
 33  
 34  async def get_llm_response(
 35      prompt,
 36      model, 
 37      provider,
 38      temperature,
 39      max_tokens,
 40      async_llm=False,
 41      ):
 42      """
 43      Main interface for getting responses from various LLM providers
 44      """
 45      if 'azure' in provider.lower():
 46          if async_llm:
 47              async_azure_openai_client = AsyncAzureOpenAI(azure_endpoint=azure_endpoint, api_key=azure_api_key, api_version=azure_api_version)
 48              return await _get_async_azure_openai_response(async_azure_openai_client, prompt, model, temperature, max_tokens)
 49          else:
 50              azure_openai_client = AzureOpenAI(azure_endpoint=azure_endpoint, api_key=azure_api_key, api_version=azure_api_version)
 51              return _get_azure_openai_response(azure_openai_client, prompt, model, temperature, max_tokens)
 52      elif 'openai_beta' in provider.lower():
 53          openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 54          return _get_openai_beta_response(openai_client, prompt, model, temperature, max_tokens)
 55      elif 'openai' in provider.lower():
 56          if async_llm:
 57              async_openai_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 58              return await _get_async_openai_response(async_openai_client, prompt, model, temperature, max_tokens)
 59          else:
 60              openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 61              return _get_openai_response(openai_client, prompt, model, temperature, max_tokens)
 62      elif 'chat_google' in provider.lower():
 63          if async_llm:
 64              return await _get_async_chat_google_generativeai_response(prompt, model, temperature, max_tokens)
 65          else:
 66              return _get_chat_google_generativeai_response(prompt, model, temperature, max_tokens)
 67      elif 'google' in provider.lower():
 68          if async_llm:
 69              return await _get_async_google_generativeai_response(prompt, model, temperature, max_tokens)
 70          else:
 71              return _get_google_generativeai_response(prompt, model, temperature, max_tokens)
 72      elif 'chat_vertexai' in provider.lower():
 73          if async_llm:
 74              return await _get_async_chat_vertexai_response(prompt, model, temperature, max_tokens)
 75          else:
 76              return _get_chat_vertexai_response(prompt, model, temperature, max_tokens)
 77      elif 'vertexai' in provider.lower():
 78          if async_llm:
 79              return await _get_async_vertexai_response(prompt, model, temperature, max_tokens)
 80          else:
 81              return _get_vertexai_response(prompt, model, temperature, max_tokens)
 82      elif 'anthropic' in provider.lower():
 83          if async_llm:
 84              async_anthropic_client = AsyncAnthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
 85              return await _get_async_anthropic_response(async_anthropic_client, prompt, model, temperature, max_tokens)
 86          else:
 87              anthropic_client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
 88              return _get_anthropic_response(anthropic_client, prompt, model, temperature, max_tokens)
 89      elif 'groq' in provider.lower():
 90          if async_llm:
 91              async_groq_client = AsyncGroq(api_key=os.getenv("GROQ_API_KEY"))
 92              return await _get_async_groq_response(async_groq_client, prompt, model, temperature, max_tokens)
 93          else:
 94              groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
 95              return _get_groq_response(groq_client, prompt, model, temperature, max_tokens)
 96      elif 'litellm' in provider.lower():
 97          if async_llm:
 98              return await _get_async_litellm_response(prompt, model, temperature, max_tokens)
 99          else:
100              return _get_litellm_response(prompt, model, temperature, max_tokens)
101  
102  
103  @trace_llm(name="_get_openai_response")
104  def _get_openai_response(
105      openai_client,
106      prompt,
107      model, 
108      temperature,
109      max_tokens,
110      ):
111      """
112      Get response from OpenAI API
113      """
114      try:
115          response = openai_client.chat.completions.create(
116              model=model,
117              messages=[{"role": "user", "content": prompt}],
118              temperature=temperature,
119              max_tokens=max_tokens
120          )
121          return response.choices[0].message.content
122      except Exception as e:
123          print(f"Error with OpenAI API: {str(e)}")
124          return None
125  
126  @trace_llm(name="_get_async_openai_response")
127  async def _get_async_openai_response(
128      async_openai_client,
129      prompt,
130      model, 
131      temperature,
132      max_tokens,
133      ):
134      """
135      Get async response from OpenAI API
136      """
137      try:
138          response = await async_openai_client.chat.completions.create(
139              model=model,
140              messages=[{"role": "user", "content": prompt}],
141              temperature=temperature,
142              max_tokens=max_tokens
143          )
144          return response.choices[0].message.content
145      except Exception as e:
146          print(f"Error with async OpenAI API: {str(e)}")
147          return None
148  
149  @trace_llm(name="_get_openai_beta_response")
150  def _get_openai_beta_response(
151      openai_client,
152      prompt,
153      model, 
154      temperature,
155      max_tokens
156      ):
157      assistant = openai_client.beta.assistants.create(model=model)
158      thread = openai_client.beta.threads.create()
159      message = openai_client.beta.threads.messages.create(
160          thread_id=thread.id,
161          role="user",
162          content=prompt
163      )
164      run = openai_client.beta.threads.runs.create_and_poll(
165          thread_id=thread.id,
166          assistant_id=assistant.id,
167          temperature=temperature,
168          max_completion_tokens=max_tokens
169      )
170      if run.status == 'completed':
171          messages = openai_client.beta.threads.messages.list(thread_id=thread.id)
172          return messages.data[0].content[0].text.value
173  
174  @trace_llm(name="_get_azure_openai_response")
175  def _get_azure_openai_response(
176      azure_openai_client,
177      prompt,
178      model, 
179      temperature,
180      max_tokens
181      ):
182      """
183      Get response from Azure OpenAI API
184      """
185      try:
186          response = azure_openai_client.chat.completions.create(
187              model=model,
188              messages=[{"role": "user", "content": prompt}],
189              temperature=temperature,
190              max_tokens=max_tokens
191          )
192          return response.choices[0].message.content
193      except Exception as e:
194          print(f"Error with Azure OpenAI API: {str(e)}")
195          return None
196  
197  @trace_llm(name="_get_async_azure_openai_response")
198  async def _get_async_azure_openai_response(
199      async_azure_openai_client,
200      prompt,
201      model, 
202      temperature,
203      max_tokens
204      ):
205      """
206      Get async response from Azure OpenAI API
207      """
208      try:
209          response = await async_azure_openai_client.chat.completions.create(
210              model=model,
211              messages=[{"role": "user", "content": prompt}],
212              temperature=temperature,
213              max_tokens=max_tokens
214          )
215          return response.choices[0].message.content
216      except Exception as e:
217          print(f"Error with async Azure OpenAI API: {str(e)}")
218          return None
219  
220  @trace_llm(name="_get_litellm_response")
221  def _get_litellm_response(
222      prompt,
223      model, 
224      temperature,
225      max_tokens
226      ):
227      """
228      Get response using LiteLLM
229      """
230      try:
231          response = completion(
232              model=model,
233              messages=[{"role": "user", "content": prompt}],
234              temperature=temperature,
235              max_tokens=max_tokens
236          )
237          return response.choices[0].message.content
238      except Exception as e:
239          print(f"Error with LiteLLM: {str(e)}")
240          return None
241  
242  @trace_llm(name="_get_async_litellm_response")
243  async def _get_async_litellm_response(
244      prompt,
245      model, 
246      temperature,
247      max_tokens
248      ):
249      """
250      Get async response using LiteLLM
251      """
252      try:
253          response = await acompletion(
254              model=model,
255              messages=[{"role": "user", "content": prompt}],
256              temperature=temperature,
257              max_tokens=max_tokens
258          )
259          return response.choices[0].message.content
260      except Exception as e:
261          print(f"Error with async LiteLLM: {str(e)}")
262          return None
263  
264  @trace_llm(name="_get_vertexai_response")
265  def _get_vertexai_response(
266      prompt,
267      model, 
268      temperature,
269      max_tokens
270      ):
271      """
272      Get response from VertexAI
273      """
274      try:
275          # vertexai.init(project="gen-lang-client-0655603261", location="us-central1")
276          model = GenerativeModel(
277              model_name=model
278              )
279          response = model.generate_content(
280              prompt,
281              generation_config=GenerationConfig(
282                  temperature=temperature,
283                  max_output_tokens=max_tokens
284              )
285          )
286          return response.text
287      except Exception as e:
288          print(f"Error with VertexAI: {str(e)}")
289          return None
290  
291  @trace_llm(name="_get_async_vertexai_response")
292  async def _get_async_vertexai_response(
293      prompt,
294      model, 
295      temperature,
296      max_tokens
297      ):
298      """
299      Get async response from VertexAI
300      """
301      try:
302          model = GenerativeModel(
303              model_name=model
304              )
305          response = await model.generate_content_async(
306              prompt,
307              generation_config=GenerationConfig(
308                  temperature=temperature,
309                  max_output_tokens=max_tokens
310              )
311          )
312          return response.text
313      except Exception as e:
314          print(f"Error with async VertexAI: {str(e)}")
315          return None
316  
317  @trace_llm(name="_get_google_generativeai_response")
318  def _get_google_generativeai_response(
319      prompt,
320      model, 
321      temperature,
322      max_tokens
323      ):
324      """
325      Get response from Google GenerativeAI
326      """
327      try:
328          model = genai.GenerativeModel(model)
329          response = model.generate_content(
330              prompt,
331              generation_config=genai.GenerationConfig(
332                  temperature=temperature,
333                  max_output_tokens=max_tokens
334              )
335          )
336          return response.text
337      except Exception as e:
338          print(f"Error with Google GenerativeAI: {str(e)}")
339          return None
340  
341  @trace_llm(name="_get_async_google_generativeai_response")
342  async def _get_async_google_generativeai_response(
343      prompt,
344      model, 
345      temperature,
346      max_tokens
347      ):
348      """
349      Get async response from Google GenerativeAI
350      """
351      try:
352          model = genai.GenerativeModel(model)
353          response = await model.generate_content_async(
354              prompt,
355              generation_config=genai.GenerationConfig(
356                  temperature=temperature,
357                  max_output_tokens=max_tokens
358              )
359          )
360          return response.text
361      except Exception as e:
362          print(f"Error with async Google GenerativeAI: {str(e)}")
363          return None
364  
365  @trace_llm(name="_get_anthropic_response")
366  def _get_anthropic_response(
367      anthropic_client,
368      prompt,
369      model, 
370      temperature,
371      max_tokens,
372      ):
373      try:
374          response = anthropic_client.messages.create(
375              model=model,
376              messages=[{"role": "user", "content": prompt}],
377              temperature=temperature,
378              max_tokens=max_tokens
379          )
380          return response.content[0].text
381      except Exception as e:
382          print(f"Error with Anthropic: {str(e)}")
383          return None
384  
385  @trace_llm(name="_get_async_anthropic_response")
386  async def _get_async_anthropic_response(
387      async_anthropic_client,
388      prompt,
389      model, 
390      temperature,
391      max_tokens, 
392      ):
393      try:
394          response = await async_anthropic_client.messages.create(
395              model=model,
396              messages=[{"role": "user", "content": prompt}],
397              temperature=temperature,
398              max_tokens=max_tokens
399          )
400          return response.content[0].text
401      except Exception as e:
402          print(f"Error with async Anthropic: {str(e)}")
403          return None
404  
405  @trace_llm(name="_get_chat_google_generativeai_response")
406  def _get_chat_google_generativeai_response(
407      prompt,
408      model, 
409      temperature,
410      max_tokens
411      ):
412      try:
413          model = ChatGoogleGenerativeAI(model=model)
414          response = model._generate(
415              [HumanMessage(content=prompt)],
416              generation_config=dict(
417                  temperature=temperature,
418                  max_output_tokens=max_tokens
419              )
420          )
421          return response.generations[0].text
422      except Exception as e:
423          print(f"Error with Google GenerativeAI: {str(e)}")
424          return None
425  
426  @trace_llm(name="_get_async_chat_google_generativeai_response")
427  async def _get_async_chat_google_generativeai_response(
428      prompt,
429      model, 
430      temperature,
431      max_tokens
432      ):
433      try:
434          model = ChatGoogleGenerativeAI(model=model)
435          response = await model._agenerate(
436              [HumanMessage(content=prompt)],
437              generation_config=dict(
438                  temperature=temperature,
439                  max_output_tokens=max_tokens
440              )
441          )
442          return response.generations[0].text
443      except Exception as e:
444          print(f"Error with async Google GenerativeAI: {str(e)}")
445          return None
446  
447  @trace_llm(name="_get_chat_vertexai_response")
448  def _get_chat_vertexai_response(
449      prompt,
450      model, 
451      temperature,
452      max_tokens
453      ):
454      try:
455          model = ChatVertexAI(
456              model=model, 
457              google_api_key=os.getenv("GOOGLE_API_KEY")
458              )
459          response = model._generate(
460              [HumanMessage(content=prompt)],
461              generation_config=dict(
462                  temperature=temperature,
463                  max_output_tokens=max_tokens
464              )
465          )
466          return response.generations[0].text
467      except Exception as e:
468          print(f"Error with VertexAI: {str(e)}")
469          return None
470  
471  @trace_llm(name="_get_async_chat_vertexai_response")
472  async def _get_async_chat_vertexai_response(
473      prompt,
474      model, 
475      temperature,
476      max_tokens
477      ):
478      try:
479          model = ChatVertexAI(
480              model=model, 
481              google_api_key=os.getenv("GOOGLE_API_KEY")
482              )
483          response = await model._agenerate(
484              [HumanMessage(content=prompt)],
485              generation_config=dict(
486                  temperature=temperature,
487                  max_output_tokens=max_tokens
488              )
489          )
490          return response.generations[0].text
491      except Exception as e:
492          print(f"Error with async VertexAI: {str(e)}")
493          return None
494  
495  @trace_llm(name="_get_groq_response")
496  def _get_groq_response(
497      groq_client,
498      prompt,
499      model, 
500      temperature,
501      max_tokens
502      ):
503      try:
504          response = groq_client.chat.completions.create(
505              model=model,
506              messages=[{"role": "user", "content": prompt}],
507              temperature=temperature,
508              max_tokens=max_tokens
509          )
510          return response.choices[0].message.content
511      except Exception as e:
512          print(f"Error with Groq: {str(e)}")
513          return None
514  
515  @trace_llm(name="_get_async_groq_response")
516  async def _get_async_groq_response(
517      async_groq_client,
518      prompt,
519      model, 
520      temperature,
521      max_tokens
522      ):
523      try:
524          response = await async_groq_client.chat.completions.create(
525              model=model,
526              messages=[{"role": "user", "content": prompt}],
527              temperature=temperature,
528              max_tokens=max_tokens
529          )
530          return response.choices[0].message.content
531      except Exception as e:
532          print(f"Error with async Groq: {str(e)}")
533          return None
534