/ AIG-PromptSecurity / deepteam / red_teamer / risk_assessment.py
risk_assessment.py
  1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  #
  3  # Licensed under the Apache License, Version 2.0 (the "License");
  4  # you may not use this file except in compliance with the License.
  5  # You may obtain a copy of the License at
  6  #
  7  #     http://www.apache.org/licenses/LICENSE-2.0
  8  #
  9  # Unless required by applicable law or agreed to in writing, software
 10  # distributed under the License is distributed on an "AS IS" BASIS,
 11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  # See the License for the specific language governing permissions and
 13  # limitations under the License.
 14  #
 15  # Requirement: Any integration or derivative work must explicitly attribute
 16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  # documentation or user interface, as detailed in the NOTICE file.
 18  
 19  from pydantic import BaseModel, Field
 20  from typing import Dict, Optional, List
 21  import datetime
 22  import os
 23  import json
 24  from enum import Enum
 25  
 26  from deepteam.vulnerabilities.types import VulnerabilityType
 27  
 28  
 29  class RedTeamingTestCase(BaseModel):
 30      vulnerability: str
 31      vulnerability_type: VulnerabilityType
 32      risk_category: str = Field(alias="riskCategory")
 33      attack_method: Optional[str] = Field(None, alias="attackMethod")
 34      original_input: Optional[str] = None
 35      input: Optional[str] = None
 36      actual_output: Optional[str] = Field(
 37          None, serialization_alias="actualOutput"
 38      )
 39      score: Optional[float] = None
 40      reason: Optional[str] = None
 41      error: Optional[str] = None
 42      useless: bool = False
 43  
 44  
 45  class TestCasesList(list):
 46      def to_df(self):
 47          import pandas as pd
 48  
 49          data = []
 50          for case in self:
 51              data.append(
 52                  {
 53                      "Vulnerability": case.vulnerability,
 54                      "Vulnerability Type": str(case.vulnerability_type.value),
 55                      "Risk Category": case.risk_category,
 56                      "Attack Enhancement": case.attack_method,
 57                      "Input": case.input,
 58                      "Actual Output": case.actual_output,
 59                      "Score": case.score,
 60                      "Reason": case.reason,
 61                      "Error": case.error,
 62                      "Status": (
 63                          "Passed"
 64                          if case.score and case.score > 0
 65                          else "Errored" if case.error else "Failed"
 66                      ),
 67                  }
 68              )
 69          return pd.DataFrame(data)
 70  
 71  
 72  class VulnerabilityTypeResult(BaseModel):
 73      vulnerability: str
 74      vulnerability_type: VulnerabilityType
 75      pass_rate: float
 76      passing: int
 77      failing: int
 78      errored: int
 79      unused: int
 80  
 81  
 82  class AttackMethodResult(BaseModel):
 83      attack_method: Optional[str] = None
 84      pass_rate: float
 85      passing: int
 86      failing: int
 87      errored: int
 88      unused: int
 89  
 90  
 91  class RedTeamingOverview(BaseModel):
 92      vulnerability_type_results: List[VulnerabilityTypeResult]
 93      attack_method_results: List[AttackMethodResult]
 94  
 95      def to_df(self):
 96          import pandas as pd
 97  
 98          data = []
 99          for result in self.vulnerability_type_results:
