/ chat_workflow / llm / providers / ollama.py
ollama.py
  1  import builtins
  2  import json
  3  import sys
  4  import os
  5  import requests
  6  import typing
  7  from chainlit import logger
  8  from datetime import datetime, timedelta
  9  from functools import lru_cache
 10  from inspect import signature
 11  from typing import List, Optional, TypeVar, Dict, Any, Union, Set
 12  from typing import get_type_hints, get_args, get_origin
 13  from langchain_core.language_models.chat_models import BaseChatModel
 14  from langchain_ollama import ChatOllama
 15  from .base import LLMProvider
 16  from ..capabilities import ModelCapability
 17  
 18  T = TypeVar('T')
 19  
 20  
 21  class TimedCache:
 22      """Cache with TTL support"""
 23  
 24      def __init__(self, ttl_seconds: int = 300):
 25          self.cache = {}
 26          self.ttl = ttl_seconds
 27  
 28      def get(self, key: str) -> Optional[T]:
 29          if key in self.cache:
 30              value, timestamp = self.cache[key]
 31              if datetime.now() - timestamp < timedelta(seconds=self.ttl):
 32                  return value
 33              del self.cache[key]
 34          return None
 35  
 36      def set(self, key: str, value: T) -> None:
 37          self.cache[key] = (value, datetime.now())
 38  
 39  
 40  # 5 minute TTL for model list
 41  model_cache = TimedCache(ttl_seconds=300)
 42  
 43  
 44  class OllamaProvider(LLMProvider):
 45      def __init__(self, base_url: str = "http://localhost:11434"):
 46          self.base_url = base_url
 47  
 48      def create_model(self, name: str, model: str, tools: Optional[List] = None, **kwargs) -> BaseChatModel:
 49          # Fetch the model parameters from the Ollama API
 50          base_url = os.getenv("OLLAMA_URL")
 51          response = requests.post(
 52              f"{base_url}/api/show", json={"name": model})
 53          logger.debug(f"Response: {json.dumps(response.json(), indent=2)}")
 54          params_kwargs = {}
 55          if response.status_code == 200:
 56              response_json = response.json()
 57              if "parameters" in response_json:
 58                  params_kwargs = self.parse_ollama_params(
 59                      response_json["parameters"])
 60  
 61          # Merge the parameters from the Ollama API response with the provided kwargs
 62          # When conflict, kwargs overrides
 63          params_kwargs.update(kwargs)
 64          logger.debug(f"Params kwargs: {params_kwargs}")
 65          llm = ChatOllama(name=name, model=model,
 66                           base_url=os.getenv("OLLAMA_URL", "http://localhost:11434"), **params_kwargs)
 67          return llm.bind_tools(tools) if tools else llm
 68  
 69      def list_models(self) -> List[str]:
 70          cache_key = f"models_{self.base_url}"
 71          cached_models = model_cache.get(cache_key)
 72          if cached_models is not None:
 73              return cached_models
 74  
 75          try:
 76              response = requests.get(f"{self.base_url}/api/tags")
 77              response.raise_for_status()
 78              models = [f'{model["name"]}' for model in response.json()[
 79                  "models"]]
 80              model_cache.set(cache_key, models)
 81              return models
 82          except:
 83              return []
 84  
 85      @lru_cache(maxsize=1)
 86      def get_ollama_param_types(self) -> Dict[str, Any]:
 87          """
 88          Dynamically extract and cache parameter types from ChatOllama
 89          """
 90  
 91          # Get all modules from langchain packages
 92          localns = {}
 93          for module_name, module in sys.modules.items():
 94              if module_name.startswith(('langchain', 'typing')):
 95                  if module:
 96                      localns.update({
 97                          k: v for k, v in module.__dict__.items()
 98                          if isinstance(v, type) or hasattr(v, '__origin__')
 99                      })
