/ src / chunkers / langchain_chunker.py
langchain_chunker.py
  1  from __future__ import annotations
  2  
  3  """LangChain-based chunker implementation."""
  4  import logging
  5  from pathlib import Path
  6  from typing import Any
  7  
  8  from langchain_core.documents import Document
  9  from langchain_text_splitters import (
 10      CharacterTextSplitter,
 11      RecursiveCharacterTextSplitter,
 12      TokenTextSplitter,
 13  )
 14  
 15  from ..constants import DEFAULT_ENCODING
 16  from .constants import (
 17      CHUNK_INDEX_METADATA_KEY,
 18      CHUNKING_METHOD_CHARACTER,
 19      CHUNKING_METHOD_RECURSIVE,
 20      CHUNKING_METHOD_TOKEN,
 21      DEFAULT_CHUNK_OVERLAP,
 22      DEFAULT_CHUNK_SIZE,
 23      RECURSIVE_SEPARATORS,
 24      TOTAL_CHUNKS_METADATA_KEY,
 25  )
 26  from .protocol import Chunker
 27  
 28  
 29  class LangChainChunker:
 30      """Chunk markdown text or LangChain documents using LangChain splitters.
 31  
 32      This is a concrete implementation of the Chunker protocol using
 33      LangChain's text splitter implementations.
 34      """
 35  
 36      def __init__(
 37          self,
 38          chunk_size: int = DEFAULT_CHUNK_SIZE,
 39          chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
 40          method: str = CHUNKING_METHOD_RECURSIVE,
 41      ):
 42          """Initialize the chunker.
 43  
 44          Parameters
 45          ----------
 46          chunk_size
 47              Maximum size of each chunk (characters or tokens depending on method).
 48          chunk_overlap
 49              Number of characters/tokens to overlap between chunks.
 50          method
 51              Chunking method: "recursive" (default), "character", or "token".
 52              - "recursive": Smart splitting that tries to preserve structure
 53              - "character": Simple character-based splitting
 54              - "token": Token-based splitting (requires tiktoken)
 55          """
 56          self.logger = logging.getLogger(__name__)
 57  
 58          self.logger.info(
 59              f"Chunking markdown (size={chunk_size}, "
 60              f"overlap={chunk_overlap})..."
 61          )
 62  
 63          self.chunk_size = chunk_size
 64          self.chunk_overlap = chunk_overlap
 65          self.method = method
 66          self._splitter = self._create_splitter()
 67  
 68      def _create_splitter(self):
 69          """Create the appropriate text splitter based on the chunking method.
 70  
 71          Returns
 72          -------
 73          Text splitter instance (RecursiveCharacterTextSplitter, CharacterTextSplitter,
 74          or TokenTextSplitter) configured with chunk_size and chunk_overlap.
 75  
 76          Raises
 77          ------
 78          ValueError
 79              If the method is not one of the supported methods.
 80          """
 81          splitters = {
 82              CHUNKING_METHOD_RECURSIVE: (
 83                  RecursiveCharacterTextSplitter,
 84                  {"separators": RECURSIVE_SEPARATORS},
 85              ),
 86              CHUNKING_METHOD_CHARACTER: (CharacterTextSplitter, {"separator": "\n\n"}),
 87              CHUNKING_METHOD_TOKEN: (TokenTextSplitter, {}),
 88          }
 89          if self.method not in splitters:
 90              raise ValueError(
 91                  f"Unknown method: {self.method}. Choose from: {', '.join(splitters.keys())}"
 92              )
 93          splitter_class, extra_kwargs = splitters[self.method]
 94          return splitter_class(
 95              chunk_size=self.chunk_size,
 96              chunk_overlap=self.chunk_overlap,
 97              **extra_kwargs
 98          )
 99  
100      def chunk_text(self, text: str, metadata: dict | None = None) -> list[Document]:
101          """Chunk a markdown text string into LangChain Documents."""
102          if not text or not text.strip():
103              return []
104  
105          temp_doc = Document(page_content=text, metadata=metadata or {})
106          chunks = self._splitter.split_documents([temp_doc])
107  
108          for i, chunk in enumerate(chunks):
109              chunk.metadata[CHUNK_INDEX_METADATA_KEY] = i
110              chunk.metadata[TOTAL_CHUNKS_METADATA_KEY] = len(chunks)
111  
112          return chunks
113  
114      def chunk_documents(self, documents: list[Document]) -> list[Document]:
115          """Chunk a list of LangChain Documents into smaller chunks."""
116          if not documents:
117              return []
118  
119          all_chunks = self._splitter.split_documents(documents)
120  
121          for i, chunk in enumerate(all_chunks):
122              chunk.metadata[CHUNK_INDEX_METADATA_KEY] = i
123              chunk.metadata[TOTAL_CHUNKS_METADATA_KEY] = len(all_chunks)
124  
125          return all_chunks
126  
127      def chunk_markdown_file(
128          self, file_path: str, encoding: str = DEFAULT_ENCODING
129      ) -> list[Document]:
130          """Load a markdown file and chunk it.
131  
132          Parameters
133          ----------
134          file_path
135              Path to the markdown file.
136          encoding
137              File encoding (default: utf-8).
138  
139          Returns
140          -------
141          List of chunked LangChain Document objects.
142          """
143          path = Path(file_path)
144          if not path.exists():
145              raise FileNotFoundError(f"Markdown file not found: {file_path}")
146  
147          text = path.read_text(encoding=encoding)
148          metadata = {
149              "source": str(path),
150              "file_name": path.name,
151          }
152  
153          return self.chunk_text(text, metadata=metadata)
154  
155  
156  def create_langchain_chunker(config: dict[str, Any]) -> Chunker:
157      """Create a LangChain chunker from configuration.
158  
159      Parameters
160      ----------
161      config
162          Configuration dictionary with keys:
163          - chunk_size: int (optional) - Size of text chunks in characters
164          - chunk_overlap: int (optional) - Overlap between chunks in characters
165          - method: str (optional) - Chunking method (recursive, character, token)
166  
167      Returns
168      -------
169      Chunker instance.
170      """
171      chunk_size = config.get("chunk_size", DEFAULT_CHUNK_SIZE)
172      chunk_overlap = config.get("chunk_overlap", DEFAULT_CHUNK_OVERLAP)
173      method = config.get("method", CHUNKING_METHOD_RECURSIVE)
174  
175      if not isinstance(chunk_size, int) or chunk_size <= 0:
176          raise ValueError(f"chunk_size must be a positive integer, got: {chunk_size}")
177      if not isinstance(chunk_overlap, int) or chunk_overlap < 0:
178          raise ValueError(f"chunk_overlap must be a non-negative integer, got: {chunk_overlap}")
179      if chunk_overlap >= chunk_size:
180          raise ValueError(f"chunk_overlap ({chunk_overlap}) must be less than chunk_size ({chunk_size})")
181      if method not in [CHUNKING_METHOD_RECURSIVE, CHUNKING_METHOD_CHARACTER, CHUNKING_METHOD_TOKEN]:
182          raise ValueError(
183              f"method must be one of: {CHUNKING_METHOD_RECURSIVE}, {CHUNKING_METHOD_CHARACTER}, "
184              f"{CHUNKING_METHOD_TOKEN}, got: {method}"
185          )
186  
187      return LangChainChunker(
188          chunk_size=chunk_size,
189          chunk_overlap=chunk_overlap,
190          method=method,
191      )