/ hermes_cli / curses_ui.py
curses_ui.py
  1  """Shared curses-based UI components for Hermes CLI.
  2  
  3  Used by `hermes tools` and `hermes skills` for interactive checklists.
  4  Provides a curses multi-select with keyboard navigation, plus a
  5  text-based numbered fallback for terminals without curses support.
  6  """
  7  import sys
  8  from typing import Callable, List, Optional, Set
  9  
 10  from hermes_cli.colors import Colors, color
 11  
 12  
 13  def flush_stdin() -> None:
 14      """Flush any stray bytes from the stdin input buffer.
 15  
 16      Must be called after ``curses.wrapper()`` (or any terminal-mode library
 17      like simple_term_menu) returns, **before** the next ``input()`` /
 18      ``getpass.getpass()`` call.  ``curses.endwin()`` restores the terminal
 19      but does NOT drain the OS input buffer — leftover escape-sequence bytes
 20      (from arrow keys, terminal mode-switch responses, or rapid keypresses)
 21      remain buffered and silently get consumed by the next ``input()`` call,
 22      corrupting user data (e.g. writing ``^[^[`` into .env files).
 23  
 24      On non-TTY stdin (piped, redirected) or Windows, this is a no-op.
 25      """
 26      try:
 27          if not sys.stdin.isatty():
 28              return
 29          import termios
 30          termios.tcflush(sys.stdin, termios.TCIFLUSH)
 31      except Exception:
 32          pass
 33  
 34  
 35  def curses_checklist(
 36      title: str,
 37      items: List[str],
 38      selected: Set[int],
 39      *,
 40      cancel_returns: Set[int] | None = None,
 41      status_fn: Optional[Callable[[Set[int]], str]] = None,
 42  ) -> Set[int]:
 43      """Curses multi-select checklist. Returns set of selected indices.
 44  
 45      Args:
 46          title: Header line displayed above the checklist.
 47          items: Display labels for each row.
 48          selected: Indices that start checked (pre-selected).
 49          cancel_returns: Returned on ESC/q. Defaults to the original *selected*.
 50          status_fn: Optional callback ``f(chosen_indices) -> str`` whose return
 51              value is rendered on the bottom row of the terminal.  Use this for
 52              live aggregate info (e.g. estimated token counts).
 53      """
 54      if cancel_returns is None:
 55          cancel_returns = set(selected)
 56  
 57      # Safety: curses and input() both hang or spin when stdin is not a
 58      # terminal (e.g. subprocess pipe).  Return defaults immediately.
 59      if not sys.stdin.isatty():
 60          return cancel_returns
 61  
 62      try:
 63          import curses
 64          chosen = set(selected)
 65          result_holder: list = [None]
 66  
 67          def _draw(stdscr):
 68              curses.curs_set(0)
 69              if curses.has_colors():
 70                  curses.start_color()
 71                  curses.use_default_colors()
 72                  curses.init_pair(1, curses.COLOR_GREEN, -1)
 73                  curses.init_pair(2, curses.COLOR_YELLOW, -1)
 74                  curses.init_pair(3, 8, -1)  # dim gray
 75              cursor = 0
 76              scroll_offset = 0
 77  
 78              while True:
 79                  stdscr.clear()
 80                  max_y, max_x = stdscr.getmaxyx()
 81  
 82                  # Reserve bottom row for status bar when status_fn provided
 83                  footer_rows = 1 if status_fn else 0
 84  
 85                  # Header
 86                  try:
 87                      hattr = curses.A_BOLD
 88                      if curses.has_colors():
 89                          hattr |= curses.color_pair(2)
 90                      stdscr.addnstr(0, 0, title, max_x - 1, hattr)
 91                      stdscr.addnstr(
 92                          1, 0,
 93                          "  ↑↓ navigate  SPACE toggle  ENTER confirm  ESC cancel",
 94                          max_x - 1, curses.A_DIM,
 95                      )
 96                  except curses.error:
 97                      pass
 98  
 99                  # Scrollable item list
