/ tests / core-tests / test_ssh_service_manager.py
test_ssh_service_manager.py
  1  """Tests for SSHServiceManager — SSH agent integration."""
  2  from unittest.mock import AsyncMock, MagicMock, patch
  3  
  4  import pytest
  5  
  6  from src.core.ssh.ssh_config import SSHProfile, SSHSecurityConfig
  7  from src.services.ssh_service_manager import SSHServiceManager
  8  
  9  
 10  @pytest.fixture
 11  def manager():
 12      """Fresh SSHServiceManager for each test."""
 13      return SSHServiceManager()
 14  
 15  
 16  @pytest.fixture
 17  def sample_profile():
 18      """Sample SSHProfile for testing."""
 19      return SSHProfile(
 20          name="test-server",
 21          host="192.168.1.100",
 22          port=22,
 23          username="deploy",
 24          mode="readonly",
 25          privilege_level=0,
 26          description="Test server",
 27      )
 28  
 29  
 30  @pytest.fixture
 31  def sample_profiles(sample_profile):
 32      """Dict of profiles keyed by name."""
 33      return {sample_profile.name: sample_profile}
 34  
 35  
 36  class TestInitialize:
 37      """Tests for SSHServiceManager.initialize()."""
 38  
 39      @pytest.mark.asyncio
 40      async def test_always_initializes_without_yaml(self, manager):
 41          """SSH infrastructure always initializes — no YAML needed."""
 42          await manager.initialize()
 43          assert manager.enabled is True
 44          assert manager._pool is not None
 45          assert manager._command_filter is not None
 46          assert manager.security_config is not None
 47          assert manager.security_config.enabled is True
 48  
 49      @pytest.mark.asyncio
 50      async def test_logs_deprecation_when_yaml_exists(self, manager, tmp_path):
 51          """Logs warning when legacy ssh-security.yaml is found."""
 52          yaml_path = tmp_path / "ssh-security.yaml"
 53          yaml_path.write_text("ssh:\n  enabled: true\n")
 54  
 55          with patch(
 56              "src.services.ssh_service_manager._LEGACY_SSH_SECURITY_CONFIG_PATH",
 57              yaml_path,
 58          ), patch(
 59              "src.services.ssh_service_manager.logger"
 60          ) as mock_logger:
 61              await manager.initialize()
 62  
 63          mock_logger.warning.assert_called_once()
 64          assert "no longer used" in mock_logger.warning.call_args[0][0]
 65          assert manager.enabled is True
 66  
 67  
 68  class TestIsUserSshEnabled:
 69      """Tests for SSHServiceManager.is_user_ssh_enabled()."""
 70  
 71      @pytest.mark.asyncio
 72      async def test_returns_false_when_flag_disabled(self, manager):
 73          """Returns False when user's ssh_enabled feature flag is False."""
 74          await manager.initialize()
 75  
 76          with patch.object(
 77              SSHServiceManager,
 78              "_resolve_user_ssh_flag",
 79              new_callable=AsyncMock,
 80              return_value=False,
 81          ):
 82              result = await manager.is_user_ssh_enabled("user-1")
 83  
 84          assert result is False
 85  
 86      @pytest.mark.asyncio
 87      async def test_returns_true_when_flag_enabled(self, manager):
 88          """Returns True when user's ssh_enabled feature flag is True."""
 89          await manager.initialize()
 90  
 91          with patch.object(
 92              SSHServiceManager,
 93              "_resolve_user_ssh_flag",
 94              new_callable=AsyncMock,
 95              return_value=True,
 96          ):
 97              result = await manager.is_user_ssh_enabled("user-1")
 98  
 99          assert result is True
