/ go / internal / cli / rag.go
rag.go
  1  package cli
  2  
  3  import (
  4  	"fmt"
  5  	"os"
  6  	"path/filepath"
  7  	"sort"
  8  	"strings"
  9  
 10  	"github.com/spf13/cobra"
 11  	"github.com/TransformerOS/kamaji-go/internal/config"
 12  	"github.com/TransformerOS/kamaji-go/internal/providers"
 13  	"github.com/TransformerOS/kamaji-go/internal/style"
 14  )
 15  
 16  var ragCmd = &cobra.Command{
 17  	Use:   "rag [query] [files...]",
 18  	Short: "Query documents using RAG (Retrieval Augmented Generation)",
 19  	Long:  "Query documents using RAG to get contextual answers from your files",
 20  	Args:  cobra.MinimumNArgs(2),
 21  	RunE:  runRAG,
 22  }
 23  
 24  type Document struct {
 25  	Content  string
 26  	Source   string
 27  	Metadata map[string]string
 28  }
 29  
 30  type DocumentChunk struct {
 31  	Content string
 32  	Source  string
 33  	Score   int
 34  }
 35  
 36  func init() {
 37  	ragCmd.Flags().Float64P("temperature", "t", 0.7, "Temperature for response generation")
 38  	ragCmd.Flags().IntP("chunks", "k", 3, "Number of relevant chunks to retrieve")
 39  }
 40  
 41  func runRAG(cmd *cobra.Command, args []string) error {
 42  	query := args[0]
 43  	files := args[1:]
 44  
 45  	chunks, _ := cmd.Flags().GetInt("chunks")
 46  
 47  	fmt.Printf("\n%s\n", style.Info("📚 Loading documents..."))
 48  
 49  	// Load documents
 50  	documents, err := loadDocuments(files)
 51  	if err != nil {
 52  		return fmt.Errorf("failed to load documents: %w", err)
 53  	}
 54  
 55  	if len(documents) == 0 {
 56  		return fmt.Errorf("no documents loaded")
 57  	}
 58  
 59  	fmt.Printf("%s\n\n", style.Success(fmt.Sprintf("✓ Loaded %d document(s)", len(documents))))
 60  
 61  	// Split documents into chunks
 62  	allChunks := splitDocuments(documents)
 63  
 64  	// Retrieve relevant chunks
 65  	relevantChunks := retrieveRelevant(query, allChunks, chunks)
 66  
 67  	if len(relevantChunks) == 0 {
 68  		fmt.Printf("%s\n", style.Warning("⚠️  No relevant documents found"))
 69  		fmt.Printf("\nAnswer: I couldn't find relevant information in the provided documents.\n\n")
 70  		return nil
 71  	}
 72  
 73  	// Build context
 74  	context := buildContext(relevantChunks)
 75  
 76  	// Create prompt
 77  	prompt := fmt.Sprintf(`Use the following context to answer the question. If you cannot answer based on the context, say so.
 78  
 79  Context:
 80  %s
 81  
 82  Question: %s
 83  
 84  Answer:`, context, query)
 85  
 86  	// Get LLM and generate response
 87  	cfg, _ := config.Load()
 88  	apiKey := os.Getenv("OPENAI_API_KEY")
 89  	if apiKey == "" {
 90  		return fmt.Errorf("OPENAI_API_KEY environment variable not set")
 91  	}
 92  	llm := providers.NewOpenAIProviderWrapper(apiKey, cfg.Model)
 93  
 94  	fmt.Printf("%s\n", style.Info("💭 Thinking..."))
 95  	fmt.Printf("\nAnswer: ")
 96  
 97  	// Get response
 98  	response, err := llm.Call(cmd.Context(), prompt)
 99  	if err != nil {
100  		return fmt.Errorf("failed to generate response: %w", err)
101  	}
102  
103  	fmt.Print(response)
104  
105  	// Show sources
106  	fmt.Printf("\n\n%s\n", style.Info("📄 Sources:"))
107  	sources := make(map[string]bool)
108  	for _, chunk := range relevantChunks {
109  		if !sources[chunk.Source] {
110  			fmt.Printf("  - %s\n", chunk.Source)
111  			sources[chunk.Source] = true
112  		}
113  	}
114  	fmt.Println()
115  
116  	return nil
117  }
118  
119  func loadDocuments(filePaths []string) ([]Document, error) {
120  	var documents []Document
121  
122  	for _, filePath := range filePaths {
123  		if _, err := os.Stat(filePath); os.IsNotExist(err) {
124  			fmt.Printf("%s\n", style.Warning(fmt.Sprintf("⚠️  Warning: File not found: %s", filePath)))
125  			continue
126  		}
127  
128  		content, err := os.ReadFile(filePath)
129  		if err != nil {
130  			fmt.Printf("%s\n", style.Warning(fmt.Sprintf("⚠️  Warning: Could not read %s: %v", filePath, err)))
131  			continue
132  		}
133  
134  		documents = append(documents, Document{
135  			Content: string(content),
136  			Source:  filepath.Base(filePath),
137  			Metadata: map[string]string{
138  				"path": filePath,
139  			},
140  		})
141  
142  		fmt.Printf("%s\n", style.Success(fmt.Sprintf("✓ Loaded: %s", filepath.Base(filePath))))
143  	}
144  
145  	return documents, nil
146  }
147  
148  func splitDocuments(documents []Document) []DocumentChunk {
149  	var chunks []DocumentChunk
150  	chunkSize := 1000
151  	overlap := 200
152  
153  	for _, doc := range documents {
154  		content := doc.Content
155  		
156  		// Simple text splitting
157  		for i := 0; i < len(content); i += chunkSize - overlap {
158  			end := i + chunkSize
159  			if end > len(content) {
160  				end = len(content)
161  			}
162  			
163  			chunk := content[i:end]
164  			if strings.TrimSpace(chunk) != "" {
165  				chunks = append(chunks, DocumentChunk{
166  					Content: chunk,
167  					Source:  doc.Source,
168  					Score:   0,
169  				})
170  			}
171  			
172  			if end == len(content) {
173  				break
174  			}
175  		}
176  	}
177  
178  	return chunks
179  }
180  
181  func retrieveRelevant(query string, chunks []DocumentChunk, k int) []DocumentChunk {
182  	queryTerms := strings.Fields(strings.ToLower(query))
183  
184  	// Score chunks based on keyword matching
185  	for i := range chunks {
186  		content := strings.ToLower(chunks[i].Content)
187  		score := 0
188  		for _, term := range queryTerms {
189  			score += strings.Count(content, term)
190  		}
191  		chunks[i].Score = score
192  	}
193  
194  	// Sort by score (descending)
195  	sort.Slice(chunks, func(i, j int) bool {
196  		return chunks[i].Score > chunks[j].Score
197  	})
198  
199  	// Return top k chunks with score > 0
200  	var relevant []DocumentChunk
201  	for i, chunk := range chunks {
202  		if i >= k || chunk.Score == 0 {
203  			break
204  		}
205  		relevant = append(relevant, chunk)
206  	}
207  
208  	return relevant
209  }
210  
211  func buildContext(chunks []DocumentChunk) string {
212  	var parts []string
213  	for _, chunk := range chunks {
214  		parts = append(parts, chunk.Content)
215  	}
216  	return strings.Join(parts, "\n\n")
217  }