misc.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import json 6 import mimetypes 7 import tempfile 8 from math import inf 9 from pathlib import Path 10 from typing import TYPE_CHECKING, Any, Literal, overload 11 12 from numpy import exp, ndarray 13 14 from haystack import logging 15 16 if TYPE_CHECKING: 17 from haystack.dataclasses import Document 18 19 CUSTOM_MIMETYPES = { 20 # we add markdown because it is not added by the mimetypes module 21 # see https://github.com/python/cpython/pull/17995 22 ".md": "text/markdown", 23 ".markdown": "text/markdown", 24 # we add msg because it is not added by the mimetypes module 25 ".msg": "application/vnd.ms-outlook", 26 } 27 28 logger = logging.getLogger(__name__) 29 30 31 def expand_page_range(page_range: list[str | int]) -> list[int]: 32 """ 33 Takes a list of page numbers and ranges and expands them into a list of page numbers. 34 35 For example, given a page_range=['1-3', '5', '8', '10-12'] the function will return [1, 2, 3, 5, 8, 10, 11, 12] 36 37 :param page_range: List of page numbers and ranges 38 :returns: 39 An expanded list of page integers 40 41 """ 42 expanded_page_range = [] 43 44 for page in page_range: 45 if isinstance(page, int): 46 # check if it's a range wrongly passed as an integer expression 47 if "-" in str(page): 48 msg = "range must be a string in the format 'start-end'" 49 raise ValueError(f"Invalid page range: {page} - {msg}") 50 expanded_page_range.append(page) 51 52 elif isinstance(page, str) and page.isdigit(): 53 expanded_page_range.append(int(page)) 54 55 elif isinstance(page, str) and "-" in page: 56 start, end = page.split("-") 57 expanded_page_range.extend(range(int(start), int(end) + 1)) 58 59 else: 60 msg = "range must be a string in the format 'start-end' or an integer" 61 raise ValueError(f"Invalid page range: {page} - {msg}") 62 63 if not expanded_page_range: 64 raise ValueError("No valid page numbers or ranges found in the input list") 65 66 return expanded_page_range 67 68 69 @overload 70 def expit(x: float) -> float: ... 71 @overload 72 def expit(x: ndarray[Any, Any]) -> ndarray[Any, Any]: ... 73 def expit(x: float | ndarray[Any, Any]) -> float | ndarray[Any, Any]: 74 """ 75 Compute logistic sigmoid function. Maps input values to a range between 0 and 1 76 77 :param x: input value. Can be a scalar or a numpy array. 78 """ 79 return 1 / (1 + exp(-x)) 80 81 82 def _guess_mime_type(path: Path) -> str | None: 83 """ 84 Guess the MIME type of the provided file path. 85 86 :param path: The file path to get the MIME type for. 87 88 :returns: The MIME type of the provided file path, or `None` if the MIME type cannot be determined. 89 """ 90 extension = path.suffix.lower() 91 mime_type = mimetypes.guess_type(path.as_posix())[0] 92 # lookup custom mappings if the mime type is not found 93 return CUSTOM_MIMETYPES.get(extension, mime_type) 94 95 96 def _get_output_dir(out_dir: str) -> str: 97 """ 98 Find or create a writable directory for saving status files. 99 100 Tries in the following order: 101 102 1. ~/.haystack/{out_dir} 103 2. {tempdir}/haystack/{out_dir} 104 3. ./.haystack/{out_dir} 105 106 :raises RuntimeError: If no directory could be created. 107 :returns: 108 The path to the created directory. 109 """ 110 111 candidates = [ 112 Path.home() / ".haystack" / out_dir, 113 Path(tempfile.gettempdir()) / "haystack" / out_dir, 114 Path.cwd() / ".haystack" / out_dir, 115 ] 116 117 for candidate in candidates: 118 try: 119 candidate.mkdir(parents=True, exist_ok=True) 120 return str(candidate) 121 except Exception: 122 continue 123 124 raise RuntimeError( 125 f"Could not create a writable directory for output files in any of the following locations: {candidates}" 126 ) 127 128 129 def _deduplicate_documents(documents: list["Document"]) -> list["Document"]: 130 """ 131 Deduplicate a list of documents by their id keeping the duplicate with the highest score if a score is present. 132 133 :param documents: List of documents to deduplicate. 134 :returns: List of deduplicated documents. 135 """ 136 # Keep for each Document id the one with the highest score 137 highest_scoring_docs: dict[str, "Document"] = {} 138 for doc in documents: 139 score = doc.score if doc.score is not None else -inf 140 best = highest_scoring_docs.get(doc.id) 141 142 if best is None or score > (best.score if best.score is not None else -inf): 143 highest_scoring_docs[doc.id] = doc 144 145 return list(highest_scoring_docs.values()) 146 147 148 @overload 149 def _parse_dict_from_json( 150 text: str, expected_keys: list[str] | None = ..., raise_on_failure: Literal[True] = ... 151 ) -> dict[str, Any]: ... 152 @overload 153 def _parse_dict_from_json( 154 text: str, expected_keys: list[str] | None = ..., raise_on_failure: Literal[False] = ... 155 ) -> dict[str, Any] | None: ... 156 @overload 157 def _parse_dict_from_json( 158 text: str, expected_keys: list[str] | None = ..., raise_on_failure: bool = ... 159 ) -> dict[str, Any] | None: ... 160 def _parse_dict_from_json( 161 text: str, expected_keys: list[str] | None = None, raise_on_failure: bool = True 162 ) -> dict[str, Any] | None: 163 """ 164 Parses a JSON string containing a dictionary. 165 166 :param text: The string to parse. 167 :param expected_keys: A list of keys that must be present in the parsed dictionary. 168 :param raise_on_failure: If True, raises an exception on failure. If False, logs a warning and returns None. 169 170 :return: The parsed dictionary, or None if parsing fails and raise_on_failure is False. 171 :raises json.JSONDecodeError: If the text is not valid JSON and raise_on_failure is True. 172 :raises ValueError: If the parsed object is not a dictionary or has missing expected keys, 173 and `raise_on_failure` is True. 174 """ 175 cleaned_text = text.strip() 176 177 try: 178 parsed_json = json.loads(cleaned_text) 179 except json.JSONDecodeError as e: 180 if raise_on_failure: 181 raise e 182 logger.warning("Failed to parse JSON from text: {text}. Error: {error}", text=text, error=e) 183 return None 184 185 if not isinstance(parsed_json, dict): 186 if raise_on_failure: 187 raise ValueError(f"Expected a JSON object containing a dictionary but got {type(parsed_json).__name__}") 188 logger.warning( 189 "Expected a JSON object containing a dictionary but got {type}. Returning None", 190 type=type(parsed_json).__name__, 191 ) 192 return None 193 194 if not expected_keys: 195 return parsed_json 196 197 missing_keys = [key for key in expected_keys if key not in parsed_json] 198 if missing_keys: 199 if raise_on_failure: 200 raise ValueError(f"Missing expected keys in JSON: {missing_keys}. Got keys: {list(parsed_json.keys())}") 201 logger.warning( 202 "Missing expected keys in JSON: {missing_keys}. Got keys: {keys}", 203 missing_keys=missing_keys, 204 keys=list(parsed_json.keys()), 205 ) 206 return None 207 208 return parsed_json 209 210 211 def _normalize_metadata_field_name(metadata_field: str) -> str: 212 """ 213 Normalizes a metadata field name by removing the "meta." prefix if present. 214 """ 215 return metadata_field[5:] if metadata_field.startswith("meta.") else metadata_field