recursive_splitter.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import re 6 from copy import deepcopy 7 from typing import Any, Literal 8 9 from haystack import Document, component, logging 10 from haystack.lazy_imports import LazyImport 11 12 with LazyImport("Run 'pip install tiktoken'") as tiktoken_imports: 13 import tiktoken 14 15 logger = logging.getLogger(__name__) 16 17 18 @component 19 class RecursiveDocumentSplitter: 20 """ 21 Recursively chunk text into smaller chunks. 22 23 This component is used to split text into smaller chunks, it does so by recursively applying a list of separators 24 to the text. 25 26 The separators are applied in the order they are provided, typically this is a list of separators that are 27 applied in a specific order, being the last separator the most specific one. 28 29 Each separator is applied to the text, it then checks each of the resulting chunks, it keeps the chunks that 30 are within the split_length, for the ones that are larger than the split_length, it applies the next separator in the 31 list to the remaining text. 32 33 This is done until all chunks are smaller than the split_length parameter. 34 35 Example: 36 37 ```python 38 from haystack import Document 39 from haystack.components.preprocessors import RecursiveDocumentSplitter 40 41 chunker = RecursiveDocumentSplitter(split_length=260, split_overlap=0, separators=["\\n\\n", "\\n", ".", " "]) 42 text = ('''Artificial intelligence (AI) - Introduction 43 44 AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems. 45 AI technology is widely used throughout industry, government, and science. Some high-profile applications include advanced web search engines; recommendation systems; interacting via human speech; autonomous vehicles; generative and creative tools; and superhuman play and analysis in strategy games.''') 46 doc = Document(content=text) 47 doc_chunks = chunker.run([doc]) 48 print(doc_chunks["documents"]) 49 # [ 50 # Document(id=..., content: 'Artificial intelligence (AI) - Introduction\\n\\n', meta: {'original_id': '...', 'split_id': 0, 'split_idx_start': 0, '_split_overlap': []}) 51 # Document(id=..., content: 'AI, in its broadest sense, is intelligence exhibited by machines, particularly computer systems.\\n', meta: {'original_id': '...', 'split_id': 1, 'split_idx_start': 45, '_split_overlap': []}) 52 # Document(id=..., content: 'AI technology is widely used throughout industry, government, and science.', meta: {'original_id': '...', 'split_id': 2, 'split_idx_start': 142, '_split_overlap': []}) 53 # Document(id=..., content: ' Some high-profile applications include advanced web search engines; recommendation systems; interac...', meta: {'original_id': '...', 'split_id': 3, 'split_idx_start': 216, '_split_overlap': []}) 54 # ] 55 ``` 56 """ # noqa: E501 57 58 def __init__( 59 self, 60 *, 61 split_length: int = 200, 62 split_overlap: int = 0, 63 split_unit: Literal["word", "char", "token"] = "word", 64 separators: list[str] | None = None, 65 sentence_splitter_params: dict[str, Any] | None = None, 66 ) -> None: 67 """ 68 Initializes a RecursiveDocumentSplitter. 69 70 :param split_length: The maximum length of each chunk by default in words, but can be in characters or tokens. 71 See the `split_units` parameter. 72 :param split_overlap: The number of characters to overlap between consecutive chunks. 73 :param split_unit: The unit of the split_length parameter. It can be either "word", "char", or "token". 74 If "token" is selected, the text will be split into tokens using the tiktoken tokenizer (o200k_base). 75 :param separators: An optional list of separator strings to use for splitting the text. The string 76 separators will be treated as regular expressions unless the separator is "sentence", in that case the 77 text will be split into sentences using a custom sentence tokenizer based on NLTK. 78 See: haystack.components.preprocessors.sentence_tokenizer.SentenceSplitter. 79 If no separators are provided, the default separators ["\\n\\n", "sentence", "\\n", " "] are used. 80 :param sentence_splitter_params: Optional parameters to pass to the sentence tokenizer. 81 See: haystack.components.preprocessors.sentence_tokenizer.SentenceSplitter for more information. 82 83 :raises ValueError: If the overlap is greater than or equal to the chunk size or if the overlap is negative, or 84 if any separator is not a string. 85 """ 86 self.split_length = split_length 87 self.split_overlap = split_overlap 88 self.split_units = split_unit 89 self.separators = separators if separators else ["\n\n", "sentence", "\n", " "] # default separators 90 self._check_params() 91 self.nltk_tokenizer = None 92 self.sentence_splitter_params = ( 93 {"keep_white_spaces": True} if sentence_splitter_params is None else sentence_splitter_params 94 ) 95 self.tiktoken_tokenizer: "tiktoken.Encoding" | None = None 96 self._is_warmed_up = False 97 98 def warm_up(self) -> None: 99 """ 100 Warm up the sentence tokenizer and tiktoken tokenizer if needed. 101 """ 102 if "sentence" in self.separators: 103 self.nltk_tokenizer = self._get_custom_sentence_tokenizer(self.sentence_splitter_params) 104 if self.split_units == "token": 105 tiktoken_imports.check() 106 self.tiktoken_tokenizer = tiktoken.get_encoding("o200k_base") 107 self._is_warmed_up = True 108 109 def _check_params(self) -> None: 110 if self.split_length < 1: 111 raise ValueError("Split length must be at least 1 character.") 112 if self.split_overlap < 0: 113 raise ValueError("Overlap must be greater than zero.") 114 if self.split_overlap >= self.split_length: 115 raise ValueError("Overlap cannot be greater than or equal to the chunk size.") 116 if not all(isinstance(separator, str) for separator in self.separators): 117 raise ValueError("All separators must be strings.") 118 119 @staticmethod 120 def _get_custom_sentence_tokenizer(sentence_splitter_params: dict[str, Any]) -> Any: 121 from haystack.components.preprocessors.sentence_tokenizer import SentenceSplitter 122 123 return SentenceSplitter(**sentence_splitter_params) 124 125 def _split_chunk(self, current_chunk: str) -> tuple[str, str]: 126 """ 127 Splits a chunk based on the split_length and split_units attribute. 128 129 :param current_chunk: The current chunk to be split. 130 :returns: 131 A tuple containing the current chunk and the remaining chunk. 132 """ 133 if self.split_units == "word": 134 words = current_chunk.split() 135 current_chunk = " ".join(words[: self.split_length]) 136 remaining_words = words[self.split_length :] 137 return current_chunk, " ".join(remaining_words) 138 if self.split_units == "char": 139 text = current_chunk 140 current_chunk = text[: self.split_length] 141 remaining_chars = text[self.split_length :] 142 return current_chunk, remaining_chars 143 144 # at this point we know that the tokenizer is already initialized 145 tokens = self.tiktoken_tokenizer.encode(current_chunk) # type: ignore 146 current_tokens = tokens[: self.split_length] 147 remaining_tokens = tokens[self.split_length :] 148 return self.tiktoken_tokenizer.decode(current_tokens), self.tiktoken_tokenizer.decode(remaining_tokens) # type: ignore 149 150 def _apply_overlap(self, chunks: list[str]) -> list[str]: 151 """ 152 Applies an overlap between consecutive chunks if the chunk_overlap attribute is greater than zero. 153 154 Works for both word- and character-level splitting. It trims the last chunk if it exceeds the split_length and 155 adds the trimmed content to the next chunk. If the last chunk is still too long after trimming, it splits it 156 and adds the first chunk to the list. This process continues until the last chunk is within the split_length. 157 158 :param chunks: A list of text chunks. 159 :returns: 160 A list of text chunks with the overlap applied. 161 """ 162 overlapped_chunks: list[str] = [] 163 164 for idx, chunk in enumerate(chunks): 165 if idx == 0: 166 overlapped_chunks.append(chunk) 167 continue 168 169 # get the overlap between the current and previous chunk 170 overlap, prev_chunk = self._get_overlap(overlapped_chunks) 171 if overlap == prev_chunk: 172 logger.warning( 173 "Overlap is the same as the previous chunk. " 174 "Consider increasing the `split_length` parameter or decreasing the `split_overlap` parameter." 175 ) 176 177 current_chunk = self._create_chunk_starting_with_overlap(chunk, overlap) 178 179 # if this new chunk exceeds 'split_length', trim it and move the remaining text to the next chunk 180 # if this is the last chunk, another new chunk will contain the trimmed text preceded by the overlap 181 # of the last chunk 182 if self._chunk_length(current_chunk) > self.split_length: 183 current_chunk, remaining_text = self._split_chunk(current_chunk) 184 if idx < len(chunks) - 1: 185 if self.split_units == "word": 186 chunks[idx + 1] = remaining_text + " " + chunks[idx + 1] 187 elif self.split_units == "token": 188 # For token-based splitting, combine at token level 189 # at this point we know that the tokenizer is already initialized 190 remaining_tokens = self.tiktoken_tokenizer.encode(remaining_text) # type: ignore 191 next_chunk_tokens = self.tiktoken_tokenizer.encode(chunks[idx + 1]) # type: ignore 192 chunks[idx + 1] = self.tiktoken_tokenizer.decode(remaining_tokens + next_chunk_tokens) # type: ignore 193 else: # char 194 chunks[idx + 1] = remaining_text + chunks[idx + 1] 195 elif remaining_text: 196 # create a new chunk with the trimmed text preceded by the overlap of the last chunk 197 overlapped_chunks.append(current_chunk) 198 chunk = remaining_text 199 overlap, _ = self._get_overlap(overlapped_chunks) 200 current_chunk = self._create_chunk_starting_with_overlap(chunk, overlap) 201 202 overlapped_chunks.append(current_chunk) 203 204 # it can still be that the new last chunk exceeds the 'split_length' 205 # continue splitting until the last chunk is within 'split_length' 206 if idx == len(chunks) - 1 and self._chunk_length(current_chunk) > self.split_length: 207 last_chunk = overlapped_chunks.pop() 208 first_chunk, remaining_chunk = self._split_chunk(last_chunk) 209 overlapped_chunks.append(first_chunk) 210 211 while remaining_chunk: 212 # combine overlap with remaining chunk 213 overlap, _ = self._get_overlap(overlapped_chunks) 214 current = self._create_chunk_starting_with_overlap(remaining_chunk, overlap) 215 216 # if it fits within split_length we are done 217 if self._chunk_length(current) <= self.split_length: 218 overlapped_chunks.append(current) 219 break 220 221 # otherwise split it again 222 first_chunk, remaining_chunk = self._split_chunk(current) 223 overlapped_chunks.append(first_chunk) 224 225 return overlapped_chunks 226 227 def _create_chunk_starting_with_overlap(self, chunk: str, overlap: str) -> str: 228 if self.split_units == "word": 229 current_chunk = overlap + " " + chunk 230 elif self.split_units == "token": 231 # For token-based splitting, combine at token level 232 # at this point we know that the tokenizer is already initialized 233 overlap_tokens = self.tiktoken_tokenizer.encode(overlap) # type: ignore 234 chunk_tokens = self.tiktoken_tokenizer.encode(chunk) # type: ignore 235 current_chunk = self.tiktoken_tokenizer.decode(overlap_tokens + chunk_tokens) # type: ignore 236 else: # char 237 current_chunk = overlap + chunk 238 return current_chunk 239 240 def _get_overlap(self, overlapped_chunks: list[str]) -> tuple[str, str]: 241 """Get the previous overlapped chunk instead of the original chunk.""" 242 prev_chunk = overlapped_chunks[-1] 243 overlap_start = max(0, self._chunk_length(prev_chunk) - self.split_overlap) 244 245 if self.split_units == "word": 246 word_chunks = prev_chunk.split() 247 overlap = " ".join(word_chunks[overlap_start:]) 248 elif self.split_units == "token": 249 # For token-based splitting, handle overlap at token level 250 # at this point we know that the tokenizer is already initialized 251 tokens = self.tiktoken_tokenizer.encode(prev_chunk) # type: ignore 252 overlap_tokens = tokens[overlap_start:] 253 overlap = self.tiktoken_tokenizer.decode(overlap_tokens) # type: ignore 254 else: # char 255 overlap = prev_chunk[overlap_start:] 256 257 return overlap, prev_chunk 258 259 def _chunk_length(self, text: str) -> int: 260 """ 261 Get the length of the chunk in the specified units (words, characters, or tokens). 262 263 :param text: The text to measure. 264 :returns: The length of the text in the specified units. 265 """ 266 if self.split_units == "word": 267 words = [word for word in text.split(" ") if word] 268 return len(words) 269 if self.split_units == "char": 270 return len(text) 271 # token 272 # at this point we know that the tokenizer is already initialized 273 return len(self.tiktoken_tokenizer.encode(text)) # type: ignore 274 275 def _chunk_text(self, text: str) -> list[str]: 276 """ 277 Recursive chunking algorithm that divides text into smaller chunks based on a list of separator characters. 278 279 It starts with a list of separator characters (e.g., ["\n\n", "sentence", "\n", " "]) and attempts to divide 280 the text using the first separator. If the resulting chunks are still larger than the specified chunk size, 281 it moves to the next separator in the list. This process continues recursively, progressively applying each 282 specific separator until the chunks meet the desired size criteria. 283 284 :param text: The text to be split into chunks. 285 :returns: 286 A list of text chunks. 287 """ 288 if self._chunk_length(text) <= self.split_length: 289 return [text] 290 291 for curr_separator in self.separators: 292 if curr_separator == "sentence": 293 # re. ignore: correct SentenceSplitter initialization is checked at the initialization of the component 294 sentence_with_spans = self.nltk_tokenizer.split_sentences(text) # type: ignore 295 splits = [sentence["sentence"] for sentence in sentence_with_spans] 296 else: 297 # add escape "\" to the separator and wrapped it in a group so that it's included in the splits as well 298 escaped_separator = re.escape(curr_separator) 299 escaped_separator = f"({escaped_separator})" 300 301 # split the text and merge every two consecutive splits, i.e.: the text and the separator after it 302 splits = re.split(escaped_separator, text) 303 splits = [ 304 "".join([splits[i], splits[i + 1]]) if i < len(splits) - 1 else splits[i] 305 for i in range(0, len(splits), 2) 306 ] 307 308 # remove last split if it's empty 309 splits = splits[:-1] if splits[-1] == "" else splits 310 311 if len(splits) == 1: # go to next separator, if current separator not found in the text 312 continue 313 314 chunks = [] 315 current_chunk: list[str] = [] 316 current_length = 0 317 318 # check splits, if any is too long, recursively chunk it, otherwise add to current chunk 319 for split in splits: 320 split_text = split 321 322 # if adding this split exceeds chunk_size, process current_chunk 323 if current_length + self._chunk_length(split_text) > self.split_length: 324 # process current_chunk 325 if current_chunk: # keep the good splits 326 chunks.append("".join(current_chunk)) 327 current_chunk = [] 328 current_length = 0 329 330 # recursively handle splits that are too large 331 if self._chunk_length(split_text) > self.split_length: 332 if curr_separator == self.separators[-1]: 333 # tried last separator, can't split further, do a fixed-split based on word/character/token 334 fall_back_chunks = self._fall_back_to_fixed_chunking(split_text, self.split_units) 335 chunks.extend(fall_back_chunks) 336 else: 337 chunks.extend(self._chunk_text(split_text)) 338 339 else: 340 current_chunk.append(split_text) 341 current_length += self._chunk_length(split_text) 342 else: 343 current_chunk.append(split_text) 344 current_length += self._chunk_length(split_text) 345 346 if current_chunk: 347 chunks.append("".join(current_chunk)) 348 349 if self.split_overlap > 0: 350 chunks = self._apply_overlap(chunks) 351 352 if chunks: 353 return chunks 354 355 # if no separator worked, fall back to word- or character-level chunking 356 return self._fall_back_to_fixed_chunking(text, self.split_units) 357 358 def _fall_back_to_fixed_chunking(self, text: str, split_units: Literal["word", "char", "token"]) -> list[str]: 359 """ 360 Fall back to a fixed chunking approach if no separator works for the text. 361 362 Splits the text into smaller chunks based on the split_length and split_units attributes, either by words, 363 characters, or tokens. 364 365 :param text: The text to be split into chunks. 366 :param split_units: The unit of the split_length parameter. It can be either "word", "char", or "token". 367 :returns: 368 A list of text chunks. 369 """ 370 chunks = [] 371 372 if split_units == "word": 373 words = re.findall(r"\S+|\s+", text) 374 current_chunk = [] 375 current_length = 0 376 377 for word in words: 378 if word != " ": 379 current_chunk.append(word) 380 current_length += 1 381 if current_length == self.split_length and current_chunk: 382 chunks.append("".join(current_chunk)) 383 current_chunk = [] 384 current_length = 0 385 else: 386 current_chunk.append(word) 387 388 if current_chunk: 389 chunks.append("".join(current_chunk)) 390 elif split_units == "char": 391 for i in range(0, self._chunk_length(text), self.split_length): 392 chunks.append(text[i : i + self.split_length]) 393 else: # token 394 # at this point we know that the tokenizer is already initialized 395 tokens = self.tiktoken_tokenizer.encode(text) # type: ignore 396 for i in range(0, len(tokens), self.split_length): 397 chunk_tokens = tokens[i : i + self.split_length] 398 chunks.append(self.tiktoken_tokenizer.decode(chunk_tokens)) # type: ignore 399 return chunks 400 401 def _add_overlap_info(self, curr_pos: int, new_doc: Document, new_docs: list[Document]) -> None: 402 prev_doc = new_docs[-1] 403 overlap_length = self._chunk_length(prev_doc.content) - (curr_pos - prev_doc.meta["split_idx_start"]) # type: ignore 404 if overlap_length > 0: 405 prev_doc.meta["_split_overlap"].append({"doc_id": new_doc.id, "range": (0, overlap_length)}) 406 new_doc.meta["_split_overlap"].append( 407 { 408 "doc_id": prev_doc.id, 409 "range": ( 410 self._chunk_length(prev_doc.content) - overlap_length, # type: ignore 411 self._chunk_length(prev_doc.content), # type: ignore 412 ), 413 } 414 ) 415 416 def _run_one(self, doc: Document) -> list[Document]: 417 chunks = self._chunk_text(doc.content) # type: ignore # the caller already check for a non-empty doc.content 418 chunks = chunks[:-1] if len(chunks[-1]) == 0 else chunks # remove last empty chunk if it exists 419 current_position = 0 420 current_page = 1 421 422 new_docs: list[Document] = [] 423 424 for split_nr, chunk in enumerate(chunks): 425 meta = deepcopy(doc.meta) 426 meta["parent_id"] = doc.id 427 meta["split_id"] = split_nr 428 meta["split_idx_start"] = current_position 429 meta["_split_overlap"] = [] if self.split_overlap > 0 else None 430 new_doc = Document(content=chunk, meta=meta) 431 432 # add overlap information to the previous and current doc 433 if split_nr > 0 and self.split_overlap > 0: 434 self._add_overlap_info(current_position, new_doc, new_docs) 435 436 # count page breaks in the chunk 437 current_page += chunk.count("\f") 438 439 # if there are consecutive page breaks at the end with no more text, adjust the page number 440 # e.g: "text\f\f\f" -> 3 page breaks, but current_page should be 1 441 consecutive_page_breaks = len(chunk) - len(chunk.rstrip("\f")) 442 443 if consecutive_page_breaks > 0: 444 new_doc.meta["page_number"] = current_page - consecutive_page_breaks 445 else: 446 new_doc.meta["page_number"] = current_page 447 448 # keep the new chunk doc and update the current position 449 new_docs.append(new_doc) 450 current_position += len(chunk) - (self.split_overlap if split_nr < len(chunks) - 1 else 0) 451 452 return new_docs 453 454 @component.output_types(documents=list[Document]) 455 def run(self, documents: list[Document]) -> dict[str, list[Document]]: 456 """ 457 Split a list of documents into documents with smaller chunks of text. 458 459 :param documents: List of Documents to split. 460 :returns: 461 A dictionary containing a key "documents" with a List of Documents with smaller chunks of text corresponding 462 to the input documents. 463 """ 464 if not self._is_warmed_up and ("sentence" in self.separators or self.split_units == "token"): 465 self.warm_up() 466 467 docs = [] 468 for doc in documents: 469 if not doc.content or doc.content == "": 470 logger.warning("Document ID {doc_id} has an empty content. Skipping this document.", doc_id=doc.id) 471 continue 472 docs.extend(self._run_one(doc)) 473 474 return {"documents": docs}