/ integrations / instrumented.py
instrumented.py
1 """Transparent LLMProvider wrapper that captures timing and token usage.""" 2 3 __all__ = ["InstrumentedProvider"] 4 5 import time 6 from collections.abc import Sequence 7 8 from pydantic import BaseModel 9 10 from integrations.llm import LLMProvider 11 from models.llm import Attachment 12 from telemetry import RunStats 13 14 15 class InstrumentedProvider: 16 """Wraps any LLMProvider to capture per-call timing and token usage. 17 18 Args: 19 provider: The real LLM provider to delegate to. 20 stats: Shared accumulator for the current pipeline run. 21 phase: Phase label used when recording into stats 22 (e.g. "extract", "analyze", "tailor", "main"). 23 """ 24 25 def __init__( 26 self, 27 provider: LLMProvider, 28 stats: RunStats, 29 *, 30 phase: str, 31 ) -> None: 32 self._provider = provider 33 self._stats = stats 34 self._phase = phase 35 36 def complete( 37 self, 38 system: str, 39 user: str, 40 *, 41 temperature: float | None = None, 42 seed: int | None = None, 43 response_schema: type[BaseModel] | None = None, 44 attachments: Sequence[Attachment] | None = None, 45 ) -> str: 46 """Delegate to the wrapped provider, recording timing and token usage.""" 47 start = time.perf_counter() 48 try: 49 result = self._provider.complete( 50 system, 51 user, 52 temperature=temperature, 53 seed=seed, 54 response_schema=response_schema, 55 attachments=attachments, 56 ) 57 # Broad catch: wraps 6 providers with distinct exception trees. 58 except Exception: 59 self._stats.record_llm_error( 60 duration_s=time.perf_counter() - start, phase=self._phase 61 ) 62 raise 63 duration = time.perf_counter() - start 64 in_tok, out_tok = self._provider.last_usage 65 self._stats.record_llm_call( 66 duration_s=duration, 67 input_tokens=in_tok, 68 output_tokens=out_tok, 69 phase=self._phase, 70 ) 71 return result 72 73 def ping(self, temperature: float | None = None) -> None: 74 """Delegates to the wrapped provider. Not instrumented.""" 75 self._provider.ping(temperature=temperature) 76 77 @property 78 def last_usage(self) -> tuple[int, int]: 79 """Delegates to the wrapped provider. 80 81 Reflects the most recent successful call on the current thread. 82 """ 83 return self._provider.last_usage