/ tests / demo / test_cli.py
test_cli.py
  1  """Tests for the mlflow demo CLI command.
  2  
  3  Includes both quick help/registration tests and functional tests that
  4  invoke the actual CLI command with a mocked server.
  5  """
  6  
  7  import socket
  8  import sys
  9  from unittest import mock
 10  
 11  import click
 12  import pytest
 13  from click.testing import CliRunner
 14  
 15  import mlflow
 16  from mlflow.cli import cli
 17  from mlflow.cli.demo import _check_server_connection, demo
 18  from mlflow.demo.base import DEMO_EXPERIMENT_NAME, DEMO_PROMPT_PREFIX
 19  from mlflow.demo.generators.traces import DEMO_VERSION_TAG
 20  from mlflow.demo.registry import demo_registry
 21  from mlflow.genai.datasets import search_datasets
 22  from mlflow.genai.prompts import search_prompts
 23  
 24  
 25  @pytest.fixture(autouse=True)
 26  def disable_quiet_logging(monkeypatch):
 27      """Prevent CLI from modifying logging state during tests."""
 28      demo_module = sys.modules["mlflow.cli.demo"]
 29      monkeypatch.setattr(demo_module, "_set_quiet_logging", lambda: None)
 30  
 31  
 32  def test_demo_command_registered():
 33      runner = CliRunner()
 34      result = runner.invoke(cli, ["demo", "--help"])
 35  
 36      assert result.exit_code == 0
 37      assert "Launch MLflow with pre-populated demo data" in result.output
 38  
 39  
 40  def test_demo_command_help_shows_options():
 41      runner = CliRunner()
 42      result = runner.invoke(demo, ["--help"])
 43  
 44      assert result.exit_code == 0
 45      assert "--port" in result.output
 46      assert "--no-browser" in result.output
 47  
 48  
 49  def test_demo_command_port_option():
 50      runner = CliRunner()
 51      result = runner.invoke(demo, ["--help"])
 52  
 53      assert result.exit_code == 0
 54      assert "Port to run demo server on" in result.output
 55  
 56  
 57  def test_cli_generates_all_registered_features():
 58      runner = CliRunner()
 59  
 60      with mock.patch("mlflow.server._run_server"):
 61          result = runner.invoke(demo, ["--no-browser"], input="n\n")
 62  
 63      assert result.exit_code == 0
 64      assert "Generated:" in result.output
 65  
 66      tracking_uri = mlflow.get_tracking_uri()
 67      mlflow.set_tracking_uri(tracking_uri)
 68  
 69      registered_names = set(demo_registry.list_generators())
 70      for name in registered_names:
 71          assert name in result.output
 72  
 73  
 74  def test_cli_creates_experiment():
 75      runner = CliRunner()
 76  
 77      with mock.patch("mlflow.server._run_server"):
 78          result = runner.invoke(demo, ["--no-browser"], input="n\n")
 79  
 80      assert result.exit_code == 0
 81  
 82      experiment = mlflow.get_experiment_by_name(DEMO_EXPERIMENT_NAME)
 83      assert experiment is not None
 84      assert experiment.lifecycle_stage == "active"
 85  
 86  
 87  def test_cli_creates_traces():
 88      runner = CliRunner()
 89  
 90      with mock.patch("mlflow.server._run_server"):
 91          result = runner.invoke(demo, ["--no-browser"], input="n\n")
 92  
 93      assert result.exit_code == 0
 94  
 95      experiment = mlflow.get_experiment_by_name(DEMO_EXPERIMENT_NAME)
 96      client = mlflow.MlflowClient()
 97      all_traces = client.search_traces(
 98          locations=[experiment.experiment_id],
 99          max_results=200,
100      )
101  
102      # Filter for demo traces only (exclude evaluation traces created by evaluate())
103      demo_traces = [t for t in all_traces if t.info.trace_metadata.get(DEMO_VERSION_TAG)]
104      assert len(demo_traces) == 42
105  
106  
107  def test_cli_creates_evaluation_datasets():
108      runner = CliRunner()
109  
110      with mock.patch("mlflow.server._run_server"):
111          result = runner.invoke(demo, ["--no-browser"], input="n\n")
112  
113      assert result.exit_code == 0
114  
115      experiment = mlflow.get_experiment_by_name(DEMO_EXPERIMENT_NAME)
116      datasets = search_datasets(
117          experiment_ids=[experiment.experiment_id],
118          filter_string="name LIKE 'demo-%'",
119          max_results=10,
120      )
121  
122      assert len(datasets) == 3
123  
124  
125  def test_cli_creates_prompts():
126      runner = CliRunner()
127  
128      with mock.patch("mlflow.server._run_server"):
129          result = runner.invoke(demo, ["--no-browser"], input="n\n")
130  
131      assert result.exit_code == 0
132  
133      prompts = search_prompts(
134          filter_string=f"name LIKE '{DEMO_PROMPT_PREFIX}.%'",
135          max_results=100,
136      )
137  
138      assert len(prompts) == 3
139  
140  
141  def test_cli_shows_server_url():
142      runner = CliRunner()
143  
144      with mock.patch("mlflow.server._run_server"):
145          result = runner.invoke(demo, ["--no-browser"], input="n\n")
146  
147      assert result.exit_code == 0
148      assert "MLflow Tracking Server running at:" in result.output
149      assert "View the demo at:" in result.output
150  
151  
152  def test_cli_respects_port_option():
153      runner = CliRunner()
154  
155      with mock.patch("mlflow.server._run_server") as mock_server:
156          result = runner.invoke(demo, ["--no-browser", "--port", "5555"], input="n\n")
157  
158      assert result.exit_code == 0
159      assert "http://127.0.0.1:5555" in result.output
160      mock_server.assert_called_once()
161      assert mock_server.call_args.kwargs["port"] == 5555
162  
163  
164  def test_cli_port_in_use_error():
165      runner = CliRunner()
166  
167      with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
168          s.bind(("127.0.0.1", 0))
169          bound_port = s.getsockname()[1]
170  
171          result = runner.invoke(demo, ["--port", str(bound_port)], input="n\n")
172  
173      assert result.exit_code != 0
174      assert "already in use" in result.output
175  
176  
177  def test_cli_unreachable_server_error():
178      runner = CliRunner()
179  
180      # Use a URL that won't have a server running
181      result = runner.invoke(demo, ["--tracking-uri", "http://localhost:59999"])
182  
183      assert result.exit_code != 0
184      assert "Cannot connect to MLflow server" in result.output
185      assert "Please verify" in result.output
186  
187  
188  def test_check_server_connection_fails_for_bad_url():
189      with pytest.raises(click.ClickException, match="Cannot connect to MLflow server"):
190          _check_server_connection("http://localhost:59999", max_retries=1, timeout=1)