/ src / api / routes / llm_proxy.py
llm_proxy.py
  1  """Claude-compatible LLM proxy endpoint."""
  2  from __future__ import annotations
  3  
  4  import json
  5  import logging
  6  import os
  7  import uuid
  8  from datetime import datetime, timezone
  9  from pathlib import Path
 10  from typing import Any, AsyncIterator
 11  
 12  import httpx
 13  from fastapi import APIRouter, Depends, HTTPException, Request, status
 14  from fastapi.responses import JSONResponse, StreamingResponse
 15  
 16  from ..deps import get_proxy_caller_id
 17  from ..llm_proxy.config import load_llm_proxy_config, ProxyConfigError
 18  from ..llm_proxy.translator import (
 19      claude_to_openai_messages,
 20      map_claude_tools,
 21      openai_to_claude_response,
 22      stream_openai_to_claude,
 23  )
 24  from ...config import load_sandboxed_envs
 25  
 26  logger = logging.getLogger(__name__)
 27  
 28  router = APIRouter(prefix="/llm-proxy/v1", tags=["llm-proxy"])
 29  session_router = APIRouter(prefix="/llm-proxy/s/{session_id}/v1", tags=["llm-proxy"])
 30  
 31  # Debug directory for saving request/response pairs
 32  DEBUG_DIR = Path(__file__).resolve().parents[3] / "data" / "llm_proxy_debug"
 33  
 34  # Maximum number of debug files to keep (oldest deleted first)
 35  DEBUG_MAX_FILES = 200
 36  
 37  # Fields to redact in debug output (values replaced with "***REDACTED***")
 38  _SENSITIVE_KEYS = frozenset({
 39      "api_key", "api-key", "x-api-key", "authorization",
 40      "token", "access_token", "secret", "password",
 41      "ANTHROPIC_API_KEY", "OPENAI_API_KEY", "OPENROUTER_API_KEY",
 42  })
 43  
 44  
 45  def _redact_sensitive(data: Any, *, _depth: int = 0) -> Any:
 46      """Recursively redact sensitive fields from data before writing to debug files."""
 47      if _depth > 20:
 48          return data
 49      if isinstance(data, dict):
 50          return {
 51              k: ("***REDACTED***" if k.lower() in _SENSITIVE_KEYS else _redact_sensitive(v, _depth=_depth + 1))
 52              for k, v in data.items()
 53          }
 54      if isinstance(data, list):
 55          return [_redact_sensitive(item, _depth=_depth + 1) for item in data]
 56      return data
 57  
 58  
 59  def _is_debug_enabled() -> bool:
 60      """Check if debug mode is enabled in config."""
 61      try:
 62          config = load_llm_proxy_config()
 63          return config.proxy.debug
 64      except Exception:
 65          return False
 66  
 67  
 68  def _get_debug_dir(session_id: str | None = None) -> Path:
 69      """Return the debug directory, optionally scoped to a session."""
 70      if session_id:
 71          return DEBUG_DIR / session_id
 72      return DEBUG_DIR
 73  
 74  
 75  def _cleanup_debug_files(session_id: str | None = None) -> None:
 76      """Remove oldest debug files if directory exceeds DEBUG_MAX_FILES."""
 77      try:
 78          target_dir = _get_debug_dir(session_id)
 79          if not target_dir.exists():
 80              return
 81          files = sorted(target_dir.glob("*.json"), key=lambda f: f.stat().st_mtime)
 82          if len(files) > DEBUG_MAX_FILES:
 83              for f in files[: len(files) - DEBUG_MAX_FILES]:
 84                  f.unlink(missing_ok=True)
 85      except Exception as e:
 86          logger.debug("LLM Proxy debug: cleanup error: %s", e)
 87  
 88  
 89  def _save_debug_file(
 90      filename: str,
 91      data: dict[str, Any],
 92      session_id: str | None = None,
 93  ) -> None:
 94      """Save a debug JSON file with sensitive fields redacted.
 95  
 96      Caller is responsible for checking debug mode.
 97      When session_id is provided, files are saved under data/llm_proxy_debug/<session_id>/.
 98      """
 99      try:
