/ qwencoder-eval / instruct / aider / aider / coders / base_coder.py
base_coder.py
   1  #!/usr/bin/env python
   2  
   3  import base64
   4  import hashlib
   5  import json
   6  import locale
   7  import math
   8  import mimetypes
   9  import os
  10  import platform
  11  import re
  12  import sys
  13  import threading
  14  import time
  15  import traceback
  16  from collections import defaultdict
  17  from datetime import datetime
  18  from json.decoder import JSONDecodeError
  19  from pathlib import Path
  20  
  21  from rich.console import Console, Text
  22  from rich.markdown import Markdown
  23  
  24  from aider import __version__, models, prompts, urls, utils
  25  from aider.commands import Commands
  26  from aider.history import ChatSummary
  27  from aider.io import ConfirmGroup, InputOutput
  28  from aider.linter import Linter
  29  from aider.llm import litellm
  30  from aider.mdstream import MarkdownStream
  31  from aider.repo import ANY_GIT_ERROR, GitRepo
  32  from aider.repomap import RepoMap
  33  from aider.run_cmd import run_cmd
  34  from aider.sendchat import retry_exceptions, send_completion
  35  from aider.utils import format_content, format_messages, format_tokens, is_image_file
  36  
  37  from ..dump import dump  # noqa: F401
  38  from .chat_chunks import ChatChunks
  39  
  40  
  41  class MissingAPIKeyError(ValueError):
  42      pass
  43  
  44  
  45  class FinishReasonLength(Exception):
  46      pass
  47  
  48  
  49  def wrap_fence(name):
  50      return f"<{name}>", f"</{name}>"
  51  
  52  
  53  all_fences = [
  54      ("``" + "`", "``" + "`"),
  55      wrap_fence("source"),
  56      wrap_fence("code"),
  57      wrap_fence("pre"),
  58      wrap_fence("codeblock"),
  59      wrap_fence("sourcecode"),
  60  ]
  61  
  62  
  63  class Coder:
  64      abs_fnames = None
  65      abs_read_only_fnames = None
  66      repo = None
  67      last_aider_commit_hash = None
  68      aider_edited_files = None
  69      last_asked_for_commit_time = 0
  70      repo_map = None
  71      functions = None
  72      num_exhausted_context_windows = 0
  73      num_malformed_responses = 0
  74      last_keyboard_interrupt = None
  75      num_reflections = 0
  76      max_reflections = 3
  77      edit_format = None
  78      yield_stream = False
  79      temperature = 0
  80      auto_lint = True
  81      auto_test = False
  82      test_cmd = None
  83      lint_outcome = None
  84      test_outcome = None
  85      multi_response_content = ""
  86      partial_response_content = ""
  87      commit_before_message = []
  88      message_cost = 0.0
  89      message_tokens_sent = 0
  90      message_tokens_received = 0
  91      add_cache_headers = False
  92      cache_warming_thread = None
  93      num_cache_warming_pings = 0
  94      suggest_shell_commands = True
  95      ignore_mentions = None
  96      chat_language = None
  97  
  98      @classmethod
  99      def create(
 100          self,
 101          main_model=None,
 102          edit_format=None,
 103          io=None,
 104          from_coder=None,
 105          summarize_from_coder=True,
 106          **kwargs,
 107      ):
 108          import aider.coders as coders
 109  
 110          if not main_model:
 111              if from_coder:
 112                  main_model = from_coder.main_model
 113              else:
 114                  main_model = models.Model(models.DEFAULT_MODEL_NAME)
 115  
 116          if edit_format == "code":
 117              edit_format = None
 118          if edit_format is None:
 119              if from_coder:
 120                  edit_format = from_coder.edit_format
 121              else:
 122                  edit_format = main_model.edit_format
 123  
 124          if not io and from_coder:
 125              io = from_coder.io
 126  
 127          if from_coder:
 128              use_kwargs = dict(from_coder.original_kwargs)  # copy orig kwargs
 129  
 130              # If the edit format changes, we can't leave old ASSISTANT
 131              # messages in the chat history. The old edit format will
 132              # confused the new LLM. It may try and imitate it, disobeying
 133              # the system prompt.
 134              done_messages = from_coder.done_messages
 135              if edit_format != from_coder.edit_format and done_messages and summarize_from_coder:
 136                  done_messages = from_coder.summarizer.summarize_all(done_messages)
 137  
 138              # Bring along context from the old Coder
 139              update = dict(
 140                  fnames=list(from_coder.abs_fnames),
 141                  read_only_fnames=list(from_coder.abs_read_only_fnames),  # Copy read-only files
 142                  done_messages=done_messages,
 143                  cur_messages=from_coder.cur_messages,
 144                  aider_commit_hashes=from_coder.aider_commit_hashes,
 145                  commands=from_coder.commands.clone(),
 146                  total_cost=from_coder.total_cost,
 147              )
 148  
 149              use_kwargs.update(update)  # override to complete the switch
 150              use_kwargs.update(kwargs)  # override passed kwargs
 151  
 152              kwargs = use_kwargs
 153  
 154          for coder in coders.__all__:
 155              if hasattr(coder, "edit_format") and coder.edit_format == edit_format:
 156                  res = coder(main_model, io, **kwargs)
 157                  res.original_kwargs = dict(kwargs)
 158                  return res
 159  
 160          raise ValueError(f"Unknown edit format {edit_format}")
 161  
 162      def clone(self, **kwargs):
 163          new_coder = Coder.create(from_coder=self, **kwargs)
 164          new_coder.ignore_mentions = self.ignore_mentions
 165          return new_coder
 166  
 167      def get_announcements(self):
 168          lines = []
 169          lines.append(f"Aider v{__version__}")
 170  
 171          # Model
 172          main_model = self.main_model
 173          weak_model = main_model.weak_model
 174  
 175          if weak_model is not main_model:
 176              prefix = "Main model"
 177          else:
 178              prefix = "Model"
 179  
 180          output = f"{prefix}: {main_model.name} with {self.edit_format} edit format"
 181          if self.add_cache_headers or main_model.caches_by_default:
 182              output += ", prompt cache"
 183          if main_model.info.get("supports_assistant_prefill"):
 184              output += ", infinite output"
 185          lines.append(output)
 186  
 187          if weak_model is not main_model:
 188              output = f"Weak model: {weak_model.name}"
 189              lines.append(output)
 190  
 191          # Repo
 192          if self.repo:
 193              rel_repo_dir = self.repo.get_rel_repo_dir()
 194              num_files = len(self.repo.get_tracked_files())
 195  
 196              lines.append(f"Git repo: {rel_repo_dir} with {num_files:,} files")
 197              if num_files > 1000:
 198                  lines.append(
 199                      "Warning: For large repos, consider using --subtree-only and .aiderignore"
 200                  )
 201                  lines.append(f"See: {urls.large_repos}")
 202          else:
 203              lines.append("Git repo: none")
 204  
 205          # Repo-map
 206          if self.repo_map:
 207              map_tokens = self.repo_map.max_map_tokens
 208              if map_tokens > 0:
 209                  refresh = self.repo_map.refresh
 210                  lines.append(f"Repo-map: using {map_tokens} tokens, {refresh} refresh")
 211                  max_map_tokens = 2048
 212                  if map_tokens > max_map_tokens:
 213                      lines.append(
 214                          f"Warning: map-tokens > {max_map_tokens} is not recommended as too much"
 215                          " irrelevant code can confuse LLMs."
 216                      )
 217              else:
 218                  lines.append("Repo-map: disabled because map_tokens == 0")
 219          else:
 220              lines.append("Repo-map: disabled")
 221  
 222          # Files
 223          for fname in self.get_inchat_relative_files():
 224              lines.append(f"Added {fname} to the chat.")
 225  
 226          if self.done_messages:
 227              lines.append("Restored previous conversation history.")
 228  
 229          return lines
 230  
 231      def __init__(
 232          self,
 233          main_model,
 234          io,
 235          repo=None,
 236          fnames=None,
 237          read_only_fnames=None,
 238          show_diffs=False,
 239          auto_commits=True,
 240          dirty_commits=True,
 241          dry_run=False,
 242          map_tokens=1024,
 243          verbose=False,
 244          assistant_output_color="blue",
 245          code_theme="default",
 246          stream=True,
 247          use_git=True,
 248          cur_messages=None,
 249          done_messages=None,
 250          restore_chat_history=False,
 251          auto_lint=True,
 252          auto_test=False,
 253          lint_cmds=None,
 254          test_cmd=None,
 255          aider_commit_hashes=None,
 256          map_mul_no_files=8,
 257          commands=None,
 258          summarizer=None,
 259          total_cost=0.0,
 260          map_refresh="auto",
 261          cache_prompts=False,
 262          num_cache_warming_pings=0,
 263          suggest_shell_commands=True,
 264          chat_language=None,
 265      ):
 266          self.chat_language = chat_language
 267          self.commit_before_message = []
 268          self.aider_commit_hashes = set()
 269          self.rejected_urls = set()
 270          self.abs_root_path_cache = {}
 271          self.ignore_mentions = set()
 272  
 273          self.suggest_shell_commands = suggest_shell_commands
 274  
 275          self.num_cache_warming_pings = num_cache_warming_pings
 276  
 277          if not fnames:
 278              fnames = []
 279  
 280          if io is None:
 281              io = InputOutput()
 282  
 283          if aider_commit_hashes:
 284              self.aider_commit_hashes = aider_commit_hashes
 285          else:
 286              self.aider_commit_hashes = set()
 287  
 288          self.chat_completion_call_hashes = []
 289          self.chat_completion_response_hashes = []
 290          self.need_commit_before_edits = set()
 291  
 292          self.total_cost = total_cost
 293  
 294          self.verbose = verbose
 295          self.abs_fnames = set()
 296          self.abs_read_only_fnames = set()
 297  
 298          if cur_messages:
 299              self.cur_messages = cur_messages
 300          else:
 301              self.cur_messages = []
 302  
 303          if done_messages:
 304              self.done_messages = done_messages
 305          else:
 306              self.done_messages = []
 307  
 308          self.io = io
 309          self.stream = stream
 310  
 311          self.shell_commands = []
 312  
 313          if not auto_commits:
 314              dirty_commits = False
 315  
 316          self.auto_commits = auto_commits
 317          self.dirty_commits = dirty_commits
 318          self.assistant_output_color = assistant_output_color
 319          self.code_theme = code_theme
 320  
 321          self.dry_run = dry_run
 322          self.pretty = self.io.pretty
 323  
 324          if self.pretty:
 325              self.console = Console()
 326          else:
 327              self.console = Console(force_terminal=False, no_color=True)
 328  
 329          self.main_model = main_model
 330  
 331          if cache_prompts and self.main_model.cache_control:
 332              self.add_cache_headers = True
 333  
 334          self.show_diffs = show_diffs
 335  
 336          self.commands = commands or Commands(self.io, self)
 337          self.commands.coder = self
 338  
 339          self.repo = repo
 340          if use_git and self.repo is None:
 341              try:
 342                  self.repo = GitRepo(
 343                      self.io,
 344                      fnames,
 345                      None,
 346                      models=main_model.commit_message_models(),
 347                  )
 348              except FileNotFoundError:
 349                  pass
 350  
 351          if self.repo:
 352              self.root = self.repo.root
 353  
 354          for fname in fnames:
 355              fname = Path(fname)
 356              if self.repo and self.repo.ignored_file(fname):
 357                  self.io.tool_warning(f"Skipping {fname} that matches aiderignore spec.")
 358                  continue
 359  
 360              if not fname.exists():
 361                  if utils.touch_file(fname):
 362                      self.io.tool_output(f"Creating empty file {fname}")
 363                  else:
 364                      self.io.tool_warning(f"Can not create {fname}, skipping.")
 365                      continue
 366  
 367              if not fname.is_file():
 368                  self.io.tool_warning(f"Skipping {fname} that is not a normal file.")
 369                  continue
 370  
 371              fname = str(fname.resolve())
 372  
 373              self.abs_fnames.add(fname)
 374              self.check_added_files()
 375  
 376          if not self.repo:
 377              self.root = utils.find_common_root(self.abs_fnames)
 378  
 379          if read_only_fnames:
 380              self.abs_read_only_fnames = set()
 381              for fname in read_only_fnames:
 382                  abs_fname = self.abs_root_path(fname)
 383                  if os.path.exists(abs_fname):
 384                      self.abs_read_only_fnames.add(abs_fname)
 385                  else:
 386                      self.io.tool_warning(f"Error: Read-only file {fname} does not exist. Skipping.")
 387  
 388          if map_tokens is None:
 389              use_repo_map = main_model.use_repo_map
 390              map_tokens = 1024
 391          else:
 392              use_repo_map = map_tokens > 0
 393  
 394          max_inp_tokens = self.main_model.info.get("max_input_tokens") or 0
 395  
 396          has_map_prompt = hasattr(self, "gpt_prompts") and self.gpt_prompts.repo_content_prefix
 397  
 398          if use_repo_map and self.repo and has_map_prompt:
 399              self.repo_map = RepoMap(
 400                  map_tokens,
 401                  self.root,
 402                  self.main_model,
 403                  io,
 404                  self.gpt_prompts.repo_content_prefix,
 405                  self.verbose,
 406                  max_inp_tokens,
 407                  map_mul_no_files=map_mul_no_files,
 408                  refresh=map_refresh,
 409              )
 410  
 411          self.summarizer = summarizer or ChatSummary(
 412              [self.main_model.weak_model, self.main_model],
 413              self.main_model.max_chat_history_tokens,
 414          )
 415  
 416          self.summarizer_thread = None
 417          self.summarized_done_messages = []
 418  
 419          if not self.done_messages and restore_chat_history:
 420              history_md = self.io.read_text(self.io.chat_history_file)
 421              if history_md:
 422                  self.done_messages = utils.split_chat_history_markdown(history_md)
 423                  self.summarize_start()
 424  
 425          # Linting and testing
 426          self.linter = Linter(root=self.root, encoding=io.encoding)
 427          self.auto_lint = auto_lint
 428          self.setup_lint_cmds(lint_cmds)
 429          self.lint_cmds = lint_cmds
 430          self.auto_test = auto_test
 431          self.test_cmd = test_cmd
 432  
 433          # validate the functions jsonschema
 434          if self.functions:
 435              from jsonschema import Draft7Validator
 436  
 437              for function in self.functions:
 438                  Draft7Validator.check_schema(function)
 439  
 440              if self.verbose:
 441                  self.io.tool_output("JSON Schema:")
 442                  self.io.tool_output(json.dumps(self.functions, indent=4))
 443  
 444      def setup_lint_cmds(self, lint_cmds):
 445          if not lint_cmds:
 446              return
 447          for lang, cmd in lint_cmds.items():
 448              self.linter.set_linter(lang, cmd)
 449  
 450      def show_announcements(self):
 451          bold = True
 452          for line in self.get_announcements():
 453              self.io.tool_output(line, bold=bold)
 454              bold = False
 455  
 456      def add_rel_fname(self, rel_fname):
 457          self.abs_fnames.add(self.abs_root_path(rel_fname))
 458          self.check_added_files()
 459  
 460      def drop_rel_fname(self, fname):
 461          abs_fname = self.abs_root_path(fname)
 462          if abs_fname in self.abs_fnames:
 463              self.abs_fnames.remove(abs_fname)
 464              return True
 465  
 466      def abs_root_path(self, path):
 467          key = path
 468          if key in self.abs_root_path_cache:
 469              return self.abs_root_path_cache[key]
 470  
 471          res = Path(self.root) / path
 472          res = utils.safe_abs_path(res)
 473          self.abs_root_path_cache[key] = res
 474          return res
 475  
 476      fences = all_fences
 477      fence = fences[0]
 478  
 479      def show_pretty(self):
 480          if not self.pretty:
 481              return False
 482  
 483          # only show pretty output if fences are the normal triple-backtick
 484          if self.fence != self.fences[0]:
 485              return False
 486  
 487          return True
 488  
 489      def get_abs_fnames_content(self):
 490          for fname in list(self.abs_fnames):
 491              content = self.io.read_text(fname)
 492  
 493              if content is None:
 494                  relative_fname = self.get_rel_fname(fname)
 495                  self.io.tool_warning(f"Dropping {relative_fname} from the chat.")
 496                  self.abs_fnames.remove(fname)
 497              else:
 498                  yield fname, content
 499  
 500      def choose_fence(self):
 501          all_content = ""
 502          for _fname, content in self.get_abs_fnames_content():
 503              all_content += content + "\n"
 504          for _fname in self.abs_read_only_fnames:
 505              content = self.io.read_text(_fname)
 506              if content is not None:
 507                  all_content += content + "\n"
 508  
 509          good = False
 510          for fence_open, fence_close in self.fences:
 511              if fence_open in all_content or fence_close in all_content:
 512                  continue
 513              good = True
 514              break
 515  
 516          if good:
 517              self.fence = (fence_open, fence_close)
 518          else:
 519              self.fence = self.fences[0]
 520              self.io.tool_warning(
 521                  "Unable to find a fencing strategy! Falling back to:"
 522                  f" {self.fence[0]}...{self.fence[1]}"
 523              )
 524  
 525          return
 526  
 527      def get_files_content(self, fnames=None):
 528          if not fnames:
 529              fnames = self.abs_fnames
 530  
 531          prompt = ""
 532          for fname, content in self.get_abs_fnames_content():
 533              if not is_image_file(fname):
 534                  relative_fname = self.get_rel_fname(fname)
 535                  prompt += "\n"
 536                  prompt += relative_fname
 537                  prompt += f"\n{self.fence[0]}\n"
 538  
 539                  prompt += content
 540  
 541                  # lines = content.splitlines(keepends=True)
 542                  # lines = [f"{i+1:03}:{line}" for i, line in enumerate(lines)]
 543                  # prompt += "".join(lines)
 544  
 545                  prompt += f"{self.fence[1]}\n"
 546  
 547          return prompt
 548  
 549      def get_read_only_files_content(self):
 550          prompt = ""
 551          for fname in self.abs_read_only_fnames:
 552              content = self.io.read_text(fname)
 553              if content is not None and not is_image_file(fname):
 554                  relative_fname = self.get_rel_fname(fname)
 555                  prompt += "\n"
 556                  prompt += relative_fname
 557                  prompt += f"\n{self.fence[0]}\n"
 558                  prompt += content
 559                  prompt += f"{self.fence[1]}\n"
 560          return prompt
 561  
 562      def get_cur_message_text(self):
 563          text = ""
 564          for msg in self.cur_messages:
 565              text += msg["content"] + "\n"
 566          return text
 567  
 568      def get_ident_mentions(self, text):
 569          # Split the string on any character that is not alphanumeric
 570          # \W+ matches one or more non-word characters (equivalent to [^a-zA-Z0-9_]+)
 571          words = set(re.split(r"\W+", text))
 572          return words
 573  
 574      def get_ident_filename_matches(self, idents):
 575          all_fnames = defaultdict(set)
 576          for fname in self.get_all_relative_files():
 577              base = Path(fname).with_suffix("").name.lower()
 578              if len(base) >= 5:
 579                  all_fnames[base].add(fname)
 580  
 581          matches = set()
 582          for ident in idents:
 583              if len(ident) < 5:
 584                  continue
 585              matches.update(all_fnames[ident.lower()])
 586  
 587          return matches
 588  
 589      def get_repo_map(self, force_refresh=False):
 590          if not self.repo_map:
 591              return
 592  
 593          cur_msg_text = self.get_cur_message_text()
 594          mentioned_fnames = self.get_file_mentions(cur_msg_text)
 595          mentioned_idents = self.get_ident_mentions(cur_msg_text)
 596  
 597          mentioned_fnames.update(self.get_ident_filename_matches(mentioned_idents))
 598  
 599          all_abs_files = set(self.get_all_abs_files())
 600          repo_abs_read_only_fnames = set(self.abs_read_only_fnames) & all_abs_files
 601          chat_files = set(self.abs_fnames) | repo_abs_read_only_fnames
 602          other_files = all_abs_files - chat_files
 603  
 604          repo_content = self.repo_map.get_repo_map(
 605              chat_files,
 606              other_files,
 607              mentioned_fnames=mentioned_fnames,
 608              mentioned_idents=mentioned_idents,
 609              force_refresh=force_refresh,
 610          )
 611  
 612          # fall back to global repo map if files in chat are disjoint from rest of repo
 613          if not repo_content:
 614              repo_content = self.repo_map.get_repo_map(
 615                  set(),
 616                  all_abs_files,
 617                  mentioned_fnames=mentioned_fnames,
 618                  mentioned_idents=mentioned_idents,
 619              )
 620  
 621          # fall back to completely unhinted repo
 622          if not repo_content:
 623              repo_content = self.repo_map.get_repo_map(
 624                  set(),
 625                  all_abs_files,
 626              )
 627  
 628          return repo_content
 629  
 630      def get_repo_messages(self):
 631          repo_messages = []
 632          repo_content = self.get_repo_map()
 633          if repo_content:
 634              repo_messages += [
 635                  dict(role="user", content=repo_content),
 636                  dict(
 637                      role="assistant",
 638                      content="Ok, I won't try and edit those files without asking first.",
 639                  ),
 640              ]
 641          return repo_messages
 642  
 643      def get_readonly_files_messages(self):
 644          readonly_messages = []
 645          read_only_content = self.get_read_only_files_content()
 646          if read_only_content:
 647              readonly_messages += [
 648                  dict(
 649                      role="user", content=self.gpt_prompts.read_only_files_prefix + read_only_content
 650                  ),
 651                  dict(
 652                      role="assistant",
 653                      content="Ok, I will use these files as references.",
 654                  ),
 655              ]
 656          return readonly_messages
 657  
 658      def get_chat_files_messages(self):
 659          chat_files_messages = []
 660          if self.abs_fnames:
 661              files_content = self.gpt_prompts.files_content_prefix
 662              files_content += self.get_files_content()
 663              files_reply = "Ok, any changes I propose will be to those files."
 664          elif self.get_repo_map() and self.gpt_prompts.files_no_full_files_with_repo_map:
 665              files_content = self.gpt_prompts.files_no_full_files_with_repo_map
 666              files_reply = self.gpt_prompts.files_no_full_files_with_repo_map_reply
 667          else:
 668              files_content = self.gpt_prompts.files_no_full_files
 669              files_reply = "Ok."
 670  
 671          if files_content:
 672              chat_files_messages += [
 673                  dict(role="user", content=files_content),
 674                  dict(role="assistant", content=files_reply),
 675              ]
 676  
 677          images_message = self.get_images_message()
 678          if images_message is not None:
 679              chat_files_messages += [
 680                  images_message,
 681                  dict(role="assistant", content="Ok."),
 682              ]
 683  
 684          return chat_files_messages
 685  
 686      def get_images_message(self):
 687          if not self.main_model.accepts_images:
 688              return None
 689  
 690          image_messages = []
 691          for fname, content in self.get_abs_fnames_content():
 692              if is_image_file(fname):
 693                  with open(fname, "rb") as image_file:
 694                      encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
 695                  mime_type, _ = mimetypes.guess_type(fname)
 696                  if mime_type and mime_type.startswith("image/"):
 697                      image_url = f"data:{mime_type};base64,{encoded_string}"
 698                      rel_fname = self.get_rel_fname(fname)
 699                      image_messages += [
 700                          {"type": "text", "text": f"Image file: {rel_fname}"},
 701                          {"type": "image_url", "image_url": {"url": image_url, "detail": "high"}},
 702                      ]
 703  
 704          if not image_messages:
 705              return None
 706  
 707          return {"role": "user", "content": image_messages}
 708  
 709      def run_stream(self, user_message):
 710          self.io.user_input(user_message)
 711          self.init_before_message()
 712          yield from self.send_message(user_message)
 713  
 714      def init_before_message(self):
 715          self.aider_edited_files = set()
 716          self.reflected_message = None
 717          self.num_reflections = 0
 718          self.lint_outcome = None
 719          self.test_outcome = None
 720          self.shell_commands = []
 721  
 722          if self.repo:
 723              self.commit_before_message.append(self.repo.get_head_commit_sha())
 724  
 725      def run(self, with_message=None, preproc=True):
 726          try:
 727              if with_message:
 728                  self.io.user_input(with_message)
 729                  self.run_one(with_message, preproc)
 730                  return self.partial_response_content
 731  
 732              while True:
 733                  try:
 734                      user_message = self.get_input()
 735                      self.run_one(user_message, preproc)
 736                      self.show_undo_hint()
 737                  except KeyboardInterrupt:
 738                      self.keyboard_interrupt()
 739          except EOFError:
 740              return
 741  
 742      def get_input(self):
 743          inchat_files = self.get_inchat_relative_files()
 744          read_only_files = [self.get_rel_fname(fname) for fname in self.abs_read_only_fnames]
 745          all_files = sorted(set(inchat_files + read_only_files))
 746          edit_format = "" if self.edit_format == self.main_model.edit_format else self.edit_format
 747          return self.io.get_input(
 748              self.root,
 749              all_files,
 750              self.get_addable_relative_files(),
 751              self.commands,
 752              self.abs_read_only_fnames,
 753              edit_format=edit_format,
 754          )
 755  
 756      def preproc_user_input(self, inp):
 757          if not inp:
 758              return
 759  
 760          if self.commands.is_command(inp):
 761              return self.commands.run(inp)
 762  
 763          self.check_for_file_mentions(inp)
 764          self.check_for_urls(inp)
 765  
 766          return inp
 767  
 768      def run_one(self, user_message, preproc):
 769          self.init_before_message()
 770  
 771          if preproc:
 772              message = self.preproc_user_input(user_message)
 773          else:
 774              message = user_message
 775  
 776          while message:
 777              self.reflected_message = None
 778              list(self.send_message(message))
 779  
 780              if not self.reflected_message:
 781                  break
 782  
 783              if self.num_reflections >= self.max_reflections:
 784                  self.io.tool_warning(f"Only {self.max_reflections} reflections allowed, stopping.")
 785                  return
 786  
 787              self.num_reflections += 1
 788              message = self.reflected_message
 789  
 790      def check_for_urls(self, inp):
 791          url_pattern = re.compile(r"(https?://[^\s/$.?#].[^\s]*[^\s,.])")
 792          urls = list(set(url_pattern.findall(inp)))  # Use set to remove duplicates
 793          added_urls = []
 794          group = ConfirmGroup(urls)
 795          for url in urls:
 796              if url not in self.rejected_urls:
 797                  if self.io.confirm_ask("Add URL to the chat?", subject=url, group=group):
 798                      inp += "\n\n"
 799                      inp += self.commands.cmd_web(url)
 800                      added_urls.append(url)
 801                  else:
 802                      self.rejected_urls.add(url)
 803  
 804          return added_urls
 805  
 806      def keyboard_interrupt(self):
 807          now = time.time()
 808  
 809          thresh = 2  # seconds
 810          if self.last_keyboard_interrupt and now - self.last_keyboard_interrupt < thresh:
 811              self.io.tool_warning("\n\n^C KeyboardInterrupt")
 812              sys.exit()
 813  
 814          self.io.tool_warning("\n\n^C again to exit")
 815  
 816          self.last_keyboard_interrupt = now
 817  
 818      def summarize_start(self):
 819          if not self.summarizer.too_big(self.done_messages):
 820              return
 821  
 822          self.summarize_end()
 823  
 824          if self.verbose:
 825              self.io.tool_output("Starting to summarize chat history.")
 826  
 827          self.summarizer_thread = threading.Thread(target=self.summarize_worker)
 828          self.summarizer_thread.start()
 829  
 830      def summarize_worker(self):
 831          try:
 832              self.summarized_done_messages = self.summarizer.summarize(self.done_messages)
 833          except ValueError as err:
 834              self.io.tool_warning(err.args[0])
 835  
 836          if self.verbose:
 837              self.io.tool_output("Finished summarizing chat history.")
 838  
 839      def summarize_end(self):
 840          if self.summarizer_thread is None:
 841              return
 842  
 843          self.summarizer_thread.join()
 844          self.summarizer_thread = None
 845  
 846          self.done_messages = self.summarized_done_messages
 847          self.summarized_done_messages = []
 848  
 849      def move_back_cur_messages(self, message):
 850          self.done_messages += self.cur_messages
 851          self.summarize_start()
 852  
 853          # TODO check for impact on image messages
 854          if message:
 855              self.done_messages += [
 856                  dict(role="user", content=message),
 857                  dict(role="assistant", content="Ok."),
 858              ]
 859          self.cur_messages = []
 860  
 861      def get_user_language(self):
 862          if self.chat_language:
 863              return self.chat_language
 864  
 865          try:
 866              lang = locale.getlocale()[0]
 867              if lang:
 868                  return lang  # Return the full language code, including country
 869          except Exception:
 870              pass
 871  
 872          for env_var in ["LANG", "LANGUAGE", "LC_ALL", "LC_MESSAGES"]:
 873              lang = os.environ.get(env_var)
 874              if lang:
 875                  return lang.split(".")[
 876                      0
 877                  ]  # Return language and country, but remove encoding if present
 878  
 879          return None
 880  
 881      def get_platform_info(self):
 882          platform_text = f"- Platform: {platform.platform()}\n"
 883          shell_var = "COMSPEC" if os.name == "nt" else "SHELL"
 884          shell_val = os.getenv(shell_var)
 885          platform_text += f"- Shell: {shell_var}={shell_val}\n"
 886  
 887          user_lang = self.get_user_language()
 888          if user_lang:
 889              platform_text += f"- Language: {user_lang}\n"
 890  
 891          dt = datetime.now().astimezone().strftime("%Y-%m-%d")
 892          platform_text += f"- Current date: {dt}\n"
 893  
 894          if self.repo:
 895              platform_text += "- The user is operating inside a git repository\n"
 896  
 897          if self.lint_cmds:
 898              if self.auto_lint:
 899                  platform_text += (
 900                      "- The user's pre-commit runs these lint commands, don't suggest running"
 901                      " them:\n"
 902                  )
 903              else:
 904                  platform_text += "- The user prefers these lint commands:\n"
 905              for lang, cmd in self.lint_cmds.items():
 906                  if lang is None:
 907                      platform_text += f"  - {cmd}\n"
 908                  else:
 909                      platform_text += f"  - {lang}: {cmd}\n"
 910  
 911          if self.test_cmd:
 912              if self.auto_test:
 913                  platform_text += (
 914                      "- The user's pre-commit runs this test command, don't suggest running them: "
 915                  )
 916              else:
 917                  platform_text += "- The user prefers this test command: "
 918              platform_text += self.test_cmd + "\n"
 919  
 920          return platform_text
 921  
 922      def fmt_system_prompt(self, prompt):
 923          lazy_prompt = self.gpt_prompts.lazy_prompt if self.main_model.lazy else ""
 924          platform_text = self.get_platform_info()
 925  
 926          prompt = prompt.format(
 927              fence=self.fence,
 928              lazy_prompt=lazy_prompt,
 929              platform=platform_text,
 930          )
 931          return prompt
 932  
 933      def format_chat_chunks(self):
 934          self.choose_fence()
 935          main_sys = self.fmt_system_prompt(self.gpt_prompts.main_system)
 936  
 937          example_messages = []
 938          if self.main_model.examples_as_sys_msg:
 939              if self.gpt_prompts.example_messages:
 940                  main_sys += "\n# Example conversations:\n\n"
 941              for msg in self.gpt_prompts.example_messages:
 942                  role = msg["role"]
 943                  content = self.fmt_system_prompt(msg["content"])
 944                  main_sys += f"## {role.upper()}: {content}\n\n"
 945              main_sys = main_sys.strip()
 946          else:
 947              for msg in self.gpt_prompts.example_messages:
 948                  example_messages.append(
 949                      dict(
 950                          role=msg["role"],
 951                          content=self.fmt_system_prompt(msg["content"]),
 952                      )
 953                  )
 954              if self.gpt_prompts.example_messages:
 955                  example_messages += [
 956                      dict(
 957                          role="user",
 958                          content=(
 959                              "I switched to a new code base. Please don't consider the above files"
 960                              " or try to edit them any longer."
 961                          ),
 962                      ),
 963                      dict(role="assistant", content="Ok."),
 964                  ]
 965  
 966          if self.gpt_prompts.system_reminder:
 967              main_sys += "\n" + self.fmt_system_prompt(self.gpt_prompts.system_reminder)
 968  
 969          chunks = ChatChunks()
 970  
 971          chunks.system = [
 972              dict(role="system", content=main_sys),
 973          ]
 974          chunks.examples = example_messages
 975  
 976          self.summarize_end()
 977          chunks.done = self.done_messages
 978  
 979          chunks.repo = self.get_repo_messages()
 980          chunks.readonly_files = self.get_readonly_files_messages()
 981          chunks.chat_files = self.get_chat_files_messages()
 982  
 983          if self.gpt_prompts.system_reminder:
 984              reminder_message = [
 985                  dict(
 986                      role="system", content=self.fmt_system_prompt(self.gpt_prompts.system_reminder)
 987                  ),
 988              ]
 989          else:
 990              reminder_message = []
 991  
 992          chunks.cur = list(self.cur_messages)
 993          chunks.reminder = []
 994  
 995          # TODO review impact of token count on image messages
 996          messages_tokens = self.main_model.token_count(chunks.all_messages())
 997          reminder_tokens = self.main_model.token_count(reminder_message)
 998          cur_tokens = self.main_model.token_count(chunks.cur)
 999  
