/ go / internal / cli / ask.go
ask.go
  1  package cli
  2  
  3  import (
  4  	"context"
  5  	"fmt"
  6  	"os"
  7  	"strings"
  8  
  9  	"github.com/spf13/cobra"
 10  	"github.com/TransformerOS/kamaji-go/internal/config"
 11  	"github.com/TransformerOS/kamaji-go/internal/providers"
 12  	"github.com/TransformerOS/kamaji-go/internal/rag"
 13  	"github.com/TransformerOS/kamaji-go/internal/style"
 14  	"github.com/TransformerOS/kamaji-go/internal/types"
 15  )
 16  
 17  var askCmd = &cobra.Command{
 18  	Use:   "ask [question]",
 19  	Short: "Ask a single question",
 20  	Long:  "Ask a one-off question and get an answer",
 21  	RunE:  runAsk,
 22  }
 23  
 24  func init() {
 25  	askCmd.Flags().Float64P("temperature", "t", 0.7, "Temperature for response generation (0.0 - 1.0)")
 26  	askCmd.Flags().Bool("no-stream", false, "Disable streaming output")
 27  	askCmd.Flags().StringSliceP("files", "f", []string{}, "Files to load into context for RAG")
 28  }
 29  
 30  func runAsk(cmd *cobra.Command, args []string) error {
 31  	if len(args) == 0 {
 32  		return fmt.Errorf("please provide a question")
 33  	}
 34  
 35  	question := strings.Join(args, " ")
 36  	noStream, _ := cmd.Flags().GetBool("no-stream")
 37  	files, _ := cmd.Flags().GetStringSlice("files")
 38  
 39  	// Load config
 40  	cfg, err := config.Load()
 41  	if err != nil {
 42  		return fmt.Errorf("failed to load config: %w", err)
 43  	}
 44  
 45  	// Create LLM provider
 46  	var llm types.LLMProvider
 47  	switch cfg.Provider {
 48  	case "ollama":
 49  		llm, err = providers.NewOllamaProvider(cfg.BaseURL, cfg.Model)
 50  		if err != nil {
 51  			return fmt.Errorf("failed to create Ollama provider: %w", err)
 52  		}
 53  	case "anthropic":
 54  		apiKey := os.Getenv("ANTHROPIC_API_KEY")
 55  		if apiKey == "" {
 56  			return fmt.Errorf("ANTHROPIC_API_KEY not set")
 57  		}
 58  		llm, err = providers.NewAnthropicProvider(apiKey, cfg.Model)
 59  		if err != nil {
 60  			return fmt.Errorf("failed to create Anthropic provider: %w", err)
 61  		}
 62  	case "openai":
 63  		apiKey := os.Getenv("OPENAI_API_KEY")
 64  		if apiKey == "" {
 65  			return fmt.Errorf("OPENAI_API_KEY not set")
 66  		}
 67  		llm = providers.NewOpenAIProviderWrapper(apiKey, cfg.Model)
 68  	case "q":
 69  		llm, err = providers.GetQProvider()
 70  		if err != nil {
 71  			return fmt.Errorf("failed to create Q provider: %w", err)
 72  		}
 73  	default:
 74  		return fmt.Errorf("unsupported provider: %s", cfg.Provider)
 75  	}
 76  
 77  	// Load documents if files provided
 78  	var docContext string
 79  	if len(files) > 0 {
 80  		docStore := rag.NewDocumentStore()
 81  		loadedCount := docStore.LoadDocuments(files)
 82  		if loadedCount > 0 {
 83  			fmt.Printf("%s\n\n", style.Success(fmt.Sprintf("📂 Loaded %d document(s) into context", loadedCount)))
 84  			// Query documents with the question to get relevant context
 85  			docContext = docStore.Query(question, 3)
 86  		} else {
 87  			fmt.Printf("%s\n\n", style.Warning("⚠️  Warning: Could not load any documents"))
 88  		}
 89  	}
 90  
 91  	// Build prompt with document context if available
 92  	prompt := question
 93  	if docContext != "" {
 94  		prompt = fmt.Sprintf("Context from documents:\n%s\n\nQuestion: %s", docContext, question)
 95  	}
 96  
 97  	ctx := context.Background()
 98  
 99  	if noStream {
100  		// Non-streaming mode
101  		fmt.Printf("\n%s\n", style.Fire("🔥 Kamaji:"))
102  		response, err := llm.Call(ctx, prompt)
103  		if err != nil {
104  			return fmt.Errorf("LLM call failed: %w", err)
105  		}
106  		fmt.Printf("%s\n\n", response)
107  	} else {
108  		// Streaming mode
109  		fmt.Printf("\n%s ", style.Fire("🔥 Kamaji:"))
110  		responseChan, err := llm.CallStream(ctx, prompt)
111  		if err != nil {
112  			return fmt.Errorf("LLM streaming call failed: %w", err)
113  		}
114  
115  		for chunk := range responseChan {
116  			if chunk.Error != nil {
117  				fmt.Printf("\n%s\n", style.Error(fmt.Sprintf("Error: %v", chunk.Error)))
118  				break
119  			}
120  			fmt.Print(chunk.Content)
121  		}
122  		fmt.Println("\n")
123  	}
124  
125  	return nil
126  }