custom_provider_example.ts
1 /** 2 * Custom Provider Registry Example 3 * 4 * This example demonstrates how to register and use custom LLM providers 5 * with the PraisonAI TypeScript SDK. 6 */ 7 8 import { 9 registerProvider, 10 createProvider, 11 BaseProvider, 12 listProviders, 13 hasProvider, 14 type ProviderConfig, 15 type GenerateTextOptions, 16 type GenerateTextResult, 17 type StreamTextOptions, 18 type StreamChunk, 19 type GenerateObjectOptions, 20 type GenerateObjectResult, 21 type Message, 22 type ToolDefinition 23 } from 'praisonai'; 24 25 // Example 1: Simple Custom Provider 26 // --------------------------------- 27 // A minimal custom provider that wraps a hypothetical API 28 29 class SimpleCustomProvider extends BaseProvider { 30 readonly providerId = 'simple-custom'; 31 private apiEndpoint: string; 32 33 constructor(modelId: string, config?: ProviderConfig) { 34 super(modelId, config); 35 this.apiEndpoint = (config as any)?.apiEndpoint || 'https://api.example.com'; 36 } 37 38 async generateText(options: GenerateTextOptions): Promise<GenerateTextResult> { 39 // In a real implementation, you would call your API here 40 console.log(`[SimpleCustomProvider] Generating text with model: ${this.modelId}`); 41 console.log(`[SimpleCustomProvider] Messages:`, options.messages); 42 43 // Simulated response 44 return { 45 text: `Response from ${this.providerId}/${this.modelId}: Hello! This is a simulated response.`, 46 usage: { 47 promptTokens: 10, 48 completionTokens: 20, 49 totalTokens: 30 50 } 51 }; 52 } 53 54 async *streamText(options: StreamTextOptions): AsyncGenerator<StreamChunk> { 55 const words = ['Hello', 'from', 'streaming', 'response!']; 56 for (const word of words) { 57 yield { text: word + ' ', done: false }; 58 await new Promise(resolve => setTimeout(resolve, 100)); 59 } 60 yield { text: '', done: true }; 61 } 62 63 async generateObject<T>(options: GenerateObjectOptions<T>): Promise<GenerateObjectResult<T>> { 64 // Simulated structured output 65 return { 66 object: { message: 'Structured response' } as T, 67 usage: { 68 promptTokens: 10, 69 completionTokens: 20, 70 totalTokens: 30 71 } 72 }; 73 } 74 75 formatTools(tools: ToolDefinition[]): any[] { 76 return tools; 77 } 78 79 formatMessages(messages: Message[]): any[] { 80 return messages; 81 } 82 } 83 84 // Example 2: Ollama Provider 85 // -------------------------- 86 // A more realistic example for local Ollama integration 87 88 class OllamaProvider extends BaseProvider { 89 readonly providerId = 'ollama'; 90 private baseUrl: string; 91 92 constructor(modelId: string, config?: ProviderConfig) { 93 super(modelId, config); 94 this.baseUrl = (config as any)?.baseUrl || 'http://localhost:11434'; 95 } 96 97 async generateText(options: GenerateTextOptions): Promise<GenerateTextResult> { 98 const response = await fetch(`${this.baseUrl}/api/generate`, { 99 method: 'POST', 100 headers: { 'Content-Type': 'application/json' }, 101 body: JSON.stringify({ 102 model: this.modelId, 103 prompt: options.messages.map(m => `${m.role}: ${m.content}`).join('\n'), 104 stream: false 105 }) 106 }); 107 108 if (!response.ok) { 109 throw new Error(`Ollama API error: ${response.statusText}`); 110 } 111 112 const data = await response.json(); 113 return { 114 text: data.response, 115 usage: { 116 promptTokens: data.prompt_eval_count || 0, 117 completionTokens: data.eval_count || 0, 118 totalTokens: (data.prompt_eval_count || 0) + (data.eval_count || 0) 119 } 120 }; 121 } 122 123 async *streamText(options: StreamTextOptions): AsyncGenerator<StreamChunk> { 124 const response = await fetch(`${this.baseUrl}/api/generate`, { 125 method: 'POST', 126 headers: { 'Content-Type': 'application/json' }, 127 body: JSON.stringify({ 128 model: this.modelId, 129 prompt: options.messages.map(m => `${m.role}: ${m.content}`).join('\n'), 130 stream: true 131 }) 132 }); 133 134 if (!response.ok) { 135 throw new Error(`Ollama API error: ${response.statusText}`); 136 } 137 138 const reader = response.body?.getReader(); 139 if (!reader) throw new Error('No response body'); 140 141 const decoder = new TextDecoder(); 142 while (true) { 143 const { done, value } = await reader.read(); 144 if (done) break; 145 146 const chunk = decoder.decode(value); 147 const lines = chunk.split('\n').filter(line => line.trim()); 148 149 for (const line of lines) { 150 try { 151 const data = JSON.parse(line); 152 yield { text: data.response || '', done: data.done || false }; 153 } catch { 154 // Skip invalid JSON 155 } 156 } 157 } 158 } 159 160 async generateObject<T>(options: GenerateObjectOptions<T>): Promise<GenerateObjectResult<T>> { 161 const result = await this.generateText({ 162 messages: [ 163 ...options.messages, 164 { role: 'system', content: `Respond with valid JSON matching this schema: ${JSON.stringify(options.schema)}` } 165 ] 166 }); 167 168 return { 169 object: JSON.parse(result.text) as T, 170 usage: result.usage 171 }; 172 } 173 174 formatTools(tools: ToolDefinition[]): any[] { 175 return tools; 176 } 177 178 formatMessages(messages: Message[]): any[] { 179 return messages; 180 } 181 } 182 183 // Main Example 184 // ------------ 185 186 async function main() { 187 console.log('=== Provider Registry Example ===\n'); 188 189 // Check initial providers 190 console.log('Initial providers:', listProviders()); 191 console.log('Has openai:', hasProvider('openai')); 192 console.log('Has ollama:', hasProvider('ollama')); 193 console.log(); 194 195 // Register custom providers 196 console.log('Registering custom providers...'); 197 registerProvider('simple-custom', SimpleCustomProvider); 198 registerProvider('ollama', OllamaProvider, { aliases: ['local'] }); 199 console.log(); 200 201 // Check providers after registration 202 console.log('Providers after registration:', listProviders()); 203 console.log('Has ollama:', hasProvider('ollama')); 204 console.log('Has local (alias):', hasProvider('local')); 205 console.log(); 206 207 // Create and use providers 208 console.log('=== Using Custom Providers ===\n'); 209 210 // Use simple custom provider 211 const simpleProvider = createProvider('simple-custom/test-model'); 212 console.log(`Created provider: ${simpleProvider.providerId}/${simpleProvider.modelId}`); 213 214 const simpleResult = await simpleProvider.generateText({ 215 messages: [{ role: 'user', content: 'Hello!' }] 216 }); 217 console.log('Response:', simpleResult.text); 218 console.log('Usage:', simpleResult.usage); 219 console.log(); 220 221 // Use ollama provider via alias 222 const ollamaProvider = createProvider('local/llama2', { 223 baseUrl: 'http://localhost:11434' 224 } as any); 225 console.log(`Created provider: ${ollamaProvider.providerId}/${ollamaProvider.modelId}`); 226 console.log(); 227 228 // Demonstrate error handling for unknown provider 229 console.log('=== Error Handling ===\n'); 230 try { 231 createProvider('unknown-provider/model'); 232 } catch (error: any) { 233 console.log('Expected error:', error.message); 234 } 235 236 console.log('\n=== Example Complete ==='); 237 } 238 239 main().catch(console.error);