/ agent / transports / bedrock.py
bedrock.py
  1  """AWS Bedrock Converse API transport.
  2  
  3  Delegates to the existing adapter functions in agent/bedrock_adapter.py.
  4  Bedrock uses its own boto3 client (not the OpenAI SDK), so the transport
  5  owns format conversion and normalization, while client construction and
  6  boto3 calls stay on AIAgent.
  7  """
  8  
  9  from typing import Any, Dict, List, Optional
 10  
 11  from agent.transports.base import ProviderTransport
 12  from agent.transports.types import NormalizedResponse, ToolCall, Usage
 13  
 14  
 15  class BedrockTransport(ProviderTransport):
 16      """Transport for api_mode='bedrock_converse'."""
 17  
 18      @property
 19      def api_mode(self) -> str:
 20          return "bedrock_converse"
 21  
 22      def convert_messages(self, messages: List[Dict[str, Any]], **kwargs) -> Any:
 23          """Convert OpenAI messages to Bedrock Converse format."""
 24          from agent.bedrock_adapter import convert_messages_to_converse
 25          return convert_messages_to_converse(messages)
 26  
 27      def convert_tools(self, tools: List[Dict[str, Any]]) -> Any:
 28          """Convert OpenAI tool schemas to Bedrock Converse toolConfig."""
 29          from agent.bedrock_adapter import convert_tools_to_converse
 30          return convert_tools_to_converse(tools)
 31  
 32      def build_kwargs(
 33          self,
 34          model: str,
 35          messages: List[Dict[str, Any]],
 36          tools: Optional[List[Dict[str, Any]]] = None,
 37          **params,
 38      ) -> Dict[str, Any]:
 39          """Build Bedrock converse() kwargs.
 40  
 41          Calls convert_messages and convert_tools internally.
 42  
 43          params:
 44              max_tokens: int — output token limit (default 4096)
 45              temperature: float | None
 46              guardrail_config: dict | None — Bedrock guardrails
 47              region: str — AWS region (default 'us-east-1')
 48          """
 49          from agent.bedrock_adapter import build_converse_kwargs
 50  
 51          region = params.get("region", "us-east-1")
 52          guardrail = params.get("guardrail_config")
 53  
 54          kwargs = build_converse_kwargs(
 55              model=model,
 56              messages=messages,
 57              tools=tools,
 58              max_tokens=params.get("max_tokens", 4096),
 59              temperature=params.get("temperature"),
 60              guardrail_config=guardrail,
 61          )
 62          # Sentinel keys for dispatch — agent pops these before the boto3 call
 63          kwargs["__bedrock_converse__"] = True
 64          kwargs["__bedrock_region__"] = region
 65          return kwargs
 66  
 67      def normalize_response(self, response: Any, **kwargs) -> NormalizedResponse:
 68          """Normalize Bedrock response to NormalizedResponse.
 69  
 70          Handles two shapes:
 71          1. Raw boto3 dict (from direct converse() calls)
 72          2. Already-normalized SimpleNamespace with .choices (from dispatch site)
 73          """
 74          from agent.bedrock_adapter import normalize_converse_response
 75  
 76          # Normalize to OpenAI-compatible SimpleNamespace
 77          if hasattr(response, "choices") and response.choices:
 78              # Already normalized at dispatch site
 79              ns = response
 80          else:
 81              # Raw boto3 dict
 82              ns = normalize_converse_response(response)
 83  
 84          choice = ns.choices[0]
 85          msg = choice.message
 86          finish_reason = choice.finish_reason or "stop"
 87  
 88          tool_calls = None
 89          if msg.tool_calls:
 90              tool_calls = [
 91                  ToolCall(
 92                      id=tc.id,
 93                      name=tc.function.name,
 94                      arguments=tc.function.arguments,
 95                  )
 96                  for tc in msg.tool_calls
 97              ]
 98  
 99          usage = None
100          if hasattr(ns, "usage") and ns.usage:
101              u = ns.usage
102              usage = Usage(
103                  prompt_tokens=getattr(u, "prompt_tokens", 0) or 0,
104                  completion_tokens=getattr(u, "completion_tokens", 0) or 0,
105                  total_tokens=getattr(u, "total_tokens", 0) or 0,
106              )
107  
108          reasoning = getattr(msg, "reasoning", None) or getattr(msg, "reasoning_content", None)
109  
110          return NormalizedResponse(
111              content=msg.content,
112              tool_calls=tool_calls,
113              finish_reason=finish_reason,
114              reasoning=reasoning,
115              usage=usage,
116          )
117  
118      def validate_response(self, response: Any) -> bool:
119          """Check Bedrock response structure.
120  
121          After normalize_converse_response, the response has OpenAI-compatible
122          .choices — same check as chat_completions.
123          """
124          if response is None:
125              return False
126          # Raw Bedrock dict response — check for 'output' key
127          if isinstance(response, dict):
128              return "output" in response
129          # Already-normalized SimpleNamespace
130          if hasattr(response, "choices"):
131              return bool(response.choices)
132          return False
133  
134      def map_finish_reason(self, raw_reason: str) -> str:
135          """Map Bedrock stop reason to OpenAI finish_reason.
136  
137          The adapter already does this mapping inside normalize_converse_response,
138          so this is only used for direct access to raw responses.
139          """
140          _MAP = {
141              "end_turn": "stop",
142              "tool_use": "tool_calls",
143              "max_tokens": "length",
144              "stop_sequence": "stop",
145              "guardrail_intervened": "content_filter",
146              "content_filtered": "content_filter",
147          }
148          return _MAP.get(raw_reason, "stop")
149  
150  
151  # Auto-register on import
152  from agent.transports import register_transport  # noqa: E402
153  
154  register_transport("bedrock_converse", BedrockTransport)