ollama.py
1 import builtins 2 import json 3 import sys 4 import os 5 import requests 6 import typing 7 from chainlit import logger 8 from datetime import datetime, timedelta 9 from functools import lru_cache 10 from inspect import signature 11 from typing import List, Optional, TypeVar, Dict, Any, Union, Set 12 from typing import get_type_hints, get_args, get_origin 13 from langchain_core.language_models.chat_models import BaseChatModel 14 from langchain_ollama import ChatOllama 15 from .base import LLMProvider 16 from ..capabilities import ModelCapability 17 18 T = TypeVar('T') 19 20 21 class TimedCache: 22 """Cache with TTL support""" 23 24 def __init__(self, ttl_seconds: int = 300): 25 self.cache = {} 26 self.ttl = ttl_seconds 27 28 def get(self, key: str) -> Optional[T]: 29 if key in self.cache: 30 value, timestamp = self.cache[key] 31 if datetime.now() - timestamp < timedelta(seconds=self.ttl): 32 return value 33 del self.cache[key] 34 return None 35 36 def set(self, key: str, value: T) -> None: 37 self.cache[key] = (value, datetime.now()) 38 39 40 # 5 minute TTL for model list 41 model_cache = TimedCache(ttl_seconds=300) 42 43 44 class OllamaProvider(LLMProvider): 45 def __init__(self, base_url: str = "http://localhost:11434"): 46 self.base_url = base_url 47 48 def create_model(self, name: str, model: str, tools: Optional[List] = None, **kwargs) -> BaseChatModel: 49 # Fetch the model parameters from the Ollama API 50 base_url = os.getenv("OLLAMA_URL") 51 response = requests.post( 52 f"{base_url}/api/show", json={"name": model}) 53 logger.debug(f"Response: {json.dumps(response.json(), indent=2)}") 54 params_kwargs = {} 55 if response.status_code == 200: 56 response_json = response.json() 57 if "parameters" in response_json: 58 params_kwargs = self.parse_ollama_params( 59 response_json["parameters"]) 60 61 # Merge the parameters from the Ollama API response with the provided kwargs 62 # When conflict, kwargs overrides 63 params_kwargs.update(kwargs) 64 logger.debug(f"Params kwargs: {params_kwargs}") 65 llm = ChatOllama(name=name, model=model, 66 base_url=os.getenv("OLLAMA_URL", "http://localhost:11434"), **params_kwargs) 67 return llm.bind_tools(tools) if tools else llm 68 69 def list_models(self) -> List[str]: 70 cache_key = f"models_{self.base_url}" 71 cached_models = model_cache.get(cache_key) 72 if cached_models is not None: 73 return cached_models 74 75 try: 76 response = requests.get(f"{self.base_url}/api/tags") 77 response.raise_for_status() 78 models = [f'{model["name"]}' for model in response.json()[ 79 "models"]] 80 model_cache.set(cache_key, models) 81 return models 82 except: 83 return [] 84 85 @lru_cache(maxsize=1) 86 def get_ollama_param_types(self) -> Dict[str, Any]: 87 """ 88 Dynamically extract and cache parameter types from ChatOllama 89 """ 90 91 # Get all modules from langchain packages 92 localns = {} 93 for module_name, module in sys.modules.items(): 94 if module_name.startswith(('langchain', 'typing')): 95 if module: 96 localns.update({ 97 k: v for k, v in module.__dict__.items() 98 if isinstance(v, type) or hasattr(v, '__origin__') 99 }) 100 101 # Add built-in types 102 localns.update({ 103 k: v for k, v in builtins.__dict__.items() 104 if isinstance(v, type) 105 }) 106 107 # Add typing constructs 108 localns.update(typing.__dict__) 109 110 sig = signature(ChatOllama) 111 type_hints = get_type_hints(ChatOllama, localns=localns) 112 113 param_types = {} 114 for param_name, param in sig.parameters.items(): 115 if param_name == 'self': 116 continue 117 type_hint = type_hints.get(param_name, Any) 118 param_types[param_name] = type_hint 119 120 return param_types 121 122 def parse_value(self, value_str: str, type_hint): 123 """ 124 Parse the string value to the type specified by type_hint 125 """ 126 origin = get_origin(type_hint) 127 args = get_args(type_hint) 128 129 if origin is Union: 130 for arg_type in args: 131 if arg_type == type(None): 132 continue 133 try: 134 return self.parse_value(value_str, arg_type) 135 except ValueError: 136 continue 137 raise ValueError(f"Cannot parse {value_str} as {type_hint}") 138 elif type_hint == int: 139 return int(value_str) 140 elif type_hint == float: 141 return float(value_str) 142 elif type_hint == str: 143 # Keep the original string value without stripping special characters 144 return value_str.strip('"\'') 145 elif origin == list: 146 # For stop tokens, preserve the exact string including brackets 147 elem_type = args[0] if args else str 148 149 if value_str.startswith('[') and value_str.endswith(']'): 150 # Single element with brackets - return as is 151 return [value_str.strip('"\'')] 152 # Handle comma-separated list case 153 elements = [elem.strip().strip('"\'') 154 for elem in value_str.split(',')] 155 return [self.parse_value(elem, elem_type) for elem in elements] 156 elif origin == dict: 157 import json 158 return json.loads(value_str) 159 else: 160 161 return value_str.strip('"\'') 162 163 def parse_ollama_params(self, parameters: str) -> Dict[str, Any]: 164 """ 165 Parse the parameters from the Ollama API response 166 167 Args: 168 parameters: Raw parameter string from Ollama API 169 170 Returns: 171 Dict of parsed parameters with appropriate types 172 173 Example: 174 Input: 'num_ctx 4096\nstop "[INST]"\nstop "[/INST]"' 175 Output: { 176 "num_ctx": 4096, 177 "stop": ["[INST]", "[/INST]"] 178 } 179 """ 180 param_types = self.get_ollama_param_types() 181 result = {} 182 183 if not parameters or not isinstance(parameters, str): 184 return result 185 186 # First pass to collect all values for each key 187 collected_values = {} 188 for line in parameters.strip().split('\n'): 189 parts = line.strip().split(None, 1) 190 if len(parts) != 2: 191 continue 192 193 key, raw_value = parts 194 value_str = raw_value.strip('"\'') 195 if not value_str: 196 continue 197 198 if key not in collected_values: 199 collected_values[key] = [] 200 collected_values[key].append(value_str) 201 202 # Second pass to parse values with correct types 203 for key, values in collected_values.items(): 204 type_hint = param_types.get(key, str) 205 try: 206 # If we have multiple values, treat as a list 207 if len(values) > 1: 208 result[key] = [self.parse_value(v, str) for v in values] 209 else: 210 result[key] = self.parse_value(values[0], type_hint) 211 except (ValueError, TypeError) as e: 212 logger.debug( 213 f"Skipping invalid parameter value for {key}: {values} ({e})") 214 continue 215 216 return result 217 218 @property 219 def name(self) -> str: 220 return "ollama" 221 222 @property 223 def capabilities(self) -> Dict[str, Set[ModelCapability]]: 224 # Currently Ollama doesn't have a way to query capabilities 225 # But we can infer some capabilities based on model metadata 226 def get_model_capabilities(metadata: dict) -> Set[ModelCapability]: 227 capability_keywords = { 228 ModelCapability.TEXT_TO_TEXT: ["text", "response", "conversation", "Q&A", "template"], 229 ModelCapability.IMAGE_TO_TEXT: ["vision", "image", "CLIP", "image encoder", "patch_size", "projection_dim"], 230 ModelCapability.TOOL_CALLING: [ 231 "tool", "tool_calls", "function call", "parameters", 232 "function name", "arguments", "tool calling capabilities" 233 ], 234 ModelCapability.STRUCTURED_OUTPUT: [ 235 "json", "structured", "format", "parameters", "dictionary"] 236 } 237 238 detected_capabilities = set( 239 [ModelCapability.TEXT_TO_TEXT]) # Base capability 240 241 # Check template for tool calling patterns 242 template = metadata.get("template", "").lower() 243 for capability, keywords in capability_keywords.items(): 244 if any(keyword in template for keyword in keywords): 245 detected_capabilities.add(capability) 246 247 # Check for JSON function call format in template 248 if '"name":' in template and '"parameters":' in template: 249 detected_capabilities.add(ModelCapability.TOOL_CALLING) 250 detected_capabilities.add(ModelCapability.STRUCTURED_OUTPUT) 251 252 return detected_capabilities 253 254 try: 255 cache_key = f"capabilities_{self.base_url}" 256 cached_capabilities = model_cache.get(cache_key) 257 if cached_capabilities is not None: 258 return cached_capabilities 259 260 response = requests.get(f"{self.base_url}/api/tags") 261 response.raise_for_status() 262 logger.debug(f"API response: {response.json()}") 263 264 model_capabilities = {} 265 for model in response.json()["models"]: 266 # Get detailed info for each model 267 show_response = requests.post( 268 f"{self.base_url}/api/show", 269 json={"name": model["name"]} 270 ) 271 show_response.raise_for_status() 272 capabilities = get_model_capabilities(show_response.json()) 273 model_capabilities[model["name"]] = capabilities 274 275 model_cache.set(cache_key, model_capabilities) 276 logger.debug(f"Model capabilities: {model_capabilities}") 277 return model_capabilities 278 except Exception as e: 279 logger.error(f"Error getting capabilities: {e}") 280 return {}