100                  visible_rows = max_y - 3 - footer_rows
101                  if cursor < scroll_offset:
102                      scroll_offset = cursor
103                  elif cursor >= scroll_offset + visible_rows:
104                      scroll_offset = cursor - visible_rows + 1
105  
106                  for draw_i, i in enumerate(
107                      range(scroll_offset, min(len(items), scroll_offset + visible_rows))
108                  ):
109                      y = draw_i + 3
110                      if y >= max_y - 1 - footer_rows:
111                          break
112                      check = "✓" if i in chosen else " "
113                      arrow = "→" if i == cursor else " "
114                      line = f" {arrow} [{check}] {items[i]}"
115                      attr = curses.A_NORMAL
116                      if i == cursor:
117                          attr = curses.A_BOLD
118                          if curses.has_colors():
119                              attr |= curses.color_pair(1)
120                      try:
121                          stdscr.addnstr(y, 0, line, max_x - 1, attr)
122                      except curses.error:
123                          pass
124  
125                  # Status bar (bottom row, right-aligned)
126                  if status_fn:
127                      try:
128                          status_text = status_fn(chosen)
129                          if status_text:
130                              # Right-align on the bottom row
131                              sx = max(0, max_x - len(status_text) - 1)
132                              sattr = curses.A_DIM
133                              if curses.has_colors():
134                                  sattr |= curses.color_pair(3)
135                              stdscr.addnstr(max_y - 1, sx, status_text, max_x - sx - 1, sattr)
136                      except curses.error:
137                          pass
138  
139                  stdscr.refresh()
140                  key = stdscr.getch()
141  
142                  if key in (curses.KEY_UP, ord("k")):
143                      cursor = (cursor - 1) % len(items)
144                  elif key in (curses.KEY_DOWN, ord("j")):
145                      cursor = (cursor + 1) % len(items)
146                  elif key == ord(" "):
147                      chosen.symmetric_difference_update({cursor})
148                  elif key in (curses.KEY_ENTER, 10, 13):
149                      result_holder[0] = set(chosen)
150                      return
151                  elif key in (27, ord("q")):
152                      result_holder[0] = cancel_returns
153                      return
154  
155          curses.wrapper(_draw)
156          flush_stdin()
157          return result_holder[0] if result_holder[0] is not None else cancel_returns
158  
159      except KeyboardInterrupt:
160          return cancel_returns
161      except Exception:
162          return _numbered_fallback(title, items, selected, cancel_returns, status_fn)
163  
164  
165  def curses_radiolist(
166      title: str,
167      items: List[str],
168      selected: int = 0,
169      *,
170      cancel_returns: int | None = None,
171      description: str | None = None,
172  ) -> int:
173      """Curses single-select radio list. Returns the selected index.
174  
175      Args:
176          title: Header line displayed above the list.
177          items: Display labels for each row.
178          selected: Index that starts selected (pre-selected).
179          cancel_returns: Returned on ESC/q. Defaults to the original *selected*.
180          description: Optional multi-line text shown between the title and
181              the item list.  Useful for context that should survive the
182              curses screen clear.
183      """
184      if cancel_returns is None:
185          cancel_returns = selected
186  
187      if not sys.stdin.isatty():
188          return cancel_returns
189  
190      desc_lines: list[str] = []
191      if description:
192          desc_lines = description.splitlines()
193  
194      try:
195          import curses
196          result_holder: list = [None]
197  
198          def _draw(stdscr):
199              curses.curs_set(0)
200              if curses.has_colors():
201                  curses.start_color()
202                  curses.use_default_colors()
203                  curses.init_pair(1, curses.COLOR_GREEN, -1)
204                  curses.init_pair(2, curses.COLOR_YELLOW, -1)
205              cursor = selected
206              scroll_offset = 0
207  
208              while True:
209                  stdscr.clear()
210                  max_y, max_x = stdscr.getmaxyx()
211  
212                  row = 0
213  
214                  # Header
215                  try:
216                      hattr = curses.A_BOLD
217                      if curses.has_colors():
218                          hattr |= curses.color_pair(2)
219                      stdscr.addnstr(row, 0, title, max_x - 1, hattr)
220                      row += 1
221  
222                      # Description lines
223                      for dline in desc_lines:
224                          if row >= max_y - 1:
225                              break
226                          stdscr.addnstr(row, 0, dline, max_x - 1, curses.A_NORMAL)
227                          row += 1
228  
229                      stdscr.addnstr(
230                          row, 0,
231                          "  \u2191\u2193 navigate  ENTER/SPACE select  ESC cancel",
232                          max_x - 1, curses.A_DIM,
233                      )
234                      row += 1
235                  except curses.error:
236                      pass
237  
238                  # Scrollable item list
239                  items_start = row + 1
240                  visible_rows = max_y - items_start - 1
241                  if cursor < scroll_offset:
242                      scroll_offset = cursor
243                  elif cursor >= scroll_offset + visible_rows:
244                      scroll_offset = cursor - visible_rows + 1
245  
246                  for draw_i, i in enumerate(
247                      range(scroll_offset, min(len(items), scroll_offset + visible_rows))
248                  ):
249                      y = draw_i + items_start
250                      if y >= max_y - 1:
251                          break
252                      radio = "\u25cf" if i == selected else "\u25cb"
253                      arrow = "\u2192" if i == cursor else " "
254                      line = f" {arrow} ({radio}) {items[i]}"
255                      attr = curses.A_NORMAL
256                      if i == cursor:
257                          attr = curses.A_BOLD
258                          if curses.has_colors():
259                              attr |= curses.color_pair(1)
260                      try:
261                          stdscr.addnstr(y, 0, line, max_x - 1, attr)
262                      except curses.error:
263                          pass
264  
265                  stdscr.refresh()
266                  key = stdscr.getch()
267  
268                  if key in (curses.KEY_UP, ord("k")):
269                      cursor = (cursor - 1) % len(items)
270                  elif key in (curses.KEY_DOWN, ord("j")):
271                      cursor = (cursor + 1) % len(items)
272                  elif key in (ord(" "), curses.KEY_ENTER, 10, 13):
273                      result_holder[0] = cursor
274                      return
275                  elif key in (27, ord("q")):
276                      result_holder[0] = cancel_returns
277                      return
278  
279          curses.wrapper(_draw)
280          flush_stdin()
281          return result_holder[0] if result_holder[0] is not None else cancel_returns
282  
283      except KeyboardInterrupt:
284          return cancel_returns
285      except Exception:
286          return _radio_numbered_fallback(title, items, selected, cancel_returns)
287  
288  
289  def _radio_numbered_fallback(
290      title: str,
291      items: List[str],
292      selected: int,
293      cancel_returns: int,
294  ) -> int:
295      """Text-based numbered fallback for radio selection."""
296      print(color(f"\n  {title}", Colors.YELLOW))
297      print(color("  Select by number, Enter to confirm.\n", Colors.DIM))
298  
299      for i, label in enumerate(items):
300          marker = color("(\u25cf)", Colors.GREEN) if i == selected else "(\u25cb)"
301          print(f"  {marker} {i + 1:>2}. {label}")
302      print()
303      try:
304          val = input(color(f"  Choice [default {selected + 1}]: ", Colors.DIM)).strip()
305          if not val:
306              return selected
307          idx = int(val) - 1
308          if 0 <= idx < len(items):
309              return idx
310          return selected
311      except (ValueError, KeyboardInterrupt, EOFError):
312          return cancel_returns
313  
314  
315  def curses_single_select(
316      title: str,
317      items: List[str],
318      default_index: int = 0,
319      *,
320      cancel_label: str = "Cancel",
321  ) -> int | None:
322      """Curses single-select menu. Returns selected index or None on cancel.
323  
324      Works inside prompt_toolkit because curses.wrapper() restores the terminal
325      safely, unlike simple_term_menu which conflicts with /dev/tty.
326      """
327      if not sys.stdin.isatty():
328          return None
329  
330      try:
331          import curses
332          result_holder: list = [None]
333  
334          all_items = list(items) + [cancel_label]
335          cancel_idx = len(items)
336  
337          def _draw(stdscr):
338              curses.curs_set(0)
339              if curses.has_colors():
340                  curses.start_color()
341                  curses.use_default_colors()
342                  curses.init_pair(1, curses.COLOR_GREEN, -1)
343                  curses.init_pair(2, curses.COLOR_YELLOW, -1)
344              cursor = min(default_index, len(all_items) - 1)
345              scroll_offset = 0
346  
347              while True:
348                  stdscr.clear()
349                  max_y, max_x = stdscr.getmaxyx()
350  
351                  try:
352                      hattr = curses.A_BOLD
353                      if curses.has_colors():
354                          hattr |= curses.color_pair(2)
355                      stdscr.addnstr(0, 0, title, max_x - 1, hattr)
356                      stdscr.addnstr(
357                          1, 0,
358                          "  ↑↓ navigate  ENTER confirm  ESC/q cancel",
359                          max_x - 1, curses.A_DIM,
360                      )
361                  except curses.error:
362                      pass
363  
364                  visible_rows = max_y - 3
365                  if cursor < scroll_offset:
366                      scroll_offset = cursor
367                  elif cursor >= scroll_offset + visible_rows:
368                      scroll_offset = cursor - visible_rows + 1
369  
370                  for draw_i, i in enumerate(
371                      range(scroll_offset, min(len(all_items), scroll_offset + visible_rows))
372                  ):
373                      y = draw_i + 3
374                      if y >= max_y - 1:
375                          break
376                      arrow = "→" if i == cursor else " "
377                      line = f" {arrow} {all_items[i]}"
378                      attr = curses.A_NORMAL
379                      if i == cursor:
380                          attr = curses.A_BOLD
381                          if curses.has_colors():
382                              attr |= curses.color_pair(1)
383                      try:
384                          stdscr.addnstr(y, 0, line, max_x - 1, attr)
385                      except curses.error:
386                          pass
387  
388                  stdscr.refresh()
389                  key = stdscr.getch()
390  
391                  if key in (curses.KEY_UP, ord("k")):
392                      cursor = (cursor - 1) % len(all_items)
393                  elif key in (curses.KEY_DOWN, ord("j")):
394                      cursor = (cursor + 1) % len(all_items)
395                  elif key in (curses.KEY_ENTER, 10, 13):
396                      result_holder[0] = cursor
397                      return
398                  elif key in (27, ord("q")):
399                      result_holder[0] = None
400                      return
401  
402          curses.wrapper(_draw)
403          flush_stdin()
404          if result_holder[0] is not None and result_holder[0] >= cancel_idx:
405              return None
406          return result_holder[0]
407  
408      except KeyboardInterrupt:
409          return None
410      except Exception:
411          all_items = list(items) + [cancel_label]
412          cancel_idx = len(items)
413          return _numbered_single_fallback(title, all_items, cancel_idx)
414  
415  
416  def _numbered_single_fallback(
417      title: str,
418      items: List[str],
419      cancel_idx: int,
420  ) -> int | None:
421      """Text-based numbered fallback for single-select."""
422      print(f"\n  {title}\n")
423      for i, label in enumerate(items, 1):
424          print(f"  {i}. {label}")
425      print()
426      try:
427          val = input(f"  Choice [1-{len(items)}]: ").strip()
428          if not val:
429              return None
430          idx = int(val) - 1
431          if 0 <= idx < len(items) and idx < cancel_idx:
432              return idx
433          if idx == cancel_idx:
434              return None
435      except (ValueError, KeyboardInterrupt, EOFError):
436          pass
437      return None
438  
439  
440  def _numbered_fallback(
441      title: str,
442      items: List[str],
443      selected: Set[int],
444      cancel_returns: Set[int],
445      status_fn: Optional[Callable[[Set[int]], str]] = None,
446  ) -> Set[int]:
447      """Text-based toggle fallback for terminals without curses."""
448      chosen = set(selected)
449      print(color(f"\n  {title}", Colors.YELLOW))
450      print(color("  Toggle by number, Enter to confirm.\n", Colors.DIM))
451  
452      while True:
453          for i, label in enumerate(items):
454              marker = color("[✓]", Colors.GREEN) if i in chosen else "[ ]"
455              print(f"  {marker} {i + 1:>2}. {label}")
456          if status_fn:
457              status_text = status_fn(chosen)
458              if status_text:
459                  print(color(f"\n  {status_text}", Colors.DIM))
460          print()
461          try:
462              val = input(color("  Toggle # (or Enter to confirm): ", Colors.DIM)).strip()
463              if not val:
464                  break
465              idx = int(val) - 1
466              if 0 <= idx < len(items):
467                  chosen.symmetric_difference_update({idx})
468          except (ValueError, KeyboardInterrupt, EOFError):
469              return cancel_returns
470          print()
471  
472      return chosen