rag.go
1 package cli 2 3 import ( 4 "fmt" 5 "os" 6 "path/filepath" 7 "sort" 8 "strings" 9 10 "github.com/spf13/cobra" 11 "github.com/TransformerOS/kamaji-go/internal/config" 12 "github.com/TransformerOS/kamaji-go/internal/providers" 13 "github.com/TransformerOS/kamaji-go/internal/style" 14 ) 15 16 var ragCmd = &cobra.Command{ 17 Use: "rag [query] [files...]", 18 Short: "Query documents using RAG (Retrieval Augmented Generation)", 19 Long: "Query documents using RAG to get contextual answers from your files", 20 Args: cobra.MinimumNArgs(2), 21 RunE: runRAG, 22 } 23 24 type Document struct { 25 Content string 26 Source string 27 Metadata map[string]string 28 } 29 30 type DocumentChunk struct { 31 Content string 32 Source string 33 Score int 34 } 35 36 func init() { 37 ragCmd.Flags().Float64P("temperature", "t", 0.7, "Temperature for response generation") 38 ragCmd.Flags().IntP("chunks", "k", 3, "Number of relevant chunks to retrieve") 39 } 40 41 func runRAG(cmd *cobra.Command, args []string) error { 42 query := args[0] 43 files := args[1:] 44 45 chunks, _ := cmd.Flags().GetInt("chunks") 46 47 fmt.Printf("\n%s\n", style.Info("📚 Loading documents...")) 48 49 // Load documents 50 documents, err := loadDocuments(files) 51 if err != nil { 52 return fmt.Errorf("failed to load documents: %w", err) 53 } 54 55 if len(documents) == 0 { 56 return fmt.Errorf("no documents loaded") 57 } 58 59 fmt.Printf("%s\n\n", style.Success(fmt.Sprintf("✓ Loaded %d document(s)", len(documents)))) 60 61 // Split documents into chunks 62 allChunks := splitDocuments(documents) 63 64 // Retrieve relevant chunks 65 relevantChunks := retrieveRelevant(query, allChunks, chunks) 66 67 if len(relevantChunks) == 0 { 68 fmt.Printf("%s\n", style.Warning("⚠️ No relevant documents found")) 69 fmt.Printf("\nAnswer: I couldn't find relevant information in the provided documents.\n\n") 70 return nil 71 } 72 73 // Build context 74 context := buildContext(relevantChunks) 75 76 // Create prompt 77 prompt := fmt.Sprintf(`Use the following context to answer the question. If you cannot answer based on the context, say so. 78 79 Context: 80 %s 81 82 Question: %s 83 84 Answer:`, context, query) 85 86 // Get LLM and generate response 87 cfg, _ := config.Load() 88 apiKey := os.Getenv("OPENAI_API_KEY") 89 if apiKey == "" { 90 return fmt.Errorf("OPENAI_API_KEY environment variable not set") 91 } 92 llm := providers.NewOpenAIProviderWrapper(apiKey, cfg.Model) 93 94 fmt.Printf("%s\n", style.Info("💭 Thinking...")) 95 fmt.Printf("\nAnswer: ") 96 97 // Get response 98 response, err := llm.Call(cmd.Context(), prompt) 99 if err != nil { 100 return fmt.Errorf("failed to generate response: %w", err) 101 } 102 103 fmt.Print(response) 104 105 // Show sources 106 fmt.Printf("\n\n%s\n", style.Info("📄 Sources:")) 107 sources := make(map[string]bool) 108 for _, chunk := range relevantChunks { 109 if !sources[chunk.Source] { 110 fmt.Printf(" - %s\n", chunk.Source) 111 sources[chunk.Source] = true 112 } 113 } 114 fmt.Println() 115 116 return nil 117 } 118 119 func loadDocuments(filePaths []string) ([]Document, error) { 120 var documents []Document 121 122 for _, filePath := range filePaths { 123 if _, err := os.Stat(filePath); os.IsNotExist(err) { 124 fmt.Printf("%s\n", style.Warning(fmt.Sprintf("⚠️ Warning: File not found: %s", filePath))) 125 continue 126 } 127 128 content, err := os.ReadFile(filePath) 129 if err != nil { 130 fmt.Printf("%s\n", style.Warning(fmt.Sprintf("⚠️ Warning: Could not read %s: %v", filePath, err))) 131 continue 132 } 133 134 documents = append(documents, Document{ 135 Content: string(content), 136 Source: filepath.Base(filePath), 137 Metadata: map[string]string{ 138 "path": filePath, 139 }, 140 }) 141 142 fmt.Printf("%s\n", style.Success(fmt.Sprintf("✓ Loaded: %s", filepath.Base(filePath)))) 143 } 144 145 return documents, nil 146 } 147 148 func splitDocuments(documents []Document) []DocumentChunk { 149 var chunks []DocumentChunk 150 chunkSize := 1000 151 overlap := 200 152 153 for _, doc := range documents { 154 content := doc.Content 155 156 // Simple text splitting 157 for i := 0; i < len(content); i += chunkSize - overlap { 158 end := i + chunkSize 159 if end > len(content) { 160 end = len(content) 161 } 162 163 chunk := content[i:end] 164 if strings.TrimSpace(chunk) != "" { 165 chunks = append(chunks, DocumentChunk{ 166 Content: chunk, 167 Source: doc.Source, 168 Score: 0, 169 }) 170 } 171 172 if end == len(content) { 173 break 174 } 175 } 176 } 177 178 return chunks 179 } 180 181 func retrieveRelevant(query string, chunks []DocumentChunk, k int) []DocumentChunk { 182 queryTerms := strings.Fields(strings.ToLower(query)) 183 184 // Score chunks based on keyword matching 185 for i := range chunks { 186 content := strings.ToLower(chunks[i].Content) 187 score := 0 188 for _, term := range queryTerms { 189 score += strings.Count(content, term) 190 } 191 chunks[i].Score = score 192 } 193 194 // Sort by score (descending) 195 sort.Slice(chunks, func(i, j int) bool { 196 return chunks[i].Score > chunks[j].Score 197 }) 198 199 // Return top k chunks with score > 0 200 var relevant []DocumentChunk 201 for i, chunk := range chunks { 202 if i >= k || chunk.Score == 0 { 203 break 204 } 205 relevant = append(relevant, chunk) 206 } 207 208 return relevant 209 } 210 211 func buildContext(chunks []DocumentChunk) string { 212 var parts []string 213 for _, chunk := range chunks { 214 parts = append(parts, chunk.Content) 215 } 216 return strings.Join(parts, "\n\n") 217 }