xai.py
1 import os 2 from openai import OpenAI 3 from typing import List, Optional, Dict, Set 4 from langchain_openai import ChatOpenAI 5 from langchain_core.language_models.chat_models import BaseChatModel 6 from .base import LLMProvider 7 from ..capabilities import ModelCapability 8 9 10 class XAIProvider(LLMProvider): 11 def create_model(self, name: str, model: str, tools: Optional[List] = None, **kwargs) -> BaseChatModel: 12 llm = ChatOpenAI( 13 name=name, 14 model=model, 15 api_key=os.getenv("XAI_API_KEY"), 16 base_url=os.getenv("XAI_BASE_URL"), 17 **kwargs 18 ) 19 return llm.bind_tools(tools) if tools else llm 20 21 def list_models(self) -> List[str]: 22 """ 23 List all available XAI models 24 25 Sample Response: 26 { 27 "object": "list", 28 "data": [ 29 { 30 "id": "model-id-0", 31 "object": "model", 32 "created": 1686935002, 33 "owned_by": "organization-owner" 34 }, 35 { 36 "id": "model-id-1", 37 "object": "model", 38 "created": 1686935002, 39 "owned_by": "organization-owner", 40 }, 41 { 42 "id": "model-id-2", 43 "object": "model", 44 "created": 1686935002, 45 "owned_by": "openai" 46 }, 47 ], 48 "object": "list" 49 } 50 """ 51 try: 52 client = OpenAI(api_key=os.getenv("XAI_API_KEY"), 53 base_url=os.getenv("XAI_BASE_URL")) 54 response = client.models.list() 55 return [f'{model.id}' for model in response.data] 56 except Exception as e: 57 return [] 58 59 @property 60 def name(self) -> str: 61 return "xai" 62 63 @property 64 def capabilities(self) -> Dict[str, Set[ModelCapability]]: 65 return { 66 "grok-beta": {ModelCapability.TEXT_TO_TEXT, ModelCapability.TOOL_CALLING}, 67 }