multi_dataset.py
  1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  #
 15  # Requirement: Any integration or derivative work must explicitly attribute
 16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  # documentation or user interface, as detailed in the NOTICE file.
 18  
 19  import os
 20  import json
 21  import pandas as pd
 22  import random
 23  from typing import List, Optional, Union, Dict, Any
 24  
 25  from deepteam.vulnerabilities.custom import CustomVulnerability
 26  from deepteam.vulnerabilities.multi_dataset import MultiDatasetVulnerabilityType
 27  
 28  import os
 29  import json
 30  import random
 31  import pandas as pd
 32  from typing import List, Dict, Any, Optional, Tuple
 33  
 34  class PromptLoader:
 35      official = False
 36      dataset_name = None
 37      def __init__(self, random_seed = None):
 38          # 设置随机种子
 39          if random_seed is not None:
 40              random.seed(random_seed)
 41          self.random_seed = random_seed
 42  
 43          # 定义可能的prompt列名,按优先级排序
 44          self.PROMPT_COLUMN_CANDIDATES = [
 45              'prompt', 'question', 'query', 'text', 
 46              'input', 'content', 'instruction', 'message'
 47          ]
 48      
 49      def load_prompts(self, file_path: str, num_prompts: int = -1, 
 50                      prompt_key: Optional[str] = None,
 51                      filter_conditions: Optional[Dict[str, Any]] = None) -> Tuple[List[str], List[Dict[str, Any]]]:
 52          """从各种文件格式加载prompts
 53          
 54          Args:
 55              file_path: 输入文件路径
 56              num_prompts: 要提取的prompt数量,-1表示全部
 57              prompt_key: 指定作为prompt的列名
 58              filter_conditions: 过滤条件字典 {列名: 值}
 59              
 60          Returns:
 61              tuple: (prompts列表, 元数据列表)
 62          """
 63          if not os.path.exists(file_path):
 64              raise FileNotFoundError(f"File not found: {file_path}")
 65          
 66          ext = os.path.splitext(file_path)[1].lower()
 67          
 68          if ext == '.json':
 69              return self._load_from_json(file_path, num_prompts, prompt_key, filter_conditions)
 70          elif ext == '.jsonl':
 71              return self._load_from_jsonlines(file_path, num_prompts, prompt_key, filter_conditions)
 72          elif ext in ('.csv', '.tsv'):
 73              return self._load_from_csv(file_path, num_prompts, prompt_key, filter_conditions)
 74          elif ext == '.parquet':
 75              return self._load_from_parquet(file_path, num_prompts, prompt_key, filter_conditions)
 76          elif ext in ('.xlsx', '.xls'):
 77              return self._load_from_excel(file_path, num_prompts, prompt_key, filter_conditions)
 78          elif ext == '.txt':
 79              return self._load_from_txt(file_path, num_prompts, filter_conditions)
 80          else:
 81              raise ValueError(f"Unsupported file format: {ext}")
 82      
 83      def _detect_prompt_column(self, df: pd.DataFrame) -> str:
 84          """自动检测DataFrame中最可能是prompt的列"""
 85          for col in self.PROMPT_COLUMN_CANDIDATES:
 86              if col in df.columns:
 87                  return col
 88          
 89          # 如果没有匹配的列名,尝试基于内容识别
 90          for col in df.columns:
 91              sample = str(df[col].iloc[0]) if len(df) > 0 else ""
 92              if len(sample.split()) >= 5:  # 假设prompt通常有较多单词
 93                  return col
 94          
 95          # 最后返回第一列
 96          return df.columns[0]
 97      
 98      def _apply_filters(self, df: pd.DataFrame, filter_conditions: Optional[Dict[str, Any]]) -> pd.DataFrame:
 99          """应用过滤条件到DataFrame"""
