/ tests / cli / test_traces.py
test_traces.py
  1  import json
  2  import logging
  3  from unittest import mock
  4  
  5  import pytest
  6  from click.testing import CliRunner
  7  
  8  from mlflow.cli.traces import commands
  9  from mlflow.entities import (
 10      AssessmentSourceType,
 11      MlflowExperimentLocation,
 12      Trace,
 13      TraceData,
 14      TraceInfo,
 15      TraceLocation,
 16      TraceLocationType,
 17      TraceState,
 18  )
 19  from mlflow.store.entities.paged_list import PagedList
 20  
 21  
 22  @pytest.fixture(autouse=True)
 23  def suppress_logging():
 24      """Suppress logging for all tests."""
 25      # Suppress logging
 26      original_root = logging.root.level
 27      original_mlflow = logging.getLogger("mlflow").level
 28      original_alembic = logging.getLogger("alembic").level
 29  
 30      logging.root.setLevel(logging.CRITICAL)
 31      logging.getLogger("mlflow").setLevel(logging.CRITICAL)
 32      logging.getLogger("alembic").setLevel(logging.CRITICAL)
 33  
 34      yield
 35  
 36      # Restore original logging levels
 37      logging.root.setLevel(original_root)
 38      logging.getLogger("mlflow").setLevel(original_mlflow)
 39      logging.getLogger("alembic").setLevel(original_alembic)
 40  
 41  
 42  @pytest.fixture
 43  def runner():
 44      """Provide a CLI runner for testing."""
 45      return CliRunner(catch_exceptions=False)
 46  
 47  
 48  def test_commands_group_exists():
 49      assert commands.name == "traces"
 50      assert commands.help is not None
 51  
 52  
 53  def test_search_command_params():
 54      search_cmd = next((cmd for cmd in commands.commands.values() if cmd.name == "search"), None)
 55      assert search_cmd is not None
 56      param_names = [p.name for p in search_cmd.params]
 57      assert "experiment_id" in param_names
 58      assert "filter_string" in param_names
 59      assert "max_results" in param_names
 60      assert "order_by" in param_names
 61      assert "page_token" in param_names
 62      assert "output" in param_names
 63      assert "extract_fields" in param_names
 64  
 65  
 66  def test_get_command_params():
 67      get_cmd = next((cmd for cmd in commands.commands.values() if cmd.name == "get"), None)
 68      assert get_cmd is not None
 69      param_names = [p.name for p in get_cmd.params]
 70      assert "trace_id" in param_names
 71      assert "extract_fields" in param_names
 72  
 73  
 74  def test_assessment_source_type_choices():
 75      log_feedback_cmd = next(
 76          (cmd for cmd in commands.commands.values() if cmd.name == "log-feedback"), None
 77      )
 78      assert log_feedback_cmd is not None
 79  
 80      source_type_param = next(
 81          (param for param in log_feedback_cmd.params if param.name == "source_type"), None
 82      )
 83      assert source_type_param is not None
 84      assert AssessmentSourceType.HUMAN in source_type_param.type.choices
 85      assert AssessmentSourceType.LLM_JUDGE in source_type_param.type.choices
 86      assert AssessmentSourceType.CODE in source_type_param.type.choices
 87  
 88  
 89  def test_search_command_with_fields(runner):
 90      trace_location = TraceLocation(
 91          type=TraceLocationType.MLFLOW_EXPERIMENT,
 92          mlflow_experiment=MlflowExperimentLocation(experiment_id="1"),
 93      )
 94      trace = Trace(
 95          info=TraceInfo(
 96              trace_id="tr-123",
 97              state=TraceState.OK,
 98              request_time=1700000000000,
 99              execution_duration=1234,
100              request_preview="test request",
101              response_preview="test response",
102              trace_location=trace_location,
103          ),
104          data=TraceData(spans=[]),
105      )
106  
107      mock_result = PagedList([trace], None)
108  
109      with mock.patch("mlflow.cli.traces.TracingClient") as mock_client:
110          mock_client.return_value.search_traces.return_value = mock_result
111          result = runner.invoke(
112              commands,
113              ["search", "--experiment-id", "1", "--extract-fields", "info.trace_id,info.state"],
114          )
115  
116          assert result.exit_code == 0
117          assert "tr-123" in result.output
118          assert "OK" in result.output
119  
120  
121  def test_get_command_with_fields(runner):
122      trace_location = TraceLocation(
123          type=TraceLocationType.MLFLOW_EXPERIMENT,
124          mlflow_experiment=MlflowExperimentLocation(experiment_id="1"),
125      )
126      trace = Trace(
127          info=TraceInfo(
128              trace_id="tr-123",
129              state=TraceState.OK,
130              trace_location=trace_location,
131              request_time=1700000000000,
132              execution_duration=1234,
133          ),
134          data=TraceData(spans=[]),
135      )
136  
137      with mock.patch("mlflow.cli.traces.TracingClient") as mock_client:
138          mock_client.return_value.get_trace.return_value = trace
139          result = runner.invoke(
140              commands,
141              ["get", "--trace-id", "tr-123", "--extract-fields", "info.trace_id"],
142          )
143  
144          assert result.exit_code == 0
145          output_json = json.loads(result.output)
146          assert output_json == {"info": {"trace_id": "tr-123"}}
147  
148  
149  def test_delete_command(runner):
150      with mock.patch("mlflow.cli.traces.TracingClient") as mock_client:
151          mock_client.return_value.delete_traces.return_value = 5
152          result = runner.invoke(
153              commands,
154              ["delete", "--experiment-id", "1", "--trace-ids", "tr-1,tr-2,tr-3"],
155          )
156  
157          assert result.exit_code == 0
158          assert "Deleted 5 trace(s)" in result.output
159  
160  
161  def test_field_validation_error(runner):
162      trace_location = TraceLocation(
163          type=TraceLocationType.MLFLOW_EXPERIMENT,
164          mlflow_experiment=MlflowExperimentLocation(experiment_id="1"),
165      )
166      trace = Trace(
167          info=TraceInfo(
168              trace_id="tr-123",
169              trace_location=trace_location,
170              request_time=1700000000000,
171              execution_duration=1234,
172              state=TraceState.OK,
173          ),
174          data=TraceData(spans=[]),
175      )
176  
177      mock_result = PagedList([trace], None)
178  
179      with mock.patch("mlflow.cli.traces.TracingClient") as mock_client:
180          mock_client.return_value.search_traces.return_value = mock_result
181          result = runner.invoke(
182              commands,
183              ["search", "--experiment-id", "1", "--extract-fields", "invalid.field"],
184          )
185  
186          assert result.exit_code != 0
187          assert "Invalid field path" in result.output
188          assert "--verbose" in result.output
189  
190  
191  def test_field_validation_error_verbose_mode(runner):
192      trace_location = TraceLocation(
193          type=TraceLocationType.MLFLOW_EXPERIMENT,
194          mlflow_experiment=MlflowExperimentLocation(experiment_id="1"),
195      )
196      trace = Trace(
197          info=TraceInfo(
198              trace_id="tr-123",
199              state=TraceState.OK,
200              request_time=1700000000000,
201              trace_location=trace_location,
202              execution_duration=1234,
203          ),
204          data=TraceData(spans=[]),
205      )
206  
207      mock_result = PagedList([trace], None)
208  
209      with mock.patch("mlflow.cli.traces.TracingClient") as mock_client:
210          mock_client.return_value.search_traces.return_value = mock_result
211          result = runner.invoke(
212              commands,
213              [
214                  "search",
215                  "--experiment-id",
216                  "1",
217                  "--extract-fields",
218                  "invalid.field",
219                  "--verbose",
220              ],
221          )
222  
223          assert result.exit_code != 0
224          assert "Invalid field path" in result.output
225          assert "info.trace_id" in result.output
226          assert "info.state" in result.output
227          assert "info.request_time" in result.output
228          assert "Tip: Use --verbose" not in result.output