/ haystack / components / routers / document_type_router.py
document_type_router.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import mimetypes
  6  import re
  7  from collections import defaultdict
  8  from pathlib import Path
  9  
 10  from haystack import component
 11  from haystack.dataclasses import Document
 12  from haystack.utils.misc import _guess_mime_type
 13  
 14  
 15  @component
 16  class DocumentTypeRouter:
 17      """
 18      Routes documents by their MIME types.
 19  
 20      DocumentTypeRouter is used to dynamically route documents within a pipeline based on their MIME types.
 21      It supports exact MIME type matches and regex patterns.
 22  
 23      MIME types can be extracted directly from document metadata or inferred from file paths using standard or
 24      user-supplied MIME type mappings.
 25  
 26      ### Usage example
 27  
 28      ```python
 29      from haystack.components.routers import DocumentTypeRouter
 30      from haystack.dataclasses import Document
 31  
 32      docs = [
 33          Document(content="Example text", meta={"file_path": "example.txt"}),
 34          Document(content="Another document", meta={"mime_type": "application/pdf"}),
 35          Document(content="Unknown type")
 36      ]
 37  
 38      router = DocumentTypeRouter(
 39          mime_type_meta_field="mime_type",
 40          file_path_meta_field="file_path",
 41          mime_types=["text/plain", "application/pdf"]
 42      )
 43  
 44      result = router.run(documents=docs)
 45      print(result)
 46      ```
 47  
 48      Expected output:
 49      ```python
 50      {
 51          "text/plain": [Document(...)],
 52          "application/pdf": [Document(...)],
 53          "unclassified": [Document(...)]
 54      }
 55      ```
 56      """
 57  
 58      def __init__(
 59          self,
 60          *,
 61          mime_types: list[str],
 62          mime_type_meta_field: str | None = None,
 63          file_path_meta_field: str | None = None,
 64          additional_mimetypes: dict[str, str] | None = None,
 65      ) -> None:
 66          """
 67          Initialize the DocumentTypeRouter component.
 68  
 69          :param mime_types:
 70              A list of MIME types or regex patterns to classify the input documents.
 71              (for example: `["text/plain", "audio/x-wav", "image/jpeg"]`).
 72          :param mime_type_meta_field:
 73              Optional name of the metadata field that holds the MIME type.
 74          :param file_path_meta_field:
 75              Optional name of the metadata field that holds the file path. Used to infer the MIME type if
 76              `mime_type_meta_field` is not provided or missing in a document.
 77          :param additional_mimetypes:
 78              Optional dictionary mapping MIME types to file extensions to enhance or override the standard
 79              `mimetypes` module. Useful when working with uncommon or custom file types.
 80              For example: `{"application/vnd.custom-type": ".custom"}`.
 81  
 82          :raises ValueError: If `mime_types` is empty or if both `mime_type_meta_field` and `file_path_meta_field` are
 83              not provided.
 84          """
 85          if not mime_types:
 86              raise ValueError("The list of mime types cannot be empty.")
 87  
 88          if mime_type_meta_field is None and file_path_meta_field is None:
 89              raise ValueError(
 90                  "At least one of 'mime_type_meta_field' or 'file_path_meta_field' must be provided to determine MIME "
 91                  "types."
 92              )
 93          self.mime_type_meta_field = mime_type_meta_field
 94          self.file_path_meta_field = file_path_meta_field
 95  
 96          if additional_mimetypes:
 97              for mime, ext in additional_mimetypes.items():
 98                  mimetypes.add_type(mime, ext)
 99  
100          self._mime_type_patterns = []
101          for mime_type in mime_types:
102              try:
103                  pattern = re.compile(mime_type)
104              except re.error as e:
105                  raise ValueError(f"Invalid regex pattern '{mime_type}'.") from e
106              self._mime_type_patterns.append(pattern)
107  
108          component.set_output_types(self, unclassified=list[Document], **dict.fromkeys(mime_types, list[Document]))
109          self.mime_types = mime_types
110          self.additional_mimetypes = additional_mimetypes
111  
112      def run(self, documents: list[Document]) -> dict[str, list[Document]]:
113          """
114          Categorize input documents into groups based on their MIME type.
115  
116          MIME types can either be directly available in document metadata or derived from file paths using the
117          standard Python `mimetypes` module and custom mappings.
118  
119          :param documents:
120              A list of documents to be categorized.
121  
122          :returns:
123              A dictionary where the keys are MIME types (or `"unclassified"`) and the values are lists of documents.
124          """
125          mime_types = defaultdict(list)
126  
127          for doc in documents:
128              mime_type = doc.meta.get(self.mime_type_meta_field) if self.mime_type_meta_field else None
129              file_path = doc.meta.get(self.file_path_meta_field) if self.file_path_meta_field else None
130  
131              if mime_type is None and file_path:
132                  # if mime_type is not provided, try to guess it from the file path
133                  mime_type = _guess_mime_type(Path(file_path))
134  
135              matched = False
136              if mime_type:
137                  for pattern in self._mime_type_patterns:
138                      if pattern.fullmatch(mime_type):
139                          mime_types[pattern.pattern].append(doc)
140                          matched = True
141                          break
142              if not matched:
143                  mime_types["unclassified"].append(doc)
144  
145          return dict(mime_types)