langchain_chunker.py
1 from __future__ import annotations 2 3 """LangChain-based chunker implementation.""" 4 import logging 5 from pathlib import Path 6 from typing import Any 7 8 from langchain_core.documents import Document 9 from langchain_text_splitters import ( 10 CharacterTextSplitter, 11 RecursiveCharacterTextSplitter, 12 TokenTextSplitter, 13 ) 14 15 from ..constants import DEFAULT_ENCODING 16 from .constants import ( 17 CHUNK_INDEX_METADATA_KEY, 18 CHUNKING_METHOD_CHARACTER, 19 CHUNKING_METHOD_RECURSIVE, 20 CHUNKING_METHOD_TOKEN, 21 DEFAULT_CHUNK_OVERLAP, 22 DEFAULT_CHUNK_SIZE, 23 RECURSIVE_SEPARATORS, 24 TOTAL_CHUNKS_METADATA_KEY, 25 ) 26 from .protocol import Chunker 27 28 29 class LangChainChunker: 30 """Chunk markdown text or LangChain documents using LangChain splitters. 31 32 This is a concrete implementation of the Chunker protocol using 33 LangChain's text splitter implementations. 34 """ 35 36 def __init__( 37 self, 38 chunk_size: int = DEFAULT_CHUNK_SIZE, 39 chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, 40 method: str = CHUNKING_METHOD_RECURSIVE, 41 ): 42 """Initialize the chunker. 43 44 Parameters 45 ---------- 46 chunk_size 47 Maximum size of each chunk (characters or tokens depending on method). 48 chunk_overlap 49 Number of characters/tokens to overlap between chunks. 50 method 51 Chunking method: "recursive" (default), "character", or "token". 52 - "recursive": Smart splitting that tries to preserve structure 53 - "character": Simple character-based splitting 54 - "token": Token-based splitting (requires tiktoken) 55 """ 56 self.logger = logging.getLogger(__name__) 57 58 self.logger.info( 59 f"Chunking markdown (size={chunk_size}, " 60 f"overlap={chunk_overlap})..." 61 ) 62 63 self.chunk_size = chunk_size 64 self.chunk_overlap = chunk_overlap 65 self.method = method 66 self._splitter = self._create_splitter() 67 68 def _create_splitter(self): 69 """Create the appropriate text splitter based on the chunking method. 70 71 Returns 72 ------- 73 Text splitter instance (RecursiveCharacterTextSplitter, CharacterTextSplitter, 74 or TokenTextSplitter) configured with chunk_size and chunk_overlap. 75 76 Raises 77 ------ 78 ValueError 79 If the method is not one of the supported methods. 80 """ 81 splitters = { 82 CHUNKING_METHOD_RECURSIVE: ( 83 RecursiveCharacterTextSplitter, 84 {"separators": RECURSIVE_SEPARATORS}, 85 ), 86 CHUNKING_METHOD_CHARACTER: (CharacterTextSplitter, {"separator": "\n\n"}), 87 CHUNKING_METHOD_TOKEN: (TokenTextSplitter, {}), 88 } 89 if self.method not in splitters: 90 raise ValueError( 91 f"Unknown method: {self.method}. Choose from: {', '.join(splitters.keys())}" 92 ) 93 splitter_class, extra_kwargs = splitters[self.method] 94 return splitter_class( 95 chunk_size=self.chunk_size, 96 chunk_overlap=self.chunk_overlap, 97 **extra_kwargs 98 ) 99 100 def chunk_text(self, text: str, metadata: dict | None = None) -> list[Document]: 101 """Chunk a markdown text string into LangChain Documents.""" 102 if not text or not text.strip(): 103 return [] 104 105 temp_doc = Document(page_content=text, metadata=metadata or {}) 106 chunks = self._splitter.split_documents([temp_doc]) 107 108 for i, chunk in enumerate(chunks): 109 chunk.metadata[CHUNK_INDEX_METADATA_KEY] = i 110 chunk.metadata[TOTAL_CHUNKS_METADATA_KEY] = len(chunks) 111 112 return chunks 113 114 def chunk_documents(self, documents: list[Document]) -> list[Document]: 115 """Chunk a list of LangChain Documents into smaller chunks.""" 116 if not documents: 117 return [] 118 119 all_chunks = self._splitter.split_documents(documents) 120 121 for i, chunk in enumerate(all_chunks): 122 chunk.metadata[CHUNK_INDEX_METADATA_KEY] = i 123 chunk.metadata[TOTAL_CHUNKS_METADATA_KEY] = len(all_chunks) 124 125 return all_chunks 126 127 def chunk_markdown_file( 128 self, file_path: str, encoding: str = DEFAULT_ENCODING 129 ) -> list[Document]: 130 """Load a markdown file and chunk it. 131 132 Parameters 133 ---------- 134 file_path 135 Path to the markdown file. 136 encoding 137 File encoding (default: utf-8). 138 139 Returns 140 ------- 141 List of chunked LangChain Document objects. 142 """ 143 path = Path(file_path) 144 if not path.exists(): 145 raise FileNotFoundError(f"Markdown file not found: {file_path}") 146 147 text = path.read_text(encoding=encoding) 148 metadata = { 149 "source": str(path), 150 "file_name": path.name, 151 } 152 153 return self.chunk_text(text, metadata=metadata) 154 155 156 def create_langchain_chunker(config: dict[str, Any]) -> Chunker: 157 """Create a LangChain chunker from configuration. 158 159 Parameters 160 ---------- 161 config 162 Configuration dictionary with keys: 163 - chunk_size: int (optional) - Size of text chunks in characters 164 - chunk_overlap: int (optional) - Overlap between chunks in characters 165 - method: str (optional) - Chunking method (recursive, character, token) 166 167 Returns 168 ------- 169 Chunker instance. 170 """ 171 chunk_size = config.get("chunk_size", DEFAULT_CHUNK_SIZE) 172 chunk_overlap = config.get("chunk_overlap", DEFAULT_CHUNK_OVERLAP) 173 method = config.get("method", CHUNKING_METHOD_RECURSIVE) 174 175 if not isinstance(chunk_size, int) or chunk_size <= 0: 176 raise ValueError(f"chunk_size must be a positive integer, got: {chunk_size}") 177 if not isinstance(chunk_overlap, int) or chunk_overlap < 0: 178 raise ValueError(f"chunk_overlap must be a non-negative integer, got: {chunk_overlap}") 179 if chunk_overlap >= chunk_size: 180 raise ValueError(f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})") 181 if method not in [CHUNKING_METHOD_RECURSIVE, CHUNKING_METHOD_CHARACTER, CHUNKING_METHOD_TOKEN]: 182 raise ValueError( 183 f"method must be one of: {CHUNKING_METHOD_RECURSIVE}, {CHUNKING_METHOD_CHARACTER}, " 184 f"{CHUNKING_METHOD_TOKEN}, got: {method}" 185 ) 186 187 return LangChainChunker( 188 chunk_size=chunk_size, 189 chunk_overlap=chunk_overlap, 190 method=method, 191 )