sanitize.py
1 """Post-processing LLM-generated Python code implemented using tree-sitter.""" 2 3 import os 4 import pathlib 5 from typing import Dict, Generator, List, Optional, Set, Tuple 6 7 from tqdm import tqdm 8 from tree_sitter import Node 9 from tree_sitter_languages import get_parser 10 11 from data import ( 12 get_bigcodebench, 13 load_solutions, 14 write_directory, 15 write_jsonl, 16 ) 17 from syncheck import syntax_check 18 19 CLASS_TYPE = "class_definition" 20 FUNCTION_TYPE = "function_definition" 21 IMPORT_TYPE = ["import_statement", "import_from_statement"] 22 IDENTIFIER_TYPE = "identifier" 23 ATTRIBUTE_TYPE = "attribute" 24 RETURN_TYPE = "return_statement" 25 EXPRESSION_TYPE = "expression_statement" 26 ASSIGNMENT_TYPE = "assignment" 27 28 29 def code_extract(text: str) -> str: 30 lines = text.split("\n") 31 longest_line_pair = (0, 0) 32 longest_so_far = 0 33 34 for i in range(len(lines)): 35 for j in range(i + 1, len(lines)): 36 current_lines = "\n".join(lines[i : j + 1]) 37 if syntax_check(current_lines): 38 current_length = sum(1 for line in lines[i : j + 1] if line.strip()) 39 if current_length > longest_so_far: 40 longest_so_far = current_length 41 longest_line_pair = (i, j) 42 43 return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1]) 44 45 46 def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]: 47 48 def dfs_get_deps(node: Node, deps: Set[str]) -> None: 49 for child in node.children: 50 if child.type == IDENTIFIER_TYPE: 51 deps.add(child.text.decode("utf8")) 52 else: 53 dfs_get_deps(child, deps) 54 55 name2deps = {} 56 for name, node in nodes: 57 deps = set() 58 dfs_get_deps(node, deps) 59 name2deps[name] = deps 60 return name2deps 61 62 63 def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]: 64 queue = [entrypoint] 65 visited = {entrypoint} 66 while queue: 67 current = queue.pop(0) 68 if current not in call_graph: 69 continue 70 for neighbour in call_graph[current]: 71 if not (neighbour in visited): 72 visited.add(neighbour) 73 queue.append(neighbour) 74 return visited 75 76 77 def get_definition_name(node: Node) -> str: 78 for child in node.children: 79 if child.type == IDENTIFIER_TYPE: 80 return child.text.decode("utf8") 81 82 83 def traverse_tree(node: Node) -> Generator[Node, None, None]: 84 cursor = node.walk() 85 depth = 0 86 87 visited_children = False 88 while True: 89 if not visited_children: 90 yield cursor.node 91 if not cursor.goto_first_child(): 92 depth += 1 93 visited_children = True 94 elif cursor.goto_next_sibling(): 95 visited_children = False 96 elif not cursor.goto_parent() or depth == 0: 97 break 98 else: 99 depth -= 1 100 101 102 def has_return_statement(node: Node) -> bool: 103 traverse_nodes = traverse_tree(node) 104 for node in traverse_nodes: 105 if node.type == RETURN_TYPE: 106 return True 107 return False 108 109 110 def sanitize(code: str, entrypoint: Optional[str] = None) -> str: 111 code = code_extract(code.strip()) 112 code_bytes = bytes(code, "utf8") 113 parser = get_parser("python") 114 tree = parser.parse(code_bytes) 115 class_names = set() 116 function_names = set() 117 variable_names = set() 118 119 root_node = tree.root_node 120 import_nodes = [] 121 definition_nodes = [] 122 123 for child in root_node.children: 124 if child.type in IMPORT_TYPE: 125 import_nodes.append(child) 126 elif child.type == CLASS_TYPE: 127 name = get_definition_name(child) 128 if not (name in class_names or name in variable_names or name in function_names): 129 definition_nodes.append((name, child)) 130 class_names.add(name) 131 elif child.type == FUNCTION_TYPE: 132 name = get_definition_name(child) 133 if not (name in function_names or name in variable_names or name in class_names): 134 definition_nodes.append((name, child)) 135 function_names.add(get_definition_name(child)) 136 elif child.type == EXPRESSION_TYPE and child.children[0].type == ASSIGNMENT_TYPE: 137 subchild = child.children[0] 138 name = get_definition_name(subchild) 139 if not (name in variable_names or name in function_names or name in class_names): 140 definition_nodes.append((name, subchild)) 141 variable_names.add(name) 142 143 if entrypoint: 144 name2deps = get_deps(definition_nodes) 145 reacheable = get_function_dependency(entrypoint, name2deps) 146 147 sanitized_output = b"" 148 149 for node in import_nodes: 150 sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" 151 152 for pair in definition_nodes: 153 name, node = pair 154 if entrypoint and not (name in reacheable): 155 continue 156 sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" 157 158 sanitized_output = sanitized_output[:-1].decode("utf8") 159 160 # ad-hoc approach to remove unnecessary lines, but it works 161 lines = sanitized_output.splitlines() 162 outer_lines = [] 163 for i in range(len(lines) - 1, -1, -1): 164 if lines[i].startswith(" "): 165 break 166 if not lines[i].startswith(" ") and entrypoint in lines[i]: 167 outer_lines.append(i) 168 if outer_lines: 169 sanitized_output = "\n".join(lines[: outer_lines[-1]]) 170 return sanitized_output 171 172 173 def script(samples: str, inplace: bool = False, debug_task: str = None, calibrate: bool = False): 174 # task_id -> entry_point 175 entry_point = {} 176 # merge two datasets 177 dataset = {**get_bigcodebench()} 178 179 for task_id, problem in dataset.items(): 180 entry_point[task_id] = problem["entry_point"] 181 182 # make a new folder with "-sanitized" suffix 183 is_folder = os.path.isdir(samples) 184 target_path = pathlib.Path(samples) 185 if not inplace: 186 if is_folder: 187 if calibrate: 188 new_name = target_path.name + "-sanitized-calibrated" 189 else: 190 new_name = target_path.name + "-sanitized" 191 else: 192 if calibrate: 193 new_name = target_path.name.replace(".jsonl", "-sanitized-calibrated.jsonl") 194 else: 195 new_name = target_path.name.replace(".jsonl", "-sanitized.jsonl") 196 target_path = target_path.parent / new_name 197 target_path = str(target_path) 198 199 nsan = 0 200 ntotal = 0 201 202 new_solutions = [] 203 204 for solution in tqdm(load_solutions(samples)): 205 task_id = solution["task_id"] 206 if task_id not in dataset: 207 print(f"Skiping {task_id} as it does not existing in the latest EvalPlus dataset.") 208 continue 209 210 function_name = entry_point[task_id] if task_id in entry_point else None 211 dbg_identifier = solution["_identifier"] 212 if debug_task is not None and task_id != debug_task: 213 continue 214 215 ntotal += 1 216 if "solution" in solution: 217 old_code = solution["solution"] 218 if calibrate: 219 old_code = solution["solution"].replace("```python\n ", "```python\n" + dataset[task_id]["complete_prompt"] + " ") 220 else: 221 assert "completion" in solution 222 old_code = dataset[task_id]["complete_prompt"] + "\n" + solution["completion"] 223 224 new_code = sanitize(code=old_code, entrypoint=function_name) 225 # if changed, print the message 226 if new_code != old_code: 227 msg = "Sanitized: " + dbg_identifier 228 if is_folder: 229 msg += " -> " + dbg_identifier.replace(samples, target_path) 230 # print(msg) 231 nsan += 1 232 233 new_solutions.append({"task_id": task_id, "solution": new_code}) 234 235 if is_folder: 236 write_directory(target_path, new_solutions) 237 else: 238 write_jsonl(target_path, new_solutions) 239 240 if nsan > 0: 241 print(f"Sanitized {nsan} out of {ntotal} files.") 242 else: 243 print(f"All files seems valid -- no files are sanitized.") 244 print(f"Check the sanitized files at {target_path}") 245 246 247 def main(): 248 from fire import Fire 249 250 Fire(script) 251 252 253 if __name__ == "__main__": 254 main()