sanitize.py
1 """Post-processing LLM-generated Python code implemented using tree-sitter.""" 2 3 import os 4 import pathlib 5 from typing import Dict, Generator, List, Optional, Set, Tuple 6 7 from tqdm import tqdm 8 from tree_sitter import Node 9 from tree_sitter_languages import get_parser 10 11 from data import ( 12 get_bigcodebench, 13 load_solutions, 14 write_directory, 15 write_jsonl, 16 ) 17 from syncheck import syntax_check 18 19 CLASS_TYPE = "class_definition" 20 FUNCTION_TYPE = "function_definition" 21 IMPORT_TYPE = ["import_statement", "import_from_statement"] 22 IDENTIFIER_TYPE = "identifier" 23 ATTRIBUTE_TYPE = "attribute" 24 RETURN_TYPE = "return_statement" 25 EXPRESSION_TYPE = "expression_statement" 26 ASSIGNMENT_TYPE = "assignment" 27 28 29 def code_extract(text: str) -> str: 30 lines = text.split("\n") 31 longest_line_pair = (0, 0) 32 longest_so_far = 0 33 34 for i in range(len(lines)): 35 for j in range(i + 1, len(lines)): 36 current_lines = "\n".join(lines[i : j + 1]) 37 if syntax_check(current_lines): 38 current_length = sum(1 for line in lines[i : j + 1] if line.strip()) 39 if current_length > longest_so_far: 40 longest_so_far = current_length 41 longest_line_pair = (i, j) 42 43 return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1]) 44 45 46 def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]: 47 48 def dfs_get_deps(node: Node, deps: Set[str]) -> None: 49 for child in node.children: 50 if child.type == IDENTIFIER_TYPE: 51 deps.add(child.text.decode("utf8")) 52 else: 53 dfs_get_deps(child, deps) 54 55 name2deps = {} 56 for name, node in nodes: 57 deps = set() 58 dfs_get_deps(node, deps) 59 name2deps[name] = deps 60 return name2deps 61 62 63 def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]: 64 queue = [entrypoint] 65 visited = {entrypoint} 66 while queue: 67 current = queue.pop(0) 68 if current not in call_graph: 69 continue 70 for neighbour in call_graph[current]: 71 if not (neighbour in visited): 72 visited.add(neighbour) 73 queue.append(neighbour) 74 return visited 75 76 77 def get_definition_name(node: Node) -> str: 78 for child in node.children: 79 if child.type == IDENTIFIER_TYPE: 80 return child.text.decode("utf8") 81 82 83 def traverse_tree(node: Node) -> Generator[Node, None, None]: 84 cursor = node.walk() 85 depth = 0 86 87 visited_children = False 88 while True: 89 if not visited_children: 90 yield cursor.node 91 if not cursor.goto_first_child(): 92 depth += 1 93 visited_children = True 94 elif cursor.goto_next_sibling(): 95 visited_children = False 96 elif not cursor.goto_parent() or depth == 0: 97 break 98 else: 99 depth -= 1 100 101 102 def has_return_statement(node: Node) -> bool: 103 traverse_nodes = traverse_tree(node) 104 for node in traverse_nodes: 105 if node.type == RETURN_TYPE: 106 return True 107 return False 108 109 110 def sanitize(code: str, entrypoint: Optional[str] = None) -> str: 111 code = code_extract(code.strip()) 112 code_bytes = bytes(code, "utf8") 113 parser = get_parser("python") 114 tree = parser.parse(code_bytes) 115 class_names = set() 116 function_names = set() 117 variable_names = set() 118 119 root_node = tree.root_node 120 import_nodes = [] 121 definition_nodes = [] 122 123 for child in root_node.children: 124 if child.type in IMPORT_TYPE: 125 import_nodes.append(child) 126 elif child.type == CLASS_TYPE: 127 name = get_definition_name(child) 128 if not ( 129 name in class_names or name in variable_names or name in function_names 130 ): 131 definition_nodes.append((name, child)) 132 class_names.add(name) 133 elif child.type == FUNCTION_TYPE: 134 name = get_definition_name(child) 135 if not ( 136 name in function_names or name in variable_names or name in class_names 137 ): 138 definition_nodes.append((name, child)) 139 function_names.add(get_definition_name(child)) 140 elif ( 141 child.type == EXPRESSION_TYPE and child.children[0].type == ASSIGNMENT_TYPE 142 ): 143 subchild = child.children[0] 144 name = get_definition_name(subchild) 145 if not ( 146 name in variable_names or name in function_names or name in class_names 147 ): 148 definition_nodes.append((name, subchild)) 149 variable_names.add(name) 150 151 if entrypoint: 152 name2deps = get_deps(definition_nodes) 153 reacheable = get_function_dependency(entrypoint, name2deps) 154 155 sanitized_output = b"" 156 157 for node in import_nodes: 158 sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" 159 160 for pair in definition_nodes: 161 name, node = pair 162 if entrypoint and not (name in reacheable): 163 continue 164 sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" 165 166 sanitized_output = sanitized_output[:-1].decode("utf8") 167 168 # ad-hoc approach to remove unnecessary lines, but it works 169 lines = sanitized_output.splitlines() 170 outer_lines = [] 171 for i in range(len(lines) - 1, -1, -1): 172 if lines[i].startswith(" "): 173 break 174 if not lines[i].startswith(" ") and entrypoint in lines[i]: 175 outer_lines.append(i) 176 if outer_lines: 177 sanitized_output = "\n".join(lines[: outer_lines[-1]]) 178 return sanitized_output 179 180 181 def script( 182 samples: str, inplace: bool = False, debug_task: str = None, calibrate: bool = False 183 ): 184 # task_id -> entry_point 185 entry_point = {} 186 # merge two datasets 187 dataset = {**get_bigcodebench()} 188 189 for task_id, problem in dataset.items(): 190 entry_point[task_id] = problem["entry_point"] 191 192 # make a new folder with "-sanitized" suffix 193 is_folder = os.path.isdir(samples) 194 target_path = pathlib.Path(samples) 195 if not inplace: 196 if is_folder: 197 if calibrate: 198 new_name = target_path.name + "-sanitized-calibrated" 199 else: 200 new_name = target_path.name + "-sanitized" 201 else: 202 if calibrate: 203 new_name = target_path.name.replace(".jsonl", "-sanitized-calibrated.jsonl") 204 else: 205 new_name = target_path.name.replace(".jsonl", "-sanitized.jsonl") 206 target_path = target_path.parent / new_name 207 target_path = str(target_path) 208 209 nsan = 0 210 ntotal = 0 211 212 new_solutions = [] 213 214 for solution in tqdm(load_solutions(samples)): 215 task_id = solution["task_id"] 216 if task_id not in dataset: 217 print( 218 f"Skiping {task_id} as it does not existing in the latest EvalPlus dataset." 219 ) 220 continue 221 222 function_name = entry_point[task_id] if task_id in entry_point else None 223 dbg_identifier = solution["_identifier"] 224 if debug_task is not None and task_id != debug_task: 225 continue 226 227 ntotal += 1 228 if "solution" in solution: 229 old_code = solution["solution"] 230 if calibrate: 231 old_code = solution["solution"].replace("```python\n ", "```python\n"+dataset[task_id]["complete_prompt"]+" ") 232 else: 233 assert "completion" in solution 234 old_code = dataset[task_id]["complete_prompt"] + "\n" + solution["completion"] 235 236 new_code = sanitize(code=old_code, entrypoint=function_name) 237 # if changed, print the message 238 if new_code != old_code: 239 msg = "Sanitized: " + dbg_identifier 240 if is_folder: 241 msg += " -> " + dbg_identifier.replace(samples, target_path) 242 print(msg) 243 nsan += 1 244 245 new_solutions.append({"task_id": task_id, "solution": new_code}) 246 247 if is_folder: 248 write_directory(target_path, new_solutions) 249 else: 250 write_jsonl(target_path, new_solutions) 251 252 if nsan > 0: 253 print(f"Sanitized {nsan} out of {ntotal} files.") 254 else: 255 print(f"All files seems valid -- no files are sanitized.") 256 print(f"Check the sanitized files at {target_path}") 257 258 259 def main(): 260 from fire import Fire 261 262 Fire(script) 263 264 265 if __name__ == "__main__": 266 main()