xai.py
 1  import os
 2  from openai import OpenAI
 3  from typing import List, Optional, Dict, Set
 4  from langchain_openai import ChatOpenAI
 5  from langchain_core.language_models.chat_models import BaseChatModel
 6  from .base import LLMProvider
 7  from ..capabilities import ModelCapability
 8  
 9  
10  class XAIProvider(LLMProvider):
11      def create_model(self, name: str, model: str, tools: Optional[List] = None, **kwargs) -> BaseChatModel:
12          llm = ChatOpenAI(
13              name=name,
14              model=model,
15              api_key=os.getenv("XAI_API_KEY"),
16              base_url=os.getenv("XAI_BASE_URL"),
17              **kwargs
18          )
19          return llm.bind_tools(tools) if tools else llm
20  
21      def list_models(self) -> List[str]:
22          """
23          List all available XAI models
24  
25          Sample Response:
26          {
27            "object": "list",
28            "data": [
29              {
30                "id": "model-id-0",
31                "object": "model",
32                "created": 1686935002,
33                "owned_by": "organization-owner"
34              },
35              {
36                "id": "model-id-1",
37                "object": "model",
38                "created": 1686935002,
39                "owned_by": "organization-owner",
40              },
41              {
42                "id": "model-id-2",
43                "object": "model",
44                "created": 1686935002,
45                "owned_by": "openai"
46              },
47            ],
48            "object": "list"
49          }
50          """
51          try:
52              client = OpenAI(api_key=os.getenv("XAI_API_KEY"),
53                              base_url=os.getenv("XAI_BASE_URL"))
54              response = client.models.list()
55              return [f'{model.id}' for model in response.data]
56          except Exception as e:
57              return []
58  
59      @property
60      def name(self) -> str:
61          return "xai"
62  
63      @property
64      def capabilities(self) -> Dict[str, Set[ModelCapability]]:
65          return {
66              "grok-beta": {ModelCapability.TEXT_TO_TEXT, ModelCapability.TOOL_CALLING},
67          }