1000          if None not in (messages_tokens, reminder_tokens, cur_tokens):
1001              total_tokens = messages_tokens + reminder_tokens + cur_tokens
1002          else:
1003              # add the reminder anyway
1004              total_tokens = 0
1005  
1006          final = chunks.cur[-1]
1007  
1008          max_input_tokens = self.main_model.info.get("max_input_tokens") or 0
1009          # Add the reminder prompt if we still have room to include it.
1010          if (
1011              max_input_tokens is None
1012              or total_tokens < max_input_tokens
1013              and self.gpt_prompts.system_reminder
1014          ):
1015              if self.main_model.reminder == "sys":
1016                  chunks.reminder = reminder_message
1017              elif self.main_model.reminder == "user" and final["role"] == "user":
1018                  # stuff it into the user message
1019                  new_content = (
1020                      final["content"]
1021                      + "\n\n"
1022                      + self.fmt_system_prompt(self.gpt_prompts.system_reminder)
1023                  )
1024                  chunks.cur[-1] = dict(role=final["role"], content=new_content)
1025  
1026          return chunks
1027  
1028      def format_messages(self):
1029          chunks = self.format_chat_chunks()
1030          if self.add_cache_headers:
1031              chunks.add_cache_control_headers()
1032  
1033          return chunks
1034  
1035      def warm_cache(self, chunks):
1036          if not self.add_cache_headers:
1037              return
1038          if not self.num_cache_warming_pings:
1039              return
1040  
1041          delay = 5 * 60 - 5
1042          self.next_cache_warm = time.time() + delay
1043          self.warming_pings_left = self.num_cache_warming_pings
1044          self.cache_warming_chunks = chunks
1045  
1046          if self.cache_warming_thread:
1047              return
1048  
1049          def warm_cache_worker():
1050              while True:
1051                  time.sleep(1)
1052                  if self.warming_pings_left <= 0:
1053                      continue
1054                  now = time.time()
1055                  if now < self.next_cache_warm:
1056                      continue
1057  
1058                  self.warming_pings_left -= 1
1059                  self.next_cache_warm = time.time() + delay
1060  
1061                  try:
1062                      completion = litellm.completion(
1063                          model=self.main_model.name,
1064                          messages=self.cache_warming_chunks.cacheable_messages(),
1065                          stream=False,
1066                          max_tokens=1,
1067                          extra_headers=self.main_model.extra_headers,
1068                      )
1069                  except Exception as err:
1070                      self.io.tool_warning(f"Cache warming error: {str(err)}")
1071                      continue
1072  
1073                  cache_hit_tokens = getattr(
1074                      completion.usage, "prompt_cache_hit_tokens", 0
1075                  ) or getattr(completion.usage, "cache_read_input_tokens", 0)
1076  
1077                  if self.verbose:
1078                      self.io.tool_output(f"Warmed {format_tokens(cache_hit_tokens)} cached tokens.")
1079  
1080          self.cache_warming_thread = threading.Timer(0, warm_cache_worker)
1081          self.cache_warming_thread.daemon = True
1082          self.cache_warming_thread.start()
1083  
1084          return chunks
1085  
1086      def send_message(self, inp):
1087          self.cur_messages += [
1088              dict(role="user", content=inp),
1089          ]
1090  
1091          chunks = self.format_messages()
1092          messages = chunks.all_messages()
1093          self.warm_cache(chunks)
1094  
1095          if self.verbose:
1096              utils.show_messages(messages, functions=self.functions)
1097  
1098          self.multi_response_content = ""
1099          if self.show_pretty() and self.stream:
1100              mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme)
1101              self.mdstream = MarkdownStream(mdargs=mdargs)
1102          else:
1103              self.mdstream = None
1104  
1105          retry_delay = 0.125
1106  
1107          self.usage_report = None
1108          exhausted = False
1109          interrupted = False
1110          try:
1111              while True:
1112                  try:
1113                      yield from self.send(messages, functions=self.functions)
1114                      break
1115                  except retry_exceptions() as err:
1116                      self.io.tool_warning(str(err))
1117                      retry_delay *= 2
1118                      if retry_delay > 60:
1119                          break
1120                      self.io.tool_output(f"Retrying in {retry_delay:.1f} seconds...")
1121                      time.sleep(retry_delay)
1122                      continue
1123                  except KeyboardInterrupt:
1124                      interrupted = True
1125                      break
1126                  except litellm.ContextWindowExceededError:
1127                      # The input is overflowing the context window!
1128                      exhausted = True
1129                      break
1130                  except litellm.exceptions.BadRequestError as br_err:
1131                      self.io.tool_error(f"BadRequestError: {br_err}")
1132                      return
1133                  except FinishReasonLength:
1134                      # We hit the output limit!
1135                      if not self.main_model.info.get("supports_assistant_prefill"):
1136                          exhausted = True
1137                          break
1138  
1139                      self.multi_response_content = self.get_multi_response_content()
1140  
1141                      if messages[-1]["role"] == "assistant":
1142                          messages[-1]["content"] = self.multi_response_content
1143                      else:
1144                          messages.append(
1145                              dict(role="assistant", content=self.multi_response_content, prefix=True)
1146                          )
1147                  except Exception as err:
1148                      self.io.tool_error(f"Unexpected error: {err}")
1149                      lines = traceback.format_exception(type(err), err, err.__traceback__)
1150                      self.io.tool_error("".join(lines))
1151                      return
1152          finally:
1153              if self.mdstream:
1154                  self.live_incremental_response(True)
1155                  self.mdstream = None
1156  
1157              self.partial_response_content = self.get_multi_response_content(True)
1158              self.multi_response_content = ""
1159  
1160          self.io.tool_output()
1161  
1162          self.show_usage_report()
1163  
1164          if exhausted:
1165              self.show_exhausted_error()
1166              self.num_exhausted_context_windows += 1
1167              return
1168  
1169          if self.partial_response_function_call:
1170              args = self.parse_partial_args()
1171              if args:
1172                  content = args.get("explanation") or ""
1173              else:
1174                  content = ""
1175          elif self.partial_response_content:
1176              content = self.partial_response_content
1177          else:
1178              content = ""
1179  
1180          if interrupted:
1181              content += "\n^C KeyboardInterrupt"
1182              self.cur_messages += [dict(role="assistant", content=content)]
1183              return
1184  
1185          edited = self.apply_updates()
1186  
1187          self.update_cur_messages()
1188  
1189          if edited:
1190              self.aider_edited_files.update(edited)
1191              saved_message = self.auto_commit(edited)
1192  
1193              if not saved_message and hasattr(self.gpt_prompts, "files_content_gpt_edits_no_repo"):
1194                  saved_message = self.gpt_prompts.files_content_gpt_edits_no_repo
1195  
1196              self.move_back_cur_messages(saved_message)
1197  
1198          if self.reflected_message:
1199              return
1200  
1201          if edited and self.auto_lint:
1202              lint_errors = self.lint_edited(edited)
1203              self.auto_commit(edited, context="Ran the linter")
1204              self.lint_outcome = not lint_errors
1205              if lint_errors:
1206                  ok = self.io.confirm_ask("Attempt to fix lint errors?")
1207                  if ok:
1208                      self.reflected_message = lint_errors
1209                      self.update_cur_messages()
1210                      return
1211  
1212          shared_output = self.run_shell_commands()
1213          if shared_output:
1214              self.cur_messages += [
1215                  dict(role="user", content=shared_output),
1216                  dict(role="assistant", content="Ok"),
1217              ]
1218  
1219          if edited and self.auto_test:
1220              test_errors = self.commands.cmd_test(self.test_cmd)
1221              self.test_outcome = not test_errors
1222              if test_errors:
1223                  ok = self.io.confirm_ask("Attempt to fix test errors?")
1224                  if ok:
1225                      self.reflected_message = test_errors
1226                      self.update_cur_messages()
1227                      return
1228  
1229          add_rel_files_message = self.check_for_file_mentions(content)
1230          if add_rel_files_message:
1231              if self.reflected_message:
1232                  self.reflected_message += "\n\n" + add_rel_files_message
1233              else:
1234                  self.reflected_message = add_rel_files_message
1235  
1236      def show_exhausted_error(self):
1237          output_tokens = 0
1238          if self.partial_response_content:
1239              output_tokens = self.main_model.token_count(self.partial_response_content)
1240          max_output_tokens = self.main_model.info.get("max_output_tokens") or 0
1241  
1242          input_tokens = self.main_model.token_count(self.format_messages().all_messages())
1243          max_input_tokens = self.main_model.info.get("max_input_tokens") or 0
1244  
1245          total_tokens = input_tokens + output_tokens
1246  
1247          fudge = 0.7
1248  
1249          out_err = ""
1250          if output_tokens >= max_output_tokens * fudge:
1251              out_err = " -- possibly exceeded output limit!"
1252  
1253          inp_err = ""
1254          if input_tokens >= max_input_tokens * fudge:
1255              inp_err = " -- possibly exhausted context window!"
1256  
1257          tot_err = ""
1258          if total_tokens >= max_input_tokens * fudge:
1259              tot_err = " -- possibly exhausted context window!"
1260  
1261          res = ["", ""]
1262          res.append(f"Model {self.main_model.name} has hit a token limit!")
1263          res.append("Token counts below are approximate.")
1264          res.append("")
1265          res.append(f"Input tokens: ~{input_tokens:,} of {max_input_tokens:,}{inp_err}")
1266          res.append(f"Output tokens: ~{output_tokens:,} of {max_output_tokens:,}{out_err}")
1267          res.append(f"Total tokens: ~{total_tokens:,} of {max_input_tokens:,}{tot_err}")
1268  
1269          if output_tokens >= max_output_tokens:
1270              res.append("")
1271              res.append("To reduce output tokens:")
1272              res.append("- Ask for smaller changes in each request.")
1273              res.append("- Break your code into smaller source files.")
1274              if "diff" not in self.main_model.edit_format:
1275                  res.append(
1276                      "- Use a stronger model like gpt-4o, sonnet or opus that can return diffs."
1277                  )
1278  
1279          if input_tokens >= max_input_tokens or total_tokens >= max_input_tokens:
1280              res.append("")
1281              res.append("To reduce input tokens:")
1282              res.append("- Use /tokens to see token usage.")
1283              res.append("- Use /drop to remove unneeded files from the chat session.")
1284              res.append("- Use /clear to clear the chat history.")
1285              res.append("- Break your code into smaller source files.")
1286  
1287          res.append("")
1288          res.append(f"For more info: {urls.token_limits}")
1289  
1290          res = "".join([line + "\n" for line in res])
1291          self.io.tool_error(res)
1292  
1293      def lint_edited(self, fnames):
1294          res = ""
1295          for fname in fnames:
1296              errors = self.linter.lint(self.abs_root_path(fname))
1297  
1298              if errors:
1299                  res += "\n"
1300                  res += errors
1301                  res += "\n"
1302  
1303          if res:
1304              self.io.tool_warning(res)
1305  
1306          return res
1307  
1308      def update_cur_messages(self):
1309          if self.partial_response_content:
1310              self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
1311          if self.partial_response_function_call:
1312              self.cur_messages += [
1313                  dict(
1314                      role="assistant",
1315                      content=None,
1316                      function_call=self.partial_response_function_call,
1317                  )
1318              ]
1319  
1320      def get_file_mentions(self, content):
1321          words = set(word for word in content.split())
1322  
1323          # drop sentence punctuation from the end
1324          words = set(word.rstrip(",.!;:") for word in words)
1325  
1326          # strip away all kinds of quotes
1327          quotes = "".join(['"', "'", "`"])
1328          words = set(word.strip(quotes) for word in words)
1329  
1330          addable_rel_fnames = self.get_addable_relative_files()
1331  
1332          mentioned_rel_fnames = set()
1333          fname_to_rel_fnames = {}
1334          for rel_fname in addable_rel_fnames:
1335              normalized_rel_fname = rel_fname.replace("\\", "/")
1336              normalized_words = set(word.replace("\\", "/") for word in words)
1337              if normalized_rel_fname in normalized_words:
1338                  mentioned_rel_fnames.add(rel_fname)
1339  
1340              fname = os.path.basename(rel_fname)
1341  
1342              # Don't add basenames that could be plain words like "run" or "make"
1343              if "/" in fname or "\\" in fname or "." in fname or "_" in fname or "-" in fname:
1344                  if fname not in fname_to_rel_fnames:
1345                      fname_to_rel_fnames[fname] = []
1346                  fname_to_rel_fnames[fname].append(rel_fname)
1347  
1348          for fname, rel_fnames in fname_to_rel_fnames.items():
1349              if len(rel_fnames) == 1 and fname in words:
1350                  mentioned_rel_fnames.add(rel_fnames[0])
1351  
1352          return mentioned_rel_fnames
1353  
1354      def check_for_file_mentions(self, content):
1355          mentioned_rel_fnames = self.get_file_mentions(content)
1356  
1357          new_mentions = mentioned_rel_fnames - self.ignore_mentions
1358  
1359          if not new_mentions:
1360              return
1361  
1362          added_fnames = []
1363          group = ConfirmGroup(new_mentions)
1364          for rel_fname in sorted(new_mentions):
1365              if self.io.confirm_ask(f"Add {rel_fname} to the chat?", group=group):
1366                  self.add_rel_fname(rel_fname)
1367                  added_fnames.append(rel_fname)
1368              else:
1369                  self.ignore_mentions.add(rel_fname)
1370  
1371          if added_fnames:
1372              return prompts.added_files.format(fnames=", ".join(added_fnames))
1373  
1374      def send(self, messages, model=None, functions=None):
1375          if not model:
1376              model = self.main_model
1377  
1378          self.partial_response_content = ""
1379          self.partial_response_function_call = dict()
1380  
1381          self.io.log_llm_history("TO LLM", format_messages(messages))
1382  
1383          completion = None
1384          try:
1385              hash_object, completion = send_completion(
1386                  model.name,
1387                  messages,
1388                  functions,
1389                  self.stream,
1390                  self.temperature,
1391                  extra_headers=model.extra_headers,
1392                  max_tokens=model.max_tokens,
1393              )
1394              self.chat_completion_call_hashes.append(hash_object.hexdigest())
1395  
1396              if self.stream:
1397                  yield from self.show_send_output_stream(completion)
1398              else:
1399                  self.show_send_output(completion)
1400          except KeyboardInterrupt as kbi:
1401              self.keyboard_interrupt()
1402              raise kbi
1403          finally:
1404              self.io.log_llm_history(
1405                  "LLM RESPONSE",
1406                  format_content("ASSISTANT", self.partial_response_content),
1407              )
1408  
1409              if self.partial_response_content:
1410                  self.io.ai_output(self.partial_response_content)
1411              elif self.partial_response_function_call:
1412                  # TODO: push this into subclasses
1413                  args = self.parse_partial_args()
1414                  if args:
1415                      self.io.ai_output(json.dumps(args, indent=4))
1416  
1417              self.calculate_and_show_tokens_and_cost(messages, completion)
1418  
1419      def show_send_output(self, completion):
1420          if self.verbose:
1421              print(completion)
1422  
1423          if not completion.choices:
1424              self.io.tool_error(str(completion))
1425              return
1426  
1427          show_func_err = None
1428          show_content_err = None
1429          try:
1430              if completion.choices[0].message.tool_calls:
1431                  self.partial_response_function_call = (
1432                      completion.choices[0].message.tool_calls[0].function
1433                  )
1434          except AttributeError as func_err:
1435              show_func_err = func_err
1436  
1437          try:
1438              self.partial_response_content = completion.choices[0].message.content or ""
1439          except AttributeError as content_err:
1440              show_content_err = content_err
1441  
1442          resp_hash = dict(
1443              function_call=str(self.partial_response_function_call),
1444              content=self.partial_response_content,
1445          )
1446          resp_hash = hashlib.sha1(json.dumps(resp_hash, sort_keys=True).encode())
1447          self.chat_completion_response_hashes.append(resp_hash.hexdigest())
1448  
1449          if show_func_err and show_content_err:
1450              self.io.tool_error(show_func_err)
1451              self.io.tool_error(show_content_err)
1452              raise Exception("No data found in LLM response!")
1453  
1454          show_resp = self.render_incremental_response(True)
1455          if self.show_pretty():
1456              show_resp = Markdown(
1457                  show_resp, style=self.assistant_output_color, code_theme=self.code_theme
1458              )
1459          else:
1460              show_resp = Text(show_resp or "<no response>")
1461  
1462          self.io.console.print(show_resp)
1463  
1464          if (
1465              hasattr(completion.choices[0], "finish_reason")
1466              and completion.choices[0].finish_reason == "length"
1467          ):
1468              raise FinishReasonLength()
1469  
1470      def show_send_output_stream(self, completion):
1471          for chunk in completion:
1472              if len(chunk.choices) == 0:
1473                  continue
1474  
1475              if (
1476                  hasattr(chunk.choices[0], "finish_reason")
1477                  and chunk.choices[0].finish_reason == "length"
1478              ):
1479                  raise FinishReasonLength()
1480  
1481              try:
1482                  func = chunk.choices[0].delta.function_call
1483                  # dump(func)
1484                  for k, v in func.items():
1485                      if k in self.partial_response_function_call:
1486                          self.partial_response_function_call[k] += v
1487                      else:
1488                          self.partial_response_function_call[k] = v
1489              except AttributeError:
1490                  pass
1491  
1492              try:
1493                  text = chunk.choices[0].delta.content
1494                  if text:
1495                      self.partial_response_content += text
1496              except AttributeError:
1497                  text = None
1498  
1499              if self.show_pretty():
1500                  self.live_incremental_response(False)
1501              elif text:
1502                  try:
1503                      sys.stdout.write(text)
1504                  except UnicodeEncodeError:
1505                      # Safely encode and decode the text
1506                      safe_text = text.encode(sys.stdout.encoding, errors="backslashreplace").decode(
1507                          sys.stdout.encoding
1508                      )
1509                      sys.stdout.write(safe_text)
1510                  sys.stdout.flush()
1511                  yield text
1512  
1513      def live_incremental_response(self, final):
1514          show_resp = self.render_incremental_response(final)
1515          self.mdstream.update(show_resp, final=final)
1516  
1517      def render_incremental_response(self, final):
1518          return self.get_multi_response_content()
1519  
1520      def calculate_and_show_tokens_and_cost(self, messages, completion=None):
1521          prompt_tokens = 0
1522          completion_tokens = 0
1523          cache_hit_tokens = 0
1524          cache_write_tokens = 0
1525  
1526          if completion and hasattr(completion, "usage") and completion.usage is not None:
1527              prompt_tokens = completion.usage.prompt_tokens
1528              completion_tokens = completion.usage.completion_tokens
1529              cache_hit_tokens = getattr(completion.usage, "prompt_cache_hit_tokens", 0) or getattr(
1530                  completion.usage, "cache_read_input_tokens", 0
1531              )
1532              cache_write_tokens = getattr(completion.usage, "cache_creation_input_tokens", 0)
1533  
1534              if hasattr(completion.usage, "cache_read_input_tokens") or hasattr(
1535                  completion.usage, "cache_creation_input_tokens"
1536              ):
1537                  self.message_tokens_sent += prompt_tokens
1538                  self.message_tokens_sent += cache_hit_tokens
1539                  self.message_tokens_sent += cache_write_tokens
1540              else:
1541                  self.message_tokens_sent += prompt_tokens
1542  
1543          else:
1544              prompt_tokens = self.main_model.token_count(messages)
1545              completion_tokens = self.main_model.token_count(self.partial_response_content)
1546              self.message_tokens_sent += prompt_tokens
1547  
1548          self.message_tokens_received += completion_tokens
1549  
1550          tokens_report = f"Tokens: {format_tokens(self.message_tokens_sent)} sent"
1551  
1552          if cache_write_tokens:
1553              tokens_report += f", {format_tokens(cache_write_tokens)} cache write"
1554          if cache_hit_tokens:
1555              tokens_report += f", {format_tokens(cache_hit_tokens)} cache hit"
1556          tokens_report += f", {format_tokens(self.message_tokens_received)} received."
1557  
1558          if not self.main_model.info.get("input_cost_per_token"):
1559              self.usage_report = tokens_report
1560              return
1561  
1562          cost = 0
1563  
1564          input_cost_per_token = self.main_model.info.get("input_cost_per_token") or 0
1565          output_cost_per_token = self.main_model.info.get("output_cost_per_token") or 0
1566          input_cost_per_token_cache_hit = (
1567              self.main_model.info.get("input_cost_per_token_cache_hit") or 0
1568          )
1569  
1570          # deepseek
1571          # prompt_cache_hit_tokens + prompt_cache_miss_tokens
1572          #    == prompt_tokens == total tokens that were sent
1573          #
1574          # Anthropic
1575          # cache_creation_input_tokens + cache_read_input_tokens + prompt
1576          #    == total tokens that were
1577  
1578          if input_cost_per_token_cache_hit:
1579              # must be deepseek
1580              cost += input_cost_per_token_cache_hit * cache_hit_tokens
1581              cost += (prompt_tokens - input_cost_per_token_cache_hit) * input_cost_per_token
1582          else:
1583              # hard code the anthropic adjustments, no-ops for other models since cache_x_tokens==0
1584              cost += cache_write_tokens * input_cost_per_token * 1.25
1585              cost += cache_hit_tokens * input_cost_per_token * 0.10
1586              cost += prompt_tokens * input_cost_per_token
1587  
1588          cost += completion_tokens * output_cost_per_token
1589  
1590          self.total_cost += cost
1591          self.message_cost += cost
1592  
1593          def format_cost(value):
1594              if value == 0:
1595                  return "0.00"
1596              magnitude = abs(value)
1597              if magnitude >= 0.01:
1598                  return f"{value:.2f}"
1599              else:
1600                  return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}"
1601  
1602          cost_report = (
1603              f"Cost: ${format_cost(self.message_cost)} message,"
1604              f" ${format_cost(self.total_cost)} session."
1605          )
1606  
1607          if self.add_cache_headers and self.stream:
1608              warning = " Use --no-stream for accurate caching costs."
1609              self.usage_report = tokens_report + "\n" + cost_report + warning
1610              return
1611  
1612          if cache_hit_tokens and cache_write_tokens:
1613              sep = "\n"
1614          else:
1615              sep = " "
1616  
1617          self.usage_report = tokens_report + sep + cost_report
1618  
1619      def show_usage_report(self):
1620          if self.usage_report:
1621              self.io.tool_output(self.usage_report)
1622              self.message_cost = 0.0
1623              self.message_tokens_sent = 0
1624              self.message_tokens_received = 0
1625  
1626      def get_multi_response_content(self, final=False):
1627          cur = self.multi_response_content or ""
1628          new = self.partial_response_content or ""
1629  
1630          if new.rstrip() != new and not final:
1631              new = new.rstrip()
1632          return cur + new
1633  
1634      def get_rel_fname(self, fname):
1635          try:
1636              return os.path.relpath(fname, self.root)
1637          except ValueError:
1638              return fname
1639  
1640      def get_inchat_relative_files(self):
1641          files = [self.get_rel_fname(fname) for fname in self.abs_fnames]
1642          return sorted(set(files))
1643  
1644      def is_file_safe(self, fname):
1645          try:
1646              return Path(self.abs_root_path(fname)).is_file()
1647          except OSError:
1648              return
1649  
1650      def get_all_relative_files(self):
1651          if self.repo:
1652              files = self.repo.get_tracked_files()
1653          else:
1654              files = self.get_inchat_relative_files()
1655  
1656          # This is quite slow in large repos
1657          # files = [fname for fname in files if self.is_file_safe(fname)]
1658  
1659          return sorted(set(files))
1660  
1661      def get_all_abs_files(self):
1662          files = self.get_all_relative_files()
1663          files = [self.abs_root_path(path) for path in files]
1664          return files
1665  
1666      def get_addable_relative_files(self):
1667          all_files = set(self.get_all_relative_files())
1668          inchat_files = set(self.get_inchat_relative_files())
1669          read_only_files = set(self.get_rel_fname(fname) for fname in self.abs_read_only_fnames)
1670          return all_files - inchat_files - read_only_files
1671  
1672      def check_for_dirty_commit(self, path):
1673          if not self.repo:
1674              return
1675          if not self.dirty_commits:
1676              return
1677          if not self.repo.is_dirty(path):
1678              return
1679  
1680          # We need a committed copy of the file in order to /undo, so skip this
1681          # fullp = Path(self.abs_root_path(path))
1682          # if not fullp.stat().st_size:
1683          #     return
1684  
1685          self.io.tool_output(f"Committing {path} before applying edits.")
1686          self.need_commit_before_edits.add(path)
1687  
1688      def allowed_to_edit(self, path):
1689          full_path = self.abs_root_path(path)
1690          if self.repo:
1691              need_to_add = not self.repo.path_in_repo(path)
1692          else:
1693              need_to_add = False
1694  
1695          if full_path in self.abs_fnames:
1696              self.check_for_dirty_commit(path)
1697              return True
1698  
1699          if not Path(full_path).exists():
1700              if not self.io.confirm_ask("Create new file?", subject=path):
1701                  self.io.tool_output(f"Skipping edits to {path}")
1702                  return
1703  
1704              if not self.dry_run:
1705                  if not utils.touch_file(full_path):
1706                      self.io.tool_error(f"Unable to create {path}, skipping edits.")
1707                      return
1708  
1709                  # Seems unlikely that we needed to create the file, but it was
1710                  # actually already part of the repo.
1711                  # But let's only add if we need to, just to be safe.
1712                  if need_to_add:
1713                      self.repo.repo.git.add(full_path)
1714  
1715              self.abs_fnames.add(full_path)
1716              self.check_added_files()
1717              return True
1718  
1719          if not self.io.confirm_ask(
1720              "Allow edits to file that has not been added to the chat?",
1721              subject=path,
1722          ):
1723              self.io.tool_output(f"Skipping edits to {path}")
1724              return
1725  
1726          if need_to_add:
1727              self.repo.repo.git.add(full_path)
1728  
1729          self.abs_fnames.add(full_path)
1730          self.check_added_files()
1731          self.check_for_dirty_commit(path)
1732  
1733          return True
1734  
1735      warning_given = False
1736  
1737      def check_added_files(self):
1738          if self.warning_given:
1739              return
1740  
1741          warn_number_of_files = 4
1742          warn_number_of_tokens = 20 * 1024
1743  
1744          num_files = len(self.abs_fnames)
1745          if num_files < warn_number_of_files:
1746              return
1747  
1748          tokens = 0
1749          for fname in self.abs_fnames:
1750              if is_image_file(fname):
1751                  continue
1752              content = self.io.read_text(fname)
1753              tokens += self.main_model.token_count(content)
1754  
1755          if tokens < warn_number_of_tokens:
1756              return
1757  
1758          self.io.tool_warning("Warning: it's best to only add files that need changes to the chat.")
1759          self.io.tool_warning(urls.edit_errors)
1760          self.warning_given = True
1761  
1762      def prepare_to_edit(self, edits):
1763          res = []
1764          seen = dict()
1765  
1766          self.need_commit_before_edits = set()
1767  
1768          for edit in edits:
1769              path = edit[0]
1770              if path is None:
1771                  res.append(edit)
1772                  continue
1773              if path == "python":
1774                  dump(edits)
1775              if path in seen:
1776                  allowed = seen[path]
1777              else:
1778                  allowed = self.allowed_to_edit(path)
1779                  seen[path] = allowed
1780  
1781              if allowed:
1782                  res.append(edit)
1783  
1784          self.dirty_commit()
1785          self.need_commit_before_edits = set()
1786  
1787          return res
1788  
1789      def apply_updates(self):
1790          edited = set()
1791          try:
1792              edits = self.get_edits()
1793              edits = self.prepare_to_edit(edits)
1794              edited = set(edit[0] for edit in edits)
1795              self.apply_edits(edits)
1796          except ValueError as err:
1797              self.num_malformed_responses += 1
1798  
1799              err = err.args[0]
1800  
1801              self.io.tool_error("The LLM did not conform to the edit format.")
1802              self.io.tool_output(urls.edit_errors)
1803              self.io.tool_output()
1804              self.io.tool_output(str(err))
1805  
1806              self.reflected_message = str(err)
1807              return edited
1808  
1809          except ANY_GIT_ERROR as err:
1810              self.io.tool_error(str(err))
1811              return edited
1812          except Exception as err:
1813              self.io.tool_error("Exception while updating files:")
1814              self.io.tool_error(str(err), strip=False)
1815  
1816              traceback.print_exc()
1817  
1818              self.reflected_message = str(err)
1819              return edited
1820  
1821          for path in edited:
1822              if self.dry_run:
1823                  self.io.tool_output(f"Did not apply edit to {path} (--dry-run)")
1824              else:
1825                  self.io.tool_output(f"Applied edit to {path}")
1826  
1827          return edited
1828  
1829      def parse_partial_args(self):
1830          # dump(self.partial_response_function_call)
1831  
1832          data = self.partial_response_function_call.get("arguments")
1833          if not data:
1834              return
1835  
1836          try:
1837              return json.loads(data)
1838          except JSONDecodeError:
1839              pass
1840  
1841          try:
1842              return json.loads(data + "]}")
1843          except JSONDecodeError:
1844              pass
1845  
1846          try:
1847              return json.loads(data + "}]}")
1848          except JSONDecodeError:
1849              pass
1850  
1851          try:
1852              return json.loads(data + '"}]}')
1853          except JSONDecodeError:
1854              pass
1855  
1856      # commits...
1857  
1858      def get_context_from_history(self, history):
1859          context = ""
1860          if history:
1861              for msg in history:
1862                  context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n"
1863  
1864          return context
1865  
1866      def auto_commit(self, edited, context=None):
1867          if not self.repo or not self.auto_commits or self.dry_run:
1868              return
1869  
1870          if not context:
1871              context = self.get_context_from_history(self.cur_messages)
1872  
1873          try:
1874              res = self.repo.commit(fnames=edited, context=context, aider_edits=True)
1875              if res:
1876                  self.show_auto_commit_outcome(res)
1877                  commit_hash, commit_message = res
1878                  return self.gpt_prompts.files_content_gpt_edits.format(
1879                      hash=commit_hash,
1880                      message=commit_message,
1881                  )
1882  
1883              self.io.tool_output("No changes made to git tracked files.")
1884              return self.gpt_prompts.files_content_gpt_no_edits
1885          except ANY_GIT_ERROR as err:
1886              self.io.tool_error(f"Unable to commit: {str(err)}")
1887              return
1888  
1889      def show_auto_commit_outcome(self, res):
1890          commit_hash, commit_message = res
1891          self.last_aider_commit_hash = commit_hash
1892          self.aider_commit_hashes.add(commit_hash)
1893          self.last_aider_commit_message = commit_message
1894          if self.show_diffs:
1895              self.commands.cmd_diff()
1896  
1897      def show_undo_hint(self):
1898          if not self.commit_before_message:
1899              return
1900          if self.commit_before_message[-1] != self.repo.get_head_commit_sha():
1901              self.io.tool_output("You can use /undo to undo and discard each aider commit.")
1902  
1903      def dirty_commit(self):
1904          if not self.need_commit_before_edits:
1905              return
1906          if not self.dirty_commits:
1907              return
1908          if not self.repo:
1909              return
1910  
1911          self.repo.commit(fnames=self.need_commit_before_edits)
1912  
1913          # files changed, move cur messages back behind the files messages
1914          # self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
1915          return True
1916  
1917      def get_edits(self, mode="update"):
1918          return []
1919  
1920      def apply_edits(self, edits):
1921          return
1922  
1923      def run_shell_commands(self):
1924          if not self.suggest_shell_commands:
1925              return ""
1926  
1927          done = set()
1928          group = ConfirmGroup(set(self.shell_commands))
1929          accumulated_output = ""
1930          for command in self.shell_commands:
1931              if command in done:
1932                  continue
1933              done.add(command)
1934              output = self.handle_shell_commands(command, group)
1935              if output:
1936                  accumulated_output += output + "\n\n"
1937          return accumulated_output
1938  
1939      def handle_shell_commands(self, commands_str, group):
1940          commands = commands_str.strip().splitlines()
1941          command_count = sum(
1942              1 for cmd in commands if cmd.strip() and not cmd.strip().startswith("#")
1943          )
1944          prompt = "Run shell command?" if command_count == 1 else "Run shell commands?"
1945          if not self.io.confirm_ask(
1946              prompt, subject="\n".join(commands), explicit_yes_required=True, group=group
1947          ):
1948              return
1949  
1950          accumulated_output = ""
1951          for command in commands:
1952              command = command.strip()
1953              if not command or command.startswith("#"):
1954                  continue
1955  
1956              self.io.tool_output()
1957              self.io.tool_output(f"Running {command}")
1958              # Add the command to input history
1959              self.io.add_to_input_history(f"/run {command.strip()}")
1960              exit_status, output = run_cmd(command, error_print=self.io.tool_error)
1961              if output:
1962                  accumulated_output += f"Output from {command}\n{output}\n"
1963  
1964          if accumulated_output.strip() and not self.io.confirm_ask(
1965              "Add command output to the chat?"
1966          ):
1967              accumulated_output = ""
1968  
1969          return accumulated_output