/ agent / google_code_assist.py
google_code_assist.py
  1  """Google Code Assist API client — project discovery, onboarding, quota.
  2  
  3  The Code Assist API powers Google's official gemini-cli. It sits at
  4  ``cloudcode-pa.googleapis.com`` and provides:
  5  
  6  - Free tier access (generous daily quota) for personal Google accounts
  7  - Paid tier access via GCP projects with billing / Workspace / Standard / Enterprise
  8  
  9  This module handles the control-plane dance needed before inference:
 10  
 11  1. ``load_code_assist()`` — probe the user's account to learn what tier they're on
 12     and whether a ``cloudaicompanionProject`` is already assigned.
 13  2. ``onboard_user()`` — if the user hasn't been onboarded yet (new account, fresh
 14     free tier, etc.), call this with the chosen tier + project id. Supports LRO
 15     polling for slow provisioning.
 16  3. ``retrieve_user_quota()`` — fetch the ``buckets[]`` array showing remaining
 17     quota per model, used by the ``/gquota`` slash command.
 18  
 19  VPC-SC handling: enterprise accounts under a VPC Service Controls perimeter
 20  will get ``SECURITY_POLICY_VIOLATED`` on ``load_code_assist``. We catch this
 21  and force the account to ``standard-tier`` so the call chain still succeeds.
 22  
 23  Derived from opencode-gemini-auth (MIT) and clawdbot/extensions/google. The
 24  request/response shapes are specific to Google's internal Code Assist API,
 25  documented nowhere public — we copy them from the reference implementations.
 26  """
 27  
 28  from __future__ import annotations
 29  
 30  import json
 31  import logging
 32  import time
 33  import urllib.error
 34  import urllib.parse
 35  import urllib.request
 36  import uuid
 37  from dataclasses import dataclass, field
 38  from typing import Any, Dict, List, Optional
 39  
 40  logger = logging.getLogger(__name__)
 41  
 42  
 43  # =============================================================================
 44  # Constants
 45  # =============================================================================
 46  
 47  CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com"
 48  
 49  # Fallback endpoints tried when prod returns an error during project discovery
 50  FALLBACK_ENDPOINTS = [
 51      "https://daily-cloudcode-pa.sandbox.googleapis.com",
 52      "https://autopush-cloudcode-pa.sandbox.googleapis.com",
 53  ]
 54  
 55  # Tier identifiers that Google's API uses
 56  FREE_TIER_ID = "free-tier"
 57  LEGACY_TIER_ID = "legacy-tier"
 58  STANDARD_TIER_ID = "standard-tier"
 59  
 60  # Default HTTP headers matching gemini-cli's fingerprint.
 61  # Google may reject unrecognized User-Agents on these internal endpoints.
 62  _GEMINI_CLI_USER_AGENT = "google-api-nodejs-client/9.15.1 (gzip)"
 63  _X_GOOG_API_CLIENT = "gl-node/24.0.0"
 64  _DEFAULT_REQUEST_TIMEOUT = 30.0
 65  _ONBOARDING_POLL_ATTEMPTS = 12
 66  _ONBOARDING_POLL_INTERVAL_SECONDS = 5.0
 67  
 68  
 69  class CodeAssistError(RuntimeError):
 70      """Exception raised by the Code Assist (``cloudcode-pa``) integration.
 71  
 72      Carries HTTP status / response / retry-after metadata so the agent's
 73      ``error_classifier._extract_status_code`` and the main loop's Retry-After
 74      handling (which walks ``error.response.headers``) pick up the right
 75      signals.  Without these, 429s from the OAuth path look like opaque
 76      ``RuntimeError`` and skip the rate-limit path.
 77      """
 78  
 79      def __init__(
 80          self,
 81          message: str,
 82          *,
 83          code: str = "code_assist_error",
 84          status_code: Optional[int] = None,
 85          response: Any = None,
 86          retry_after: Optional[float] = None,
 87          details: Optional[Dict[str, Any]] = None,
 88      ) -> None:
 89          super().__init__(message)
 90          self.code = code
 91          # ``status_code`` is picked up by ``agent.error_classifier._extract_status_code``
 92          # so a 429 from Code Assist classifies as FailoverReason.rate_limit and
 93          # triggers the main loop's fallback_providers chain the same way SDK
 94          # errors do.
 95          self.status_code = status_code
 96          # ``response`` is the underlying ``httpx.Response`` (or a shim with a
 97          # ``.headers`` mapping and ``.json()`` method).  The main loop reads
 98          # ``error.response.headers["Retry-After"]`` to honor Google's retry
 99          # hints when the backend throttles us.
