sanitize.py
  1  """Post-processing LLM-generated Python code implemented using tree-sitter."""
  2  
  3  import os
  4  import pathlib
  5  from typing import Dict, Generator, List, Optional, Set, Tuple
  6  
  7  from tqdm import tqdm
  8  from tree_sitter import Node
  9  from tree_sitter_languages import get_parser
 10  
 11  from data import (
 12      get_bigcodebench,
 13      load_solutions,
 14      write_directory,
 15      write_jsonl,
 16  )
 17  from syncheck import syntax_check
 18  
 19  CLASS_TYPE = "class_definition"
 20  FUNCTION_TYPE = "function_definition"
 21  IMPORT_TYPE = ["import_statement", "import_from_statement"]
 22  IDENTIFIER_TYPE = "identifier"
 23  ATTRIBUTE_TYPE = "attribute"
 24  RETURN_TYPE = "return_statement"
 25  EXPRESSION_TYPE = "expression_statement"
 26  ASSIGNMENT_TYPE = "assignment"
 27  
 28  
 29  def code_extract(text: str) -> str:
 30      lines = text.split("\n")
 31      longest_line_pair = (0, 0)
 32      longest_so_far = 0
 33  
 34      for i in range(len(lines)):
 35          for j in range(i + 1, len(lines)):
 36              current_lines = "\n".join(lines[i : j + 1])
 37              if syntax_check(current_lines):
 38                  current_length = sum(1 for line in lines[i : j + 1] if line.strip())
 39                  if current_length > longest_so_far:
 40                      longest_so_far = current_length
 41                      longest_line_pair = (i, j)
 42  
 43      return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])
 44  
 45  
 46  def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
 47  
 48      def dfs_get_deps(node: Node, deps: Set[str]) -> None:
 49          for child in node.children:
 50              if child.type == IDENTIFIER_TYPE:
 51                  deps.add(child.text.decode("utf8"))
 52              else:
 53                  dfs_get_deps(child, deps)
 54  
 55      name2deps = {}
 56      for name, node in nodes:
 57          deps = set()
 58          dfs_get_deps(node, deps)
 59          name2deps[name] = deps
 60      return name2deps
 61  
 62  
 63  def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]:
 64      queue = [entrypoint]
 65      visited = {entrypoint}
 66      while queue:
 67          current = queue.pop(0)
 68          if current not in call_graph:
 69              continue
 70          for neighbour in call_graph[current]:
 71              if not (neighbour in visited):
 72                  visited.add(neighbour)
 73                  queue.append(neighbour)
 74      return visited
 75  
 76  
 77  def get_definition_name(node: Node) -> str:
 78      for child in node.children:
 79          if child.type == IDENTIFIER_TYPE:
 80              return child.text.decode("utf8")
 81  
 82  
 83  def traverse_tree(node: Node) -> Generator[Node, None, None]:
 84      cursor = node.walk()
 85      depth = 0
 86  
 87      visited_children = False
 88      while True:
 89          if not visited_children:
 90              yield cursor.node
 91              if not cursor.goto_first_child():
 92                  depth += 1
 93                  visited_children = True
 94          elif cursor.goto_next_sibling():
 95              visited_children = False
 96          elif not cursor.goto_parent() or depth == 0:
 97              break
 98          else:
 99              depth -= 1
