video_analysis_tool.py
1 """ 2 Video analysis tool using TwelveLabs API. 3 4 Uses TwelveLabs SDK for uploads (handles multipart/form-data) and 5 requests for queries and listing. Video IDs and metadata are returned 6 so the agent can store them in memory for follow-up questions. 7 """ 8 9 import os 10 import json 11 import logging 12 from typing import Dict 13 14 import boto3 15 import requests 16 from strands import tool 17 18 logger = logging.getLogger(__name__) 19 20 TL_SECRET_ARN = os.environ.get("TL_SECRET_ARN", "") 21 TL_BASE_URL = "https://api.twelvelabs.io/v1.3" 22 REGION = os.environ.get("AWS_REGION", "us-east-1") 23 DEFAULT_INDEX_NAME = os.environ.get("TL_INDEX_NAME", "whatsapp-video-index") 24 TL_MODEL_NAME = os.environ.get("TL_MODEL_NAME", "pegasus1.2") 25 S3_PRESIGNED_EXPIRY = 3600 # 1 hour 26 27 _cached_api_key = None 28 29 30 def _get_api_key() -> str: 31 """Retrieve TwelveLabs API key from Secrets Manager (cached after first call).""" 32 global _cached_api_key 33 if _cached_api_key: 34 return _cached_api_key 35 36 if not TL_SECRET_ARN: 37 raise ValueError("TL_SECRET_ARN environment variable not set.") 38 39 sm = boto3.client("secretsmanager", region_name=REGION) 40 resp = sm.get_secret_value(SecretId=TL_SECRET_ARN) 41 secret = json.loads(resp["SecretString"]) 42 _cached_api_key = secret["TL_API_KEY"] 43 return _cached_api_key 44 45 46 def _get_tl_headers() -> dict: 47 return {"x-api-key": _get_api_key()} 48 49 50 def _generate_presigned_url(s3_uri: str) -> str: 51 """Convert s3://bucket/key to a pre-signed URL for TwelveLabs.""" 52 parts = s3_uri.replace("s3://", "").split("/", 1) 53 bucket = parts[0] 54 key = parts[1] 55 s3_client = boto3.client("s3", region_name=REGION) 56 return s3_client.generate_presigned_url( 57 "get_object", 58 Params={"Bucket": bucket, "Key": key}, 59 ExpiresIn=S3_PRESIGNED_EXPIRY, 60 ) 61 62 63 def _get_or_create_index(client, index_name: str): 64 """Get existing index or create new one using TwelveLabs SDK.""" 65 from twelvelabs.indexes import IndexesCreateRequestModelsItem 66 67 try: 68 for index in client.indexes.list(): 69 if index.index_name == index_name: 70 return index 71 except Exception: 72 pass 73 74 return client.indexes.create( 75 index_name=index_name, 76 models=[ 77 IndexesCreateRequestModelsItem( 78 model_name=TL_MODEL_NAME, 79 model_options=["visual", "audio"], 80 ) 81 ], 82 ) 83 84 85 @tool 86 def video_analysis( 87 action: str, 88 video_path: str = None, 89 video_name: str = None, 90 prompt: str = None, 91 index_name: str = DEFAULT_INDEX_NAME, 92 temperature: float = 0.2, 93 ) -> Dict: 94 """Analyze videos using TwelveLabs API. 95 96 Actions: 97 - upload: Upload a video for indexing. Use video_path for the S3 URI (s3://bucket/key) 98 or a public URL. Returns video_id and metadata (title, topics, hashtags). 99 IMPORTANT: After uploading, tell the user the video_id and metadata so it is 100 stored in memory for future queries. 101 - query: Ask a question about an already-uploaded video. Requires video_path 102 set to the video_id and prompt with the question. 103 - list_videos: List all indexed videos with their IDs. 104 105 Args: 106 action: Operation to perform (upload, query, list_videos). 107 video_path: S3 URI or URL for upload; video_id for query. 108 video_name: Human-readable name for the video (upload only). 109 prompt: Question about the video content (query only). 110 index_name: TwelveLabs index name. 111 temperature: Model temperature for query responses. 112 113 Returns: 114 Dict with operation results. 115 """ 116 if not TL_SECRET_ARN: 117 return { 118 "status": "error", 119 "content": [{"text": "TL_SECRET_ARN environment variable not set."}], 120 } 121 122 try: 123 if action == "upload": 124 return _handle_upload(video_path, video_name, index_name) 125 elif action == "query": 126 return _handle_query(video_path, prompt, temperature) 127 elif action == "list_videos": 128 return _handle_list_videos() 129 else: 130 return { 131 "status": "error", 132 "content": [{"text": f"Invalid action: {action}. Use upload, query, or list_videos."}], 133 } 134 except Exception as e: 135 logger.error("video_analysis failed: action=%s, error=%s", action, str(e), exc_info=True) 136 return { 137 "status": "error", 138 "content": [{"text": f"Video analysis failed: {str(e)}"}], 139 } 140 141 142 def _handle_upload(video_path: str, video_name: str, index_name: str) -> Dict: 143 """Upload video to TwelveLabs for indexing using the SDK.""" 144 if not video_path: 145 return {"status": "error", "content": [{"text": "video_path required for upload."}]} 146 147 from twelvelabs import TwelveLabs 148 149 api_key = _get_api_key() 150 client = TwelveLabs(api_key=api_key) 151 index = _get_or_create_index(client, index_name) 152 153 # Convert S3 URI to pre-signed URL 154 video_url = video_path 155 if video_path.startswith("s3://"): 156 video_url = _generate_presigned_url(video_path) 157 logger.info("Generated pre-signed URL for %s", video_path) 158 159 # Upload via SDK (handles multipart/form-data) 160 logger.info("TwelveLabs upload: index=%s, url_length=%d", index.id, len(video_url)) 161 task = client.tasks.create(index_id=index.id, video_url=video_url) 162 task = client.tasks.wait_for_done(task_id=task.id) 163 164 if task.status != "ready": 165 return {"status": "error", "content": [{"text": f"Video indexing failed: status={task.status}"}]} 166 167 video_id = task.video_id 168 logger.info("Video indexed: video_id=%s", video_id) 169 170 # Get metadata/insights via REST API (SDK gist() not available in all versions) 171 headers = _get_tl_headers() 172 gist_resp = requests.post( 173 f"{TL_BASE_URL}/gist", 174 headers=headers, 175 json={"video_id": video_id, "types": ["title", "topic", "hashtag"]}, 176 timeout=30, 177 ) 178 gist_data = gist_resp.json() if gist_resp.status_code == 200 else {} 179 180 return { 181 "status": "success", 182 "content": [{ 183 "json": { 184 "video_id": video_id, 185 "s3_uri": video_path if video_path.startswith("s3://") else None, 186 "title": gist_data.get("title", video_name or "Untitled"), 187 "topics": gist_data.get("topics", []), 188 "hashtags": gist_data.get("hashtags", []), 189 } 190 }], 191 } 192 193 194 def _handle_query(video_id: str, prompt: str, temperature: float) -> Dict: 195 """Query an indexed video using TwelveLabs analyze API.""" 196 if not video_id or not prompt: 197 return {"status": "error", "content": [{"text": "video_id and prompt required for query."}]} 198 199 headers = _get_tl_headers() 200 201 response = requests.post( 202 f"{TL_BASE_URL}/analyze", 203 headers=headers, 204 json={ 205 "video_id": video_id, 206 "prompt": prompt, 207 "temperature": temperature, 208 }, 209 timeout=60, 210 ) 211 212 if response.status_code != 200: 213 return {"status": "error", "content": [{"text": f"Query failed: {response.text}"}]} 214 215 text_parts = [] 216 for line in response.text.strip().split("\n"): 217 if line.strip(): 218 try: 219 event = json.loads(line) 220 if event.get("event_type") == "text_generation": 221 text_parts.append(event.get("text", "")) 222 except json.JSONDecodeError: 223 continue 224 225 return { 226 "status": "success", 227 "content": [{ 228 "json": { 229 "video_id": video_id, 230 "prompt": prompt, 231 "response": "".join(text_parts), 232 } 233 }], 234 } 235 236 237 def _handle_list_videos() -> Dict: 238 """List all indexed videos across all indexes.""" 239 headers = _get_tl_headers() 240 241 indexes_resp = requests.get( 242 f"{TL_BASE_URL}/indexes", 243 headers=headers, 244 params={"model_family": "pegasus"}, 245 timeout=30, 246 ) 247 if indexes_resp.status_code != 200: 248 return {"status": "error", "content": [{"text": f"Failed to list indexes: {indexes_resp.text}"}]} 249 250 all_videos = [] 251 for index in indexes_resp.json().get("data", []): 252 if index.get("video_count", 0) > 0: 253 videos_resp = requests.get( 254 f"{TL_BASE_URL}/indexes/{index['_id']}/videos", 255 headers=headers, 256 timeout=30, 257 ) 258 if videos_resp.status_code == 200: 259 for video in videos_resp.json().get("data", []): 260 all_videos.append({ 261 "video_id": video["_id"], 262 "created_at": video.get("created_at"), 263 "index_name": index["index_name"], 264 }) 265 266 return { 267 "status": "success", 268 "content": [{ 269 "json": { 270 "videos": all_videos, 271 "total_count": len(all_videos), 272 } 273 }], 274 }