risk_assessment.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 from pydantic import BaseModel, Field 20 from typing import Dict, Optional, List 21 import datetime 22 import os 23 import json 24 from enum import Enum 25 26 from deepteam.vulnerabilities.types import VulnerabilityType 27 28 29 class RedTeamingTestCase(BaseModel): 30 vulnerability: str 31 vulnerability_type: VulnerabilityType 32 risk_category: str = Field(alias="riskCategory") 33 attack_method: Optional[str] = Field(None, alias="attackMethod") 34 original_input: Optional[str] = None 35 input: Optional[str] = None 36 actual_output: Optional[str] = Field( 37 None, serialization_alias="actualOutput" 38 ) 39 score: Optional[float] = None 40 reason: Optional[str] = None 41 error: Optional[str] = None 42 useless: bool = False 43 44 45 class TestCasesList(list): 46 def to_df(self): 47 import pandas as pd 48 49 data = [] 50 for case in self: 51 data.append( 52 { 53 "Vulnerability": case.vulnerability, 54 "Vulnerability Type": str(case.vulnerability_type.value), 55 "Risk Category": case.risk_category, 56 "Attack Enhancement": case.attack_method, 57 "Input": case.input, 58 "Actual Output": case.actual_output, 59 "Score": case.score, 60 "Reason": case.reason, 61 "Error": case.error, 62 "Status": ( 63 "Passed" 64 if case.score and case.score > 0 65 else "Errored" if case.error else "Failed" 66 ), 67 } 68 ) 69 return pd.DataFrame(data) 70 71 72 class VulnerabilityTypeResult(BaseModel): 73 vulnerability: str 74 vulnerability_type: VulnerabilityType 75 pass_rate: float 76 passing: int 77 failing: int 78 errored: int 79 unused: int 80 81 82 class AttackMethodResult(BaseModel): 83 attack_method: Optional[str] = None 84 pass_rate: float 85 passing: int 86 failing: int 87 errored: int 88 unused: int 89 90 91 class RedTeamingOverview(BaseModel): 92 vulnerability_type_results: List[VulnerabilityTypeResult] 93 attack_method_results: List[AttackMethodResult] 94 95 def to_df(self): 96 import pandas as pd 97 98 data = [] 99 for result in self.vulnerability_type_results: 100 data.append( 101 { 102 "Vulnerability": result.vulnerability, 103 "Vulnerability Type": str(result.vulnerability_type.value), 104 "Total": result.passing + result.failing + result.errored, 105 "Pass Rate": result.pass_rate, 106 "Passing": result.passing, 107 "Failing": result.failing, 108 "Errored": result.errored, 109 } 110 ) 111 return pd.DataFrame(data) 112 113 114 class EnumEncoder(json.JSONEncoder): 115 def default(self, obj): 116 if isinstance(obj, Enum): 117 return obj.value 118 return super().default(obj) 119 120 121 class RiskAssessment(BaseModel): 122 overview: RedTeamingOverview 123 test_cases: List[RedTeamingTestCase] 124 125 def __init__(self, **data): 126 super().__init__(**data) 127 self.test_cases = TestCasesList[RedTeamingTestCase](self.test_cases) 128 129 def save(self, to: str) -> str: 130 try: 131 new_filename = ( 132 datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + ".json" 133 ) 134 135 if not os.path.exists(to): 136 try: 137 os.makedirs(to) 138 except OSError as e: 139 raise OSError(f"Cannot create directory '{to}': {e}") 140 141 full_file_path = os.path.join(to, new_filename) 142 143 # Convert model to a dictionary 144 data = self.model_dump(by_alias=True) 145 146 # Write to JSON file 147 with open(full_file_path, "w") as f: 148 json.dump(data, f, indent=2, cls=EnumEncoder) 149 150 print( 151 f"š Success! š Your risk assessment file has been saved to:\nš {full_file_path} ā " 152 ) 153 154 except OSError as e: 155 raise OSError(f"Failed to save file to '{to}': {e}") from e 156 157 158 def construct_risk_assessment_overview( 159 red_teaming_test_cases: List[RedTeamingTestCase], 160 ) -> RedTeamingOverview: 161 # Group test cases by vulnerability type 162 vulnerability_to_cases: Dict[str, List[RedTeamingTestCase]] = {} 163 attack_method_to_cases: Dict[str, List[RedTeamingTestCase]] = {} 164 165 for test_case in red_teaming_test_cases: 166 # Group by vulnerability type 167 if test_case.vulnerability not in vulnerability_to_cases: 168 vulnerability_to_cases[test_case.vulnerability] = [] 169 vulnerability_to_cases[test_case.vulnerability].append(test_case) 170 171 # Group by attack method 172 if test_case.attack_method not in attack_method_to_cases: 173 attack_method_to_cases[test_case.attack_method] = [] 174 attack_method_to_cases[test_case.attack_method].append(test_case) 175 176 vulnerability_type_results = [] 177 attack_method_results = [] 178 179 # Stats per vulnerability type 180 for vuln, test_cases in vulnerability_to_cases.items(): 181 passing = sum( 182 1 for tc in test_cases if tc.score is not None and tc.score > 0 183 ) 184 errored = sum(1 for tc in test_cases if tc.error is not None) 185 unused = sum(1 for tc in test_cases if (tc.useless and tc.error is None)) 186 failing = len(test_cases) - passing - errored - unused 187 valid_cases = passing + failing 188 pass_rate = (passing / valid_cases) if valid_cases > 0 else 0.0 189 190 vulnerability_type_results.append( 191 VulnerabilityTypeResult( 192 vulnerability=vuln, 193 vulnerability_type=test_cases[-1].vulnerability_type if test_cases else "", 194 pass_rate=pass_rate, 195 passing=passing, 196 failing=failing, 197 errored=errored, 198 unused=unused, 199 ) 200 ) 201 202 # Stats per attack method 203 for attack_method, test_cases in attack_method_to_cases.items(): 204 passing = sum( 205 1 for tc in test_cases if tc.score is not None and tc.score > 0 206 ) 207 errored = sum(1 for tc in test_cases if tc.error is not None) 208 unused = sum(1 for tc in test_cases if (tc.useless and tc.error is None)) 209 failing = len(test_cases) - passing - errored - unused 210 valid_cases = passing + failing 211 pass_rate = (passing / valid_cases) if valid_cases > 0 else 0.0 212 213 attack_method_results.append( 214 AttackMethodResult( 215 attack_method=attack_method, 216 pass_rate=pass_rate, 217 passing=passing, 218 failing=failing, 219 errored=errored, 220 unused=unused, 221 ) 222 ) 223 224 return RedTeamingOverview( 225 vulnerability_type_results=vulnerability_type_results, 226 attack_method_results=attack_method_results, 227 )