/ src / evidently / llm / utils / wrapper.py
wrapper.py
  1  import asyncio
  2  import dataclasses
  3  import datetime
  4  import json
  5  from abc import ABC
  6  from abc import abstractmethod
  7  from asyncio import Lock
  8  from asyncio import Semaphore
  9  from asyncio import sleep
 10  from importlib.util import find_spec
 11  from typing import Any
 12  from typing import Awaitable
 13  from typing import Callable
 14  from typing import ClassVar
 15  from typing import Dict
 16  from typing import Generic
 17  from typing import List
 18  from typing import Optional
 19  from typing import Sequence
 20  from typing import Tuple
 21  from typing import Type
 22  from typing import TypeVar
 23  
 24  from evidently._pydantic_compat import BaseModel
 25  from evidently._pydantic_compat import SecretStr
 26  from evidently.legacy.options.base import Options
 27  from evidently.legacy.options.option import Option
 28  from evidently.legacy.utils.sync import sync_api
 29  from evidently.llm.models import LLMMessage
 30  from evidently.llm.utils.errors import LLMRateLimitError
 31  from evidently.llm.utils.errors import LLMRequestError
 32  
 33  TResult = TypeVar("TResult")
 34  
 35  
 36  class RateLimits(BaseModel):
 37      """Rate limiting configuration for LLM API calls.
 38  
 39      Defines limits for requests per minute (RPM) and tokens per minute (TPM)
 40      to avoid exceeding API rate limits.
 41      """
 42  
 43      rpm: Optional[int] = None
 44      """Optional requests per minute limit."""
 45      itpm: Optional[int] = None
 46      """Optional input tokens per minute limit."""
 47      otpm: Optional[int] = None
 48      """Optional output tokens per minute limit."""
 49      tpm: Optional[int] = None
 50      """Optional total tokens per minute limit."""
 51      interval: datetime.timedelta = datetime.timedelta(minutes=1)
 52      """Time window for rate limiting."""
 53      # continious_token_refresh: bool = False
 54  
 55  
 56  @dataclasses.dataclass
 57  class _Enter:
 58      ts: datetime.datetime
 59      estimated_input_tokens: int
 60      estimated_output_tokens: int
 61      input_tokens: Optional[int] = None
 62      output_tokens: Optional[int] = None
 63      done: bool = False
 64  
 65      @property
 66      def estimated_tokens(self):
 67          return self.estimated_input_tokens + self.estimated_output_tokens
 68  
 69      @property
 70      def tokens(self):
 71          return self.input_tokens + self.output_tokens
 72  
 73  
 74  class _RateLimiterEntrypoint:
 75      def __init__(self, limiter: "RateLimiter", request: "LimitRequest"):
 76          self.limiter = limiter
 77          self.request = request
 78          self.enter = _Enter(datetime.datetime.now(), request.estimated_input, 0)
 79  
 80      @property
 81      def limits(self):
 82          return self.limiter.limits
 83  
 84      @property
 85      def enters(self):
 86          return self.limiter.enters
 87  
 88      @property
 89      def lock(self):
 90          return self.limiter.lock
 91  
 92      async def __aenter__(self):
 93          while True:
 94              async with self.lock:
 95                  await self.limiter.clean()
 96                  if self._check_rpm() and self._check_tokens():
 97                      self.enter.ts = datetime.datetime.now()
 98                      self.enter.estimated_output_tokens = self.limiter.mean_output_size()
 99                      self.enters.append(self.enter)
