test_evaluation_metrics.py
1 import pytest 2 import os 3 import requests 4 from unittest.mock import patch, MagicMock 5 from ragaai_catalyst.evaluation import Evaluation 6 7 @pytest.fixture 8 def evaluation(): 9 with patch('requests.get') as mock_get, \ 10 patch('requests.post') as mock_post: 11 # Mock project list response 12 mock_get.return_value.json.return_value = { 13 "data": { 14 "content": [{ 15 "id": "test_project_id", 16 "name": "test_project" 17 }] 18 } 19 } 20 mock_get.return_value.status_code = 200 21 22 # Mock dataset list response 23 mock_post.return_value.json.return_value = { 24 "data": { 25 "content": [{ 26 "id": "test_dataset_id", 27 "name": "test_dataset" 28 }] 29 } 30 } 31 mock_post.return_value.status_code = 200 32 33 return Evaluation(project_name="test_project", dataset_name="test_dataset") 34 35 @pytest.fixture 36 def valid_metrics(): 37 return [{ 38 "name": "accuracy", 39 "config": {"threshold": 0.8}, 40 "column_name": "accuracy_col", 41 "schema_mapping": {"input": "test_input"} 42 }] 43 44 @pytest.fixture 45 def mock_response(): 46 mock = MagicMock() 47 mock.status_code = 200 48 mock.json.return_value = { 49 "success": True, 50 "message": "Metrics added successfully", 51 "data": {"jobId": "test_job_123"} 52 } 53 return mock 54 55 def test_add_metrics_success(evaluation, valid_metrics, mock_response): 56 """Test successful addition of metrics""" 57 with patch('requests.post') as mock_post, \ 58 patch.object(evaluation, '_get_executed_metrics_list', return_value=[]), \ 59 patch.object(evaluation, 'list_metrics', return_value=["accuracy"]), \ 60 patch.object(evaluation, '_update_base_json', return_value={}): 61 62 mock_post.return_value = mock_response 63 evaluation.add_metrics(valid_metrics) 64 65 # Verify the request was made with correct project_id 66 assert mock_post.call_args[1]['headers']['X-Project-Id'] == str(evaluation.project_id) 67 assert evaluation.jobId == "test_job_123" 68 69 def test_add_metrics_missing_required_keys(evaluation): 70 """Test validation of required keys""" 71 invalid_metrics = [{ 72 "name": "accuracy", 73 "config": {"threshold": 0.8} 74 # missing column_name and schema_mapping 75 }] 76 77 with pytest.raises(ValueError) as exc_info: 78 evaluation.add_metrics(invalid_metrics) 79 80 assert "required for each metric evaluation" in str(exc_info.value) 81 82 def test_add_metrics_invalid_metric_name(evaluation, valid_metrics): 83 """Test validation of metric names""" 84 with patch.object(evaluation, '_get_executed_metrics_list', return_value=[]), \ 85 patch.object(evaluation, 'list_metrics', return_value=["different_metric"]): 86 87 with pytest.raises(ValueError) as exc_info: 88 evaluation.add_metrics(valid_metrics) 89 90 assert "Enter a valid metric name" in str(exc_info.value) 91 92 def test_add_metrics_duplicate_column_name(evaluation, valid_metrics): 93 """Test validation of duplicate column names""" 94 with patch.object(evaluation, '_get_executed_metrics_list', 95 return_value=["accuracy_col"]), \ 96 patch.object(evaluation, 'list_metrics', return_value=["accuracy"]): 97 98 with pytest.raises(ValueError) as exc_info: 99 evaluation.add_metrics(valid_metrics) 100 101 assert "Column name 'accuracy_col' already exists" in str(exc_info.value) 102 103 def test_add_metrics_http_error(evaluation, valid_metrics): 104 """Test handling of HTTP errors""" 105 with patch('requests.post') as mock_post, \ 106 patch.object(evaluation, '_get_executed_metrics_list', return_value=[]), \ 107 patch.object(evaluation, 'list_metrics', return_value=["accuracy"]), \ 108 patch.object(evaluation, '_update_base_json', return_value={}): 109 110 mock_post.side_effect = requests.exceptions.HTTPError("HTTP Error") 111 evaluation.add_metrics(valid_metrics) 112 # Should log error but not raise exception 113 114 def test_add_metrics_connection_error(evaluation, valid_metrics): 115 """Test handling of connection errors""" 116 with patch('requests.post') as mock_post, \ 117 patch.object(evaluation, '_get_executed_metrics_list', return_value=[]), \ 118 patch.object(evaluation, 'list_metrics', return_value=["accuracy"]), \ 119 patch.object(evaluation, '_update_base_json', return_value={}): 120 121 mock_post.side_effect = requests.exceptions.ConnectionError("Connection Error") 122 evaluation.add_metrics(valid_metrics) 123 # Should log error but not raise exception 124 125 def test_add_metrics_timeout_error(evaluation, valid_metrics): 126 """Test handling of timeout errors""" 127 with patch('requests.post') as mock_post, \ 128 patch.object(evaluation, '_get_executed_metrics_list', return_value=[]), \ 129 patch.object(evaluation, 'list_metrics', return_value=["accuracy"]), \ 130 patch.object(evaluation, '_update_base_json', return_value={}): 131 132 mock_post.side_effect = requests.exceptions.Timeout("Timeout Error") 133 evaluation.add_metrics(valid_metrics) 134 # Should log error but not raise exception 135 136 def test_add_metrics_bad_request(evaluation, valid_metrics): 137 """Test handling of 400 bad request""" 138 mock_response = MagicMock() 139 mock_response.status_code = 400 140 mock_response.json.return_value = {"message": "Bad request error"} 141 142 with patch('requests.post') as mock_post, \ 143 patch.object(evaluation, '_get_executed_metrics_list', return_value=[]), \ 144 patch.object(evaluation, 'list_metrics', return_value=["accuracy"]), \ 145 patch.object(evaluation, '_update_base_json', return_value={}), \ 146 patch('ragaai_catalyst.evaluation.logger') as mock_logger: 147 148 mock_post.return_value = mock_response 149 evaluation.add_metrics(valid_metrics) 150 151 # Verify error is logged 152 mock_logger.error.assert_called_with( 153 "An unexpected error occurred: Bad request error" 154 ) 155 assert evaluation.jobId is None