/ haystack / utils / misc.py
misc.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import json
  6  import mimetypes
  7  import tempfile
  8  from math import inf
  9  from pathlib import Path
 10  from typing import TYPE_CHECKING, Any, Literal, overload
 11  
 12  from numpy import exp, ndarray
 13  
 14  from haystack import logging
 15  
 16  if TYPE_CHECKING:
 17      from haystack.dataclasses import Document
 18  
 19  CUSTOM_MIMETYPES = {
 20      # we add markdown because it is not added by the mimetypes module
 21      # see https://github.com/python/cpython/pull/17995
 22      ".md": "text/markdown",
 23      ".markdown": "text/markdown",
 24      # we add msg because it is not added by the mimetypes module
 25      ".msg": "application/vnd.ms-outlook",
 26  }
 27  
 28  logger = logging.getLogger(__name__)
 29  
 30  
 31  def expand_page_range(page_range: list[str | int]) -> list[int]:
 32      """
 33      Takes a list of page numbers and ranges and expands them into a list of page numbers.
 34  
 35      For example, given a page_range=['1-3', '5', '8', '10-12'] the function will return [1, 2, 3, 5, 8, 10, 11, 12]
 36  
 37      :param page_range: List of page numbers and ranges
 38      :returns:
 39          An expanded list of page integers
 40  
 41      """
 42      expanded_page_range = []
 43  
 44      for page in page_range:
 45          if isinstance(page, int):
 46              # check if it's a range wrongly passed as an integer expression
 47              if "-" in str(page):
 48                  msg = "range must be a string in the format 'start-end'"
 49                  raise ValueError(f"Invalid page range: {page} - {msg}")
 50              expanded_page_range.append(page)
 51  
 52          elif isinstance(page, str) and page.isdigit():
 53              expanded_page_range.append(int(page))
 54  
 55          elif isinstance(page, str) and "-" in page:
 56              start, end = page.split("-")
 57              expanded_page_range.extend(range(int(start), int(end) + 1))
 58  
 59          else:
 60              msg = "range must be a string in the format 'start-end' or an integer"
 61              raise ValueError(f"Invalid page range: {page} - {msg}")
 62  
 63      if not expanded_page_range:
 64          raise ValueError("No valid page numbers or ranges found in the input list")
 65  
 66      return expanded_page_range
 67  
 68  
 69  @overload
 70  def expit(x: float) -> float: ...
 71  @overload
 72  def expit(x: ndarray[Any, Any]) -> ndarray[Any, Any]: ...
 73  def expit(x: float | ndarray[Any, Any]) -> float | ndarray[Any, Any]:
 74      """
 75      Compute logistic sigmoid function. Maps input values to a range between 0 and 1
 76  
 77      :param x: input value. Can be a scalar or a numpy array.
 78      """
 79      return 1 / (1 + exp(-x))
 80  
 81  
 82  def _guess_mime_type(path: Path) -> str | None:
 83      """
 84      Guess the MIME type of the provided file path.
 85  
 86      :param path: The file path to get the MIME type for.
 87  
 88      :returns: The MIME type of the provided file path, or `None` if the MIME type cannot be determined.
 89      """
 90      extension = path.suffix.lower()
 91      mime_type = mimetypes.guess_type(path.as_posix())[0]
 92      # lookup custom mappings if the mime type is not found
 93      return CUSTOM_MIMETYPES.get(extension, mime_type)
 94  
 95  
 96  def _get_output_dir(out_dir: str) -> str:
 97      """
 98      Find or create a writable directory for saving status files.
 99  
100      Tries in the following order:
101  
102          1. ~/.haystack/{out_dir}
103          2. {tempdir}/haystack/{out_dir}
104          3. ./.haystack/{out_dir}
105  
106      :raises RuntimeError: If no directory could be created.
107      :returns:
108          The path to the created directory.
109      """
110  
111      candidates = [
112          Path.home() / ".haystack" / out_dir,
113          Path(tempfile.gettempdir()) / "haystack" / out_dir,
114          Path.cwd() / ".haystack" / out_dir,
115      ]
116  
117      for candidate in candidates:
118          try:
119              candidate.mkdir(parents=True, exist_ok=True)
120              return str(candidate)
121          except Exception:
122              continue
123  
124      raise RuntimeError(
125          f"Could not create a writable directory for output files in any of the following locations: {candidates}"
126      )
127  
128  
129  def _deduplicate_documents(documents: list["Document"]) -> list["Document"]:
130      """
131      Deduplicate a list of documents by their id keeping the duplicate with the highest score if a score is present.
132  
133      :param documents: List of documents to deduplicate.
134      :returns: List of deduplicated documents.
135      """
136      # Keep for each Document id the one with the highest score
137      highest_scoring_docs: dict[str, "Document"] = {}
138      for doc in documents:
139          score = doc.score if doc.score is not None else -inf
140          best = highest_scoring_docs.get(doc.id)
141  
142          if best is None or score > (best.score if best.score is not None else -inf):
143              highest_scoring_docs[doc.id] = doc
144  
145      return list(highest_scoring_docs.values())
146  
147  
148  @overload
149  def _parse_dict_from_json(
150      text: str, expected_keys: list[str] | None = ..., raise_on_failure: Literal[True] = ...
151  ) -> dict[str, Any]: ...
152  @overload
153  def _parse_dict_from_json(
154      text: str, expected_keys: list[str] | None = ..., raise_on_failure: Literal[False] = ...
155  ) -> dict[str, Any] | None: ...
156  @overload
157  def _parse_dict_from_json(
158      text: str, expected_keys: list[str] | None = ..., raise_on_failure: bool = ...
159  ) -> dict[str, Any] | None: ...
160  def _parse_dict_from_json(
161      text: str, expected_keys: list[str] | None = None, raise_on_failure: bool = True
162  ) -> dict[str, Any] | None:
163      """
164      Parses a JSON string containing a dictionary.
165  
166      :param text: The string to parse.
167      :param expected_keys: A list of keys that must be present in the parsed dictionary.
168      :param raise_on_failure: If True, raises an exception on failure. If False, logs a warning and returns None.
169  
170      :return: The parsed dictionary, or None if parsing fails and raise_on_failure is False.
171      :raises json.JSONDecodeError: If the text is not valid JSON and raise_on_failure is True.
172      :raises ValueError: If the parsed object is not a dictionary or has missing expected keys,
173          and `raise_on_failure` is True.
174      """
175      cleaned_text = text.strip()
176  
177      try:
178          parsed_json = json.loads(cleaned_text)
179      except json.JSONDecodeError as e:
180          if raise_on_failure:
181              raise e
182          logger.warning("Failed to parse JSON from text: {text}. Error: {error}", text=text, error=e)
183          return None
184  
185      if not isinstance(parsed_json, dict):
186          if raise_on_failure:
187              raise ValueError(f"Expected a JSON object containing a dictionary but got {type(parsed_json).__name__}")
188          logger.warning(
189              "Expected a JSON object containing a dictionary but got {type}. Returning None",
190              type=type(parsed_json).__name__,
191          )
192          return None
193  
194      if not expected_keys:
195          return parsed_json
196  
197      missing_keys = [key for key in expected_keys if key not in parsed_json]
198      if missing_keys:
199          if raise_on_failure:
200              raise ValueError(f"Missing expected keys in JSON: {missing_keys}. Got keys: {list(parsed_json.keys())}")
201          logger.warning(
202              "Missing expected keys in JSON: {missing_keys}. Got keys: {keys}",
203              missing_keys=missing_keys,
204              keys=list(parsed_json.keys()),
205          )
206          return None
207  
208      return parsed_json
209  
210  
211  def _normalize_metadata_field_name(metadata_field: str) -> str:
212      """
213      Normalizes a metadata field name by removing the "meta." prefix if present.
214      """
215      return metadata_field[5:] if metadata_field.startswith("meta.") else metadata_field