/ internal / client / ollama.go
ollama.go
  1  package client
  2  
  3  import (
  4  	"bufio"
  5  	"bytes"
  6  	"context"
  7  	"encoding/json"
  8  	"fmt"
  9  	"io"
 10  	"net/http"
 11  	"strings"
 12  	"time"
 13  )
 14  
 15  // ---------------------------------------------------------------------------
 16  // OpenAI-compatible message types (used for Ollama request/response)
 17  // ---------------------------------------------------------------------------
 18  
 19  type openAIMessage struct {
 20  	Role       string           `json:"role"`
 21  	Content    string           `json:"content"`
 22  	ToolCalls  []openAIToolCall `json:"tool_calls,omitempty"`
 23  	ToolCallID string           `json:"tool_call_id,omitempty"`
 24  }
 25  
 26  type openAIToolCall struct {
 27  	ID       string             `json:"id"`
 28  	Type     string             `json:"type"`
 29  	Function openAIFunctionCall `json:"function"`
 30  }
 31  
 32  type openAIFunctionCall struct {
 33  	Name      string `json:"name"`
 34  	Arguments string `json:"arguments"`
 35  }
 36  
 37  // ---------------------------------------------------------------------------
 38  // Format conversion: ShanClaw (Anthropic) → OpenAI
 39  // ---------------------------------------------------------------------------
 40  
 41  // convertMessagesToOpenAI converts ShanClaw's Anthropic-format messages to
 42  // OpenAI-compatible format for Ollama.
 43  func convertMessagesToOpenAI(msgs []Message) []openAIMessage {
 44  	var result []openAIMessage
 45  	for _, msg := range msgs {
 46  		if !msg.Content.HasBlocks() {
 47  			result = append(result, openAIMessage{
 48  				Role:    msg.Role,
 49  				Content: msg.Content.Text(),
 50  			})
 51  			continue
 52  		}
 53  
 54  		blocks := msg.Content.Blocks()
 55  
 56  		var toolUseBlocks []ContentBlock
 57  		var toolResultBlocks []ContentBlock
 58  		var textParts []string
 59  
 60  		for _, b := range blocks {
 61  			switch b.Type {
 62  			case "tool_use":
 63  				toolUseBlocks = append(toolUseBlocks, b)
 64  			case "tool_result":
 65  				toolResultBlocks = append(toolResultBlocks, b)
 66  			case "text":
 67  				if b.Text != "" {
 68  					textParts = append(textParts, b.Text)
 69  				}
 70  			}
 71  		}
 72  
 73  		if msg.Role == "assistant" && len(toolUseBlocks) > 0 {
 74  			// Assistant message with tool calls → OpenAI tool_calls format
 75  			m := openAIMessage{
 76  				Role:    "assistant",
 77  				Content: strings.Join(textParts, "\n"),
 78  			}
 79  			for _, b := range toolUseBlocks {
 80  				m.ToolCalls = append(m.ToolCalls, openAIToolCall{
 81  					ID:   b.ID,
 82  					Type: "function",
 83  					Function: openAIFunctionCall{
 84  						Name:      b.Name,
 85  						Arguments: string(b.Input),
 86  					},
 87  				})
 88  			}
 89  			result = append(result, m)
 90  		} else if len(toolResultBlocks) > 0 {
 91  			// Tool results → one role:tool message per result
 92  			for _, b := range toolResultBlocks {
 93  				result = append(result, openAIMessage{
 94  					Role:       "tool",
 95  					Content:    ToolResultText(b),
 96  					ToolCallID: b.ToolUseID,
 97  				})
 98  			}
 99  		} else {
100  			// Text-only blocks → concatenate
101  			result = append(result, openAIMessage{
102  				Role:    msg.Role,
103  				Content: strings.Join(textParts, "\n"),
104  			})
105  		}
106  	}
107  	return result
108  }
109  
110  // ---------------------------------------------------------------------------
111  // Format conversion: OpenAI → ShanClaw
112  // ---------------------------------------------------------------------------
113  
114  // mapFinishReason converts OpenAI finish reasons to ShanClaw's internal values.
115  func mapFinishReason(reason string) string {
116  	switch reason {
117  	case "stop":
118  		return "end_turn"
119  	case "length":
120  		return "max_tokens"
121  	case "tool_calls":
122  		return "tool_use"
123  	case "content_filter":
124  		return "content_filter"
125  	default:
126  		return "end_turn"
127  	}
128  }
129  
130  // convertOpenAIResponse parses an OpenAI-format JSON response into CompletionResponse.
131  func convertOpenAIResponse(data []byte) (*CompletionResponse, error) {
132  	var raw struct {
133  		Model   string `json:"model"`
134  		Choices []struct {
135  			Message struct {
136  				Role      string  `json:"role"`
137  				Content   *string `json:"content"`
138  				Reasoning *string `json:"reasoning"`
139  				ToolCalls []struct {
140  					ID       string `json:"id"`
141  					Type     string `json:"type"`
142  					Function struct {
143  						Name      string          `json:"name"`
144  						Arguments json.RawMessage `json:"arguments"`
145  					} `json:"function"`
146  				} `json:"tool_calls"`
147  			} `json:"message"`
148  			FinishReason string `json:"finish_reason"`
149  		} `json:"choices"`
150  		Usage struct {
151  			PromptTokens     int `json:"prompt_tokens"`
152  			CompletionTokens int `json:"completion_tokens"`
153  			TotalTokens      int `json:"total_tokens"`
154  		} `json:"usage"`
155  	}
156  
157  	if err := json.Unmarshal(data, &raw); err != nil {
158  		return nil, fmt.Errorf("decode OpenAI response: %w", err)
159  	}
160  	if len(raw.Choices) == 0 {
161  		return nil, fmt.Errorf("empty choices in response")
162  	}
163  
164  	choice := raw.Choices[0]
165  	resp := &CompletionResponse{
166  		Provider:     "ollama",
167  		Model:        raw.Model,
168  		FinishReason: mapFinishReason(choice.FinishReason),
169  		Usage: Usage{
170  			InputTokens:  raw.Usage.PromptTokens,
171  			OutputTokens: raw.Usage.CompletionTokens,
172  			TotalTokens:  raw.Usage.TotalTokens,
173  		},
174  	}
175  
176  	if choice.Message.Content != nil && *choice.Message.Content != "" {
177  		resp.OutputText = *choice.Message.Content
178  	} else if choice.Message.Reasoning != nil && *choice.Message.Reasoning != "" {
179  		// Thinking models (e.g. Qwen3) may exhaust max_tokens during reasoning,
180  		// leaving content empty. Surface the reasoning so the user sees something.
181  		resp.OutputText = "[thinking] " + *choice.Message.Reasoning
182  	}
183  
184  	for _, tc := range choice.Message.ToolCalls {
185  		args := tc.Function.Arguments
186  		// Arguments can be a JSON string (OpenAI compat) or JSON object (Ollama native).
187  		// Normalize to raw JSON object for ShanClaw's FunctionCall.Arguments.
188  		var argsStr string
189  		if err := json.Unmarshal(args, &argsStr); err == nil {
190  			args = json.RawMessage(argsStr)
191  		}
192  
193  		resp.ToolCalls = append(resp.ToolCalls, FunctionCall{
194  			ID:        tc.ID,
195  			Name:      tc.Function.Name,
196  			Arguments: args,
197  		})
198  	}
199  
200  	return resp, nil
201  }
202  
203  // ---------------------------------------------------------------------------
204  // Tool schema filtering
205  // ---------------------------------------------------------------------------
206  
207  // filterToolsForOpenAI removes native tool definitions (e.g. Anthropic computer use)
208  // that Ollama doesn't support. Only standard function tools pass through.
209  func filterToolsForOpenAI(tools []Tool) []Tool {
210  	if tools == nil {
211  		return nil
212  	}
213  	var filtered []Tool
214  	for _, t := range tools {
215  		if t.Type == "function" {
216  			filtered = append(filtered, t)
217  		}
218  	}
219  	return filtered
220  }
221  
222  // ---------------------------------------------------------------------------
223  // OllamaModel — model metadata from /api/tags
224  // ---------------------------------------------------------------------------
225  
226  // OllamaModel represents a model available in the local Ollama instance.
227  type OllamaModel struct {
228  	Name    string `json:"name"`
229  	Size    int64  `json:"size"`
230  	Details struct {
231  		ParameterSize string `json:"parameter_size"`
232  	} `json:"details"`
233  }
234  
235  // ---------------------------------------------------------------------------
236  // OllamaClient
237  // ---------------------------------------------------------------------------
238  
239  // OllamaClient implements LLMClient for local Ollama instances via the
240  // OpenAI-compatible /v1/chat/completions endpoint.
241  type OllamaClient struct {
242  	endpoint   string
243  	model      string
244  	httpClient *http.Client
245  }
246  
247  // NewOllamaClient creates a client for a local Ollama instance.
248  // endpoint is the base URL (e.g. "http://localhost:11434").
249  // model is the default model name (e.g. "llama3.1").
250  func NewOllamaClient(endpoint, model string) *OllamaClient {
251  	return &OllamaClient{
252  		endpoint: endpoint,
253  		model:    model,
254  		httpClient: &http.Client{
255  			Timeout: 600 * time.Second,
256  		},
257  	}
258  }
259  
260  func (c *OllamaClient) resolveModel(req CompletionRequest) string {
261  	if req.SpecificModel != "" {
262  		return req.SpecificModel
263  	}
264  	return c.model
265  }
266  
267  type ollamaRequestBody struct {
268  	Model       string          `json:"model"`
269  	Messages    []openAIMessage `json:"messages"`
270  	Tools       []Tool          `json:"tools,omitempty"`
271  	Temperature float64         `json:"temperature,omitempty"`
272  	MaxTokens   int             `json:"max_tokens,omitempty"`
273  	Stream      bool            `json:"stream"`
274  }
275  
276  func (c *OllamaClient) buildRequestBody(req CompletionRequest, stream bool) ollamaRequestBody {
277  	return ollamaRequestBody{
278  		Model:       c.resolveModel(req),
279  		Messages:    convertMessagesToOpenAI(req.Messages),
280  		Tools:       filterToolsForOpenAI(req.Tools),
281  		Temperature: req.Temperature,
282  		MaxTokens:   req.MaxTokens,
283  		Stream:      stream,
284  	}
285  }
286  
287  // Complete sends a non-streaming completion request to Ollama.
288  func (c *OllamaClient) Complete(ctx context.Context, req CompletionRequest) (*CompletionResponse, error) {
289  	body, err := json.Marshal(c.buildRequestBody(req, false))
290  	if err != nil {
291  		return nil, fmt.Errorf("marshal request: %w", err)
292  	}
293  
294  	httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost,
295  		c.endpoint+"/v1/chat/completions", bytes.NewReader(body))
296  	if err != nil {
297  		return nil, fmt.Errorf("create request: %w", err)
298  	}
299  	httpReq.Header.Set("Content-Type", "application/json")
300  
301  	resp, err := c.httpClient.Do(httpReq)
302  	if err != nil {
303  		return nil, fmt.Errorf("request failed: %w", err)
304  	}
305  	defer resp.Body.Close()
306  
307  	if resp.StatusCode != http.StatusOK {
308  		return nil, &APIError{StatusCode: resp.StatusCode, Body: readResponseBody(resp)}
309  	}
310  
311  	respData, err := io.ReadAll(resp.Body)
312  	if err != nil {
313  		return nil, fmt.Errorf("read response: %w", err)
314  	}
315  
316  	return convertOpenAIResponse(respData)
317  }
318  
319  // CompleteStream sends a streaming completion request to Ollama.
320  // It calls onDelta for each text chunk and returns the final response.
321  func (c *OllamaClient) CompleteStream(ctx context.Context, req CompletionRequest, onDelta func(StreamDelta)) (*CompletionResponse, error) {
322  	body, err := json.Marshal(c.buildRequestBody(req, true))
323  	if err != nil {
324  		return nil, fmt.Errorf("marshal request: %w", err)
325  	}
326  
327  	httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost,
328  		c.endpoint+"/v1/chat/completions", bytes.NewReader(body))
329  	if err != nil {
330  		return nil, fmt.Errorf("create request: %w", err)
331  	}
332  	httpReq.Header.Set("Content-Type", "application/json")
333  	httpReq.Header.Set("Accept", "text/event-stream")
334  
335  	resp, err := c.httpClient.Do(httpReq)
336  	if err != nil {
337  		return nil, fmt.Errorf("request failed: %w", err)
338  	}
339  	defer resp.Body.Close()
340  
341  	if resp.StatusCode != http.StatusOK {
342  		return nil, &APIError{StatusCode: resp.StatusCode, Body: readResponseBody(resp)}
343  	}
344  
345  	scanner := bufio.NewScanner(resp.Body)
346  	scanner.Buffer(make([]byte, 0, 64*1024), 4*1024*1024)
347  
348  	var fullText strings.Builder
349  	var reasoningText strings.Builder
350  	var finishReason string
351  	var model string
352  	var usage Usage
353  	var toolCalls []FunctionCall
354  
355  	for scanner.Scan() {
356  		line := scanner.Text()
357  		if line == "" || strings.HasPrefix(line, ":") {
358  			continue
359  		}
360  		if !strings.HasPrefix(line, "data: ") {
361  			continue
362  		}
363  		payload := line[6:]
364  		if payload == "[DONE]" {
365  			break
366  		}
367  
368  		var chunk struct {
369  			Model   string `json:"model"`
370  			Choices []struct {
371  				Delta struct {
372  					Content   string `json:"content"`
373  					Reasoning string `json:"reasoning"`
374  					ToolCalls []struct {
375  						ID       string `json:"id"`
376  						Type     string `json:"type"`
377  						Function struct {
378  							Name      string          `json:"name"`
379  							Arguments json.RawMessage `json:"arguments"`
380  						} `json:"function"`
381  					} `json:"tool_calls"`
382  				} `json:"delta"`
383  				FinishReason *string `json:"finish_reason"`
384  			} `json:"choices"`
385  			Usage *struct {
386  				PromptTokens     int `json:"prompt_tokens"`
387  				CompletionTokens int `json:"completion_tokens"`
388  				TotalTokens      int `json:"total_tokens"`
389  			} `json:"usage"`
390  		}
391  
392  		if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
393  			continue
394  		}
395  
396  		if chunk.Model != "" {
397  			model = chunk.Model
398  		}
399  
400  		if len(chunk.Choices) > 0 {
401  			delta := chunk.Choices[0].Delta
402  			if delta.Content != "" {
403  				fullText.WriteString(delta.Content)
404  				if onDelta != nil {
405  					onDelta(StreamDelta{Text: delta.Content})
406  				}
407  			}
408  			if delta.Reasoning != "" {
409  				reasoningText.WriteString(delta.Reasoning)
410  			}
411  			// Ollama sends tool calls as a complete chunk (not delta-split like OpenAI)
412  			for _, tc := range delta.ToolCalls {
413  				args := tc.Function.Arguments
414  				var argsStr string
415  				if err := json.Unmarshal(args, &argsStr); err == nil {
416  					args = json.RawMessage(argsStr)
417  				}
418  				toolCalls = append(toolCalls, FunctionCall{
419  					ID:        tc.ID,
420  					Name:      tc.Function.Name,
421  					Arguments: args,
422  				})
423  			}
424  			if chunk.Choices[0].FinishReason != nil {
425  				finishReason = *chunk.Choices[0].FinishReason
426  			}
427  		}
428  
429  		if chunk.Usage != nil {
430  			usage = Usage{
431  				InputTokens:  chunk.Usage.PromptTokens,
432  				OutputTokens: chunk.Usage.CompletionTokens,
433  				TotalTokens:  chunk.Usage.TotalTokens,
434  			}
435  		}
436  	}
437  
438  	if err := scanner.Err(); err != nil {
439  		return nil, fmt.Errorf("stream read error: %w", err)
440  	}
441  
442  	output := fullText.String()
443  	if output == "" && reasoningText.Len() > 0 {
444  		output = "[thinking] " + reasoningText.String()
445  	}
446  
447  	return &CompletionResponse{
448  		Provider:     "ollama",
449  		Model:        model,
450  		OutputText:   output,
451  		FinishReason: mapFinishReason(finishReason),
452  		ToolCalls:    toolCalls,
453  		Usage:        usage,
454  	}, nil
455  }
456  
457  // ListModels queries the Ollama instance for available models via /api/tags.
458  func (c *OllamaClient) ListModels(ctx context.Context) ([]OllamaModel, error) {
459  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.endpoint+"/api/tags", nil)
460  	if err != nil {
461  		return nil, fmt.Errorf("create request: %w", err)
462  	}
463  	resp, err := c.httpClient.Do(req)
464  	if err != nil {
465  		return nil, fmt.Errorf("request failed: %w", err)
466  	}
467  	defer resp.Body.Close()
468  	if resp.StatusCode != http.StatusOK {
469  		return nil, fmt.Errorf("ollama returned %d", resp.StatusCode)
470  	}
471  	var result struct {
472  		Models []OllamaModel `json:"models"`
473  	}
474  	if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
475  		return nil, fmt.Errorf("decode response: %w", err)
476  	}
477  	return result.Models, nil
478  }
479  
480  // CheckHealth checks if the Ollama instance is reachable.
481  func (c *OllamaClient) CheckHealth(ctx context.Context) error {
482  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.endpoint+"/", nil)
483  	if err != nil {
484  		return err
485  	}
486  	resp, err := c.httpClient.Do(req)
487  	if err != nil {
488  		return fmt.Errorf("ollama not reachable: %w", err)
489  	}
490  	defer resp.Body.Close()
491  	if resp.StatusCode != http.StatusOK {
492  		return fmt.Errorf("ollama returned %d", resp.StatusCode)
493  	}
494  	return nil
495  }