/ haystack / components / preprocessors / csv_document_splitter.py
csv_document_splitter.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from io import StringIO
  6  from typing import Any, Literal, get_args
  7  
  8  from haystack import Document, component, logging
  9  from haystack.lazy_imports import LazyImport
 10  
 11  with LazyImport("Run 'pip install pandas'") as pandas_import:
 12      import pandas as pd
 13  
 14  logger = logging.getLogger(__name__)
 15  
 16  SplitMode = Literal["threshold", "row-wise"]
 17  
 18  
 19  @component
 20  class CSVDocumentSplitter:
 21      """
 22      A component for splitting CSV documents into sub-tables based on split arguments.
 23  
 24      The splitter supports two modes of operation:
 25      - identify consecutive empty rows or columns that exceed a given threshold
 26      and uses them as delimiters to segment the document into smaller tables.
 27      - split each row into a separate sub-table, represented as a Document.
 28  
 29      """
 30  
 31      def __init__(
 32          self,
 33          row_split_threshold: int | None = 2,
 34          column_split_threshold: int | None = 2,
 35          read_csv_kwargs: dict[str, Any] | None = None,
 36          split_mode: SplitMode = "threshold",
 37      ) -> None:
 38          """
 39          Initializes the CSVDocumentSplitter component.
 40  
 41          :param row_split_threshold: The minimum number of consecutive empty rows required to trigger a split.
 42          :param column_split_threshold: The minimum number of consecutive empty columns required to trigger a split.
 43          :param read_csv_kwargs: Additional keyword arguments to pass to `pandas.read_csv`.
 44              By default, the component with options:
 45              - `header=None`
 46              - `skip_blank_lines=False` to preserve blank lines
 47              - `dtype=object` to prevent type inference (e.g., converting numbers to floats).
 48              See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for more information.
 49          :param split_mode:
 50              If `threshold`, the component will split the document based on the number of
 51              consecutive empty rows or columns that exceed the `row_split_threshold` or `column_split_threshold`.
 52              If `row-wise`, the component will split each row into a separate sub-table.
 53          """
 54          pandas_import.check()
 55          if split_mode not in get_args(SplitMode):
 56              raise ValueError(
 57                  f"Split mode '{split_mode}' not recognized. Choose one among: {', '.join(get_args(SplitMode))}."
 58              )
 59          if row_split_threshold is not None and row_split_threshold < 1:
 60              raise ValueError("row_split_threshold must be greater than 0")
 61  
 62          if column_split_threshold is not None and column_split_threshold < 1:
 63              raise ValueError("column_split_threshold must be greater than 0")
 64  
 65          if row_split_threshold is None and column_split_threshold is None:
 66              raise ValueError("At least one of row_split_threshold or column_split_threshold must be specified.")
 67  
 68          self.row_split_threshold = row_split_threshold
 69          self.column_split_threshold = column_split_threshold
 70          self.read_csv_kwargs = read_csv_kwargs or {}
 71          self.split_mode = split_mode
 72  
 73      @component.output_types(documents=list[Document])
 74      def run(self, documents: list[Document]) -> dict[str, list[Document]]:
 75          """
 76          Processes and splits a list of CSV documents into multiple sub-tables.
 77  
 78          **Splitting Process:**
 79          1. Applies a row-based split if `row_split_threshold` is provided.
 80          2. Applies a column-based split if `column_split_threshold` is provided.
 81          3. If both thresholds are specified, performs a recursive split by rows first, then columns, ensuring
 82             further fragmentation of any sub-tables that still contain empty sections.
 83          4. Sorts the resulting sub-tables based on their original positions within the document.
 84  
 85          :param documents: A list of Documents containing CSV-formatted content.
 86              Each document is assumed to contain one or more tables separated by empty rows or columns.
 87  
 88          :return:
 89              A dictionary with a key `"documents"`, mapping to a list of new `Document` objects,
 90              each representing an extracted sub-table from the original CSV.
 91              The metadata of each document includes:
 92                  - A field `source_id` to track the original document.
 93                  - A field `row_idx_start` to indicate the starting row index of the sub-table in the original table.
 94                  - A field `col_idx_start` to indicate the starting column index of the sub-table in the original table.
 95                  - A field `split_id` to indicate the order of the split in the original document.
 96                  - All other metadata copied from the original document.
 97  
 98          - If a document cannot be processed, it is returned unchanged.
 99          - The `meta` field from the original document is preserved in the split documents.
100          """
101          if len(documents) == 0:
102              return {"documents": documents}
103  
104          resolved_read_csv_kwargs = {"header": None, "skip_blank_lines": False, "dtype": object, **self.read_csv_kwargs}
105  
106          split_documents = []
107          split_dfs = []
108          for document in documents:
109              try:
110                  df = pd.read_csv(StringIO(document.content), **resolved_read_csv_kwargs)
111              except Exception as e:
112                  logger.exception(
113                      "Error processing document {document_id}. Keeping it, but skipping splitting. Error: {error}",
114                      document_id=document.id,
115                      error=e,
116                  )
117                  split_documents.append(document)
118                  continue
119  
120              if self.split_mode == "row-wise":
121                  # each row is a separate sub-table
122                  split_dfs = self._split_by_row(df=df)
123  
124              elif self.split_mode == "threshold":
125                  if self.row_split_threshold is not None and self.column_split_threshold is None:
126                      # split by rows
127                      split_dfs = self._split_dataframe(df=df, split_threshold=self.row_split_threshold, axis="row")
128                  elif self.column_split_threshold is not None and self.row_split_threshold is None:
129                      # split by columns
130                      split_dfs = self._split_dataframe(df=df, split_threshold=self.column_split_threshold, axis="column")
131                  else:
132                      # recursive split
133                      split_dfs = self._recursive_split(
134                          df=df,
135                          row_split_threshold=self.row_split_threshold,  # type: ignore
136                          column_split_threshold=self.column_split_threshold,  # type: ignore
137                      )
138  
139              # check if no sub-tables were found
140              if len(split_dfs) == 0:
141                  logger.warning(
142                      "No sub-tables found while splitting CSV Document with id {doc_id}. Skipping document.",
143                      doc_id=document.id,
144                  )
145                  continue
146  
147              # Sort split_dfs first by row index, then by column index
148              split_dfs.sort(key=lambda dataframe: (dataframe.index[0], dataframe.columns[0]))
149  
150              for split_id, split_df in enumerate(split_dfs):
151                  split_documents.append(
152                      Document(
153                          content=split_df.to_csv(index=False, header=False, lineterminator="\n"),
154                          meta={
155                              **document.meta.copy(),
156                              "source_id": document.id,
157                              "row_idx_start": int(split_df.index[0]),
158                              "col_idx_start": int(split_df.columns[0]),
159                              "split_id": split_id,
160                          },
161                      )
162                  )
163  
164          return {"documents": split_documents}
165  
166      @staticmethod
167      def _find_split_indices(
168          df: "pd.DataFrame", split_threshold: int, axis: Literal["row", "column"]
169      ) -> list[tuple[int, int]]:
170          """
171          Finds the indices of consecutive empty rows or columns in a DataFrame.
172  
173          :param df: DataFrame to split.
174          :param split_threshold: Minimum number of consecutive empty rows or columns to trigger a split.
175          :param axis: Axis along which to find empty elements. Either "row" or "column".
176          :return: List of indices where consecutive empty rows or columns start.
177          """
178          if axis == "row":
179              empty_elements = df[df.isnull().all(axis=1)].index.tolist()
180          else:
181              empty_elements = df.columns[df.isnull().all(axis=0)].tolist()
182  
183          # If no empty elements found, return empty list
184          if len(empty_elements) == 0:
185              return []
186  
187          # Identify groups of consecutive empty elements
188          split_indices = []
189          consecutive_count = 1
190          start_index = empty_elements[0]
191  
192          for i in range(1, len(empty_elements)):
193              if empty_elements[i] == empty_elements[i - 1] + 1:
194                  consecutive_count += 1
195              else:
196                  if consecutive_count >= split_threshold:
197                      split_indices.append((start_index, empty_elements[i - 1]))
198                  consecutive_count = 1
199                  start_index = empty_elements[i]
200  
201          # Handle the last group of consecutive elements
202          if consecutive_count >= split_threshold:
203              split_indices.append((start_index, empty_elements[-1]))
204  
205          return split_indices
206  
207      def _split_dataframe(
208          self, df: "pd.DataFrame", split_threshold: int, axis: Literal["row", "column"]
209      ) -> list["pd.DataFrame"]:
210          """
211          Splits a DataFrame into sub-tables based on consecutive empty rows or columns exceeding `split_threshold`.
212  
213          :param df: DataFrame to split.
214          :param split_threshold: Minimum number of consecutive empty rows or columns to trigger a split.
215          :param axis: Axis along which to split. Either "row" or "column".
216          :return: List of split DataFrames.
217          """
218          # Find indices of consecutive empty rows or columns
219          split_indices = self._find_split_indices(df=df, split_threshold=split_threshold, axis=axis)
220  
221          # If no split_indices are found, return the original DataFrame
222          if len(split_indices) == 0:
223              return [df]
224  
225          # Split the DataFrame at identified indices
226          sub_tables = []
227          table_start_idx = 0
228          df_length = df.shape[0] if axis == "row" else df.shape[1]
229          for empty_start_idx, empty_end_idx in split_indices + [(df_length, df_length)]:
230              # Avoid empty splits
231              if empty_start_idx - table_start_idx >= 1:
232                  if axis == "row":
233                      sub_table = df.iloc[table_start_idx:empty_start_idx]
234                  else:
235                      sub_table = df.iloc[:, table_start_idx:empty_start_idx]
236                  if not sub_table.empty:
237                      sub_tables.append(sub_table)
238              table_start_idx = empty_end_idx + 1
239  
240          return sub_tables
241  
242      def _recursive_split(
243          self, df: "pd.DataFrame", row_split_threshold: int, column_split_threshold: int
244      ) -> list["pd.DataFrame"]:
245          """
246          Recursively splits a DataFrame.
247  
248          Recursively splits a DataFrame first by empty rows, then by empty columns, and repeats the process
249          until no more splits are possible. Returns a list of DataFrames, each representing a fully separated sub-table.
250  
251          :param df: A Pandas DataFrame representing a table (or multiple tables) extracted from a CSV.
252          :param row_split_threshold: The minimum number of consecutive empty rows required to trigger a split.
253          :param column_split_threshold: The minimum number of consecutive empty columns to trigger a split.
254          """
255  
256          # Step 1: Split by rows
257          new_sub_tables = self._split_dataframe(df=df, split_threshold=row_split_threshold, axis="row")
258  
259          # Step 2: Split by columns
260          final_tables = []
261          for table in new_sub_tables:
262              final_tables.extend(self._split_dataframe(df=table, split_threshold=column_split_threshold, axis="column"))
263  
264          # Step 3: Recursively reapply splitting checked by whether any new empty rows appear after column split
265          result = []
266          for table in final_tables:
267              # Check if there are consecutive rows >= row_split_threshold now present
268              if len(self._find_split_indices(df=table, split_threshold=row_split_threshold, axis="row")) > 0:
269                  result.extend(
270                      self._recursive_split(
271                          df=table, row_split_threshold=row_split_threshold, column_split_threshold=column_split_threshold
272                      )
273                  )
274              else:
275                  result.append(table)
276  
277          return result
278  
279      def _split_by_row(self, df: "pd.DataFrame") -> list["pd.DataFrame"]:
280          """Split each CSV row into a separate subtable"""
281          split_dfs = []
282          for idx, row in enumerate(df.itertuples(index=False)):
283              split_df = pd.DataFrame(row).T
284              split_df.index = [idx]  # Set the index of the new DataFrame to idx
285              split_dfs.append(split_df)
286          return split_dfs