/ ragaai_catalyst / experiment.py
experiment.py
  1  import os
  2  import requests
  3  import logging
  4  import pandas as pd
  5  from .utils import response_checker
  6  from .ragaai_catalyst import RagaAICatalyst
  7  
  8  logging.basicConfig(level=logging.DEBUG)
  9  logger = logging.getLogger(__name__)
 10  
 11  get_token = RagaAICatalyst.get_token
 12  
 13  
 14  class Experiment:
 15      BASE_URL = None
 16      TIMEOUT = 10
 17      NUM_PROJECTS = 100
 18  
 19      def __init__(
 20          self, project_name, experiment_name, experiment_description, dataset_name
 21      ):
 22          """
 23          Initializes the Experiment object with the provided project details and initializes various attributes.
 24  
 25          Parameters:
 26              project_name (str): The name of the project.
 27              experiment_name (str): The name of the experiment.
 28              experiment_description (str): The description of the experiment.
 29              dataset_name (str): The name of the dataset.
 30  
 31          Returns:
 32              None
 33          """
 34          Experiment.BASE_URL = RagaAICatalyst.BASE_URL
 35          self.project_name = project_name
 36          self.experiment_name = experiment_name
 37          self.experiment_description = experiment_description
 38          self.dataset_name = dataset_name
 39          self.experiment_id = None
 40          self.job_id = None
 41  
 42          params = {
 43              "size": str(self.NUM_PROJECTS),
 44              "page": "0",
 45              "type": "llm",
 46          }
 47          headers = {
 48              "Content-Type": "application/json",
 49              "Authorization": f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
 50          }
 51          response = requests.get(
 52              f"{RagaAICatalyst.BASE_URL}/projects",
 53              params=params,
 54              headers=headers,
 55              timeout=10,
 56          )
 57          response.raise_for_status()
 58          # logger.debug("Projects list retrieved successfully")
 59          experiment_list = [exp["name"] for project in response.json()["data"]["content"] if project["name"] == self.project_name for exp in project["experiments"]]
 60          # print(experiment_list)
 61          if self.experiment_name in experiment_list:
 62              raise ValueError("The experiment name already exists in the project. Enter a unique experiment name.")
 63  
 64          self.access_key = os.getenv("RAGAAI_CATALYST_ACCESS_KEY")
 65          self.secret_key = os.getenv("RAGAAI_CATALYST_SECRET_KEY")
 66  
 67          self.token = (
 68              os.getenv("RAGAAI_CATALYST_TOKEN")
 69              if os.getenv("RAGAAI_CATALYST_TOKEN") is not None
 70              else get_token()
 71          )
 72          
 73          if not self._check_if_project_exists(project_name=project_name):
 74              raise ValueError(f"Project '{project_name}' not found. Please enter a valid project name")
 75          
 76          if not self._check_if_dataset_exists(project_name=project_name,dataset_name=dataset_name):
 77              raise ValueError(f"dataset '{dataset_name}' not found. Please enter a valid dataset name")
 78  
 79  
 80          self.metrics = []
 81      def _check_if_dataset_exists(self,project_name,dataset_name):
 82          headers = {
 83              "X-Project-Name":project_name,
 84              # "accept":"application/json, text/plain, */*",
 85              "Authorization": f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
 86          }
 87          response = requests.get(
 88              f"{RagaAICatalyst.BASE_URL}/v1/llm/sub-datasets?projectName={project_name}",
 89              headers=headers,
 90              timeout=self.TIMEOUT,
 91          )
 92          response.raise_for_status()
 93          logger.debug("dataset list retrieved successfully")
 94          dataset_list = [
 95              item['name'] for item in response.json()['data']['content']
 96          ]
 97          exists = dataset_name in dataset_list
 98          if exists:
 99              logger.info(f"dataset '{dataset_name}' exists.")
