/ integrations / instrumented.py
instrumented.py
 1  """Transparent LLMProvider wrapper that captures timing and token usage."""
 2  
 3  __all__ = ["InstrumentedProvider"]
 4  
 5  import time
 6  from collections.abc import Sequence
 7  
 8  from pydantic import BaseModel
 9  
10  from integrations.llm import LLMProvider
11  from models.llm import Attachment
12  from telemetry import RunStats
13  
14  
15  class InstrumentedProvider:
16      """Wraps any LLMProvider to capture per-call timing and token usage.
17  
18      Args:
19          provider: The real LLM provider to delegate to.
20          stats: Shared accumulator for the current pipeline run.
21          phase: Phase label used when recording into stats
22                 (e.g. "extract", "analyze", "tailor", "main").
23      """
24  
25      def __init__(
26          self,
27          provider: LLMProvider,
28          stats: RunStats,
29          *,
30          phase: str,
31      ) -> None:
32          self._provider = provider
33          self._stats = stats
34          self._phase = phase
35  
36      def complete(
37          self,
38          system: str,
39          user: str,
40          *,
41          temperature: float | None = None,
42          seed: int | None = None,
43          response_schema: type[BaseModel] | None = None,
44          attachments: Sequence[Attachment] | None = None,
45      ) -> str:
46          """Delegate to the wrapped provider, recording timing and token usage."""
47          start = time.perf_counter()
48          try:
49              result = self._provider.complete(
50                  system,
51                  user,
52                  temperature=temperature,
53                  seed=seed,
54                  response_schema=response_schema,
55                  attachments=attachments,
56              )
57          # Broad catch: wraps 6 providers with distinct exception trees.
58          except Exception:
59              self._stats.record_llm_error(
60                  duration_s=time.perf_counter() - start, phase=self._phase
61              )
62              raise
63          duration = time.perf_counter() - start
64          in_tok, out_tok = self._provider.last_usage
65          self._stats.record_llm_call(
66              duration_s=duration,
67              input_tokens=in_tok,
68              output_tokens=out_tok,
69              phase=self._phase,
70          )
71          return result
72  
73      def ping(self, temperature: float | None = None) -> None:
74          """Delegates to the wrapped provider. Not instrumented."""
75          self._provider.ping(temperature=temperature)
76  
77      @property
78      def last_usage(self) -> tuple[int, int]:
79          """Delegates to the wrapped provider.
80  
81          Reflects the most recent successful call on the current thread.
82          """
83          return self._provider.last_usage