100                      break
101              await sleep(0.1)
102          return self
103  
104      async def __aexit__(self, exc_type, exc_val, exc_tb):
105          async with self.lock:
106              self.enter.done = True
107              stat = LimiterStat(
108                  self.enter.estimated_input_tokens,
109                  self.enter.estimated_output_tokens,
110                  self.enter.input_tokens or 0,
111                  self.enter.output_tokens or 0,
112              )
113              self.limiter.stats.append(stat)
114  
115      def record(self, input_tokens: int, output_tokens: int):
116          self.enter.input_tokens = input_tokens
117          self.enter.output_tokens = output_tokens
118  
119      def _check_rpm(self):
120          res = self.limits.rpm is None or len(self.enters) < self.limits.rpm
121          return res
122  
123      def _check_tokens(self):
124          used_input_tokens = 0
125          used_output_tokens = 0
126          for e in self.enters:
127              used_input_tokens += e.input_tokens or e.estimated_input_tokens
128              used_output_tokens += e.output_tokens or e.estimated_output_tokens
129          input_good = self.limits.itpm is None or used_input_tokens < self.limits.itpm
130          output_good = self.limits.otpm is None or used_output_tokens < self.limits.otpm
131          total_good = self.limits.tpm is None or used_output_tokens + used_input_tokens < self.limits.tpm
132          res = input_good and output_good and total_good
133          return res
134  
135  
136  @dataclasses.dataclass
137  class LimiterStat:
138      estimated_input_tokens: int
139      estimated_output_tokens: int
140      input_tokens: int
141      output_tokens: int
142  
143  
144  class RateLimiter:
145      """Rate limiter for LLM API calls.
146  
147      Enforces rate limits on requests and tokens to avoid exceeding API quotas.
148      Tracks usage over time windows and blocks requests when limits would be exceeded.
149      """
150  
151      def __init__(self, limits: RateLimits, initial_output_estimation: int = 100000):
152          """Initialize the rate limiter.
153  
154          Args:
155          * `limits`: `RateLimits` configuration.
156          * `initial_output_estimation`: Initial estimate for output token size.
157          """
158          self.limits = limits
159          self.enters: List[_Enter] = []
160          self.stats: List[LimiterStat] = []
161          self.lock = Lock()
162          self.initial_output_estimation = initial_output_estimation
163  
164      def enter(self, request: "LimitRequest"):
165          """Get a context manager for a rate-limited request.
166  
167          Args:
168          * `request`: `LimitRequest` to rate limit.
169  
170          Returns:
171          * Context manager that blocks until rate limits allow the request.
172          """
173          return _RateLimiterEntrypoint(self, request)
174  
175      async def clean(self):
176          """Remove old entries outside the rate limit window."""
177          now = datetime.datetime.now()
178          self.enters = [e for e in self.enters if not e.done or now - e.ts < self.limits.interval]
179  
180      def mean_output_size(self):
181          """Get the mean output token size from historical data.
182  
183          Returns:
184          * Mean output token size, or initial estimation if no data available.
185          """
186          if len(self.stats) == 0:
187              return self.initial_output_estimation
188          return sum(s.output_tokens for s in self.stats) / len(self.stats)
189  
190  
191  @dataclasses.dataclass
192  class LLMRequest(Generic[TResult]):
193      """Request to an LLM with messages and response parsing."""
194  
195      messages: List[LLMMessage]
196      """List of `LLMMessage` objects for the conversation."""
197      response_parser: Callable[[str], TResult]
198      """Function to parse the raw string response into `TResult`."""
199      response_type: Type[TResult]
200      """Type of the expected result."""
201      retries: int = 1
202      """Number of retry attempts on failure."""
203  
204  
205  @dataclasses.dataclass
206  class LLMResult(Generic[TResult]):
207      """Result from an LLM API call."""
208  
209      result: TResult
210      """Parsed result value."""
211      input_tokens: int
212      """Number of input tokens used."""
213      output_tokens: int
214      """Number of output tokens used."""
215  
216  
217  TBatchItem = TypeVar("TBatchItem")
218  TBatchResult = TypeVar("TBatchResult")
219  
220  
221  @dataclasses.dataclass
222  class LimitRequest(Generic[TBatchItem]):
223      request: TBatchItem
224      estimated_input: int
225      # estimated_output: int
226  
227  
228  class LLMWrapper(ABC):
229      """Base class for LLM API wrappers.
230  
231      Provides a unified interface for calling different LLM providers
232      with rate limiting, batching, and retry logic.
233  
234      Subclasses should implement `complete()` for the specific provider.
235      """
236  
237      __used_options__: ClassVar[List[Type[Option]]] = []
238  
239      @abstractmethod
240      async def complete(self, messages: List[LLMMessage], seed: Optional[int] = None) -> LLMResult[str]:
241          """Complete a conversation with the LLM.
242  
243          Args:
244          * `messages`: List of `LLMMessage` objects for the conversation.
245          * `seed`: Optional random seed for deterministic outputs.
246  
247          Returns:
248          * `LLMResult` with the response string and token usage.
249          """
250          raise NotImplementedError
251  
252      async def _batch(
253          self,
254          coro: Callable[[TBatchItem], Awaitable[LLMResult[TBatchResult]]],
255          batches: Sequence[LimitRequest[TBatchItem]],
256          batch_size: Optional[int] = None,
257          limits: Optional[RateLimits] = None,
258      ) -> List[TBatchResult]:
259          """Execute a batch of requests with rate limiting and concurrency control.
260  
261          Args:
262          * `coro`: Coroutine function to execute for each request.
263          * `batches`: Sequence of rate-limited requests.
264          * `batch_size`: Optional maximum concurrent requests.
265          * `limits`: Optional rate limits (uses default if not provided).
266  
267          Returns:
268          * List of results from all requests.
269          """
270          if batch_size is None:
271              batch_size = self.get_batch_size()
272          if limits is None:
273              limits = self.get_limits()
274          rate_limiter = RateLimiter(limits=limits)
275          semaphore = Semaphore(batch_size)
276  
277          async def work(request: LimitRequest[TBatchItem]) -> TBatchResult:
278              async with semaphore, rate_limiter.enter(request) as rate:
279                  res = await coro(request.request)
280                  rate.record(res.input_tokens, res.output_tokens)
281                  return res.result
282  
283          return await asyncio.gather(*[work(batch) for batch in batches])
284  
285      async def complete_batch(
286          self,
287          messages_batch: List[List[LLMMessage]],
288          batch_size: Optional[int] = None,
289          limits: Optional[RateLimits] = None,
290      ) -> List[str]:
291          """Complete multiple conversations in parallel.
292  
293          Args:
294          * `messages_batch`: List of message lists, one per conversation.
295          * `batch_size`: Optional maximum concurrent requests.
296          * `limits`: Optional rate limits.
297  
298          Returns:
299          * List of response strings.
300          """
301          requests = [LimitRequest(msgs, sum(self.estimate_tokens(m) for m in msgs)) for msgs in messages_batch]
302          return await self._batch(self.complete, requests, batch_size, limits)
303  
304      async def run(self, request: LLMRequest[TResult]) -> TResult:
305          """Run a single LLM request with retry logic.
306  
307          Args:
308          * `request`: `LLMRequest` to execute.
309  
310          Returns:
311          * Parsed result value.
312  
313          Raises:
314          * Exception: If all retries fail.
315          """
316          return (await self._run(request)).result
317  
318      async def _run(self, request: LLMRequest[TResult]) -> LLMResult[TResult]:
319          """Run a request with retry logic and response parsing.
320  
321          Args:
322          * `request`: `LLMRequest` to execute.
323  
324          Returns:
325          * `LLMResult` with parsed result and token usage.
326  
327          Raises:
328          * Exception: If all retries fail.
329          """
330          num_retries = request.retries
331          error = None
332          while num_retries >= 0:
333              num_retries -= 1
334              try:
335                  response = await self.complete(request.messages)
336                  return LLMResult(
337                      request.response_parser(response.result), response.input_tokens, response.output_tokens
338                  )
339              except Exception as e:
340                  error = e
341          raise error
342  
343      async def run_batch(
344          self,
345          requests: Sequence[LLMRequest[TResult]],
346          batch_size: Optional[int] = None,
347          limits: Optional[RateLimits] = None,
348      ) -> List[TResult]:
349          """Run multiple requests in parallel with rate limiting.
350  
351          Args:
352          * `requests`: Sequence of `LLMRequest` objects.
353          * `batch_size`: Optional maximum concurrent requests.
354          * `limits`: Optional rate limits.
355  
356          Returns:
357          * List of parsed results.
358          """
359          rs = [LimitRequest(r, sum(self.estimate_tokens(m) for m in r.messages)) for r in requests]
360          return await self._batch(self._run, rs, batch_size, limits)
361  
362      def get_batch_size(self) -> int:
363          """Get the default batch size for concurrent requests.
364  
365          Returns:
366          * Maximum number of concurrent requests (default: 100).
367          """
368          return 100
369  
370      def get_limits(self) -> RateLimits:
371          """Get the default rate limits.
372  
373          Returns:
374          * `RateLimits` with default values (no limits).
375          """
376          return RateLimits()
377  
378      def get_used_options(self) -> List[Type[Option]]:
379          """Get the option types used by this wrapper.
380  
381          Returns:
382          * List of `Option` classes that this wrapper accepts.
383          """
384          return self.__used_options__
385  
386      def estimate_tokens(self, msg: LLMMessage):
387          """Estimate token count for a message.
388  
389          Args:
390          * `msg`: `LLMMessage` to estimate.
391  
392          Returns:
393          * Estimated token count (default: character count).
394          """
395          return len(msg.content)
396  
397      complete_batch_sync = sync_api(complete_batch)
398      run_sync = sync_api(run)
399      run_batch_sync = sync_api(run_batch)
400  
401  
402  LLMProvider = str
403  LLMModel = str
404  LLMWrapperProvider = Callable[[LLMModel, Options], LLMWrapper]
405  _wrappers: Dict[Tuple[LLMProvider, Optional[LLMModel]], LLMWrapperProvider] = {}
406  
407  
408  def llm_provider(name: LLMProvider, model: Optional[LLMModel]) -> Callable[[LLMWrapperProvider], LLMWrapperProvider]:
409      def dec(f: LLMWrapperProvider):
410          _wrappers[(name, model)] = f
411          return f
412  
413      return dec
414  
415  
416  def get_llm_wrapper(provider: LLMProvider, model: LLMModel, options: Options) -> LLMWrapper:
417      """Get an LLM wrapper for the specified provider and model.
418  
419      Looks up registered wrappers first, then falls back to LiteLLM if available.
420  
421      Args:
422      * `provider`: Provider name (e.g., "openai", "anthropic").
423      * `model`: Model name (e.g., "gpt-4o-mini", "claude-3-sonnet").
424      * `options`: Processing options with provider-specific configuration.
425  
426      Returns:
427      * `LLMWrapper` instance for the provider/model.
428  
429      Raises:
430      * `ValueError`: If no wrapper is found for the provider/model.
431      """
432      key: Tuple[str, Optional[str]] = (provider, model)
433      if key in _wrappers:
434          return _wrappers[key](model, options)
435      key = (provider, None)
436      if key in _wrappers:
437          return _wrappers[key](model, options)
438      if find_spec("litellm") is not None:
439          litellm_wrapper = get_litellm_wrapper(provider, model, options)
440          if litellm_wrapper is not None:
441              return litellm_wrapper
442      raise ValueError(f"LLM wrapper for provider {provider} model {model} not found. Try installing litellm")
443  
444  
445  class LLMOptions(Option):
446      """Base class for LLM provider options.
447  
448      Provides common configuration for API keys, rate limits, and custom API URLs.
449      """
450  
451      __provider_name__: ClassVar[str]
452  
453      class Config:
454          extra = "forbid"
455  
456      api_key: Optional[SecretStr] = None
457      """Optional API key for the provider."""
458      # rpm_limit: int = 500
459      limits: RateLimits = RateLimits()
460      """Rate limiting configuration."""
461      api_url: Optional[str] = None
462      """Optional custom API URL (for self-hosted providers)."""
463  
464      def __init__(self, api_key: Optional[str] = None, rpm_limit: Optional[int] = None, **data):
465          """Initialize LLM options.
466  
467          Args:
468          * `api_key`: Optional API key for the provider.
469          * `rpm_limit`: Optional requests per minute limit (backward compatibility).
470          """
471          self.api_key = SecretStr(api_key) if api_key is not None else None
472          super().__init__(**data)
473          # backward comp
474          if rpm_limit is not None:
475              self.limits.rpm = rpm_limit
476  
477      def get_api_key(self) -> Optional[str]:
478          """Get the API key as a plain string.
479  
480          Returns:
481          * API key string, or `None` if not set.
482          """
483          if self.api_key is None:
484              return None
485          return self.api_key.get_secret_value()
486  
487      def get_additional_kwargs(self) -> Dict[str, Any]:
488          """Get additional keyword arguments for the LLM client.
489  
490          Returns:
491          * Dictionary of additional arguments (empty by default, can be overridden).
492          """
493          return {}
494  
495  
496  class OpenAIKey(LLMOptions):
497      """Options for OpenAI provider."""
498  
499      __provider_name__: ClassVar[str] = "openai"
500      limits: RateLimits = RateLimits(rpm=500)
501      """Rate limiting configuration (default: 500 requests per minute)."""
502  
503  
504  OpenAIOptions = OpenAIKey  # for consistency
505  
506  
507  @llm_provider("openai", None)
508  class OpenAIWrapper(LLMWrapper):
509      """Wrapper for OpenAI API.
510  
511      Provides async access to OpenAI's chat completion API with rate limiting
512      and token tracking.
513      """
514  
515      __used_options__: ClassVar = [OpenAIKey]
516  
517      def __init__(self, model: str, options: Options):
518          import openai
519  
520          self.model = model
521          self.options = options.get(OpenAIKey)
522          self._clients: Dict[int, openai.AsyncOpenAI] = {}
523  
524      @property
525      def client(self):
526          import openai
527  
528          try:
529              loop = asyncio.get_running_loop()
530          except RuntimeError as e:
531              raise RuntimeError("Cannot access OpenAIWrapper client without loop") from e
532          loop_id = id(loop)
533          if loop_id not in self._clients:
534              self._clients[loop_id] = openai.AsyncOpenAI(
535                  api_key=self.options.get_api_key(), base_url=self.options.api_url
536              )
537          return self._clients[loop_id]
538  
539      async def complete(self, messages: List[LLMMessage], seed: Optional[int] = None) -> LLMResult[str]:
540          import openai
541          from openai.types.chat.chat_completion import ChatCompletion
542  
543          messages = [{"role": msg.role, "content": msg.content} for msg in messages]
544          try:
545              response: ChatCompletion = await self.client.chat.completions.create(
546                  model=self.model, messages=messages, seed=seed
547              )  # type: ignore[arg-type]
548          except openai.RateLimitError as e:
549              raise LLMRateLimitError(e.message) from e
550          except openai.APIError as e:
551              raise LLMRequestError(f"Failed to call OpenAI complete API: {e.message}", original_error=e) from e
552  
553          content = response.choices[0].message.content
554          assert content is not None  # todo: better error
555          if response.usage is None:
556              return LLMResult(content, 0, 0)
557          return LLMResult(content, response.usage.prompt_tokens, response.usage.completion_tokens)
558  
559      def get_limits(self) -> RateLimits:
560          return self.options.limits
561  
562  
563  def get_litellm_wrapper(provider: LLMProvider, model: LLMModel, options: Options) -> Optional[LLMWrapper]:
564      from litellm import BadRequestError
565      from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
566  
567      try:
568          model, provider, *_ = get_llm_provider(model, provider)
569          return LiteLLMWrapper(f"{provider}/{model}", options)
570      except BadRequestError:
571          return None
572  
573  
574  @llm_provider("litellm", None)
575  class LiteLLMWrapper(LLMWrapper):
576      __llm_options_type__: ClassVar[Type[LLMOptions]] = LLMOptions
577  
578      def get_used_options(self) -> List[Type[Option]]:
579          return [self.__llm_options_type__]
580  
581      def __init__(self, model: str, options: Options):
582          self.model = model
583          self.options: LLMOptions = options.get(self.__llm_options_type__)
584  
585      @property
586      def provider_and_model(self) -> Tuple[Optional[str], str]:
587          if not hasattr(self.options, "__provider_name__"):
588              return None, self.model
589          provider_name = self.options.__provider_name__
590          if self.model.startswith(provider_name + "/"):
591              return provider_name, self.model[len(provider_name) + 1 :]
592          return provider_name, self.model
593  
594      async def complete(self, messages: List[LLMMessage], seed: Optional[int] = None) -> LLMResult[str]:
595          from litellm import acompletion
596          from litellm.types.utils import ModelResponse
597          from litellm.types.utils import Usage
598  
599          provider_name, model = self.provider_and_model
600          response: ModelResponse = await acompletion(
601              model=model,
602              custom_llm_provider=provider_name,
603              messages=[m.dict() for m in messages],
604              api_key=self.options.get_api_key(),
605              api_base=self.options.api_url,
606              seed=seed,
607              **self.options.get_additional_kwargs(),
608          )
609          content = response.choices[0].message.content
610          usage: Optional[Usage] = response.model_extra.get("usage")
611          if usage is None:
612              return LLMResult(content, 0, 0)
613          return LLMResult(content, usage.prompt_tokens, usage.completion_tokens)
614  
615      def get_limits(self) -> RateLimits:
616          return self.options.limits
617  
618  
619  class AnthropicOptions(LLMOptions):
620      __provider_name__: ClassVar = "anthropic"
621      limits: RateLimits = RateLimits(
622          rpm=50 // 12, itpm=40000 // 12, otpm=8000 // 12, interval=datetime.timedelta(seconds=5)
623      )
624  
625  
626  @llm_provider("anthropic", None)
627  class AnthropicWrapper(LiteLLMWrapper):
628      __llm_options_type__: ClassVar = AnthropicOptions
629  
630  
631  class GeminiOptions(LLMOptions):
632      __provider_name__: ClassVar = "gemini"
633  
634  
635  @llm_provider("gemini", None)
636  class GeminiWrapper(LiteLLMWrapper):
637      __llm_options_type__: ClassVar = GeminiOptions
638  
639  
640  class VertexAIOptions(LLMOptions):
641      __provider_name__: ClassVar = "vertex_ai"
642  
643      def get_additional_kwargs(self) -> Dict[str, Any]:
644          if self.api_key is None or len(self.api_key.get_secret_value()) > 10000:  # check for using non-strict json
645              return {}
646          try:
647              vertex_credentials = json.loads(self.api_key.get_secret_value())
648          except json.decoder.JSONDecodeError:
649              return {}
650          return {"vertex_credentials": vertex_credentials}
651  
652  
653  @llm_provider("vertex_ai", None)
654  class VertexAIWrapper(LiteLLMWrapper):
655      __llm_options_type__: ClassVar = VertexAIOptions
656  
657  
658  class DeepSeekOptions(LLMOptions):
659      __provider_name__: ClassVar = "deepseek"
660  
661  
662  @llm_provider("deepseek", None)
663  class DeepSeekWrapper(LiteLLMWrapper):
664      __llm_options_type__: ClassVar = DeepSeekOptions
665  
666  
667  class MistralOptions(LLMOptions):
668      __provider_name__: ClassVar = "mistral"
669      limits: RateLimits = RateLimits(rpm=1, itpm=500000 // 60, otpm=500000 // 60, interval=datetime.timedelta(seconds=1))
670  
671  
672  @llm_provider("mistral", None)
673  class MistralWrapper(LiteLLMWrapper):
674      __llm_options_type__: ClassVar = MistralOptions
675  
676  
677  class OllamaOptions(LLMOptions):
678      __provider_name__: ClassVar = "ollama"
679      api_url: str
680  
681  
682  @llm_provider("ollama", None)
683  class OllamaWrapper(LiteLLMWrapper):
684      __llm_options_type__: ClassVar = OllamaOptions
685  
686  
687  class NebiusOptions(LLMOptions):
688      __provider_name__: ClassVar = "nebius"
689  
690  
691  @llm_provider("nebius", None)
692  class NebiusWrapper(LiteLLMWrapper):
693      __llm_options_type__: ClassVar = NebiusOptions
694  
695  
696  excludes = [
697      "openai",  # supported natively
698      "openai_like",
699      "custom_openai",
700      "text-completion-openai",
701      "anthropic_text",
702      "huggingface",  # llama models do not work, disable until tested
703      "vertex_ai_beta",
704      "azure_text",
705      "sagemaker_chat",
706      "ollama_chat",
707      "text-completion-codestral",
708      "watsonx_text",
709      "custom",
710      "aiohttp_openai",
711  ]
712  litellm_providers = [
713      "openai",
714      "openai_like",
715      "jina_ai",
716      "xai",
717      "custom_openai",
718      "text-completion-openai",
719      "cohere",
720      "cohere_chat",
721      "clarifai",
722      "anthropic",
723      "anthropic_text",
724      "bytez",
725      "replicate",
726      "huggingface",
727      "together_ai",
728      "openrouter",
729      "datarobot",
730      "vertex_ai",
731      "vertex_ai_beta",
732      "gemini",
733      "ai21",
734      "baseten",
735      "azure",
736      "azure_text",
737      "azure_ai",
738      "sagemaker",
739      "sagemaker_chat",
740      "bedrock",
741      "vllm",
742      "nlp_cloud",
743      "petals",
744      "oobabooga",
745      "ollama",
746      "ollama_chat",
747      "deepinfra",
748      "perplexity",
749      "mistral",
750      "groq",
751      "nvidia_nim",
752      "cerebras",
753      "ai21_chat",
754      "volcengine",
755      "codestral",
756      "text-completion-codestral",
757      "dashscope",
758      "deepseek",
759      "sambanova",
760      "maritalk",
761      "voyage",
762      "cloudflare",
763      "xinference",
764      "fireworks_ai",
765      "friendliai",
766      "featherless_ai",
767      "watsonx",
768      "watsonx_text",
769      "triton",
770      "predibase",
771      "databricks",
772      "empower",
773      "github",
774      "custom",
775      "litellm_proxy",
776      "hosted_vllm",
777      "llamafile",
778      "lm_studio",
779      "galadriel",
780      "nebius",
781      "infinity",
782      "deepgram",
783      "elevenlabs",
784      "novita",
785      "aiohttp_openai",
786      "langfuse",
787      "humanloop",
788      "topaz",
789      "assemblyai",
790      "github_copilot",
791      "snowflake",
792      "meta_llama",
793      "nscale",
794  ]
795  litellm_providers = [p for p in litellm_providers if p not in excludes]
796  
797  
798  def _create_litellm_wrapper(provider: str):
799      words = provider.split("_")
800      class_name_prefix = "".join(word.upper() if word.lower() == "ai" else word.capitalize() for word in words)
801  
802      wrapper_name = f"{class_name_prefix}Wrapper"
803      options_name = f"{class_name_prefix}Options"
804      options_type = type(
805          options_name,
806          (LLMOptions,),
807          {"__provider_name__": provider, "__annotations__": {"__provider_name__": ClassVar[str]}},
808      )
809  
810      def __init__(self, model: str, options: Options):
811          super(self.__class__, self).__init__(model, options)
812  
813      wrapper_type = type(
814          wrapper_name,
815          (LiteLLMWrapper,),
816          {
817              "__llm_options_type__": options_type,
818              "__annotations__": {"__llm_options_type__": ClassVar},
819              "__init__": __init__,
820          },
821      )
822  
823      return {
824          wrapper_name: llm_provider(provider, None)(wrapper_type),
825          options_name: options_type,
826      }
827  
828  
829  for provider in litellm_providers:
830      key = (provider, None)
831      if key in _wrappers:
832          continue
833      locals().update(**_create_litellm_wrapper(provider))
834  
835  
836  def main():
837      from pprint import pformat
838  
839      import litellm
840  
841      print(pformat([p.value for p in litellm.provider_list]).replace("'", '"'))
842  
843  
844  if __name__ == "__main__":
845      main()