/ src / input_sources / s3_source.py
s3_source.py
  1  from __future__ import annotations
  2  
  3  """S3 input source implementation."""
  4  
  5  import logging
  6  import tempfile
  7  from pathlib import Path
  8  from typing import Any
  9  
 10  import boto3
 11  
 12  from .protocol import InputSource
 13  
 14  
 15  class S3InputSource:
 16      """Input source for AWS S3 buckets.
 17  
 18      Downloads files from S3 to temporary locations for processing.
 19      Handles cleanup of temporary files.
 20      """
 21  
 22      def __init__(self, config: dict[str, Any]):
 23          """Initialize the S3 input source.
 24  
 25          Parameters
 26          ----------
 27          config
 28              Configuration dictionary with keys:
 29              - bucket_name: str (required) - S3 bucket name
 30              - prefix: str (optional) - S3 key prefix to filter files
 31              - aws_access_key_id: str (optional) - AWS access key
 32              - aws_secret_access_key: str (optional) - AWS secret key
 33              - aws_session_token: str (optional) - AWS session token (for temporary credentials)
 34              - region_name: str (optional) - AWS region
 35              - endpoint_url: str (optional) - Custom S3 endpoint (for S3-compatible services)
 36          """
 37          self.config = config
 38          self.logger = logging.getLogger(__name__)
 39          self.temp_files: list[Path] = []
 40  
 41          # Validate required config
 42          if "bucket_name" not in config:
 43              raise ValueError("bucket_name is required for S3 input source")
 44  
 45          self.bucket_name = config["bucket_name"]
 46          self.prefix = config.get("prefix", "")
 47  
 48          self._s3_client = None
 49  
 50      @property
 51      def s3_client(self):
 52          """Get or create S3 client (lazy initialization)."""
 53          if self._s3_client is None:
 54              boto3_config = {
 55                  "aws_access_key_id": self.config.get("aws_access_key_id"),
 56                  "aws_secret_access_key": self.config.get("aws_secret_access_key"),
 57                  "aws_session_token": self.config.get("aws_session_token"),
 58                  "region_name": self.config.get("region_name"),
 59                  "endpoint_url": self.config.get("endpoint_url"),
 60              }
 61  
 62              self._s3_client = boto3.client("s3", **boto3_config)
 63              self.logger.info(f"Initialized S3 client for bucket: {self.bucket_name}")
 64  
 65          return self._s3_client
 66  
 67      def list_files(self, path: str = "", extensions: list[str] | None = None) -> list[str]:
 68          """List files in S3 bucket with optional prefix and extension filtering.
 69  
 70          Parameters
 71          ----------
 72          path
 73              S3 prefix to list files from (relative to configured prefix).
 74              If empty string or not provided, uses the prefix from config.
 75              If provided, it will be appended to the configured prefix.
 76          extensions
 77              Optional list of file extensions to filter by.
 78  
 79          Returns
 80          -------
 81          List of S3 URIs in format: s3://bucket/key
 82          """
 83          # Combine configured prefix with path
 84          full_prefix = self.prefix
 85          if path and path not in {".", ""}:
 86              full_prefix = f"{self.prefix.rstrip('/')}/{path.lstrip('/')}"
 87  
 88          self.logger.info(f"Listing S3 objects: s3://{self.bucket_name}/{full_prefix}")
 89  
 90          files = []
 91          paginator = self.s3_client.get_paginator("list_objects_v2")
 92  
 93          try:
 94              for page in paginator.paginate(Bucket=self.bucket_name, Prefix=full_prefix):
 95                  if "Contents" not in page:
 96                      continue
 97  
 98                  for obj in page["Contents"]:
 99                      key = obj["Key"]
100  
101                      # Skip directories (keys ending with /)
102                      if key.endswith("/"):
103                          continue
104  
105                      # Filter by extensions if provided
106                      if extensions is not None:
107                          file_ext = Path(key).suffix.lower()
108                          if file_ext not in extensions:
109                              continue
110  
111                      # Return as S3 URI
112                      s3_uri = f"s3://{self.bucket_name}/{key}"
113                      files.append(s3_uri)
114  
115          except Exception as e:
116              self.logger.info(f"Failed to list S3 objects: {e}")
117              raise RuntimeError(f"Failed to list S3 objects: {e}") from e
118  
119          self.logger.info(f"Found {len(files)} file(s) in S3")
120          return sorted(files)
121  
122      def get_file(self, file_id: str) -> Path:
123          """Download S3 file to temporary location.
124  
125          Parameters
126          ----------
127          file_id
128              S3 URI in format: s3://bucket/key
129  
130          Returns
131          -------
132          Path to the downloaded temporary file.
133          """
134          # Parse S3 URI
135          if not file_id.startswith("s3://"):
136              raise ValueError(f"Invalid S3 URI: {file_id}")
137  
138          parts = file_id[5:].split("/", 1)
139          if len(parts) != 2:
140              raise ValueError(f"Invalid S3 URI format: {file_id}")
141  
142          bucket, key = parts
143  
144          # Verify bucket matches
145          if bucket != self.bucket_name:
146              raise ValueError(
147                  f"Bucket mismatch: expected {self.bucket_name}, got {bucket}"
148              )
149  
150          # Create temporary file with same extension
151          suffix = Path(key).suffix
152          with tempfile.NamedTemporaryFile(
153              delete=False, suffix=suffix, prefix="s3_"
154          ) as temp_file:
155              temp_path = Path(temp_file.name)
156  
157          self.logger.info(f"Downloading {file_id} to {temp_path}")
158  
159          try:
160              self.s3_client.download_file(bucket, key, str(temp_path))
161              self.temp_files.append(temp_path)
162              return temp_path
163          except Exception as e:
164              # Clean up temp file on error
165              if temp_path.exists():
166                  temp_path.unlink()
167              self.logger.info(f"Failed to download {file_id}: {e}")
168              raise RuntimeError(f"Failed to download {file_id}: {e}") from e
169  
170      def cleanup(self) -> None:
171          """Clean up all temporary downloaded files."""
172          self.logger.info(f"Cleaning up {len(self.temp_files)} temporary file(s)")
173  
174          for temp_file in self.temp_files:
175              try:
176                  if temp_file.exists():
177                      temp_file.unlink()
178                      self.logger.debug(f"Deleted temporary file: {temp_file}")
179              except Exception as e:
180                  self.logger.warning(f"Failed to delete {temp_file}: {e}")
181  
182          self.temp_files.clear()
183  
184  
185  def create_s3_source(config: dict[str, Any]) -> InputSource:
186      """Create an S3 input source.
187  
188      Parameters
189      ----------
190      config
191          Configuration dictionary with S3 settings.
192  
193      Returns
194      -------
195      InputSource instance for S3.
196  
197      Raises
198      ------
199      ValueError
200          If required configuration is missing.
201      ImportError
202          If boto3 is not installed.
203      """
204      return S3InputSource(config)