/ tests / test_catalyst / test_evaluation_metrics.py
test_evaluation_metrics.py
  1  import pytest
  2  import os
  3  import requests
  4  from unittest.mock import patch, MagicMock
  5  from ragaai_catalyst.evaluation import Evaluation
  6  
  7  @pytest.fixture
  8  def evaluation():
  9      with patch('requests.get') as mock_get, \
 10           patch('requests.post') as mock_post:
 11          # Mock project list response
 12          mock_get.return_value.json.return_value = {
 13              "data": {
 14                  "content": [{
 15                      "id": "test_project_id",
 16                      "name": "test_project"
 17                  }]
 18              }
 19          }
 20          mock_get.return_value.status_code = 200
 21          
 22          # Mock dataset list response
 23          mock_post.return_value.json.return_value = {
 24              "data": {
 25                  "content": [{
 26                      "id": "test_dataset_id",
 27                      "name": "test_dataset"
 28                  }]
 29              }
 30          }
 31          mock_post.return_value.status_code = 200
 32          
 33          return Evaluation(project_name="test_project", dataset_name="test_dataset")
 34  
 35  @pytest.fixture
 36  def valid_metrics():
 37      return [{
 38          "name": "accuracy",
 39          "config": {"threshold": 0.8},
 40          "column_name": "accuracy_col",
 41          "schema_mapping": {"input": "test_input"}
 42      }]
 43  
 44  @pytest.fixture
 45  def mock_response():
 46      mock = MagicMock()
 47      mock.status_code = 200
 48      mock.json.return_value = {
 49          "success": True,
 50          "message": "Metrics added successfully",
 51          "data": {"jobId": "test_job_123"}
 52      }
 53      return mock
 54  
 55  def test_add_metrics_success(evaluation, valid_metrics, mock_response):
 56      """Test successful addition of metrics"""
 57      with patch('requests.post') as mock_post, \
 58           patch.object(evaluation, '_get_executed_metrics_list', return_value=[]), \
 59           patch.object(evaluation, 'list_metrics', return_value=["accuracy"]), \
 60           patch.object(evaluation, '_update_base_json', return_value={}):
 61          
 62          mock_post.return_value = mock_response
 63          evaluation.add_metrics(valid_metrics)
 64          
 65          # Verify the request was made with correct project_id
 66          assert mock_post.call_args[1]['headers']['X-Project-Id'] == str(evaluation.project_id)
 67          assert evaluation.jobId == "test_job_123"
 68  
 69  def test_add_metrics_missing_required_keys(evaluation):
 70      """Test validation of required keys"""
 71      invalid_metrics = [{
 72          "name": "accuracy",
 73          "config": {"threshold": 0.8}
 74          # missing column_name and schema_mapping
 75      }]
 76      
 77      with pytest.raises(ValueError) as exc_info:
 78          evaluation.add_metrics(invalid_metrics)
 79      
 80      assert "required for each metric evaluation" in str(exc_info.value)
 81  
 82  def test_add_metrics_invalid_metric_name(evaluation, valid_metrics):
 83      """Test validation of metric names"""
 84      with patch.object(evaluation, '_get_executed_metrics_list', return_value=[]), \
 85           patch.object(evaluation, 'list_metrics', return_value=["different_metric"]):
 86          
 87          with pytest.raises(ValueError) as exc_info:
 88              evaluation.add_metrics(valid_metrics)
 89          
 90          assert "Enter a valid metric name" in str(exc_info.value)
 91  
 92  def test_add_metrics_duplicate_column_name(evaluation, valid_metrics):
 93      """Test validation of duplicate column names"""
 94      with patch.object(evaluation, '_get_executed_metrics_list', 
 95                       return_value=["accuracy_col"]), \
 96           patch.object(evaluation, 'list_metrics', return_value=["accuracy"]):
 97          
 98          with pytest.raises(ValueError) as exc_info:
 99              evaluation.add_metrics(valid_metrics)
100          
101          assert "Column name 'accuracy_col' already exists" in str(exc_info.value)
102  
103  def test_add_metrics_http_error(evaluation, valid_metrics):
104      """Test handling of HTTP errors"""
105      with patch('requests.post') as mock_post, \
106           patch.object(evaluation, '_get_executed_metrics_list', return_value=[]), \
107           patch.object(evaluation, 'list_metrics', return_value=["accuracy"]), \
108           patch.object(evaluation, '_update_base_json', return_value={}):
109          
110          mock_post.side_effect = requests.exceptions.HTTPError("HTTP Error")
111          evaluation.add_metrics(valid_metrics)
112          # Should log error but not raise exception
113  
114  def test_add_metrics_connection_error(evaluation, valid_metrics):
115      """Test handling of connection errors"""
116      with patch('requests.post') as mock_post, \
117           patch.object(evaluation, '_get_executed_metrics_list', return_value=[]), \
118           patch.object(evaluation, 'list_metrics', return_value=["accuracy"]), \
119           patch.object(evaluation, '_update_base_json', return_value={}):
120          
121          mock_post.side_effect = requests.exceptions.ConnectionError("Connection Error")
122          evaluation.add_metrics(valid_metrics)
123          # Should log error but not raise exception
124  
125  def test_add_metrics_timeout_error(evaluation, valid_metrics):
126      """Test handling of timeout errors"""
127      with patch('requests.post') as mock_post, \
128           patch.object(evaluation, '_get_executed_metrics_list', return_value=[]), \
129           patch.object(evaluation, 'list_metrics', return_value=["accuracy"]), \
130           patch.object(evaluation, '_update_base_json', return_value={}):
131          
132          mock_post.side_effect = requests.exceptions.Timeout("Timeout Error")
133          evaluation.add_metrics(valid_metrics)
134          # Should log error but not raise exception
135  
136  def test_add_metrics_bad_request(evaluation, valid_metrics):
137      """Test handling of 400 bad request"""
138      mock_response = MagicMock()
139      mock_response.status_code = 400
140      mock_response.json.return_value = {"message": "Bad request error"}
141      
142      with patch('requests.post') as mock_post, \
143           patch.object(evaluation, '_get_executed_metrics_list', return_value=[]), \
144           patch.object(evaluation, 'list_metrics', return_value=["accuracy"]), \
145           patch.object(evaluation, '_update_base_json', return_value={}), \
146           patch('ragaai_catalyst.evaluation.logger') as mock_logger:
147          
148          mock_post.return_value = mock_response
149          evaluation.add_metrics(valid_metrics)
150          
151          # Verify error is logged
152          mock_logger.error.assert_called_with(
153              "An unexpected error occurred: Bad request error"
154          )
155          assert evaluation.jobId is None