100  
101      @pytest.mark.asyncio
102      async def test_cache_hit_skips_db(self, manager):
103          """Redis cache hit returns immediately without calling _resolve_user_ssh_flag."""
104          await manager.initialize()
105  
106          mock_client = AsyncMock()
107          mock_client.get = AsyncMock(return_value="1")
108  
109          with (
110              patch(
111                  "src.services.ssh_service_manager._redis"
112              ) as mock_redis,
113              patch.object(
114                  SSHServiceManager,
115                  "_resolve_user_ssh_flag",
116                  new_callable=AsyncMock,
117              ) as mock_resolve,
118          ):
119              mock_redis.get.return_value = mock_client
120              result = await manager.is_user_ssh_enabled("user-1")
121  
122          assert result is True
123          mock_resolve.assert_not_called()
124  
125      @pytest.mark.asyncio
126      async def test_cache_miss_queries_db_and_sets_cache(self, manager):
127          """Redis cache miss falls through to DB then caches with 30s TTL."""
128          await manager.initialize()
129  
130          mock_client = AsyncMock()
131          mock_client.get = AsyncMock(return_value=None)
132          mock_client.set = AsyncMock()
133  
134          with (
135              patch(
136                  "src.services.ssh_service_manager._redis"
137              ) as mock_redis,
138              patch.object(
139                  SSHServiceManager,
140                  "_resolve_user_ssh_flag",
141                  new_callable=AsyncMock,
142                  return_value=True,
143              ),
144          ):
145              mock_redis.get.return_value = mock_client
146              result = await manager.is_user_ssh_enabled("user-1")
147  
148          assert result is True
149          mock_client.set.assert_called_once_with(
150              "feature:ssh_enabled:user-1", "1", ex=30,
151          )
152  
153      @pytest.mark.asyncio
154      async def test_redis_unavailable_falls_through(self, manager):
155          """When Redis raises, method still returns correct value from DB."""
156          await manager.initialize()
157  
158          with (
159              patch(
160                  "src.services.ssh_service_manager._redis"
161              ) as mock_redis,
162              patch.object(
163                  SSHServiceManager,
164                  "_resolve_user_ssh_flag",
165                  new_callable=AsyncMock,
166                  return_value=False,
167              ),
168          ):
169              mock_redis.get.side_effect = RuntimeError("Redis down")
170              result = await manager.is_user_ssh_enabled("user-1")
171  
172          assert result is False
173  
174  
175  class TestBuildSessionContext:
176      """Tests for SSHServiceManager.build_session_context()."""
177  
178      @pytest.mark.asyncio
179      async def test_returns_none_when_not_initialized(self, manager, sample_profiles):
180          """Returns None when manager hasn't been initialized."""
181          result = await manager.build_session_context(
182              session_id="sess-1",
183              user_id="user-1",
184              profiles=sample_profiles,
185              db_session_factory=AsyncMock(),
186              vault_service=MagicMock(),
187          )
188          assert result is None
189  
190      @pytest.mark.asyncio
191      async def test_returns_none_when_no_profiles(self, manager):
192          """Returns None when profiles dict is empty."""
193          await manager.initialize()
194  
195          result = await manager.build_session_context(
196              session_id="sess-1",
197              user_id="user-1",
198              profiles={},
199              db_session_factory=AsyncMock(),
200              vault_service=MagicMock(),
201          )
202          assert result is None
203  
204      @pytest.mark.asyncio
205      async def test_builds_context_with_profiles(self, manager, sample_profiles):
206          """Builds SSHToolContext when initialized and profiles exist."""
207          await manager.initialize()
208  
209          vault_svc = MagicMock()
210          db_factory = AsyncMock()
211  
212          ctx = await manager.build_session_context(
213              session_id="sess-1",
214              user_id="user-1",
215              profiles=sample_profiles,
216              db_session_factory=db_factory,
217              vault_service=vault_svc,
218          )
219          assert ctx is not None
220          assert ctx.session_id == "sess-1"
221          assert ctx.user_id == "user-1"
222          assert ctx.profiles == sample_profiles
223          assert ctx.connection_pool is manager._pool
224          assert ctx.command_filter is manager._command_filter
225          assert ctx.command_semaphore is not None
226  
227      @pytest.mark.asyncio
228      async def test_multiple_sessions_share_pool(self, manager, sample_profiles):
229          """Multiple sessions share the same connection pool."""
230          await manager.initialize()
231  
232          vault_svc = MagicMock()
233          ctx1 = await manager.build_session_context(
234              session_id="sess-1",
235              user_id="user-1",
236              profiles=sample_profiles,
237              db_session_factory=AsyncMock(),
238              vault_service=vault_svc,
239          )
240          ctx2 = await manager.build_session_context(
241              session_id="sess-2",
242              user_id="user-1",
243              profiles=sample_profiles,
244              db_session_factory=AsyncMock(),
245              vault_service=vault_svc,
246          )
247          assert ctx1.connection_pool is ctx2.connection_pool
248  
249  
250  class TestCleanupSession:
251      """Tests for SSHServiceManager.cleanup_session()."""
252  
253      @pytest.mark.asyncio
254      async def test_cleanup_calls_pool(self, manager):
255          """Cleanup calls close_session_connections on the pool."""
256          await manager.initialize()
257  
258          manager._pool.close_session_connections = AsyncMock(return_value=2)
259          await manager.cleanup_session("sess-1")
260          manager._pool.close_session_connections.assert_awaited_once_with("sess-1")
261  
262      @pytest.mark.asyncio
263      async def test_cleanup_noop_when_not_initialized(self, manager):
264          """Cleanup is a no-op when not initialized (no pool)."""
265          manager._pool = None
266          await manager.cleanup_session("sess-1")  # Should not raise
267  
268  
269  class TestShutdown:
270      """Tests for SSHServiceManager.shutdown()."""
271  
272      @pytest.mark.asyncio
273      async def test_shutdown_closes_pool(self, manager):
274          """Shutdown calls pool.shutdown()."""
275          await manager.initialize()
276  
277          manager._pool.shutdown = AsyncMock()
278          await manager.shutdown()
279          manager._pool.shutdown.assert_awaited_once()
280  
281      @pytest.mark.asyncio
282      async def test_shutdown_noop_when_not_initialized(self, manager):
283          """Shutdown is a no-op when never initialized."""
284          manager._pool = None
285          await manager.shutdown()  # Should not raise
286  
287  
288  class TestPromptContext:
289      """Tests for SSH profile injection into prompt context."""
290  
291      def test_ssh_enabled_flag_set(self, sample_profiles):
292          """SSH_ENABLED flag is True when profiles provided."""
293          from src.core.prompt_context import build_prompt_context
294          ctx = build_prompt_context(ssh_profiles=sample_profiles)
295          assert ctx.flags["SSH_ENABLED"] is True
296  
297      def test_ssh_disabled_flag_when_no_profiles(self):
298          """SSH_ENABLED flag is False when no profiles."""
299          from src.core.prompt_context import build_prompt_context
300          ctx = build_prompt_context(ssh_profiles=None)
301          assert ctx.flags["SSH_ENABLED"] is False
302  
303      def test_ssh_disabled_flag_when_empty_profiles(self):
304          """SSH_ENABLED flag is False when profiles dict is empty."""
305          from src.core.prompt_context import build_prompt_context
306          ctx = build_prompt_context(ssh_profiles={})
307          assert ctx.flags["SSH_ENABLED"] is False
308  
309      def test_profiles_block_generated(self, sample_profile):
310          """SSH_PROFILES_BLOCK contains profile info."""
311          from src.core.prompt_context import build_prompt_context
312          profiles = {sample_profile.name: sample_profile}
313          ctx = build_prompt_context(ssh_profiles=profiles)
314          block = ctx.strings["SSH_PROFILES_BLOCK"]
315          assert "test-server" in block
316          assert "deploy@192.168.1.100:22" in block
317          assert "P0 Observer" in block
318          assert "Test server" in block
319  
320      def test_tool_names_registered(self, sample_profiles):
321          """SSH tool name variables are registered in context."""
322          from src.core.prompt_context import build_prompt_context
323          ctx = build_prompt_context(ssh_profiles=sample_profiles)
324          assert ctx.tool_names["AG3NTUM_SSH_EXEC_TOOL"] == "mcp__ag3ntum__SSHExec"
325          assert ctx.tool_names["AG3NTUM_SSH_READ_TOOL"] == "mcp__ag3ntum__SSHRead"
326          assert ctx.tool_names["AG3NTUM_SSH_CONNECT_TOOL"] == "mcp__ag3ntum__SSHConnect"
327  
328      def test_multiple_profiles_in_block(self):
329          """Multiple profiles are listed in SSH_PROFILES_BLOCK."""
330          from src.core.prompt_context import build_prompt_context
331          profiles = {
332              "prod-web": SSHProfile(
333                  name="prod-web", host="10.0.0.1", port=22,
334                  username="root", mode="operations", privilege_level=1,
335              ),
336              "staging-db": SSHProfile(
337                  name="staging-db", host="10.0.0.2", port=2222,
338                  username="admin", mode="readonly", privilege_level=0,
339                  description="Staging database",
340              ),
341          }
342          ctx = build_prompt_context(ssh_profiles=profiles)
343          block = ctx.strings["SSH_PROFILES_BLOCK"]
344          assert "prod-web" in block
345          assert "staging-db" in block
346          assert "root@10.0.0.1:22" in block
347          assert "admin@10.0.0.2:2222" in block
348          assert "P1 Site Manager" in block
349          assert "P0 Observer" in block
350  
351  
352  class TestSSHPromptRendering:
353      """Tests for SSH system prompt template rendering."""
354  
355      def test_prompt_rendered_when_ssh_enabled(self, sample_profiles):
356          """07b-ssh.md renders content when SSH_ENABLED is True."""
357          from pathlib import Path
358          from src.core.prompt_engine import PromptTemplateEngine
359          from src.core.prompt_context import build_prompt_context
360  
361          ctx = build_prompt_context(ssh_profiles=sample_profiles)
362          engine = PromptTemplateEngine(
363              base_dir=Path("prompts"),
364          )
365          prompt_file = Path("prompts/system-prompts/07b-ssh.md")
366          rendered = engine.load_and_render(prompt_file, ctx)
367          assert "SSH Remote Server Access" in rendered
368          assert "test-server" in rendered
369          assert "mcp__ag3ntum__SSHExec" in rendered
370  
371      def test_prompt_empty_when_ssh_disabled(self):
372          """07b-ssh.md renders empty when SSH_ENABLED is False."""
373          from pathlib import Path
374          from src.core.prompt_engine import PromptTemplateEngine
375          from src.core.prompt_context import build_prompt_context
376  
377          ctx = build_prompt_context(ssh_profiles=None)
378          engine = PromptTemplateEngine(
379              base_dir=Path("prompts"),
380          )
381          prompt_file = Path("prompts/system-prompts/07b-ssh.md")
382          rendered = engine.load_and_render(prompt_file, ctx)
383          assert "SSH Remote Server Access" not in rendered
384          assert rendered.strip() == ""
385  
386  
387  class TestSSHSecurityDefaults:
388      """Tests for hardcoded SSH security defaults."""
389  
390      def test_default_config_has_sane_limits(self):
391          """get_default_ssh_security_config returns config with sensible limits."""
392          from src.core.ssh.ssh_config import get_default_ssh_security_config
393          config = get_default_ssh_security_config()
394          assert config.enabled is True
395          assert config.limits.max_connections_per_user == 3
396          assert config.limits.command_timeout_seconds == 300
397          assert config.limits.max_output_bytes == 1_048_576
398          assert config.credentials.password_auth_allowed is False
399          assert config.host_key_verification.mode == "tofu"
400  
401      def test_always_blocked_hosts_constant(self):
402          """ALWAYS_BLOCKED_HOSTS includes localhost and metadata IPs."""
403          from src.core.ssh.ssh_config import ALWAYS_BLOCKED_HOSTS
404          assert "127.0.0.1" in ALWAYS_BLOCKED_HOSTS
405          assert "localhost" in ALWAYS_BLOCKED_HOSTS
406          assert "::1" in ALWAYS_BLOCKED_HOSTS
407          assert "169.254.0.0/16" in ALWAYS_BLOCKED_HOSTS
408  
409      def test_default_config_blocks_dangerous_hosts(self):
410          """Default config always_blocked list includes all ALWAYS_BLOCKED_HOSTS."""
411          from src.core.ssh.ssh_config import (
412              ALWAYS_BLOCKED_HOSTS,
413              get_default_ssh_security_config,
414          )
415          config = get_default_ssh_security_config()
416          for host in ALWAYS_BLOCKED_HOSTS:
417              assert host in config.hosts.always_blocked