s3_source.py
1 from __future__ import annotations 2 3 """S3 input source implementation.""" 4 5 import logging 6 import tempfile 7 from pathlib import Path 8 from typing import Any 9 10 import boto3 11 12 from .protocol import InputSource 13 14 15 class S3InputSource: 16 """Input source for AWS S3 buckets. 17 18 Downloads files from S3 to temporary locations for processing. 19 Handles cleanup of temporary files. 20 """ 21 22 def __init__(self, config: dict[str, Any]): 23 """Initialize the S3 input source. 24 25 Parameters 26 ---------- 27 config 28 Configuration dictionary with keys: 29 - bucket_name: str (required) - S3 bucket name 30 - prefix: str (optional) - S3 key prefix to filter files 31 - aws_access_key_id: str (optional) - AWS access key 32 - aws_secret_access_key: str (optional) - AWS secret key 33 - aws_session_token: str (optional) - AWS session token (for temporary credentials) 34 - region_name: str (optional) - AWS region 35 - endpoint_url: str (optional) - Custom S3 endpoint (for S3-compatible services) 36 """ 37 self.config = config 38 self.logger = logging.getLogger(__name__) 39 self.temp_files: list[Path] = [] 40 41 # Validate required config 42 if "bucket_name" not in config: 43 raise ValueError("bucket_name is required for S3 input source") 44 45 self.bucket_name = config["bucket_name"] 46 self.prefix = config.get("prefix", "") 47 48 self._s3_client = None 49 50 @property 51 def s3_client(self): 52 """Get or create S3 client (lazy initialization).""" 53 if self._s3_client is None: 54 boto3_config = { 55 "aws_access_key_id": self.config.get("aws_access_key_id"), 56 "aws_secret_access_key": self.config.get("aws_secret_access_key"), 57 "aws_session_token": self.config.get("aws_session_token"), 58 "region_name": self.config.get("region_name"), 59 "endpoint_url": self.config.get("endpoint_url"), 60 } 61 62 self._s3_client = boto3.client("s3", **boto3_config) 63 self.logger.info(f"Initialized S3 client for bucket: {self.bucket_name}") 64 65 return self._s3_client 66 67 def list_files(self, path: str = "", extensions: list[str] | None = None) -> list[str]: 68 """List files in S3 bucket with optional prefix and extension filtering. 69 70 Parameters 71 ---------- 72 path 73 S3 prefix to list files from (relative to configured prefix). 74 If empty string or not provided, uses the prefix from config. 75 If provided, it will be appended to the configured prefix. 76 extensions 77 Optional list of file extensions to filter by. 78 79 Returns 80 ------- 81 List of S3 URIs in format: s3://bucket/key 82 """ 83 # Combine configured prefix with path 84 full_prefix = self.prefix 85 if path and path not in {".", ""}: 86 full_prefix = f"{self.prefix.rstrip('/')}/{path.lstrip('/')}" 87 88 self.logger.info(f"Listing S3 objects: s3://{self.bucket_name}/{full_prefix}") 89 90 files = [] 91 paginator = self.s3_client.get_paginator("list_objects_v2") 92 93 try: 94 for page in paginator.paginate(Bucket=self.bucket_name, Prefix=full_prefix): 95 if "Contents" not in page: 96 continue 97 98 for obj in page["Contents"]: 99 key = obj["Key"] 100 101 # Skip directories (keys ending with /) 102 if key.endswith("/"): 103 continue 104 105 # Filter by extensions if provided 106 if extensions is not None: 107 file_ext = Path(key).suffix.lower() 108 if file_ext not in extensions: 109 continue 110 111 # Return as S3 URI 112 s3_uri = f"s3://{self.bucket_name}/{key}" 113 files.append(s3_uri) 114 115 except Exception as e: 116 self.logger.info(f"Failed to list S3 objects: {e}") 117 raise RuntimeError(f"Failed to list S3 objects: {e}") from e 118 119 self.logger.info(f"Found {len(files)} file(s) in S3") 120 return sorted(files) 121 122 def get_file(self, file_id: str) -> Path: 123 """Download S3 file to temporary location. 124 125 Parameters 126 ---------- 127 file_id 128 S3 URI in format: s3://bucket/key 129 130 Returns 131 ------- 132 Path to the downloaded temporary file. 133 """ 134 # Parse S3 URI 135 if not file_id.startswith("s3://"): 136 raise ValueError(f"Invalid S3 URI: {file_id}") 137 138 parts = file_id[5:].split("/", 1) 139 if len(parts) != 2: 140 raise ValueError(f"Invalid S3 URI format: {file_id}") 141 142 bucket, key = parts 143 144 # Verify bucket matches 145 if bucket != self.bucket_name: 146 raise ValueError( 147 f"Bucket mismatch: expected {self.bucket_name}, got {bucket}" 148 ) 149 150 # Create temporary file with same extension 151 suffix = Path(key).suffix 152 with tempfile.NamedTemporaryFile( 153 delete=False, suffix=suffix, prefix="s3_" 154 ) as temp_file: 155 temp_path = Path(temp_file.name) 156 157 self.logger.info(f"Downloading {file_id} to {temp_path}") 158 159 try: 160 self.s3_client.download_file(bucket, key, str(temp_path)) 161 self.temp_files.append(temp_path) 162 return temp_path 163 except Exception as e: 164 # Clean up temp file on error 165 if temp_path.exists(): 166 temp_path.unlink() 167 self.logger.info(f"Failed to download {file_id}: {e}") 168 raise RuntimeError(f"Failed to download {file_id}: {e}") from e 169 170 def cleanup(self) -> None: 171 """Clean up all temporary downloaded files.""" 172 self.logger.info(f"Cleaning up {len(self.temp_files)} temporary file(s)") 173 174 for temp_file in self.temp_files: 175 try: 176 if temp_file.exists(): 177 temp_file.unlink() 178 self.logger.debug(f"Deleted temporary file: {temp_file}") 179 except Exception as e: 180 self.logger.warning(f"Failed to delete {temp_file}: {e}") 181 182 self.temp_files.clear() 183 184 185 def create_s3_source(config: dict[str, Any]) -> InputSource: 186 """Create an S3 input source. 187 188 Parameters 189 ---------- 190 config 191 Configuration dictionary with S3 settings. 192 193 Returns 194 ------- 195 InputSource instance for S3. 196 197 Raises 198 ------ 199 ValueError 200 If required configuration is missing. 201 ImportError 202 If boto3 is not installed. 203 """ 204 return S3InputSource(config)