/ 00-agent-agentcore / agent_files / video_analysis_tool.py
video_analysis_tool.py
  1  """
  2  Video analysis tool using TwelveLabs API.
  3  
  4  Uses TwelveLabs SDK for uploads (handles multipart/form-data) and
  5  requests for queries and listing. Video IDs and metadata are returned
  6  so the agent can store them in memory for follow-up questions.
  7  """
  8  
  9  import os
 10  import json
 11  import logging
 12  from typing import Dict
 13  
 14  import boto3
 15  import requests
 16  from strands import tool
 17  
 18  logger = logging.getLogger(__name__)
 19  
 20  TL_SECRET_ARN = os.environ.get("TL_SECRET_ARN", "")
 21  TL_BASE_URL = "https://api.twelvelabs.io/v1.3"
 22  REGION = os.environ.get("AWS_REGION", "us-east-1")
 23  DEFAULT_INDEX_NAME = os.environ.get("TL_INDEX_NAME", "whatsapp-video-index")
 24  TL_MODEL_NAME = os.environ.get("TL_MODEL_NAME", "pegasus1.2")
 25  S3_PRESIGNED_EXPIRY = 3600  # 1 hour
 26  
 27  _cached_api_key = None
 28  
 29  
 30  def _get_api_key() -> str:
 31      """Retrieve TwelveLabs API key from Secrets Manager (cached after first call)."""
 32      global _cached_api_key
 33      if _cached_api_key:
 34          return _cached_api_key
 35  
 36      if not TL_SECRET_ARN:
 37          raise ValueError("TL_SECRET_ARN environment variable not set.")
 38  
 39      sm = boto3.client("secretsmanager", region_name=REGION)
 40      resp = sm.get_secret_value(SecretId=TL_SECRET_ARN)
 41      secret = json.loads(resp["SecretString"])
 42      _cached_api_key = secret["TL_API_KEY"]
 43      return _cached_api_key
 44  
 45  
 46  def _get_tl_headers() -> dict:
 47      return {"x-api-key": _get_api_key()}
 48  
 49  
 50  def _generate_presigned_url(s3_uri: str) -> str:
 51      """Convert s3://bucket/key to a pre-signed URL for TwelveLabs."""
 52      parts = s3_uri.replace("s3://", "").split("/", 1)
 53      bucket = parts[0]
 54      key = parts[1]
 55      s3_client = boto3.client("s3", region_name=REGION)
 56      return s3_client.generate_presigned_url(
 57          "get_object",
 58          Params={"Bucket": bucket, "Key": key},
 59          ExpiresIn=S3_PRESIGNED_EXPIRY,
 60      )
 61  
 62  
 63  def _get_or_create_index(client, index_name: str):
 64      """Get existing index or create new one using TwelveLabs SDK."""
 65      from twelvelabs.indexes import IndexesCreateRequestModelsItem
 66  
 67      try:
 68          for index in client.indexes.list():
 69              if index.index_name == index_name:
 70                  return index
 71      except Exception:
 72          pass
 73  
 74      return client.indexes.create(
 75          index_name=index_name,
 76          models=[
 77              IndexesCreateRequestModelsItem(
 78                  model_name=TL_MODEL_NAME,
 79                  model_options=["visual", "audio"],
 80              )
 81          ],
 82      )
 83  
 84  
 85  @tool
 86  def video_analysis(
 87      action: str,
 88      video_path: str = None,
 89      video_name: str = None,
 90      prompt: str = None,
 91      index_name: str = DEFAULT_INDEX_NAME,
 92      temperature: float = 0.2,
 93  ) -> Dict:
 94      """Analyze videos using TwelveLabs API.
 95  
 96      Actions:
 97      - upload: Upload a video for indexing. Use video_path for the S3 URI (s3://bucket/key)
 98        or a public URL. Returns video_id and metadata (title, topics, hashtags).
 99        IMPORTANT: After uploading, tell the user the video_id and metadata so it is
100        stored in memory for future queries.
101      - query: Ask a question about an already-uploaded video. Requires video_path
102        set to the video_id and prompt with the question.
103      - list_videos: List all indexed videos with their IDs.
104  
105      Args:
106          action: Operation to perform (upload, query, list_videos).
107          video_path: S3 URI or URL for upload; video_id for query.
108          video_name: Human-readable name for the video (upload only).
109          prompt: Question about the video content (query only).
110          index_name: TwelveLabs index name.
111          temperature: Model temperature for query responses.
112  
113      Returns:
114          Dict with operation results.
115      """
116      if not TL_SECRET_ARN:
117          return {
118              "status": "error",
119              "content": [{"text": "TL_SECRET_ARN environment variable not set."}],
120          }
121  
122      try:
123          if action == "upload":
124              return _handle_upload(video_path, video_name, index_name)
125          elif action == "query":
126              return _handle_query(video_path, prompt, temperature)
127          elif action == "list_videos":
128              return _handle_list_videos()
129          else:
130              return {
131                  "status": "error",
132                  "content": [{"text": f"Invalid action: {action}. Use upload, query, or list_videos."}],
133              }
134      except Exception as e:
135          logger.error("video_analysis failed: action=%s, error=%s", action, str(e), exc_info=True)
136          return {
137              "status": "error",
138              "content": [{"text": f"Video analysis failed: {str(e)}"}],
139          }
140  
141  
142  def _handle_upload(video_path: str, video_name: str, index_name: str) -> Dict:
143      """Upload video to TwelveLabs for indexing using the SDK."""
144      if not video_path:
145          return {"status": "error", "content": [{"text": "video_path required for upload."}]}
146  
147      from twelvelabs import TwelveLabs
148  
149      api_key = _get_api_key()
150      client = TwelveLabs(api_key=api_key)
151      index = _get_or_create_index(client, index_name)
152  
153      # Convert S3 URI to pre-signed URL
154      video_url = video_path
155      if video_path.startswith("s3://"):
156          video_url = _generate_presigned_url(video_path)
157          logger.info("Generated pre-signed URL for %s", video_path)
158  
159      # Upload via SDK (handles multipart/form-data)
160      logger.info("TwelveLabs upload: index=%s, url_length=%d", index.id, len(video_url))
161      task = client.tasks.create(index_id=index.id, video_url=video_url)
162      task = client.tasks.wait_for_done(task_id=task.id)
163  
164      if task.status != "ready":
165          return {"status": "error", "content": [{"text": f"Video indexing failed: status={task.status}"}]}
166  
167      video_id = task.video_id
168      logger.info("Video indexed: video_id=%s", video_id)
169  
170      # Get metadata/insights via REST API (SDK gist() not available in all versions)
171      headers = _get_tl_headers()
172      gist_resp = requests.post(
173          f"{TL_BASE_URL}/gist",
174          headers=headers,
175          json={"video_id": video_id, "types": ["title", "topic", "hashtag"]},
176          timeout=30,
177      )
178      gist_data = gist_resp.json() if gist_resp.status_code == 200 else {}
179  
180      return {
181          "status": "success",
182          "content": [{
183              "json": {
184                  "video_id": video_id,
185                  "s3_uri": video_path if video_path.startswith("s3://") else None,
186                  "title": gist_data.get("title", video_name or "Untitled"),
187                  "topics": gist_data.get("topics", []),
188                  "hashtags": gist_data.get("hashtags", []),
189              }
190          }],
191      }
192  
193  
194  def _handle_query(video_id: str, prompt: str, temperature: float) -> Dict:
195      """Query an indexed video using TwelveLabs analyze API."""
196      if not video_id or not prompt:
197          return {"status": "error", "content": [{"text": "video_id and prompt required for query."}]}
198  
199      headers = _get_tl_headers()
200  
201      response = requests.post(
202          f"{TL_BASE_URL}/analyze",
203          headers=headers,
204          json={
205              "video_id": video_id,
206              "prompt": prompt,
207              "temperature": temperature,
208          },
209          timeout=60,
210      )
211  
212      if response.status_code != 200:
213          return {"status": "error", "content": [{"text": f"Query failed: {response.text}"}]}
214  
215      text_parts = []
216      for line in response.text.strip().split("\n"):
217          if line.strip():
218              try:
219                  event = json.loads(line)
220                  if event.get("event_type") == "text_generation":
221                      text_parts.append(event.get("text", ""))
222              except json.JSONDecodeError:
223                  continue
224  
225      return {
226          "status": "success",
227          "content": [{
228              "json": {
229                  "video_id": video_id,
230                  "prompt": prompt,
231                  "response": "".join(text_parts),
232              }
233          }],
234      }
235  
236  
237  def _handle_list_videos() -> Dict:
238      """List all indexed videos across all indexes."""
239      headers = _get_tl_headers()
240  
241      indexes_resp = requests.get(
242          f"{TL_BASE_URL}/indexes",
243          headers=headers,
244          params={"model_family": "pegasus"},
245          timeout=30,
246      )
247      if indexes_resp.status_code != 200:
248          return {"status": "error", "content": [{"text": f"Failed to list indexes: {indexes_resp.text}"}]}
249  
250      all_videos = []
251      for index in indexes_resp.json().get("data", []):
252          if index.get("video_count", 0) > 0:
253              videos_resp = requests.get(
254                  f"{TL_BASE_URL}/indexes/{index['_id']}/videos",
255                  headers=headers,
256                  timeout=30,
257              )
258              if videos_resp.status_code == 200:
259                  for video in videos_resp.json().get("data", []):
260                      all_videos.append({
261                          "video_id": video["_id"],
262                          "created_at": video.get("created_at"),
263                          "index_name": index["index_name"],
264                      })
265  
266      return {
267          "status": "success",
268          "content": [{
269              "json": {
270                  "videos": all_videos,
271                  "total_count": len(all_videos),
272              }
273          }],
274      }