multi_dataset.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 import os 20 import json 21 import pandas as pd 22 import random 23 from typing import List, Optional, Union, Dict, Any 24 25 from deepteam.vulnerabilities.custom import CustomVulnerability 26 from deepteam.vulnerabilities.multi_dataset import MultiDatasetVulnerabilityType 27 28 import os 29 import json 30 import random 31 import pandas as pd 32 from typing import List, Dict, Any, Optional, Tuple 33 34 class PromptLoader: 35 official = False 36 dataset_name = None 37 def __init__(self, random_seed = None): 38 # 设置随机种子 39 if random_seed is not None: 40 random.seed(random_seed) 41 self.random_seed = random_seed 42 43 # 定义可能的prompt列名,按优先级排序 44 self.PROMPT_COLUMN_CANDIDATES = [ 45 'prompt', 'question', 'query', 'text', 46 'input', 'content', 'instruction', 'message' 47 ] 48 49 def load_prompts(self, file_path: str, num_prompts: int = -1, 50 prompt_key: Optional[str] = None, 51 filter_conditions: Optional[Dict[str, Any]] = None) -> Tuple[List[str], List[Dict[str, Any]]]: 52 """从各种文件格式加载prompts 53 54 Args: 55 file_path: 输入文件路径 56 num_prompts: 要提取的prompt数量,-1表示全部 57 prompt_key: 指定作为prompt的列名 58 filter_conditions: 过滤条件字典 {列名: 值} 59 60 Returns: 61 tuple: (prompts列表, 元数据列表) 62 """ 63 if not os.path.exists(file_path): 64 raise FileNotFoundError(f"File not found: {file_path}") 65 66 ext = os.path.splitext(file_path)[1].lower() 67 68 if ext == '.json': 69 return self._load_from_json(file_path, num_prompts, prompt_key, filter_conditions) 70 elif ext == '.jsonl': 71 return self._load_from_jsonlines(file_path, num_prompts, prompt_key, filter_conditions) 72 elif ext in ('.csv', '.tsv'): 73 return self._load_from_csv(file_path, num_prompts, prompt_key, filter_conditions) 74 elif ext == '.parquet': 75 return self._load_from_parquet(file_path, num_prompts, prompt_key, filter_conditions) 76 elif ext in ('.xlsx', '.xls'): 77 return self._load_from_excel(file_path, num_prompts, prompt_key, filter_conditions) 78 elif ext == '.txt': 79 return self._load_from_txt(file_path, num_prompts, filter_conditions) 80 else: 81 raise ValueError(f"Unsupported file format: {ext}") 82 83 def _detect_prompt_column(self, df: pd.DataFrame) -> str: 84 """自动检测DataFrame中最可能是prompt的列""" 85 for col in self.PROMPT_COLUMN_CANDIDATES: 86 if col in df.columns: 87 return col 88 89 # 如果没有匹配的列名,尝试基于内容识别 90 for col in df.columns: 91 sample = str(df[col].iloc[0]) if len(df) > 0 else "" 92 if len(sample.split()) >= 5: # 假设prompt通常有较多单词 93 return col 94 95 # 最后返回第一列 96 return df.columns[0] 97 98 def _apply_filters(self, df: pd.DataFrame, filter_conditions: Optional[Dict[str, Any]]) -> pd.DataFrame: 99 """应用过滤条件到DataFrame""" 100 if not filter_conditions: 101 return df 102 103 for column, value in filter_conditions.items(): 104 if column in df.columns: 105 if isinstance(value, (list, tuple)): 106 df = df[df[column].isin(value)] 107 else: 108 df = df[df[column] == value] 109 110 return df 111 112 def _process_dataframe(self, df: pd.DataFrame, num_prompts: int, 113 prompt_key: Optional[str], source_file: str) -> Tuple[List[str], List[Dict[str, Any]]]: 114 """处理DataFrame提取prompts和元数据""" 115 # 确定prompt列 116 prompt_col = prompt_key if prompt_key else self._detect_prompt_column(df) 117 118 if prompt_col not in df.columns: 119 raise ValueError(f"Prompt column '{prompt_col}' not found in data") 120 121 # 清理数据 122 df = df.dropna(subset=[prompt_col]) 123 df = df.drop_duplicates(subset=[prompt_col]) 124 125 if df.empty: 126 raise ValueError("No valid prompts found after cleaning") 127 128 # 随机筛选指定数量的prompt 129 if len(df) <= num_prompts or num_prompts == -1: 130 selected_df = df 131 if num_prompts != -1: 132 print(f"WARNING: Requested {num_prompts} prompts but only {len(df)} available") 133 else: 134 selected_df = df.sample(n=num_prompts, random_state=self.random_seed) 135 136 prompts = [] 137 metadata = [] 138 139 for _, row in selected_df.iterrows(): 140 prompt = str(row[prompt_col]).strip() 141 142 if prompt: 143 prompts.append(prompt) 144 145 # 构建元数据 146 meta = { 147 "prompt": prompt, 148 "source_file": os.path.basename(source_file), 149 "row_index": row.name 150 } 151 152 # 添加所有其他字段作为元数据 153 for col in df.columns: 154 if col != prompt_col and pd.notna(row[col]): 155 meta[col] = str(row[col]) 156 157 metadata.append(meta) 158 159 if not prompts: 160 raise ValueError("No valid prompts found in file") 161 162 return prompts, metadata 163 164 def _load_from_json(self, json_file: str, num_prompts: int, 165 prompt_key: Optional[str], filter_conditions: Optional[Dict[str, Any]]) -> Tuple[List[str], List[Dict[str, Any]]]: 166 """从JSON文件加载prompt列表和元数据""" 167 try: 168 with open(json_file, 'r', encoding='utf-8') as f: 169 data = json.load(f) 170 171 # 处理不同JSON结构 172 if isinstance(data, dict): 173 self.official = data.get("official", False) 174 self.dataset_name = data.get("name") 175 prob_item = [] 176 for k, v in data.items(): 177 if isinstance(v, list): 178 if k in ['data', 'examples']: 179 items = v 180 break 181 elif len(v) > len(prob_item): 182 prob_item = v 183 else: 184 # 如果没有匹配,找最长的列表 185 if prob_item: 186 items = prob_item 187 # 整个字典就是数据 188 else: 189 items = [data] 190 elif isinstance(data, list): 191 items = data 192 else: 193 raise ValueError(f"Invalid JSON format in {json_file}") 194 195 df = pd.DataFrame(items) 196 df = self._apply_filters(df, filter_conditions) 197 198 return self._process_dataframe(df, num_prompts, prompt_key, json_file) 199 200 except Exception as e: 201 raise ValueError(f"Error loading prompts from JSON file {json_file}: {e}") 202 203 def _load_from_jsonlines(self, jsonl_file: str, num_prompts: int, 204 prompt_key: Optional[str], filter_conditions: Optional[Dict[str, Any]]) -> Tuple[List[str], List[Dict[str, Any]]]: 205 """从JSON Lines文件加载prompt列表和元数据""" 206 try: 207 items = [] 208 with open(jsonl_file, 'r', encoding='utf-8') as f: 209 for line in f: 210 line = line.strip() 211 if line: 212 items.append(json.loads(line)) 213 214 if not items: 215 raise ValueError(f"No data found in JSON Lines file: {jsonl_file}") 216 217 df = pd.DataFrame(items) 218 df = self._apply_filters(df, filter_conditions) 219 220 return self._process_dataframe(df, num_prompts, prompt_key, jsonl_file) 221 222 except Exception as e: 223 raise ValueError(f"Error loading prompts from JSON Lines file {jsonl_file}: {e}") 224 225 def _load_from_csv(self, csv_file: str, num_prompts: int, 226 prompt_key: Optional[str], filter_conditions: Optional[Dict[str, Any]]) -> Tuple[List[str], List[Dict[str, Any]]]: 227 """从CSV/TSV文件加载prompt列表和元数据""" 228 try: 229 # 自动检测分隔符 230 sep = ',' if csv_file.endswith('.csv') else '\t' 231 232 df = pd.read_csv(csv_file, sep=sep, encoding='utf-8') 233 df = self._apply_filters(df, filter_conditions) 234 235 return self._process_dataframe(df, num_prompts, prompt_key, csv_file) 236 237 except Exception as e: 238 raise ValueError(f"Error loading prompts from CSV file {csv_file}: {e}") 239 240 def _load_from_parquet(self, parquet_file: str, num_prompts: int, 241 prompt_key: Optional[str], filter_conditions: Optional[Dict[str, Any]]) -> Tuple[List[str], List[Dict[str, Any]]]: 242 """从Parquet文件加载prompt列表和元数据""" 243 try: 244 df = pd.read_parquet(parquet_file) 245 df = self._apply_filters(df, filter_conditions) 246 247 return self._process_dataframe(df, num_prompts, prompt_key, parquet_file) 248 249 except Exception as e: 250 raise ValueError(f"Error loading prompts from Parquet file {parquet_file}: {e}") 251 252 def _load_from_excel(self, excel_file: str, num_prompts: int, 253 prompt_key: Optional[str], filter_conditions: Optional[Dict[str, Any]]) -> Tuple[List[str], List[Dict[str, Any]]]: 254 """从Excel文件加载prompt列表和元数据""" 255 try: 256 df = pd.read_excel(excel_file) 257 df = self._apply_filters(df, filter_conditions) 258 259 return self._process_dataframe(df, num_prompts, prompt_key, excel_file) 260 261 except Exception as e: 262 raise ValueError(f"Error loading prompts from Excel file {excel_file}: {e}") 263 264 def _load_from_txt(self, txt_file: str, num_prompts: int, 265 filter_conditions: Optional[Dict[str, Any]]) -> Tuple[List[str], List[Dict[str, Any]]]: 266 """从文本文件加载prompt列表和元数据""" 267 try: 268 prompts = [] 269 with open(txt_file, 'r', encoding='utf-8') as f: 270 for line in f: 271 line = line.strip() 272 if line and not line.startswith(('#', '//')): # 跳过注释行 273 prompts.append(line) 274 275 if not prompts: 276 raise ValueError(f"No valid prompts found in text file: {txt_file}") 277 278 # 随机筛选指定数量的prompt 279 if len(prompts) <= num_prompts or num_prompts == -1: 280 selected_prompts = prompts 281 if num_prompts != -1: 282 print(f"WARNING: Requested {num_prompts} prompts but only {len(prompts)} available") 283 else: 284 selected_prompts = random.sample(prompts, num_prompts) 285 286 # 为文本文件创建简单元数据 287 metadata = [{ 288 "prompt": prompt, 289 "source_file": os.path.basename(txt_file), 290 "description": "Loaded from text file" 291 } for prompt in selected_prompts] 292 293 return selected_prompts, metadata 294 295 except Exception as e: 296 raise ValueError(f"Error loading prompts from text file {txt_file}: {e}") 297 298 class MultiDatasetVulnerability(CustomVulnerability): 299 """ 300 多数据集漏洞类,从CSV文件中读取prompt并随机筛选 301 使用pandas实现,支持更多高级功能 302 """ 303 304 def __init__(self, prompt = None, dataset_file: str = "", num_prompts: int = 10, random_seed: Optional[int] = None, 305 prompt_column: Optional[str] = None, filter_conditions: Optional[Dict[str, Any]] = None): 306 """ 307 初始化多数据集漏洞 308 309 Args: 310 csv_file: CSV文件路径 311 num_prompts: 要筛选的prompt数量,默认为10 312 random_seed: 随机种子,用于可重现的结果 313 prompt_column: 指定prompt列名,如果为None则自动检测 314 filter_conditions: 过滤条件字典,如{"category": "harmful", "language": "zh"} 315 """ 316 # 获取CSV文件的完整路径 317 if not os.path.isabs(dataset_file): 318 dataset_file = os.path.join(os.path.dirname(__file__), dataset_file) 319 320 # 加载prompts和元数据 321 self.loader = PromptLoader(random_seed) 322 if prompt is not None: 323 self.prompts = [prompt] 324 self.metadata = [{ 325 "prompt": prompt, 326 "category": "multi_dataset", 327 "language": "unknown", 328 "description": "Loaded from single prompt", 329 "source_file": "Single prompt", 330 "row_index": "prompt" # pandas的索引 331 }] 332 else: 333 self.prompts, self.metadata = self.loader.load_prompts(dataset_file, num_prompts, prompt_column, filter_conditions) 334 335 if self.loader.official: 336 dataset_type = MultiDatasetVulnerabilityType.OFFICIAL_MULTI_DATASET_VULNERABILITY 337 else: 338 dataset_type = MultiDatasetVulnerabilityType.MULTI_DATASET_VULNERABILITY 339 340 self.dataset_name = self.loader.dataset_name 341 # 调用父类初始化 342 super().__init__( 343 name="Multi Dataset Vulnerability", 344 types=[dataset_type], 345 custom_prompt=self.prompts 346 ) 347 348 def get_prompts(self) -> List[str]: 349 """获取所有prompt""" 350 return self.prompts 351 352 def get_custom_prompt(self) -> Optional[str]: 353 """获取第一个prompt(兼容性方法)""" 354 return self.prompts[0] if self.prompts else None 355 356 def get_dataframe_info(self) -> Dict[str, Any]: 357 """获取数据集信息""" 358 if not self.metadata: 359 return {"error": "No metadata available"} 360 361 # 统计信息 362 info = { 363 "total_prompts": len(self.prompts), 364 "source_file": self.metadata[0].get("source_file", "unknown"), 365 "available_columns": list(self.metadata[0].keys()) if self.metadata else [] 366 } 367 368 # 如果有category列,统计类别分布 369 categories = [meta.get("category") for meta in self.metadata if meta.get("category")] 370 if categories: 371 info["category_distribution"] = pd.Series(categories).value_counts().to_dict() 372 373 # 如果有language列,统计语言分布 374 languages = [meta.get("language") for meta in self.metadata if meta.get("language")] 375 if languages: 376 info["language_distribution"] = pd.Series(languages).value_counts().to_dict() 377 378 return info 379 380 # 测试代码 381 if __name__ == "__main__": 382 # 测试1: 默认参数 383 try: 384 vuln1 = MultiDatasetVulnerability() 385 print(f"Test 1: {len(vuln1.prompts)} prompts loaded") 386 print(f"Sample prompts: {vuln1.prompts[:3]}") 387 print(f"Dataset info: {vuln1.get_dataframe_info()}") 388 except Exception as e: 389 print(f"Test 1 failed: {e}") 390 391 # 测试2: 指定数量和随机种子 392 try: 393 vuln2 = MultiDatasetVulnerability(num_prompts=5, random_seed=42) 394 print(f"Test 2: {len(vuln2.prompts)} prompts loaded") 395 except Exception as e: 396 print(f"Test 2 failed: {e}") 397 398 # 测试3: 指定prompt列名 399 try: 400 vuln3 = MultiDatasetVulnerability(num_prompts=3, prompt_column="text") 401 print(f"Test 3: {len(vuln3.prompts)} prompts loaded with specified column") 402 except Exception as e: 403 print(f"Test 3 failed: {e}") 404 405 # 测试4: 使用过滤条件 406 try: 407 filter_conditions = {"category": "harmful", "language": "zh"} 408 vuln4 = MultiDatasetVulnerability(num_prompts=2, filter_conditions=filter_conditions) 409 print(f"Test 4: {len(vuln4.prompts)} prompts loaded with filters") 410 except Exception as e: 411 print(f"Test 4 failed: {e}") 412 413 # 测试5: 元数据 414 try: 415 vuln5 = MultiDatasetVulnerability(num_prompts=2) 416 print(f"Test 5: Metadata sample: {vuln5.metadata[0] if vuln5.metadata else 'No metadata'}") 417 except Exception as e: 418 print(f"Test 5 failed: {e}")