/ examples / provider-registry / custom_provider_example.py
custom_provider_example.py
  1  """
  2  Custom Provider Registry Example (Python)
  3  
  4  This example demonstrates how to register and use custom LLM providers
  5  with the PraisonAI Python wrapper.
  6  
  7  Note: The Python provider registry is for custom provider extensions.
  8  Built-in providers (OpenAI, Anthropic, Google) are handled by LiteLLM
  9  in praisonaiagents automatically.
 10  """
 11  
 12  import sys
 13  sys.path.insert(0, '../../src/praisonai')
 14  
 15  from praisonai.llm import (
 16      LLMProviderRegistry,
 17      register_llm_provider,
 18      unregister_llm_provider,
 19      has_llm_provider,
 20      list_llm_providers,
 21      create_llm_provider,
 22      get_default_llm_registry,
 23      parse_model_string
 24  )
 25  
 26  
 27  # Example 1: Simple Custom Provider
 28  # ---------------------------------
 29  
 30  class SimpleCustomProvider:
 31      """A minimal custom provider example."""
 32      
 33      provider_id = "simple-custom"
 34      
 35      def __init__(self, model_id: str, config: dict = None):
 36          self.model_id = model_id
 37          self.config = config or {}
 38          self.api_endpoint = self.config.get('api_endpoint', 'https://api.example.com')
 39      
 40      def generate(self, prompt: str) -> str:
 41          """Generate a response (simulated)."""
 42          print(f"[SimpleCustomProvider] Generating with model: {self.model_id}")
 43          print(f"[SimpleCustomProvider] Prompt: {prompt[:50]}...")
 44          return f"Response from {self.provider_id}/{self.model_id}: Hello! This is a simulated response."
 45  
 46  
 47  # Example 2: Ollama Provider
 48  # --------------------------
 49  
 50  class OllamaProvider:
 51      """Custom provider for local Ollama integration."""
 52      
 53      provider_id = "ollama"
 54      
 55      def __init__(self, model_id: str, config: dict = None):
 56          self.model_id = model_id
 57          self.config = config or {}
 58          self.base_url = self.config.get('base_url', 'http://localhost:11434')
 59      
 60      def generate(self, prompt: str) -> str:
 61          """Generate a response using Ollama API."""
 62          import requests
 63          
 64          response = requests.post(
 65              f"{self.base_url}/api/generate",
 66              json={
 67                  "model": self.model_id,
 68                  "prompt": prompt,
 69                  "stream": False
 70              }
 71          )
 72          
 73          if response.status_code != 200:
 74              raise Exception(f"Ollama API error: {response.text}")
 75          
 76          return response.json().get("response", "")
 77      
 78      def generate_stream(self, prompt: str):
 79          """Generate a streaming response using Ollama API."""
 80          import requests
 81          
 82          response = requests.post(
 83              f"{self.base_url}/api/generate",
 84              json={
 85                  "model": self.model_id,
 86                  "prompt": prompt,
 87                  "stream": True
 88              },
 89              stream=True
 90          )
 91          
 92          for line in response.iter_lines():
 93              if line:
 94                  import json
 95                  data = json.loads(line)
 96                  yield data.get("response", "")
 97                  if data.get("done"):
 98                      break
 99  
100  
101  # Example 3: Cloudflare Workers AI Provider
102  # -----------------------------------------
103  
104  class CloudflareProvider:
105      """Custom provider for Cloudflare Workers AI."""
106      
107      provider_id = "cloudflare"
108      
109      def __init__(self, model_id: str, config: dict = None):
110          self.model_id = model_id
111          self.config = config or {}
112          self.account_id = self.config.get('account_id')
113          self.api_token = self.config.get('api_token')
114      
115      def generate(self, prompt: str) -> str:
116          """Generate a response using Cloudflare Workers AI."""
117          import requests
118          
119          if not self.account_id or not self.api_token:
120              raise ValueError("Cloudflare account_id and api_token are required")
121          
122          response = requests.post(
123              f"https://api.cloudflare.com/client/v4/accounts/{self.account_id}/ai/run/{self.model_id}",
124              headers={
125                  "Authorization": f"Bearer {self.api_token}",
126                  "Content-Type": "application/json"
127              },
128              json={"prompt": prompt}
129          )
130          
131          if response.status_code != 200:
132              raise Exception(f"Cloudflare API error: {response.text}")
133          
134          return response.json().get("result", {}).get("response", "")
135  
136  
137  def main():
138      print("=== Provider Registry Example (Python) ===\n")
139      
140      # Check initial state
141      print("Initial providers:", list_llm_providers())
142      print()
143      
144      # Register custom providers
145      print("Registering custom providers...")
146      register_llm_provider("simple-custom", SimpleCustomProvider)
147      register_llm_provider("ollama", OllamaProvider, aliases=["local"])
148      register_llm_provider("cloudflare", CloudflareProvider, aliases=["cf", "workers-ai"])
149      print()
150      
151      # Check providers after registration
152      print("Providers after registration:", list_llm_providers())
153      print("Has ollama:", has_llm_provider("ollama"))
154      print("Has local (alias):", has_llm_provider("local"))
155      print("Has cloudflare:", has_llm_provider("cloudflare"))
156      print("Has cf (alias):", has_llm_provider("cf"))
157      print()
158      
159      # Parse model strings
160      print("=== Model String Parsing ===\n")
161      
162      test_strings = [
163          "openai/gpt-4o-mini",
164          "gpt-4o-mini",
165          "claude-3-5-sonnet-latest",
166          "gemini-2.0-flash",
167          "ollama/llama2",
168          "cloudflare/workers-ai-model"
169      ]
170      
171      for model_str in test_strings:
172          parsed = parse_model_string(model_str)
173          print(f"  '{model_str}' -> provider={parsed['provider_id']}, model={parsed['model_id']}")
174      print()
175      
176      # Create and use providers
177      print("=== Using Custom Providers ===\n")
178      
179      # Use simple custom provider
180      provider = create_llm_provider("simple-custom/test-model")
181      print(f"Created provider: {provider.provider_id}/{provider.model_id}")
182      response = provider.generate("Hello, world!")
183      print(f"Response: {response}")
184      print()
185      
186      # Use ollama provider via alias
187      provider = create_llm_provider("local/llama2", config={"base_url": "http://localhost:11434"})
188      print(f"Created provider: {provider.provider_id}/{provider.model_id}")
189      print()
190      
191      # Demonstrate error handling
192      print("=== Error Handling ===\n")
193      try:
194          create_llm_provider("unknown-provider/model")
195      except ValueError as e:
196          print(f"Expected error: {e}")
197      print()
198      
199      # Demonstrate isolated registries
200      print("=== Isolated Registries ===\n")
201      
202      # Create isolated registry
203      isolated_registry = LLMProviderRegistry()
204      isolated_registry.register("isolated-provider", SimpleCustomProvider)
205      
206      print(f"Default registry providers: {list_llm_providers()}")
207      print(f"Isolated registry providers: {isolated_registry.list()}")
208      
209      # Use isolated registry
210      provider = create_llm_provider("isolated-provider/model", registry=isolated_registry)
211      print(f"Created from isolated registry: {provider.provider_id}/{provider.model_id}")
212      print()
213      
214      # Cleanup
215      print("=== Cleanup ===\n")
216      unregister_llm_provider("simple-custom")
217      unregister_llm_provider("ollama")
218      unregister_llm_provider("cloudflare")
219      print(f"Providers after cleanup: {list_llm_providers()}")
220      
221      print("\n=== Example Complete ===")
222  
223  
224  if __name__ == "__main__":
225      main()