100  
101          # Add built-in types
102          localns.update({
103              k: v for k, v in builtins.__dict__.items()
104              if isinstance(v, type)
105          })
106  
107          # Add typing constructs
108          localns.update(typing.__dict__)
109  
110          sig = signature(ChatOllama)
111          type_hints = get_type_hints(ChatOllama, localns=localns)
112  
113          param_types = {}
114          for param_name, param in sig.parameters.items():
115              if param_name == 'self':
116                  continue
117              type_hint = type_hints.get(param_name, Any)
118              param_types[param_name] = type_hint
119  
120          return param_types
121  
122      def parse_value(self, value_str: str, type_hint):
123          """
124          Parse the string value to the type specified by type_hint
125          """
126          origin = get_origin(type_hint)
127          args = get_args(type_hint)
128  
129          if origin is Union:
130              for arg_type in args:
131                  if arg_type == type(None):
132                      continue
133                  try:
134                      return self.parse_value(value_str, arg_type)
135                  except ValueError:
136                      continue
137              raise ValueError(f"Cannot parse {value_str} as {type_hint}")
138          elif type_hint == int:
139              return int(value_str)
140          elif type_hint == float:
141              return float(value_str)
142          elif type_hint == str:
143              # Keep the original string value without stripping special characters
144              return value_str.strip('"\'')
145          elif origin == list:
146              # For stop tokens, preserve the exact string including brackets
147              elem_type = args[0] if args else str
148  
149              if value_str.startswith('[') and value_str.endswith(']'):
150                  # Single element with brackets - return as is
151                  return [value_str.strip('"\'')]
152              # Handle comma-separated list case
153              elements = [elem.strip().strip('"\'')
154                          for elem in value_str.split(',')]
155              return [self.parse_value(elem, elem_type) for elem in elements]
156          elif origin == dict:
157              import json
158              return json.loads(value_str)
159          else:
160  
161              return value_str.strip('"\'')
162  
163      def parse_ollama_params(self, parameters: str) -> Dict[str, Any]:
164          """
165          Parse the parameters from the Ollama API response
166  
167          Args:
168              parameters: Raw parameter string from Ollama API
169  
170          Returns:
171              Dict of parsed parameters with appropriate types
172  
173          Example:
174              Input: 'num_ctx                        4096\nstop                           "[INST]"\nstop                           "[/INST]"'
175              Output: {
176                  "num_ctx": 4096,
177                  "stop": ["[INST]", "[/INST]"]
178              }
179          """
180          param_types = self.get_ollama_param_types()
181          result = {}
182  
183          if not parameters or not isinstance(parameters, str):
184              return result
185  
186          # First pass to collect all values for each key
187          collected_values = {}
188          for line in parameters.strip().split('\n'):
189              parts = line.strip().split(None, 1)
190              if len(parts) != 2:
191                  continue
192  
193              key, raw_value = parts
194              value_str = raw_value.strip('"\'')
195              if not value_str:
196                  continue
197  
198              if key not in collected_values:
199                  collected_values[key] = []
200              collected_values[key].append(value_str)
201  
202          # Second pass to parse values with correct types
203          for key, values in collected_values.items():
204              type_hint = param_types.get(key, str)
205              try:
206                  # If we have multiple values, treat as a list
207                  if len(values) > 1:
208                      result[key] = [self.parse_value(v, str) for v in values]
209                  else:
210                      result[key] = self.parse_value(values[0], type_hint)
211              except (ValueError, TypeError) as e:
212                  logger.debug(
213                      f"Skipping invalid parameter value for {key}: {values} ({e})")
214                  continue
215  
216          return result
217  
218      @property
219      def name(self) -> str:
220          return "ollama"
221  
222      @property
223      def capabilities(self) -> Dict[str, Set[ModelCapability]]:
224          # Currently Ollama doesn't have a way to query capabilities
225          # But we can infer some capabilities based on model metadata
226          def get_model_capabilities(metadata: dict) -> Set[ModelCapability]:
227              capability_keywords = {
228                  ModelCapability.TEXT_TO_TEXT: ["text", "response", "conversation", "Q&A", "template"],
229                  ModelCapability.IMAGE_TO_TEXT: ["vision", "image", "CLIP", "image encoder", "patch_size", "projection_dim"],
230                  ModelCapability.TOOL_CALLING: [
231                      "tool", "tool_calls", "function call", "parameters",
232                      "function name", "arguments", "tool calling capabilities"
233                  ],
234                  ModelCapability.STRUCTURED_OUTPUT: [
235                      "json", "structured", "format", "parameters", "dictionary"]
236              }
237  
238              detected_capabilities = set(
239                  [ModelCapability.TEXT_TO_TEXT])  # Base capability
240  
241              # Check template for tool calling patterns
242              template = metadata.get("template", "").lower()
243              for capability, keywords in capability_keywords.items():
244                  if any(keyword in template for keyword in keywords):
245                      detected_capabilities.add(capability)
246  
247              # Check for JSON function call format in template
248              if '"name":' in template and '"parameters":' in template:
249                  detected_capabilities.add(ModelCapability.TOOL_CALLING)
250                  detected_capabilities.add(ModelCapability.STRUCTURED_OUTPUT)
251  
252              return detected_capabilities
253  
254          try:
255              cache_key = f"capabilities_{self.base_url}"
256              cached_capabilities = model_cache.get(cache_key)
257              if cached_capabilities is not None:
258                  return cached_capabilities
259  
260              response = requests.get(f"{self.base_url}/api/tags")
261              response.raise_for_status()
262              logger.debug(f"API response: {response.json()}")
263  
264              model_capabilities = {}
265              for model in response.json()["models"]:
266                  # Get detailed info for each model
267                  show_response = requests.post(
268                      f"{self.base_url}/api/show",
269                      json={"name": model["name"]}
270                  )
271                  show_response.raise_for_status()
272                  capabilities = get_model_capabilities(show_response.json())
273                  model_capabilities[model["name"]] = capabilities
274  
275              model_cache.set(cache_key, model_capabilities)
276              logger.debug(f"Model capabilities: {model_capabilities}")
277              return model_capabilities
278          except Exception as e:
279              logger.error(f"Error getting capabilities: {e}")
280              return {}