ssh_connection_pool.py
1 """ 2 SSH connection pool with persistent connections, keepalive, and reconnection. 3 4 Manages SSH connections per-user, per-session, per-profile with: 5 - SSH-level keepalive (encrypted, application-layer) 6 - Idle timeout with asyncio watchdog timer 7 - Transparent reconnection on next use after disconnect 8 - Background health checker for zombie detection 9 - All connections closed on session end 10 """ 11 import asyncio 12 import logging 13 from dataclasses import dataclass, field 14 from datetime import datetime, timezone 15 from typing import Any, Optional 16 17 logger = logging.getLogger(__name__) 18 19 20 class SSHConnectionLimitError(Exception): 21 """Raised when connection limit per session is exceeded.""" 22 pass 23 24 25 @dataclass 26 class SSHCommandResult: 27 """Result of an SSH command execution.""" 28 exit_code: int 29 stdout: str 30 stderr: str 31 command: str 32 timed_out: bool = False 33 connection_lost: bool = False 34 35 36 @dataclass 37 class SSHConnectionEntry: 38 """Tracks a single SSH connection with lifecycle metadata.""" 39 conn: Any # asyncssh.SSHClientConnection (typed as Any to avoid import) 40 profile_name: str 41 host: str 42 port: int 43 username: str 44 user_id: str 45 session_id: str 46 connected_at: datetime = field( 47 default_factory=lambda: datetime.now(timezone.utc) 48 ) 49 last_activity: datetime = field( 50 default_factory=lambda: datetime.now(timezone.utc) 51 ) 52 command_count: int = 0 53 relay_mode: bool = False 54 privilege_level: int = 0 55 _watchdog_task: Optional[asyncio.Task] = field( 56 default=None, repr=False 57 ) 58 59 60 class SSHConnectionPool: 61 """Manages persistent SSH connections across agent turns. 62 63 Key properties: 64 - Connections are per-user, per-session, per-profile (no sharing) 65 - SSH-level keepalive via ServerAliveInterval (30s) 66 - Idle timeout with watchdog timer (configurable, default 15min) 67 - Transparent reconnection on next use after disconnect 68 - All connections closed on session end 69 - Background health checker for zombie detection 70 """ 71 72 def __init__( 73 self, 74 idle_timeout_seconds: int = 900, # 15 minutes 75 max_connections_per_session: int = 5, 76 health_check_interval_seconds: int = 60, 77 ) -> None: 78 self._connections: dict[str, SSHConnectionEntry] = {} 79 self._idle_timeout = idle_timeout_seconds 80 self._max_per_session = max_connections_per_session 81 self._health_check_interval = health_check_interval_seconds 82 self._health_task: Optional[asyncio.Task] = None 83 self._lock = asyncio.Lock() 84 85 def _connection_key( 86 self, session_id: str, profile_name: str 87 ) -> str: 88 """Generate unique connection key.""" 89 return f"{session_id}:{profile_name}" 90 91 async def get_connection( 92 self, 93 session_id: str, 94 profile_name: str, 95 user_id: str, 96 connect_fn: Any, # async callable returning SSHClientConnection 97 ) -> Any: 98 """Get or create a connection with transparent reconnection. 99 100 Args: 101 session_id: The Ag3ntum session ID. 102 profile_name: SSH profile name. 103 user_id: The Ag3ntum user ID. 104 connect_fn: Async callable that returns an authenticated 105 asyncssh.SSHClientConnection. Called on new 106 connections or reconnections. 107 108 Returns: 109 Live asyncssh.SSHClientConnection. 110 111 Raises: 112 SSHConnectionLimitError: If session has too many connections. 113 """ 114 key = self._connection_key(session_id, profile_name) 115 116 async with self._lock: 117 entry = self._connections.get(key) 118 119 # Check if existing connection is alive 120 if entry is not None: 121 if not self._is_connection_closed(entry.conn): 122 # Connection alive — reset watchdog and return 123 entry.last_activity = datetime.now(timezone.utc) 124 self._reset_watchdog(key, entry) 125 return entry.conn 126 127 # Connection dead — clean up 128 logger.info( 129 f"SSH connection dead for {profile_name}, " 130 f"will reconnect" 131 ) 132 await self._cleanup_entry( 133 key, entry, reason="connection_lost" 134 ) 135 136 # Check per-session limit 137 session_count = sum( 138 1 for k in self._connections 139 if k.startswith(f"{session_id}:") 140 ) 141 if session_count >= self._max_per_session: 142 raise SSHConnectionLimitError( 143 f"Maximum {self._max_per_session} concurrent SSH " 144 f"connections per session" 145 ) 146 147 # Connect 148 logger.info(f"Establishing SSH connection for {profile_name}") 149 conn = await connect_fn() 150 151 entry = SSHConnectionEntry( 152 conn=conn, 153 profile_name=profile_name, 154 host=getattr(conn, '_host', 'unknown'), 155 port=getattr(conn, '_port', 22), 156 username=getattr(conn, '_username', 'unknown'), 157 user_id=user_id, 158 session_id=session_id, 159 ) 160 161 self._connections[key] = entry 162 self._start_watchdog(key, entry) 163 self._ensure_health_checker() 164 165 logger.info( 166 f"SSH connection established: {profile_name} " 167 f"(session={session_id})" 168 ) 169 return conn 170 171 async def release_connection( 172 self, session_id: str, profile_name: str, 173 ) -> None: 174 """Explicitly close a connection.""" 175 key = self._connection_key(session_id, profile_name) 176 async with self._lock: 177 entry = self._connections.get(key) 178 if entry: 179 await self._cleanup_entry( 180 key, entry, reason="explicit_close" 181 ) 182 183 async def close_session_connections( 184 self, session_id: str, 185 ) -> int: 186 """Close ALL connections for a session. 187 188 Called when an agent session ends. 189 190 Returns: 191 Number of connections closed. 192 """ 193 closed = 0 194 async with self._lock: 195 keys_to_close = [ 196 k for k in self._connections 197 if k.startswith(f"{session_id}:") 198 ] 199 for key in keys_to_close: 200 entry = self._connections[key] 201 await self._cleanup_entry( 202 key, entry, reason="session_end" 203 ) 204 closed += 1 205 if closed: 206 logger.info( 207 f"Closed {closed} SSH connection(s) for session " 208 f"{session_id}" 209 ) 210 return closed 211 212 def record_activity( 213 self, session_id: str, profile_name: str, 214 ) -> None: 215 """Record activity on a connection (resets watchdog).""" 216 key = self._connection_key(session_id, profile_name) 217 entry = self._connections.get(key) 218 if entry: 219 entry.last_activity = datetime.now(timezone.utc) 220 entry.command_count += 1 221 222 def get_connection_info( 223 self, session_id: str, 224 ) -> list[dict]: 225 """Get info about all connections for a session.""" 226 result = [] 227 for key, entry in self._connections.items(): 228 if key.startswith(f"{session_id}:"): 229 now = datetime.now(timezone.utc) 230 result.append({ 231 "profile": entry.profile_name, 232 "host": entry.host, 233 "port": entry.port, 234 "username": entry.username, 235 "connected_at": entry.connected_at.isoformat(), 236 "last_activity": entry.last_activity.isoformat(), 237 "command_count": entry.command_count, 238 "privilege_level": entry.privilege_level, 239 "relay_mode": entry.relay_mode, 240 "alive": not self._is_connection_closed(entry.conn), 241 "idle_seconds": int( 242 (now - entry.last_activity).total_seconds() 243 ), 244 }) 245 return result 246 247 @property 248 def total_connections(self) -> int: 249 """Total number of tracked connections.""" 250 return len(self._connections) 251 252 # --- Internal helpers --- 253 254 def _is_connection_closed(self, conn: Any) -> bool: 255 """Check if an asyncssh connection is closed.""" 256 try: 257 # asyncssh connections have a _transport attribute 258 # that is None when closed, or use is_closed() if available 259 if hasattr(conn, 'is_closed'): 260 return conn.is_closed() 261 if hasattr(conn, '_transport'): 262 return conn._transport is None 263 return True # Assume closed if we can't check 264 except Exception: 265 return True 266 267 def _start_watchdog( 268 self, key: str, entry: SSHConnectionEntry, 269 ) -> None: 270 """Start idle timeout watchdog for a connection.""" 271 async def _watchdog() -> None: 272 while True: 273 await asyncio.sleep(self._idle_timeout) 274 elapsed = ( 275 datetime.now(timezone.utc) - entry.last_activity 276 ).total_seconds() 277 if elapsed >= self._idle_timeout: 278 async with self._lock: 279 if key in self._connections: 280 await self._cleanup_entry( 281 key, entry, 282 reason=f"idle_timeout_{int(elapsed)}s", 283 ) 284 return 285 286 if entry._watchdog_task is not None: 287 entry._watchdog_task.cancel() 288 entry._watchdog_task = asyncio.create_task(_watchdog()) 289 290 def _reset_watchdog( 291 self, key: str, entry: SSHConnectionEntry, 292 ) -> None: 293 """Cancel and restart the watchdog timer on activity.""" 294 if entry._watchdog_task is not None: 295 entry._watchdog_task.cancel() 296 self._start_watchdog(key, entry) 297 298 async def _cleanup_entry( 299 self, key: str, entry: SSHConnectionEntry, reason: str, 300 ) -> None: 301 """Close connection and clean up resources.""" 302 if entry._watchdog_task is not None: 303 entry._watchdog_task.cancel() 304 entry._watchdog_task = None 305 try: 306 if not self._is_connection_closed(entry.conn): 307 entry.conn.close() 308 await asyncio.wait_for( 309 entry.conn.wait_closed(), timeout=5 310 ) 311 except asyncio.TimeoutError: 312 logger.warning( 313 f"SSH connection close timed out: {entry.profile_name}" 314 ) 315 except Exception as e: 316 logger.debug( 317 f"Error closing SSH connection {entry.profile_name}: {e}" 318 ) 319 self._connections.pop(key, None) 320 alive_seconds = int( 321 (datetime.now(timezone.utc) - entry.connected_at) 322 .total_seconds() 323 ) 324 logger.info( 325 f"SSH connection closed: {entry.profile_name} " 326 f"({reason}, {entry.command_count} commands, " 327 f"alive {alive_seconds}s)" 328 ) 329 330 def _ensure_health_checker(self) -> None: 331 """Start background health check if not running.""" 332 if self._health_task is None or self._health_task.done(): 333 self._health_task = asyncio.create_task( 334 self._health_check_loop() 335 ) 336 337 async def _health_check_loop(self) -> None: 338 """Periodically check connection health and close zombies.""" 339 while self._connections: 340 await asyncio.sleep(self._health_check_interval) 341 async with self._lock: 342 for key, entry in list(self._connections.items()): 343 if self._is_connection_closed(entry.conn): 344 await self._cleanup_entry( 345 key, entry, reason="zombie_detected" 346 ) 347 self._health_task = None 348 349 async def shutdown(self) -> None: 350 """Close all connections and stop background tasks.""" 351 async with self._lock: 352 for key, entry in list(self._connections.items()): 353 await self._cleanup_entry( 354 key, entry, reason="pool_shutdown" 355 ) 356 if self._health_task is not None: 357 self._health_task.cancel() 358 self._health_task = None