/ ragaai_catalyst / prompt_manager.py
prompt_manager.py
  1  import os
  2  import requests
  3  import json
  4  import re
  5  from .ragaai_catalyst import RagaAICatalyst
  6  import copy
  7  
  8  class PromptManager:
  9      NUM_PROJECTS = 100
 10      TIMEOUT = 10
 11  
 12      def __init__(self, project_name):
 13          """
 14          Initialize the PromptManager with a project name.
 15  
 16          Args:
 17              project_name (str): The name of the project.
 18  
 19          Raises:
 20              requests.RequestException: If there's an error with the API request.
 21              ValueError: If the project is not found.
 22          """
 23          self.project_name = project_name
 24          self.base_url = f"{RagaAICatalyst.BASE_URL}/playground/prompt"
 25          self.timeout = 10
 26          self.size = 99999 #Number of projects to fetch
 27  
 28          try:
 29              response = requests.get(
 30                  f"{RagaAICatalyst.BASE_URL}/v2/llm/projects?size={self.size}",
 31                  headers={
 32                      "Authorization": f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
 33                  },
 34                  timeout=self.timeout,
 35              )
 36              response.raise_for_status()
 37              # logger.debug("Projects list retrieved successfully")
 38  
 39              project_list = [
 40                  project["name"] for project in response.json()["data"]["content"]
 41              ]
 42              self.project_id = [
 43              project["id"] for project in response.json()["data"]["content"] if project["name"]==project_name
 44              ][0]
 45  
 46          except (KeyError, json.JSONDecodeError) as e:
 47              raise ValueError(f"Error parsing project list: {str(e)}")
 48  
 49          if self.project_name not in project_list:
 50              raise ValueError("Project not found. Please enter a valid project name")
 51  
 52  
 53          self.headers = {
 54                  "Authorization": f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
 55                  "X-Project-Id": str(self.project_id)
 56              }
 57  
 58  
 59      def list_prompts(self):
 60          """
 61          List all available prompts.
 62  
 63          Returns:
 64              list: A list of prompt names.
 65  
 66          Raises:
 67              requests.RequestException: If there's an error with the API request.
 68          """
 69          prompt = Prompt()
 70          try:
 71              prompt_list = prompt.list_prompts(self.base_url, self.headers, self.timeout)
 72              return prompt_list
 73          except requests.RequestException as e:
 74              raise requests.RequestException(f"Error listing prompts: {str(e)}")
 75      
 76      def get_prompt(self, prompt_name, version=None):
 77          """
 78          Get a specific prompt.
 79  
 80          Args:
 81              prompt_name (str): The name of the prompt.
 82              version (str, optional): The version of the prompt. Defaults to None.
 83  
 84          Returns:
 85              PromptObject: An object representing the prompt.
 86  
 87          Raises:
 88              ValueError: If the prompt or version is not found.
 89              requests.RequestException: If there's an error with the API request.
 90          """
 91          try:
 92              prompt_list = self.list_prompts()
 93          except requests.RequestException as e:
 94              raise requests.RequestException(f"Error fetching prompt list: {str(e)}")
 95  
 96          if prompt_name not in prompt_list:
 97              raise ValueError("Prompt not found. Please enter a valid prompt name")
 98  
 99          try:
