query_engine.py
1 from __future__ import annotations 2 3 import json 4 from dataclasses import dataclass, field 5 from uuid import uuid4 6 7 from .commands import build_command_backlog 8 from .models import PermissionDenial, UsageSummary 9 from .port_manifest import PortManifest, build_port_manifest 10 from .session_store import StoredSession, load_session, save_session 11 from .tools import build_tool_backlog 12 from .transcript import TranscriptStore 13 14 15 @dataclass(frozen=True) 16 class QueryEngineConfig: 17 max_turns: int = 8 18 max_budget_tokens: int = 2000 19 compact_after_turns: int = 12 20 structured_output: bool = False 21 structured_retry_limit: int = 2 22 23 24 @dataclass(frozen=True) 25 class TurnResult: 26 prompt: str 27 output: str 28 matched_commands: tuple[str, ...] 29 matched_tools: tuple[str, ...] 30 permission_denials: tuple[PermissionDenial, ...] 31 usage: UsageSummary 32 stop_reason: str 33 34 35 @dataclass 36 class QueryEnginePort: 37 manifest: PortManifest 38 config: QueryEngineConfig = field(default_factory=QueryEngineConfig) 39 session_id: str = field(default_factory=lambda: uuid4().hex) 40 mutable_messages: list[str] = field(default_factory=list) 41 permission_denials: list[PermissionDenial] = field(default_factory=list) 42 total_usage: UsageSummary = field(default_factory=UsageSummary) 43 transcript_store: TranscriptStore = field(default_factory=TranscriptStore) 44 45 @classmethod 46 def from_workspace(cls) -> 'QueryEnginePort': 47 return cls(manifest=build_port_manifest()) 48 49 @classmethod 50 def from_saved_session(cls, session_id: str) -> 'QueryEnginePort': 51 stored = load_session(session_id) 52 transcript = TranscriptStore(entries=list(stored.messages), flushed=True) 53 return cls( 54 manifest=build_port_manifest(), 55 session_id=stored.session_id, 56 mutable_messages=list(stored.messages), 57 total_usage=UsageSummary(stored.input_tokens, stored.output_tokens), 58 transcript_store=transcript, 59 ) 60 61 def submit_message( 62 self, 63 prompt: str, 64 matched_commands: tuple[str, ...] = (), 65 matched_tools: tuple[str, ...] = (), 66 denied_tools: tuple[PermissionDenial, ...] = (), 67 ) -> TurnResult: 68 if len(self.mutable_messages) >= self.config.max_turns: 69 output = f'Max turns reached before processing prompt: {prompt}' 70 return TurnResult( 71 prompt=prompt, 72 output=output, 73 matched_commands=matched_commands, 74 matched_tools=matched_tools, 75 permission_denials=denied_tools, 76 usage=self.total_usage, 77 stop_reason='max_turns_reached', 78 ) 79 80 summary_lines = [ 81 f'Prompt: {prompt}', 82 f'Matched commands: {", ".join(matched_commands) if matched_commands else "none"}', 83 f'Matched tools: {", ".join(matched_tools) if matched_tools else "none"}', 84 f'Permission denials: {len(denied_tools)}', 85 ] 86 output = self._format_output(summary_lines) 87 projected_usage = self.total_usage.add_turn(prompt, output) 88 stop_reason = 'completed' 89 if projected_usage.input_tokens + projected_usage.output_tokens > self.config.max_budget_tokens: 90 stop_reason = 'max_budget_reached' 91 self.mutable_messages.append(prompt) 92 self.transcript_store.append(prompt) 93 self.permission_denials.extend(denied_tools) 94 self.total_usage = projected_usage 95 self.compact_messages_if_needed() 96 return TurnResult( 97 prompt=prompt, 98 output=output, 99 matched_commands=matched_commands, 100 matched_tools=matched_tools, 101 permission_denials=denied_tools, 102 usage=self.total_usage, 103 stop_reason=stop_reason, 104 ) 105 106 def stream_submit_message( 107 self, 108 prompt: str, 109 matched_commands: tuple[str, ...] = (), 110 matched_tools: tuple[str, ...] = (), 111 denied_tools: tuple[PermissionDenial, ...] = (), 112 ): 113 yield {'type': 'message_start', 'session_id': self.session_id, 'prompt': prompt} 114 if matched_commands: 115 yield {'type': 'command_match', 'commands': matched_commands} 116 if matched_tools: 117 yield {'type': 'tool_match', 'tools': matched_tools} 118 if denied_tools: 119 yield {'type': 'permission_denial', 'denials': [denial.tool_name for denial in denied_tools]} 120 result = self.submit_message(prompt, matched_commands, matched_tools, denied_tools) 121 yield {'type': 'message_delta', 'text': result.output} 122 yield { 123 'type': 'message_stop', 124 'usage': {'input_tokens': result.usage.input_tokens, 'output_tokens': result.usage.output_tokens}, 125 'stop_reason': result.stop_reason, 126 'transcript_size': len(self.transcript_store.entries), 127 } 128 129 def compact_messages_if_needed(self) -> None: 130 if len(self.mutable_messages) > self.config.compact_after_turns: 131 self.mutable_messages[:] = self.mutable_messages[-self.config.compact_after_turns :] 132 self.transcript_store.compact(self.config.compact_after_turns) 133 134 def replay_user_messages(self) -> tuple[str, ...]: 135 return self.transcript_store.replay() 136 137 def flush_transcript(self) -> None: 138 self.transcript_store.flush() 139 140 def persist_session(self) -> str: 141 self.flush_transcript() 142 path = save_session( 143 StoredSession( 144 session_id=self.session_id, 145 messages=tuple(self.mutable_messages), 146 input_tokens=self.total_usage.input_tokens, 147 output_tokens=self.total_usage.output_tokens, 148 ) 149 ) 150 return str(path) 151 152 def _format_output(self, summary_lines: list[str]) -> str: 153 if self.config.structured_output: 154 payload = { 155 'summary': summary_lines, 156 'session_id': self.session_id, 157 } 158 return self._render_structured_output(payload) 159 return '\n'.join(summary_lines) 160 161 def _render_structured_output(self, payload: dict[str, object]) -> str: 162 last_error: Exception | None = None 163 for _ in range(self.config.structured_retry_limit): 164 try: 165 return json.dumps(payload, indent=2) 166 except (TypeError, ValueError) as exc: # pragma: no cover - defensive branch 167 last_error = exc 168 payload = {'summary': ['structured output retry'], 'session_id': self.session_id} 169 raise RuntimeError('structured output rendering failed') from last_error 170 171 def render_summary(self) -> str: 172 command_backlog = build_command_backlog() 173 tool_backlog = build_tool_backlog() 174 sections = [ 175 '# Python Porting Workspace Summary', 176 '', 177 self.manifest.to_markdown(), 178 '', 179 f'Command surface: {len(command_backlog.modules)} mirrored entries', 180 *command_backlog.summary_lines()[:10], 181 '', 182 f'Tool surface: {len(tool_backlog.modules)} mirrored entries', 183 *tool_backlog.summary_lines()[:10], 184 '', 185 f'Session id: {self.session_id}', 186 f'Conversation turns stored: {len(self.mutable_messages)}', 187 f'Permission denials tracked: {len(self.permission_denials)}', 188 f'Usage totals: in={self.total_usage.input_tokens} out={self.total_usage.output_tokens}', 189 f'Max turns: {self.config.max_turns}', 190 f'Max budget tokens: {self.config.max_budget_tokens}', 191 f'Transcript flushed: {self.transcript_store.flushed}', 192 ] 193 return '\n'.join(sections)