100  
101  
102  def has_return_statement(node: Node) -> bool:
103      traverse_nodes = traverse_tree(node)
104      for node in traverse_nodes:
105          if node.type == RETURN_TYPE:
106              return True
107      return False
108  
109  
110  def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
111      code = code_extract(code.strip())
112      code_bytes = bytes(code, "utf8")
113      parser = get_parser("python")
114      tree = parser.parse(code_bytes)
115      class_names = set()
116      function_names = set()
117      variable_names = set()
118  
119      root_node = tree.root_node
120      import_nodes = []
121      definition_nodes = []
122  
123      for child in root_node.children:
124          if child.type in IMPORT_TYPE:
125              import_nodes.append(child)
126          elif child.type == CLASS_TYPE:
127              name = get_definition_name(child)
128              if not (name in class_names or name in variable_names or name in function_names):
129                  definition_nodes.append((name, child))
130                  class_names.add(name)
131          elif child.type == FUNCTION_TYPE:
132              name = get_definition_name(child)
133              if not (name in function_names or name in variable_names or name in class_names):
134                  definition_nodes.append((name, child))
135                  function_names.add(get_definition_name(child))
136          elif child.type == EXPRESSION_TYPE and child.children[0].type == ASSIGNMENT_TYPE:
137              subchild = child.children[0]
138              name = get_definition_name(subchild)
139              if not (name in variable_names or name in function_names or name in class_names):
140                  definition_nodes.append((name, subchild))
141                  variable_names.add(name)
142  
143      if entrypoint:
144          name2deps = get_deps(definition_nodes)
145          reacheable = get_function_dependency(entrypoint, name2deps)
146  
147      sanitized_output = b""
148  
149      for node in import_nodes:
150          sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
151  
152      for pair in definition_nodes:
153          name, node = pair
154          if entrypoint and not (name in reacheable):
155              continue
156          sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
157  
158      sanitized_output = sanitized_output[:-1].decode("utf8")
159  
160      # ad-hoc approach to remove unnecessary lines, but it works
161      lines = sanitized_output.splitlines()
162      outer_lines = []
163      for i in range(len(lines) - 1, -1, -1):
164          if lines[i].startswith(" "):
165              break
166          if not lines[i].startswith(" ") and entrypoint in lines[i]:
167              outer_lines.append(i)
168      if outer_lines:
169          sanitized_output = "\n".join(lines[: outer_lines[-1]])
170      return sanitized_output
171  
172  
173  def script(samples: str, inplace: bool = False, debug_task: str = None, calibrate: bool = False):
174      # task_id -> entry_point
175      entry_point = {}
176      # merge two datasets
177      dataset = {**get_bigcodebench()}
178  
179      for task_id, problem in dataset.items():
180          entry_point[task_id] = problem["entry_point"]
181  
182      # make a new folder with "-sanitized" suffix
183      is_folder = os.path.isdir(samples)
184      target_path = pathlib.Path(samples)
185      if not inplace:
186          if is_folder:
187              if calibrate:
188                  new_name = target_path.name + "-sanitized-calibrated"
189              else:
190                  new_name = target_path.name + "-sanitized"
191          else:
192              if calibrate:
193                  new_name = target_path.name.replace(".jsonl", "-sanitized-calibrated.jsonl")
194              else:
195                  new_name = target_path.name.replace(".jsonl", "-sanitized.jsonl")
196          target_path = target_path.parent / new_name
197      target_path = str(target_path)
198  
199      nsan = 0
200      ntotal = 0
201  
202      new_solutions = []
203  
204      for solution in tqdm(load_solutions(samples)):
205          task_id = solution["task_id"]
206          if task_id not in dataset:
207              print(f"Skiping {task_id} as it does not existing in the latest EvalPlus dataset.")
208              continue
209  
210          function_name = entry_point[task_id] if task_id in entry_point else None
211          dbg_identifier = solution["_identifier"]
212          if debug_task is not None and task_id != debug_task:
213              continue
214  
215          ntotal += 1
216          if "solution" in solution:
217              old_code = solution["solution"]
218              if calibrate:
219                  old_code = solution["solution"].replace("```python\n    ", "```python\n" + dataset[task_id]["complete_prompt"] + "    ")
220          else:
221              assert "completion" in solution
222              old_code = dataset[task_id]["complete_prompt"] + "\n" + solution["completion"]
223  
224          new_code = sanitize(code=old_code, entrypoint=function_name)
225          # if changed, print the message
226          if new_code != old_code:
227              msg = "Sanitized: " + dbg_identifier
228              if is_folder:
229                  msg += " -> " + dbg_identifier.replace(samples, target_path)
230              # print(msg)
231              nsan += 1
232  
233          new_solutions.append({"task_id": task_id, "solution": new_code})
234  
235      if is_folder:
236          write_directory(target_path, new_solutions)
237      else:
238          write_jsonl(target_path, new_solutions)
239  
240      if nsan > 0:
241          print(f"Sanitized {nsan} out of {ntotal} files.")
242      else:
243          print(f"All files seems valid -- no files are sanitized.")
244      print(f"Check the sanitized files at {target_path}")
245  
246  
247  def main():
248      from fire import Fire
249  
250      Fire(script)
251  
252  
253  if __name__ == "__main__":
254      main()