/ haystack / components / preprocessors / recursive_splitter.py
recursive_splitter.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import re
  6  from copy import deepcopy
  7  from typing import Any, Literal
  8  
  9  from haystack import Document, component, logging
 10  from haystack.lazy_imports import LazyImport
 11  
 12  with LazyImport("Run 'pip install tiktoken'") as tiktoken_imports:
 13      import tiktoken
 14  
 15  logger = logging.getLogger(__name__)
 16  
 17  
 18  @component
 19  class RecursiveDocumentSplitter:
 20      """
 21      Recursively chunk text into smaller chunks.
 22  
 23      This component is used to split text into smaller chunks, it does so by recursively applying a list of separators
 24      to the text.
 25  
 26      The separators are applied in the order they are provided, typically this is a list of separators that are
 27      applied in a specific order, being the last separator the most specific one.
 28  
 29      Each separator is applied to the text, it then checks each of the resulting chunks, it keeps the chunks that
 30      are within the split_length, for the ones that are larger than the split_length, it applies the next separator in the
 31      list to the remaining text.
 32  
 33      This is done until all chunks are smaller than the split_length parameter.
 34  
 35      Example:
 36  
 37      ```python
 38      from haystack import Document
 39      from haystack.components.preprocessors import RecursiveDocumentSplitter
 40  
 41      chunker = RecursiveDocumentSplitter(split_length=260, split_overlap=0, separators=["\\n\\n", "\\n", ".", " "])
 42      text = ('''Artificial intelligence (AI) - Introduction
 43  
 44      AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems.
 45      AI technology is widely used throughout industry, government, and science. Some high-profile applications include advanced web search engines; recommendation systems; interacting via human speech; autonomous vehicles; generative and creative tools; and superhuman play and analysis in strategy games.''')
 46      doc = Document(content=text)
 47      doc_chunks = chunker.run([doc])
 48      print(doc_chunks["documents"])
 49      # [
 50      # Document(id=..., content: 'Artificial intelligence (AI) - Introduction\\n\\n', meta: {'original_id': '...', 'split_id': 0, 'split_idx_start': 0, '_split_overlap': []})
 51      # Document(id=..., content: 'AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems.\\n', meta: {'original_id': '...', 'split_id': 1, 'split_idx_start': 45, '_split_overlap': []})
 52      # Document(id=..., content: 'AI technology is widely used throughout industry, government, and science.', meta: {'original_id': '...', 'split_id': 2, 'split_idx_start': 142, '_split_overlap': []})
 53      # Document(id=..., content: ' Some high-profile applications include advanced web search engines; recommendation systems; interac...', meta: {'original_id': '...', 'split_id': 3, 'split_idx_start': 216, '_split_overlap': []})
 54      # ]
 55      ```
 56      """  # noqa: E501
 57  
 58      def __init__(
 59          self,
 60          *,
 61          split_length: int = 200,
 62          split_overlap: int = 0,
 63          split_unit: Literal["word", "char", "token"] = "word",
 64          separators: list[str] | None = None,
 65          sentence_splitter_params: dict[str, Any] | None = None,
 66      ) -> None:
 67          """
 68          Initializes a RecursiveDocumentSplitter.
 69  
 70          :param split_length: The maximum length of each chunk by default in words, but can be in characters or tokens.
 71              See the `split_units` parameter.
 72          :param split_overlap: The number of characters to overlap between consecutive chunks.
 73          :param split_unit: The unit of the split_length parameter. It can be either "word", "char", or "token".
 74              If "token" is selected, the text will be split into tokens using the tiktoken tokenizer (o200k_base).
 75          :param separators: An optional list of separator strings to use for splitting the text. The string
 76              separators will be treated as regular expressions unless the separator is "sentence", in that case the
 77              text will be split into sentences using a custom sentence tokenizer based on NLTK.
 78              See: haystack.components.preprocessors.sentence_tokenizer.SentenceSplitter.
 79              If no separators are provided, the default separators ["\\n\\n", "sentence", "\\n", " "] are used.
 80          :param sentence_splitter_params: Optional parameters to pass to the sentence tokenizer.
 81              See: haystack.components.preprocessors.sentence_tokenizer.SentenceSplitter for more information.
 82  
 83          :raises ValueError: If the overlap is greater than or equal to the chunk size or if the overlap is negative, or
 84                              if any separator is not a string.
 85          """
 86          self.split_length = split_length
 87          self.split_overlap = split_overlap
 88          self.split_units = split_unit
 89          self.separators = separators if separators else ["\n\n", "sentence", "\n", " "]  # default separators
 90          self._check_params()
 91          self.nltk_tokenizer = None
 92          self.sentence_splitter_params = (
 93              {"keep_white_spaces": True} if sentence_splitter_params is None else sentence_splitter_params
 94          )
 95          self.tiktoken_tokenizer: "tiktoken.Encoding" | None = None
 96          self._is_warmed_up = False
 97  
 98      def warm_up(self) -> None:
 99          """
100          Warm up the sentence tokenizer and tiktoken tokenizer if needed.
101          """
102          if "sentence" in self.separators:
103              self.nltk_tokenizer = self._get_custom_sentence_tokenizer(self.sentence_splitter_params)
104          if self.split_units == "token":
105              tiktoken_imports.check()
106              self.tiktoken_tokenizer = tiktoken.get_encoding("o200k_base")
107          self._is_warmed_up = True
108  
109      def _check_params(self) -> None:
110          if self.split_length < 1:
111              raise ValueError("Split length must be at least 1 character.")
112          if self.split_overlap < 0:
113              raise ValueError("Overlap must be greater than zero.")
114          if self.split_overlap >= self.split_length:
115              raise ValueError("Overlap cannot be greater than or equal to the chunk size.")
116          if not all(isinstance(separator, str) for separator in self.separators):
117              raise ValueError("All separators must be strings.")
118  
119      @staticmethod
120      def _get_custom_sentence_tokenizer(sentence_splitter_params: dict[str, Any]) -> Any:
121          from haystack.components.preprocessors.sentence_tokenizer import SentenceSplitter
122  
123          return SentenceSplitter(**sentence_splitter_params)
124  
125      def _split_chunk(self, current_chunk: str) -> tuple[str, str]:
126          """
127          Splits a chunk based on the split_length and split_units attribute.
128  
129          :param current_chunk: The current chunk to be split.
130          :returns:
131              A tuple containing the current chunk and the remaining chunk.
132          """
133          if self.split_units == "word":
134              words = current_chunk.split()
135              current_chunk = " ".join(words[: self.split_length])
136              remaining_words = words[self.split_length :]
137              return current_chunk, " ".join(remaining_words)
138          if self.split_units == "char":
139              text = current_chunk
140              current_chunk = text[: self.split_length]
141              remaining_chars = text[self.split_length :]
142              return current_chunk, remaining_chars
143  
144          # at this point we know that the tokenizer is already initialized
145          tokens = self.tiktoken_tokenizer.encode(current_chunk)  # type: ignore
146          current_tokens = tokens[: self.split_length]
147          remaining_tokens = tokens[self.split_length :]
148          return self.tiktoken_tokenizer.decode(current_tokens), self.tiktoken_tokenizer.decode(remaining_tokens)  # type: ignore
149  
150      def _apply_overlap(self, chunks: list[str]) -> list[str]:
151          """
152          Applies an overlap between consecutive chunks if the chunk_overlap attribute is greater than zero.
153  
154          Works for both word- and character-level splitting. It trims the last chunk if it exceeds the split_length and
155          adds the trimmed content to the next chunk. If the last chunk is still too long after trimming, it splits it
156          and adds the first chunk to the list. This process continues until the last chunk is within the split_length.
157  
158          :param chunks: A list of text chunks.
159          :returns:
160              A list of text chunks with the overlap applied.
161          """
162          overlapped_chunks: list[str] = []
163  
164          for idx, chunk in enumerate(chunks):
165              if idx == 0:
166                  overlapped_chunks.append(chunk)
167                  continue
168  
169              # get the overlap between the current and previous chunk
170              overlap, prev_chunk = self._get_overlap(overlapped_chunks)
171              if overlap == prev_chunk:
172                  logger.warning(
173                      "Overlap is the same as the previous chunk. "
174                      "Consider increasing the `split_length` parameter or decreasing the `split_overlap` parameter."
175                  )
176  
177              current_chunk = self._create_chunk_starting_with_overlap(chunk, overlap)
178  
179              # if this new chunk exceeds 'split_length', trim it and move the remaining text to the next chunk
180              # if this is the last chunk, another new chunk will contain the trimmed text preceded by the overlap
181              # of the last chunk
182              if self._chunk_length(current_chunk) > self.split_length:
183                  current_chunk, remaining_text = self._split_chunk(current_chunk)
184                  if idx < len(chunks) - 1:
185                      if self.split_units == "word":
186                          chunks[idx + 1] = remaining_text + " " + chunks[idx + 1]
187                      elif self.split_units == "token":
188                          # For token-based splitting, combine at token level
189                          # at this point we know that the tokenizer is already initialized
190                          remaining_tokens = self.tiktoken_tokenizer.encode(remaining_text)  # type: ignore
191                          next_chunk_tokens = self.tiktoken_tokenizer.encode(chunks[idx + 1])  # type: ignore
192                          chunks[idx + 1] = self.tiktoken_tokenizer.decode(remaining_tokens + next_chunk_tokens)  # type: ignore
193                      else:  # char
194                          chunks[idx + 1] = remaining_text + chunks[idx + 1]
195                  elif remaining_text:
196                      # create a new chunk with the trimmed text preceded by the overlap of the last chunk
197                      overlapped_chunks.append(current_chunk)
198                      chunk = remaining_text
199                      overlap, _ = self._get_overlap(overlapped_chunks)
200                      current_chunk = self._create_chunk_starting_with_overlap(chunk, overlap)
201  
202              overlapped_chunks.append(current_chunk)
203  
204              # it can still be that the new last chunk exceeds the 'split_length'
205              # continue splitting until the last chunk is within 'split_length'
206              if idx == len(chunks) - 1 and self._chunk_length(current_chunk) > self.split_length:
207                  last_chunk = overlapped_chunks.pop()
208                  first_chunk, remaining_chunk = self._split_chunk(last_chunk)
209                  overlapped_chunks.append(first_chunk)
210  
211                  while remaining_chunk:
212                      # combine overlap with remaining chunk
213                      overlap, _ = self._get_overlap(overlapped_chunks)
214                      current = self._create_chunk_starting_with_overlap(remaining_chunk, overlap)
215  
216                      # if it fits within split_length we are done
217                      if self._chunk_length(current) <= self.split_length:
218                          overlapped_chunks.append(current)
219                          break
220  
221                      # otherwise split it again
222                      first_chunk, remaining_chunk = self._split_chunk(current)
223                      overlapped_chunks.append(first_chunk)
224  
225          return overlapped_chunks
226  
227      def _create_chunk_starting_with_overlap(self, chunk: str, overlap: str) -> str:
228          if self.split_units == "word":
229              current_chunk = overlap + " " + chunk
230          elif self.split_units == "token":
231              # For token-based splitting, combine at token level
232              # at this point we know that the tokenizer is already initialized
233              overlap_tokens = self.tiktoken_tokenizer.encode(overlap)  # type: ignore
234              chunk_tokens = self.tiktoken_tokenizer.encode(chunk)  # type: ignore
235              current_chunk = self.tiktoken_tokenizer.decode(overlap_tokens + chunk_tokens)  # type: ignore
236          else:  # char
237              current_chunk = overlap + chunk
238          return current_chunk
239  
240      def _get_overlap(self, overlapped_chunks: list[str]) -> tuple[str, str]:
241          """Get the previous overlapped chunk instead of the original chunk."""
242          prev_chunk = overlapped_chunks[-1]
243          overlap_start = max(0, self._chunk_length(prev_chunk) - self.split_overlap)
244  
245          if self.split_units == "word":
246              word_chunks = prev_chunk.split()
247              overlap = " ".join(word_chunks[overlap_start:])
248          elif self.split_units == "token":
249              # For token-based splitting, handle overlap at token level
250              # at this point we know that the tokenizer is already initialized
251              tokens = self.tiktoken_tokenizer.encode(prev_chunk)  # type: ignore
252              overlap_tokens = tokens[overlap_start:]
253              overlap = self.tiktoken_tokenizer.decode(overlap_tokens)  # type: ignore
254          else:  # char
255              overlap = prev_chunk[overlap_start:]
256  
257          return overlap, prev_chunk
258  
259      def _chunk_length(self, text: str) -> int:
260          """
261          Get the length of the chunk in the specified units (words, characters, or tokens).
262  
263          :param text: The text to measure.
264          :returns: The length of the text in the specified units.
265          """
266          if self.split_units == "word":
267              words = [word for word in text.split(" ") if word]
268              return len(words)
269          if self.split_units == "char":
270              return len(text)
271          # token
272          # at this point we know that the tokenizer is already initialized
273          return len(self.tiktoken_tokenizer.encode(text))  # type: ignore
274  
275      def _chunk_text(self, text: str) -> list[str]:
276          """
277          Recursive chunking algorithm that divides text into smaller chunks based on a list of separator characters.
278  
279          It starts with a list of separator characters (e.g., ["\n\n", "sentence", "\n", " "]) and attempts to divide
280          the text using the first separator. If the resulting chunks are still larger than the specified chunk size,
281          it moves to the next separator in the list. This process continues recursively, progressively applying each
282          specific separator until the chunks meet the desired size criteria.
283  
284          :param text: The text to be split into chunks.
285          :returns:
286              A list of text chunks.
287          """
288          if self._chunk_length(text) <= self.split_length:
289              return [text]
290  
291          for curr_separator in self.separators:
292              if curr_separator == "sentence":
293                  # re. ignore: correct SentenceSplitter initialization is checked at the initialization of the component
294                  sentence_with_spans = self.nltk_tokenizer.split_sentences(text)  # type: ignore
295                  splits = [sentence["sentence"] for sentence in sentence_with_spans]
296              else:
297                  # add escape "\" to the separator and wrapped it in a group so that it's included in the splits as well
298                  escaped_separator = re.escape(curr_separator)
299                  escaped_separator = f"({escaped_separator})"
300  
301                  # split the text and merge every two consecutive splits, i.e.: the text and the separator after it
302                  splits = re.split(escaped_separator, text)
303                  splits = [
304                      "".join([splits[i], splits[i + 1]]) if i < len(splits) - 1 else splits[i]
305                      for i in range(0, len(splits), 2)
306                  ]
307  
308                  # remove last split if it's empty
309                  splits = splits[:-1] if splits[-1] == "" else splits
310  
311              if len(splits) == 1:  # go to next separator, if current separator not found in the text
312                  continue
313  
314              chunks = []
315              current_chunk: list[str] = []
316              current_length = 0
317  
318              # check splits, if any is too long, recursively chunk it, otherwise add to current chunk
319              for split in splits:
320                  split_text = split
321  
322                  # if adding this split exceeds chunk_size, process current_chunk
323                  if current_length + self._chunk_length(split_text) > self.split_length:
324                      # process current_chunk
325                      if current_chunk:  # keep the good splits
326                          chunks.append("".join(current_chunk))
327                          current_chunk = []
328                          current_length = 0
329  
330                      # recursively handle splits that are too large
331                      if self._chunk_length(split_text) > self.split_length:
332                          if curr_separator == self.separators[-1]:
333                              # tried last separator, can't split further, do a fixed-split based on word/character/token
334                              fall_back_chunks = self._fall_back_to_fixed_chunking(split_text, self.split_units)
335                              chunks.extend(fall_back_chunks)
336                          else:
337                              chunks.extend(self._chunk_text(split_text))
338  
339                      else:
340                          current_chunk.append(split_text)
341                          current_length += self._chunk_length(split_text)
342                  else:
343                      current_chunk.append(split_text)
344                      current_length += self._chunk_length(split_text)
345  
346              if current_chunk:
347                  chunks.append("".join(current_chunk))
348  
349              if self.split_overlap > 0:
350                  chunks = self._apply_overlap(chunks)
351  
352              if chunks:
353                  return chunks
354  
355          # if no separator worked, fall back to word- or character-level chunking
356          return self._fall_back_to_fixed_chunking(text, self.split_units)
357  
358      def _fall_back_to_fixed_chunking(self, text: str, split_units: Literal["word", "char", "token"]) -> list[str]:
359          """
360          Fall back to a fixed chunking approach if no separator works for the text.
361  
362          Splits the text into smaller chunks based on the split_length and split_units attributes, either by words,
363          characters, or tokens.
364  
365          :param text: The text to be split into chunks.
366          :param split_units: The unit of the split_length parameter. It can be either "word", "char", or "token".
367          :returns:
368              A list of text chunks.
369          """
370          chunks = []
371  
372          if split_units == "word":
373              words = re.findall(r"\S+|\s+", text)
374              current_chunk = []
375              current_length = 0
376  
377              for word in words:
378                  if word != " ":
379                      current_chunk.append(word)
380                      current_length += 1
381                      if current_length == self.split_length and current_chunk:
382                          chunks.append("".join(current_chunk))
383                          current_chunk = []
384                          current_length = 0
385                  else:
386                      current_chunk.append(word)
387  
388              if current_chunk:
389                  chunks.append("".join(current_chunk))
390          elif split_units == "char":
391              for i in range(0, self._chunk_length(text), self.split_length):
392                  chunks.append(text[i : i + self.split_length])
393          else:  # token
394              # at this point we know that the tokenizer is already initialized
395              tokens = self.tiktoken_tokenizer.encode(text)  # type: ignore
396              for i in range(0, len(tokens), self.split_length):
397                  chunk_tokens = tokens[i : i + self.split_length]
398                  chunks.append(self.tiktoken_tokenizer.decode(chunk_tokens))  # type: ignore
399          return chunks
400  
401      def _add_overlap_info(self, curr_pos: int, new_doc: Document, new_docs: list[Document]) -> None:
402          prev_doc = new_docs[-1]
403          overlap_length = self._chunk_length(prev_doc.content) - (curr_pos - prev_doc.meta["split_idx_start"])  # type: ignore
404          if overlap_length > 0:
405              prev_doc.meta["_split_overlap"].append({"doc_id": new_doc.id, "range": (0, overlap_length)})
406              new_doc.meta["_split_overlap"].append(
407                  {
408                      "doc_id": prev_doc.id,
409                      "range": (
410                          self._chunk_length(prev_doc.content) - overlap_length,  # type: ignore
411                          self._chunk_length(prev_doc.content),  # type: ignore
412                      ),
413                  }
414              )
415  
416      def _run_one(self, doc: Document) -> list[Document]:
417          chunks = self._chunk_text(doc.content)  # type: ignore # the caller already check for a non-empty doc.content
418          chunks = chunks[:-1] if len(chunks[-1]) == 0 else chunks  # remove last empty chunk if it exists
419          current_position = 0
420          current_page = 1
421  
422          new_docs: list[Document] = []
423  
424          for split_nr, chunk in enumerate(chunks):
425              meta = deepcopy(doc.meta)
426              meta["parent_id"] = doc.id
427              meta["split_id"] = split_nr
428              meta["split_idx_start"] = current_position
429              meta["_split_overlap"] = [] if self.split_overlap > 0 else None
430              new_doc = Document(content=chunk, meta=meta)
431  
432              # add overlap information to the previous and current doc
433              if split_nr > 0 and self.split_overlap > 0:
434                  self._add_overlap_info(current_position, new_doc, new_docs)
435  
436              # count page breaks in the chunk
437              current_page += chunk.count("\f")
438  
439              # if there are consecutive page breaks at the end with no more text, adjust the page number
440              # e.g: "text\f\f\f" -> 3 page breaks, but current_page should be 1
441              consecutive_page_breaks = len(chunk) - len(chunk.rstrip("\f"))
442  
443              if consecutive_page_breaks > 0:
444                  new_doc.meta["page_number"] = current_page - consecutive_page_breaks
445              else:
446                  new_doc.meta["page_number"] = current_page
447  
448              # keep the new chunk doc and update the current position
449              new_docs.append(new_doc)
450              current_position += len(chunk) - (self.split_overlap if split_nr < len(chunks) - 1 else 0)
451  
452          return new_docs
453  
454      @component.output_types(documents=list[Document])
455      def run(self, documents: list[Document]) -> dict[str, list[Document]]:
456          """
457          Split a list of documents into documents with smaller chunks of text.
458  
459          :param documents: List of Documents to split.
460          :returns:
461              A dictionary containing a key "documents" with a List of Documents with smaller chunks of text corresponding
462              to the input documents.
463          """
464          if not self._is_warmed_up and ("sentence" in self.separators or self.split_units == "token"):
465              self.warm_up()
466  
467          docs = []
468          for doc in documents:
469              if not doc.content or doc.content == "":
470                  logger.warning("Document ID {doc_id} has an empty content. Skipping this document.", doc_id=doc.id)
471                  continue
472              docs.extend(self._run_one(doc))
473  
474          return {"documents": docs}