/ src / core / ssh / ssh_credential_vault.py
ssh_credential_vault.py
  1  """
  2  SSH credential vault — manages SSH authentication via unified vault.
  3  
  4  The agent NEVER sees raw credentials. This module provides authenticated
  5  SSH connections using credentials stored in the VaultService. Key material
  6  is decrypted only in memory during connection establishment and zeroed
  7  immediately after.
  8  """
  9  from __future__ import annotations
 10  
 11  import logging
 12  from typing import TYPE_CHECKING, Awaitable, Callable
 13  
 14  if TYPE_CHECKING:
 15      import asyncssh
 16      from sqlalchemy.ext.asyncio import AsyncSession
 17  
 18      from ...services.vault_service import VaultService
 19      from .ssh_config import SSHProfile, SSHSecurityConfig
 20      from .ssh_host_key_resolver import SSHHostKeyResolver
 21  
 22  logger = logging.getLogger(__name__)
 23  
 24  
 25  class SSHCredentialVault:
 26      """SSH-specific credential management. Provides authenticated connections.
 27  
 28      The agent NEVER accesses this class directly — it's called internally
 29      by the SSH MCP tool implementation.
 30  
 31      Wraps VaultService for SSH operations. Credentials are fetched once
 32      under the caller's db session and captured in closures so that vault
 33      access does not happen at connection time (after the session may have
 34      been released).
 35      """
 36  
 37      def __init__(
 38          self,
 39          vault_service: VaultService,
 40          security_config: SSHSecurityConfig,
 41          host_key_resolver: SSHHostKeyResolver | None = None,
 42      ) -> None:
 43          self._vault = vault_service
 44          self._config = security_config
 45          self._host_key_resolver = host_key_resolver
 46  
 47      async def get_connect_fn(
 48          self,
 49          db: AsyncSession,
 50          user_id: str,
 51          session_id: str,
 52          profile: SSHProfile,
 53      ) -> Callable[[], Awaitable[asyncssh.SSHClientConnection]]:
 54          """Create a connection factory for the connection pool.
 55  
 56          Returns an async callable that, when called, establishes an
 57          authenticated SSH connection. This is passed to
 58          SSHConnectionPool.get_connection() as the connect_fn argument.
 59  
 60          Credentials are pre-loaded here so vault access happens under the
 61          caller's db session, not deferred to the connect_fn closure.
 62  
 63          Supports three auth methods based on profile.auth_method:
 64          - "key":         Load SSH private key from vault via get_ssh_key()
 65          - "certificate": Load key + certificate from vault
 66          - "password":    Load password from vault (if allowed by config)
 67  
 68          Raises:
 69              ValueError: If the auth method is unsupported, credentials are
 70                          missing from the profile, or password auth is
 71                          disabled by security config.
 72          """
 73          auth_method = profile.auth_method
 74  
 75          # Resolve host key verification under the active db session
 76          known_hosts = await self._resolve_host_verification(
 77              db, user_id, session_id, profile
 78          )
 79  
 80          if auth_method == "key":
 81              return await self._build_key_connect_fn(
 82                  db, user_id, session_id, profile, known_hosts
 83              )
 84  
 85          if auth_method == "certificate":
 86              return await self._build_certificate_connect_fn(
 87                  db, user_id, session_id, profile, known_hosts
 88              )
 89  
 90          if auth_method == "password":
 91              return await self._build_password_connect_fn(
 92                  db, user_id, session_id, profile, known_hosts
 93              )
 94  
 95          raise ValueError(
 96              f"Unsupported auth method '{auth_method}' for profile '{profile.name}'. "
 97              f"Supported: key, certificate, password"
 98          )
 99  
