/ tests / test_catalyst / test_dataset.py
test_dataset.py
  1  import pytest
  2  import os
  3  import dotenv
  4  dotenv.load_dotenv()
  5  import pandas as pd
  6  from datetime import datetime
  7  from typing import Dict, List
  8  from unittest.mock import patch, Mock
  9  import requests
 10  from ragaai_catalyst import Dataset,RagaAICatalyst
 11  
 12  csv_path = os.path.join(os.path.dirname(__file__), os.path.join("test_data", "util_test_dataset.csv"))
 13  
 14  
 15  @pytest.fixture
 16  def base_url():
 17      return os.getenv("RAGAAI_CATALYST_BASE_URL")
 18  
 19  @pytest.fixture
 20  def access_keys():
 21      return {
 22          "access_key": os.getenv("RAGAAI_CATALYST_ACCESS_KEY"),
 23          "secret_key": os.getenv("RAGAAI_CATALYST_SECRET_KEY")}
 24  
 25  @pytest.fixture
 26  def dataset(base_url, access_keys):
 27      """Create evaluation instance with specific project and dataset"""
 28      os.environ["RAGAAI_CATALYST_BASE_URL"] = base_url
 29      catalyst = RagaAICatalyst(
 30          access_key=access_keys["access_key"],
 31          secret_key=access_keys["secret_key"]
 32      )
 33      return Dataset(project_name="prompt_metric_dataset")
 34  
 35  def test_list_dataset(dataset) -> List[str]:
 36      datasets = dataset.list_datasets()
 37      return datasets
 38  
 39  
 40  # def test_get_dataset_columns(dataset)  -> List[str]:
 41  #     dataset_column = dataset.get_dataset_columns(dataset_name="schema_metric_dataset_ritika_3")
 42  #     return dataset_column
 43  
 44  def test_incorrect_dataset(dataset):
 45      with pytest.raises(ValueError, match="Please enter a valid dataset name"):
 46          dataset.get_dataset_columns(dataset_name="ritika_datset")
 47  
 48  def test_get_schema_mapping(dataset):
 49      schema_mapping_columns= dataset.get_schema_mapping()
 50      return schema_mapping_columns
 51  
 52  
 53  def test_upload_csv(dataset):
 54      project_name = 'prompt_metric_dataset3'
 55  
 56      schema_mapping = {
 57          'Query': 'prompt',
 58          'Response': 'response',
 59          'Context': 'context',
 60          'ExpectedResponse': 'expected_response',
 61      }
 62  
 63      timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 
 64      dataset_name = f"schema_metric_dataset_ritika_{timestamp}"  
 65  
 66      
 67  
 68      dataset.create_from_csv(
 69          csv_path=csv_path,
 70          dataset_name=dataset_name,
 71          schema_mapping=schema_mapping
 72      )
 73  
 74  def test_upload_csv_repeat_dataset(dataset):
 75      with pytest.raises(ValueError, match="already exists"):
 76          project_name = 'prompt_metric_dataset'
 77  
 78          schema_mapping = {
 79              'Query': 'prompt',
 80              'Response': 'response',
 81              'Context': 'context',
 82              'ExpectedResponse': 'expected_response',
 83          }
 84  
 85          dataset.create_from_csv(
 86              csv_path=csv_path,
 87              dataset_name="schema_metric_dataset_ritika_3",
 88              schema_mapping=schema_mapping
 89          )
 90  
 91  
 92  def test_upload_csv_no_schema_mapping(dataset):
 93      with pytest.raises(TypeError, match="missing 1 required positional argument"):
 94          project_name = 'prompt_metric_dataset'
 95  
 96          schema_mapping = {
 97              'Query': 'prompt',
 98              'Response': 'response',
 99              'Context': 'context',
100              'ExpectedResponse': 'expected_response',
101          }
102  
103          dataset.create_from_csv(
104              csv_path=csv_path,
105              dataset_name="schema_metric_dataset_ritika_3",
106          )
107  
108  def test_upload_csv_empty_csv_path(dataset):
109      with pytest.raises(FileNotFoundError, match="No such file or directory"):
110          project_name = 'prompt_metric_dataset'
111  
112          schema_mapping = {
113              'Query': 'prompt',
114              'Response': 'response',
115              'Context': 'context',
116              'ExpectedResponse': 'expected_response',
117          }
118  
119          dataset.create_from_csv(
120              csv_path="",
121              dataset_name="schema_metric_dataset_ritika_12",
122              schema_mapping=schema_mapping
123  
124          )
125  
126  def test_upload_csv_empty_schema_mapping(dataset):
127      with pytest.raises(AttributeError):
128          project_name = 'prompt_metric_dataset'
129  
130          schema_mapping = {
131              'Query': 'prompt',
132              'Response': 'response',
133              'Context': 'context',
134              'ExpectedResponse': 'expected_response',
135          }
136  
137          dataset.create_from_csv(
138              csv_path=csv_path,
139              dataset_name="schema_metric_dataset_ritika_12",
140              schema_mapping=""
141  
142          )
143  
144  
145  def test_upload_csv_invalid_schema(dataset):
146      with pytest.raises(ValueError, match="Invalid schema mapping provided"):
147  
148          project_name = 'prompt_metric_dataset'
149  
150          schema_mapping={
151              'prompt': 'prompt',
152              'response': 'response',
153              'chatId': 'chatId',
154              'chatSequence': 'chatSequence'
155          }
156  
157          dataset.create_from_csv(
158              csv_path=csv_path,
159              dataset_name="schema_metric_dataset_ritika_12",
160              schema_mapping=schema_mapping)