100          else:
101              logger.info(f"dataset '{dataset_name}' does not exist.")
102          return exists
103  
104  
105  
106  
107      def _check_if_project_exists(self,project_name,num_projects=100):
108          # TODO: 1. List All projects
109          params = {
110              "size": str(num_projects),
111              "page": "0",
112              "type": "llm",
113          }
114          headers = {
115              "Content-Type": "application/json",
116              "Authorization": f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
117          }
118          response = requests.get(
119              f"{RagaAICatalyst.BASE_URL}/projects",
120              params=params,
121              headers=headers,
122              timeout=self.TIMEOUT,
123          )
124          response.raise_for_status()
125          logger.debug("Projects list retrieved successfully")
126          project_list = [
127              project["name"] for project in response.json()["data"]["content"]
128          ]
129          
130          # TODO: 2. Check if the given project_name exists
131          # TODO: 3. Return bool (True / False output)
132          exists = project_name in project_list
133          if exists:
134              logger.info(f"Project '{project_name}' exists.")
135          else:
136              logger.info(f"Project '{project_name}' does not exist.")
137          return exists
138          
139      def list_experiments(self):
140          """
141          Retrieves a list of experiments associated with the current project.
142  
143          Returns:
144              list: A list of experiment names.
145  
146          Raises:
147              requests.exceptions.RequestException: If the request fails.
148  
149          """
150  
151          def make_request():
152              headers = {
153                  "authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
154                  "X-Project-Name": self.project_name,
155              }
156              params = {
157                  "name": self.project_name,
158              }
159              response = requests.get(
160                  f"{Experiment.BASE_URL}/project",
161                  headers=headers,
162                  params=params,
163                  timeout=Experiment.TIMEOUT,
164              )
165              return response
166  
167          response = make_request()
168          response_checker(response, "Experiment.list_experiments")
169          if response.status_code == 401:
170              get_token()  # Fetch a new token and set it in the environment
171              response = make_request()  # Retry the request
172          if response.status_code != 200:
173              return {
174                  "status_code": response.status_code,
175                  "message": response.json(),
176              }
177          experiments = response.json()["data"]["experiments"]
178          return [experiment["name"] for experiment in experiments]
179  
180      def add_metrics(self, metrics):
181          """
182          Adds metrics to the experiment and handles different status codes in the response.
183  
184          Parameters:
185              metrics: The metrics to be added to the experiment. It can be a single metric or a list of metrics.
186  
187          Returns:
188              If the status code is 200, returns a success message with the added metric names.
189              If the status code is 401, retries the request, updates the job and experiment IDs, and returns the test response.
190              If the status code is not 200 or 401, logs an error, and returns an error message with the response check.
191          """
192          headers = {
193              "Content-Type": "application/json",
194              "Authorization": f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
195              "X-Project-Name": self.project_name,
196          }
197  
198          if not isinstance(metrics, list):
199              metrics = [metrics]
200          else:
201              metrics_list = metrics
202          sub_providers = ["openai","azure","gemini","groq"]
203          sub_metrics = RagaAICatalyst.list_metrics()  
204          for metric in metrics_list:
205              provider = metric.get('config', {}).get('provider', '').lower()
206              if provider and provider not in sub_providers:
207                  raise ValueError("Enter a valid provider name. The following Provider names are supported: OpenAI, Azure, Gemini, Groq")
208  
209              if metric['name'] not in sub_metrics:
210                  raise ValueError("Enter a valid metric name. Refer to RagaAI Metric Library to select a valid metric")
211  
212          json_data = {
213              "projectName": self.project_name,
214              "datasetName": self.dataset_name,
215              "experimentName": self.experiment_name,
216              "metrics": metrics_list,
217          }
218          logger.debug(
219              f"Preparing to add metrics for '{self.experiment_name}': {metrics}"
220          )
221          response = requests.post(
222              f"{Experiment.BASE_URL}/v1/llm/experiment",
223              headers=headers,
224              json=json_data,
225              timeout=Experiment.TIMEOUT,
226          )
227  
228          status_code = response.status_code
229          if status_code == 200:
230              test_response = response.json()
231              self.job_id = test_response.get("data").get("jobId")
232              self.experiment_id = test_response.get("data").get("experiment").get("id")
233              self.project_id = (
234                  test_response.get("data").get("experiment").get("projectId")
235              )
236              print(f"Metrics added successfully. Job ID: {self.job_id}")
237              metric_names = [
238                  execution["metricName"]
239                  for execution in test_response["data"]["experiment"]["executions"]
240              ]
241              return f"Metrics {metric_names} added successfully"
242          elif status_code == 401:
243              headers = {
244                  "Content-Type": "application/json",
245                  "Authorization": f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
246                  "X-Project-Name": self.project_name,
247              }
248              response = requests.post(
249                  f"{Experiment.BASE_URL}/v1/llm/experiment",
250                  headers=headers,
251                  json=json_data,
252                  timeout=Experiment.TIMEOUT,
253              )
254              status_code = response.status_code
255              if status_code == 200:
256                  test_response = response.json()
257                  self.job_id = test_response.get("data").get("jobId")
258                  self.experiment_id = (
259                      test_response.get("data").get("experiment").get("id")
260                  )
261                  self.project_id = (
262                      test_response.get("data").get("experiment").get("projectId")
263                  )
264  
265                  return test_response
266              else:
267                  logger.error("Endpoint not responsive after retry attempts.")
268                  return response_checker(response, "Experiment.run_tests")
269          else:
270              logger.error(f"Failed to add metrics: HTTP {status_code}")
271              return (
272                  "Error in running tests",
273                  response_checker(response, "Experiment.run_tests"),
274              )
275  
276      def get_status(self, job_id=None):
277          """
278          Retrieves the status of a job based on the provided job ID.
279  
280          Returns the status of the job if the status code is 200, otherwise handles different status codes.
281          """
282          headers = {
283              "Content-Type": "application/json",
284              "Authorization": f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
285              "X-Project-Name": self.project_name,
286          }
287          if job_id is not None:
288              job_id_to_check = job_id
289          else:
290              job_id_to_check = self.job_id
291  
292          if job_id_to_check is None:
293              logger.warning("Attempt to fetch status without a valid job ID.")
294              return "Please run an experiment test first"
295          json_data = {
296              "jobId": job_id_to_check,
297          }
298          logger.debug(f"Fetching status for Job ID: {job_id_to_check}")
299          response = requests.get(
300              f"{Experiment.BASE_URL}/job/status",
301              headers=headers,
302              json=json_data,
303              timeout=Experiment.TIMEOUT,
304          )
305          status_code = response_checker(response, "Experiment.get_status")
306          if status_code == 200:
307              test_response = response.json()
308              jobs = test_response["data"]["content"]
309              for job in jobs:
310                  if job["id"] == job_id_to_check:
311                      return job["status"]
312          elif status_code == 401:
313              headers = {
314                  "Content-Type": "application/json",
315                  "Authorization": f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
316                  "X-Project-Name": self.project_name,
317              }
318              response = requests.post(
319                  f"{Experiment.BASE_URL}/job/status",
320                  headers=headers,
321                  json=json_data,
322                  timeout=Experiment.TIMEOUT,
323              )
324              status_code = response_checker(response, "Experiment.get_status")
325              if status_code == 200:
326                  test_response = response.json()
327                  self.experiment_id = (
328                      test_response.get("data").get("experiment").get("id")
329                  )
330                  return test_response
331              else:
332                  logger.error("Endpoint not responsive after retry attempts.")
333                  return response_checker(response, "Experiment.get_status")
334          else:
335              return (
336                  "Error in running tests",
337                  response_checker(response, "Experiment.get_status"),
338              )
339  
340      def get_results(self, job_id=None):
341          """
342          A function that retrieves results based on the experiment ID.
343          It makes a POST request to the BASE_URL to fetch results using the provided JSON data.
344          If the request is successful (status code 200), it returns the retrieved test response.
345          If the status code is 401, it retries the request and returns the test response if successful.
346          If the status is neither 200 nor 401, it logs an error and returns the response checker result.
347          """
348          if job_id is not None:
349              job_id_to_use = job_id
350          else:
351              job_id_to_use = self.job_id
352  
353          if job_id_to_use is None:
354              logger.warning("Results fetch attempted without prior job execution.")
355              return "Please run an experiment test first"
356  
357          headers = {
358              "Content-Type": "application/json",
359              "Authorization": f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
360              "X-Project-Id": str(self.project_id),
361          }
362  
363          json_data = {
364              "fields": [],
365              "experimentId": self.experiment_id,
366              "numRecords": 4,
367              "projectId": self.project_id,
368              "filterList": [],
369          }
370          base_url_without_api = Experiment.BASE_URL.removesuffix('/api')
371  
372          status_json = self.get_status(job_id_to_use)
373          if status_json == "Failed":
374              return print("Job failed. No results to fetch.")
375          elif status_json == "In Progress":
376              return print(f"Job in progress. Please wait while the job completes.\n Visit Job Status: {base_url_without_api}/home/job-status to track")
377          elif status_json == "Completed":
378              print(f"Job completed. fetching results.\n Visit Job Status: {base_url_without_api}/home/job-status to track")
379  
380          response = requests.post(
381              f"{Experiment.BASE_URL}/v1/llm/docs",
382              headers=headers,
383              json=json_data,
384              timeout=Experiment.TIMEOUT,
385          )
386          if response.status_code == 200:
387              print("Results successfully retrieved.")
388              test_response = response.json()
389  
390              if test_response["success"]:
391                  parse_success, parsed_response = self.parse_response(test_response)
392                  if parse_success:
393                      return parsed_response
394                  else:
395                      logger.error(f"Failed to parse response: {test_response}")
396                      raise FailedToRetrieveResults(
397                          f"Failed to parse response: {test_response}"
398                      )
399  
400              else:
401                  logger.error(f"Failed to retrieve results for job: {job_id_to_use}")
402                  raise FailedToRetrieveResults(
403                      f"Failed to retrieve results for job: {job_id_to_use}"
404                  )
405  
406              return parsed_response
407          elif response.status_code == 401:
408              headers = {
409                  "Content-Type": "application/json",
410                  "Authorization": f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
411                  "X-Project-Id": str(self.project_id),
412              }
413              response = requests.post(
414                  f"{Experiment.BASE_URL}/v1/llm/docs",
415                  headers=headers,
416                  json=json_data,
417                  timeout=Experiment.TIMEOUT,
418              )
419              if response.status_code == 200:
420                  test_response = response.json()
421                  return test_response
422              else:
423                  logger.error("Endpoint not responsive after retry attempts.")
424                  return response_checker(response, "Experiment.get_test_results")
425          else:
426              return (
427                  "Error in running tests",
428                  response_checker(response, "Experiment.get_test_results"),
429              )
430  
431      def parse_response(self, response):
432          """
433          Parse the response to get the results
434          """
435          try:
436              x = pd.DataFrame(response["data"]["docs"])
437  
438              column_names_to_replace = [
439                  {item["columnName"]: item["displayName"]}
440                  for item in response["data"]["columns"]
441              ]
442  
443              if column_names_to_replace:
444                  for item in column_names_to_replace:
445                      x = x.rename(columns=item)
446  
447                  dict_cols = [
448                      col
449                      for col in x.columns
450                      if x[col].dtype == "object"
451                      and x[col].apply(lambda y: isinstance(y, dict)).any()
452                  ]
453  
454                  for dict_col in dict_cols:
455                      x[f"{dict_col}_reason"] = x[dict_col].apply(
456                          lambda y: y.get("reason") if isinstance(y, dict) else None
457                      )
458                      x[f"{dict_col}_metric_config"] = x[dict_col].apply(
459                          lambda y: (
460                              y.get("metric_config") if isinstance(y, dict) else None
461                          )
462                      )
463                      x[f"{dict_col}_status"] = x[dict_col].apply(
464                          lambda y: y.get("status") if isinstance(y, dict) else None
465                      )
466  
467                      x = x.drop(columns=[dict_col])
468  
469              x.columns = x.columns.str.replace("_reason_reason", "_reason")
470              x.columns = x.columns.str.replace("_reason_metric_config", "_metric_config")
471              x.columns = x.columns.str.replace("_reason_status", "_status")
472  
473              columns_list = x.columns.tolist()
474              #remove trace_uri from columns_list if it exists
475              columns_list = list(set(columns_list) - {"trace_uri"})
476              x = x[columns_list]
477  
478              return True, x
479  
480          except Exception as e:
481              logger.error(f"Failed to parse response: {e}", exc_info=True)
482              return False, pd.DataFrame()
483  
484  
485  class FailedToRetrieveResults(Exception):
486      pass