test_cli.py
1 """Tests for the mlflow demo CLI command. 2 3 Includes both quick help/registration tests and functional tests that 4 invoke the actual CLI command with a mocked server. 5 """ 6 7 import socket 8 import sys 9 from unittest import mock 10 11 import click 12 import pytest 13 from click.testing import CliRunner 14 15 import mlflow 16 from mlflow.cli import cli 17 from mlflow.cli.demo import _check_server_connection, demo 18 from mlflow.demo.base import DEMO_EXPERIMENT_NAME, DEMO_PROMPT_PREFIX 19 from mlflow.demo.generators.traces import DEMO_VERSION_TAG 20 from mlflow.demo.registry import demo_registry 21 from mlflow.genai.datasets import search_datasets 22 from mlflow.genai.prompts import search_prompts 23 24 25 @pytest.fixture(autouse=True) 26 def disable_quiet_logging(monkeypatch): 27 """Prevent CLI from modifying logging state during tests.""" 28 demo_module = sys.modules["mlflow.cli.demo"] 29 monkeypatch.setattr(demo_module, "_set_quiet_logging", lambda: None) 30 31 32 def test_demo_command_registered(): 33 runner = CliRunner() 34 result = runner.invoke(cli, ["demo", "--help"]) 35 36 assert result.exit_code == 0 37 assert "Launch MLflow with pre-populated demo data" in result.output 38 39 40 def test_demo_command_help_shows_options(): 41 runner = CliRunner() 42 result = runner.invoke(demo, ["--help"]) 43 44 assert result.exit_code == 0 45 assert "--port" in result.output 46 assert "--no-browser" in result.output 47 48 49 def test_demo_command_port_option(): 50 runner = CliRunner() 51 result = runner.invoke(demo, ["--help"]) 52 53 assert result.exit_code == 0 54 assert "Port to run demo server on" in result.output 55 56 57 def test_cli_generates_all_registered_features(): 58 runner = CliRunner() 59 60 with mock.patch("mlflow.server._run_server"): 61 result = runner.invoke(demo, ["--no-browser"], input="n\n") 62 63 assert result.exit_code == 0 64 assert "Generated:" in result.output 65 66 tracking_uri = mlflow.get_tracking_uri() 67 mlflow.set_tracking_uri(tracking_uri) 68 69 registered_names = set(demo_registry.list_generators()) 70 for name in registered_names: 71 assert name in result.output 72 73 74 def test_cli_creates_experiment(): 75 runner = CliRunner() 76 77 with mock.patch("mlflow.server._run_server"): 78 result = runner.invoke(demo, ["--no-browser"], input="n\n") 79 80 assert result.exit_code == 0 81 82 experiment = mlflow.get_experiment_by_name(DEMO_EXPERIMENT_NAME) 83 assert experiment is not None 84 assert experiment.lifecycle_stage == "active" 85 86 87 def test_cli_creates_traces(): 88 runner = CliRunner() 89 90 with mock.patch("mlflow.server._run_server"): 91 result = runner.invoke(demo, ["--no-browser"], input="n\n") 92 93 assert result.exit_code == 0 94 95 experiment = mlflow.get_experiment_by_name(DEMO_EXPERIMENT_NAME) 96 client = mlflow.MlflowClient() 97 all_traces = client.search_traces( 98 locations=[experiment.experiment_id], 99 max_results=200, 100 ) 101 102 # Filter for demo traces only (exclude evaluation traces created by evaluate()) 103 demo_traces = [t for t in all_traces if t.info.trace_metadata.get(DEMO_VERSION_TAG)] 104 assert len(demo_traces) == 42 105 106 107 def test_cli_creates_evaluation_datasets(): 108 runner = CliRunner() 109 110 with mock.patch("mlflow.server._run_server"): 111 result = runner.invoke(demo, ["--no-browser"], input="n\n") 112 113 assert result.exit_code == 0 114 115 experiment = mlflow.get_experiment_by_name(DEMO_EXPERIMENT_NAME) 116 datasets = search_datasets( 117 experiment_ids=[experiment.experiment_id], 118 filter_string="name LIKE 'demo-%'", 119 max_results=10, 120 ) 121 122 assert len(datasets) == 3 123 124 125 def test_cli_creates_prompts(): 126 runner = CliRunner() 127 128 with mock.patch("mlflow.server._run_server"): 129 result = runner.invoke(demo, ["--no-browser"], input="n\n") 130 131 assert result.exit_code == 0 132 133 prompts = search_prompts( 134 filter_string=f"name LIKE '{DEMO_PROMPT_PREFIX}.%'", 135 max_results=100, 136 ) 137 138 assert len(prompts) == 3 139 140 141 def test_cli_shows_server_url(): 142 runner = CliRunner() 143 144 with mock.patch("mlflow.server._run_server"): 145 result = runner.invoke(demo, ["--no-browser"], input="n\n") 146 147 assert result.exit_code == 0 148 assert "MLflow Tracking Server running at:" in result.output 149 assert "View the demo at:" in result.output 150 151 152 def test_cli_respects_port_option(): 153 runner = CliRunner() 154 155 with mock.patch("mlflow.server._run_server") as mock_server: 156 result = runner.invoke(demo, ["--no-browser", "--port", "5555"], input="n\n") 157 158 assert result.exit_code == 0 159 assert "http://127.0.0.1:5555" in result.output 160 mock_server.assert_called_once() 161 assert mock_server.call_args.kwargs["port"] == 5555 162 163 164 def test_cli_port_in_use_error(): 165 runner = CliRunner() 166 167 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 168 s.bind(("127.0.0.1", 0)) 169 bound_port = s.getsockname()[1] 170 171 result = runner.invoke(demo, ["--port", str(bound_port)], input="n\n") 172 173 assert result.exit_code != 0 174 assert "already in use" in result.output 175 176 177 def test_cli_unreachable_server_error(): 178 runner = CliRunner() 179 180 # Use a URL that won't have a server running 181 result = runner.invoke(demo, ["--tracking-uri", "http://localhost:59999"]) 182 183 assert result.exit_code != 0 184 assert "Cannot connect to MLflow server" in result.output 185 assert "Please verify" in result.output 186 187 188 def test_check_server_connection_fails_for_bad_url(): 189 with pytest.raises(click.ClickException, match="Cannot connect to MLflow server"): 190 _check_server_connection("http://localhost:59999", max_retries=1, timeout=1)