test_traces.py
1 import json 2 import logging 3 from unittest import mock 4 5 import pytest 6 from click.testing import CliRunner 7 8 from mlflow.cli.traces import commands 9 from mlflow.entities import ( 10 AssessmentSourceType, 11 MlflowExperimentLocation, 12 Trace, 13 TraceData, 14 TraceInfo, 15 TraceLocation, 16 TraceLocationType, 17 TraceState, 18 ) 19 from mlflow.store.entities.paged_list import PagedList 20 21 22 @pytest.fixture(autouse=True) 23 def suppress_logging(): 24 """Suppress logging for all tests.""" 25 # Suppress logging 26 original_root = logging.root.level 27 original_mlflow = logging.getLogger("mlflow").level 28 original_alembic = logging.getLogger("alembic").level 29 30 logging.root.setLevel(logging.CRITICAL) 31 logging.getLogger("mlflow").setLevel(logging.CRITICAL) 32 logging.getLogger("alembic").setLevel(logging.CRITICAL) 33 34 yield 35 36 # Restore original logging levels 37 logging.root.setLevel(original_root) 38 logging.getLogger("mlflow").setLevel(original_mlflow) 39 logging.getLogger("alembic").setLevel(original_alembic) 40 41 42 @pytest.fixture 43 def runner(): 44 """Provide a CLI runner for testing.""" 45 return CliRunner(catch_exceptions=False) 46 47 48 def test_commands_group_exists(): 49 assert commands.name == "traces" 50 assert commands.help is not None 51 52 53 def test_search_command_params(): 54 search_cmd = next((cmd for cmd in commands.commands.values() if cmd.name == "search"), None) 55 assert search_cmd is not None 56 param_names = [p.name for p in search_cmd.params] 57 assert "experiment_id" in param_names 58 assert "filter_string" in param_names 59 assert "max_results" in param_names 60 assert "order_by" in param_names 61 assert "page_token" in param_names 62 assert "output" in param_names 63 assert "extract_fields" in param_names 64 65 66 def test_get_command_params(): 67 get_cmd = next((cmd for cmd in commands.commands.values() if cmd.name == "get"), None) 68 assert get_cmd is not None 69 param_names = [p.name for p in get_cmd.params] 70 assert "trace_id" in param_names 71 assert "extract_fields" in param_names 72 73 74 def test_assessment_source_type_choices(): 75 log_feedback_cmd = next( 76 (cmd for cmd in commands.commands.values() if cmd.name == "log-feedback"), None 77 ) 78 assert log_feedback_cmd is not None 79 80 source_type_param = next( 81 (param for param in log_feedback_cmd.params if param.name == "source_type"), None 82 ) 83 assert source_type_param is not None 84 assert AssessmentSourceType.HUMAN in source_type_param.type.choices 85 assert AssessmentSourceType.LLM_JUDGE in source_type_param.type.choices 86 assert AssessmentSourceType.CODE in source_type_param.type.choices 87 88 89 def test_search_command_with_fields(runner): 90 trace_location = TraceLocation( 91 type=TraceLocationType.MLFLOW_EXPERIMENT, 92 mlflow_experiment=MlflowExperimentLocation(experiment_id="1"), 93 ) 94 trace = Trace( 95 info=TraceInfo( 96 trace_id="tr-123", 97 state=TraceState.OK, 98 request_time=1700000000000, 99 execution_duration=1234, 100 request_preview="test request", 101 response_preview="test response", 102 trace_location=trace_location, 103 ), 104 data=TraceData(spans=[]), 105 ) 106 107 mock_result = PagedList([trace], None) 108 109 with mock.patch("mlflow.cli.traces.TracingClient") as mock_client: 110 mock_client.return_value.search_traces.return_value = mock_result 111 result = runner.invoke( 112 commands, 113 ["search", "--experiment-id", "1", "--extract-fields", "info.trace_id,info.state"], 114 ) 115 116 assert result.exit_code == 0 117 assert "tr-123" in result.output 118 assert "OK" in result.output 119 120 121 def test_get_command_with_fields(runner): 122 trace_location = TraceLocation( 123 type=TraceLocationType.MLFLOW_EXPERIMENT, 124 mlflow_experiment=MlflowExperimentLocation(experiment_id="1"), 125 ) 126 trace = Trace( 127 info=TraceInfo( 128 trace_id="tr-123", 129 state=TraceState.OK, 130 trace_location=trace_location, 131 request_time=1700000000000, 132 execution_duration=1234, 133 ), 134 data=TraceData(spans=[]), 135 ) 136 137 with mock.patch("mlflow.cli.traces.TracingClient") as mock_client: 138 mock_client.return_value.get_trace.return_value = trace 139 result = runner.invoke( 140 commands, 141 ["get", "--trace-id", "tr-123", "--extract-fields", "info.trace_id"], 142 ) 143 144 assert result.exit_code == 0 145 output_json = json.loads(result.output) 146 assert output_json == {"info": {"trace_id": "tr-123"}} 147 148 149 def test_delete_command(runner): 150 with mock.patch("mlflow.cli.traces.TracingClient") as mock_client: 151 mock_client.return_value.delete_traces.return_value = 5 152 result = runner.invoke( 153 commands, 154 ["delete", "--experiment-id", "1", "--trace-ids", "tr-1,tr-2,tr-3"], 155 ) 156 157 assert result.exit_code == 0 158 assert "Deleted 5 trace(s)" in result.output 159 160 161 def test_field_validation_error(runner): 162 trace_location = TraceLocation( 163 type=TraceLocationType.MLFLOW_EXPERIMENT, 164 mlflow_experiment=MlflowExperimentLocation(experiment_id="1"), 165 ) 166 trace = Trace( 167 info=TraceInfo( 168 trace_id="tr-123", 169 trace_location=trace_location, 170 request_time=1700000000000, 171 execution_duration=1234, 172 state=TraceState.OK, 173 ), 174 data=TraceData(spans=[]), 175 ) 176 177 mock_result = PagedList([trace], None) 178 179 with mock.patch("mlflow.cli.traces.TracingClient") as mock_client: 180 mock_client.return_value.search_traces.return_value = mock_result 181 result = runner.invoke( 182 commands, 183 ["search", "--experiment-id", "1", "--extract-fields", "invalid.field"], 184 ) 185 186 assert result.exit_code != 0 187 assert "Invalid field path" in result.output 188 assert "--verbose" in result.output 189 190 191 def test_field_validation_error_verbose_mode(runner): 192 trace_location = TraceLocation( 193 type=TraceLocationType.MLFLOW_EXPERIMENT, 194 mlflow_experiment=MlflowExperimentLocation(experiment_id="1"), 195 ) 196 trace = Trace( 197 info=TraceInfo( 198 trace_id="tr-123", 199 state=TraceState.OK, 200 request_time=1700000000000, 201 trace_location=trace_location, 202 execution_duration=1234, 203 ), 204 data=TraceData(spans=[]), 205 ) 206 207 mock_result = PagedList([trace], None) 208 209 with mock.patch("mlflow.cli.traces.TracingClient") as mock_client: 210 mock_client.return_value.search_traces.return_value = mock_result 211 result = runner.invoke( 212 commands, 213 [ 214 "search", 215 "--experiment-id", 216 "1", 217 "--extract-fields", 218 "invalid.field", 219 "--verbose", 220 ], 221 ) 222 223 assert result.exit_code != 0 224 assert "Invalid field path" in result.output 225 assert "info.trace_id" in result.output 226 assert "info.state" in result.output 227 assert "info.request_time" in result.output 228 assert "Tip: Use --verbose" not in result.output