100          self.response = response
101          # Parsed ``Retry-After`` seconds (kept separately for convenience —
102          # Google returns retry hints in both the header and the error body's
103          # ``google.rpc.RetryInfo`` details, and we pick whichever we found).
104          self.retry_after = retry_after
105          # Parsed structured error details from the Google error envelope
106          # (e.g. ``{"reason": "MODEL_CAPACITY_EXHAUSTED", "status": "RESOURCE_EXHAUSTED"}``).
107          # Useful for logging and for tests that want to assert on specifics.
108          self.details = details or {}
109  
110  
111  class ProjectIdRequiredError(CodeAssistError):
112      def __init__(self, message: str = "GCP project id required for this tier") -> None:
113          super().__init__(message, code="code_assist_project_id_required")
114  
115  
116  # =============================================================================
117  # HTTP primitive (auth via Bearer token passed per-call)
118  # =============================================================================
119  
120  def _build_headers(access_token: str, *, user_agent_model: str = "") -> Dict[str, str]:
121      ua = _GEMINI_CLI_USER_AGENT
122      if user_agent_model:
123          ua = f"{ua} model/{user_agent_model}"
124      return {
125          "Content-Type": "application/json",
126          "Accept": "application/json",
127          "Authorization": f"Bearer {access_token}",
128          "User-Agent": ua,
129          "X-Goog-Api-Client": _X_GOOG_API_CLIENT,
130          "x-activity-request-id": str(uuid.uuid4()),
131      }
132  
133  
134  def _client_metadata() -> Dict[str, str]:
135      """Match Google's gemini-cli exactly — unrecognized metadata may be rejected."""
136      return {
137          "ideType": "IDE_UNSPECIFIED",
138          "platform": "PLATFORM_UNSPECIFIED",
139          "pluginType": "GEMINI",
140      }
141  
142  
143  def _post_json(
144      url: str,
145      body: Dict[str, Any],
146      access_token: str,
147      *,
148      timeout: float = _DEFAULT_REQUEST_TIMEOUT,
149      user_agent_model: str = "",
150  ) -> Dict[str, Any]:
151      data = json.dumps(body).encode("utf-8")
152      request = urllib.request.Request(
153          url, data=data, method="POST",
154          headers=_build_headers(access_token, user_agent_model=user_agent_model),
155      )
156      try:
157          with urllib.request.urlopen(request, timeout=timeout) as response:
158              raw = response.read().decode("utf-8", errors="replace")
159              return json.loads(raw) if raw else {}
160      except urllib.error.HTTPError as exc:
161          detail = ""
162          try:
163              detail = exc.read().decode("utf-8", errors="replace")
164          except Exception:
165              pass
166          # Special case: VPC-SC violation should be distinguishable
167          if _is_vpc_sc_violation(detail):
168              raise CodeAssistError(
169                  f"VPC-SC policy violation: {detail}",
170                  code="code_assist_vpc_sc",
171              ) from exc
172          raise CodeAssistError(
173              f"Code Assist HTTP {exc.code}: {detail or exc.reason}",
174              code=f"code_assist_http_{exc.code}",
175          ) from exc
176      except urllib.error.URLError as exc:
177          raise CodeAssistError(
178              f"Code Assist request failed: {exc}",
179              code="code_assist_network_error",
180          ) from exc
181  
182  
183  def _is_vpc_sc_violation(body: str) -> bool:
184      """Detect a VPC Service Controls violation from a response body."""
185      if not body:
186          return False
187      try:
188          parsed = json.loads(body)
189      except (json.JSONDecodeError, ValueError):
190          return "SECURITY_POLICY_VIOLATED" in body
191      # Walk the nested error structure Google uses
192      error = parsed.get("error") if isinstance(parsed, dict) else None
193      if not isinstance(error, dict):
194          return False
195      details = error.get("details") or []
196      if isinstance(details, list):
197          for item in details:
198              if isinstance(item, dict):
199                  reason = item.get("reason") or ""
200                  if reason == "SECURITY_POLICY_VIOLATED":
201                      return True
202      msg = str(error.get("message", ""))
203      return "SECURITY_POLICY_VIOLATED" in msg
204  
205  
206  # =============================================================================
207  # load_code_assist — discovers current tier + assigned project
208  # =============================================================================
209  
210  @dataclass
211  class CodeAssistProjectInfo:
212      """Result from ``load_code_assist``."""
213      current_tier_id: str = ""
214      cloudaicompanion_project: str = ""   # Google-managed project (free tier)
215      allowed_tiers: List[str] = field(default_factory=list)
216      raw: Dict[str, Any] = field(default_factory=dict)
217  
218  
219  def load_code_assist(
220      access_token: str,
221      *,
222      project_id: str = "",
223      user_agent_model: str = "",
224  ) -> CodeAssistProjectInfo:
225      """Call ``POST /v1internal:loadCodeAssist`` with prod → sandbox fallback.
226  
227      Returns whatever tier + project info Google reports. On VPC-SC violations,
228      returns a synthetic ``standard-tier`` result so the chain can continue.
229      """
230      body: Dict[str, Any] = {
231          "metadata": {
232              "duetProject": project_id,
233              **_client_metadata(),
234          },
235      }
236      if project_id:
237          body["cloudaicompanionProject"] = project_id
238  
239      endpoints = [CODE_ASSIST_ENDPOINT] + FALLBACK_ENDPOINTS
240      last_err: Optional[Exception] = None
241      for endpoint in endpoints:
242          url = f"{endpoint}/v1internal:loadCodeAssist"
243          try:
244              resp = _post_json(url, body, access_token, user_agent_model=user_agent_model)
245              return _parse_load_response(resp)
246          except CodeAssistError as exc:
247              if exc.code == "code_assist_vpc_sc":
248                  logger.info("VPC-SC violation on %s — defaulting to standard-tier", endpoint)
249                  return CodeAssistProjectInfo(
250                      current_tier_id=STANDARD_TIER_ID,
251                      cloudaicompanion_project=project_id,
252                  )
253              last_err = exc
254              logger.warning("loadCodeAssist failed on %s: %s", endpoint, exc)
255              continue
256      if last_err:
257          raise last_err
258      return CodeAssistProjectInfo()
259  
260  
261  def _parse_load_response(resp: Dict[str, Any]) -> CodeAssistProjectInfo:
262      current_tier = resp.get("currentTier") or {}
263      tier_id = str(current_tier.get("id") or "") if isinstance(current_tier, dict) else ""
264      project = str(resp.get("cloudaicompanionProject") or "")
265      allowed = resp.get("allowedTiers") or []
266      allowed_ids: List[str] = []
267      if isinstance(allowed, list):
268          for t in allowed:
269              if isinstance(t, dict):
270                  tid = str(t.get("id") or "")
271                  if tid:
272                      allowed_ids.append(tid)
273      return CodeAssistProjectInfo(
274          current_tier_id=tier_id,
275          cloudaicompanion_project=project,
276          allowed_tiers=allowed_ids,
277          raw=resp,
278      )
279  
280  
281  # =============================================================================
282  # onboard_user — provisions a new user on a tier (with LRO polling)
283  # =============================================================================
284  
285  def onboard_user(
286      access_token: str,
287      *,
288      tier_id: str,
289      project_id: str = "",
290      user_agent_model: str = "",
291  ) -> Dict[str, Any]:
292      """Call ``POST /v1internal:onboardUser`` to provision the user.
293  
294      For paid tiers, ``project_id`` is REQUIRED (raises ProjectIdRequiredError).
295      For free tiers, ``project_id`` is optional — Google will assign one.
296  
297      Returns the final operation response. Polls ``/v1internal/<name>`` for up
298      to ``_ONBOARDING_POLL_ATTEMPTS`` × ``_ONBOARDING_POLL_INTERVAL_SECONDS``
299      (default: 12 × 5s = 1 min).
300      """
301      if tier_id != FREE_TIER_ID and tier_id != LEGACY_TIER_ID and not project_id:
302          raise ProjectIdRequiredError(
303              f"Tier {tier_id!r} requires a GCP project id. "
304              "Set HERMES_GEMINI_PROJECT_ID or GOOGLE_CLOUD_PROJECT."
305          )
306  
307      body: Dict[str, Any] = {
308          "tierId": tier_id,
309          "metadata": _client_metadata(),
310      }
311      if project_id:
312          body["cloudaicompanionProject"] = project_id
313  
314      endpoint = CODE_ASSIST_ENDPOINT
315      url = f"{endpoint}/v1internal:onboardUser"
316      resp = _post_json(url, body, access_token, user_agent_model=user_agent_model)
317  
318      # Poll if LRO (long-running operation)
319      if not resp.get("done"):
320          op_name = resp.get("name", "")
321          if not op_name:
322              return resp
323          for attempt in range(_ONBOARDING_POLL_ATTEMPTS):
324              time.sleep(_ONBOARDING_POLL_INTERVAL_SECONDS)
325              poll_url = f"{endpoint}/v1internal/{op_name}"
326              try:
327                  poll_resp = _post_json(poll_url, {}, access_token, user_agent_model=user_agent_model)
328              except CodeAssistError as exc:
329                  logger.warning("Onboarding poll attempt %d failed: %s", attempt + 1, exc)
330                  continue
331              if poll_resp.get("done"):
332                  return poll_resp
333          logger.warning("Onboarding did not complete within %d attempts", _ONBOARDING_POLL_ATTEMPTS)
334      return resp
335  
336  
337  # =============================================================================
338  # retrieve_user_quota — for /gquota
339  # =============================================================================
340  
341  @dataclass
342  class QuotaBucket:
343      model_id: str
344      token_type: str = ""
345      remaining_fraction: float = 0.0
346      reset_time_iso: str = ""
347      raw: Dict[str, Any] = field(default_factory=dict)
348  
349  
350  def retrieve_user_quota(
351      access_token: str,
352      *,
353      project_id: str = "",
354      user_agent_model: str = "",
355  ) -> List[QuotaBucket]:
356      """Call ``POST /v1internal:retrieveUserQuota`` and parse ``buckets[]``."""
357      body: Dict[str, Any] = {}
358      if project_id:
359          body["project"] = project_id
360      url = f"{CODE_ASSIST_ENDPOINT}/v1internal:retrieveUserQuota"
361      resp = _post_json(url, body, access_token, user_agent_model=user_agent_model)
362      raw_buckets = resp.get("buckets") or []
363      buckets: List[QuotaBucket] = []
364      if not isinstance(raw_buckets, list):
365          return buckets
366      for b in raw_buckets:
367          if not isinstance(b, dict):
368              continue
369          buckets.append(QuotaBucket(
370              model_id=str(b.get("modelId") or ""),
371              token_type=str(b.get("tokenType") or ""),
372              remaining_fraction=float(b.get("remainingFraction") or 0.0),
373              reset_time_iso=str(b.get("resetTime") or ""),
374              raw=b,
375          ))
376      return buckets
377  
378  
379  # =============================================================================
380  # Project context resolution
381  # =============================================================================
382  
383  @dataclass
384  class ProjectContext:
385      """Resolved state for a given OAuth session."""
386      project_id: str = ""           # effective project id sent on requests
387      managed_project_id: str = ""   # Google-assigned project (free tier)
388      tier_id: str = ""
389      source: str = ""               # "env", "config", "discovered", "onboarded"
390  
391  
392  def resolve_project_context(
393      access_token: str,
394      *,
395      configured_project_id: str = "",
396      env_project_id: str = "",
397      user_agent_model: str = "",
398  ) -> ProjectContext:
399      """Figure out what project id + tier to use for requests.
400  
401      Priority:
402        1. If configured_project_id or env_project_id is set, use that directly
403           and short-circuit (no discovery needed).
404        2. Otherwise call loadCodeAssist to see what Google says.
405        3. If no tier assigned yet, onboard the user (free tier default).
406      """
407      # Short-circuit: caller provided a project id
408      if configured_project_id:
409          return ProjectContext(
410              project_id=configured_project_id,
411              tier_id=STANDARD_TIER_ID,  # assume paid since they specified one
412              source="config",
413          )
414      if env_project_id:
415          return ProjectContext(
416              project_id=env_project_id,
417              tier_id=STANDARD_TIER_ID,
418              source="env",
419          )
420  
421      # Discover via loadCodeAssist
422      info = load_code_assist(access_token, user_agent_model=user_agent_model)
423  
424      effective_project = info.cloudaicompanion_project
425      tier = info.current_tier_id
426  
427      if not tier:
428          # User hasn't been onboarded — provision them on free tier
429          onboard_resp = onboard_user(
430              access_token,
431              tier_id=FREE_TIER_ID,
432              project_id="",
433              user_agent_model=user_agent_model,
434          )
435          # Re-parse from the onboard response
436          response_body = onboard_resp.get("response") or {}
437          if isinstance(response_body, dict):
438              effective_project = (
439                  effective_project
440                  or str(response_body.get("cloudaicompanionProject") or "")
441              )
442          tier = FREE_TIER_ID
443          source = "onboarded"
444      else:
445          source = "discovered"
446  
447      return ProjectContext(
448          project_id=effective_project,
449          managed_project_id=effective_project if tier == FREE_TIER_ID else "",
450          tier_id=tier,
451          source=source,
452      )