/ ragaai_catalyst / prompt_manager.py
prompt_manager.py
1 import os 2 import requests 3 import json 4 import re 5 from .ragaai_catalyst import RagaAICatalyst 6 import copy 7 8 class PromptManager: 9 NUM_PROJECTS = 100 10 TIMEOUT = 10 11 12 def __init__(self, project_name): 13 """ 14 Initialize the PromptManager with a project name. 15 16 Args: 17 project_name (str): The name of the project. 18 19 Raises: 20 requests.RequestException: If there's an error with the API request. 21 ValueError: If the project is not found. 22 """ 23 self.project_name = project_name 24 self.base_url = f"{RagaAICatalyst.BASE_URL}/playground/prompt" 25 self.timeout = 10 26 self.size = 99999 #Number of projects to fetch 27 28 try: 29 response = requests.get( 30 f"{RagaAICatalyst.BASE_URL}/v2/llm/projects?size={self.size}", 31 headers={ 32 "Authorization": f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}', 33 }, 34 timeout=self.timeout, 35 ) 36 response.raise_for_status() 37 # logger.debug("Projects list retrieved successfully") 38 39 project_list = [ 40 project["name"] for project in response.json()["data"]["content"] 41 ] 42 self.project_id = [ 43 project["id"] for project in response.json()["data"]["content"] if project["name"]==project_name 44 ][0] 45 46 except (KeyError, json.JSONDecodeError) as e: 47 raise ValueError(f"Error parsing project list: {str(e)}") 48 49 if self.project_name not in project_list: 50 raise ValueError("Project not found. Please enter a valid project name") 51 52 53 self.headers = { 54 "Authorization": f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}', 55 "X-Project-Id": str(self.project_id) 56 } 57 58 59 def list_prompts(self): 60 """ 61 List all available prompts. 62 63 Returns: 64 list: A list of prompt names. 65 66 Raises: 67 requests.RequestException: If there's an error with the API request. 68 """ 69 prompt = Prompt() 70 try: 71 prompt_list = prompt.list_prompts(self.base_url, self.headers, self.timeout) 72 return prompt_list 73 except requests.RequestException as e: 74 raise requests.RequestException(f"Error listing prompts: {str(e)}") 75 76 def get_prompt(self, prompt_name, version=None): 77 """ 78 Get a specific prompt. 79 80 Args: 81 prompt_name (str): The name of the prompt. 82 version (str, optional): The version of the prompt. Defaults to None. 83 84 Returns: 85 PromptObject: An object representing the prompt. 86 87 Raises: 88 ValueError: If the prompt or version is not found. 89 requests.RequestException: If there's an error with the API request. 90 """ 91 try: 92 prompt_list = self.list_prompts() 93 except requests.RequestException as e: 94 raise requests.RequestException(f"Error fetching prompt list: {str(e)}") 95 96 if prompt_name not in prompt_list: 97 raise ValueError("Prompt not found. Please enter a valid prompt name") 98 99 try: 100 prompt_versions = self.list_prompt_versions(prompt_name) 101 except requests.RequestException as e: 102 raise requests.RequestException(f"Error fetching prompt versions: {str(e)}") 103 104 if version and version not in prompt_versions.keys(): 105 raise ValueError("Version not found. Please enter a valid version name") 106 107 prompt = Prompt() 108 try: 109 prompt_object = prompt.get_prompt(self.base_url, self.headers, self.timeout, prompt_name, version) 110 return prompt_object 111 except requests.RequestException as e: 112 raise requests.RequestException(f"Error fetching prompt: {str(e)}") 113 114 def list_prompt_versions(self, prompt_name): 115 """ 116 List all versions of a specific prompt. 117 118 Args: 119 prompt_name (str): The name of the prompt. 120 121 Returns: 122 dict: A dictionary mapping version names to prompt texts. 123 124 Raises: 125 ValueError: If the prompt is not found. 126 requests.RequestException: If there's an error with the API request. 127 """ 128 try: 129 prompt_list = self.list_prompts() 130 except requests.RequestException as e: 131 raise requests.RequestException(f"Error fetching prompt list: {str(e)}") 132 133 if prompt_name not in prompt_list: 134 raise ValueError("Prompt not found. Please enter a valid prompt name") 135 136 prompt = Prompt() 137 try: 138 prompt_versions = prompt.list_prompt_versions(self.base_url, self.headers, self.timeout, prompt_name) 139 return prompt_versions 140 except requests.RequestException as e: 141 raise requests.RequestException(f"Error fetching prompt versions: {str(e)}") 142 143 144 class Prompt: 145 def __init__(self): 146 """ 147 Initialize the Prompt class. 148 """ 149 pass 150 151 def list_prompts(self, url, headers, timeout): 152 """ 153 List all available prompts. 154 155 Args: 156 url (str): The base URL for the API. 157 headers (dict): The headers to be used in the request. 158 timeout (int): The timeout for the request. 159 160 Returns: 161 list: A list of prompt names. 162 163 Raises: 164 requests.RequestException: If there's an error with the API request. 165 ValueError: If there's an error parsing the prompt list. 166 """ 167 try: 168 response = requests.get(url, headers=headers, timeout=timeout) 169 response.raise_for_status() 170 prompt_list = [prompt["name"] for prompt in response.json()["data"]] 171 return prompt_list 172 except requests.RequestException as e: 173 raise requests.RequestException(f"Error listing prompts: {str(e)}") 174 except (KeyError, json.JSONDecodeError) as e: 175 raise ValueError(f"Error parsing prompt list: {str(e)}") 176 177 def _get_response_by_version(self, base_url, headers, timeout, prompt_name, version): 178 """ 179 Get a specific version of a prompt. 180 181 Args: 182 base_url (str): The base URL for the API. 183 headers (dict): The headers to be used in the request. 184 timeout (int): The timeout for the request. 185 prompt_name (str): The name of the prompt. 186 version (str): The version of the prompt. 187 188 Returns: 189 response: The response object containing the prompt version data. 190 191 Raises: 192 requests.RequestException: If there's an error with the API request. 193 ValueError: If there's an error parsing the prompt version. 194 """ 195 try: 196 response = requests.get(f"{base_url}/version/{prompt_name}?version={version}", 197 headers=headers, timeout=timeout) 198 response.raise_for_status() 199 except requests.RequestException as e: 200 raise requests.RequestException(f"Error fetching prompt version: {str(e)}") 201 except (KeyError, json.JSONDecodeError, IndexError) as e: 202 raise ValueError(f"Error parsing prompt version: {str(e)}") 203 return response 204 205 def _get_response(self, base_url, headers, timeout, prompt_name): 206 """ 207 Get the latest version of a prompt. 208 209 Args: 210 base_url (str): The base URL for the API. 211 headers (dict): The headers to be used in the request. 212 timeout (int): The timeout for the request. 213 prompt_name (str): The name of the prompt. 214 215 Returns: 216 response: The response object containing the latest prompt version data. 217 218 Raises: 219 requests.RequestException: If there's an error with the API request. 220 ValueError: If there's an error parsing the prompt version. 221 """ 222 try: 223 response = requests.get(f"{base_url}/version/{prompt_name}", 224 headers=headers, timeout=timeout) 225 response.raise_for_status() 226 except requests.RequestException as e: 227 raise requests.RequestException(f"Error fetching prompt version: {str(e)}") 228 except (KeyError, json.JSONDecodeError, IndexError) as e: 229 raise ValueError(f"Error parsing prompt version: {str(e)}") 230 return response 231 232 def _get_prompt_by_version(self, base_url, headers, timeout, prompt_name, version): 233 """ 234 Get a specific version of a prompt. 235 236 Args: 237 base_url (str): The base URL for the API. 238 headers (dict): The headers to be used in the request. 239 timeout (int): The timeout for the request. 240 prompt_name (str): The name of the prompt. 241 version (str): The version of the prompt. 242 243 Returns: 244 str: The text of the prompt. 245 246 Raises: 247 requests.RequestException: If there's an error with the API request. 248 """ 249 response = self._get_response_by_version(base_url, headers, timeout, prompt_name, version) 250 prompt_text = response.json()["data"]["docs"][0]["textFields"] 251 return prompt_text 252 253 def get_prompt(self, base_url, headers, timeout, prompt_name, version=None): 254 """ 255 Get a prompt, optionally specifying a version. 256 257 Args: 258 base_url (str): The base URL for the API. 259 headers (dict): The headers to be used in the request. 260 timeout (int): The timeout for the request. 261 prompt_name (str): The name of the prompt. 262 version (str, optional): The version of the prompt. Defaults to None. 263 264 Returns: 265 PromptObject: An object representing the prompt. 266 267 Raises: 268 requests.RequestException: If there's an error with the API request. 269 """ 270 if version: 271 response = self._get_response_by_version(base_url, headers, timeout, prompt_name, version) 272 prompt_text = response.json()["data"]["docs"][0]["textFields"] 273 prompt_parameters = response.json()["data"]["docs"][0]["modelSpecs"]["parameters"] 274 model = response.json()["data"]["docs"][0]["modelSpecs"]["model"] 275 else: 276 response = self._get_response(base_url, headers, timeout, prompt_name) 277 prompt_text = response.json()["data"]["docs"][0]["textFields"] 278 prompt_parameters = response.json()["data"]["docs"][0]["modelSpecs"]["parameters"] 279 model = response.json()["data"]["docs"][0]["modelSpecs"]["model"] 280 return PromptObject(prompt_text, prompt_parameters, model) 281 282 283 def list_prompt_versions(self, base_url, headers, timeout, prompt_name): 284 """ 285 List all versions of a specific prompt. 286 287 Args: 288 base_url (str): The base URL for the API. 289 headers (dict): The headers to be used in the request. 290 timeout (int): The timeout for the request. 291 prompt_name (str): The name of the prompt. 292 293 Returns: 294 dict: A dictionary mapping version names to prompt texts. 295 296 Raises: 297 requests.RequestException: If there's an error with the API request. 298 ValueError: If there's an error parsing the prompt versions. 299 """ 300 try: 301 response = requests.get(f"{base_url}/{prompt_name}/version", 302 headers=headers, timeout=timeout) 303 response.raise_for_status() 304 version_names = [version["name"] for version in response.json()["data"]] 305 prompt_versions = {} 306 for version in version_names: 307 prompt_versions[version] = self._get_prompt_by_version(base_url, headers, timeout, prompt_name, version) 308 return prompt_versions 309 except requests.RequestException as e: 310 raise requests.RequestException(f"Error listing prompt versions: {str(e)}") 311 except (KeyError, json.JSONDecodeError) as e: 312 raise ValueError(f"Error parsing prompt versions: {str(e)}") 313 314 315 class PromptObject: 316 def __init__(self, text, parameters, model): 317 """ 318 Initialize a PromptObject with the given text. 319 320 Args: 321 text (str): The text of the prompt. 322 parameters (dict): The parameters of the prompt. 323 model (str): The model of the prompt. 324 """ 325 self.text = text 326 self.parameters = parameters 327 self.model = model 328 329 def _extract_variable_from_content(self, content): 330 """ 331 Extract variables from the content. 332 333 Args: 334 content (str): The content containing variables. 335 336 Returns: 337 list: A list of variable names found in the content. 338 """ 339 pattern = r'\{\{(.*?)\}\}' 340 matches = re.findall(pattern, content) 341 variables = [match.strip() for match in matches if '"' not in match] 342 return variables 343 344 def _add_variable_value_to_content(self, content, user_variables): 345 """ 346 Add variable values to the content. 347 348 Args: 349 content (str): The content containing variables. 350 user_variables (dict): A dictionary of user-provided variable values. 351 352 Returns: 353 str: The content with variables replaced by their values. 354 """ 355 variables = self._extract_variable_from_content(content) 356 for key, value in user_variables.items(): 357 if not isinstance(value, str): 358 raise ValueError(f"Value for variable '{key}' must be a string, not {type(value).__name__}") 359 if key in variables: 360 content = content.replace(f"{{{{{key}}}}}", value) 361 return content 362 363 def compile(self, **kwargs): 364 """ 365 Compile the prompt by replacing variables with provided values. 366 367 Args: 368 **kwargs: Keyword arguments where keys are variable names and values are their replacements. 369 370 Returns: 371 str: The compiled prompt with variables replaced. 372 373 Raises: 374 ValueError: If there are missing or extra variables, or if a value is not a string. 375 """ 376 required_variables = self.get_variables() 377 provided_variables = set(kwargs.keys()) 378 379 missing_variables = [item for item in required_variables if item not in provided_variables] 380 extra_variables = [item for item in provided_variables if item not in required_variables] 381 382 if missing_variables: 383 raise ValueError(f"Missing variable(s): {', '.join(missing_variables)}") 384 if extra_variables: 385 raise ValueError(f"Extra variable(s) provided: {', '.join(extra_variables)}") 386 387 updated_text = copy.deepcopy(self.text) 388 389 for item in updated_text: 390 item["content"] = self._add_variable_value_to_content(item["content"], kwargs) 391 392 return updated_text 393 394 def get_variables(self): 395 """ 396 Get all variables in the prompt text. 397 398 Returns: 399 list: A list of variable names found in the prompt text. 400 """ 401 variables = set() 402 for item in self.text: 403 content = item["content"] 404 for var in self._extract_variable_from_content(content): 405 variables.add(var) 406 if variables: 407 return list(variables) 408 else: 409 return [] 410 411 def _convert_value(self, value, type_): 412 """ 413 Convert value based on type. 414 415 Args: 416 value: The value to be converted. 417 type_ (str): The type to convert the value to. 418 419 Returns: 420 The converted value. 421 """ 422 if type_ == "float": 423 return float(value) 424 elif type_ == "int": 425 return int(value) 426 return value # Default case, return as is 427 428 def get_model_parameters(self): 429 """ 430 Get all parameters in the prompt text. 431 432 Returns: 433 dict: A dictionary of parameters found in the prompt text. 434 """ 435 parameters = {} 436 for param in self.parameters: 437 if "value" in param: 438 parameters[param["name"]] = self._convert_value(param["value"], param["type"]) 439 else: 440 parameters[param["name"]] = "" 441 parameters["model"] = self.model 442 return parameters 443 444 def get_prompt_content(self): 445 return self.text