metadata_router.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from typing import Any 6 7 from haystack import Document, component, default_from_dict, default_to_dict 8 from haystack.dataclasses import ByteStream 9 from haystack.utils import deserialize_type, serialize_type 10 from haystack.utils.filters import document_matches_filter 11 12 13 @component 14 class MetadataRouter: 15 """ 16 Routes documents or byte streams to different connections based on their metadata fields. 17 18 Specify the routing rules in the `init` method. 19 If a document or byte stream does not match any of the rules, it's routed to a connection named "unmatched". 20 21 22 ### Usage examples 23 24 **Routing Documents by metadata:** 25 ```python 26 from haystack import Document 27 from haystack.components.routers import MetadataRouter 28 29 docs = [Document(content="Paris is the capital of France.", meta={"language": "en"}), 30 Document(content="Berlin ist die Haupststadt von Deutschland.", meta={"language": "de"})] 31 32 router = MetadataRouter(rules={"en": {"field": "meta.language", "operator": "==", "value": "en"}}) 33 34 print(router.run(documents=docs)) 35 # {'en': [Document(id=..., content: 'Paris is the capital of France.', meta: {'language': 'en'})], 36 # 'unmatched': [Document(id=..., content: 'Berlin ist die Haupststadt von Deutschland.', meta: {'language': 'de'})]} 37 ``` 38 39 **Routing ByteStreams by metadata:** 40 ```python 41 from haystack.dataclasses import ByteStream 42 from haystack.components.routers import MetadataRouter 43 44 streams = [ 45 ByteStream.from_string("Hello world", meta={"language": "en"}), 46 ByteStream.from_string("Bonjour le monde", meta={"language": "fr"}) 47 ] 48 49 router = MetadataRouter( 50 rules={"english": {"field": "meta.language", "operator": "==", "value": "en"}}, 51 output_type=list[ByteStream] 52 ) 53 54 result = router.run(documents=streams) 55 # {'english': [ByteStream(...)], 'unmatched': [ByteStream(...)]} 56 ``` 57 """ 58 59 def __init__(self, rules: dict[str, dict], output_type: type = list[Document]) -> None: 60 """ 61 Initializes the MetadataRouter component. 62 63 :param rules: A dictionary defining how to route documents or byte streams to output connections based on their 64 metadata. Keys are output connection names, and values are dictionaries of 65 [filtering expressions](https://docs.haystack.deepset.ai/docs/metadata-filtering) in Haystack. 66 For example: 67 ```python 68 { 69 "edge_1": { 70 "operator": "AND", 71 "conditions": [ 72 {"field": "meta.created_at", "operator": ">=", "value": "2023-01-01"}, 73 {"field": "meta.created_at", "operator": "<", "value": "2023-04-01"}, 74 ], 75 }, 76 "edge_2": { 77 "operator": "AND", 78 "conditions": [ 79 {"field": "meta.created_at", "operator": ">=", "value": "2023-04-01"}, 80 {"field": "meta.created_at", "operator": "<", "value": "2023-07-01"}, 81 ], 82 }, 83 "edge_3": { 84 "operator": "AND", 85 "conditions": [ 86 {"field": "meta.created_at", "operator": ">=", "value": "2023-07-01"}, 87 {"field": "meta.created_at", "operator": "<", "value": "2023-10-01"}, 88 ], 89 }, 90 "edge_4": { 91 "operator": "AND", 92 "conditions": [ 93 {"field": "meta.created_at", "operator": ">=", "value": "2023-10-01"}, 94 {"field": "meta.created_at", "operator": "<", "value": "2024-01-01"}, 95 ], 96 }, 97 } 98 ``` 99 :param output_type: The type of the output produced. Lists of Documents or ByteStreams can be specified. 100 """ 101 self.rules = rules 102 self.output_type = output_type 103 for rule in self.rules.values(): 104 if "operator" not in rule: 105 raise ValueError( 106 "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." 107 ) 108 component.set_output_types(self, unmatched=self.output_type, **dict.fromkeys(rules, self.output_type)) 109 110 def run(self, documents: list[Document] | list[ByteStream]) -> dict[str, list[Document] | list[ByteStream]]: 111 """ 112 Routes documents or byte streams to different connections based on their metadata fields. 113 114 If a document or byte stream does not match any of the rules, it's routed to a connection named "unmatched". 115 116 :param documents: A list of `Document` or `ByteStream` objects to be routed based on their metadata. 117 118 :returns: A dictionary where the keys are the names of the output connections (including `"unmatched"`) 119 and the values are lists of `Document` or `ByteStream` objects that matched the corresponding rules. 120 """ 121 122 unmatched: list[Document] | list[ByteStream] = [] 123 output: dict[str, list[Document] | list[ByteStream]] = {edge: [] for edge in self.rules} 124 125 for doc_or_bytestream in documents: 126 current_obj_matched = False 127 for edge, rule in self.rules.items(): 128 if document_matches_filter(filters=rule, document=doc_or_bytestream): 129 # we need to ignore the arg-type here because the underlying 130 # filter methods use type Union[Document, ByteStream] 131 output[edge].append(doc_or_bytestream) # type: ignore[arg-type] 132 current_obj_matched = True 133 134 if not current_obj_matched: 135 unmatched.append(doc_or_bytestream) # type: ignore[arg-type] 136 137 output["unmatched"] = unmatched 138 return output 139 140 def to_dict(self) -> dict[str, Any]: 141 """ 142 Serialize this component to a dictionary. 143 144 :returns: 145 The serialized component as a dictionary. 146 """ 147 return default_to_dict(self, rules=self.rules, output_type=serialize_type(self.output_type)) 148 149 @classmethod 150 def from_dict(cls, data: dict[str, Any]) -> "MetadataRouter": 151 """ 152 Deserialize this component from a dictionary. 153 154 :param data: 155 The dictionary representation of this component. 156 :returns: 157 The deserialized component instance. 158 """ 159 init_params = data.get("init_parameters", {}) 160 if "output_type" in init_params: 161 # Deserialize the output_type to its original type 162 init_params["output_type"] = deserialize_type(init_params["output_type"]) 163 return default_from_dict(cls, data)