100          if not filter_conditions:
101              return df
102          
103          for column, value in filter_conditions.items():
104              if column in df.columns:
105                  if isinstance(value, (list, tuple)):
106                      df = df[df[column].isin(value)]
107                  else:
108                      df = df[df[column] == value]
109          
110          return df
111      
112      def _process_dataframe(self, df: pd.DataFrame, num_prompts: int, 
113                           prompt_key: Optional[str], source_file: str) -> Tuple[List[str], List[Dict[str, Any]]]:
114          """处理DataFrame提取prompts和元数据"""
115          # 确定prompt列
116          prompt_col = prompt_key if prompt_key else self._detect_prompt_column(df)
117          
118          if prompt_col not in df.columns:
119              raise ValueError(f"Prompt column '{prompt_col}' not found in data")
120          
121          # 清理数据
122          df = df.dropna(subset=[prompt_col])
123          df = df.drop_duplicates(subset=[prompt_col])
124          
125          if df.empty:
126              raise ValueError("No valid prompts found after cleaning")
127          
128          # 随机筛选指定数量的prompt
129          if len(df) <= num_prompts or num_prompts == -1:
130              selected_df = df
131              if num_prompts != -1:
132                  print(f"WARNING: Requested {num_prompts} prompts but only {len(df)} available")
133          else:
134              selected_df = df.sample(n=num_prompts, random_state=self.random_seed)
135          
136          prompts = []
137          metadata = []
138          
139          for _, row in selected_df.iterrows():
140              prompt = str(row[prompt_col]).strip()
141              
142              if prompt:
143                  prompts.append(prompt)
144                  
145                  # 构建元数据
146                  meta = {
147                      "prompt": prompt,
148                      "source_file": os.path.basename(source_file),
149                      "row_index": row.name
150                  }
151                  
152                  # 添加所有其他字段作为元数据
153                  for col in df.columns:
154                      if col != prompt_col and pd.notna(row[col]):
155                          meta[col] = str(row[col])
156                  
157                  metadata.append(meta)
158          
159          if not prompts:
160              raise ValueError("No valid prompts found in file")
161              
162          return prompts, metadata
163      
164      def _load_from_json(self, json_file: str, num_prompts: int, 
165                         prompt_key: Optional[str], filter_conditions: Optional[Dict[str, Any]]) -> Tuple[List[str], List[Dict[str, Any]]]:
166          """从JSON文件加载prompt列表和元数据"""
167          try:
168              with open(json_file, 'r', encoding='utf-8') as f:
169                  data = json.load(f)
170              
171              # 处理不同JSON结构
172              if isinstance(data, dict):
173                  self.official = data.get("official", False)
174                  self.dataset_name = data.get("name")
175                  prob_item = []
176                  for k, v in data.items():
177                      if isinstance(v, list):
178                          if k in ['data', 'examples']:
179                              items = v
180                              break
181                          elif len(v) > len(prob_item):
182                              prob_item = v
183                  else:
184                      # 如果没有匹配,找最长的列表
185                      if prob_item:
186                          items = prob_item
187                      # 整个字典就是数据
188                      else:  
189                          items = [data]
190              elif isinstance(data, list):
191                  items = data
192              else:
193                  raise ValueError(f"Invalid JSON format in {json_file}")
194              
195              df = pd.DataFrame(items)
196              df = self._apply_filters(df, filter_conditions)
197              
198              return self._process_dataframe(df, num_prompts, prompt_key, json_file)
199              
200          except Exception as e:
201              raise ValueError(f"Error loading prompts from JSON file {json_file}: {e}")
202      
203      def _load_from_jsonlines(self, jsonl_file: str, num_prompts: int, 
204                             prompt_key: Optional[str], filter_conditions: Optional[Dict[str, Any]]) -> Tuple[List[str], List[Dict[str, Any]]]:
205          """从JSON Lines文件加载prompt列表和元数据"""
206          try:
207              items = []
208              with open(jsonl_file, 'r', encoding='utf-8') as f:
209                  for line in f:
210                      line = line.strip()
211                      if line:
212                          items.append(json.loads(line))
213              
214              if not items:
215                  raise ValueError(f"No data found in JSON Lines file: {jsonl_file}")
216              
217              df = pd.DataFrame(items)
218              df = self._apply_filters(df, filter_conditions)
219              
220              return self._process_dataframe(df, num_prompts, prompt_key, jsonl_file)
221              
222          except Exception as e:
223              raise ValueError(f"Error loading prompts from JSON Lines file {jsonl_file}: {e}")
224      
225      def _load_from_csv(self, csv_file: str, num_prompts: int, 
226                        prompt_key: Optional[str], filter_conditions: Optional[Dict[str, Any]]) -> Tuple[List[str], List[Dict[str, Any]]]:
227          """从CSV/TSV文件加载prompt列表和元数据"""
228          try:
229              # 自动检测分隔符
230              sep = ',' if csv_file.endswith('.csv') else '\t'
231              
232              df = pd.read_csv(csv_file, sep=sep, encoding='utf-8')
233              df = self._apply_filters(df, filter_conditions)
234              
235              return self._process_dataframe(df, num_prompts, prompt_key, csv_file)
236              
237          except Exception as e:
238              raise ValueError(f"Error loading prompts from CSV file {csv_file}: {e}")
239      
240      def _load_from_parquet(self, parquet_file: str, num_prompts: int,
241                           prompt_key: Optional[str], filter_conditions: Optional[Dict[str, Any]]) -> Tuple[List[str], List[Dict[str, Any]]]:
242          """从Parquet文件加载prompt列表和元数据"""
243          try:
244              df = pd.read_parquet(parquet_file)
245              df = self._apply_filters(df, filter_conditions)
246              
247              return self._process_dataframe(df, num_prompts, prompt_key, parquet_file)
248              
249          except Exception as e:
250              raise ValueError(f"Error loading prompts from Parquet file {parquet_file}: {e}")
251  
252      def _load_from_excel(self, excel_file: str, num_prompts: int, 
253                          prompt_key: Optional[str], filter_conditions: Optional[Dict[str, Any]]) -> Tuple[List[str], List[Dict[str, Any]]]:
254          """从Excel文件加载prompt列表和元数据"""
255          try:
256              df = pd.read_excel(excel_file)
257              df = self._apply_filters(df, filter_conditions)
258              
259              return self._process_dataframe(df, num_prompts, prompt_key, excel_file)
260              
261          except Exception as e:
262              raise ValueError(f"Error loading prompts from Excel file {excel_file}: {e}")
263      
264      def _load_from_txt(self, txt_file: str, num_prompts: int, 
265                        filter_conditions: Optional[Dict[str, Any]]) -> Tuple[List[str], List[Dict[str, Any]]]:
266          """从文本文件加载prompt列表和元数据"""
267          try:
268              prompts = []
269              with open(txt_file, 'r', encoding='utf-8') as f:
270                  for line in f:
271                      line = line.strip()
272                      if line and not line.startswith(('#', '//')):  # 跳过注释行
273                          prompts.append(line)
274              
275              if not prompts:
276                  raise ValueError(f"No valid prompts found in text file: {txt_file}")
277              
278              # 随机筛选指定数量的prompt
279              if len(prompts) <= num_prompts or num_prompts == -1:
280                  selected_prompts = prompts
281                  if num_prompts != -1:
282                      print(f"WARNING: Requested {num_prompts} prompts but only {len(prompts)} available")
283              else:
284                  selected_prompts = random.sample(prompts, num_prompts)
285              
286              # 为文本文件创建简单元数据
287              metadata = [{
288                  "prompt": prompt,
289                  "source_file": os.path.basename(txt_file),
290                  "description": "Loaded from text file"
291              } for prompt in selected_prompts]
292              
293              return selected_prompts, metadata
294              
295          except Exception as e:
296              raise ValueError(f"Error loading prompts from text file {txt_file}: {e}")
297  
298  class MultiDatasetVulnerability(CustomVulnerability):
299      """
300      多数据集漏洞类,从CSV文件中读取prompt并随机筛选
301      使用pandas实现,支持更多高级功能
302      """
303      
304      def __init__(self, prompt = None, dataset_file: str = "", num_prompts: int = 10, random_seed: Optional[int] = None, 
305                   prompt_column: Optional[str] = None, filter_conditions: Optional[Dict[str, Any]] = None):
306          """
307          初始化多数据集漏洞
308          
309          Args:
310              csv_file: CSV文件路径
311              num_prompts: 要筛选的prompt数量,默认为10
312              random_seed: 随机种子,用于可重现的结果
313              prompt_column: 指定prompt列名,如果为None则自动检测
314              filter_conditions: 过滤条件字典,如{"category": "harmful", "language": "zh"}
315          """
316          # 获取CSV文件的完整路径
317          if not os.path.isabs(dataset_file):
318              dataset_file = os.path.join(os.path.dirname(__file__), dataset_file)
319          
320          # 加载prompts和元数据
321          self.loader = PromptLoader(random_seed)
322          if prompt is not None:
323              self.prompts = [prompt]
324              self.metadata = [{
325                  "prompt": prompt,
326                  "category": "multi_dataset",
327                  "language": "unknown",
328                  "description": "Loaded from single prompt",
329                  "source_file": "Single prompt",
330                  "row_index": "prompt"  # pandas的索引
331              }]
332          else:
333              self.prompts, self.metadata = self.loader.load_prompts(dataset_file, num_prompts, prompt_column, filter_conditions)
334          
335          if self.loader.official:
336              dataset_type = MultiDatasetVulnerabilityType.OFFICIAL_MULTI_DATASET_VULNERABILITY
337          else:
338              dataset_type = MultiDatasetVulnerabilityType.MULTI_DATASET_VULNERABILITY
339          
340          self.dataset_name = self.loader.dataset_name
341          # 调用父类初始化
342          super().__init__(
343              name="Multi Dataset Vulnerability",
344              types=[dataset_type],
345              custom_prompt=self.prompts
346          )
347      
348      def get_prompts(self) -> List[str]:
349          """获取所有prompt"""
350          return self.prompts
351      
352      def get_custom_prompt(self) -> Optional[str]:
353          """获取第一个prompt(兼容性方法)"""
354          return self.prompts[0] if self.prompts else None
355      
356      def get_dataframe_info(self) -> Dict[str, Any]:
357          """获取数据集信息"""
358          if not self.metadata:
359              return {"error": "No metadata available"}
360          
361          # 统计信息
362          info = {
363              "total_prompts": len(self.prompts),
364              "source_file": self.metadata[0].get("source_file", "unknown"),
365              "available_columns": list(self.metadata[0].keys()) if self.metadata else []
366          }
367          
368          # 如果有category列,统计类别分布
369          categories = [meta.get("category") for meta in self.metadata if meta.get("category")]
370          if categories:
371              info["category_distribution"] = pd.Series(categories).value_counts().to_dict()
372          
373          # 如果有language列,统计语言分布
374          languages = [meta.get("language") for meta in self.metadata if meta.get("language")]
375          if languages:
376              info["language_distribution"] = pd.Series(languages).value_counts().to_dict()
377          
378          return info
379  
380  # 测试代码
381  if __name__ == "__main__":
382      # 测试1: 默认参数
383      try:
384          vuln1 = MultiDatasetVulnerability()
385          print(f"Test 1: {len(vuln1.prompts)} prompts loaded")
386          print(f"Sample prompts: {vuln1.prompts[:3]}")
387          print(f"Dataset info: {vuln1.get_dataframe_info()}")
388      except Exception as e:
389          print(f"Test 1 failed: {e}")
390      
391      # 测试2: 指定数量和随机种子
392      try:
393          vuln2 = MultiDatasetVulnerability(num_prompts=5, random_seed=42)
394          print(f"Test 2: {len(vuln2.prompts)} prompts loaded")
395      except Exception as e:
396          print(f"Test 2 failed: {e}")
397      
398      # 测试3: 指定prompt列名
399      try:
400          vuln3 = MultiDatasetVulnerability(num_prompts=3, prompt_column="text")
401          print(f"Test 3: {len(vuln3.prompts)} prompts loaded with specified column")
402      except Exception as e:
403          print(f"Test 3 failed: {e}")
404      
405      # 测试4: 使用过滤条件
406      try:
407          filter_conditions = {"category": "harmful", "language": "zh"}
408          vuln4 = MultiDatasetVulnerability(num_prompts=2, filter_conditions=filter_conditions)
409          print(f"Test 4: {len(vuln4.prompts)} prompts loaded with filters")
410      except Exception as e:
411          print(f"Test 4 failed: {e}")
412      
413      # 测试5: 元数据
414      try:
415          vuln5 = MultiDatasetVulnerability(num_prompts=2)
416          print(f"Test 5: Metadata sample: {vuln5.metadata[0] if vuln5.metadata else 'No metadata'}")
417      except Exception as e:
418          print(f"Test 5 failed: {e}")