/ haystack / components / routers / metadata_router.py
metadata_router.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from typing import Any
  6  
  7  from haystack import Document, component, default_from_dict, default_to_dict
  8  from haystack.dataclasses import ByteStream
  9  from haystack.utils import deserialize_type, serialize_type
 10  from haystack.utils.filters import document_matches_filter
 11  
 12  
 13  @component
 14  class MetadataRouter:
 15      """
 16      Routes documents or byte streams to different connections based on their metadata fields.
 17  
 18      Specify the routing rules in the `init` method.
 19      If a document or byte stream does not match any of the rules, it's routed to a connection named "unmatched".
 20  
 21  
 22      ### Usage examples
 23  
 24      **Routing Documents by metadata:**
 25      ```python
 26      from haystack import Document
 27      from haystack.components.routers import MetadataRouter
 28  
 29      docs = [Document(content="Paris is the capital of France.", meta={"language": "en"}),
 30              Document(content="Berlin ist die Haupststadt von Deutschland.", meta={"language": "de"})]
 31  
 32      router = MetadataRouter(rules={"en": {"field": "meta.language", "operator": "==", "value": "en"}})
 33  
 34      print(router.run(documents=docs))
 35      # {'en': [Document(id=..., content: 'Paris is the capital of France.', meta: {'language': 'en'})],
 36      # 'unmatched': [Document(id=..., content: 'Berlin ist die Haupststadt von Deutschland.', meta: {'language': 'de'})]}
 37      ```
 38  
 39      **Routing ByteStreams by metadata:**
 40      ```python
 41      from haystack.dataclasses import ByteStream
 42      from haystack.components.routers import MetadataRouter
 43  
 44      streams = [
 45          ByteStream.from_string("Hello world", meta={"language": "en"}),
 46          ByteStream.from_string("Bonjour le monde", meta={"language": "fr"})
 47      ]
 48  
 49      router = MetadataRouter(
 50          rules={"english": {"field": "meta.language", "operator": "==", "value": "en"}},
 51          output_type=list[ByteStream]
 52      )
 53  
 54      result = router.run(documents=streams)
 55      # {'english': [ByteStream(...)], 'unmatched': [ByteStream(...)]}
 56      ```
 57      """
 58  
 59      def __init__(self, rules: dict[str, dict], output_type: type = list[Document]) -> None:
 60          """
 61          Initializes the MetadataRouter component.
 62  
 63          :param rules: A dictionary defining how to route documents or byte streams to output connections based on their
 64              metadata. Keys are output connection names, and values are dictionaries of
 65              [filtering expressions](https://docs.haystack.deepset.ai/docs/metadata-filtering) in Haystack.
 66              For example:
 67              ```python
 68              {
 69              "edge_1": {
 70                  "operator": "AND",
 71                  "conditions": [
 72                      {"field": "meta.created_at", "operator": ">=", "value": "2023-01-01"},
 73                      {"field": "meta.created_at", "operator": "<", "value": "2023-04-01"},
 74                  ],
 75              },
 76              "edge_2": {
 77                  "operator": "AND",
 78                  "conditions": [
 79                      {"field": "meta.created_at", "operator": ">=", "value": "2023-04-01"},
 80                      {"field": "meta.created_at", "operator": "<", "value": "2023-07-01"},
 81                  ],
 82              },
 83              "edge_3": {
 84                  "operator": "AND",
 85                  "conditions": [
 86                      {"field": "meta.created_at", "operator": ">=", "value": "2023-07-01"},
 87                      {"field": "meta.created_at", "operator": "<", "value": "2023-10-01"},
 88                  ],
 89              },
 90              "edge_4": {
 91                  "operator": "AND",
 92                  "conditions": [
 93                      {"field": "meta.created_at", "operator": ">=", "value": "2023-10-01"},
 94                      {"field": "meta.created_at", "operator": "<", "value": "2024-01-01"},
 95                  ],
 96              },
 97              }
 98              ```
 99              :param output_type: The type of the output produced. Lists of Documents or ByteStreams can be specified.
100          """
101          self.rules = rules
102          self.output_type = output_type
103          for rule in self.rules.values():
104              if "operator" not in rule:
105                  raise ValueError(
106                      "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details."
107                  )
108          component.set_output_types(self, unmatched=self.output_type, **dict.fromkeys(rules, self.output_type))
109  
110      def run(self, documents: list[Document] | list[ByteStream]) -> dict[str, list[Document] | list[ByteStream]]:
111          """
112          Routes documents or byte streams to different connections based on their metadata fields.
113  
114          If a document or byte stream does not match any of the rules, it's routed to a connection named "unmatched".
115  
116          :param documents: A list of `Document` or `ByteStream` objects to be routed based on their metadata.
117  
118          :returns: A dictionary where the keys are the names of the output connections (including `"unmatched"`)
119              and the values are lists of `Document` or `ByteStream` objects that matched the corresponding rules.
120          """
121  
122          unmatched: list[Document] | list[ByteStream] = []
123          output: dict[str, list[Document] | list[ByteStream]] = {edge: [] for edge in self.rules}
124  
125          for doc_or_bytestream in documents:
126              current_obj_matched = False
127              for edge, rule in self.rules.items():
128                  if document_matches_filter(filters=rule, document=doc_or_bytestream):
129                      # we need to ignore the arg-type here because the underlying
130                      # filter methods use type Union[Document, ByteStream]
131                      output[edge].append(doc_or_bytestream)  # type: ignore[arg-type]
132                      current_obj_matched = True
133  
134              if not current_obj_matched:
135                  unmatched.append(doc_or_bytestream)  # type: ignore[arg-type]
136  
137          output["unmatched"] = unmatched
138          return output
139  
140      def to_dict(self) -> dict[str, Any]:
141          """
142          Serialize this component to a dictionary.
143  
144          :returns:
145              The serialized component as a dictionary.
146          """
147          return default_to_dict(self, rules=self.rules, output_type=serialize_type(self.output_type))
148  
149      @classmethod
150      def from_dict(cls, data: dict[str, Any]) -> "MetadataRouter":
151          """
152          Deserialize this component from a dictionary.
153  
154          :param data:
155              The dictionary representation of this component.
156          :returns:
157              The deserialized component instance.
158          """
159          init_params = data.get("init_parameters", {})
160          if "output_type" in init_params:
161              # Deserialize the output_type to its original type
162              init_params["output_type"] = deserialize_type(init_params["output_type"])
163          return default_from_dict(cls, data)