100      def validate_profile_credentials(self, profile: SSHProfile) -> list[str]:
101          """Validate a profile's credential configuration without accessing vault.
102  
103          Checks that required credential references are present and that
104          the requested auth method is permitted by the security config.
105  
106          Returns:
107              List of validation error strings. Empty list means valid.
108          """
109          errors: list[str] = []
110  
111          if profile.auth_method == "key":
112              if not profile.key_ref:
113                  errors.append(
114                      f"Profile '{profile.name}': key auth requires key_ref"
115                  )
116  
117          elif profile.auth_method == "certificate":
118              if not profile.key_ref:
119                  errors.append(
120                      f"Profile '{profile.name}': cert auth requires key_ref"
121                  )
122              if not profile.certificate_ref:
123                  errors.append(
124                      f"Profile '{profile.name}': cert auth requires certificate_ref"
125                  )
126  
127          elif profile.auth_method == "password":
128              if not self._config.credentials.password_auth_allowed:
129                  errors.append(
130                      f"Profile '{profile.name}': password auth disabled in security config"
131                  )
132              if profile.password_secret_id is None:
133                  errors.append(
134                      f"Profile '{profile.name}': password auth requires password_secret_id"
135                  )
136  
137          return errors
138  
139      # --- Internal helpers ---
140  
141      async def _resolve_host_verification(
142          self,
143          db: AsyncSession,
144          user_id: str,
145          session_id: str,
146          profile: SSHProfile,
147      ) -> object:
148          """Resolve the known_hosts parameter for asyncssh.
149  
150          Priority:
151          1. Host key resolver (vault-pinned keys) — if configured
152          2. Explicit known_hosts_path from config — legacy fallback
153          3. Empty tuple — asyncssh system defaults
154  
155          Never returns None — that would disable host key verification.
156          """
157          if self._host_key_resolver is not None:
158              return await self._host_key_resolver.resolve(
159                  profile, user_id, session_id, db
160              )
161  
162          configured = self._config.credentials.known_hosts_path
163          if configured:
164              return configured
165  
166          # asyncssh accepts empty tuple = use system defaults
167          return ()
168  
169      # --- Connection builders ---
170  
171      async def _build_key_connect_fn(
172          self,
173          db: AsyncSession,
174          user_id: str,
175          session_id: str,
176          profile: SSHProfile,
177          known_hosts,
178      ) -> Callable[[], Awaitable[asyncssh.SSHClientConnection]]:
179          """Build a connect_fn for key-based authentication."""
180          if not profile.key_ref:
181              raise ValueError(
182                  f"Profile '{profile.name}' uses key auth but no key_ref configured"
183              )
184  
185          ssh_key = await self._vault.get_ssh_key(
186              db, user_id, profile.key_ref, session_id
187          )
188          host = profile.host
189          port = profile.port
190          username = profile.username
191  
192          logger.debug(
193              "Built key connect_fn for profile '%s' (user=%s, host=%s:%d)",
194              profile.name, user_id, host, port,
195          )
196  
197          async def connect() -> asyncssh.SSHClientConnection:
198              import asyncssh as _asyncssh
199              return await _asyncssh.connect(
200                  host=host,
201                  port=port,
202                  username=username,
203                  client_keys=[ssh_key],
204                  known_hosts=known_hosts,
205                  keepalive_interval=30,
206                  keepalive_count_max=3,
207                  agent_path=None,
208                  config=[],
209                  x509_trusted_certs=None,
210                  x509_trusted_cert_paths=[],
211              )
212  
213          return connect
214  
215      async def _build_certificate_connect_fn(
216          self,
217          db: AsyncSession,
218          user_id: str,
219          session_id: str,
220          profile: SSHProfile,
221          known_hosts,
222      ) -> Callable[[], Awaitable[asyncssh.SSHClientConnection]]:
223          """Build a connect_fn for certificate-based authentication."""
224          if not profile.key_ref:
225              raise ValueError(
226                  f"Profile '{profile.name}' uses cert auth but no key_ref configured"
227              )
228          if not profile.certificate_ref:
229              raise ValueError(
230                  f"Profile '{profile.name}' uses cert auth but no certificate_ref configured"
231              )
232  
233          ssh_key = await self._vault.get_ssh_key(
234              db, user_id, profile.key_ref, session_id
235          )
236          # certificate_ref stores the vault secret ID (int) as the certificate PEM
237          cert_pem = await self._vault.get_secret_value(
238              db, user_id, profile.certificate_ref, session_id
239          )
240          host = profile.host
241          port = profile.port
242          username = profile.username
243  
244          logger.debug(
245              "Built certificate connect_fn for profile '%s' (user=%s, host=%s:%d)",
246              profile.name, user_id, host, port,
247          )
248  
249          async def connect() -> asyncssh.SSHClientConnection:
250              import asyncssh as _asyncssh
251              cert = _asyncssh.import_certificate(cert_pem)
252              return await _asyncssh.connect(
253                  host=host,
254                  port=port,
255                  username=username,
256                  client_keys=[(ssh_key, cert)],
257                  known_hosts=known_hosts,
258                  keepalive_interval=30,
259                  keepalive_count_max=3,
260                  agent_path=None,
261                  config=[],
262                  x509_trusted_certs=None,
263                  x509_trusted_cert_paths=[],
264              )
265  
266          return connect
267  
268      async def _build_password_connect_fn(
269          self,
270          db: AsyncSession,
271          user_id: str,
272          session_id: str,
273          profile: SSHProfile,
274          known_hosts,
275      ) -> Callable[[], Awaitable[asyncssh.SSHClientConnection]]:
276          """Build a connect_fn for password-based authentication."""
277          if not self._config.credentials.password_auth_allowed:
278              raise ValueError(
279                  "Password authentication is disabled in SSH security config"
280              )
281          if profile.password_secret_id is None:
282              raise ValueError(
283                  f"Profile '{profile.name}' uses password auth but no password_secret_id configured"
284              )
285  
286          password = await self._vault.get_secret_value(
287              db, user_id, profile.password_secret_id, session_id
288          )
289          host = profile.host
290          port = profile.port
291          username = profile.username
292  
293          logger.debug(
294              "Built password connect_fn for profile '%s' (user=%s, host=%s:%d)",
295              profile.name, user_id, host, port,
296          )
297  
298          async def connect() -> asyncssh.SSHClientConnection:
299              import asyncssh as _asyncssh
300              return await _asyncssh.connect(
301                  host=host,
302                  port=port,
303                  username=username,
304                  password=password,
305                  known_hosts=known_hosts,
306                  keepalive_interval=30,
307                  keepalive_count_max=3,
308                  agent_path=None,
309                  config=[],
310                  x509_trusted_certs=None,
311                  x509_trusted_cert_paths=[],
312              )
313  
314          return connect