/ src / query_engine.py
query_engine.py
  1  from __future__ import annotations
  2  
  3  import json
  4  from dataclasses import dataclass, field
  5  from uuid import uuid4
  6  
  7  from .commands import build_command_backlog
  8  from .models import PermissionDenial, UsageSummary
  9  from .port_manifest import PortManifest, build_port_manifest
 10  from .session_store import StoredSession, load_session, save_session
 11  from .tools import build_tool_backlog
 12  from .transcript import TranscriptStore
 13  
 14  
 15  @dataclass(frozen=True)
 16  class QueryEngineConfig:
 17      max_turns: int = 8
 18      max_budget_tokens: int = 2000
 19      compact_after_turns: int = 12
 20      structured_output: bool = False
 21      structured_retry_limit: int = 2
 22  
 23  
 24  @dataclass(frozen=True)
 25  class TurnResult:
 26      prompt: str
 27      output: str
 28      matched_commands: tuple[str, ...]
 29      matched_tools: tuple[str, ...]
 30      permission_denials: tuple[PermissionDenial, ...]
 31      usage: UsageSummary
 32      stop_reason: str
 33  
 34  
 35  @dataclass
 36  class QueryEnginePort:
 37      manifest: PortManifest
 38      config: QueryEngineConfig = field(default_factory=QueryEngineConfig)
 39      session_id: str = field(default_factory=lambda: uuid4().hex)
 40      mutable_messages: list[str] = field(default_factory=list)
 41      permission_denials: list[PermissionDenial] = field(default_factory=list)
 42      total_usage: UsageSummary = field(default_factory=UsageSummary)
 43      transcript_store: TranscriptStore = field(default_factory=TranscriptStore)
 44  
 45      @classmethod
 46      def from_workspace(cls) -> 'QueryEnginePort':
 47          return cls(manifest=build_port_manifest())
 48  
 49      @classmethod
 50      def from_saved_session(cls, session_id: str) -> 'QueryEnginePort':
 51          stored = load_session(session_id)
 52          transcript = TranscriptStore(entries=list(stored.messages), flushed=True)
 53          return cls(
 54              manifest=build_port_manifest(),
 55              session_id=stored.session_id,
 56              mutable_messages=list(stored.messages),
 57              total_usage=UsageSummary(stored.input_tokens, stored.output_tokens),
 58              transcript_store=transcript,
 59          )
 60  
 61      def submit_message(
 62          self,
 63          prompt: str,
 64          matched_commands: tuple[str, ...] = (),
 65          matched_tools: tuple[str, ...] = (),
 66          denied_tools: tuple[PermissionDenial, ...] = (),
 67      ) -> TurnResult:
 68          if len(self.mutable_messages) >= self.config.max_turns:
 69              output = f'Max turns reached before processing prompt: {prompt}'
 70              return TurnResult(
 71                  prompt=prompt,
 72                  output=output,
 73                  matched_commands=matched_commands,
 74                  matched_tools=matched_tools,
 75                  permission_denials=denied_tools,
 76                  usage=self.total_usage,
 77                  stop_reason='max_turns_reached',
 78              )
 79  
 80          summary_lines = [
 81              f'Prompt: {prompt}',
 82              f'Matched commands: {", ".join(matched_commands) if matched_commands else "none"}',
 83              f'Matched tools: {", ".join(matched_tools) if matched_tools else "none"}',
 84              f'Permission denials: {len(denied_tools)}',
 85          ]
 86          output = self._format_output(summary_lines)
 87          projected_usage = self.total_usage.add_turn(prompt, output)
 88          stop_reason = 'completed'
 89          if projected_usage.input_tokens + projected_usage.output_tokens > self.config.max_budget_tokens:
 90              stop_reason = 'max_budget_reached'
 91          self.mutable_messages.append(prompt)
 92          self.transcript_store.append(prompt)
 93          self.permission_denials.extend(denied_tools)
 94          self.total_usage = projected_usage
 95          self.compact_messages_if_needed()
 96          return TurnResult(
 97              prompt=prompt,
 98              output=output,
 99              matched_commands=matched_commands,
100              matched_tools=matched_tools,
101              permission_denials=denied_tools,
102              usage=self.total_usage,
103              stop_reason=stop_reason,
104          )
105  
106      def stream_submit_message(
107          self,
108          prompt: str,
109          matched_commands: tuple[str, ...] = (),
110          matched_tools: tuple[str, ...] = (),
111          denied_tools: tuple[PermissionDenial, ...] = (),
112      ):
113          yield {'type': 'message_start', 'session_id': self.session_id, 'prompt': prompt}
114          if matched_commands:
115              yield {'type': 'command_match', 'commands': matched_commands}
116          if matched_tools:
117              yield {'type': 'tool_match', 'tools': matched_tools}
118          if denied_tools:
119              yield {'type': 'permission_denial', 'denials': [denial.tool_name for denial in denied_tools]}
120          result = self.submit_message(prompt, matched_commands, matched_tools, denied_tools)
121          yield {'type': 'message_delta', 'text': result.output}
122          yield {
123              'type': 'message_stop',
124              'usage': {'input_tokens': result.usage.input_tokens, 'output_tokens': result.usage.output_tokens},
125              'stop_reason': result.stop_reason,
126              'transcript_size': len(self.transcript_store.entries),
127          }
128  
129      def compact_messages_if_needed(self) -> None:
130          if len(self.mutable_messages) > self.config.compact_after_turns:
131              self.mutable_messages[:] = self.mutable_messages[-self.config.compact_after_turns :]
132          self.transcript_store.compact(self.config.compact_after_turns)
133  
134      def replay_user_messages(self) -> tuple[str, ...]:
135          return self.transcript_store.replay()
136  
137      def flush_transcript(self) -> None:
138          self.transcript_store.flush()
139  
140      def persist_session(self) -> str:
141          self.flush_transcript()
142          path = save_session(
143              StoredSession(
144                  session_id=self.session_id,
145                  messages=tuple(self.mutable_messages),
146                  input_tokens=self.total_usage.input_tokens,
147                  output_tokens=self.total_usage.output_tokens,
148              )
149          )
150          return str(path)
151  
152      def _format_output(self, summary_lines: list[str]) -> str:
153          if self.config.structured_output:
154              payload = {
155                  'summary': summary_lines,
156                  'session_id': self.session_id,
157              }
158              return self._render_structured_output(payload)
159          return '\n'.join(summary_lines)
160  
161      def _render_structured_output(self, payload: dict[str, object]) -> str:
162          last_error: Exception | None = None
163          for _ in range(self.config.structured_retry_limit):
164              try:
165                  return json.dumps(payload, indent=2)
166              except (TypeError, ValueError) as exc:  # pragma: no cover - defensive branch
167                  last_error = exc
168                  payload = {'summary': ['structured output retry'], 'session_id': self.session_id}
169          raise RuntimeError('structured output rendering failed') from last_error
170  
171      def render_summary(self) -> str:
172          command_backlog = build_command_backlog()
173          tool_backlog = build_tool_backlog()
174          sections = [
175              '# Python Porting Workspace Summary',
176              '',
177              self.manifest.to_markdown(),
178              '',
179              f'Command surface: {len(command_backlog.modules)} mirrored entries',
180              *command_backlog.summary_lines()[:10],
181              '',
182              f'Tool surface: {len(tool_backlog.modules)} mirrored entries',
183              *tool_backlog.summary_lines()[:10],
184              '',
185              f'Session id: {self.session_id}',
186              f'Conversation turns stored: {len(self.mutable_messages)}',
187              f'Permission denials tracked: {len(self.permission_denials)}',
188              f'Usage totals: in={self.total_usage.input_tokens} out={self.total_usage.output_tokens}',
189              f'Max turns: {self.config.max_turns}',
190              f'Max budget tokens: {self.config.max_budget_tokens}',
191              f'Transcript flushed: {self.transcript_store.flushed}',
192          ]
193          return '\n'.join(sections)