/ src / core / ssh / ssh_connection_pool.py
ssh_connection_pool.py
  1  """
  2  SSH connection pool with persistent connections, keepalive, and reconnection.
  3  
  4  Manages SSH connections per-user, per-session, per-profile with:
  5  - SSH-level keepalive (encrypted, application-layer)
  6  - Idle timeout with asyncio watchdog timer
  7  - Transparent reconnection on next use after disconnect
  8  - Background health checker for zombie detection
  9  - All connections closed on session end
 10  """
 11  import asyncio
 12  import logging
 13  from dataclasses import dataclass, field
 14  from datetime import datetime, timezone
 15  from typing import Any, Optional
 16  
 17  logger = logging.getLogger(__name__)
 18  
 19  
 20  class SSHConnectionLimitError(Exception):
 21      """Raised when connection limit per session is exceeded."""
 22      pass
 23  
 24  
 25  @dataclass
 26  class SSHCommandResult:
 27      """Result of an SSH command execution."""
 28      exit_code: int
 29      stdout: str
 30      stderr: str
 31      command: str
 32      timed_out: bool = False
 33      connection_lost: bool = False
 34  
 35  
 36  @dataclass
 37  class SSHConnectionEntry:
 38      """Tracks a single SSH connection with lifecycle metadata."""
 39      conn: Any  # asyncssh.SSHClientConnection (typed as Any to avoid import)
 40      profile_name: str
 41      host: str
 42      port: int
 43      username: str
 44      user_id: str
 45      session_id: str
 46      connected_at: datetime = field(
 47          default_factory=lambda: datetime.now(timezone.utc)
 48      )
 49      last_activity: datetime = field(
 50          default_factory=lambda: datetime.now(timezone.utc)
 51      )
 52      command_count: int = 0
 53      relay_mode: bool = False
 54      privilege_level: int = 0
 55      _watchdog_task: Optional[asyncio.Task] = field(
 56          default=None, repr=False
 57      )
 58  
 59  
 60  class SSHConnectionPool:
 61      """Manages persistent SSH connections across agent turns.
 62  
 63      Key properties:
 64      - Connections are per-user, per-session, per-profile (no sharing)
 65      - SSH-level keepalive via ServerAliveInterval (30s)
 66      - Idle timeout with watchdog timer (configurable, default 15min)
 67      - Transparent reconnection on next use after disconnect
 68      - All connections closed on session end
 69      - Background health checker for zombie detection
 70      """
 71  
 72      def __init__(
 73          self,
 74          idle_timeout_seconds: int = 900,  # 15 minutes
 75          max_connections_per_session: int = 5,
 76          health_check_interval_seconds: int = 60,
 77      ) -> None:
 78          self._connections: dict[str, SSHConnectionEntry] = {}
 79          self._idle_timeout = idle_timeout_seconds
 80          self._max_per_session = max_connections_per_session
 81          self._health_check_interval = health_check_interval_seconds
 82          self._health_task: Optional[asyncio.Task] = None
 83          self._lock = asyncio.Lock()
 84  
 85      def _connection_key(
 86          self, session_id: str, profile_name: str
 87      ) -> str:
 88          """Generate unique connection key."""
 89          return f"{session_id}:{profile_name}"
 90  
 91      async def get_connection(
 92          self,
 93          session_id: str,
 94          profile_name: str,
 95          user_id: str,
 96          connect_fn: Any,  # async callable returning SSHClientConnection
 97      ) -> Any:
 98          """Get or create a connection with transparent reconnection.
 99  
100          Args:
101              session_id: The Ag3ntum session ID.
102              profile_name: SSH profile name.
103              user_id: The Ag3ntum user ID.
104              connect_fn: Async callable that returns an authenticated
105                          asyncssh.SSHClientConnection. Called on new
106                          connections or reconnections.
107  
108          Returns:
109              Live asyncssh.SSHClientConnection.
110  
111          Raises:
112              SSHConnectionLimitError: If session has too many connections.
113          """
114          key = self._connection_key(session_id, profile_name)
115  
116          async with self._lock:
117              entry = self._connections.get(key)
118  
119              # Check if existing connection is alive
120              if entry is not None:
121                  if not self._is_connection_closed(entry.conn):
122                      # Connection alive — reset watchdog and return
123                      entry.last_activity = datetime.now(timezone.utc)
124                      self._reset_watchdog(key, entry)
125                      return entry.conn
126  
127                  # Connection dead — clean up
128                  logger.info(
129                      f"SSH connection dead for {profile_name}, "
130                      f"will reconnect"
131                  )
132                  await self._cleanup_entry(
133                      key, entry, reason="connection_lost"
134                  )
135  
136              # Check per-session limit
137              session_count = sum(
138                  1 for k in self._connections
139                  if k.startswith(f"{session_id}:")
140              )
141              if session_count >= self._max_per_session:
142                  raise SSHConnectionLimitError(
143                      f"Maximum {self._max_per_session} concurrent SSH "
144                      f"connections per session"
145                  )
146  
147              # Connect
148              logger.info(f"Establishing SSH connection for {profile_name}")
149              conn = await connect_fn()
150  
151              entry = SSHConnectionEntry(
152                  conn=conn,
153                  profile_name=profile_name,
154                  host=getattr(conn, '_host', 'unknown'),
155                  port=getattr(conn, '_port', 22),
156                  username=getattr(conn, '_username', 'unknown'),
157                  user_id=user_id,
158                  session_id=session_id,
159              )
160  
161              self._connections[key] = entry
162              self._start_watchdog(key, entry)
163              self._ensure_health_checker()
164  
165              logger.info(
166                  f"SSH connection established: {profile_name} "
167                  f"(session={session_id})"
168              )
169              return conn
170  
171      async def release_connection(
172          self, session_id: str, profile_name: str,
173      ) -> None:
174          """Explicitly close a connection."""
175          key = self._connection_key(session_id, profile_name)
176          async with self._lock:
177              entry = self._connections.get(key)
178              if entry:
179                  await self._cleanup_entry(
180                      key, entry, reason="explicit_close"
181                  )
182  
183      async def close_session_connections(
184          self, session_id: str,
185      ) -> int:
186          """Close ALL connections for a session.
187  
188          Called when an agent session ends.
189  
190          Returns:
191              Number of connections closed.
192          """
193          closed = 0
194          async with self._lock:
195              keys_to_close = [
196                  k for k in self._connections
197                  if k.startswith(f"{session_id}:")
198              ]
199              for key in keys_to_close:
200                  entry = self._connections[key]
201                  await self._cleanup_entry(
202                      key, entry, reason="session_end"
203                  )
204                  closed += 1
205          if closed:
206              logger.info(
207                  f"Closed {closed} SSH connection(s) for session "
208                  f"{session_id}"
209              )
210          return closed
211  
212      def record_activity(
213          self, session_id: str, profile_name: str,
214      ) -> None:
215          """Record activity on a connection (resets watchdog)."""
216          key = self._connection_key(session_id, profile_name)
217          entry = self._connections.get(key)
218          if entry:
219              entry.last_activity = datetime.now(timezone.utc)
220              entry.command_count += 1
221  
222      def get_connection_info(
223          self, session_id: str,
224      ) -> list[dict]:
225          """Get info about all connections for a session."""
226          result = []
227          for key, entry in self._connections.items():
228              if key.startswith(f"{session_id}:"):
229                  now = datetime.now(timezone.utc)
230                  result.append({
231                      "profile": entry.profile_name,
232                      "host": entry.host,
233                      "port": entry.port,
234                      "username": entry.username,
235                      "connected_at": entry.connected_at.isoformat(),
236                      "last_activity": entry.last_activity.isoformat(),
237                      "command_count": entry.command_count,
238                      "privilege_level": entry.privilege_level,
239                      "relay_mode": entry.relay_mode,
240                      "alive": not self._is_connection_closed(entry.conn),
241                      "idle_seconds": int(
242                          (now - entry.last_activity).total_seconds()
243                      ),
244                  })
245          return result
246  
247      @property
248      def total_connections(self) -> int:
249          """Total number of tracked connections."""
250          return len(self._connections)
251  
252      # --- Internal helpers ---
253  
254      def _is_connection_closed(self, conn: Any) -> bool:
255          """Check if an asyncssh connection is closed."""
256          try:
257              # asyncssh connections have a _transport attribute
258              # that is None when closed, or use is_closed() if available
259              if hasattr(conn, 'is_closed'):
260                  return conn.is_closed()
261              if hasattr(conn, '_transport'):
262                  return conn._transport is None
263              return True  # Assume closed if we can't check
264          except Exception:
265              return True
266  
267      def _start_watchdog(
268          self, key: str, entry: SSHConnectionEntry,
269      ) -> None:
270          """Start idle timeout watchdog for a connection."""
271          async def _watchdog() -> None:
272              while True:
273                  await asyncio.sleep(self._idle_timeout)
274                  elapsed = (
275                      datetime.now(timezone.utc) - entry.last_activity
276                  ).total_seconds()
277                  if elapsed >= self._idle_timeout:
278                      async with self._lock:
279                          if key in self._connections:
280                              await self._cleanup_entry(
281                                  key, entry,
282                                  reason=f"idle_timeout_{int(elapsed)}s",
283                              )
284                      return
285  
286          if entry._watchdog_task is not None:
287              entry._watchdog_task.cancel()
288          entry._watchdog_task = asyncio.create_task(_watchdog())
289  
290      def _reset_watchdog(
291          self, key: str, entry: SSHConnectionEntry,
292      ) -> None:
293          """Cancel and restart the watchdog timer on activity."""
294          if entry._watchdog_task is not None:
295              entry._watchdog_task.cancel()
296          self._start_watchdog(key, entry)
297  
298      async def _cleanup_entry(
299          self, key: str, entry: SSHConnectionEntry, reason: str,
300      ) -> None:
301          """Close connection and clean up resources."""
302          if entry._watchdog_task is not None:
303              entry._watchdog_task.cancel()
304              entry._watchdog_task = None
305          try:
306              if not self._is_connection_closed(entry.conn):
307                  entry.conn.close()
308                  await asyncio.wait_for(
309                      entry.conn.wait_closed(), timeout=5
310                  )
311          except asyncio.TimeoutError:
312              logger.warning(
313                  f"SSH connection close timed out: {entry.profile_name}"
314              )
315          except Exception as e:
316              logger.debug(
317                  f"Error closing SSH connection {entry.profile_name}: {e}"
318              )
319          self._connections.pop(key, None)
320          alive_seconds = int(
321              (datetime.now(timezone.utc) - entry.connected_at)
322              .total_seconds()
323          )
324          logger.info(
325              f"SSH connection closed: {entry.profile_name} "
326              f"({reason}, {entry.command_count} commands, "
327              f"alive {alive_seconds}s)"
328          )
329  
330      def _ensure_health_checker(self) -> None:
331          """Start background health check if not running."""
332          if self._health_task is None or self._health_task.done():
333              self._health_task = asyncio.create_task(
334                  self._health_check_loop()
335              )
336  
337      async def _health_check_loop(self) -> None:
338          """Periodically check connection health and close zombies."""
339          while self._connections:
340              await asyncio.sleep(self._health_check_interval)
341              async with self._lock:
342                  for key, entry in list(self._connections.items()):
343                      if self._is_connection_closed(entry.conn):
344                          await self._cleanup_entry(
345                              key, entry, reason="zombie_detected"
346                          )
347          self._health_task = None
348  
349      async def shutdown(self) -> None:
350          """Close all connections and stop background tasks."""
351          async with self._lock:
352              for key, entry in list(self._connections.items()):
353                  await self._cleanup_entry(
354                      key, entry, reason="pool_shutdown"
355                  )
356          if self._health_task is not None:
357              self._health_task.cancel()
358              self._health_task = None