100              prompt_versions = self.list_prompt_versions(prompt_name)
101          except requests.RequestException as e:
102              raise requests.RequestException(f"Error fetching prompt versions: {str(e)}")
103  
104          if version and version not in prompt_versions.keys():
105              raise ValueError("Version not found. Please enter a valid version name")
106  
107          prompt = Prompt()
108          try:
109              prompt_object = prompt.get_prompt(self.base_url, self.headers, self.timeout, prompt_name, version)
110              return prompt_object
111          except requests.RequestException as e:
112              raise requests.RequestException(f"Error fetching prompt: {str(e)}")
113  
114      def list_prompt_versions(self, prompt_name):
115          """
116          List all versions of a specific prompt.
117  
118          Args:
119              prompt_name (str): The name of the prompt.
120  
121          Returns:
122              dict: A dictionary mapping version names to prompt texts.
123  
124          Raises:
125              ValueError: If the prompt is not found.
126              requests.RequestException: If there's an error with the API request.
127          """
128          try:
129              prompt_list = self.list_prompts()
130          except requests.RequestException as e:
131              raise requests.RequestException(f"Error fetching prompt list: {str(e)}")
132  
133          if prompt_name not in prompt_list:
134              raise ValueError("Prompt not found. Please enter a valid prompt name")
135          
136          prompt = Prompt()
137          try:
138              prompt_versions = prompt.list_prompt_versions(self.base_url, self.headers, self.timeout, prompt_name)
139              return prompt_versions
140          except requests.RequestException as e:
141              raise requests.RequestException(f"Error fetching prompt versions: {str(e)}")
142  
143  
144  class Prompt:
145      def __init__(self):
146          """
147          Initialize the Prompt class.
148          """
149          pass
150  
151      def list_prompts(self, url, headers, timeout):
152          """
153          List all available prompts.
154  
155          Args:
156              url (str): The base URL for the API.
157              headers (dict): The headers to be used in the request.
158              timeout (int): The timeout for the request.
159  
160          Returns:
161              list: A list of prompt names.
162  
163          Raises:
164              requests.RequestException: If there's an error with the API request.
165              ValueError: If there's an error parsing the prompt list.
166          """
167          try:
168              response = requests.get(url, headers=headers, timeout=timeout)
169              response.raise_for_status()
170              prompt_list = [prompt["name"] for prompt in response.json()["data"]]                        
171              return prompt_list
172          except requests.RequestException as e:
173              raise requests.RequestException(f"Error listing prompts: {str(e)}")
174          except (KeyError, json.JSONDecodeError) as e:
175              raise ValueError(f"Error parsing prompt list: {str(e)}")
176  
177      def _get_response_by_version(self, base_url, headers, timeout, prompt_name, version):
178          """
179          Get a specific version of a prompt.
180  
181          Args:
182              base_url (str): The base URL for the API.
183              headers (dict): The headers to be used in the request.
184              timeout (int): The timeout for the request.
185              prompt_name (str): The name of the prompt.
186              version (str): The version of the prompt.
187  
188          Returns:
189              response: The response object containing the prompt version data.
190  
191          Raises:
192              requests.RequestException: If there's an error with the API request.
193              ValueError: If there's an error parsing the prompt version.
194          """
195          try:
196              response = requests.get(f"{base_url}/version/{prompt_name}?version={version}",
197                                      headers=headers, timeout=timeout)
198              response.raise_for_status()
199          except requests.RequestException as e:
200              raise requests.RequestException(f"Error fetching prompt version: {str(e)}")
201          except (KeyError, json.JSONDecodeError, IndexError) as e:
202              raise ValueError(f"Error parsing prompt version: {str(e)}")
203          return response
204  
205      def _get_response(self, base_url, headers, timeout, prompt_name):
206          """
207          Get the latest version of a prompt.
208  
209          Args:
210              base_url (str): The base URL for the API.
211              headers (dict): The headers to be used in the request.
212              timeout (int): The timeout for the request.
213              prompt_name (str): The name of the prompt.
214  
215          Returns:
216              response: The response object containing the latest prompt version data.
217  
218          Raises:
219              requests.RequestException: If there's an error with the API request.
220              ValueError: If there's an error parsing the prompt version.
221          """
222          try:
223              response = requests.get(f"{base_url}/version/{prompt_name}",
224                                  headers=headers, timeout=timeout)
225              response.raise_for_status()
226          except requests.RequestException as e:
227              raise requests.RequestException(f"Error fetching prompt version: {str(e)}")
228          except (KeyError, json.JSONDecodeError, IndexError) as e:
229              raise ValueError(f"Error parsing prompt version: {str(e)}")
230          return response
231  
232      def _get_prompt_by_version(self, base_url, headers, timeout, prompt_name, version):
233          """
234          Get a specific version of a prompt.
235  
236          Args:
237              base_url (str): The base URL for the API.
238              headers (dict): The headers to be used in the request.
239              timeout (int): The timeout for the request.
240              prompt_name (str): The name of the prompt.
241              version (str): The version of the prompt.
242  
243          Returns:
244              str: The text of the prompt.
245  
246          Raises:
247              requests.RequestException: If there's an error with the API request.
248          """
249          response = self._get_response_by_version(base_url, headers, timeout, prompt_name, version)
250          prompt_text = response.json()["data"]["docs"][0]["textFields"]
251          return prompt_text
252  
253      def get_prompt(self, base_url, headers, timeout, prompt_name, version=None):
254          """
255          Get a prompt, optionally specifying a version.
256  
257          Args:
258              base_url (str): The base URL for the API.
259              headers (dict): The headers to be used in the request.
260              timeout (int): The timeout for the request.
261              prompt_name (str): The name of the prompt.
262              version (str, optional): The version of the prompt. Defaults to None.
263  
264          Returns:
265              PromptObject: An object representing the prompt.
266  
267          Raises:
268              requests.RequestException: If there's an error with the API request.
269          """
270          if version:
271              response = self._get_response_by_version(base_url, headers, timeout, prompt_name, version)
272              prompt_text = response.json()["data"]["docs"][0]["textFields"]
273              prompt_parameters = response.json()["data"]["docs"][0]["modelSpecs"]["parameters"]
274              model = response.json()["data"]["docs"][0]["modelSpecs"]["model"]
275          else:
276              response = self._get_response(base_url, headers, timeout, prompt_name)
277              prompt_text = response.json()["data"]["docs"][0]["textFields"]
278              prompt_parameters = response.json()["data"]["docs"][0]["modelSpecs"]["parameters"]
279              model = response.json()["data"]["docs"][0]["modelSpecs"]["model"]
280          return PromptObject(prompt_text, prompt_parameters, model)
281  
282  
283      def list_prompt_versions(self, base_url, headers, timeout, prompt_name):
284          """
285          List all versions of a specific prompt.
286  
287          Args:
288              base_url (str): The base URL for the API.
289              headers (dict): The headers to be used in the request.
290              timeout (int): The timeout for the request.
291              prompt_name (str): The name of the prompt.
292  
293          Returns:
294              dict: A dictionary mapping version names to prompt texts.
295  
296          Raises:
297              requests.RequestException: If there's an error with the API request.
298              ValueError: If there's an error parsing the prompt versions.
299          """
300          try:
301              response = requests.get(f"{base_url}/{prompt_name}/version",
302                                      headers=headers, timeout=timeout)
303              response.raise_for_status()
304              version_names = [version["name"] for version in response.json()["data"]]
305              prompt_versions = {}
306              for version in version_names:
307                  prompt_versions[version] = self._get_prompt_by_version(base_url, headers, timeout, prompt_name, version)
308              return prompt_versions
309          except requests.RequestException as e:
310              raise requests.RequestException(f"Error listing prompt versions: {str(e)}")
311          except (KeyError, json.JSONDecodeError) as e:
312              raise ValueError(f"Error parsing prompt versions: {str(e)}")
313  
314  
315  class PromptObject:
316      def __init__(self, text, parameters, model):
317          """
318          Initialize a PromptObject with the given text.
319  
320          Args:
321              text (str): The text of the prompt.
322              parameters (dict): The parameters of the prompt.
323              model (str): The model of the prompt.
324          """
325          self.text = text
326          self.parameters = parameters
327          self.model = model
328      
329      def _extract_variable_from_content(self, content):
330          """
331          Extract variables from the content.
332  
333          Args:
334              content (str): The content containing variables.
335  
336          Returns:
337              list: A list of variable names found in the content.
338          """
339          pattern = r'\{\{(.*?)\}\}'
340          matches = re.findall(pattern, content)
341          variables = [match.strip() for match in matches if '"' not in match]
342          return variables
343  
344      def _add_variable_value_to_content(self, content, user_variables):
345          """
346          Add variable values to the content.
347  
348          Args:
349              content (str): The content containing variables.
350              user_variables (dict): A dictionary of user-provided variable values.
351  
352          Returns:
353              str: The content with variables replaced by their values.
354          """
355          variables = self._extract_variable_from_content(content)
356          for key, value in user_variables.items():
357              if not isinstance(value, str):
358                  raise ValueError(f"Value for variable '{key}' must be a string, not {type(value).__name__}")
359              if key in variables:
360                  content = content.replace(f"{{{{{key}}}}}", value)
361          return content
362  
363      def compile(self, **kwargs):
364          """
365          Compile the prompt by replacing variables with provided values.
366  
367          Args:
368              **kwargs: Keyword arguments where keys are variable names and values are their replacements.
369  
370          Returns:
371              str: The compiled prompt with variables replaced.
372  
373          Raises:
374              ValueError: If there are missing or extra variables, or if a value is not a string.
375          """
376          required_variables = self.get_variables()
377          provided_variables = set(kwargs.keys())
378  
379          missing_variables = [item for item in required_variables if item not in provided_variables]
380          extra_variables = [item for item in provided_variables if item not in required_variables]
381  
382          if missing_variables:
383              raise ValueError(f"Missing variable(s): {', '.join(missing_variables)}")
384          if extra_variables:
385              raise ValueError(f"Extra variable(s) provided: {', '.join(extra_variables)}")
386  
387          updated_text = copy.deepcopy(self.text)
388  
389          for item in updated_text:
390              item["content"] = self._add_variable_value_to_content(item["content"], kwargs)
391  
392          return updated_text
393      
394      def get_variables(self):
395          """
396          Get all variables in the prompt text.
397  
398          Returns:
399              list: A list of variable names found in the prompt text.
400          """
401          variables = set()
402          for item in self.text:
403              content = item["content"]
404              for var in self._extract_variable_from_content(content):
405                  variables.add(var)
406          if variables:
407              return list(variables)
408          else:
409              return []
410      
411      def _convert_value(self, value, type_):
412          """
413          Convert value based on type.
414  
415          Args:
416              value: The value to be converted.
417              type_ (str): The type to convert the value to.
418  
419          Returns:
420              The converted value.
421          """
422          if type_ == "float":
423              return float(value)
424          elif type_ == "int":
425              return int(value)
426          return value  # Default case, return as is
427  
428      def get_model_parameters(self):
429          """
430          Get all parameters in the prompt text.
431  
432          Returns:
433              dict: A dictionary of parameters found in the prompt text.
434          """
435          parameters = {}
436          for param in self.parameters:
437              if "value" in param:
438                  parameters[param["name"]] = self._convert_value(param["value"], param["type"])
439              else:
440                  parameters[param["name"]] = ""
441          parameters["model"] = self.model
442          return parameters    
443      
444      def get_prompt_content(self):
445          return self.text