client.py
1 from __future__ import annotations 2 3 import json 4 import logging 5 import time 6 import uuid 7 from dataclasses import dataclass 8 from typing import Any, Callable, Iterable 9 10 import grpc 11 from google.protobuf.json_format import MessageToDict 12 from smg_grpc_proto.generated import vllm_engine_pb2, vllm_engine_pb2_grpc 13 from transformers import AutoTokenizer, PreTrainedTokenizerBase 14 15 from app.discover_vllm_grpc import discover_surface 16 from app.schemas import FinalDecision, IncidentRequest, TriageUpdate 17 18 logger = logging.getLogger(__name__) 19 20 MetadataFactory = Callable[[], Iterable[tuple[str, str]]] 21 22 23 class VllmGrpcError(RuntimeError): 24 """Raised when a transport-level gRPC failure occurs.""" 25 26 27 @dataclass(slots=True) 28 class GenerationResult: 29 request_id: str 30 selected_rpc: str 31 raw_output: str 32 updates: list[TriageUpdate] 33 chunk_count: int 34 time_to_first_update_ms: float | None 35 end_to_end_latency_ms: float 36 output_bytes_received: int 37 finish_reason: str | None 38 grpc_surface: dict[str, Any] 39 40 41 class VllmGrpcClient: 42 def __init__( 43 self, 44 endpoint: str = "localhost:8000", 45 *, 46 timeout: float = 120.0, 47 metadata_factory: MetadataFactory | None = None, 48 surface: dict[str, Any] | None = None, 49 ) -> None: 50 self.endpoint = endpoint 51 self.timeout = timeout 52 self.metadata_factory = metadata_factory 53 self.surface = surface or discover_surface(endpoint=endpoint) 54 self._channel = grpc.insecure_channel( 55 endpoint, 56 options=[ 57 ("grpc.max_send_message_length", -1), 58 ("grpc.max_receive_message_length", -1), 59 ("grpc.keepalive_time_ms", 20000), 60 ("grpc.keepalive_timeout_ms", 10000), 61 ], 62 ) 63 self._stub = vllm_engine_pb2_grpc.VllmEngineStub(self._channel) 64 self._tokenizer: PreTrainedTokenizerBase | None = None 65 self._model_info: dict[str, Any] | None = None 66 67 def close(self) -> None: 68 self._channel.close() 69 70 def _metadata(self) -> list[tuple[str, str]] | None: 71 if self.metadata_factory is None: 72 return None 73 return list(self.metadata_factory()) 74 75 def _rpc_error(self, exc: grpc.RpcError, rpc_name: str) -> VllmGrpcError: 76 code = exc.code().name if exc.code() else "UNKNOWN" 77 details = exc.details() or "no details" 78 logger.error("gRPC transport error on %s: %s %s", rpc_name, code, details) 79 return VllmGrpcError(f"{rpc_name} failed with {code}: {details}") 80 81 def health_check(self) -> dict[str, Any]: 82 try: 83 response = self._stub.HealthCheck( 84 vllm_engine_pb2.HealthCheckRequest(), 85 timeout=self.timeout, 86 metadata=self._metadata(), 87 ) 88 except grpc.RpcError as exc: 89 raise self._rpc_error(exc, "HealthCheck") from exc 90 return MessageToDict(response, preserving_proto_field_name=True) 91 92 def get_model_info(self) -> dict[str, Any]: 93 if self._model_info is not None: 94 return self._model_info 95 try: 96 response = self._stub.GetModelInfo( 97 vllm_engine_pb2.GetModelInfoRequest(), 98 timeout=self.timeout, 99 metadata=self._metadata(), 100 ) 101 except grpc.RpcError as exc: 102 raise self._rpc_error(exc, "GetModelInfo") from exc 103 self._model_info = MessageToDict(response, preserving_proto_field_name=True) 104 return self._model_info 105 106 def get_server_info(self) -> dict[str, Any]: 107 try: 108 response = self._stub.GetServerInfo( 109 vllm_engine_pb2.GetServerInfoRequest(), 110 timeout=self.timeout, 111 metadata=self._metadata(), 112 ) 113 except grpc.RpcError as exc: 114 raise self._rpc_error(exc, "GetServerInfo") from exc 115 return MessageToDict(response, preserving_proto_field_name=True) 116 117 def _load_tokenizer(self) -> PreTrainedTokenizerBase: 118 if self._tokenizer is not None: 119 return self._tokenizer 120 model_info = self.get_model_info() 121 model_path = model_info["model_path"] 122 logger.info("Loading tokenizer from discovered model path %s", model_path) 123 self._tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 124 return self._tokenizer 125 126 @staticmethod 127 def _generation_schema() -> dict[str, Any]: 128 raw_schema = FinalDecision.model_json_schema() 129 130 def _prune(node: Any) -> Any: 131 if isinstance(node, dict): 132 return { 133 key: _prune(value) 134 for key, value in node.items() 135 if key not in {"title", "default", "examples"} 136 } 137 if isinstance(node, list): 138 return [_prune(value) for value in node] 139 return node 140 141 return _prune(raw_schema) 142 143 @staticmethod 144 def build_prompt(incident: IncidentRequest, strict_schema: bool) -> str: 145 incident_blob = json.dumps(incident.model_dump(mode="json"), indent=2) 146 contract = ( 147 "Return exactly one JSON object and nothing else." 148 if strict_schema 149 else ( 150 "Stream concise updates, then emit the final JSON object between " 151 "BEGIN_FINAL_DECISION_JSON and END_FINAL_DECISION_JSON." 152 ) 153 ) 154 return ( 155 "You are the enterprise Incident Commander for a live operations bridge.\n" 156 "Use concise executive-safe language.\n" 157 "Do not reveal chain-of-thought. Do not emit markdown or code fences.\n" 158 f"{contract}\n" 159 "Populate fields in this order: incident_id, executive_summary, severity, " 160 "suspected_root_cause, impacted_assets, confidence, recommended_actions, " 161 "escalation_team, change_risk, machine_json_valid.\n" 162 "Recommended actions must only use allowed_actions as the basis for actions.\n" 163 "Set machine_json_valid to true if your final JSON object is valid.\n" 164 "Incident payload:\n" 165 f"{incident_blob}\n" 166 ) 167 168 @staticmethod 169 def _stage_updates( 170 incident_id: str, 171 accumulated_text: str, 172 elapsed_ms: float, 173 bytes_received: int, 174 emitted_stages: set[str], 175 ) -> list[TriageUpdate]: 176 stage_markers = [ 177 ("executive_summary", "situation_assessment"), 178 ("suspected_root_cause", "probable_root_cause"), 179 ("recommended_actions", "recommended_actions"), 180 ("machine_json_valid", "final_decision_json"), 181 ] 182 updates: list[TriageUpdate] = [] 183 for marker, stage in stage_markers: 184 if marker in accumulated_text and stage not in emitted_stages: 185 emitted_stages.add(stage) 186 updates.append( 187 TriageUpdate( 188 incident_id=incident_id, 189 stage=stage, # type: ignore[arg-type] 190 text=f"Reached {stage} section in streamed JSON payload.", 191 elapsed_ms=elapsed_ms, 192 bytes_received=bytes_received, 193 ) 194 ) 195 return updates 196 197 def _selected_generate_rpc(self) -> str: 198 for service in self.surface["services"]: 199 for method in service["methods"]: 200 if method["name"] == "Generate": 201 return method["full_name"] 202 return "/vllm.grpc.engine.VllmEngine/Generate" 203 204 def generate_incident( 205 self, 206 incident: IncidentRequest, 207 *, 208 stream: bool = True, 209 max_tokens: int = 900, 210 timeout: float | None = None, 211 on_text: Callable[[str], None] | None = None, 212 on_update: Callable[[TriageUpdate], None] | None = None, 213 ) -> GenerationResult: 214 tokenizer = self._load_tokenizer() 215 request_id = f"{incident.incident_id}-{uuid.uuid4().hex[:8]}" 216 selected_rpc = self._selected_generate_rpc() 217 prompt = self.build_prompt(incident, strict_schema=True) 218 request = vllm_engine_pb2.GenerateRequest( 219 request_id=request_id, 220 text=prompt, 221 stream=stream, 222 sampling_params=vllm_engine_pb2.SamplingParams( 223 temperature=0.0, 224 top_p=1.0, 225 max_tokens=max_tokens, 226 json_schema=json.dumps(self._generation_schema(), separators=(",", ":")), 227 ), 228 ) 229 230 raw_parts: list[str] = [] 231 updates = [ 232 TriageUpdate( 233 incident_id=incident.incident_id, 234 stage="transport", 235 text=f"Opened gRPC stream via {selected_rpc}", 236 elapsed_ms=0.0, 237 bytes_received=0, 238 ) 239 ] 240 if on_update is not None: 241 on_update(updates[0]) 242 emitted_stages: set[str] = set() 243 bytes_received = 0 244 chunk_count = 0 245 finish_reason: str | None = None 246 started_at = time.perf_counter() 247 first_update_ms: float | None = None 248 249 try: 250 responses = self._stub.Generate( 251 request, 252 timeout=timeout or self.timeout, 253 metadata=self._metadata(), 254 ) 255 for response in responses: 256 bytes_received += response.ByteSize() 257 elapsed_ms = (time.perf_counter() - started_at) * 1000 258 if response.HasField("chunk"): 259 chunk_count += 1 260 if first_update_ms is None: 261 first_update_ms = elapsed_ms 262 decoded = tokenizer.decode( 263 response.chunk.token_ids, 264 skip_special_tokens=False, 265 clean_up_tokenization_spaces=False, 266 ) 267 raw_parts.append(decoded) 268 if on_text is not None: 269 on_text(decoded) 270 accumulated = "".join(raw_parts) 271 new_updates = self._stage_updates( 272 incident_id=incident.incident_id, 273 accumulated_text=accumulated, 274 elapsed_ms=elapsed_ms, 275 bytes_received=bytes_received, 276 emitted_stages=emitted_stages, 277 ) 278 updates.extend(new_updates) 279 if on_update is not None: 280 for update in new_updates: 281 on_update(update) 282 elif response.HasField("complete"): 283 finish_reason = response.complete.finish_reason or None 284 except grpc.RpcError as exc: 285 raise self._rpc_error(exc, "Generate") from exc 286 287 return GenerationResult( 288 request_id=request_id, 289 selected_rpc=selected_rpc, 290 raw_output="".join(raw_parts), 291 updates=updates, 292 chunk_count=chunk_count, 293 time_to_first_update_ms=first_update_ms, 294 end_to_end_latency_ms=(time.perf_counter() - started_at) * 1000, 295 output_bytes_received=bytes_received, 296 finish_reason=finish_reason, 297 grpc_surface=self.surface, 298 )