/ examples / provider-registry / custom_provider_example.ts
custom_provider_example.ts
  1  /**
  2   * Custom Provider Registry Example
  3   * 
  4   * This example demonstrates how to register and use custom LLM providers
  5   * with the PraisonAI TypeScript SDK.
  6   */
  7  
  8  import {
  9    registerProvider,
 10    createProvider,
 11    BaseProvider,
 12    listProviders,
 13    hasProvider,
 14    type ProviderConfig,
 15    type GenerateTextOptions,
 16    type GenerateTextResult,
 17    type StreamTextOptions,
 18    type StreamChunk,
 19    type GenerateObjectOptions,
 20    type GenerateObjectResult,
 21    type Message,
 22    type ToolDefinition
 23  } from 'praisonai';
 24  
 25  // Example 1: Simple Custom Provider
 26  // ---------------------------------
 27  // A minimal custom provider that wraps a hypothetical API
 28  
 29  class SimpleCustomProvider extends BaseProvider {
 30    readonly providerId = 'simple-custom';
 31    private apiEndpoint: string;
 32  
 33    constructor(modelId: string, config?: ProviderConfig) {
 34      super(modelId, config);
 35      this.apiEndpoint = (config as any)?.apiEndpoint || 'https://api.example.com';
 36    }
 37  
 38    async generateText(options: GenerateTextOptions): Promise<GenerateTextResult> {
 39      // In a real implementation, you would call your API here
 40      console.log(`[SimpleCustomProvider] Generating text with model: ${this.modelId}`);
 41      console.log(`[SimpleCustomProvider] Messages:`, options.messages);
 42      
 43      // Simulated response
 44      return {
 45        text: `Response from ${this.providerId}/${this.modelId}: Hello! This is a simulated response.`,
 46        usage: {
 47          promptTokens: 10,
 48          completionTokens: 20,
 49          totalTokens: 30
 50        }
 51      };
 52    }
 53  
 54    async *streamText(options: StreamTextOptions): AsyncGenerator<StreamChunk> {
 55      const words = ['Hello', 'from', 'streaming', 'response!'];
 56      for (const word of words) {
 57        yield { text: word + ' ', done: false };
 58        await new Promise(resolve => setTimeout(resolve, 100));
 59      }
 60      yield { text: '', done: true };
 61    }
 62  
 63    async generateObject<T>(options: GenerateObjectOptions<T>): Promise<GenerateObjectResult<T>> {
 64      // Simulated structured output
 65      return {
 66        object: { message: 'Structured response' } as T,
 67        usage: {
 68          promptTokens: 10,
 69          completionTokens: 20,
 70          totalTokens: 30
 71        }
 72      };
 73    }
 74  
 75    formatTools(tools: ToolDefinition[]): any[] {
 76      return tools;
 77    }
 78  
 79    formatMessages(messages: Message[]): any[] {
 80      return messages;
 81    }
 82  }
 83  
 84  // Example 2: Ollama Provider
 85  // --------------------------
 86  // A more realistic example for local Ollama integration
 87  
 88  class OllamaProvider extends BaseProvider {
 89    readonly providerId = 'ollama';
 90    private baseUrl: string;
 91  
 92    constructor(modelId: string, config?: ProviderConfig) {
 93      super(modelId, config);
 94      this.baseUrl = (config as any)?.baseUrl || 'http://localhost:11434';
 95    }
 96  
 97    async generateText(options: GenerateTextOptions): Promise<GenerateTextResult> {
 98      const response = await fetch(`${this.baseUrl}/api/generate`, {
 99        method: 'POST',
100        headers: { 'Content-Type': 'application/json' },
101        body: JSON.stringify({
102          model: this.modelId,
103          prompt: options.messages.map(m => `${m.role}: ${m.content}`).join('\n'),
104          stream: false
105        })
106      });
107  
108      if (!response.ok) {
109        throw new Error(`Ollama API error: ${response.statusText}`);
110      }
111  
112      const data = await response.json();
113      return {
114        text: data.response,
115        usage: {
116          promptTokens: data.prompt_eval_count || 0,
117          completionTokens: data.eval_count || 0,
118          totalTokens: (data.prompt_eval_count || 0) + (data.eval_count || 0)
119        }
120      };
121    }
122  
123    async *streamText(options: StreamTextOptions): AsyncGenerator<StreamChunk> {
124      const response = await fetch(`${this.baseUrl}/api/generate`, {
125        method: 'POST',
126        headers: { 'Content-Type': 'application/json' },
127        body: JSON.stringify({
128          model: this.modelId,
129          prompt: options.messages.map(m => `${m.role}: ${m.content}`).join('\n'),
130          stream: true
131        })
132      });
133  
134      if (!response.ok) {
135        throw new Error(`Ollama API error: ${response.statusText}`);
136      }
137  
138      const reader = response.body?.getReader();
139      if (!reader) throw new Error('No response body');
140  
141      const decoder = new TextDecoder();
142      while (true) {
143        const { done, value } = await reader.read();
144        if (done) break;
145        
146        const chunk = decoder.decode(value);
147        const lines = chunk.split('\n').filter(line => line.trim());
148        
149        for (const line of lines) {
150          try {
151            const data = JSON.parse(line);
152            yield { text: data.response || '', done: data.done || false };
153          } catch {
154            // Skip invalid JSON
155          }
156        }
157      }
158    }
159  
160    async generateObject<T>(options: GenerateObjectOptions<T>): Promise<GenerateObjectResult<T>> {
161      const result = await this.generateText({
162        messages: [
163          ...options.messages,
164          { role: 'system', content: `Respond with valid JSON matching this schema: ${JSON.stringify(options.schema)}` }
165        ]
166      });
167  
168      return {
169        object: JSON.parse(result.text) as T,
170        usage: result.usage
171      };
172    }
173  
174    formatTools(tools: ToolDefinition[]): any[] {
175      return tools;
176    }
177  
178    formatMessages(messages: Message[]): any[] {
179      return messages;
180    }
181  }
182  
183  // Main Example
184  // ------------
185  
186  async function main() {
187    console.log('=== Provider Registry Example ===\n');
188  
189    // Check initial providers
190    console.log('Initial providers:', listProviders());
191    console.log('Has openai:', hasProvider('openai'));
192    console.log('Has ollama:', hasProvider('ollama'));
193    console.log();
194  
195    // Register custom providers
196    console.log('Registering custom providers...');
197    registerProvider('simple-custom', SimpleCustomProvider);
198    registerProvider('ollama', OllamaProvider, { aliases: ['local'] });
199    console.log();
200  
201    // Check providers after registration
202    console.log('Providers after registration:', listProviders());
203    console.log('Has ollama:', hasProvider('ollama'));
204    console.log('Has local (alias):', hasProvider('local'));
205    console.log();
206  
207    // Create and use providers
208    console.log('=== Using Custom Providers ===\n');
209  
210    // Use simple custom provider
211    const simpleProvider = createProvider('simple-custom/test-model');
212    console.log(`Created provider: ${simpleProvider.providerId}/${simpleProvider.modelId}`);
213    
214    const simpleResult = await simpleProvider.generateText({
215      messages: [{ role: 'user', content: 'Hello!' }]
216    });
217    console.log('Response:', simpleResult.text);
218    console.log('Usage:', simpleResult.usage);
219    console.log();
220  
221    // Use ollama provider via alias
222    const ollamaProvider = createProvider('local/llama2', {
223      baseUrl: 'http://localhost:11434'
224    } as any);
225    console.log(`Created provider: ${ollamaProvider.providerId}/${ollamaProvider.modelId}`);
226    console.log();
227  
228    // Demonstrate error handling for unknown provider
229    console.log('=== Error Handling ===\n');
230    try {
231      createProvider('unknown-provider/model');
232    } catch (error: any) {
233      console.log('Expected error:', error.message);
234    }
235  
236    console.log('\n=== Example Complete ===');
237  }
238  
239  main().catch(console.error);