wrapper.py
1 import asyncio 2 import dataclasses 3 import datetime 4 import json 5 from abc import ABC 6 from abc import abstractmethod 7 from asyncio import Lock 8 from asyncio import Semaphore 9 from asyncio import sleep 10 from importlib.util import find_spec 11 from typing import Any 12 from typing import Awaitable 13 from typing import Callable 14 from typing import ClassVar 15 from typing import Dict 16 from typing import Generic 17 from typing import List 18 from typing import Optional 19 from typing import Sequence 20 from typing import Tuple 21 from typing import Type 22 from typing import TypeVar 23 24 from evidently._pydantic_compat import BaseModel 25 from evidently._pydantic_compat import SecretStr 26 from evidently.legacy.options.base import Options 27 from evidently.legacy.options.option import Option 28 from evidently.legacy.utils.sync import sync_api 29 from evidently.llm.models import LLMMessage 30 from evidently.llm.utils.errors import LLMRateLimitError 31 from evidently.llm.utils.errors import LLMRequestError 32 33 TResult = TypeVar("TResult") 34 35 36 class RateLimits(BaseModel): 37 """Rate limiting configuration for LLM API calls. 38 39 Defines limits for requests per minute (RPM) and tokens per minute (TPM) 40 to avoid exceeding API rate limits. 41 """ 42 43 rpm: Optional[int] = None 44 """Optional requests per minute limit.""" 45 itpm: Optional[int] = None 46 """Optional input tokens per minute limit.""" 47 otpm: Optional[int] = None 48 """Optional output tokens per minute limit.""" 49 tpm: Optional[int] = None 50 """Optional total tokens per minute limit.""" 51 interval: datetime.timedelta = datetime.timedelta(minutes=1) 52 """Time window for rate limiting.""" 53 # continious_token_refresh: bool = False 54 55 56 @dataclasses.dataclass 57 class _Enter: 58 ts: datetime.datetime 59 estimated_input_tokens: int 60 estimated_output_tokens: int 61 input_tokens: Optional[int] = None 62 output_tokens: Optional[int] = None 63 done: bool = False 64 65 @property 66 def estimated_tokens(self): 67 return self.estimated_input_tokens + self.estimated_output_tokens 68 69 @property 70 def tokens(self): 71 return self.input_tokens + self.output_tokens 72 73 74 class _RateLimiterEntrypoint: 75 def __init__(self, limiter: "RateLimiter", request: "LimitRequest"): 76 self.limiter = limiter 77 self.request = request 78 self.enter = _Enter(datetime.datetime.now(), request.estimated_input, 0) 79 80 @property 81 def limits(self): 82 return self.limiter.limits 83 84 @property 85 def enters(self): 86 return self.limiter.enters 87 88 @property 89 def lock(self): 90 return self.limiter.lock 91 92 async def __aenter__(self): 93 while True: 94 async with self.lock: 95 await self.limiter.clean() 96 if self._check_rpm() and self._check_tokens(): 97 self.enter.ts = datetime.datetime.now() 98 self.enter.estimated_output_tokens = self.limiter.mean_output_size() 99 self.enters.append(self.enter) 100 break 101 await sleep(0.1) 102 return self 103 104 async def __aexit__(self, exc_type, exc_val, exc_tb): 105 async with self.lock: 106 self.enter.done = True 107 stat = LimiterStat( 108 self.enter.estimated_input_tokens, 109 self.enter.estimated_output_tokens, 110 self.enter.input_tokens or 0, 111 self.enter.output_tokens or 0, 112 ) 113 self.limiter.stats.append(stat) 114 115 def record(self, input_tokens: int, output_tokens: int): 116 self.enter.input_tokens = input_tokens 117 self.enter.output_tokens = output_tokens 118 119 def _check_rpm(self): 120 res = self.limits.rpm is None or len(self.enters) < self.limits.rpm 121 return res 122 123 def _check_tokens(self): 124 used_input_tokens = 0 125 used_output_tokens = 0 126 for e in self.enters: 127 used_input_tokens += e.input_tokens or e.estimated_input_tokens 128 used_output_tokens += e.output_tokens or e.estimated_output_tokens 129 input_good = self.limits.itpm is None or used_input_tokens < self.limits.itpm 130 output_good = self.limits.otpm is None or used_output_tokens < self.limits.otpm 131 total_good = self.limits.tpm is None or used_output_tokens + used_input_tokens < self.limits.tpm 132 res = input_good and output_good and total_good 133 return res 134 135 136 @dataclasses.dataclass 137 class LimiterStat: 138 estimated_input_tokens: int 139 estimated_output_tokens: int 140 input_tokens: int 141 output_tokens: int 142 143 144 class RateLimiter: 145 """Rate limiter for LLM API calls. 146 147 Enforces rate limits on requests and tokens to avoid exceeding API quotas. 148 Tracks usage over time windows and blocks requests when limits would be exceeded. 149 """ 150 151 def __init__(self, limits: RateLimits, initial_output_estimation: int = 100000): 152 """Initialize the rate limiter. 153 154 Args: 155 * `limits`: `RateLimits` configuration. 156 * `initial_output_estimation`: Initial estimate for output token size. 157 """ 158 self.limits = limits 159 self.enters: List[_Enter] = [] 160 self.stats: List[LimiterStat] = [] 161 self.lock = Lock() 162 self.initial_output_estimation = initial_output_estimation 163 164 def enter(self, request: "LimitRequest"): 165 """Get a context manager for a rate-limited request. 166 167 Args: 168 * `request`: `LimitRequest` to rate limit. 169 170 Returns: 171 * Context manager that blocks until rate limits allow the request. 172 """ 173 return _RateLimiterEntrypoint(self, request) 174 175 async def clean(self): 176 """Remove old entries outside the rate limit window.""" 177 now = datetime.datetime.now() 178 self.enters = [e for e in self.enters if not e.done or now - e.ts < self.limits.interval] 179 180 def mean_output_size(self): 181 """Get the mean output token size from historical data. 182 183 Returns: 184 * Mean output token size, or initial estimation if no data available. 185 """ 186 if len(self.stats) == 0: 187 return self.initial_output_estimation 188 return sum(s.output_tokens for s in self.stats) / len(self.stats) 189 190 191 @dataclasses.dataclass 192 class LLMRequest(Generic[TResult]): 193 """Request to an LLM with messages and response parsing.""" 194 195 messages: List[LLMMessage] 196 """List of `LLMMessage` objects for the conversation.""" 197 response_parser: Callable[[str], TResult] 198 """Function to parse the raw string response into `TResult`.""" 199 response_type: Type[TResult] 200 """Type of the expected result.""" 201 retries: int = 1 202 """Number of retry attempts on failure.""" 203 204 205 @dataclasses.dataclass 206 class LLMResult(Generic[TResult]): 207 """Result from an LLM API call.""" 208 209 result: TResult 210 """Parsed result value.""" 211 input_tokens: int 212 """Number of input tokens used.""" 213 output_tokens: int 214 """Number of output tokens used.""" 215 216 217 TBatchItem = TypeVar("TBatchItem") 218 TBatchResult = TypeVar("TBatchResult") 219 220 221 @dataclasses.dataclass 222 class LimitRequest(Generic[TBatchItem]): 223 request: TBatchItem 224 estimated_input: int 225 # estimated_output: int 226 227 228 class LLMWrapper(ABC): 229 """Base class for LLM API wrappers. 230 231 Provides a unified interface for calling different LLM providers 232 with rate limiting, batching, and retry logic. 233 234 Subclasses should implement `complete()` for the specific provider. 235 """ 236 237 __used_options__: ClassVar[List[Type[Option]]] = [] 238 239 @abstractmethod 240 async def complete(self, messages: List[LLMMessage], seed: Optional[int] = None) -> LLMResult[str]: 241 """Complete a conversation with the LLM. 242 243 Args: 244 * `messages`: List of `LLMMessage` objects for the conversation. 245 * `seed`: Optional random seed for deterministic outputs. 246 247 Returns: 248 * `LLMResult` with the response string and token usage. 249 """ 250 raise NotImplementedError 251 252 async def _batch( 253 self, 254 coro: Callable[[TBatchItem], Awaitable[LLMResult[TBatchResult]]], 255 batches: Sequence[LimitRequest[TBatchItem]], 256 batch_size: Optional[int] = None, 257 limits: Optional[RateLimits] = None, 258 ) -> List[TBatchResult]: 259 """Execute a batch of requests with rate limiting and concurrency control. 260 261 Args: 262 * `coro`: Coroutine function to execute for each request. 263 * `batches`: Sequence of rate-limited requests. 264 * `batch_size`: Optional maximum concurrent requests. 265 * `limits`: Optional rate limits (uses default if not provided). 266 267 Returns: 268 * List of results from all requests. 269 """ 270 if batch_size is None: 271 batch_size = self.get_batch_size() 272 if limits is None: 273 limits = self.get_limits() 274 rate_limiter = RateLimiter(limits=limits) 275 semaphore = Semaphore(batch_size) 276 277 async def work(request: LimitRequest[TBatchItem]) -> TBatchResult: 278 async with semaphore, rate_limiter.enter(request) as rate: 279 res = await coro(request.request) 280 rate.record(res.input_tokens, res.output_tokens) 281 return res.result 282 283 return await asyncio.gather(*[work(batch) for batch in batches]) 284 285 async def complete_batch( 286 self, 287 messages_batch: List[List[LLMMessage]], 288 batch_size: Optional[int] = None, 289 limits: Optional[RateLimits] = None, 290 ) -> List[str]: 291 """Complete multiple conversations in parallel. 292 293 Args: 294 * `messages_batch`: List of message lists, one per conversation. 295 * `batch_size`: Optional maximum concurrent requests. 296 * `limits`: Optional rate limits. 297 298 Returns: 299 * List of response strings. 300 """ 301 requests = [LimitRequest(msgs, sum(self.estimate_tokens(m) for m in msgs)) for msgs in messages_batch] 302 return await self._batch(self.complete, requests, batch_size, limits) 303 304 async def run(self, request: LLMRequest[TResult]) -> TResult: 305 """Run a single LLM request with retry logic. 306 307 Args: 308 * `request`: `LLMRequest` to execute. 309 310 Returns: 311 * Parsed result value. 312 313 Raises: 314 * Exception: If all retries fail. 315 """ 316 return (await self._run(request)).result 317 318 async def _run(self, request: LLMRequest[TResult]) -> LLMResult[TResult]: 319 """Run a request with retry logic and response parsing. 320 321 Args: 322 * `request`: `LLMRequest` to execute. 323 324 Returns: 325 * `LLMResult` with parsed result and token usage. 326 327 Raises: 328 * Exception: If all retries fail. 329 """ 330 num_retries = request.retries 331 error = None 332 while num_retries >= 0: 333 num_retries -= 1 334 try: 335 response = await self.complete(request.messages) 336 return LLMResult( 337 request.response_parser(response.result), response.input_tokens, response.output_tokens 338 ) 339 except Exception as e: 340 error = e 341 raise error 342 343 async def run_batch( 344 self, 345 requests: Sequence[LLMRequest[TResult]], 346 batch_size: Optional[int] = None, 347 limits: Optional[RateLimits] = None, 348 ) -> List[TResult]: 349 """Run multiple requests in parallel with rate limiting. 350 351 Args: 352 * `requests`: Sequence of `LLMRequest` objects. 353 * `batch_size`: Optional maximum concurrent requests. 354 * `limits`: Optional rate limits. 355 356 Returns: 357 * List of parsed results. 358 """ 359 rs = [LimitRequest(r, sum(self.estimate_tokens(m) for m in r.messages)) for r in requests] 360 return await self._batch(self._run, rs, batch_size, limits) 361 362 def get_batch_size(self) -> int: 363 """Get the default batch size for concurrent requests. 364 365 Returns: 366 * Maximum number of concurrent requests (default: 100). 367 """ 368 return 100 369 370 def get_limits(self) -> RateLimits: 371 """Get the default rate limits. 372 373 Returns: 374 * `RateLimits` with default values (no limits). 375 """ 376 return RateLimits() 377 378 def get_used_options(self) -> List[Type[Option]]: 379 """Get the option types used by this wrapper. 380 381 Returns: 382 * List of `Option` classes that this wrapper accepts. 383 """ 384 return self.__used_options__ 385 386 def estimate_tokens(self, msg: LLMMessage): 387 """Estimate token count for a message. 388 389 Args: 390 * `msg`: `LLMMessage` to estimate. 391 392 Returns: 393 * Estimated token count (default: character count). 394 """ 395 return len(msg.content) 396 397 complete_batch_sync = sync_api(complete_batch) 398 run_sync = sync_api(run) 399 run_batch_sync = sync_api(run_batch) 400 401 402 LLMProvider = str 403 LLMModel = str 404 LLMWrapperProvider = Callable[[LLMModel, Options], LLMWrapper] 405 _wrappers: Dict[Tuple[LLMProvider, Optional[LLMModel]], LLMWrapperProvider] = {} 406 407 408 def llm_provider(name: LLMProvider, model: Optional[LLMModel]) -> Callable[[LLMWrapperProvider], LLMWrapperProvider]: 409 def dec(f: LLMWrapperProvider): 410 _wrappers[(name, model)] = f 411 return f 412 413 return dec 414 415 416 def get_llm_wrapper(provider: LLMProvider, model: LLMModel, options: Options) -> LLMWrapper: 417 """Get an LLM wrapper for the specified provider and model. 418 419 Looks up registered wrappers first, then falls back to LiteLLM if available. 420 421 Args: 422 * `provider`: Provider name (e.g., "openai", "anthropic"). 423 * `model`: Model name (e.g., "gpt-4o-mini", "claude-3-sonnet"). 424 * `options`: Processing options with provider-specific configuration. 425 426 Returns: 427 * `LLMWrapper` instance for the provider/model. 428 429 Raises: 430 * `ValueError`: If no wrapper is found for the provider/model. 431 """ 432 key: Tuple[str, Optional[str]] = (provider, model) 433 if key in _wrappers: 434 return _wrappers[key](model, options) 435 key = (provider, None) 436 if key in _wrappers: 437 return _wrappers[key](model, options) 438 if find_spec("litellm") is not None: 439 litellm_wrapper = get_litellm_wrapper(provider, model, options) 440 if litellm_wrapper is not None: 441 return litellm_wrapper 442 raise ValueError(f"LLM wrapper for provider {provider} model {model} not found. Try installing litellm") 443 444 445 class LLMOptions(Option): 446 """Base class for LLM provider options. 447 448 Provides common configuration for API keys, rate limits, and custom API URLs. 449 """ 450 451 __provider_name__: ClassVar[str] 452 453 class Config: 454 extra = "forbid" 455 456 api_key: Optional[SecretStr] = None 457 """Optional API key for the provider.""" 458 # rpm_limit: int = 500 459 limits: RateLimits = RateLimits() 460 """Rate limiting configuration.""" 461 api_url: Optional[str] = None 462 """Optional custom API URL (for self-hosted providers).""" 463 464 def __init__(self, api_key: Optional[str] = None, rpm_limit: Optional[int] = None, **data): 465 """Initialize LLM options. 466 467 Args: 468 * `api_key`: Optional API key for the provider. 469 * `rpm_limit`: Optional requests per minute limit (backward compatibility). 470 """ 471 self.api_key = SecretStr(api_key) if api_key is not None else None 472 super().__init__(**data) 473 # backward comp 474 if rpm_limit is not None: 475 self.limits.rpm = rpm_limit 476 477 def get_api_key(self) -> Optional[str]: 478 """Get the API key as a plain string. 479 480 Returns: 481 * API key string, or `None` if not set. 482 """ 483 if self.api_key is None: 484 return None 485 return self.api_key.get_secret_value() 486 487 def get_additional_kwargs(self) -> Dict[str, Any]: 488 """Get additional keyword arguments for the LLM client. 489 490 Returns: 491 * Dictionary of additional arguments (empty by default, can be overridden). 492 """ 493 return {} 494 495 496 class OpenAIKey(LLMOptions): 497 """Options for OpenAI provider.""" 498 499 __provider_name__: ClassVar[str] = "openai" 500 limits: RateLimits = RateLimits(rpm=500) 501 """Rate limiting configuration (default: 500 requests per minute).""" 502 503 504 OpenAIOptions = OpenAIKey # for consistency 505 506 507 @llm_provider("openai", None) 508 class OpenAIWrapper(LLMWrapper): 509 """Wrapper for OpenAI API. 510 511 Provides async access to OpenAI's chat completion API with rate limiting 512 and token tracking. 513 """ 514 515 __used_options__: ClassVar = [OpenAIKey] 516 517 def __init__(self, model: str, options: Options): 518 import openai 519 520 self.model = model 521 self.options = options.get(OpenAIKey) 522 self._clients: Dict[int, openai.AsyncOpenAI] = {} 523 524 @property 525 def client(self): 526 import openai 527 528 try: 529 loop = asyncio.get_running_loop() 530 except RuntimeError as e: 531 raise RuntimeError("Cannot access OpenAIWrapper client without loop") from e 532 loop_id = id(loop) 533 if loop_id not in self._clients: 534 self._clients[loop_id] = openai.AsyncOpenAI( 535 api_key=self.options.get_api_key(), base_url=self.options.api_url 536 ) 537 return self._clients[loop_id] 538 539 async def complete(self, messages: List[LLMMessage], seed: Optional[int] = None) -> LLMResult[str]: 540 import openai 541 from openai.types.chat.chat_completion import ChatCompletion 542 543 messages = [{"role": msg.role, "content": msg.content} for msg in messages] 544 try: 545 response: ChatCompletion = await self.client.chat.completions.create( 546 model=self.model, messages=messages, seed=seed 547 ) # type: ignore[arg-type] 548 except openai.RateLimitError as e: 549 raise LLMRateLimitError(e.message) from e 550 except openai.APIError as e: 551 raise LLMRequestError(f"Failed to call OpenAI complete API: {e.message}", original_error=e) from e 552 553 content = response.choices[0].message.content 554 assert content is not None # todo: better error 555 if response.usage is None: 556 return LLMResult(content, 0, 0) 557 return LLMResult(content, response.usage.prompt_tokens, response.usage.completion_tokens) 558 559 def get_limits(self) -> RateLimits: 560 return self.options.limits 561 562 563 def get_litellm_wrapper(provider: LLMProvider, model: LLMModel, options: Options) -> Optional[LLMWrapper]: 564 from litellm import BadRequestError 565 from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider 566 567 try: 568 model, provider, *_ = get_llm_provider(model, provider) 569 return LiteLLMWrapper(f"{provider}/{model}", options) 570 except BadRequestError: 571 return None 572 573 574 @llm_provider("litellm", None) 575 class LiteLLMWrapper(LLMWrapper): 576 __llm_options_type__: ClassVar[Type[LLMOptions]] = LLMOptions 577 578 def get_used_options(self) -> List[Type[Option]]: 579 return [self.__llm_options_type__] 580 581 def __init__(self, model: str, options: Options): 582 self.model = model 583 self.options: LLMOptions = options.get(self.__llm_options_type__) 584 585 @property 586 def provider_and_model(self) -> Tuple[Optional[str], str]: 587 if not hasattr(self.options, "__provider_name__"): 588 return None, self.model 589 provider_name = self.options.__provider_name__ 590 if self.model.startswith(provider_name + "/"): 591 return provider_name, self.model[len(provider_name) + 1 :] 592 return provider_name, self.model 593 594 async def complete(self, messages: List[LLMMessage], seed: Optional[int] = None) -> LLMResult[str]: 595 from litellm import acompletion 596 from litellm.types.utils import ModelResponse 597 from litellm.types.utils import Usage 598 599 provider_name, model = self.provider_and_model 600 response: ModelResponse = await acompletion( 601 model=model, 602 custom_llm_provider=provider_name, 603 messages=[m.dict() for m in messages], 604 api_key=self.options.get_api_key(), 605 api_base=self.options.api_url, 606 seed=seed, 607 **self.options.get_additional_kwargs(), 608 ) 609 content = response.choices[0].message.content 610 usage: Optional[Usage] = response.model_extra.get("usage") 611 if usage is None: 612 return LLMResult(content, 0, 0) 613 return LLMResult(content, usage.prompt_tokens, usage.completion_tokens) 614 615 def get_limits(self) -> RateLimits: 616 return self.options.limits 617 618 619 class AnthropicOptions(LLMOptions): 620 __provider_name__: ClassVar = "anthropic" 621 limits: RateLimits = RateLimits( 622 rpm=50 // 12, itpm=40000 // 12, otpm=8000 // 12, interval=datetime.timedelta(seconds=5) 623 ) 624 625 626 @llm_provider("anthropic", None) 627 class AnthropicWrapper(LiteLLMWrapper): 628 __llm_options_type__: ClassVar = AnthropicOptions 629 630 631 class GeminiOptions(LLMOptions): 632 __provider_name__: ClassVar = "gemini" 633 634 635 @llm_provider("gemini", None) 636 class GeminiWrapper(LiteLLMWrapper): 637 __llm_options_type__: ClassVar = GeminiOptions 638 639 640 class VertexAIOptions(LLMOptions): 641 __provider_name__: ClassVar = "vertex_ai" 642 643 def get_additional_kwargs(self) -> Dict[str, Any]: 644 if self.api_key is None or len(self.api_key.get_secret_value()) > 10000: # check for using non-strict json 645 return {} 646 try: 647 vertex_credentials = json.loads(self.api_key.get_secret_value()) 648 except json.decoder.JSONDecodeError: 649 return {} 650 return {"vertex_credentials": vertex_credentials} 651 652 653 @llm_provider("vertex_ai", None) 654 class VertexAIWrapper(LiteLLMWrapper): 655 __llm_options_type__: ClassVar = VertexAIOptions 656 657 658 class DeepSeekOptions(LLMOptions): 659 __provider_name__: ClassVar = "deepseek" 660 661 662 @llm_provider("deepseek", None) 663 class DeepSeekWrapper(LiteLLMWrapper): 664 __llm_options_type__: ClassVar = DeepSeekOptions 665 666 667 class MistralOptions(LLMOptions): 668 __provider_name__: ClassVar = "mistral" 669 limits: RateLimits = RateLimits(rpm=1, itpm=500000 // 60, otpm=500000 // 60, interval=datetime.timedelta(seconds=1)) 670 671 672 @llm_provider("mistral", None) 673 class MistralWrapper(LiteLLMWrapper): 674 __llm_options_type__: ClassVar = MistralOptions 675 676 677 class OllamaOptions(LLMOptions): 678 __provider_name__: ClassVar = "ollama" 679 api_url: str 680 681 682 @llm_provider("ollama", None) 683 class OllamaWrapper(LiteLLMWrapper): 684 __llm_options_type__: ClassVar = OllamaOptions 685 686 687 class NebiusOptions(LLMOptions): 688 __provider_name__: ClassVar = "nebius" 689 690 691 @llm_provider("nebius", None) 692 class NebiusWrapper(LiteLLMWrapper): 693 __llm_options_type__: ClassVar = NebiusOptions 694 695 696 excludes = [ 697 "openai", # supported natively 698 "openai_like", 699 "custom_openai", 700 "text-completion-openai", 701 "anthropic_text", 702 "huggingface", # llama models do not work, disable until tested 703 "vertex_ai_beta", 704 "azure_text", 705 "sagemaker_chat", 706 "ollama_chat", 707 "text-completion-codestral", 708 "watsonx_text", 709 "custom", 710 "aiohttp_openai", 711 ] 712 litellm_providers = [ 713 "openai", 714 "openai_like", 715 "jina_ai", 716 "xai", 717 "custom_openai", 718 "text-completion-openai", 719 "cohere", 720 "cohere_chat", 721 "clarifai", 722 "anthropic", 723 "anthropic_text", 724 "bytez", 725 "replicate", 726 "huggingface", 727 "together_ai", 728 "openrouter", 729 "datarobot", 730 "vertex_ai", 731 "vertex_ai_beta", 732 "gemini", 733 "ai21", 734 "baseten", 735 "azure", 736 "azure_text", 737 "azure_ai", 738 "sagemaker", 739 "sagemaker_chat", 740 "bedrock", 741 "vllm", 742 "nlp_cloud", 743 "petals", 744 "oobabooga", 745 "ollama", 746 "ollama_chat", 747 "deepinfra", 748 "perplexity", 749 "mistral", 750 "groq", 751 "nvidia_nim", 752 "cerebras", 753 "ai21_chat", 754 "volcengine", 755 "codestral", 756 "text-completion-codestral", 757 "dashscope", 758 "deepseek", 759 "sambanova", 760 "maritalk", 761 "voyage", 762 "cloudflare", 763 "xinference", 764 "fireworks_ai", 765 "friendliai", 766 "featherless_ai", 767 "watsonx", 768 "watsonx_text", 769 "triton", 770 "predibase", 771 "databricks", 772 "empower", 773 "github", 774 "custom", 775 "litellm_proxy", 776 "hosted_vllm", 777 "llamafile", 778 "lm_studio", 779 "galadriel", 780 "nebius", 781 "infinity", 782 "deepgram", 783 "elevenlabs", 784 "novita", 785 "aiohttp_openai", 786 "langfuse", 787 "humanloop", 788 "topaz", 789 "assemblyai", 790 "github_copilot", 791 "snowflake", 792 "meta_llama", 793 "nscale", 794 ] 795 litellm_providers = [p for p in litellm_providers if p not in excludes] 796 797 798 def _create_litellm_wrapper(provider: str): 799 words = provider.split("_") 800 class_name_prefix = "".join(word.upper() if word.lower() == "ai" else word.capitalize() for word in words) 801 802 wrapper_name = f"{class_name_prefix}Wrapper" 803 options_name = f"{class_name_prefix}Options" 804 options_type = type( 805 options_name, 806 (LLMOptions,), 807 {"__provider_name__": provider, "__annotations__": {"__provider_name__": ClassVar[str]}}, 808 ) 809 810 def __init__(self, model: str, options: Options): 811 super(self.__class__, self).__init__(model, options) 812 813 wrapper_type = type( 814 wrapper_name, 815 (LiteLLMWrapper,), 816 { 817 "__llm_options_type__": options_type, 818 "__annotations__": {"__llm_options_type__": ClassVar}, 819 "__init__": __init__, 820 }, 821 ) 822 823 return { 824 wrapper_name: llm_provider(provider, None)(wrapper_type), 825 options_name: options_type, 826 } 827 828 829 for provider in litellm_providers: 830 key = (provider, None) 831 if key in _wrappers: 832 continue 833 locals().update(**_create_litellm_wrapper(provider)) 834 835 836 def main(): 837 from pprint import pformat 838 839 import litellm 840 841 print(pformat([p.value for p in litellm.provider_list]).replace("'", '"')) 842 843 844 if __name__ == "__main__": 845 main()