test_ssh_service_manager.py
1 """Tests for SSHServiceManager — SSH agent integration.""" 2 from unittest.mock import AsyncMock, MagicMock, patch 3 4 import pytest 5 6 from src.core.ssh.ssh_config import SSHProfile, SSHSecurityConfig 7 from src.services.ssh_service_manager import SSHServiceManager 8 9 10 @pytest.fixture 11 def manager(): 12 """Fresh SSHServiceManager for each test.""" 13 return SSHServiceManager() 14 15 16 @pytest.fixture 17 def sample_profile(): 18 """Sample SSHProfile for testing.""" 19 return SSHProfile( 20 name="test-server", 21 host="192.168.1.100", 22 port=22, 23 username="deploy", 24 mode="readonly", 25 privilege_level=0, 26 description="Test server", 27 ) 28 29 30 @pytest.fixture 31 def sample_profiles(sample_profile): 32 """Dict of profiles keyed by name.""" 33 return {sample_profile.name: sample_profile} 34 35 36 class TestInitialize: 37 """Tests for SSHServiceManager.initialize().""" 38 39 @pytest.mark.asyncio 40 async def test_always_initializes_without_yaml(self, manager): 41 """SSH infrastructure always initializes — no YAML needed.""" 42 await manager.initialize() 43 assert manager.enabled is True 44 assert manager._pool is not None 45 assert manager._command_filter is not None 46 assert manager.security_config is not None 47 assert manager.security_config.enabled is True 48 49 @pytest.mark.asyncio 50 async def test_logs_deprecation_when_yaml_exists(self, manager, tmp_path): 51 """Logs warning when legacy ssh-security.yaml is found.""" 52 yaml_path = tmp_path / "ssh-security.yaml" 53 yaml_path.write_text("ssh:\n enabled: true\n") 54 55 with patch( 56 "src.services.ssh_service_manager._LEGACY_SSH_SECURITY_CONFIG_PATH", 57 yaml_path, 58 ), patch( 59 "src.services.ssh_service_manager.logger" 60 ) as mock_logger: 61 await manager.initialize() 62 63 mock_logger.warning.assert_called_once() 64 assert "no longer used" in mock_logger.warning.call_args[0][0] 65 assert manager.enabled is True 66 67 68 class TestIsUserSshEnabled: 69 """Tests for SSHServiceManager.is_user_ssh_enabled().""" 70 71 @pytest.mark.asyncio 72 async def test_returns_false_when_flag_disabled(self, manager): 73 """Returns False when user's ssh_enabled feature flag is False.""" 74 await manager.initialize() 75 76 with patch.object( 77 SSHServiceManager, 78 "_resolve_user_ssh_flag", 79 new_callable=AsyncMock, 80 return_value=False, 81 ): 82 result = await manager.is_user_ssh_enabled("user-1") 83 84 assert result is False 85 86 @pytest.mark.asyncio 87 async def test_returns_true_when_flag_enabled(self, manager): 88 """Returns True when user's ssh_enabled feature flag is True.""" 89 await manager.initialize() 90 91 with patch.object( 92 SSHServiceManager, 93 "_resolve_user_ssh_flag", 94 new_callable=AsyncMock, 95 return_value=True, 96 ): 97 result = await manager.is_user_ssh_enabled("user-1") 98 99 assert result is True 100 101 @pytest.mark.asyncio 102 async def test_cache_hit_skips_db(self, manager): 103 """Redis cache hit returns immediately without calling _resolve_user_ssh_flag.""" 104 await manager.initialize() 105 106 mock_client = AsyncMock() 107 mock_client.get = AsyncMock(return_value="1") 108 109 with ( 110 patch( 111 "src.services.ssh_service_manager._redis" 112 ) as mock_redis, 113 patch.object( 114 SSHServiceManager, 115 "_resolve_user_ssh_flag", 116 new_callable=AsyncMock, 117 ) as mock_resolve, 118 ): 119 mock_redis.get.return_value = mock_client 120 result = await manager.is_user_ssh_enabled("user-1") 121 122 assert result is True 123 mock_resolve.assert_not_called() 124 125 @pytest.mark.asyncio 126 async def test_cache_miss_queries_db_and_sets_cache(self, manager): 127 """Redis cache miss falls through to DB then caches with 30s TTL.""" 128 await manager.initialize() 129 130 mock_client = AsyncMock() 131 mock_client.get = AsyncMock(return_value=None) 132 mock_client.set = AsyncMock() 133 134 with ( 135 patch( 136 "src.services.ssh_service_manager._redis" 137 ) as mock_redis, 138 patch.object( 139 SSHServiceManager, 140 "_resolve_user_ssh_flag", 141 new_callable=AsyncMock, 142 return_value=True, 143 ), 144 ): 145 mock_redis.get.return_value = mock_client 146 result = await manager.is_user_ssh_enabled("user-1") 147 148 assert result is True 149 mock_client.set.assert_called_once_with( 150 "feature:ssh_enabled:user-1", "1", ex=30, 151 ) 152 153 @pytest.mark.asyncio 154 async def test_redis_unavailable_falls_through(self, manager): 155 """When Redis raises, method still returns correct value from DB.""" 156 await manager.initialize() 157 158 with ( 159 patch( 160 "src.services.ssh_service_manager._redis" 161 ) as mock_redis, 162 patch.object( 163 SSHServiceManager, 164 "_resolve_user_ssh_flag", 165 new_callable=AsyncMock, 166 return_value=False, 167 ), 168 ): 169 mock_redis.get.side_effect = RuntimeError("Redis down") 170 result = await manager.is_user_ssh_enabled("user-1") 171 172 assert result is False 173 174 175 class TestBuildSessionContext: 176 """Tests for SSHServiceManager.build_session_context().""" 177 178 @pytest.mark.asyncio 179 async def test_returns_none_when_not_initialized(self, manager, sample_profiles): 180 """Returns None when manager hasn't been initialized.""" 181 result = await manager.build_session_context( 182 session_id="sess-1", 183 user_id="user-1", 184 profiles=sample_profiles, 185 db_session_factory=AsyncMock(), 186 vault_service=MagicMock(), 187 ) 188 assert result is None 189 190 @pytest.mark.asyncio 191 async def test_returns_none_when_no_profiles(self, manager): 192 """Returns None when profiles dict is empty.""" 193 await manager.initialize() 194 195 result = await manager.build_session_context( 196 session_id="sess-1", 197 user_id="user-1", 198 profiles={}, 199 db_session_factory=AsyncMock(), 200 vault_service=MagicMock(), 201 ) 202 assert result is None 203 204 @pytest.mark.asyncio 205 async def test_builds_context_with_profiles(self, manager, sample_profiles): 206 """Builds SSHToolContext when initialized and profiles exist.""" 207 await manager.initialize() 208 209 vault_svc = MagicMock() 210 db_factory = AsyncMock() 211 212 ctx = await manager.build_session_context( 213 session_id="sess-1", 214 user_id="user-1", 215 profiles=sample_profiles, 216 db_session_factory=db_factory, 217 vault_service=vault_svc, 218 ) 219 assert ctx is not None 220 assert ctx.session_id == "sess-1" 221 assert ctx.user_id == "user-1" 222 assert ctx.profiles == sample_profiles 223 assert ctx.connection_pool is manager._pool 224 assert ctx.command_filter is manager._command_filter 225 assert ctx.command_semaphore is not None 226 227 @pytest.mark.asyncio 228 async def test_multiple_sessions_share_pool(self, manager, sample_profiles): 229 """Multiple sessions share the same connection pool.""" 230 await manager.initialize() 231 232 vault_svc = MagicMock() 233 ctx1 = await manager.build_session_context( 234 session_id="sess-1", 235 user_id="user-1", 236 profiles=sample_profiles, 237 db_session_factory=AsyncMock(), 238 vault_service=vault_svc, 239 ) 240 ctx2 = await manager.build_session_context( 241 session_id="sess-2", 242 user_id="user-1", 243 profiles=sample_profiles, 244 db_session_factory=AsyncMock(), 245 vault_service=vault_svc, 246 ) 247 assert ctx1.connection_pool is ctx2.connection_pool 248 249 250 class TestCleanupSession: 251 """Tests for SSHServiceManager.cleanup_session().""" 252 253 @pytest.mark.asyncio 254 async def test_cleanup_calls_pool(self, manager): 255 """Cleanup calls close_session_connections on the pool.""" 256 await manager.initialize() 257 258 manager._pool.close_session_connections = AsyncMock(return_value=2) 259 await manager.cleanup_session("sess-1") 260 manager._pool.close_session_connections.assert_awaited_once_with("sess-1") 261 262 @pytest.mark.asyncio 263 async def test_cleanup_noop_when_not_initialized(self, manager): 264 """Cleanup is a no-op when not initialized (no pool).""" 265 manager._pool = None 266 await manager.cleanup_session("sess-1") # Should not raise 267 268 269 class TestShutdown: 270 """Tests for SSHServiceManager.shutdown().""" 271 272 @pytest.mark.asyncio 273 async def test_shutdown_closes_pool(self, manager): 274 """Shutdown calls pool.shutdown().""" 275 await manager.initialize() 276 277 manager._pool.shutdown = AsyncMock() 278 await manager.shutdown() 279 manager._pool.shutdown.assert_awaited_once() 280 281 @pytest.mark.asyncio 282 async def test_shutdown_noop_when_not_initialized(self, manager): 283 """Shutdown is a no-op when never initialized.""" 284 manager._pool = None 285 await manager.shutdown() # Should not raise 286 287 288 class TestPromptContext: 289 """Tests for SSH profile injection into prompt context.""" 290 291 def test_ssh_enabled_flag_set(self, sample_profiles): 292 """SSH_ENABLED flag is True when profiles provided.""" 293 from src.core.prompt_context import build_prompt_context 294 ctx = build_prompt_context(ssh_profiles=sample_profiles) 295 assert ctx.flags["SSH_ENABLED"] is True 296 297 def test_ssh_disabled_flag_when_no_profiles(self): 298 """SSH_ENABLED flag is False when no profiles.""" 299 from src.core.prompt_context import build_prompt_context 300 ctx = build_prompt_context(ssh_profiles=None) 301 assert ctx.flags["SSH_ENABLED"] is False 302 303 def test_ssh_disabled_flag_when_empty_profiles(self): 304 """SSH_ENABLED flag is False when profiles dict is empty.""" 305 from src.core.prompt_context import build_prompt_context 306 ctx = build_prompt_context(ssh_profiles={}) 307 assert ctx.flags["SSH_ENABLED"] is False 308 309 def test_profiles_block_generated(self, sample_profile): 310 """SSH_PROFILES_BLOCK contains profile info.""" 311 from src.core.prompt_context import build_prompt_context 312 profiles = {sample_profile.name: sample_profile} 313 ctx = build_prompt_context(ssh_profiles=profiles) 314 block = ctx.strings["SSH_PROFILES_BLOCK"] 315 assert "test-server" in block 316 assert "deploy@192.168.1.100:22" in block 317 assert "P0 Observer" in block 318 assert "Test server" in block 319 320 def test_tool_names_registered(self, sample_profiles): 321 """SSH tool name variables are registered in context.""" 322 from src.core.prompt_context import build_prompt_context 323 ctx = build_prompt_context(ssh_profiles=sample_profiles) 324 assert ctx.tool_names["AG3NTUM_SSH_EXEC_TOOL"] == "mcp__ag3ntum__SSHExec" 325 assert ctx.tool_names["AG3NTUM_SSH_READ_TOOL"] == "mcp__ag3ntum__SSHRead" 326 assert ctx.tool_names["AG3NTUM_SSH_CONNECT_TOOL"] == "mcp__ag3ntum__SSHConnect" 327 328 def test_multiple_profiles_in_block(self): 329 """Multiple profiles are listed in SSH_PROFILES_BLOCK.""" 330 from src.core.prompt_context import build_prompt_context 331 profiles = { 332 "prod-web": SSHProfile( 333 name="prod-web", host="10.0.0.1", port=22, 334 username="root", mode="operations", privilege_level=1, 335 ), 336 "staging-db": SSHProfile( 337 name="staging-db", host="10.0.0.2", port=2222, 338 username="admin", mode="readonly", privilege_level=0, 339 description="Staging database", 340 ), 341 } 342 ctx = build_prompt_context(ssh_profiles=profiles) 343 block = ctx.strings["SSH_PROFILES_BLOCK"] 344 assert "prod-web" in block 345 assert "staging-db" in block 346 assert "root@10.0.0.1:22" in block 347 assert "admin@10.0.0.2:2222" in block 348 assert "P1 Site Manager" in block 349 assert "P0 Observer" in block 350 351 352 class TestSSHPromptRendering: 353 """Tests for SSH system prompt template rendering.""" 354 355 def test_prompt_rendered_when_ssh_enabled(self, sample_profiles): 356 """07b-ssh.md renders content when SSH_ENABLED is True.""" 357 from pathlib import Path 358 from src.core.prompt_engine import PromptTemplateEngine 359 from src.core.prompt_context import build_prompt_context 360 361 ctx = build_prompt_context(ssh_profiles=sample_profiles) 362 engine = PromptTemplateEngine( 363 base_dir=Path("prompts"), 364 ) 365 prompt_file = Path("prompts/system-prompts/07b-ssh.md") 366 rendered = engine.load_and_render(prompt_file, ctx) 367 assert "SSH Remote Server Access" in rendered 368 assert "test-server" in rendered 369 assert "mcp__ag3ntum__SSHExec" in rendered 370 371 def test_prompt_empty_when_ssh_disabled(self): 372 """07b-ssh.md renders empty when SSH_ENABLED is False.""" 373 from pathlib import Path 374 from src.core.prompt_engine import PromptTemplateEngine 375 from src.core.prompt_context import build_prompt_context 376 377 ctx = build_prompt_context(ssh_profiles=None) 378 engine = PromptTemplateEngine( 379 base_dir=Path("prompts"), 380 ) 381 prompt_file = Path("prompts/system-prompts/07b-ssh.md") 382 rendered = engine.load_and_render(prompt_file, ctx) 383 assert "SSH Remote Server Access" not in rendered 384 assert rendered.strip() == "" 385 386 387 class TestSSHSecurityDefaults: 388 """Tests for hardcoded SSH security defaults.""" 389 390 def test_default_config_has_sane_limits(self): 391 """get_default_ssh_security_config returns config with sensible limits.""" 392 from src.core.ssh.ssh_config import get_default_ssh_security_config 393 config = get_default_ssh_security_config() 394 assert config.enabled is True 395 assert config.limits.max_connections_per_user == 3 396 assert config.limits.command_timeout_seconds == 300 397 assert config.limits.max_output_bytes == 1_048_576 398 assert config.credentials.password_auth_allowed is False 399 assert config.host_key_verification.mode == "tofu" 400 401 def test_always_blocked_hosts_constant(self): 402 """ALWAYS_BLOCKED_HOSTS includes localhost and metadata IPs.""" 403 from src.core.ssh.ssh_config import ALWAYS_BLOCKED_HOSTS 404 assert "127.0.0.1" in ALWAYS_BLOCKED_HOSTS 405 assert "localhost" in ALWAYS_BLOCKED_HOSTS 406 assert "::1" in ALWAYS_BLOCKED_HOSTS 407 assert "169.254.0.0/16" in ALWAYS_BLOCKED_HOSTS 408 409 def test_default_config_blocks_dangerous_hosts(self): 410 """Default config always_blocked list includes all ALWAYS_BLOCKED_HOSTS.""" 411 from src.core.ssh.ssh_config import ( 412 ALWAYS_BLOCKED_HOSTS, 413 get_default_ssh_security_config, 414 ) 415 config = get_default_ssh_security_config() 416 for host in ALWAYS_BLOCKED_HOSTS: 417 assert host in config.hosts.always_blocked