google_code_assist.py
1 """Google Code Assist API client — project discovery, onboarding, quota. 2 3 The Code Assist API powers Google's official gemini-cli. It sits at 4 ``cloudcode-pa.googleapis.com`` and provides: 5 6 - Free tier access (generous daily quota) for personal Google accounts 7 - Paid tier access via GCP projects with billing / Workspace / Standard / Enterprise 8 9 This module handles the control-plane dance needed before inference: 10 11 1. ``load_code_assist()`` — probe the user's account to learn what tier they're on 12 and whether a ``cloudaicompanionProject`` is already assigned. 13 2. ``onboard_user()`` — if the user hasn't been onboarded yet (new account, fresh 14 free tier, etc.), call this with the chosen tier + project id. Supports LRO 15 polling for slow provisioning. 16 3. ``retrieve_user_quota()`` — fetch the ``buckets[]`` array showing remaining 17 quota per model, used by the ``/gquota`` slash command. 18 19 VPC-SC handling: enterprise accounts under a VPC Service Controls perimeter 20 will get ``SECURITY_POLICY_VIOLATED`` on ``load_code_assist``. We catch this 21 and force the account to ``standard-tier`` so the call chain still succeeds. 22 23 Derived from opencode-gemini-auth (MIT) and clawdbot/extensions/google. The 24 request/response shapes are specific to Google's internal Code Assist API, 25 documented nowhere public — we copy them from the reference implementations. 26 """ 27 28 from __future__ import annotations 29 30 import json 31 import logging 32 import time 33 import urllib.error 34 import urllib.parse 35 import urllib.request 36 import uuid 37 from dataclasses import dataclass, field 38 from typing import Any, Dict, List, Optional 39 40 logger = logging.getLogger(__name__) 41 42 43 # ============================================================================= 44 # Constants 45 # ============================================================================= 46 47 CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com" 48 49 # Fallback endpoints tried when prod returns an error during project discovery 50 FALLBACK_ENDPOINTS = [ 51 "https://daily-cloudcode-pa.sandbox.googleapis.com", 52 "https://autopush-cloudcode-pa.sandbox.googleapis.com", 53 ] 54 55 # Tier identifiers that Google's API uses 56 FREE_TIER_ID = "free-tier" 57 LEGACY_TIER_ID = "legacy-tier" 58 STANDARD_TIER_ID = "standard-tier" 59 60 # Default HTTP headers matching gemini-cli's fingerprint. 61 # Google may reject unrecognized User-Agents on these internal endpoints. 62 _GEMINI_CLI_USER_AGENT = "google-api-nodejs-client/9.15.1 (gzip)" 63 _X_GOOG_API_CLIENT = "gl-node/24.0.0" 64 _DEFAULT_REQUEST_TIMEOUT = 30.0 65 _ONBOARDING_POLL_ATTEMPTS = 12 66 _ONBOARDING_POLL_INTERVAL_SECONDS = 5.0 67 68 69 class CodeAssistError(RuntimeError): 70 """Exception raised by the Code Assist (``cloudcode-pa``) integration. 71 72 Carries HTTP status / response / retry-after metadata so the agent's 73 ``error_classifier._extract_status_code`` and the main loop's Retry-After 74 handling (which walks ``error.response.headers``) pick up the right 75 signals. Without these, 429s from the OAuth path look like opaque 76 ``RuntimeError`` and skip the rate-limit path. 77 """ 78 79 def __init__( 80 self, 81 message: str, 82 *, 83 code: str = "code_assist_error", 84 status_code: Optional[int] = None, 85 response: Any = None, 86 retry_after: Optional[float] = None, 87 details: Optional[Dict[str, Any]] = None, 88 ) -> None: 89 super().__init__(message) 90 self.code = code 91 # ``status_code`` is picked up by ``agent.error_classifier._extract_status_code`` 92 # so a 429 from Code Assist classifies as FailoverReason.rate_limit and 93 # triggers the main loop's fallback_providers chain the same way SDK 94 # errors do. 95 self.status_code = status_code 96 # ``response`` is the underlying ``httpx.Response`` (or a shim with a 97 # ``.headers`` mapping and ``.json()`` method). The main loop reads 98 # ``error.response.headers["Retry-After"]`` to honor Google's retry 99 # hints when the backend throttles us. 100 self.response = response 101 # Parsed ``Retry-After`` seconds (kept separately for convenience — 102 # Google returns retry hints in both the header and the error body's 103 # ``google.rpc.RetryInfo`` details, and we pick whichever we found). 104 self.retry_after = retry_after 105 # Parsed structured error details from the Google error envelope 106 # (e.g. ``{"reason": "MODEL_CAPACITY_EXHAUSTED", "status": "RESOURCE_EXHAUSTED"}``). 107 # Useful for logging and for tests that want to assert on specifics. 108 self.details = details or {} 109 110 111 class ProjectIdRequiredError(CodeAssistError): 112 def __init__(self, message: str = "GCP project id required for this tier") -> None: 113 super().__init__(message, code="code_assist_project_id_required") 114 115 116 # ============================================================================= 117 # HTTP primitive (auth via Bearer token passed per-call) 118 # ============================================================================= 119 120 def _build_headers(access_token: str, *, user_agent_model: str = "") -> Dict[str, str]: 121 ua = _GEMINI_CLI_USER_AGENT 122 if user_agent_model: 123 ua = f"{ua} model/{user_agent_model}" 124 return { 125 "Content-Type": "application/json", 126 "Accept": "application/json", 127 "Authorization": f"Bearer {access_token}", 128 "User-Agent": ua, 129 "X-Goog-Api-Client": _X_GOOG_API_CLIENT, 130 "x-activity-request-id": str(uuid.uuid4()), 131 } 132 133 134 def _client_metadata() -> Dict[str, str]: 135 """Match Google's gemini-cli exactly — unrecognized metadata may be rejected.""" 136 return { 137 "ideType": "IDE_UNSPECIFIED", 138 "platform": "PLATFORM_UNSPECIFIED", 139 "pluginType": "GEMINI", 140 } 141 142 143 def _post_json( 144 url: str, 145 body: Dict[str, Any], 146 access_token: str, 147 *, 148 timeout: float = _DEFAULT_REQUEST_TIMEOUT, 149 user_agent_model: str = "", 150 ) -> Dict[str, Any]: 151 data = json.dumps(body).encode("utf-8") 152 request = urllib.request.Request( 153 url, data=data, method="POST", 154 headers=_build_headers(access_token, user_agent_model=user_agent_model), 155 ) 156 try: 157 with urllib.request.urlopen(request, timeout=timeout) as response: 158 raw = response.read().decode("utf-8", errors="replace") 159 return json.loads(raw) if raw else {} 160 except urllib.error.HTTPError as exc: 161 detail = "" 162 try: 163 detail = exc.read().decode("utf-8", errors="replace") 164 except Exception: 165 pass 166 # Special case: VPC-SC violation should be distinguishable 167 if _is_vpc_sc_violation(detail): 168 raise CodeAssistError( 169 f"VPC-SC policy violation: {detail}", 170 code="code_assist_vpc_sc", 171 ) from exc 172 raise CodeAssistError( 173 f"Code Assist HTTP {exc.code}: {detail or exc.reason}", 174 code=f"code_assist_http_{exc.code}", 175 ) from exc 176 except urllib.error.URLError as exc: 177 raise CodeAssistError( 178 f"Code Assist request failed: {exc}", 179 code="code_assist_network_error", 180 ) from exc 181 182 183 def _is_vpc_sc_violation(body: str) -> bool: 184 """Detect a VPC Service Controls violation from a response body.""" 185 if not body: 186 return False 187 try: 188 parsed = json.loads(body) 189 except (json.JSONDecodeError, ValueError): 190 return "SECURITY_POLICY_VIOLATED" in body 191 # Walk the nested error structure Google uses 192 error = parsed.get("error") if isinstance(parsed, dict) else None 193 if not isinstance(error, dict): 194 return False 195 details = error.get("details") or [] 196 if isinstance(details, list): 197 for item in details: 198 if isinstance(item, dict): 199 reason = item.get("reason") or "" 200 if reason == "SECURITY_POLICY_VIOLATED": 201 return True 202 msg = str(error.get("message", "")) 203 return "SECURITY_POLICY_VIOLATED" in msg 204 205 206 # ============================================================================= 207 # load_code_assist — discovers current tier + assigned project 208 # ============================================================================= 209 210 @dataclass 211 class CodeAssistProjectInfo: 212 """Result from ``load_code_assist``.""" 213 current_tier_id: str = "" 214 cloudaicompanion_project: str = "" # Google-managed project (free tier) 215 allowed_tiers: List[str] = field(default_factory=list) 216 raw: Dict[str, Any] = field(default_factory=dict) 217 218 219 def load_code_assist( 220 access_token: str, 221 *, 222 project_id: str = "", 223 user_agent_model: str = "", 224 ) -> CodeAssistProjectInfo: 225 """Call ``POST /v1internal:loadCodeAssist`` with prod → sandbox fallback. 226 227 Returns whatever tier + project info Google reports. On VPC-SC violations, 228 returns a synthetic ``standard-tier`` result so the chain can continue. 229 """ 230 body: Dict[str, Any] = { 231 "metadata": { 232 "duetProject": project_id, 233 **_client_metadata(), 234 }, 235 } 236 if project_id: 237 body["cloudaicompanionProject"] = project_id 238 239 endpoints = [CODE_ASSIST_ENDPOINT] + FALLBACK_ENDPOINTS 240 last_err: Optional[Exception] = None 241 for endpoint in endpoints: 242 url = f"{endpoint}/v1internal:loadCodeAssist" 243 try: 244 resp = _post_json(url, body, access_token, user_agent_model=user_agent_model) 245 return _parse_load_response(resp) 246 except CodeAssistError as exc: 247 if exc.code == "code_assist_vpc_sc": 248 logger.info("VPC-SC violation on %s — defaulting to standard-tier", endpoint) 249 return CodeAssistProjectInfo( 250 current_tier_id=STANDARD_TIER_ID, 251 cloudaicompanion_project=project_id, 252 ) 253 last_err = exc 254 logger.warning("loadCodeAssist failed on %s: %s", endpoint, exc) 255 continue 256 if last_err: 257 raise last_err 258 return CodeAssistProjectInfo() 259 260 261 def _parse_load_response(resp: Dict[str, Any]) -> CodeAssistProjectInfo: 262 current_tier = resp.get("currentTier") or {} 263 tier_id = str(current_tier.get("id") or "") if isinstance(current_tier, dict) else "" 264 project = str(resp.get("cloudaicompanionProject") or "") 265 allowed = resp.get("allowedTiers") or [] 266 allowed_ids: List[str] = [] 267 if isinstance(allowed, list): 268 for t in allowed: 269 if isinstance(t, dict): 270 tid = str(t.get("id") or "") 271 if tid: 272 allowed_ids.append(tid) 273 return CodeAssistProjectInfo( 274 current_tier_id=tier_id, 275 cloudaicompanion_project=project, 276 allowed_tiers=allowed_ids, 277 raw=resp, 278 ) 279 280 281 # ============================================================================= 282 # onboard_user — provisions a new user on a tier (with LRO polling) 283 # ============================================================================= 284 285 def onboard_user( 286 access_token: str, 287 *, 288 tier_id: str, 289 project_id: str = "", 290 user_agent_model: str = "", 291 ) -> Dict[str, Any]: 292 """Call ``POST /v1internal:onboardUser`` to provision the user. 293 294 For paid tiers, ``project_id`` is REQUIRED (raises ProjectIdRequiredError). 295 For free tiers, ``project_id`` is optional — Google will assign one. 296 297 Returns the final operation response. Polls ``/v1internal/<name>`` for up 298 to ``_ONBOARDING_POLL_ATTEMPTS`` × ``_ONBOARDING_POLL_INTERVAL_SECONDS`` 299 (default: 12 × 5s = 1 min). 300 """ 301 if tier_id != FREE_TIER_ID and tier_id != LEGACY_TIER_ID and not project_id: 302 raise ProjectIdRequiredError( 303 f"Tier {tier_id!r} requires a GCP project id. " 304 "Set HERMES_GEMINI_PROJECT_ID or GOOGLE_CLOUD_PROJECT." 305 ) 306 307 body: Dict[str, Any] = { 308 "tierId": tier_id, 309 "metadata": _client_metadata(), 310 } 311 if project_id: 312 body["cloudaicompanionProject"] = project_id 313 314 endpoint = CODE_ASSIST_ENDPOINT 315 url = f"{endpoint}/v1internal:onboardUser" 316 resp = _post_json(url, body, access_token, user_agent_model=user_agent_model) 317 318 # Poll if LRO (long-running operation) 319 if not resp.get("done"): 320 op_name = resp.get("name", "") 321 if not op_name: 322 return resp 323 for attempt in range(_ONBOARDING_POLL_ATTEMPTS): 324 time.sleep(_ONBOARDING_POLL_INTERVAL_SECONDS) 325 poll_url = f"{endpoint}/v1internal/{op_name}" 326 try: 327 poll_resp = _post_json(poll_url, {}, access_token, user_agent_model=user_agent_model) 328 except CodeAssistError as exc: 329 logger.warning("Onboarding poll attempt %d failed: %s", attempt + 1, exc) 330 continue 331 if poll_resp.get("done"): 332 return poll_resp 333 logger.warning("Onboarding did not complete within %d attempts", _ONBOARDING_POLL_ATTEMPTS) 334 return resp 335 336 337 # ============================================================================= 338 # retrieve_user_quota — for /gquota 339 # ============================================================================= 340 341 @dataclass 342 class QuotaBucket: 343 model_id: str 344 token_type: str = "" 345 remaining_fraction: float = 0.0 346 reset_time_iso: str = "" 347 raw: Dict[str, Any] = field(default_factory=dict) 348 349 350 def retrieve_user_quota( 351 access_token: str, 352 *, 353 project_id: str = "", 354 user_agent_model: str = "", 355 ) -> List[QuotaBucket]: 356 """Call ``POST /v1internal:retrieveUserQuota`` and parse ``buckets[]``.""" 357 body: Dict[str, Any] = {} 358 if project_id: 359 body["project"] = project_id 360 url = f"{CODE_ASSIST_ENDPOINT}/v1internal:retrieveUserQuota" 361 resp = _post_json(url, body, access_token, user_agent_model=user_agent_model) 362 raw_buckets = resp.get("buckets") or [] 363 buckets: List[QuotaBucket] = [] 364 if not isinstance(raw_buckets, list): 365 return buckets 366 for b in raw_buckets: 367 if not isinstance(b, dict): 368 continue 369 buckets.append(QuotaBucket( 370 model_id=str(b.get("modelId") or ""), 371 token_type=str(b.get("tokenType") or ""), 372 remaining_fraction=float(b.get("remainingFraction") or 0.0), 373 reset_time_iso=str(b.get("resetTime") or ""), 374 raw=b, 375 )) 376 return buckets 377 378 379 # ============================================================================= 380 # Project context resolution 381 # ============================================================================= 382 383 @dataclass 384 class ProjectContext: 385 """Resolved state for a given OAuth session.""" 386 project_id: str = "" # effective project id sent on requests 387 managed_project_id: str = "" # Google-assigned project (free tier) 388 tier_id: str = "" 389 source: str = "" # "env", "config", "discovered", "onboarded" 390 391 392 def resolve_project_context( 393 access_token: str, 394 *, 395 configured_project_id: str = "", 396 env_project_id: str = "", 397 user_agent_model: str = "", 398 ) -> ProjectContext: 399 """Figure out what project id + tier to use for requests. 400 401 Priority: 402 1. If configured_project_id or env_project_id is set, use that directly 403 and short-circuit (no discovery needed). 404 2. Otherwise call loadCodeAssist to see what Google says. 405 3. If no tier assigned yet, onboard the user (free tier default). 406 """ 407 # Short-circuit: caller provided a project id 408 if configured_project_id: 409 return ProjectContext( 410 project_id=configured_project_id, 411 tier_id=STANDARD_TIER_ID, # assume paid since they specified one 412 source="config", 413 ) 414 if env_project_id: 415 return ProjectContext( 416 project_id=env_project_id, 417 tier_id=STANDARD_TIER_ID, 418 source="env", 419 ) 420 421 # Discover via loadCodeAssist 422 info = load_code_assist(access_token, user_agent_model=user_agent_model) 423 424 effective_project = info.cloudaicompanion_project 425 tier = info.current_tier_id 426 427 if not tier: 428 # User hasn't been onboarded — provision them on free tier 429 onboard_resp = onboard_user( 430 access_token, 431 tier_id=FREE_TIER_ID, 432 project_id="", 433 user_agent_model=user_agent_model, 434 ) 435 # Re-parse from the onboard response 436 response_body = onboard_resp.get("response") or {} 437 if isinstance(response_body, dict): 438 effective_project = ( 439 effective_project 440 or str(response_body.get("cloudaicompanionProject") or "") 441 ) 442 tier = FREE_TIER_ID 443 source = "onboarded" 444 else: 445 source = "discovered" 446 447 return ProjectContext( 448 project_id=effective_project, 449 managed_project_id=effective_project if tier == FREE_TIER_ID else "", 450 tier_id=tier, 451 source=source, 452 )