100          target_dir = _get_debug_dir(session_id)
101          target_dir.mkdir(parents=True, exist_ok=True)
102          redacted = _redact_sensitive(data)
103          redacted["timestamp"] = datetime.now(timezone.utc).isoformat()
104          filepath = target_dir / filename
105          filepath.write_text(json.dumps(redacted, indent=2, default=str))
106          logger.debug("LLM Proxy debug: saved %s", filepath)
107          _cleanup_debug_files(session_id)
108      except Exception as e:
109          logger.warning("LLM Proxy debug: failed to save %s: %s", filename, e)
110  
111  
112  def _log_debug_warning() -> None:
113      """Log a startup warning if debug mode is enabled."""
114      if _is_debug_enabled():
115          logger.warning(
116              "LLM Proxy debug mode is ENABLED. Request/response payloads will be "
117              "saved to %s. Sensitive fields are redacted, but disable debug mode "
118              "in production (set proxy.debug: false in llm-api-proxy.yaml).",
119              DEBUG_DIR,
120          )
121  
122  
123  # Log warning at import time (module load = server startup)
124  _log_debug_warning()
125  
126  
127  def _resolve_target(model_name: str) -> tuple[str, str, dict[str, Any]]:
128      config = load_llm_proxy_config()
129      mapping = config.models.get(model_name)
130      if mapping is not None:
131          return mapping.provider, mapping.target_model, config.providers
132  
133      if not config.routing.get("allow_unmapped_models", False):
134          raise HTTPException(
135              status_code=status.HTTP_400_BAD_REQUEST,
136              detail=f"Unknown model mapping for '{model_name}'",
137          )
138  
139      default_provider = config.routing.get("default_provider")
140      if not default_provider:
141          raise HTTPException(
142              status_code=status.HTTP_400_BAD_REQUEST,
143              detail="No default provider configured for unmapped models",
144          )
145      return default_provider, model_name, config.providers
146  
147  
148  def _get_api_key(provider: str, providers: dict[str, Any]) -> str:
149      provider_config = providers.get(provider)
150      if not provider_config:
151          raise HTTPException(
152              status_code=status.HTTP_400_BAD_REQUEST,
153              detail=f"Unknown provider '{provider}'",
154          )
155      api_key_env = provider_config.api_key_env
156  
157      # First check environment variable
158      api_key = os.environ.get(api_key_env)
159  
160      # Fall back to sandboxed_envs from secrets.yaml
161      if not api_key:
162          sandboxed_envs = load_sandboxed_envs()
163          api_key = sandboxed_envs.get(api_key_env)
164  
165      if not api_key:
166          raise HTTPException(
167              status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
168              detail=f"Missing API key for provider '{provider}' (env {api_key_env})",
169          )
170      return api_key
171  
172  
173  async def _proxy_anthropic(
174      payload: dict[str, Any],
175      provider_config: Any,
176      api_key: str,
177      stream: bool,
178      session_id: str | None = None,
179  ) -> JSONResponse | StreamingResponse:
180      headers = {
181          "x-api-key": api_key,
182          "anthropic-version": payload.get("anthropic_version", "2023-06-01"),
183      }
184  
185      # Debug: save request payload before forwarding
186      debug_enabled = _is_debug_enabled()
187      request_uid = str(uuid.uuid4())[:8]
188      if debug_enabled:
189          _save_debug_file(f"in_{request_uid}.json", {
190              "request_uid": request_uid,
191              "provider": "anthropic",
192              "target_model": payload.get("model", "unknown"),
193              "stream": stream,
194              "system_prompt_length": len(json.dumps(payload.get("system", ""))),
195              "messages_count": len(payload.get("messages", [])),
196              "payload": payload,
197          }, session_id=session_id)
198  
199      if stream:
200          client = httpx.AsyncClient(timeout=120)
201          req = client.build_request(
202              "POST",
203              f"{provider_config.base_url}/v1/messages",
204              headers=headers,
205              json=payload,
206          )
207          response = await client.send(req, stream=True)
208  
209          async def stream_and_cleanup() -> AsyncIterator[bytes]:
210              try:
211                  async for chunk in response.aiter_bytes():
212                      yield chunk
213              finally:
214                  await response.aclose()
215                  await client.aclose()
216  
217          return StreamingResponse(
218              stream_and_cleanup(),
219              media_type="text/event-stream",
220              status_code=response.status_code,
221          )
222  
223      async with httpx.AsyncClient(timeout=60) as client:
224          response = await client.post(
225              f"{provider_config.base_url}/v1/messages",
226              headers=headers,
227              json=payload,
228          )
229  
230      if debug_enabled:
231          try:
232              _save_debug_file(f"out_{request_uid}.json", {
233                  "request_uid": request_uid,
234                  "provider": "anthropic",
235                  "is_stream": False,
236                  "response": response.json(),
237              }, session_id=session_id)
238          except Exception:
239              pass
240  
241      return JSONResponse(status_code=response.status_code, content=response.json())
242  
243  
244  async def _proxy_openai(
245      payload: dict[str, Any],
246      provider_config: Any,
247      api_key: str,
248      target_model: str,
249      stream: bool = False,
250      session_id: str | None = None,
251  ) -> JSONResponse | StreamingResponse:
252      messages = claude_to_openai_messages(payload)
253      tools = payload.get("tools") or []
254  
255      body: dict[str, Any] = {
256          "model": target_model,
257          "messages": messages,
258          "stream": stream,
259      }
260  
261      # Enable stream_options to get usage in final chunk (OpenAI API)
262      if stream:
263          body["stream_options"] = {"include_usage": True}
264  
265      if tools:
266          openai_tools = map_claude_tools(tools)
267          body["tools"] = openai_tools
268          # Debug: Log tool count and first tool's schema structure
269          logger.debug(
270              "Proxy request: model=%s, tools=%d, stream=%s",
271              target_model,
272              len(openai_tools),
273              stream,
274          )
275          if openai_tools:
276              first_tool = openai_tools[0]
277              tool_name = first_tool.get("function", {}).get("name", "?")
278              params = first_tool.get("function", {}).get("parameters", {})
279              required = params.get("required", [])
280              logger.debug(
281                  "First tool: name=%s, required_params=%s",
282                  tool_name,
283                  required,
284              )
285  
286      for field in ("temperature", "max_tokens", "top_p"):
287          if field in payload:
288              body[field] = payload[field]
289  
290      headers = {
291          "Authorization": f"Bearer {api_key}",
292          "Content-Type": "application/json",
293      }
294  
295      # Check debug once to avoid re-loading config per call
296      debug_enabled = _is_debug_enabled()
297      request_uid = str(uuid.uuid4())[:8]
298  
299      if debug_enabled:
300          _save_debug_file(f"in_{request_uid}.json", {
301              "request_uid": request_uid,
302              "target_model": target_model,
303              "payload": body,
304          }, session_id=session_id)
305  
306      if stream:
307          async def stream_response() -> AsyncIterator[str]:
308              translated_chunks: list[str] = []
309  
310              async with httpx.AsyncClient(timeout=120) as client:
311                  async with client.stream(
312                      "POST",
313                      f"{provider_config.base_url}/chat/completions",
314                      headers=headers,
315                      json=body,
316                  ) as response:
317                      response.raise_for_status()
318                      async for chunk in stream_openai_to_claude(response, target_model):
319                          if debug_enabled:
320                              translated_chunks.append(chunk)
321                          yield chunk
322  
323              if debug_enabled:
324                  _save_debug_file(f"out_{request_uid}.json", {
325                      "request_uid": request_uid,
326                      "is_stream": True,
327                      "translated_chunks": translated_chunks,
328                  }, session_id=session_id)
329  
330          return StreamingResponse(
331              stream_response(),
332              media_type="text/event-stream",
333          )
334  
335      # Non-streaming mode
336      async with httpx.AsyncClient(timeout=120) as client:
337          response = await client.post(
338              f"{provider_config.base_url}/chat/completions",
339              headers=headers,
340              json=body,
341          )
342          response.raise_for_status()
343          raw_response = response.json()
344          translated = openai_to_claude_response(raw_response, target_model)
345  
346          if debug_enabled:
347              _save_debug_file(f"out_{request_uid}.json", {
348                  "request_uid": request_uid,
349                  "is_stream": False,
350                  "raw_response": raw_response,
351                  "translated_response": translated,
352              }, session_id=session_id)
353  
354          return JSONResponse(status_code=response.status_code, content=translated)
355  
356  
357  async def _handle_proxy_messages(
358      request: Request,
359      session_id: str | None = None,
360  ) -> JSONResponse | StreamingResponse:
361      """Core proxy logic shared by both session-scoped and non-session routes."""
362      try:
363          payload = await request.json()
364      except json.JSONDecodeError as exc:
365          raise HTTPException(
366              status_code=status.HTTP_400_BAD_REQUEST,
367              detail="Invalid JSON payload",
368          ) from exc
369  
370      model_name = payload.get("model")
371      if not model_name:
372          raise HTTPException(
373              status_code=status.HTTP_400_BAD_REQUEST,
374              detail="Missing model in request payload",
375          )
376  
377      logger.info("LLM Proxy: received request for model=%s (session=%s)", model_name, session_id)
378  
379      try:
380          provider_name, target_model, providers = _resolve_target(model_name)
381          logger.info(
382              "LLM Proxy: resolved model=%s -> provider=%s, target_model=%s",
383              model_name, provider_name, target_model,
384          )
385      except ProxyConfigError as exc:
386          raise HTTPException(
387              status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
388              detail=str(exc),
389          ) from exc
390  
391      provider_config = providers.get(provider_name)
392      if not provider_config:
393          raise HTTPException(
394              status_code=status.HTTP_400_BAD_REQUEST,
395              detail=f"Unknown provider '{provider_name}'",
396          )
397  
398      api_key = _get_api_key(provider_name, providers)
399      stream = bool(payload.get("stream"))
400  
401      if provider_config.type == "anthropic":
402          return await _proxy_anthropic(payload, provider_config, api_key, stream, session_id=session_id)
403      if provider_config.type in {"openai", "openai-compatible"}:
404          return await _proxy_openai(payload, provider_config, api_key, target_model, stream, session_id=session_id)
405  
406      raise HTTPException(
407          status_code=status.HTTP_400_BAD_REQUEST,
408          detail=f"Unsupported provider type '{provider_config.type}'",
409      )
410  
411  
412  # --- Non-session routes (backwards compatibility) ---
413  
414  @router.post("/messages/count_tokens", response_model=None)
415  async def count_tokens(
416      request: Request,
417      caller_id: str = Depends(get_proxy_caller_id),
418  ) -> JSONResponse:
419      """No-op handler for the SDK's count_tokens call."""
420      return JSONResponse(content={"input_tokens": 0})
421  
422  
423  @router.post("/messages", response_model=None)
424  async def proxy_messages(
425      request: Request,
426      user_id: str = Depends(get_proxy_caller_id),
427  ) -> JSONResponse | StreamingResponse:
428      return await _handle_proxy_messages(request)
429  
430  
431  # --- Session-scoped routes (debug files organized by session) ---
432  
433  @session_router.post("/messages/count_tokens", response_model=None)
434  async def count_tokens_session(
435      request: Request,
436      session_id: str,
437      caller_id: str = Depends(get_proxy_caller_id),
438  ) -> JSONResponse:
439      """No-op handler for the SDK's count_tokens call (session-scoped)."""
440      return JSONResponse(content={"input_tokens": 0})
441  
442  
443  @session_router.post("/messages", response_model=None)
444  async def proxy_messages_session(
445      request: Request,
446      session_id: str,
447      user_id: str = Depends(get_proxy_caller_id),
448  ) -> JSONResponse | StreamingResponse:
449      return await _handle_proxy_messages(request, session_id=session_id)