/ haystack / utils / auth.py
auth.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import os
  6  from abc import ABC, abstractmethod
  7  from collections.abc import Iterable
  8  from dataclasses import dataclass
  9  from enum import Enum
 10  from typing import Any
 11  
 12  
 13  class SecretType(Enum):
 14      """
 15      Type of secret: token (API key) or environment variable.
 16      """
 17  
 18      TOKEN = "token"
 19      ENV_VAR = "env_var"
 20  
 21      def __str__(self) -> str:
 22          return self.value
 23  
 24      @staticmethod
 25      def from_str(string: str) -> "SecretType":
 26          """
 27          Convert a string to a SecretType.
 28  
 29          :param string: The string to convert.
 30          """
 31          mapping = {e.value: e for e in SecretType}
 32          _type = mapping.get(string)
 33          if _type is None:
 34              raise ValueError(f"Unknown secret type '{string}'")
 35          return _type
 36  
 37  
 38  class Secret(ABC):
 39      """
 40      Encapsulates a secret used for authentication.
 41  
 42      Usage example:
 43      ```python
 44      from haystack.components.generators import OpenAIGenerator
 45      from haystack.utils import Secret
 46  
 47      generator = OpenAIGenerator(api_key=Secret.from_token("<here_goes_your_token>"))
 48      ```
 49      """
 50  
 51      @staticmethod
 52      def from_token(token: str) -> "Secret":
 53          """
 54          Create a token-based secret. Cannot be serialized.
 55  
 56          :param token:
 57              The token to use for authentication.
 58          """
 59          return TokenSecret(_token=token)
 60  
 61      @staticmethod
 62      def from_env_var(env_vars: str | list[str], *, strict: bool = True) -> "Secret":
 63          """
 64          Create an environment variable-based secret. Accepts one or more environment variables.
 65  
 66          Upon resolution, it returns a string token from the first environment variable that is set.
 67  
 68          :param env_vars:
 69              A single environment variable or an ordered list of
 70              candidate environment variables.
 71          :param strict:
 72              Whether to raise an exception if none of the environment
 73              variables are set.
 74          """
 75          if isinstance(env_vars, str):
 76              env_vars = [env_vars]
 77          return EnvVarSecret(_env_vars=tuple(env_vars), _strict=strict)
 78  
 79      def to_dict(self) -> dict[str, Any]:
 80          """
 81          Convert the secret to a JSON-serializable dictionary.
 82  
 83          Some secrets may not be serializable.
 84  
 85          :returns:
 86              The serialized policy.
 87          """
 88          out = {"type": self.type.value}
 89          inner = self._to_dict()
 90          assert all(k not in inner for k in out)
 91          out.update(inner)
 92          return out
 93  
 94      @staticmethod
 95      def from_dict(dict: dict[str, Any]) -> "Secret":  # noqa:A002
 96          """
 97          Create a secret from a JSON-serializable dictionary.
 98  
 99          :param dict:
100              The dictionary with the serialized data.
101          :returns:
102              The deserialized secret.
103          """
104          secret_map = {SecretType.TOKEN: TokenSecret, SecretType.ENV_VAR: EnvVarSecret}
105          secret_type = SecretType.from_str(dict["type"])
106          return secret_map[secret_type]._from_dict(dict)  # type: ignore
107  
108      @abstractmethod
109      def resolve_value(self) -> Any | None:
110          """
111          Resolve the secret to an atomic value. The semantics of the value is secret-dependent.
112  
113          :returns:
114              The value of the secret, if any.
115          """
116          pass
117  
118      @property
119      @abstractmethod
120      def type(self) -> SecretType:
121          """
122          The type of the secret.
123          """
124          pass
125  
126      @abstractmethod
127      def _to_dict(self) -> dict[str, Any]:
128          pass
129  
130      @staticmethod
131      @abstractmethod
132      def _from_dict(_: dict[str, Any]) -> "Secret":
133          pass
134  
135  
136  @dataclass(frozen=True)
137  class TokenSecret(Secret):
138      """
139      A secret that uses a string token/API key.
140  
141      Cannot be serialized.
142      """
143  
144      _token: str
145      _type: SecretType = SecretType.TOKEN
146  
147      def __post_init__(self) -> None:
148          super().__init__()
149          assert self._type == SecretType.TOKEN
150  
151          if len(self._token) == 0:
152              raise ValueError("Authentication token cannot be empty.")
153  
154      def _to_dict(self) -> dict[str, Any]:
155          raise ValueError(
156              "Cannot serialize token-based secret. Use an alternative secret type like environment variables."
157          )
158  
159      @staticmethod
160      def _from_dict(_: dict[str, Any]) -> "Secret":
161          raise ValueError(
162              "Cannot deserialize token-based secret. Use an alternative secret type like environment variables."
163          )
164  
165      def resolve_value(self) -> Any | None:
166          """Return the token."""
167          return self._token
168  
169      @property
170      def type(self) -> SecretType:
171          """The type of the secret."""
172          return self._type
173  
174  
175  @dataclass(frozen=True)
176  class EnvVarSecret(Secret):
177      """
178      A secret that accepts one or more environment variables.
179  
180      Upon resolution, it returns a string token from the first environment variable that is set. Can be serialized.
181      """
182  
183      _env_vars: tuple[str, ...]
184      _strict: bool = True
185      _type: SecretType = SecretType.ENV_VAR
186  
187      def __post_init__(self) -> None:
188          super().__init__()
189          assert self._type == SecretType.ENV_VAR
190  
191          if len(self._env_vars) == 0:
192              raise ValueError("One or more environment variables must be provided for the secret.")
193  
194      def _to_dict(self) -> dict[str, Any]:
195          return {"env_vars": list(self._env_vars), "strict": self._strict}
196  
197      @staticmethod
198      def _from_dict(dictionary: dict[str, Any]) -> "Secret":
199          return EnvVarSecret(tuple(dictionary["env_vars"]), _strict=dictionary["strict"])
200  
201      def resolve_value(self) -> Any | None:
202          """Resolve the secret to an atomic value. The semantics of the value is secret-dependent."""
203          out = None
204          for env_var in self._env_vars:
205              value = os.getenv(env_var)
206              if value is not None:
207                  out = value
208                  break
209          if out is None and self._strict:
210              raise ValueError(f"None of the following authentication environment variables are set: {self._env_vars}")
211          return out
212  
213      @property
214      def type(self) -> SecretType:
215          """The type of the secret."""
216          return self._type
217  
218  
219  def deserialize_secrets_inplace(data: dict[str, Any], keys: Iterable[str], *, recursive: bool = False) -> None:
220      """
221      Deserialize secrets in a dictionary inplace.
222  
223      :param data:
224          The dictionary with the serialized data.
225      :param keys:
226          The keys of the secrets to deserialize.
227      :param recursive:
228          Whether to recursively deserialize nested dictionaries.
229      """
230      for k, v in data.items():
231          if isinstance(v, dict) and recursive:
232              deserialize_secrets_inplace(v, keys)
233          elif k in keys and v is not None:
234              data[k] = Secret.from_dict(v)