csv_document_splitter.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from io import StringIO 6 from typing import Any, Literal, get_args 7 8 from haystack import Document, component, logging 9 from haystack.lazy_imports import LazyImport 10 11 with LazyImport("Run 'pip install pandas'") as pandas_import: 12 import pandas as pd 13 14 logger = logging.getLogger(__name__) 15 16 SplitMode = Literal["threshold", "row-wise"] 17 18 19 @component 20 class CSVDocumentSplitter: 21 """ 22 A component for splitting CSV documents into sub-tables based on split arguments. 23 24 The splitter supports two modes of operation: 25 - identify consecutive empty rows or columns that exceed a given threshold 26 and uses them as delimiters to segment the document into smaller tables. 27 - split each row into a separate sub-table, represented as a Document. 28 29 """ 30 31 def __init__( 32 self, 33 row_split_threshold: int | None = 2, 34 column_split_threshold: int | None = 2, 35 read_csv_kwargs: dict[str, Any] | None = None, 36 split_mode: SplitMode = "threshold", 37 ) -> None: 38 """ 39 Initializes the CSVDocumentSplitter component. 40 41 :param row_split_threshold: The minimum number of consecutive empty rows required to trigger a split. 42 :param column_split_threshold: The minimum number of consecutive empty columns required to trigger a split. 43 :param read_csv_kwargs: Additional keyword arguments to pass to `pandas.read_csv`. 44 By default, the component with options: 45 - `header=None` 46 - `skip_blank_lines=False` to preserve blank lines 47 - `dtype=object` to prevent type inference (e.g., converting numbers to floats). 48 See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for more information. 49 :param split_mode: 50 If `threshold`, the component will split the document based on the number of 51 consecutive empty rows or columns that exceed the `row_split_threshold` or `column_split_threshold`. 52 If `row-wise`, the component will split each row into a separate sub-table. 53 """ 54 pandas_import.check() 55 if split_mode not in get_args(SplitMode): 56 raise ValueError( 57 f"Split mode '{split_mode}' not recognized. Choose one among: {', '.join(get_args(SplitMode))}." 58 ) 59 if row_split_threshold is not None and row_split_threshold < 1: 60 raise ValueError("row_split_threshold must be greater than 0") 61 62 if column_split_threshold is not None and column_split_threshold < 1: 63 raise ValueError("column_split_threshold must be greater than 0") 64 65 if row_split_threshold is None and column_split_threshold is None: 66 raise ValueError("At least one of row_split_threshold or column_split_threshold must be specified.") 67 68 self.row_split_threshold = row_split_threshold 69 self.column_split_threshold = column_split_threshold 70 self.read_csv_kwargs = read_csv_kwargs or {} 71 self.split_mode = split_mode 72 73 @component.output_types(documents=list[Document]) 74 def run(self, documents: list[Document]) -> dict[str, list[Document]]: 75 """ 76 Processes and splits a list of CSV documents into multiple sub-tables. 77 78 **Splitting Process:** 79 1. Applies a row-based split if `row_split_threshold` is provided. 80 2. Applies a column-based split if `column_split_threshold` is provided. 81 3. If both thresholds are specified, performs a recursive split by rows first, then columns, ensuring 82 further fragmentation of any sub-tables that still contain empty sections. 83 4. Sorts the resulting sub-tables based on their original positions within the document. 84 85 :param documents: A list of Documents containing CSV-formatted content. 86 Each document is assumed to contain one or more tables separated by empty rows or columns. 87 88 :return: 89 A dictionary with a key `"documents"`, mapping to a list of new `Document` objects, 90 each representing an extracted sub-table from the original CSV. 91 The metadata of each document includes: 92 - A field `source_id` to track the original document. 93 - A field `row_idx_start` to indicate the starting row index of the sub-table in the original table. 94 - A field `col_idx_start` to indicate the starting column index of the sub-table in the original table. 95 - A field `split_id` to indicate the order of the split in the original document. 96 - All other metadata copied from the original document. 97 98 - If a document cannot be processed, it is returned unchanged. 99 - The `meta` field from the original document is preserved in the split documents. 100 """ 101 if len(documents) == 0: 102 return {"documents": documents} 103 104 resolved_read_csv_kwargs = {"header": None, "skip_blank_lines": False, "dtype": object, **self.read_csv_kwargs} 105 106 split_documents = [] 107 split_dfs = [] 108 for document in documents: 109 try: 110 df = pd.read_csv(StringIO(document.content), **resolved_read_csv_kwargs) 111 except Exception as e: 112 logger.exception( 113 "Error processing document {document_id}. Keeping it, but skipping splitting. Error: {error}", 114 document_id=document.id, 115 error=e, 116 ) 117 split_documents.append(document) 118 continue 119 120 if self.split_mode == "row-wise": 121 # each row is a separate sub-table 122 split_dfs = self._split_by_row(df=df) 123 124 elif self.split_mode == "threshold": 125 if self.row_split_threshold is not None and self.column_split_threshold is None: 126 # split by rows 127 split_dfs = self._split_dataframe(df=df, split_threshold=self.row_split_threshold, axis="row") 128 elif self.column_split_threshold is not None and self.row_split_threshold is None: 129 # split by columns 130 split_dfs = self._split_dataframe(df=df, split_threshold=self.column_split_threshold, axis="column") 131 else: 132 # recursive split 133 split_dfs = self._recursive_split( 134 df=df, 135 row_split_threshold=self.row_split_threshold, # type: ignore 136 column_split_threshold=self.column_split_threshold, # type: ignore 137 ) 138 139 # check if no sub-tables were found 140 if len(split_dfs) == 0: 141 logger.warning( 142 "No sub-tables found while splitting CSV Document with id {doc_id}. Skipping document.", 143 doc_id=document.id, 144 ) 145 continue 146 147 # Sort split_dfs first by row index, then by column index 148 split_dfs.sort(key=lambda dataframe: (dataframe.index[0], dataframe.columns[0])) 149 150 for split_id, split_df in enumerate(split_dfs): 151 split_documents.append( 152 Document( 153 content=split_df.to_csv(index=False, header=False, lineterminator="\n"), 154 meta={ 155 **document.meta.copy(), 156 "source_id": document.id, 157 "row_idx_start": int(split_df.index[0]), 158 "col_idx_start": int(split_df.columns[0]), 159 "split_id": split_id, 160 }, 161 ) 162 ) 163 164 return {"documents": split_documents} 165 166 @staticmethod 167 def _find_split_indices( 168 df: "pd.DataFrame", split_threshold: int, axis: Literal["row", "column"] 169 ) -> list[tuple[int, int]]: 170 """ 171 Finds the indices of consecutive empty rows or columns in a DataFrame. 172 173 :param df: DataFrame to split. 174 :param split_threshold: Minimum number of consecutive empty rows or columns to trigger a split. 175 :param axis: Axis along which to find empty elements. Either "row" or "column". 176 :return: List of indices where consecutive empty rows or columns start. 177 """ 178 if axis == "row": 179 empty_elements = df[df.isnull().all(axis=1)].index.tolist() 180 else: 181 empty_elements = df.columns[df.isnull().all(axis=0)].tolist() 182 183 # If no empty elements found, return empty list 184 if len(empty_elements) == 0: 185 return [] 186 187 # Identify groups of consecutive empty elements 188 split_indices = [] 189 consecutive_count = 1 190 start_index = empty_elements[0] 191 192 for i in range(1, len(empty_elements)): 193 if empty_elements[i] == empty_elements[i - 1] + 1: 194 consecutive_count += 1 195 else: 196 if consecutive_count >= split_threshold: 197 split_indices.append((start_index, empty_elements[i - 1])) 198 consecutive_count = 1 199 start_index = empty_elements[i] 200 201 # Handle the last group of consecutive elements 202 if consecutive_count >= split_threshold: 203 split_indices.append((start_index, empty_elements[-1])) 204 205 return split_indices 206 207 def _split_dataframe( 208 self, df: "pd.DataFrame", split_threshold: int, axis: Literal["row", "column"] 209 ) -> list["pd.DataFrame"]: 210 """ 211 Splits a DataFrame into sub-tables based on consecutive empty rows or columns exceeding `split_threshold`. 212 213 :param df: DataFrame to split. 214 :param split_threshold: Minimum number of consecutive empty rows or columns to trigger a split. 215 :param axis: Axis along which to split. Either "row" or "column". 216 :return: List of split DataFrames. 217 """ 218 # Find indices of consecutive empty rows or columns 219 split_indices = self._find_split_indices(df=df, split_threshold=split_threshold, axis=axis) 220 221 # If no split_indices are found, return the original DataFrame 222 if len(split_indices) == 0: 223 return [df] 224 225 # Split the DataFrame at identified indices 226 sub_tables = [] 227 table_start_idx = 0 228 df_length = df.shape[0] if axis == "row" else df.shape[1] 229 for empty_start_idx, empty_end_idx in split_indices + [(df_length, df_length)]: 230 # Avoid empty splits 231 if empty_start_idx - table_start_idx >= 1: 232 if axis == "row": 233 sub_table = df.iloc[table_start_idx:empty_start_idx] 234 else: 235 sub_table = df.iloc[:, table_start_idx:empty_start_idx] 236 if not sub_table.empty: 237 sub_tables.append(sub_table) 238 table_start_idx = empty_end_idx + 1 239 240 return sub_tables 241 242 def _recursive_split( 243 self, df: "pd.DataFrame", row_split_threshold: int, column_split_threshold: int 244 ) -> list["pd.DataFrame"]: 245 """ 246 Recursively splits a DataFrame. 247 248 Recursively splits a DataFrame first by empty rows, then by empty columns, and repeats the process 249 until no more splits are possible. Returns a list of DataFrames, each representing a fully separated sub-table. 250 251 :param df: A Pandas DataFrame representing a table (or multiple tables) extracted from a CSV. 252 :param row_split_threshold: The minimum number of consecutive empty rows required to trigger a split. 253 :param column_split_threshold: The minimum number of consecutive empty columns to trigger a split. 254 """ 255 256 # Step 1: Split by rows 257 new_sub_tables = self._split_dataframe(df=df, split_threshold=row_split_threshold, axis="row") 258 259 # Step 2: Split by columns 260 final_tables = [] 261 for table in new_sub_tables: 262 final_tables.extend(self._split_dataframe(df=table, split_threshold=column_split_threshold, axis="column")) 263 264 # Step 3: Recursively reapply splitting checked by whether any new empty rows appear after column split 265 result = [] 266 for table in final_tables: 267 # Check if there are consecutive rows >= row_split_threshold now present 268 if len(self._find_split_indices(df=table, split_threshold=row_split_threshold, axis="row")) > 0: 269 result.extend( 270 self._recursive_split( 271 df=table, row_split_threshold=row_split_threshold, column_split_threshold=column_split_threshold 272 ) 273 ) 274 else: 275 result.append(table) 276 277 return result 278 279 def _split_by_row(self, df: "pd.DataFrame") -> list["pd.DataFrame"]: 280 """Split each CSV row into a separate subtable""" 281 split_dfs = [] 282 for idx, row in enumerate(df.itertuples(index=False)): 283 split_df = pd.DataFrame(row).T 284 split_df.index = [idx] # Set the index of the new DataFrame to idx 285 split_dfs.append(split_df) 286 return split_dfs