100              data.append(
101                  {
102                      "Vulnerability": result.vulnerability,
103                      "Vulnerability Type": str(result.vulnerability_type.value),
104                      "Total": result.passing + result.failing + result.errored,
105                      "Pass Rate": result.pass_rate,
106                      "Passing": result.passing,
107                      "Failing": result.failing,
108                      "Errored": result.errored,
109                  }
110              )
111          return pd.DataFrame(data)
112  
113  
114  class EnumEncoder(json.JSONEncoder):
115      def default(self, obj):
116          if isinstance(obj, Enum):
117              return obj.value
118          return super().default(obj)
119  
120  
121  class RiskAssessment(BaseModel):
122      overview: RedTeamingOverview
123      test_cases: List[RedTeamingTestCase]
124  
125      def __init__(self, **data):
126          super().__init__(**data)
127          self.test_cases = TestCasesList[RedTeamingTestCase](self.test_cases)
128  
129      def save(self, to: str) -> str:
130          try:
131              new_filename = (
132                  datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + ".json"
133              )
134  
135              if not os.path.exists(to):
136                  try:
137                      os.makedirs(to)
138                  except OSError as e:
139                      raise OSError(f"Cannot create directory '{to}': {e}")
140  
141              full_file_path = os.path.join(to, new_filename)
142  
143              # Convert model to a dictionary
144              data = self.model_dump(by_alias=True)
145  
146              # Write to JSON file
147              with open(full_file_path, "w") as f:
148                  json.dump(data, f, indent=2, cls=EnumEncoder)
149  
150              print(
151                  f"šŸŽ‰ Success! šŸŽ‰ Your risk assessment file has been saved to:\nšŸ“ {full_file_path} āœ…"
152              )
153  
154          except OSError as e:
155              raise OSError(f"Failed to save file to '{to}': {e}") from e
156  
157  
158  def construct_risk_assessment_overview(
159      red_teaming_test_cases: List[RedTeamingTestCase],
160  ) -> RedTeamingOverview:
161      # Group test cases by vulnerability type
162      vulnerability_to_cases: Dict[str, List[RedTeamingTestCase]] = {}
163      attack_method_to_cases: Dict[str, List[RedTeamingTestCase]] = {}
164  
165      for test_case in red_teaming_test_cases:
166          # Group by vulnerability type
167          if test_case.vulnerability not in vulnerability_to_cases:
168              vulnerability_to_cases[test_case.vulnerability] = []
169          vulnerability_to_cases[test_case.vulnerability].append(test_case)
170  
171          # Group by attack method
172          if test_case.attack_method not in attack_method_to_cases:
173              attack_method_to_cases[test_case.attack_method] = []
174          attack_method_to_cases[test_case.attack_method].append(test_case)
175  
176      vulnerability_type_results = []
177      attack_method_results = []
178  
179      # Stats per vulnerability type
180      for vuln, test_cases in vulnerability_to_cases.items():
181          passing = sum(
182              1 for tc in test_cases if tc.score is not None and tc.score > 0
183          )
184          errored = sum(1 for tc in test_cases if tc.error is not None)
185          unused = sum(1 for tc in test_cases if (tc.useless and tc.error is None))
186          failing = len(test_cases) - passing - errored - unused
187          valid_cases = passing + failing
188          pass_rate = (passing / valid_cases) if valid_cases > 0 else 0.0
189  
190          vulnerability_type_results.append(
191              VulnerabilityTypeResult(
192                  vulnerability=vuln,
193                  vulnerability_type=test_cases[-1].vulnerability_type if test_cases else "",
194                  pass_rate=pass_rate,
195                  passing=passing,
196                  failing=failing,
197                  errored=errored,
198                  unused=unused,
199              )
200          )
201  
202      # Stats per attack method
203      for attack_method, test_cases in attack_method_to_cases.items():
204          passing = sum(
205              1 for tc in test_cases if tc.score is not None and tc.score > 0
206          )
207          errored = sum(1 for tc in test_cases if tc.error is not None)
208          unused = sum(1 for tc in test_cases if (tc.useless and tc.error is None))
209          failing = len(test_cases) - passing - errored - unused
210          valid_cases = passing + failing
211          pass_rate = (passing / valid_cases) if valid_cases > 0 else 0.0
212  
213          attack_method_results.append(
214              AttackMethodResult(
215                  attack_method=attack_method,
216                  pass_rate=pass_rate,
217                  passing=passing,
218                  failing=failing,
219                  errored=errored,
220                  unused=unused,
221              )
222          )
223  
224      return RedTeamingOverview(
225          vulnerability_type_results=vulnerability_type_results,
226          attack_method_results=attack_method_results,
227      )