custom_provider_example.py
1 """ 2 Custom Provider Registry Example (Python) 3 4 This example demonstrates how to register and use custom LLM providers 5 with the PraisonAI Python wrapper. 6 7 Note: The Python provider registry is for custom provider extensions. 8 Built-in providers (OpenAI, Anthropic, Google) are handled by LiteLLM 9 in praisonaiagents automatically. 10 """ 11 12 import sys 13 sys.path.insert(0, '../../src/praisonai') 14 15 from praisonai.llm import ( 16 LLMProviderRegistry, 17 register_llm_provider, 18 unregister_llm_provider, 19 has_llm_provider, 20 list_llm_providers, 21 create_llm_provider, 22 get_default_llm_registry, 23 parse_model_string 24 ) 25 26 27 # Example 1: Simple Custom Provider 28 # --------------------------------- 29 30 class SimpleCustomProvider: 31 """A minimal custom provider example.""" 32 33 provider_id = "simple-custom" 34 35 def __init__(self, model_id: str, config: dict = None): 36 self.model_id = model_id 37 self.config = config or {} 38 self.api_endpoint = self.config.get('api_endpoint', 'https://api.example.com') 39 40 def generate(self, prompt: str) -> str: 41 """Generate a response (simulated).""" 42 print(f"[SimpleCustomProvider] Generating with model: {self.model_id}") 43 print(f"[SimpleCustomProvider] Prompt: {prompt[:50]}...") 44 return f"Response from {self.provider_id}/{self.model_id}: Hello! This is a simulated response." 45 46 47 # Example 2: Ollama Provider 48 # -------------------------- 49 50 class OllamaProvider: 51 """Custom provider for local Ollama integration.""" 52 53 provider_id = "ollama" 54 55 def __init__(self, model_id: str, config: dict = None): 56 self.model_id = model_id 57 self.config = config or {} 58 self.base_url = self.config.get('base_url', 'http://localhost:11434') 59 60 def generate(self, prompt: str) -> str: 61 """Generate a response using Ollama API.""" 62 import requests 63 64 response = requests.post( 65 f"{self.base_url}/api/generate", 66 json={ 67 "model": self.model_id, 68 "prompt": prompt, 69 "stream": False 70 } 71 ) 72 73 if response.status_code != 200: 74 raise Exception(f"Ollama API error: {response.text}") 75 76 return response.json().get("response", "") 77 78 def generate_stream(self, prompt: str): 79 """Generate a streaming response using Ollama API.""" 80 import requests 81 82 response = requests.post( 83 f"{self.base_url}/api/generate", 84 json={ 85 "model": self.model_id, 86 "prompt": prompt, 87 "stream": True 88 }, 89 stream=True 90 ) 91 92 for line in response.iter_lines(): 93 if line: 94 import json 95 data = json.loads(line) 96 yield data.get("response", "") 97 if data.get("done"): 98 break 99 100 101 # Example 3: Cloudflare Workers AI Provider 102 # ----------------------------------------- 103 104 class CloudflareProvider: 105 """Custom provider for Cloudflare Workers AI.""" 106 107 provider_id = "cloudflare" 108 109 def __init__(self, model_id: str, config: dict = None): 110 self.model_id = model_id 111 self.config = config or {} 112 self.account_id = self.config.get('account_id') 113 self.api_token = self.config.get('api_token') 114 115 def generate(self, prompt: str) -> str: 116 """Generate a response using Cloudflare Workers AI.""" 117 import requests 118 119 if not self.account_id or not self.api_token: 120 raise ValueError("Cloudflare account_id and api_token are required") 121 122 response = requests.post( 123 f"https://api.cloudflare.com/client/v4/accounts/{self.account_id}/ai/run/{self.model_id}", 124 headers={ 125 "Authorization": f"Bearer {self.api_token}", 126 "Content-Type": "application/json" 127 }, 128 json={"prompt": prompt} 129 ) 130 131 if response.status_code != 200: 132 raise Exception(f"Cloudflare API error: {response.text}") 133 134 return response.json().get("result", {}).get("response", "") 135 136 137 def main(): 138 print("=== Provider Registry Example (Python) ===\n") 139 140 # Check initial state 141 print("Initial providers:", list_llm_providers()) 142 print() 143 144 # Register custom providers 145 print("Registering custom providers...") 146 register_llm_provider("simple-custom", SimpleCustomProvider) 147 register_llm_provider("ollama", OllamaProvider, aliases=["local"]) 148 register_llm_provider("cloudflare", CloudflareProvider, aliases=["cf", "workers-ai"]) 149 print() 150 151 # Check providers after registration 152 print("Providers after registration:", list_llm_providers()) 153 print("Has ollama:", has_llm_provider("ollama")) 154 print("Has local (alias):", has_llm_provider("local")) 155 print("Has cloudflare:", has_llm_provider("cloudflare")) 156 print("Has cf (alias):", has_llm_provider("cf")) 157 print() 158 159 # Parse model strings 160 print("=== Model String Parsing ===\n") 161 162 test_strings = [ 163 "openai/gpt-4o-mini", 164 "gpt-4o-mini", 165 "claude-3-5-sonnet-latest", 166 "gemini-2.0-flash", 167 "ollama/llama2", 168 "cloudflare/workers-ai-model" 169 ] 170 171 for model_str in test_strings: 172 parsed = parse_model_string(model_str) 173 print(f" '{model_str}' -> provider={parsed['provider_id']}, model={parsed['model_id']}") 174 print() 175 176 # Create and use providers 177 print("=== Using Custom Providers ===\n") 178 179 # Use simple custom provider 180 provider = create_llm_provider("simple-custom/test-model") 181 print(f"Created provider: {provider.provider_id}/{provider.model_id}") 182 response = provider.generate("Hello, world!") 183 print(f"Response: {response}") 184 print() 185 186 # Use ollama provider via alias 187 provider = create_llm_provider("local/llama2", config={"base_url": "http://localhost:11434"}) 188 print(f"Created provider: {provider.provider_id}/{provider.model_id}") 189 print() 190 191 # Demonstrate error handling 192 print("=== Error Handling ===\n") 193 try: 194 create_llm_provider("unknown-provider/model") 195 except ValueError as e: 196 print(f"Expected error: {e}") 197 print() 198 199 # Demonstrate isolated registries 200 print("=== Isolated Registries ===\n") 201 202 # Create isolated registry 203 isolated_registry = LLMProviderRegistry() 204 isolated_registry.register("isolated-provider", SimpleCustomProvider) 205 206 print(f"Default registry providers: {list_llm_providers()}") 207 print(f"Isolated registry providers: {isolated_registry.list()}") 208 209 # Use isolated registry 210 provider = create_llm_provider("isolated-provider/model", registry=isolated_registry) 211 print(f"Created from isolated registry: {provider.provider_id}/{provider.model_id}") 212 print() 213 214 # Cleanup 215 print("=== Cleanup ===\n") 216 unregister_llm_provider("simple-custom") 217 unregister_llm_provider("ollama") 218 unregister_llm_provider("cloudflare") 219 print(f"Providers after cleanup: {list_llm_providers()}") 220 221 print("\n=== Example Complete ===") 222 223 224 if __name__ == "__main__": 225 main()