/ app / client.py
client.py
  1  from __future__ import annotations
  2  
  3  import json
  4  import logging
  5  import time
  6  import uuid
  7  from dataclasses import dataclass
  8  from typing import Any, Callable, Iterable
  9  
 10  import grpc
 11  from google.protobuf.json_format import MessageToDict
 12  from smg_grpc_proto.generated import vllm_engine_pb2, vllm_engine_pb2_grpc
 13  from transformers import AutoTokenizer, PreTrainedTokenizerBase
 14  
 15  from app.discover_vllm_grpc import discover_surface
 16  from app.schemas import FinalDecision, IncidentRequest, TriageUpdate
 17  
 18  logger = logging.getLogger(__name__)
 19  
 20  MetadataFactory = Callable[[], Iterable[tuple[str, str]]]
 21  
 22  
 23  class VllmGrpcError(RuntimeError):
 24      """Raised when a transport-level gRPC failure occurs."""
 25  
 26  
 27  @dataclass(slots=True)
 28  class GenerationResult:
 29      request_id: str
 30      selected_rpc: str
 31      raw_output: str
 32      updates: list[TriageUpdate]
 33      chunk_count: int
 34      time_to_first_update_ms: float | None
 35      end_to_end_latency_ms: float
 36      output_bytes_received: int
 37      finish_reason: str | None
 38      grpc_surface: dict[str, Any]
 39  
 40  
 41  class VllmGrpcClient:
 42      def __init__(
 43          self,
 44          endpoint: str = "localhost:8000",
 45          *,
 46          timeout: float = 120.0,
 47          metadata_factory: MetadataFactory | None = None,
 48          surface: dict[str, Any] | None = None,
 49      ) -> None:
 50          self.endpoint = endpoint
 51          self.timeout = timeout
 52          self.metadata_factory = metadata_factory
 53          self.surface = surface or discover_surface(endpoint=endpoint)
 54          self._channel = grpc.insecure_channel(
 55              endpoint,
 56              options=[
 57                  ("grpc.max_send_message_length", -1),
 58                  ("grpc.max_receive_message_length", -1),
 59                  ("grpc.keepalive_time_ms", 20000),
 60                  ("grpc.keepalive_timeout_ms", 10000),
 61              ],
 62          )
 63          self._stub = vllm_engine_pb2_grpc.VllmEngineStub(self._channel)
 64          self._tokenizer: PreTrainedTokenizerBase | None = None
 65          self._model_info: dict[str, Any] | None = None
 66  
 67      def close(self) -> None:
 68          self._channel.close()
 69  
 70      def _metadata(self) -> list[tuple[str, str]] | None:
 71          if self.metadata_factory is None:
 72              return None
 73          return list(self.metadata_factory())
 74  
 75      def _rpc_error(self, exc: grpc.RpcError, rpc_name: str) -> VllmGrpcError:
 76          code = exc.code().name if exc.code() else "UNKNOWN"
 77          details = exc.details() or "no details"
 78          logger.error("gRPC transport error on %s: %s %s", rpc_name, code, details)
 79          return VllmGrpcError(f"{rpc_name} failed with {code}: {details}")
 80  
 81      def health_check(self) -> dict[str, Any]:
 82          try:
 83              response = self._stub.HealthCheck(
 84                  vllm_engine_pb2.HealthCheckRequest(),
 85                  timeout=self.timeout,
 86                  metadata=self._metadata(),
 87              )
 88          except grpc.RpcError as exc:
 89              raise self._rpc_error(exc, "HealthCheck") from exc
 90          return MessageToDict(response, preserving_proto_field_name=True)
 91  
 92      def get_model_info(self) -> dict[str, Any]:
 93          if self._model_info is not None:
 94              return self._model_info
 95          try:
 96              response = self._stub.GetModelInfo(
 97                  vllm_engine_pb2.GetModelInfoRequest(),
 98                  timeout=self.timeout,
 99                  metadata=self._metadata(),
100              )
101          except grpc.RpcError as exc:
102              raise self._rpc_error(exc, "GetModelInfo") from exc
103          self._model_info = MessageToDict(response, preserving_proto_field_name=True)
104          return self._model_info
105  
106      def get_server_info(self) -> dict[str, Any]:
107          try:
108              response = self._stub.GetServerInfo(
109                  vllm_engine_pb2.GetServerInfoRequest(),
110                  timeout=self.timeout,
111                  metadata=self._metadata(),
112              )
113          except grpc.RpcError as exc:
114              raise self._rpc_error(exc, "GetServerInfo") from exc
115          return MessageToDict(response, preserving_proto_field_name=True)
116  
117      def _load_tokenizer(self) -> PreTrainedTokenizerBase:
118          if self._tokenizer is not None:
119              return self._tokenizer
120          model_info = self.get_model_info()
121          model_path = model_info["model_path"]
122          logger.info("Loading tokenizer from discovered model path %s", model_path)
123          self._tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
124          return self._tokenizer
125  
126      @staticmethod
127      def _generation_schema() -> dict[str, Any]:
128          raw_schema = FinalDecision.model_json_schema()
129  
130          def _prune(node: Any) -> Any:
131              if isinstance(node, dict):
132                  return {
133                      key: _prune(value)
134                      for key, value in node.items()
135                      if key not in {"title", "default", "examples"}
136                  }
137              if isinstance(node, list):
138                  return [_prune(value) for value in node]
139              return node
140  
141          return _prune(raw_schema)
142  
143      @staticmethod
144      def build_prompt(incident: IncidentRequest, strict_schema: bool) -> str:
145          incident_blob = json.dumps(incident.model_dump(mode="json"), indent=2)
146          contract = (
147              "Return exactly one JSON object and nothing else."
148              if strict_schema
149              else (
150                  "Stream concise updates, then emit the final JSON object between "
151                  "BEGIN_FINAL_DECISION_JSON and END_FINAL_DECISION_JSON."
152              )
153          )
154          return (
155              "You are the enterprise Incident Commander for a live operations bridge.\n"
156              "Use concise executive-safe language.\n"
157              "Do not reveal chain-of-thought. Do not emit markdown or code fences.\n"
158              f"{contract}\n"
159              "Populate fields in this order: incident_id, executive_summary, severity, "
160              "suspected_root_cause, impacted_assets, confidence, recommended_actions, "
161              "escalation_team, change_risk, machine_json_valid.\n"
162              "Recommended actions must only use allowed_actions as the basis for actions.\n"
163              "Set machine_json_valid to true if your final JSON object is valid.\n"
164              "Incident payload:\n"
165              f"{incident_blob}\n"
166          )
167  
168      @staticmethod
169      def _stage_updates(
170          incident_id: str,
171          accumulated_text: str,
172          elapsed_ms: float,
173          bytes_received: int,
174          emitted_stages: set[str],
175      ) -> list[TriageUpdate]:
176          stage_markers = [
177              ("executive_summary", "situation_assessment"),
178              ("suspected_root_cause", "probable_root_cause"),
179              ("recommended_actions", "recommended_actions"),
180              ("machine_json_valid", "final_decision_json"),
181          ]
182          updates: list[TriageUpdate] = []
183          for marker, stage in stage_markers:
184              if marker in accumulated_text and stage not in emitted_stages:
185                  emitted_stages.add(stage)
186                  updates.append(
187                      TriageUpdate(
188                          incident_id=incident_id,
189                          stage=stage,  # type: ignore[arg-type]
190                          text=f"Reached {stage} section in streamed JSON payload.",
191                          elapsed_ms=elapsed_ms,
192                          bytes_received=bytes_received,
193                      )
194                  )
195          return updates
196  
197      def _selected_generate_rpc(self) -> str:
198          for service in self.surface["services"]:
199              for method in service["methods"]:
200                  if method["name"] == "Generate":
201                      return method["full_name"]
202          return "/vllm.grpc.engine.VllmEngine/Generate"
203  
204      def generate_incident(
205          self,
206          incident: IncidentRequest,
207          *,
208          stream: bool = True,
209          max_tokens: int = 900,
210          timeout: float | None = None,
211          on_text: Callable[[str], None] | None = None,
212          on_update: Callable[[TriageUpdate], None] | None = None,
213      ) -> GenerationResult:
214          tokenizer = self._load_tokenizer()
215          request_id = f"{incident.incident_id}-{uuid.uuid4().hex[:8]}"
216          selected_rpc = self._selected_generate_rpc()
217          prompt = self.build_prompt(incident, strict_schema=True)
218          request = vllm_engine_pb2.GenerateRequest(
219              request_id=request_id,
220              text=prompt,
221              stream=stream,
222              sampling_params=vllm_engine_pb2.SamplingParams(
223                  temperature=0.0,
224                  top_p=1.0,
225                  max_tokens=max_tokens,
226                  json_schema=json.dumps(self._generation_schema(), separators=(",", ":")),
227              ),
228          )
229  
230          raw_parts: list[str] = []
231          updates = [
232              TriageUpdate(
233                  incident_id=incident.incident_id,
234                  stage="transport",
235                  text=f"Opened gRPC stream via {selected_rpc}",
236                  elapsed_ms=0.0,
237                  bytes_received=0,
238              )
239          ]
240          if on_update is not None:
241              on_update(updates[0])
242          emitted_stages: set[str] = set()
243          bytes_received = 0
244          chunk_count = 0
245          finish_reason: str | None = None
246          started_at = time.perf_counter()
247          first_update_ms: float | None = None
248  
249          try:
250              responses = self._stub.Generate(
251                  request,
252                  timeout=timeout or self.timeout,
253                  metadata=self._metadata(),
254              )
255              for response in responses:
256                  bytes_received += response.ByteSize()
257                  elapsed_ms = (time.perf_counter() - started_at) * 1000
258                  if response.HasField("chunk"):
259                      chunk_count += 1
260                      if first_update_ms is None:
261                          first_update_ms = elapsed_ms
262                      decoded = tokenizer.decode(
263                          response.chunk.token_ids,
264                          skip_special_tokens=False,
265                          clean_up_tokenization_spaces=False,
266                      )
267                      raw_parts.append(decoded)
268                      if on_text is not None:
269                          on_text(decoded)
270                      accumulated = "".join(raw_parts)
271                      new_updates = self._stage_updates(
272                          incident_id=incident.incident_id,
273                          accumulated_text=accumulated,
274                          elapsed_ms=elapsed_ms,
275                          bytes_received=bytes_received,
276                          emitted_stages=emitted_stages,
277                      )
278                      updates.extend(new_updates)
279                      if on_update is not None:
280                          for update in new_updates:
281                              on_update(update)
282                  elif response.HasField("complete"):
283                      finish_reason = response.complete.finish_reason or None
284          except grpc.RpcError as exc:
285              raise self._rpc_error(exc, "Generate") from exc
286  
287          return GenerationResult(
288              request_id=request_id,
289              selected_rpc=selected_rpc,
290              raw_output="".join(raw_parts),
291              updates=updates,
292              chunk_count=chunk_count,
293              time_to_first_update_ms=first_update_ms,
294              end_to_end_latency_ms=(time.perf_counter() - started_at) * 1000,
295              output_bytes_received=bytes_received,
296              finish_reason=finish_reason,
297              grpc_surface=self.surface,
298          )