auth.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import os 6 from abc import ABC, abstractmethod 7 from collections.abc import Iterable 8 from dataclasses import dataclass 9 from enum import Enum 10 from typing import Any 11 12 13 class SecretType(Enum): 14 """ 15 Type of secret: token (API key) or environment variable. 16 """ 17 18 TOKEN = "token" 19 ENV_VAR = "env_var" 20 21 def __str__(self) -> str: 22 return self.value 23 24 @staticmethod 25 def from_str(string: str) -> "SecretType": 26 """ 27 Convert a string to a SecretType. 28 29 :param string: The string to convert. 30 """ 31 mapping = {e.value: e for e in SecretType} 32 _type = mapping.get(string) 33 if _type is None: 34 raise ValueError(f"Unknown secret type '{string}'") 35 return _type 36 37 38 class Secret(ABC): 39 """ 40 Encapsulates a secret used for authentication. 41 42 Usage example: 43 ```python 44 from haystack.components.generators import OpenAIGenerator 45 from haystack.utils import Secret 46 47 generator = OpenAIGenerator(api_key=Secret.from_token("<here_goes_your_token>")) 48 ``` 49 """ 50 51 @staticmethod 52 def from_token(token: str) -> "Secret": 53 """ 54 Create a token-based secret. Cannot be serialized. 55 56 :param token: 57 The token to use for authentication. 58 """ 59 return TokenSecret(_token=token) 60 61 @staticmethod 62 def from_env_var(env_vars: str | list[str], *, strict: bool = True) -> "Secret": 63 """ 64 Create an environment variable-based secret. Accepts one or more environment variables. 65 66 Upon resolution, it returns a string token from the first environment variable that is set. 67 68 :param env_vars: 69 A single environment variable or an ordered list of 70 candidate environment variables. 71 :param strict: 72 Whether to raise an exception if none of the environment 73 variables are set. 74 """ 75 if isinstance(env_vars, str): 76 env_vars = [env_vars] 77 return EnvVarSecret(_env_vars=tuple(env_vars), _strict=strict) 78 79 def to_dict(self) -> dict[str, Any]: 80 """ 81 Convert the secret to a JSON-serializable dictionary. 82 83 Some secrets may not be serializable. 84 85 :returns: 86 The serialized policy. 87 """ 88 out = {"type": self.type.value} 89 inner = self._to_dict() 90 assert all(k not in inner for k in out) 91 out.update(inner) 92 return out 93 94 @staticmethod 95 def from_dict(dict: dict[str, Any]) -> "Secret": # noqa:A002 96 """ 97 Create a secret from a JSON-serializable dictionary. 98 99 :param dict: 100 The dictionary with the serialized data. 101 :returns: 102 The deserialized secret. 103 """ 104 secret_map = {SecretType.TOKEN: TokenSecret, SecretType.ENV_VAR: EnvVarSecret} 105 secret_type = SecretType.from_str(dict["type"]) 106 return secret_map[secret_type]._from_dict(dict) # type: ignore 107 108 @abstractmethod 109 def resolve_value(self) -> Any | None: 110 """ 111 Resolve the secret to an atomic value. The semantics of the value is secret-dependent. 112 113 :returns: 114 The value of the secret, if any. 115 """ 116 pass 117 118 @property 119 @abstractmethod 120 def type(self) -> SecretType: 121 """ 122 The type of the secret. 123 """ 124 pass 125 126 @abstractmethod 127 def _to_dict(self) -> dict[str, Any]: 128 pass 129 130 @staticmethod 131 @abstractmethod 132 def _from_dict(_: dict[str, Any]) -> "Secret": 133 pass 134 135 136 @dataclass(frozen=True) 137 class TokenSecret(Secret): 138 """ 139 A secret that uses a string token/API key. 140 141 Cannot be serialized. 142 """ 143 144 _token: str 145 _type: SecretType = SecretType.TOKEN 146 147 def __post_init__(self) -> None: 148 super().__init__() 149 assert self._type == SecretType.TOKEN 150 151 if len(self._token) == 0: 152 raise ValueError("Authentication token cannot be empty.") 153 154 def _to_dict(self) -> dict[str, Any]: 155 raise ValueError( 156 "Cannot serialize token-based secret. Use an alternative secret type like environment variables." 157 ) 158 159 @staticmethod 160 def _from_dict(_: dict[str, Any]) -> "Secret": 161 raise ValueError( 162 "Cannot deserialize token-based secret. Use an alternative secret type like environment variables." 163 ) 164 165 def resolve_value(self) -> Any | None: 166 """Return the token.""" 167 return self._token 168 169 @property 170 def type(self) -> SecretType: 171 """The type of the secret.""" 172 return self._type 173 174 175 @dataclass(frozen=True) 176 class EnvVarSecret(Secret): 177 """ 178 A secret that accepts one or more environment variables. 179 180 Upon resolution, it returns a string token from the first environment variable that is set. Can be serialized. 181 """ 182 183 _env_vars: tuple[str, ...] 184 _strict: bool = True 185 _type: SecretType = SecretType.ENV_VAR 186 187 def __post_init__(self) -> None: 188 super().__init__() 189 assert self._type == SecretType.ENV_VAR 190 191 if len(self._env_vars) == 0: 192 raise ValueError("One or more environment variables must be provided for the secret.") 193 194 def _to_dict(self) -> dict[str, Any]: 195 return {"env_vars": list(self._env_vars), "strict": self._strict} 196 197 @staticmethod 198 def _from_dict(dictionary: dict[str, Any]) -> "Secret": 199 return EnvVarSecret(tuple(dictionary["env_vars"]), _strict=dictionary["strict"]) 200 201 def resolve_value(self) -> Any | None: 202 """Resolve the secret to an atomic value. The semantics of the value is secret-dependent.""" 203 out = None 204 for env_var in self._env_vars: 205 value = os.getenv(env_var) 206 if value is not None: 207 out = value 208 break 209 if out is None and self._strict: 210 raise ValueError(f"None of the following authentication environment variables are set: {self._env_vars}") 211 return out 212 213 @property 214 def type(self) -> SecretType: 215 """The type of the secret.""" 216 return self._type 217 218 219 def deserialize_secrets_inplace(data: dict[str, Any], keys: Iterable[str], *, recursive: bool = False) -> None: 220 """ 221 Deserialize secrets in a dictionary inplace. 222 223 :param data: 224 The dictionary with the serialized data. 225 :param keys: 226 The keys of the secrets to deserialize. 227 :param recursive: 228 Whether to recursively deserialize nested dictionaries. 229 """ 230 for k, v in data.items(): 231 if isinstance(v, dict) and recursive: 232 deserialize_secrets_inplace(v, keys) 233 elif k in keys and v is not None: 234 data[k] = Secret.from_dict(v)