test_cli.py
1 import os 2 from unittest import mock 3 4 import pytest 5 from click.testing import CliRunner 6 7 from mlflow.assistant.cli import commands 8 from mlflow.assistant.config import ProviderConfig 9 10 11 @pytest.fixture 12 def runner(): 13 return CliRunner() 14 15 16 def test_assistant_help(runner): 17 result = runner.invoke(commands, ["--help"]) 18 assert result.exit_code == 0 19 assert "AI-powered trace analysis" in result.output 20 assert "--configure" in result.output 21 22 23 def test_configure_cli_not_found(runner): 24 with mock.patch("mlflow.assistant.cli.shutil.which", return_value=None): 25 result = runner.invoke(commands, ["--configure"], input="1\n") 26 assert "not installed" in result.output 27 28 29 def test_configure_auth_failure(runner): 30 mock_result = mock.Mock() 31 mock_result.returncode = 1 32 mock_result.stderr = "unauthorized" 33 34 with ( 35 mock.patch("mlflow.assistant.cli.shutil.which", return_value="/usr/bin/claude"), 36 mock.patch( 37 "mlflow.assistant.providers.claude_code.subprocess.run", 38 return_value=mock_result, 39 ), 40 ): 41 result = runner.invoke(commands, ["--configure"], input="1\n") 42 assert result.exit_code == 0 43 # Should show error about authentication 44 assert "Not authenticated" in result.output or "not installed" in result.output.lower() 45 46 47 def test_configure_experiment_fetch_failure(runner): 48 mock_result = mock.Mock() 49 mock_result.returncode = 0 50 mock_result.stderr = "" 51 52 with ( 53 mock.patch("mlflow.assistant.cli.shutil.which", return_value="/usr/bin/claude"), 54 mock.patch( 55 "mlflow.assistant.providers.claude_code.subprocess.run", 56 return_value=mock_result, 57 ), 58 mock.patch( 59 "mlflow.assistant.cli._fetch_recent_experiments", 60 return_value=[], 61 ), 62 ): 63 # Input: provider=1, connect=y, tracking_uri=default 64 result = runner.invoke( 65 commands, 66 ["--configure"], 67 input="1\ny\nhttp://localhost:5000\n", 68 ) 69 assert "Could not fetch experiments" in result.output 70 71 72 def test_configure_success(runner, tmp_path): 73 mock_result = mock.Mock() 74 mock_result.returncode = 0 75 mock_result.stderr = "" 76 77 mock_config = mock.Mock() 78 mock_config.providers = {"claude_code": ProviderConfig()} 79 mock_config.projects = {} 80 81 def mock_set_provider(name, model): 82 mock_config.providers[name] = ProviderConfig(model=model, selected=True) 83 84 mock_config.set_provider = mock_set_provider 85 86 with ( 87 mock.patch("mlflow.assistant.cli.shutil.which", return_value="/usr/bin/claude"), 88 mock.patch( 89 "mlflow.assistant.providers.claude_code.subprocess.run", 90 return_value=mock_result, 91 ), 92 mock.patch( 93 "mlflow.assistant.cli._fetch_recent_experiments", 94 return_value=[("1", "Test Experiment")], 95 ), 96 mock.patch( 97 "mlflow.assistant.cli.AssistantConfig.load", 98 return_value=mock_config, 99 ), 100 mock.patch.object(mock_config, "save"), 101 runner.isolated_filesystem(temp_dir=tmp_path), 102 ): 103 # Input: provider=1, connect=y, experiment=1, project_path, model=default, skill_location=1 104 result = runner.invoke( 105 commands, 106 ["--configure"], 107 input=f"1\ny\nhttp://localhost:5000\n1\n{tmp_path}\ndefault\n1\n", 108 ) 109 assert "Setup Complete" in result.output 110 111 112 def test_configure_tilde_expansion(runner): 113 mock_result = mock.Mock() 114 mock_result.returncode = 0 115 mock_result.stderr = "" 116 117 mock_config = mock.Mock() 118 mock_config.providers = {"claude_code": ProviderConfig()} 119 projects_dict = {} 120 mock_config.projects = projects_dict 121 122 def mock_set_provider(name, model): 123 mock_config.providers[name] = ProviderConfig(model=model, selected=True) 124 125 mock_config.set_provider = mock_set_provider 126 127 home_dir = os.path.expanduser("~") 128 129 with ( 130 mock.patch("mlflow.assistant.cli.shutil.which", return_value="/usr/bin/claude"), 131 mock.patch( 132 "mlflow.assistant.providers.claude_code.subprocess.run", 133 return_value=mock_result, 134 ), 135 mock.patch( 136 "mlflow.assistant.cli._fetch_recent_experiments", 137 return_value=[("1", "Test Experiment")], 138 ), 139 mock.patch( 140 "mlflow.assistant.cli.AssistantConfig.load", 141 return_value=mock_config, 142 ), 143 mock.patch.object(mock_config, "save"), 144 ): 145 # Input: provider=1, connect=y, tracking_uri, experiment=1, project_path=~, 146 # model=default, skill_location=1 147 result = runner.invoke( 148 commands, 149 ["--configure"], 150 input="1\ny\nhttp://localhost:5000\n1\n~\ndefault\n1\n", 151 ) 152 # Should succeed because ~ expands to home dir which exists 153 assert "Setup Complete" in result.output 154 # Verify the saved path is the expanded path, not ~ 155 assert "1" in projects_dict 156 assert projects_dict["1"].location == home_dir 157 158 159 def test_configure_relative_path(runner): 160 mock_result = mock.Mock() 161 mock_result.returncode = 0 162 mock_result.stderr = "" 163 164 mock_config = mock.Mock() 165 mock_config.providers = {"claude_code": ProviderConfig()} 166 projects_dict = {} 167 mock_config.projects = projects_dict 168 169 def mock_set_provider(name, model): 170 mock_config.providers[name] = ProviderConfig(model=model, selected=True) 171 172 mock_config.set_provider = mock_set_provider 173 174 with ( 175 mock.patch("mlflow.assistant.cli.shutil.which", return_value="/usr/bin/claude"), 176 mock.patch( 177 "mlflow.assistant.providers.claude_code.subprocess.run", 178 return_value=mock_result, 179 ), 180 mock.patch( 181 "mlflow.assistant.cli._fetch_recent_experiments", 182 return_value=[("1", "Test Experiment")], 183 ), 184 mock.patch( 185 "mlflow.assistant.cli.AssistantConfig.load", 186 return_value=mock_config, 187 ), 188 mock.patch.object(mock_config, "save"), 189 ): 190 # Use "." which should resolve to current directory 191 # Input: provider=1, connect=y, tracking_uri, experiment=1, project_path=., 192 # model=default, skill_location=1 193 result = runner.invoke( 194 commands, 195 ["--configure"], 196 input="1\ny\nhttp://localhost:5000\n1\n.\ndefault\n1\n", 197 ) 198 assert "Setup Complete" in result.output 199 # Verify the saved path is absolute, not "." 200 assert "1" in projects_dict 201 assert os.path.isabs(projects_dict["1"].location) 202 assert projects_dict["1"].location != "."