base_coder.py
1 #!/usr/bin/env python 2 3 import base64 4 import hashlib 5 import json 6 import locale 7 import math 8 import mimetypes 9 import os 10 import platform 11 import re 12 import sys 13 import threading 14 import time 15 import traceback 16 from collections import defaultdict 17 from datetime import datetime 18 from json.decoder import JSONDecodeError 19 from pathlib import Path 20 21 from rich.console import Console, Text 22 from rich.markdown import Markdown 23 24 from aider import __version__, models, prompts, urls, utils 25 from aider.commands import Commands 26 from aider.history import ChatSummary 27 from aider.io import ConfirmGroup, InputOutput 28 from aider.linter import Linter 29 from aider.llm import litellm 30 from aider.mdstream import MarkdownStream 31 from aider.repo import ANY_GIT_ERROR, GitRepo 32 from aider.repomap import RepoMap 33 from aider.run_cmd import run_cmd 34 from aider.sendchat import retry_exceptions, send_completion 35 from aider.utils import format_content, format_messages, format_tokens, is_image_file 36 37 from ..dump import dump # noqa: F401 38 from .chat_chunks import ChatChunks 39 40 41 class MissingAPIKeyError(ValueError): 42 pass 43 44 45 class FinishReasonLength(Exception): 46 pass 47 48 49 def wrap_fence(name): 50 return f"<{name}>", f"</{name}>" 51 52 53 all_fences = [ 54 ("``" + "`", "``" + "`"), 55 wrap_fence("source"), 56 wrap_fence("code"), 57 wrap_fence("pre"), 58 wrap_fence("codeblock"), 59 wrap_fence("sourcecode"), 60 ] 61 62 63 class Coder: 64 abs_fnames = None 65 abs_read_only_fnames = None 66 repo = None 67 last_aider_commit_hash = None 68 aider_edited_files = None 69 last_asked_for_commit_time = 0 70 repo_map = None 71 functions = None 72 num_exhausted_context_windows = 0 73 num_malformed_responses = 0 74 last_keyboard_interrupt = None 75 num_reflections = 0 76 max_reflections = 3 77 edit_format = None 78 yield_stream = False 79 temperature = 0 80 auto_lint = True 81 auto_test = False 82 test_cmd = None 83 lint_outcome = None 84 test_outcome = None 85 multi_response_content = "" 86 partial_response_content = "" 87 commit_before_message = [] 88 message_cost = 0.0 89 message_tokens_sent = 0 90 message_tokens_received = 0 91 add_cache_headers = False 92 cache_warming_thread = None 93 num_cache_warming_pings = 0 94 suggest_shell_commands = True 95 ignore_mentions = None 96 chat_language = None 97 98 @classmethod 99 def create( 100 self, 101 main_model=None, 102 edit_format=None, 103 io=None, 104 from_coder=None, 105 summarize_from_coder=True, 106 **kwargs, 107 ): 108 import aider.coders as coders 109 110 if not main_model: 111 if from_coder: 112 main_model = from_coder.main_model 113 else: 114 main_model = models.Model(models.DEFAULT_MODEL_NAME) 115 116 if edit_format == "code": 117 edit_format = None 118 if edit_format is None: 119 if from_coder: 120 edit_format = from_coder.edit_format 121 else: 122 edit_format = main_model.edit_format 123 124 if not io and from_coder: 125 io = from_coder.io 126 127 if from_coder: 128 use_kwargs = dict(from_coder.original_kwargs) # copy orig kwargs 129 130 # If the edit format changes, we can't leave old ASSISTANT 131 # messages in the chat history. The old edit format will 132 # confused the new LLM. It may try and imitate it, disobeying 133 # the system prompt. 134 done_messages = from_coder.done_messages 135 if edit_format != from_coder.edit_format and done_messages and summarize_from_coder: 136 done_messages = from_coder.summarizer.summarize_all(done_messages) 137 138 # Bring along context from the old Coder 139 update = dict( 140 fnames=list(from_coder.abs_fnames), 141 read_only_fnames=list(from_coder.abs_read_only_fnames), # Copy read-only files 142 done_messages=done_messages, 143 cur_messages=from_coder.cur_messages, 144 aider_commit_hashes=from_coder.aider_commit_hashes, 145 commands=from_coder.commands.clone(), 146 total_cost=from_coder.total_cost, 147 ) 148 149 use_kwargs.update(update) # override to complete the switch 150 use_kwargs.update(kwargs) # override passed kwargs 151 152 kwargs = use_kwargs 153 154 for coder in coders.__all__: 155 if hasattr(coder, "edit_format") and coder.edit_format == edit_format: 156 res = coder(main_model, io, **kwargs) 157 res.original_kwargs = dict(kwargs) 158 return res 159 160 raise ValueError(f"Unknown edit format {edit_format}") 161 162 def clone(self, **kwargs): 163 new_coder = Coder.create(from_coder=self, **kwargs) 164 new_coder.ignore_mentions = self.ignore_mentions 165 return new_coder 166 167 def get_announcements(self): 168 lines = [] 169 lines.append(f"Aider v{__version__}") 170 171 # Model 172 main_model = self.main_model 173 weak_model = main_model.weak_model 174 175 if weak_model is not main_model: 176 prefix = "Main model" 177 else: 178 prefix = "Model" 179 180 output = f"{prefix}: {main_model.name} with {self.edit_format} edit format" 181 if self.add_cache_headers or main_model.caches_by_default: 182 output += ", prompt cache" 183 if main_model.info.get("supports_assistant_prefill"): 184 output += ", infinite output" 185 lines.append(output) 186 187 if weak_model is not main_model: 188 output = f"Weak model: {weak_model.name}" 189 lines.append(output) 190 191 # Repo 192 if self.repo: 193 rel_repo_dir = self.repo.get_rel_repo_dir() 194 num_files = len(self.repo.get_tracked_files()) 195 196 lines.append(f"Git repo: {rel_repo_dir} with {num_files:,} files") 197 if num_files > 1000: 198 lines.append( 199 "Warning: For large repos, consider using --subtree-only and .aiderignore" 200 ) 201 lines.append(f"See: {urls.large_repos}") 202 else: 203 lines.append("Git repo: none") 204 205 # Repo-map 206 if self.repo_map: 207 map_tokens = self.repo_map.max_map_tokens 208 if map_tokens > 0: 209 refresh = self.repo_map.refresh 210 lines.append(f"Repo-map: using {map_tokens} tokens, {refresh} refresh") 211 max_map_tokens = 2048 212 if map_tokens > max_map_tokens: 213 lines.append( 214 f"Warning: map-tokens > {max_map_tokens} is not recommended as too much" 215 " irrelevant code can confuse LLMs." 216 ) 217 else: 218 lines.append("Repo-map: disabled because map_tokens == 0") 219 else: 220 lines.append("Repo-map: disabled") 221 222 # Files 223 for fname in self.get_inchat_relative_files(): 224 lines.append(f"Added {fname} to the chat.") 225 226 if self.done_messages: 227 lines.append("Restored previous conversation history.") 228 229 return lines 230 231 def __init__( 232 self, 233 main_model, 234 io, 235 repo=None, 236 fnames=None, 237 read_only_fnames=None, 238 show_diffs=False, 239 auto_commits=True, 240 dirty_commits=True, 241 dry_run=False, 242 map_tokens=1024, 243 verbose=False, 244 assistant_output_color="blue", 245 code_theme="default", 246 stream=True, 247 use_git=True, 248 cur_messages=None, 249 done_messages=None, 250 restore_chat_history=False, 251 auto_lint=True, 252 auto_test=False, 253 lint_cmds=None, 254 test_cmd=None, 255 aider_commit_hashes=None, 256 map_mul_no_files=8, 257 commands=None, 258 summarizer=None, 259 total_cost=0.0, 260 map_refresh="auto", 261 cache_prompts=False, 262 num_cache_warming_pings=0, 263 suggest_shell_commands=True, 264 chat_language=None, 265 ): 266 self.chat_language = chat_language 267 self.commit_before_message = [] 268 self.aider_commit_hashes = set() 269 self.rejected_urls = set() 270 self.abs_root_path_cache = {} 271 self.ignore_mentions = set() 272 273 self.suggest_shell_commands = suggest_shell_commands 274 275 self.num_cache_warming_pings = num_cache_warming_pings 276 277 if not fnames: 278 fnames = [] 279 280 if io is None: 281 io = InputOutput() 282 283 if aider_commit_hashes: 284 self.aider_commit_hashes = aider_commit_hashes 285 else: 286 self.aider_commit_hashes = set() 287 288 self.chat_completion_call_hashes = [] 289 self.chat_completion_response_hashes = [] 290 self.need_commit_before_edits = set() 291 292 self.total_cost = total_cost 293 294 self.verbose = verbose 295 self.abs_fnames = set() 296 self.abs_read_only_fnames = set() 297 298 if cur_messages: 299 self.cur_messages = cur_messages 300 else: 301 self.cur_messages = [] 302 303 if done_messages: 304 self.done_messages = done_messages 305 else: 306 self.done_messages = [] 307 308 self.io = io 309 self.stream = stream 310 311 self.shell_commands = [] 312 313 if not auto_commits: 314 dirty_commits = False 315 316 self.auto_commits = auto_commits 317 self.dirty_commits = dirty_commits 318 self.assistant_output_color = assistant_output_color 319 self.code_theme = code_theme 320 321 self.dry_run = dry_run 322 self.pretty = self.io.pretty 323 324 if self.pretty: 325 self.console = Console() 326 else: 327 self.console = Console(force_terminal=False, no_color=True) 328 329 self.main_model = main_model 330 331 if cache_prompts and self.main_model.cache_control: 332 self.add_cache_headers = True 333 334 self.show_diffs = show_diffs 335 336 self.commands = commands or Commands(self.io, self) 337 self.commands.coder = self 338 339 self.repo = repo 340 if use_git and self.repo is None: 341 try: 342 self.repo = GitRepo( 343 self.io, 344 fnames, 345 None, 346 models=main_model.commit_message_models(), 347 ) 348 except FileNotFoundError: 349 pass 350 351 if self.repo: 352 self.root = self.repo.root 353 354 for fname in fnames: 355 fname = Path(fname) 356 if self.repo and self.repo.ignored_file(fname): 357 self.io.tool_warning(f"Skipping {fname} that matches aiderignore spec.") 358 continue 359 360 if not fname.exists(): 361 if utils.touch_file(fname): 362 self.io.tool_output(f"Creating empty file {fname}") 363 else: 364 self.io.tool_warning(f"Can not create {fname}, skipping.") 365 continue 366 367 if not fname.is_file(): 368 self.io.tool_warning(f"Skipping {fname} that is not a normal file.") 369 continue 370 371 fname = str(fname.resolve()) 372 373 self.abs_fnames.add(fname) 374 self.check_added_files() 375 376 if not self.repo: 377 self.root = utils.find_common_root(self.abs_fnames) 378 379 if read_only_fnames: 380 self.abs_read_only_fnames = set() 381 for fname in read_only_fnames: 382 abs_fname = self.abs_root_path(fname) 383 if os.path.exists(abs_fname): 384 self.abs_read_only_fnames.add(abs_fname) 385 else: 386 self.io.tool_warning(f"Error: Read-only file {fname} does not exist. Skipping.") 387 388 if map_tokens is None: 389 use_repo_map = main_model.use_repo_map 390 map_tokens = 1024 391 else: 392 use_repo_map = map_tokens > 0 393 394 max_inp_tokens = self.main_model.info.get("max_input_tokens") or 0 395 396 has_map_prompt = hasattr(self, "gpt_prompts") and self.gpt_prompts.repo_content_prefix 397 398 if use_repo_map and self.repo and has_map_prompt: 399 self.repo_map = RepoMap( 400 map_tokens, 401 self.root, 402 self.main_model, 403 io, 404 self.gpt_prompts.repo_content_prefix, 405 self.verbose, 406 max_inp_tokens, 407 map_mul_no_files=map_mul_no_files, 408 refresh=map_refresh, 409 ) 410 411 self.summarizer = summarizer or ChatSummary( 412 [self.main_model.weak_model, self.main_model], 413 self.main_model.max_chat_history_tokens, 414 ) 415 416 self.summarizer_thread = None 417 self.summarized_done_messages = [] 418 419 if not self.done_messages and restore_chat_history: 420 history_md = self.io.read_text(self.io.chat_history_file) 421 if history_md: 422 self.done_messages = utils.split_chat_history_markdown(history_md) 423 self.summarize_start() 424 425 # Linting and testing 426 self.linter = Linter(root=self.root, encoding=io.encoding) 427 self.auto_lint = auto_lint 428 self.setup_lint_cmds(lint_cmds) 429 self.lint_cmds = lint_cmds 430 self.auto_test = auto_test 431 self.test_cmd = test_cmd 432 433 # validate the functions jsonschema 434 if self.functions: 435 from jsonschema import Draft7Validator 436 437 for function in self.functions: 438 Draft7Validator.check_schema(function) 439 440 if self.verbose: 441 self.io.tool_output("JSON Schema:") 442 self.io.tool_output(json.dumps(self.functions, indent=4)) 443 444 def setup_lint_cmds(self, lint_cmds): 445 if not lint_cmds: 446 return 447 for lang, cmd in lint_cmds.items(): 448 self.linter.set_linter(lang, cmd) 449 450 def show_announcements(self): 451 bold = True 452 for line in self.get_announcements(): 453 self.io.tool_output(line, bold=bold) 454 bold = False 455 456 def add_rel_fname(self, rel_fname): 457 self.abs_fnames.add(self.abs_root_path(rel_fname)) 458 self.check_added_files() 459 460 def drop_rel_fname(self, fname): 461 abs_fname = self.abs_root_path(fname) 462 if abs_fname in self.abs_fnames: 463 self.abs_fnames.remove(abs_fname) 464 return True 465 466 def abs_root_path(self, path): 467 key = path 468 if key in self.abs_root_path_cache: 469 return self.abs_root_path_cache[key] 470 471 res = Path(self.root) / path 472 res = utils.safe_abs_path(res) 473 self.abs_root_path_cache[key] = res 474 return res 475 476 fences = all_fences 477 fence = fences[0] 478 479 def show_pretty(self): 480 if not self.pretty: 481 return False 482 483 # only show pretty output if fences are the normal triple-backtick 484 if self.fence != self.fences[0]: 485 return False 486 487 return True 488 489 def get_abs_fnames_content(self): 490 for fname in list(self.abs_fnames): 491 content = self.io.read_text(fname) 492 493 if content is None: 494 relative_fname = self.get_rel_fname(fname) 495 self.io.tool_warning(f"Dropping {relative_fname} from the chat.") 496 self.abs_fnames.remove(fname) 497 else: 498 yield fname, content 499 500 def choose_fence(self): 501 all_content = "" 502 for _fname, content in self.get_abs_fnames_content(): 503 all_content += content + "\n" 504 for _fname in self.abs_read_only_fnames: 505 content = self.io.read_text(_fname) 506 if content is not None: 507 all_content += content + "\n" 508 509 good = False 510 for fence_open, fence_close in self.fences: 511 if fence_open in all_content or fence_close in all_content: 512 continue 513 good = True 514 break 515 516 if good: 517 self.fence = (fence_open, fence_close) 518 else: 519 self.fence = self.fences[0] 520 self.io.tool_warning( 521 "Unable to find a fencing strategy! Falling back to:" 522 f" {self.fence[0]}...{self.fence[1]}" 523 ) 524 525 return 526 527 def get_files_content(self, fnames=None): 528 if not fnames: 529 fnames = self.abs_fnames 530 531 prompt = "" 532 for fname, content in self.get_abs_fnames_content(): 533 if not is_image_file(fname): 534 relative_fname = self.get_rel_fname(fname) 535 prompt += "\n" 536 prompt += relative_fname 537 prompt += f"\n{self.fence[0]}\n" 538 539 prompt += content 540 541 # lines = content.splitlines(keepends=True) 542 # lines = [f"{i+1:03}:{line}" for i, line in enumerate(lines)] 543 # prompt += "".join(lines) 544 545 prompt += f"{self.fence[1]}\n" 546 547 return prompt 548 549 def get_read_only_files_content(self): 550 prompt = "" 551 for fname in self.abs_read_only_fnames: 552 content = self.io.read_text(fname) 553 if content is not None and not is_image_file(fname): 554 relative_fname = self.get_rel_fname(fname) 555 prompt += "\n" 556 prompt += relative_fname 557 prompt += f"\n{self.fence[0]}\n" 558 prompt += content 559 prompt += f"{self.fence[1]}\n" 560 return prompt 561 562 def get_cur_message_text(self): 563 text = "" 564 for msg in self.cur_messages: 565 text += msg["content"] + "\n" 566 return text 567 568 def get_ident_mentions(self, text): 569 # Split the string on any character that is not alphanumeric 570 # \W+ matches one or more non-word characters (equivalent to [^a-zA-Z0-9_]+) 571 words = set(re.split(r"\W+", text)) 572 return words 573 574 def get_ident_filename_matches(self, idents): 575 all_fnames = defaultdict(set) 576 for fname in self.get_all_relative_files(): 577 base = Path(fname).with_suffix("").name.lower() 578 if len(base) >= 5: 579 all_fnames[base].add(fname) 580 581 matches = set() 582 for ident in idents: 583 if len(ident) < 5: 584 continue 585 matches.update(all_fnames[ident.lower()]) 586 587 return matches 588 589 def get_repo_map(self, force_refresh=False): 590 if not self.repo_map: 591 return 592 593 cur_msg_text = self.get_cur_message_text() 594 mentioned_fnames = self.get_file_mentions(cur_msg_text) 595 mentioned_idents = self.get_ident_mentions(cur_msg_text) 596 597 mentioned_fnames.update(self.get_ident_filename_matches(mentioned_idents)) 598 599 all_abs_files = set(self.get_all_abs_files()) 600 repo_abs_read_only_fnames = set(self.abs_read_only_fnames) & all_abs_files 601 chat_files = set(self.abs_fnames) | repo_abs_read_only_fnames 602 other_files = all_abs_files - chat_files 603 604 repo_content = self.repo_map.get_repo_map( 605 chat_files, 606 other_files, 607 mentioned_fnames=mentioned_fnames, 608 mentioned_idents=mentioned_idents, 609 force_refresh=force_refresh, 610 ) 611 612 # fall back to global repo map if files in chat are disjoint from rest of repo 613 if not repo_content: 614 repo_content = self.repo_map.get_repo_map( 615 set(), 616 all_abs_files, 617 mentioned_fnames=mentioned_fnames, 618 mentioned_idents=mentioned_idents, 619 ) 620 621 # fall back to completely unhinted repo 622 if not repo_content: 623 repo_content = self.repo_map.get_repo_map( 624 set(), 625 all_abs_files, 626 ) 627 628 return repo_content 629 630 def get_repo_messages(self): 631 repo_messages = [] 632 repo_content = self.get_repo_map() 633 if repo_content: 634 repo_messages += [ 635 dict(role="user", content=repo_content), 636 dict( 637 role="assistant", 638 content="Ok, I won't try and edit those files without asking first.", 639 ), 640 ] 641 return repo_messages 642 643 def get_readonly_files_messages(self): 644 readonly_messages = [] 645 read_only_content = self.get_read_only_files_content() 646 if read_only_content: 647 readonly_messages += [ 648 dict( 649 role="user", content=self.gpt_prompts.read_only_files_prefix + read_only_content 650 ), 651 dict( 652 role="assistant", 653 content="Ok, I will use these files as references.", 654 ), 655 ] 656 return readonly_messages 657 658 def get_chat_files_messages(self): 659 chat_files_messages = [] 660 if self.abs_fnames: 661 files_content = self.gpt_prompts.files_content_prefix 662 files_content += self.get_files_content() 663 files_reply = "Ok, any changes I propose will be to those files." 664 elif self.get_repo_map() and self.gpt_prompts.files_no_full_files_with_repo_map: 665 files_content = self.gpt_prompts.files_no_full_files_with_repo_map 666 files_reply = self.gpt_prompts.files_no_full_files_with_repo_map_reply 667 else: 668 files_content = self.gpt_prompts.files_no_full_files 669 files_reply = "Ok." 670 671 if files_content: 672 chat_files_messages += [ 673 dict(role="user", content=files_content), 674 dict(role="assistant", content=files_reply), 675 ] 676 677 images_message = self.get_images_message() 678 if images_message is not None: 679 chat_files_messages += [ 680 images_message, 681 dict(role="assistant", content="Ok."), 682 ] 683 684 return chat_files_messages 685 686 def get_images_message(self): 687 if not self.main_model.accepts_images: 688 return None 689 690 image_messages = [] 691 for fname, content in self.get_abs_fnames_content(): 692 if is_image_file(fname): 693 with open(fname, "rb") as image_file: 694 encoded_string = base64.b64encode(image_file.read()).decode("utf-8") 695 mime_type, _ = mimetypes.guess_type(fname) 696 if mime_type and mime_type.startswith("image/"): 697 image_url = f"data:{mime_type};base64,{encoded_string}" 698 rel_fname = self.get_rel_fname(fname) 699 image_messages += [ 700 {"type": "text", "text": f"Image file: {rel_fname}"}, 701 {"type": "image_url", "image_url": {"url": image_url, "detail": "high"}}, 702 ] 703 704 if not image_messages: 705 return None 706 707 return {"role": "user", "content": image_messages} 708 709 def run_stream(self, user_message): 710 self.io.user_input(user_message) 711 self.init_before_message() 712 yield from self.send_message(user_message) 713 714 def init_before_message(self): 715 self.aider_edited_files = set() 716 self.reflected_message = None 717 self.num_reflections = 0 718 self.lint_outcome = None 719 self.test_outcome = None 720 self.shell_commands = [] 721 722 if self.repo: 723 self.commit_before_message.append(self.repo.get_head_commit_sha()) 724 725 def run(self, with_message=None, preproc=True): 726 try: 727 if with_message: 728 self.io.user_input(with_message) 729 self.run_one(with_message, preproc) 730 return self.partial_response_content 731 732 while True: 733 try: 734 user_message = self.get_input() 735 self.run_one(user_message, preproc) 736 self.show_undo_hint() 737 except KeyboardInterrupt: 738 self.keyboard_interrupt() 739 except EOFError: 740 return 741 742 def get_input(self): 743 inchat_files = self.get_inchat_relative_files() 744 read_only_files = [self.get_rel_fname(fname) for fname in self.abs_read_only_fnames] 745 all_files = sorted(set(inchat_files + read_only_files)) 746 edit_format = "" if self.edit_format == self.main_model.edit_format else self.edit_format 747 return self.io.get_input( 748 self.root, 749 all_files, 750 self.get_addable_relative_files(), 751 self.commands, 752 self.abs_read_only_fnames, 753 edit_format=edit_format, 754 ) 755 756 def preproc_user_input(self, inp): 757 if not inp: 758 return 759 760 if self.commands.is_command(inp): 761 return self.commands.run(inp) 762 763 self.check_for_file_mentions(inp) 764 self.check_for_urls(inp) 765 766 return inp 767 768 def run_one(self, user_message, preproc): 769 self.init_before_message() 770 771 if preproc: 772 message = self.preproc_user_input(user_message) 773 else: 774 message = user_message 775 776 while message: 777 self.reflected_message = None 778 list(self.send_message(message)) 779 780 if not self.reflected_message: 781 break 782 783 if self.num_reflections >= self.max_reflections: 784 self.io.tool_warning(f"Only {self.max_reflections} reflections allowed, stopping.") 785 return 786 787 self.num_reflections += 1 788 message = self.reflected_message 789 790 def check_for_urls(self, inp): 791 url_pattern = re.compile(r"(https?://[^\s/$.?#].[^\s]*[^\s,.])") 792 urls = list(set(url_pattern.findall(inp))) # Use set to remove duplicates 793 added_urls = [] 794 group = ConfirmGroup(urls) 795 for url in urls: 796 if url not in self.rejected_urls: 797 if self.io.confirm_ask("Add URL to the chat?", subject=url, group=group): 798 inp += "\n\n" 799 inp += self.commands.cmd_web(url) 800 added_urls.append(url) 801 else: 802 self.rejected_urls.add(url) 803 804 return added_urls 805 806 def keyboard_interrupt(self): 807 now = time.time() 808 809 thresh = 2 # seconds 810 if self.last_keyboard_interrupt and now - self.last_keyboard_interrupt < thresh: 811 self.io.tool_warning("\n\n^C KeyboardInterrupt") 812 sys.exit() 813 814 self.io.tool_warning("\n\n^C again to exit") 815 816 self.last_keyboard_interrupt = now 817 818 def summarize_start(self): 819 if not self.summarizer.too_big(self.done_messages): 820 return 821 822 self.summarize_end() 823 824 if self.verbose: 825 self.io.tool_output("Starting to summarize chat history.") 826 827 self.summarizer_thread = threading.Thread(target=self.summarize_worker) 828 self.summarizer_thread.start() 829 830 def summarize_worker(self): 831 try: 832 self.summarized_done_messages = self.summarizer.summarize(self.done_messages) 833 except ValueError as err: 834 self.io.tool_warning(err.args[0]) 835 836 if self.verbose: 837 self.io.tool_output("Finished summarizing chat history.") 838 839 def summarize_end(self): 840 if self.summarizer_thread is None: 841 return 842 843 self.summarizer_thread.join() 844 self.summarizer_thread = None 845 846 self.done_messages = self.summarized_done_messages 847 self.summarized_done_messages = [] 848 849 def move_back_cur_messages(self, message): 850 self.done_messages += self.cur_messages 851 self.summarize_start() 852 853 # TODO check for impact on image messages 854 if message: 855 self.done_messages += [ 856 dict(role="user", content=message), 857 dict(role="assistant", content="Ok."), 858 ] 859 self.cur_messages = [] 860 861 def get_user_language(self): 862 if self.chat_language: 863 return self.chat_language 864 865 try: 866 lang = locale.getlocale()[0] 867 if lang: 868 return lang # Return the full language code, including country 869 except Exception: 870 pass 871 872 for env_var in ["LANG", "LANGUAGE", "LC_ALL", "LC_MESSAGES"]: 873 lang = os.environ.get(env_var) 874 if lang: 875 return lang.split(".")[ 876 0 877 ] # Return language and country, but remove encoding if present 878 879 return None 880 881 def get_platform_info(self): 882 platform_text = f"- Platform: {platform.platform()}\n" 883 shell_var = "COMSPEC" if os.name == "nt" else "SHELL" 884 shell_val = os.getenv(shell_var) 885 platform_text += f"- Shell: {shell_var}={shell_val}\n" 886 887 user_lang = self.get_user_language() 888 if user_lang: 889 platform_text += f"- Language: {user_lang}\n" 890 891 dt = datetime.now().astimezone().strftime("%Y-%m-%d") 892 platform_text += f"- Current date: {dt}\n" 893 894 if self.repo: 895 platform_text += "- The user is operating inside a git repository\n" 896 897 if self.lint_cmds: 898 if self.auto_lint: 899 platform_text += ( 900 "- The user's pre-commit runs these lint commands, don't suggest running" 901 " them:\n" 902 ) 903 else: 904 platform_text += "- The user prefers these lint commands:\n" 905 for lang, cmd in self.lint_cmds.items(): 906 if lang is None: 907 platform_text += f" - {cmd}\n" 908 else: 909 platform_text += f" - {lang}: {cmd}\n" 910 911 if self.test_cmd: 912 if self.auto_test: 913 platform_text += ( 914 "- The user's pre-commit runs this test command, don't suggest running them: " 915 ) 916 else: 917 platform_text += "- The user prefers this test command: " 918 platform_text += self.test_cmd + "\n" 919 920 return platform_text 921 922 def fmt_system_prompt(self, prompt): 923 lazy_prompt = self.gpt_prompts.lazy_prompt if self.main_model.lazy else "" 924 platform_text = self.get_platform_info() 925 926 prompt = prompt.format( 927 fence=self.fence, 928 lazy_prompt=lazy_prompt, 929 platform=platform_text, 930 ) 931 return prompt 932 933 def format_chat_chunks(self): 934 self.choose_fence() 935 main_sys = self.fmt_system_prompt(self.gpt_prompts.main_system) 936 937 example_messages = [] 938 if self.main_model.examples_as_sys_msg: 939 if self.gpt_prompts.example_messages: 940 main_sys += "\n# Example conversations:\n\n" 941 for msg in self.gpt_prompts.example_messages: 942 role = msg["role"] 943 content = self.fmt_system_prompt(msg["content"]) 944 main_sys += f"## {role.upper()}: {content}\n\n" 945 main_sys = main_sys.strip() 946 else: 947 for msg in self.gpt_prompts.example_messages: 948 example_messages.append( 949 dict( 950 role=msg["role"], 951 content=self.fmt_system_prompt(msg["content"]), 952 ) 953 ) 954 if self.gpt_prompts.example_messages: 955 example_messages += [ 956 dict( 957 role="user", 958 content=( 959 "I switched to a new code base. Please don't consider the above files" 960 " or try to edit them any longer." 961 ), 962 ), 963 dict(role="assistant", content="Ok."), 964 ] 965 966 if self.gpt_prompts.system_reminder: 967 main_sys += "\n" + self.fmt_system_prompt(self.gpt_prompts.system_reminder) 968 969 chunks = ChatChunks() 970 971 chunks.system = [ 972 dict(role="system", content=main_sys), 973 ] 974 chunks.examples = example_messages 975 976 self.summarize_end() 977 chunks.done = self.done_messages 978 979 chunks.repo = self.get_repo_messages() 980 chunks.readonly_files = self.get_readonly_files_messages() 981 chunks.chat_files = self.get_chat_files_messages() 982 983 if self.gpt_prompts.system_reminder: 984 reminder_message = [ 985 dict( 986 role="system", content=self.fmt_system_prompt(self.gpt_prompts.system_reminder) 987 ), 988 ] 989 else: 990 reminder_message = [] 991 992 chunks.cur = list(self.cur_messages) 993 chunks.reminder = [] 994 995 # TODO review impact of token count on image messages 996 messages_tokens = self.main_model.token_count(chunks.all_messages()) 997 reminder_tokens = self.main_model.token_count(reminder_message) 998 cur_tokens = self.main_model.token_count(chunks.cur) 999 1000 if None not in (messages_tokens, reminder_tokens, cur_tokens): 1001 total_tokens = messages_tokens + reminder_tokens + cur_tokens 1002 else: 1003 # add the reminder anyway 1004 total_tokens = 0 1005 1006 final = chunks.cur[-1] 1007 1008 max_input_tokens = self.main_model.info.get("max_input_tokens") or 0 1009 # Add the reminder prompt if we still have room to include it. 1010 if ( 1011 max_input_tokens is None 1012 or total_tokens < max_input_tokens 1013 and self.gpt_prompts.system_reminder 1014 ): 1015 if self.main_model.reminder == "sys": 1016 chunks.reminder = reminder_message 1017 elif self.main_model.reminder == "user" and final["role"] == "user": 1018 # stuff it into the user message 1019 new_content = ( 1020 final["content"] 1021 + "\n\n" 1022 + self.fmt_system_prompt(self.gpt_prompts.system_reminder) 1023 ) 1024 chunks.cur[-1] = dict(role=final["role"], content=new_content) 1025 1026 return chunks 1027 1028 def format_messages(self): 1029 chunks = self.format_chat_chunks() 1030 if self.add_cache_headers: 1031 chunks.add_cache_control_headers() 1032 1033 return chunks 1034 1035 def warm_cache(self, chunks): 1036 if not self.add_cache_headers: 1037 return 1038 if not self.num_cache_warming_pings: 1039 return 1040 1041 delay = 5 * 60 - 5 1042 self.next_cache_warm = time.time() + delay 1043 self.warming_pings_left = self.num_cache_warming_pings 1044 self.cache_warming_chunks = chunks 1045 1046 if self.cache_warming_thread: 1047 return 1048 1049 def warm_cache_worker(): 1050 while True: 1051 time.sleep(1) 1052 if self.warming_pings_left <= 0: 1053 continue 1054 now = time.time() 1055 if now < self.next_cache_warm: 1056 continue 1057 1058 self.warming_pings_left -= 1 1059 self.next_cache_warm = time.time() + delay 1060 1061 try: 1062 completion = litellm.completion( 1063 model=self.main_model.name, 1064 messages=self.cache_warming_chunks.cacheable_messages(), 1065 stream=False, 1066 max_tokens=1, 1067 extra_headers=self.main_model.extra_headers, 1068 ) 1069 except Exception as err: 1070 self.io.tool_warning(f"Cache warming error: {str(err)}") 1071 continue 1072 1073 cache_hit_tokens = getattr( 1074 completion.usage, "prompt_cache_hit_tokens", 0 1075 ) or getattr(completion.usage, "cache_read_input_tokens", 0) 1076 1077 if self.verbose: 1078 self.io.tool_output(f"Warmed {format_tokens(cache_hit_tokens)} cached tokens.") 1079 1080 self.cache_warming_thread = threading.Timer(0, warm_cache_worker) 1081 self.cache_warming_thread.daemon = True 1082 self.cache_warming_thread.start() 1083 1084 return chunks 1085 1086 def send_message(self, inp): 1087 self.cur_messages += [ 1088 dict(role="user", content=inp), 1089 ] 1090 1091 chunks = self.format_messages() 1092 messages = chunks.all_messages() 1093 self.warm_cache(chunks) 1094 1095 if self.verbose: 1096 utils.show_messages(messages, functions=self.functions) 1097 1098 self.multi_response_content = "" 1099 if self.show_pretty() and self.stream: 1100 mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme) 1101 self.mdstream = MarkdownStream(mdargs=mdargs) 1102 else: 1103 self.mdstream = None 1104 1105 retry_delay = 0.125 1106 1107 self.usage_report = None 1108 exhausted = False 1109 interrupted = False 1110 try: 1111 while True: 1112 try: 1113 yield from self.send(messages, functions=self.functions) 1114 break 1115 except retry_exceptions() as err: 1116 self.io.tool_warning(str(err)) 1117 retry_delay *= 2 1118 if retry_delay > 60: 1119 break 1120 self.io.tool_output(f"Retrying in {retry_delay:.1f} seconds...") 1121 time.sleep(retry_delay) 1122 continue 1123 except KeyboardInterrupt: 1124 interrupted = True 1125 break 1126 except litellm.ContextWindowExceededError: 1127 # The input is overflowing the context window! 1128 exhausted = True 1129 break 1130 except litellm.exceptions.BadRequestError as br_err: 1131 self.io.tool_error(f"BadRequestError: {br_err}") 1132 return 1133 except FinishReasonLength: 1134 # We hit the output limit! 1135 if not self.main_model.info.get("supports_assistant_prefill"): 1136 exhausted = True 1137 break 1138 1139 self.multi_response_content = self.get_multi_response_content() 1140 1141 if messages[-1]["role"] == "assistant": 1142 messages[-1]["content"] = self.multi_response_content 1143 else: 1144 messages.append( 1145 dict(role="assistant", content=self.multi_response_content, prefix=True) 1146 ) 1147 except Exception as err: 1148 self.io.tool_error(f"Unexpected error: {err}") 1149 lines = traceback.format_exception(type(err), err, err.__traceback__) 1150 self.io.tool_error("".join(lines)) 1151 return 1152 finally: 1153 if self.mdstream: 1154 self.live_incremental_response(True) 1155 self.mdstream = None 1156 1157 self.partial_response_content = self.get_multi_response_content(True) 1158 self.multi_response_content = "" 1159 1160 self.io.tool_output() 1161 1162 self.show_usage_report() 1163 1164 if exhausted: 1165 self.show_exhausted_error() 1166 self.num_exhausted_context_windows += 1 1167 return 1168 1169 if self.partial_response_function_call: 1170 args = self.parse_partial_args() 1171 if args: 1172 content = args.get("explanation") or "" 1173 else: 1174 content = "" 1175 elif self.partial_response_content: 1176 content = self.partial_response_content 1177 else: 1178 content = "" 1179 1180 if interrupted: 1181 content += "\n^C KeyboardInterrupt" 1182 self.cur_messages += [dict(role="assistant", content=content)] 1183 return 1184 1185 edited = self.apply_updates() 1186 1187 self.update_cur_messages() 1188 1189 if edited: 1190 self.aider_edited_files.update(edited) 1191 saved_message = self.auto_commit(edited) 1192 1193 if not saved_message and hasattr(self.gpt_prompts, "files_content_gpt_edits_no_repo"): 1194 saved_message = self.gpt_prompts.files_content_gpt_edits_no_repo 1195 1196 self.move_back_cur_messages(saved_message) 1197 1198 if self.reflected_message: 1199 return 1200 1201 if edited and self.auto_lint: 1202 lint_errors = self.lint_edited(edited) 1203 self.auto_commit(edited, context="Ran the linter") 1204 self.lint_outcome = not lint_errors 1205 if lint_errors: 1206 ok = self.io.confirm_ask("Attempt to fix lint errors?") 1207 if ok: 1208 self.reflected_message = lint_errors 1209 self.update_cur_messages() 1210 return 1211 1212 shared_output = self.run_shell_commands() 1213 if shared_output: 1214 self.cur_messages += [ 1215 dict(role="user", content=shared_output), 1216 dict(role="assistant", content="Ok"), 1217 ] 1218 1219 if edited and self.auto_test: 1220 test_errors = self.commands.cmd_test(self.test_cmd) 1221 self.test_outcome = not test_errors 1222 if test_errors: 1223 ok = self.io.confirm_ask("Attempt to fix test errors?") 1224 if ok: 1225 self.reflected_message = test_errors 1226 self.update_cur_messages() 1227 return 1228 1229 add_rel_files_message = self.check_for_file_mentions(content) 1230 if add_rel_files_message: 1231 if self.reflected_message: 1232 self.reflected_message += "\n\n" + add_rel_files_message 1233 else: 1234 self.reflected_message = add_rel_files_message 1235 1236 def show_exhausted_error(self): 1237 output_tokens = 0 1238 if self.partial_response_content: 1239 output_tokens = self.main_model.token_count(self.partial_response_content) 1240 max_output_tokens = self.main_model.info.get("max_output_tokens") or 0 1241 1242 input_tokens = self.main_model.token_count(self.format_messages().all_messages()) 1243 max_input_tokens = self.main_model.info.get("max_input_tokens") or 0 1244 1245 total_tokens = input_tokens + output_tokens 1246 1247 fudge = 0.7 1248 1249 out_err = "" 1250 if output_tokens >= max_output_tokens * fudge: 1251 out_err = " -- possibly exceeded output limit!" 1252 1253 inp_err = "" 1254 if input_tokens >= max_input_tokens * fudge: 1255 inp_err = " -- possibly exhausted context window!" 1256 1257 tot_err = "" 1258 if total_tokens >= max_input_tokens * fudge: 1259 tot_err = " -- possibly exhausted context window!" 1260 1261 res = ["", ""] 1262 res.append(f"Model {self.main_model.name} has hit a token limit!") 1263 res.append("Token counts below are approximate.") 1264 res.append("") 1265 res.append(f"Input tokens: ~{input_tokens:,} of {max_input_tokens:,}{inp_err}") 1266 res.append(f"Output tokens: ~{output_tokens:,} of {max_output_tokens:,}{out_err}") 1267 res.append(f"Total tokens: ~{total_tokens:,} of {max_input_tokens:,}{tot_err}") 1268 1269 if output_tokens >= max_output_tokens: 1270 res.append("") 1271 res.append("To reduce output tokens:") 1272 res.append("- Ask for smaller changes in each request.") 1273 res.append("- Break your code into smaller source files.") 1274 if "diff" not in self.main_model.edit_format: 1275 res.append( 1276 "- Use a stronger model like gpt-4o, sonnet or opus that can return diffs." 1277 ) 1278 1279 if input_tokens >= max_input_tokens or total_tokens >= max_input_tokens: 1280 res.append("") 1281 res.append("To reduce input tokens:") 1282 res.append("- Use /tokens to see token usage.") 1283 res.append("- Use /drop to remove unneeded files from the chat session.") 1284 res.append("- Use /clear to clear the chat history.") 1285 res.append("- Break your code into smaller source files.") 1286 1287 res.append("") 1288 res.append(f"For more info: {urls.token_limits}") 1289 1290 res = "".join([line + "\n" for line in res]) 1291 self.io.tool_error(res) 1292 1293 def lint_edited(self, fnames): 1294 res = "" 1295 for fname in fnames: 1296 errors = self.linter.lint(self.abs_root_path(fname)) 1297 1298 if errors: 1299 res += "\n" 1300 res += errors 1301 res += "\n" 1302 1303 if res: 1304 self.io.tool_warning(res) 1305 1306 return res 1307 1308 def update_cur_messages(self): 1309 if self.partial_response_content: 1310 self.cur_messages += [dict(role="assistant", content=self.partial_response_content)] 1311 if self.partial_response_function_call: 1312 self.cur_messages += [ 1313 dict( 1314 role="assistant", 1315 content=None, 1316 function_call=self.partial_response_function_call, 1317 ) 1318 ] 1319 1320 def get_file_mentions(self, content): 1321 words = set(word for word in content.split()) 1322 1323 # drop sentence punctuation from the end 1324 words = set(word.rstrip(",.!;:") for word in words) 1325 1326 # strip away all kinds of quotes 1327 quotes = "".join(['"', "'", "`"]) 1328 words = set(word.strip(quotes) for word in words) 1329 1330 addable_rel_fnames = self.get_addable_relative_files() 1331 1332 mentioned_rel_fnames = set() 1333 fname_to_rel_fnames = {} 1334 for rel_fname in addable_rel_fnames: 1335 normalized_rel_fname = rel_fname.replace("\\", "/") 1336 normalized_words = set(word.replace("\\", "/") for word in words) 1337 if normalized_rel_fname in normalized_words: 1338 mentioned_rel_fnames.add(rel_fname) 1339 1340 fname = os.path.basename(rel_fname) 1341 1342 # Don't add basenames that could be plain words like "run" or "make" 1343 if "/" in fname or "\\" in fname or "." in fname or "_" in fname or "-" in fname: 1344 if fname not in fname_to_rel_fnames: 1345 fname_to_rel_fnames[fname] = [] 1346 fname_to_rel_fnames[fname].append(rel_fname) 1347 1348 for fname, rel_fnames in fname_to_rel_fnames.items(): 1349 if len(rel_fnames) == 1 and fname in words: 1350 mentioned_rel_fnames.add(rel_fnames[0]) 1351 1352 return mentioned_rel_fnames 1353 1354 def check_for_file_mentions(self, content): 1355 mentioned_rel_fnames = self.get_file_mentions(content) 1356 1357 new_mentions = mentioned_rel_fnames - self.ignore_mentions 1358 1359 if not new_mentions: 1360 return 1361 1362 added_fnames = [] 1363 group = ConfirmGroup(new_mentions) 1364 for rel_fname in sorted(new_mentions): 1365 if self.io.confirm_ask(f"Add {rel_fname} to the chat?", group=group): 1366 self.add_rel_fname(rel_fname) 1367 added_fnames.append(rel_fname) 1368 else: 1369 self.ignore_mentions.add(rel_fname) 1370 1371 if added_fnames: 1372 return prompts.added_files.format(fnames=", ".join(added_fnames)) 1373 1374 def send(self, messages, model=None, functions=None): 1375 if not model: 1376 model = self.main_model 1377 1378 self.partial_response_content = "" 1379 self.partial_response_function_call = dict() 1380 1381 self.io.log_llm_history("TO LLM", format_messages(messages)) 1382 1383 completion = None 1384 try: 1385 hash_object, completion = send_completion( 1386 model.name, 1387 messages, 1388 functions, 1389 self.stream, 1390 self.temperature, 1391 extra_headers=model.extra_headers, 1392 max_tokens=model.max_tokens, 1393 ) 1394 self.chat_completion_call_hashes.append(hash_object.hexdigest()) 1395 1396 if self.stream: 1397 yield from self.show_send_output_stream(completion) 1398 else: 1399 self.show_send_output(completion) 1400 except KeyboardInterrupt as kbi: 1401 self.keyboard_interrupt() 1402 raise kbi 1403 finally: 1404 self.io.log_llm_history( 1405 "LLM RESPONSE", 1406 format_content("ASSISTANT", self.partial_response_content), 1407 ) 1408 1409 if self.partial_response_content: 1410 self.io.ai_output(self.partial_response_content) 1411 elif self.partial_response_function_call: 1412 # TODO: push this into subclasses 1413 args = self.parse_partial_args() 1414 if args: 1415 self.io.ai_output(json.dumps(args, indent=4)) 1416 1417 self.calculate_and_show_tokens_and_cost(messages, completion) 1418 1419 def show_send_output(self, completion): 1420 if self.verbose: 1421 print(completion) 1422 1423 if not completion.choices: 1424 self.io.tool_error(str(completion)) 1425 return 1426 1427 show_func_err = None 1428 show_content_err = None 1429 try: 1430 if completion.choices[0].message.tool_calls: 1431 self.partial_response_function_call = ( 1432 completion.choices[0].message.tool_calls[0].function 1433 ) 1434 except AttributeError as func_err: 1435 show_func_err = func_err 1436 1437 try: 1438 self.partial_response_content = completion.choices[0].message.content or "" 1439 except AttributeError as content_err: 1440 show_content_err = content_err 1441 1442 resp_hash = dict( 1443 function_call=str(self.partial_response_function_call), 1444 content=self.partial_response_content, 1445 ) 1446 resp_hash = hashlib.sha1(json.dumps(resp_hash, sort_keys=True).encode()) 1447 self.chat_completion_response_hashes.append(resp_hash.hexdigest()) 1448 1449 if show_func_err and show_content_err: 1450 self.io.tool_error(show_func_err) 1451 self.io.tool_error(show_content_err) 1452 raise Exception("No data found in LLM response!") 1453 1454 show_resp = self.render_incremental_response(True) 1455 if self.show_pretty(): 1456 show_resp = Markdown( 1457 show_resp, style=self.assistant_output_color, code_theme=self.code_theme 1458 ) 1459 else: 1460 show_resp = Text(show_resp or "<no response>") 1461 1462 self.io.console.print(show_resp) 1463 1464 if ( 1465 hasattr(completion.choices[0], "finish_reason") 1466 and completion.choices[0].finish_reason == "length" 1467 ): 1468 raise FinishReasonLength() 1469 1470 def show_send_output_stream(self, completion): 1471 for chunk in completion: 1472 if len(chunk.choices) == 0: 1473 continue 1474 1475 if ( 1476 hasattr(chunk.choices[0], "finish_reason") 1477 and chunk.choices[0].finish_reason == "length" 1478 ): 1479 raise FinishReasonLength() 1480 1481 try: 1482 func = chunk.choices[0].delta.function_call 1483 # dump(func) 1484 for k, v in func.items(): 1485 if k in self.partial_response_function_call: 1486 self.partial_response_function_call[k] += v 1487 else: 1488 self.partial_response_function_call[k] = v 1489 except AttributeError: 1490 pass 1491 1492 try: 1493 text = chunk.choices[0].delta.content 1494 if text: 1495 self.partial_response_content += text 1496 except AttributeError: 1497 text = None 1498 1499 if self.show_pretty(): 1500 self.live_incremental_response(False) 1501 elif text: 1502 try: 1503 sys.stdout.write(text) 1504 except UnicodeEncodeError: 1505 # Safely encode and decode the text 1506 safe_text = text.encode(sys.stdout.encoding, errors="backslashreplace").decode( 1507 sys.stdout.encoding 1508 ) 1509 sys.stdout.write(safe_text) 1510 sys.stdout.flush() 1511 yield text 1512 1513 def live_incremental_response(self, final): 1514 show_resp = self.render_incremental_response(final) 1515 self.mdstream.update(show_resp, final=final) 1516 1517 def render_incremental_response(self, final): 1518 return self.get_multi_response_content() 1519 1520 def calculate_and_show_tokens_and_cost(self, messages, completion=None): 1521 prompt_tokens = 0 1522 completion_tokens = 0 1523 cache_hit_tokens = 0 1524 cache_write_tokens = 0 1525 1526 if completion and hasattr(completion, "usage") and completion.usage is not None: 1527 prompt_tokens = completion.usage.prompt_tokens 1528 completion_tokens = completion.usage.completion_tokens 1529 cache_hit_tokens = getattr(completion.usage, "prompt_cache_hit_tokens", 0) or getattr( 1530 completion.usage, "cache_read_input_tokens", 0 1531 ) 1532 cache_write_tokens = getattr(completion.usage, "cache_creation_input_tokens", 0) 1533 1534 if hasattr(completion.usage, "cache_read_input_tokens") or hasattr( 1535 completion.usage, "cache_creation_input_tokens" 1536 ): 1537 self.message_tokens_sent += prompt_tokens 1538 self.message_tokens_sent += cache_hit_tokens 1539 self.message_tokens_sent += cache_write_tokens 1540 else: 1541 self.message_tokens_sent += prompt_tokens 1542 1543 else: 1544 prompt_tokens = self.main_model.token_count(messages) 1545 completion_tokens = self.main_model.token_count(self.partial_response_content) 1546 self.message_tokens_sent += prompt_tokens 1547 1548 self.message_tokens_received += completion_tokens 1549 1550 tokens_report = f"Tokens: {format_tokens(self.message_tokens_sent)} sent" 1551 1552 if cache_write_tokens: 1553 tokens_report += f", {format_tokens(cache_write_tokens)} cache write" 1554 if cache_hit_tokens: 1555 tokens_report += f", {format_tokens(cache_hit_tokens)} cache hit" 1556 tokens_report += f", {format_tokens(self.message_tokens_received)} received." 1557 1558 if not self.main_model.info.get("input_cost_per_token"): 1559 self.usage_report = tokens_report 1560 return 1561 1562 cost = 0 1563 1564 input_cost_per_token = self.main_model.info.get("input_cost_per_token") or 0 1565 output_cost_per_token = self.main_model.info.get("output_cost_per_token") or 0 1566 input_cost_per_token_cache_hit = ( 1567 self.main_model.info.get("input_cost_per_token_cache_hit") or 0 1568 ) 1569 1570 # deepseek 1571 # prompt_cache_hit_tokens + prompt_cache_miss_tokens 1572 # == prompt_tokens == total tokens that were sent 1573 # 1574 # Anthropic 1575 # cache_creation_input_tokens + cache_read_input_tokens + prompt 1576 # == total tokens that were 1577 1578 if input_cost_per_token_cache_hit: 1579 # must be deepseek 1580 cost += input_cost_per_token_cache_hit * cache_hit_tokens 1581 cost += (prompt_tokens - input_cost_per_token_cache_hit) * input_cost_per_token 1582 else: 1583 # hard code the anthropic adjustments, no-ops for other models since cache_x_tokens==0 1584 cost += cache_write_tokens * input_cost_per_token * 1.25 1585 cost += cache_hit_tokens * input_cost_per_token * 0.10 1586 cost += prompt_tokens * input_cost_per_token 1587 1588 cost += completion_tokens * output_cost_per_token 1589 1590 self.total_cost += cost 1591 self.message_cost += cost 1592 1593 def format_cost(value): 1594 if value == 0: 1595 return "0.00" 1596 magnitude = abs(value) 1597 if magnitude >= 0.01: 1598 return f"{value:.2f}" 1599 else: 1600 return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}" 1601 1602 cost_report = ( 1603 f"Cost: ${format_cost(self.message_cost)} message," 1604 f" ${format_cost(self.total_cost)} session." 1605 ) 1606 1607 if self.add_cache_headers and self.stream: 1608 warning = " Use --no-stream for accurate caching costs." 1609 self.usage_report = tokens_report + "\n" + cost_report + warning 1610 return 1611 1612 if cache_hit_tokens and cache_write_tokens: 1613 sep = "\n" 1614 else: 1615 sep = " " 1616 1617 self.usage_report = tokens_report + sep + cost_report 1618 1619 def show_usage_report(self): 1620 if self.usage_report: 1621 self.io.tool_output(self.usage_report) 1622 self.message_cost = 0.0 1623 self.message_tokens_sent = 0 1624 self.message_tokens_received = 0 1625 1626 def get_multi_response_content(self, final=False): 1627 cur = self.multi_response_content or "" 1628 new = self.partial_response_content or "" 1629 1630 if new.rstrip() != new and not final: 1631 new = new.rstrip() 1632 return cur + new 1633 1634 def get_rel_fname(self, fname): 1635 try: 1636 return os.path.relpath(fname, self.root) 1637 except ValueError: 1638 return fname 1639 1640 def get_inchat_relative_files(self): 1641 files = [self.get_rel_fname(fname) for fname in self.abs_fnames] 1642 return sorted(set(files)) 1643 1644 def is_file_safe(self, fname): 1645 try: 1646 return Path(self.abs_root_path(fname)).is_file() 1647 except OSError: 1648 return 1649 1650 def get_all_relative_files(self): 1651 if self.repo: 1652 files = self.repo.get_tracked_files() 1653 else: 1654 files = self.get_inchat_relative_files() 1655 1656 # This is quite slow in large repos 1657 # files = [fname for fname in files if self.is_file_safe(fname)] 1658 1659 return sorted(set(files)) 1660 1661 def get_all_abs_files(self): 1662 files = self.get_all_relative_files() 1663 files = [self.abs_root_path(path) for path in files] 1664 return files 1665 1666 def get_addable_relative_files(self): 1667 all_files = set(self.get_all_relative_files()) 1668 inchat_files = set(self.get_inchat_relative_files()) 1669 read_only_files = set(self.get_rel_fname(fname) for fname in self.abs_read_only_fnames) 1670 return all_files - inchat_files - read_only_files 1671 1672 def check_for_dirty_commit(self, path): 1673 if not self.repo: 1674 return 1675 if not self.dirty_commits: 1676 return 1677 if not self.repo.is_dirty(path): 1678 return 1679 1680 # We need a committed copy of the file in order to /undo, so skip this 1681 # fullp = Path(self.abs_root_path(path)) 1682 # if not fullp.stat().st_size: 1683 # return 1684 1685 self.io.tool_output(f"Committing {path} before applying edits.") 1686 self.need_commit_before_edits.add(path) 1687 1688 def allowed_to_edit(self, path): 1689 full_path = self.abs_root_path(path) 1690 if self.repo: 1691 need_to_add = not self.repo.path_in_repo(path) 1692 else: 1693 need_to_add = False 1694 1695 if full_path in self.abs_fnames: 1696 self.check_for_dirty_commit(path) 1697 return True 1698 1699 if not Path(full_path).exists(): 1700 if not self.io.confirm_ask("Create new file?", subject=path): 1701 self.io.tool_output(f"Skipping edits to {path}") 1702 return 1703 1704 if not self.dry_run: 1705 if not utils.touch_file(full_path): 1706 self.io.tool_error(f"Unable to create {path}, skipping edits.") 1707 return 1708 1709 # Seems unlikely that we needed to create the file, but it was 1710 # actually already part of the repo. 1711 # But let's only add if we need to, just to be safe. 1712 if need_to_add: 1713 self.repo.repo.git.add(full_path) 1714 1715 self.abs_fnames.add(full_path) 1716 self.check_added_files() 1717 return True 1718 1719 if not self.io.confirm_ask( 1720 "Allow edits to file that has not been added to the chat?", 1721 subject=path, 1722 ): 1723 self.io.tool_output(f"Skipping edits to {path}") 1724 return 1725 1726 if need_to_add: 1727 self.repo.repo.git.add(full_path) 1728 1729 self.abs_fnames.add(full_path) 1730 self.check_added_files() 1731 self.check_for_dirty_commit(path) 1732 1733 return True 1734 1735 warning_given = False 1736 1737 def check_added_files(self): 1738 if self.warning_given: 1739 return 1740 1741 warn_number_of_files = 4 1742 warn_number_of_tokens = 20 * 1024 1743 1744 num_files = len(self.abs_fnames) 1745 if num_files < warn_number_of_files: 1746 return 1747 1748 tokens = 0 1749 for fname in self.abs_fnames: 1750 if is_image_file(fname): 1751 continue 1752 content = self.io.read_text(fname) 1753 tokens += self.main_model.token_count(content) 1754 1755 if tokens < warn_number_of_tokens: 1756 return 1757 1758 self.io.tool_warning("Warning: it's best to only add files that need changes to the chat.") 1759 self.io.tool_warning(urls.edit_errors) 1760 self.warning_given = True 1761 1762 def prepare_to_edit(self, edits): 1763 res = [] 1764 seen = dict() 1765 1766 self.need_commit_before_edits = set() 1767 1768 for edit in edits: 1769 path = edit[0] 1770 if path is None: 1771 res.append(edit) 1772 continue 1773 if path == "python": 1774 dump(edits) 1775 if path in seen: 1776 allowed = seen[path] 1777 else: 1778 allowed = self.allowed_to_edit(path) 1779 seen[path] = allowed 1780 1781 if allowed: 1782 res.append(edit) 1783 1784 self.dirty_commit() 1785 self.need_commit_before_edits = set() 1786 1787 return res 1788 1789 def apply_updates(self): 1790 edited = set() 1791 try: 1792 edits = self.get_edits() 1793 edits = self.prepare_to_edit(edits) 1794 edited = set(edit[0] for edit in edits) 1795 self.apply_edits(edits) 1796 except ValueError as err: 1797 self.num_malformed_responses += 1 1798 1799 err = err.args[0] 1800 1801 self.io.tool_error("The LLM did not conform to the edit format.") 1802 self.io.tool_output(urls.edit_errors) 1803 self.io.tool_output() 1804 self.io.tool_output(str(err)) 1805 1806 self.reflected_message = str(err) 1807 return edited 1808 1809 except ANY_GIT_ERROR as err: 1810 self.io.tool_error(str(err)) 1811 return edited 1812 except Exception as err: 1813 self.io.tool_error("Exception while updating files:") 1814 self.io.tool_error(str(err), strip=False) 1815 1816 traceback.print_exc() 1817 1818 self.reflected_message = str(err) 1819 return edited 1820 1821 for path in edited: 1822 if self.dry_run: 1823 self.io.tool_output(f"Did not apply edit to {path} (--dry-run)") 1824 else: 1825 self.io.tool_output(f"Applied edit to {path}") 1826 1827 return edited 1828 1829 def parse_partial_args(self): 1830 # dump(self.partial_response_function_call) 1831 1832 data = self.partial_response_function_call.get("arguments") 1833 if not data: 1834 return 1835 1836 try: 1837 return json.loads(data) 1838 except JSONDecodeError: 1839 pass 1840 1841 try: 1842 return json.loads(data + "]}") 1843 except JSONDecodeError: 1844 pass 1845 1846 try: 1847 return json.loads(data + "}]}") 1848 except JSONDecodeError: 1849 pass 1850 1851 try: 1852 return json.loads(data + '"}]}') 1853 except JSONDecodeError: 1854 pass 1855 1856 # commits... 1857 1858 def get_context_from_history(self, history): 1859 context = "" 1860 if history: 1861 for msg in history: 1862 context += "\n" + msg["role"].upper() + ": " + msg["content"] + "\n" 1863 1864 return context 1865 1866 def auto_commit(self, edited, context=None): 1867 if not self.repo or not self.auto_commits or self.dry_run: 1868 return 1869 1870 if not context: 1871 context = self.get_context_from_history(self.cur_messages) 1872 1873 try: 1874 res = self.repo.commit(fnames=edited, context=context, aider_edits=True) 1875 if res: 1876 self.show_auto_commit_outcome(res) 1877 commit_hash, commit_message = res 1878 return self.gpt_prompts.files_content_gpt_edits.format( 1879 hash=commit_hash, 1880 message=commit_message, 1881 ) 1882 1883 self.io.tool_output("No changes made to git tracked files.") 1884 return self.gpt_prompts.files_content_gpt_no_edits 1885 except ANY_GIT_ERROR as err: 1886 self.io.tool_error(f"Unable to commit: {str(err)}") 1887 return 1888 1889 def show_auto_commit_outcome(self, res): 1890 commit_hash, commit_message = res 1891 self.last_aider_commit_hash = commit_hash 1892 self.aider_commit_hashes.add(commit_hash) 1893 self.last_aider_commit_message = commit_message 1894 if self.show_diffs: 1895 self.commands.cmd_diff() 1896 1897 def show_undo_hint(self): 1898 if not self.commit_before_message: 1899 return 1900 if self.commit_before_message[-1] != self.repo.get_head_commit_sha(): 1901 self.io.tool_output("You can use /undo to undo and discard each aider commit.") 1902 1903 def dirty_commit(self): 1904 if not self.need_commit_before_edits: 1905 return 1906 if not self.dirty_commits: 1907 return 1908 if not self.repo: 1909 return 1910 1911 self.repo.commit(fnames=self.need_commit_before_edits) 1912 1913 # files changed, move cur messages back behind the files messages 1914 # self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits) 1915 return True 1916 1917 def get_edits(self, mode="update"): 1918 return [] 1919 1920 def apply_edits(self, edits): 1921 return 1922 1923 def run_shell_commands(self): 1924 if not self.suggest_shell_commands: 1925 return "" 1926 1927 done = set() 1928 group = ConfirmGroup(set(self.shell_commands)) 1929 accumulated_output = "" 1930 for command in self.shell_commands: 1931 if command in done: 1932 continue 1933 done.add(command) 1934 output = self.handle_shell_commands(command, group) 1935 if output: 1936 accumulated_output += output + "\n\n" 1937 return accumulated_output 1938 1939 def handle_shell_commands(self, commands_str, group): 1940 commands = commands_str.strip().splitlines() 1941 command_count = sum( 1942 1 for cmd in commands if cmd.strip() and not cmd.strip().startswith("#") 1943 ) 1944 prompt = "Run shell command?" if command_count == 1 else "Run shell commands?" 1945 if not self.io.confirm_ask( 1946 prompt, subject="\n".join(commands), explicit_yes_required=True, group=group 1947 ): 1948 return 1949 1950 accumulated_output = "" 1951 for command in commands: 1952 command = command.strip() 1953 if not command or command.startswith("#"): 1954 continue 1955 1956 self.io.tool_output() 1957 self.io.tool_output(f"Running {command}") 1958 # Add the command to input history 1959 self.io.add_to_input_history(f"/run {command.strip()}") 1960 exit_status, output = run_cmd(command, error_print=self.io.tool_error) 1961 if output: 1962 accumulated_output += f"Output from {command}\n{output}\n" 1963 1964 if accumulated_output.strip() and not self.io.confirm_ask( 1965 "Add command output to the chat?" 1966 ): 1967 accumulated_output = "" 1968 1969 return accumulated_output