ssh_credential_vault.py
1 """ 2 SSH credential vault — manages SSH authentication via unified vault. 3 4 The agent NEVER sees raw credentials. This module provides authenticated 5 SSH connections using credentials stored in the VaultService. Key material 6 is decrypted only in memory during connection establishment and zeroed 7 immediately after. 8 """ 9 from __future__ import annotations 10 11 import logging 12 from typing import TYPE_CHECKING, Awaitable, Callable 13 14 if TYPE_CHECKING: 15 import asyncssh 16 from sqlalchemy.ext.asyncio import AsyncSession 17 18 from ...services.vault_service import VaultService 19 from .ssh_config import SSHProfile, SSHSecurityConfig 20 from .ssh_host_key_resolver import SSHHostKeyResolver 21 22 logger = logging.getLogger(__name__) 23 24 25 class SSHCredentialVault: 26 """SSH-specific credential management. Provides authenticated connections. 27 28 The agent NEVER accesses this class directly — it's called internally 29 by the SSH MCP tool implementation. 30 31 Wraps VaultService for SSH operations. Credentials are fetched once 32 under the caller's db session and captured in closures so that vault 33 access does not happen at connection time (after the session may have 34 been released). 35 """ 36 37 def __init__( 38 self, 39 vault_service: VaultService, 40 security_config: SSHSecurityConfig, 41 host_key_resolver: SSHHostKeyResolver | None = None, 42 ) -> None: 43 self._vault = vault_service 44 self._config = security_config 45 self._host_key_resolver = host_key_resolver 46 47 async def get_connect_fn( 48 self, 49 db: AsyncSession, 50 user_id: str, 51 session_id: str, 52 profile: SSHProfile, 53 ) -> Callable[[], Awaitable[asyncssh.SSHClientConnection]]: 54 """Create a connection factory for the connection pool. 55 56 Returns an async callable that, when called, establishes an 57 authenticated SSH connection. This is passed to 58 SSHConnectionPool.get_connection() as the connect_fn argument. 59 60 Credentials are pre-loaded here so vault access happens under the 61 caller's db session, not deferred to the connect_fn closure. 62 63 Supports three auth methods based on profile.auth_method: 64 - "key": Load SSH private key from vault via get_ssh_key() 65 - "certificate": Load key + certificate from vault 66 - "password": Load password from vault (if allowed by config) 67 68 Raises: 69 ValueError: If the auth method is unsupported, credentials are 70 missing from the profile, or password auth is 71 disabled by security config. 72 """ 73 auth_method = profile.auth_method 74 75 # Resolve host key verification under the active db session 76 known_hosts = await self._resolve_host_verification( 77 db, user_id, session_id, profile 78 ) 79 80 if auth_method == "key": 81 return await self._build_key_connect_fn( 82 db, user_id, session_id, profile, known_hosts 83 ) 84 85 if auth_method == "certificate": 86 return await self._build_certificate_connect_fn( 87 db, user_id, session_id, profile, known_hosts 88 ) 89 90 if auth_method == "password": 91 return await self._build_password_connect_fn( 92 db, user_id, session_id, profile, known_hosts 93 ) 94 95 raise ValueError( 96 f"Unsupported auth method '{auth_method}' for profile '{profile.name}'. " 97 f"Supported: key, certificate, password" 98 ) 99 100 def validate_profile_credentials(self, profile: SSHProfile) -> list[str]: 101 """Validate a profile's credential configuration without accessing vault. 102 103 Checks that required credential references are present and that 104 the requested auth method is permitted by the security config. 105 106 Returns: 107 List of validation error strings. Empty list means valid. 108 """ 109 errors: list[str] = [] 110 111 if profile.auth_method == "key": 112 if not profile.key_ref: 113 errors.append( 114 f"Profile '{profile.name}': key auth requires key_ref" 115 ) 116 117 elif profile.auth_method == "certificate": 118 if not profile.key_ref: 119 errors.append( 120 f"Profile '{profile.name}': cert auth requires key_ref" 121 ) 122 if not profile.certificate_ref: 123 errors.append( 124 f"Profile '{profile.name}': cert auth requires certificate_ref" 125 ) 126 127 elif profile.auth_method == "password": 128 if not self._config.credentials.password_auth_allowed: 129 errors.append( 130 f"Profile '{profile.name}': password auth disabled in security config" 131 ) 132 if profile.password_secret_id is None: 133 errors.append( 134 f"Profile '{profile.name}': password auth requires password_secret_id" 135 ) 136 137 return errors 138 139 # --- Internal helpers --- 140 141 async def _resolve_host_verification( 142 self, 143 db: AsyncSession, 144 user_id: str, 145 session_id: str, 146 profile: SSHProfile, 147 ) -> object: 148 """Resolve the known_hosts parameter for asyncssh. 149 150 Priority: 151 1. Host key resolver (vault-pinned keys) — if configured 152 2. Explicit known_hosts_path from config — legacy fallback 153 3. Empty tuple — asyncssh system defaults 154 155 Never returns None — that would disable host key verification. 156 """ 157 if self._host_key_resolver is not None: 158 return await self._host_key_resolver.resolve( 159 profile, user_id, session_id, db 160 ) 161 162 configured = self._config.credentials.known_hosts_path 163 if configured: 164 return configured 165 166 # asyncssh accepts empty tuple = use system defaults 167 return () 168 169 # --- Connection builders --- 170 171 async def _build_key_connect_fn( 172 self, 173 db: AsyncSession, 174 user_id: str, 175 session_id: str, 176 profile: SSHProfile, 177 known_hosts, 178 ) -> Callable[[], Awaitable[asyncssh.SSHClientConnection]]: 179 """Build a connect_fn for key-based authentication.""" 180 if not profile.key_ref: 181 raise ValueError( 182 f"Profile '{profile.name}' uses key auth but no key_ref configured" 183 ) 184 185 ssh_key = await self._vault.get_ssh_key( 186 db, user_id, profile.key_ref, session_id 187 ) 188 host = profile.host 189 port = profile.port 190 username = profile.username 191 192 logger.debug( 193 "Built key connect_fn for profile '%s' (user=%s, host=%s:%d)", 194 profile.name, user_id, host, port, 195 ) 196 197 async def connect() -> asyncssh.SSHClientConnection: 198 import asyncssh as _asyncssh 199 return await _asyncssh.connect( 200 host=host, 201 port=port, 202 username=username, 203 client_keys=[ssh_key], 204 known_hosts=known_hosts, 205 keepalive_interval=30, 206 keepalive_count_max=3, 207 agent_path=None, 208 config=[], 209 x509_trusted_certs=None, 210 x509_trusted_cert_paths=[], 211 ) 212 213 return connect 214 215 async def _build_certificate_connect_fn( 216 self, 217 db: AsyncSession, 218 user_id: str, 219 session_id: str, 220 profile: SSHProfile, 221 known_hosts, 222 ) -> Callable[[], Awaitable[asyncssh.SSHClientConnection]]: 223 """Build a connect_fn for certificate-based authentication.""" 224 if not profile.key_ref: 225 raise ValueError( 226 f"Profile '{profile.name}' uses cert auth but no key_ref configured" 227 ) 228 if not profile.certificate_ref: 229 raise ValueError( 230 f"Profile '{profile.name}' uses cert auth but no certificate_ref configured" 231 ) 232 233 ssh_key = await self._vault.get_ssh_key( 234 db, user_id, profile.key_ref, session_id 235 ) 236 # certificate_ref stores the vault secret ID (int) as the certificate PEM 237 cert_pem = await self._vault.get_secret_value( 238 db, user_id, profile.certificate_ref, session_id 239 ) 240 host = profile.host 241 port = profile.port 242 username = profile.username 243 244 logger.debug( 245 "Built certificate connect_fn for profile '%s' (user=%s, host=%s:%d)", 246 profile.name, user_id, host, port, 247 ) 248 249 async def connect() -> asyncssh.SSHClientConnection: 250 import asyncssh as _asyncssh 251 cert = _asyncssh.import_certificate(cert_pem) 252 return await _asyncssh.connect( 253 host=host, 254 port=port, 255 username=username, 256 client_keys=[(ssh_key, cert)], 257 known_hosts=known_hosts, 258 keepalive_interval=30, 259 keepalive_count_max=3, 260 agent_path=None, 261 config=[], 262 x509_trusted_certs=None, 263 x509_trusted_cert_paths=[], 264 ) 265 266 return connect 267 268 async def _build_password_connect_fn( 269 self, 270 db: AsyncSession, 271 user_id: str, 272 session_id: str, 273 profile: SSHProfile, 274 known_hosts, 275 ) -> Callable[[], Awaitable[asyncssh.SSHClientConnection]]: 276 """Build a connect_fn for password-based authentication.""" 277 if not self._config.credentials.password_auth_allowed: 278 raise ValueError( 279 "Password authentication is disabled in SSH security config" 280 ) 281 if profile.password_secret_id is None: 282 raise ValueError( 283 f"Profile '{profile.name}' uses password auth but no password_secret_id configured" 284 ) 285 286 password = await self._vault.get_secret_value( 287 db, user_id, profile.password_secret_id, session_id 288 ) 289 host = profile.host 290 port = profile.port 291 username = profile.username 292 293 logger.debug( 294 "Built password connect_fn for profile '%s' (user=%s, host=%s:%d)", 295 profile.name, user_id, host, port, 296 ) 297 298 async def connect() -> asyncssh.SSHClientConnection: 299 import asyncssh as _asyncssh 300 return await _asyncssh.connect( 301 host=host, 302 port=port, 303 username=username, 304 password=password, 305 known_hosts=known_hosts, 306 keepalive_interval=30, 307 keepalive_count_max=3, 308 agent_path=None, 309 config=[], 310 x509_trusted_certs=None, 311 x509_trusted_cert_paths=[], 312 ) 313 314 return connect