llm_proxy.py
1 """Claude-compatible LLM proxy endpoint.""" 2 from __future__ import annotations 3 4 import json 5 import logging 6 import os 7 import uuid 8 from datetime import datetime, timezone 9 from pathlib import Path 10 from typing import Any, AsyncIterator 11 12 import httpx 13 from fastapi import APIRouter, Depends, HTTPException, Request, status 14 from fastapi.responses import JSONResponse, StreamingResponse 15 16 from ..deps import get_proxy_caller_id 17 from ..llm_proxy.config import load_llm_proxy_config, ProxyConfigError 18 from ..llm_proxy.translator import ( 19 claude_to_openai_messages, 20 map_claude_tools, 21 openai_to_claude_response, 22 stream_openai_to_claude, 23 ) 24 from ...config import load_sandboxed_envs 25 26 logger = logging.getLogger(__name__) 27 28 router = APIRouter(prefix="/llm-proxy/v1", tags=["llm-proxy"]) 29 session_router = APIRouter(prefix="/llm-proxy/s/{session_id}/v1", tags=["llm-proxy"]) 30 31 # Debug directory for saving request/response pairs 32 DEBUG_DIR = Path(__file__).resolve().parents[3] / "data" / "llm_proxy_debug" 33 34 # Maximum number of debug files to keep (oldest deleted first) 35 DEBUG_MAX_FILES = 200 36 37 # Fields to redact in debug output (values replaced with "***REDACTED***") 38 _SENSITIVE_KEYS = frozenset({ 39 "api_key", "api-key", "x-api-key", "authorization", 40 "token", "access_token", "secret", "password", 41 "ANTHROPIC_API_KEY", "OPENAI_API_KEY", "OPENROUTER_API_KEY", 42 }) 43 44 45 def _redact_sensitive(data: Any, *, _depth: int = 0) -> Any: 46 """Recursively redact sensitive fields from data before writing to debug files.""" 47 if _depth > 20: 48 return data 49 if isinstance(data, dict): 50 return { 51 k: ("***REDACTED***" if k.lower() in _SENSITIVE_KEYS else _redact_sensitive(v, _depth=_depth + 1)) 52 for k, v in data.items() 53 } 54 if isinstance(data, list): 55 return [_redact_sensitive(item, _depth=_depth + 1) for item in data] 56 return data 57 58 59 def _is_debug_enabled() -> bool: 60 """Check if debug mode is enabled in config.""" 61 try: 62 config = load_llm_proxy_config() 63 return config.proxy.debug 64 except Exception: 65 return False 66 67 68 def _get_debug_dir(session_id: str | None = None) -> Path: 69 """Return the debug directory, optionally scoped to a session.""" 70 if session_id: 71 return DEBUG_DIR / session_id 72 return DEBUG_DIR 73 74 75 def _cleanup_debug_files(session_id: str | None = None) -> None: 76 """Remove oldest debug files if directory exceeds DEBUG_MAX_FILES.""" 77 try: 78 target_dir = _get_debug_dir(session_id) 79 if not target_dir.exists(): 80 return 81 files = sorted(target_dir.glob("*.json"), key=lambda f: f.stat().st_mtime) 82 if len(files) > DEBUG_MAX_FILES: 83 for f in files[: len(files) - DEBUG_MAX_FILES]: 84 f.unlink(missing_ok=True) 85 except Exception as e: 86 logger.debug("LLM Proxy debug: cleanup error: %s", e) 87 88 89 def _save_debug_file( 90 filename: str, 91 data: dict[str, Any], 92 session_id: str | None = None, 93 ) -> None: 94 """Save a debug JSON file with sensitive fields redacted. 95 96 Caller is responsible for checking debug mode. 97 When session_id is provided, files are saved under data/llm_proxy_debug/<session_id>/. 98 """ 99 try: 100 target_dir = _get_debug_dir(session_id) 101 target_dir.mkdir(parents=True, exist_ok=True) 102 redacted = _redact_sensitive(data) 103 redacted["timestamp"] = datetime.now(timezone.utc).isoformat() 104 filepath = target_dir / filename 105 filepath.write_text(json.dumps(redacted, indent=2, default=str)) 106 logger.debug("LLM Proxy debug: saved %s", filepath) 107 _cleanup_debug_files(session_id) 108 except Exception as e: 109 logger.warning("LLM Proxy debug: failed to save %s: %s", filename, e) 110 111 112 def _log_debug_warning() -> None: 113 """Log a startup warning if debug mode is enabled.""" 114 if _is_debug_enabled(): 115 logger.warning( 116 "LLM Proxy debug mode is ENABLED. Request/response payloads will be " 117 "saved to %s. Sensitive fields are redacted, but disable debug mode " 118 "in production (set proxy.debug: false in llm-api-proxy.yaml).", 119 DEBUG_DIR, 120 ) 121 122 123 # Log warning at import time (module load = server startup) 124 _log_debug_warning() 125 126 127 def _resolve_target(model_name: str) -> tuple[str, str, dict[str, Any]]: 128 config = load_llm_proxy_config() 129 mapping = config.models.get(model_name) 130 if mapping is not None: 131 return mapping.provider, mapping.target_model, config.providers 132 133 if not config.routing.get("allow_unmapped_models", False): 134 raise HTTPException( 135 status_code=status.HTTP_400_BAD_REQUEST, 136 detail=f"Unknown model mapping for '{model_name}'", 137 ) 138 139 default_provider = config.routing.get("default_provider") 140 if not default_provider: 141 raise HTTPException( 142 status_code=status.HTTP_400_BAD_REQUEST, 143 detail="No default provider configured for unmapped models", 144 ) 145 return default_provider, model_name, config.providers 146 147 148 def _get_api_key(provider: str, providers: dict[str, Any]) -> str: 149 provider_config = providers.get(provider) 150 if not provider_config: 151 raise HTTPException( 152 status_code=status.HTTP_400_BAD_REQUEST, 153 detail=f"Unknown provider '{provider}'", 154 ) 155 api_key_env = provider_config.api_key_env 156 157 # First check environment variable 158 api_key = os.environ.get(api_key_env) 159 160 # Fall back to sandboxed_envs from secrets.yaml 161 if not api_key: 162 sandboxed_envs = load_sandboxed_envs() 163 api_key = sandboxed_envs.get(api_key_env) 164 165 if not api_key: 166 raise HTTPException( 167 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 168 detail=f"Missing API key for provider '{provider}' (env {api_key_env})", 169 ) 170 return api_key 171 172 173 async def _proxy_anthropic( 174 payload: dict[str, Any], 175 provider_config: Any, 176 api_key: str, 177 stream: bool, 178 session_id: str | None = None, 179 ) -> JSONResponse | StreamingResponse: 180 headers = { 181 "x-api-key": api_key, 182 "anthropic-version": payload.get("anthropic_version", "2023-06-01"), 183 } 184 185 # Debug: save request payload before forwarding 186 debug_enabled = _is_debug_enabled() 187 request_uid = str(uuid.uuid4())[:8] 188 if debug_enabled: 189 _save_debug_file(f"in_{request_uid}.json", { 190 "request_uid": request_uid, 191 "provider": "anthropic", 192 "target_model": payload.get("model", "unknown"), 193 "stream": stream, 194 "system_prompt_length": len(json.dumps(payload.get("system", ""))), 195 "messages_count": len(payload.get("messages", [])), 196 "payload": payload, 197 }, session_id=session_id) 198 199 if stream: 200 client = httpx.AsyncClient(timeout=120) 201 req = client.build_request( 202 "POST", 203 f"{provider_config.base_url}/v1/messages", 204 headers=headers, 205 json=payload, 206 ) 207 response = await client.send(req, stream=True) 208 209 async def stream_and_cleanup() -> AsyncIterator[bytes]: 210 try: 211 async for chunk in response.aiter_bytes(): 212 yield chunk 213 finally: 214 await response.aclose() 215 await client.aclose() 216 217 return StreamingResponse( 218 stream_and_cleanup(), 219 media_type="text/event-stream", 220 status_code=response.status_code, 221 ) 222 223 async with httpx.AsyncClient(timeout=60) as client: 224 response = await client.post( 225 f"{provider_config.base_url}/v1/messages", 226 headers=headers, 227 json=payload, 228 ) 229 230 if debug_enabled: 231 try: 232 _save_debug_file(f"out_{request_uid}.json", { 233 "request_uid": request_uid, 234 "provider": "anthropic", 235 "is_stream": False, 236 "response": response.json(), 237 }, session_id=session_id) 238 except Exception: 239 pass 240 241 return JSONResponse(status_code=response.status_code, content=response.json()) 242 243 244 async def _proxy_openai( 245 payload: dict[str, Any], 246 provider_config: Any, 247 api_key: str, 248 target_model: str, 249 stream: bool = False, 250 session_id: str | None = None, 251 ) -> JSONResponse | StreamingResponse: 252 messages = claude_to_openai_messages(payload) 253 tools = payload.get("tools") or [] 254 255 body: dict[str, Any] = { 256 "model": target_model, 257 "messages": messages, 258 "stream": stream, 259 } 260 261 # Enable stream_options to get usage in final chunk (OpenAI API) 262 if stream: 263 body["stream_options"] = {"include_usage": True} 264 265 if tools: 266 openai_tools = map_claude_tools(tools) 267 body["tools"] = openai_tools 268 # Debug: Log tool count and first tool's schema structure 269 logger.debug( 270 "Proxy request: model=%s, tools=%d, stream=%s", 271 target_model, 272 len(openai_tools), 273 stream, 274 ) 275 if openai_tools: 276 first_tool = openai_tools[0] 277 tool_name = first_tool.get("function", {}).get("name", "?") 278 params = first_tool.get("function", {}).get("parameters", {}) 279 required = params.get("required", []) 280 logger.debug( 281 "First tool: name=%s, required_params=%s", 282 tool_name, 283 required, 284 ) 285 286 for field in ("temperature", "max_tokens", "top_p"): 287 if field in payload: 288 body[field] = payload[field] 289 290 headers = { 291 "Authorization": f"Bearer {api_key}", 292 "Content-Type": "application/json", 293 } 294 295 # Check debug once to avoid re-loading config per call 296 debug_enabled = _is_debug_enabled() 297 request_uid = str(uuid.uuid4())[:8] 298 299 if debug_enabled: 300 _save_debug_file(f"in_{request_uid}.json", { 301 "request_uid": request_uid, 302 "target_model": target_model, 303 "payload": body, 304 }, session_id=session_id) 305 306 if stream: 307 async def stream_response() -> AsyncIterator[str]: 308 translated_chunks: list[str] = [] 309 310 async with httpx.AsyncClient(timeout=120) as client: 311 async with client.stream( 312 "POST", 313 f"{provider_config.base_url}/chat/completions", 314 headers=headers, 315 json=body, 316 ) as response: 317 response.raise_for_status() 318 async for chunk in stream_openai_to_claude(response, target_model): 319 if debug_enabled: 320 translated_chunks.append(chunk) 321 yield chunk 322 323 if debug_enabled: 324 _save_debug_file(f"out_{request_uid}.json", { 325 "request_uid": request_uid, 326 "is_stream": True, 327 "translated_chunks": translated_chunks, 328 }, session_id=session_id) 329 330 return StreamingResponse( 331 stream_response(), 332 media_type="text/event-stream", 333 ) 334 335 # Non-streaming mode 336 async with httpx.AsyncClient(timeout=120) as client: 337 response = await client.post( 338 f"{provider_config.base_url}/chat/completions", 339 headers=headers, 340 json=body, 341 ) 342 response.raise_for_status() 343 raw_response = response.json() 344 translated = openai_to_claude_response(raw_response, target_model) 345 346 if debug_enabled: 347 _save_debug_file(f"out_{request_uid}.json", { 348 "request_uid": request_uid, 349 "is_stream": False, 350 "raw_response": raw_response, 351 "translated_response": translated, 352 }, session_id=session_id) 353 354 return JSONResponse(status_code=response.status_code, content=translated) 355 356 357 async def _handle_proxy_messages( 358 request: Request, 359 session_id: str | None = None, 360 ) -> JSONResponse | StreamingResponse: 361 """Core proxy logic shared by both session-scoped and non-session routes.""" 362 try: 363 payload = await request.json() 364 except json.JSONDecodeError as exc: 365 raise HTTPException( 366 status_code=status.HTTP_400_BAD_REQUEST, 367 detail="Invalid JSON payload", 368 ) from exc 369 370 model_name = payload.get("model") 371 if not model_name: 372 raise HTTPException( 373 status_code=status.HTTP_400_BAD_REQUEST, 374 detail="Missing model in request payload", 375 ) 376 377 logger.info("LLM Proxy: received request for model=%s (session=%s)", model_name, session_id) 378 379 try: 380 provider_name, target_model, providers = _resolve_target(model_name) 381 logger.info( 382 "LLM Proxy: resolved model=%s -> provider=%s, target_model=%s", 383 model_name, provider_name, target_model, 384 ) 385 except ProxyConfigError as exc: 386 raise HTTPException( 387 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 388 detail=str(exc), 389 ) from exc 390 391 provider_config = providers.get(provider_name) 392 if not provider_config: 393 raise HTTPException( 394 status_code=status.HTTP_400_BAD_REQUEST, 395 detail=f"Unknown provider '{provider_name}'", 396 ) 397 398 api_key = _get_api_key(provider_name, providers) 399 stream = bool(payload.get("stream")) 400 401 if provider_config.type == "anthropic": 402 return await _proxy_anthropic(payload, provider_config, api_key, stream, session_id=session_id) 403 if provider_config.type in {"openai", "openai-compatible"}: 404 return await _proxy_openai(payload, provider_config, api_key, target_model, stream, session_id=session_id) 405 406 raise HTTPException( 407 status_code=status.HTTP_400_BAD_REQUEST, 408 detail=f"Unsupported provider type '{provider_config.type}'", 409 ) 410 411 412 # --- Non-session routes (backwards compatibility) --- 413 414 @router.post("/messages/count_tokens", response_model=None) 415 async def count_tokens( 416 request: Request, 417 caller_id: str = Depends(get_proxy_caller_id), 418 ) -> JSONResponse: 419 """No-op handler for the SDK's count_tokens call.""" 420 return JSONResponse(content={"input_tokens": 0}) 421 422 423 @router.post("/messages", response_model=None) 424 async def proxy_messages( 425 request: Request, 426 user_id: str = Depends(get_proxy_caller_id), 427 ) -> JSONResponse | StreamingResponse: 428 return await _handle_proxy_messages(request) 429 430 431 # --- Session-scoped routes (debug files organized by session) --- 432 433 @session_router.post("/messages/count_tokens", response_model=None) 434 async def count_tokens_session( 435 request: Request, 436 session_id: str, 437 caller_id: str = Depends(get_proxy_caller_id), 438 ) -> JSONResponse: 439 """No-op handler for the SDK's count_tokens call (session-scoped).""" 440 return JSONResponse(content={"input_tokens": 0}) 441 442 443 @session_router.post("/messages", response_model=None) 444 async def proxy_messages_session( 445 request: Request, 446 session_id: str, 447 user_id: str = Depends(get_proxy_caller_id), 448 ) -> JSONResponse | StreamingResponse: 449 return await _handle_proxy_messages(request, session_id=session_id)