sanitize.py
  1  """Post-processing LLM-generated Python code implemented using tree-sitter."""
  2  
  3  import os
  4  import pathlib
  5  from typing import Dict, Generator, List, Optional, Set, Tuple
  6  
  7  from tqdm import tqdm
  8  from tree_sitter import Node
  9  from tree_sitter_languages import get_parser
 10  
 11  from data import (
 12      get_bigcodebench,
 13      load_solutions,
 14      write_directory,
 15      write_jsonl,
 16  )
 17  from syncheck import syntax_check
 18  
 19  CLASS_TYPE = "class_definition"
 20  FUNCTION_TYPE = "function_definition"
 21  IMPORT_TYPE = ["import_statement", "import_from_statement"]
 22  IDENTIFIER_TYPE = "identifier"
 23  ATTRIBUTE_TYPE = "attribute"
 24  RETURN_TYPE = "return_statement"
 25  EXPRESSION_TYPE = "expression_statement"
 26  ASSIGNMENT_TYPE = "assignment"
 27  
 28  
 29  def code_extract(text: str) -> str:
 30      lines = text.split("\n")
 31      longest_line_pair = (0, 0)
 32      longest_so_far = 0
 33  
 34      for i in range(len(lines)):
 35          for j in range(i + 1, len(lines)):
 36              current_lines = "\n".join(lines[i : j + 1])
 37              if syntax_check(current_lines):
 38                  current_length = sum(1 for line in lines[i : j + 1] if line.strip())
 39                  if current_length > longest_so_far:
 40                      longest_so_far = current_length
 41                      longest_line_pair = (i, j)
 42  
 43      return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])
 44  
 45  
 46  def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
 47  
 48      def dfs_get_deps(node: Node, deps: Set[str]) -> None:
 49          for child in node.children:
 50              if child.type == IDENTIFIER_TYPE:
 51                  deps.add(child.text.decode("utf8"))
 52              else:
 53                  dfs_get_deps(child, deps)
 54  
 55      name2deps = {}
 56      for name, node in nodes:
 57          deps = set()
 58          dfs_get_deps(node, deps)
 59          name2deps[name] = deps
 60      return name2deps
 61  
 62  
 63  def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]:
 64      queue = [entrypoint]
 65      visited = {entrypoint}
 66      while queue:
 67          current = queue.pop(0)
 68          if current not in call_graph:
 69              continue
 70          for neighbour in call_graph[current]:
 71              if not (neighbour in visited):
 72                  visited.add(neighbour)
 73                  queue.append(neighbour)
 74      return visited
 75  
 76  
 77  def get_definition_name(node: Node) -> str:
 78      for child in node.children:
 79          if child.type == IDENTIFIER_TYPE:
 80              return child.text.decode("utf8")
 81  
 82  
 83  def traverse_tree(node: Node) -> Generator[Node, None, None]:
 84      cursor = node.walk()
 85      depth = 0
 86  
 87      visited_children = False
 88      while True:
 89          if not visited_children:
 90              yield cursor.node
 91              if not cursor.goto_first_child():
 92                  depth += 1
 93                  visited_children = True
 94          elif cursor.goto_next_sibling():
 95              visited_children = False
 96          elif not cursor.goto_parent() or depth == 0:
 97              break
 98          else:
 99              depth -= 1
100  
101  
102  def has_return_statement(node: Node) -> bool:
103      traverse_nodes = traverse_tree(node)
104      for node in traverse_nodes:
105          if node.type == RETURN_TYPE:
106              return True
107      return False
108  
109  
110  def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
111      code = code_extract(code.strip())
112      code_bytes = bytes(code, "utf8")
113      parser = get_parser("python")
114      tree = parser.parse(code_bytes)
115      class_names = set()
116      function_names = set()
117      variable_names = set()
118  
119      root_node = tree.root_node
120      import_nodes = []
121      definition_nodes = []
122  
123      for child in root_node.children:
124          if child.type in IMPORT_TYPE:
125              import_nodes.append(child)
126          elif child.type == CLASS_TYPE:
127              name = get_definition_name(child)
128              if not (
129                  name in class_names or name in variable_names or name in function_names
130              ):
131                  definition_nodes.append((name, child))
132                  class_names.add(name)
133          elif child.type == FUNCTION_TYPE:
134              name = get_definition_name(child)
135              if not (
136                  name in function_names or name in variable_names or name in class_names
137              ):
138                  definition_nodes.append((name, child))
139                  function_names.add(get_definition_name(child))
140          elif (
141              child.type == EXPRESSION_TYPE and child.children[0].type == ASSIGNMENT_TYPE
142          ):
143              subchild = child.children[0]
144              name = get_definition_name(subchild)
145              if not (
146                  name in variable_names or name in function_names or name in class_names
147              ):
148                  definition_nodes.append((name, subchild))
149                  variable_names.add(name)
150  
151      if entrypoint:
152          name2deps = get_deps(definition_nodes)
153          reacheable = get_function_dependency(entrypoint, name2deps)
154  
155      sanitized_output = b""
156  
157      for node in import_nodes:
158          sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
159  
160      for pair in definition_nodes:
161          name, node = pair
162          if entrypoint and not (name in reacheable):
163              continue
164          sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
165          
166      sanitized_output = sanitized_output[:-1].decode("utf8")
167      
168      # ad-hoc approach to remove unnecessary lines, but it works
169      lines = sanitized_output.splitlines()
170      outer_lines = []
171      for i in range(len(lines) - 1, -1, -1):
172          if lines[i].startswith(" "):
173              break
174          if not lines[i].startswith(" ") and entrypoint in lines[i]:
175              outer_lines.append(i)
176      if outer_lines:
177          sanitized_output = "\n".join(lines[: outer_lines[-1]])
178      return sanitized_output
179  
180  
181  def script(
182      samples: str, inplace: bool = False, debug_task: str = None, calibrate: bool = False
183  ):
184      # task_id -> entry_point
185      entry_point = {}
186      # merge two datasets
187      dataset = {**get_bigcodebench()}
188  
189      for task_id, problem in dataset.items():
190          entry_point[task_id] = problem["entry_point"]
191  
192      # make a new folder with "-sanitized" suffix
193      is_folder = os.path.isdir(samples)
194      target_path = pathlib.Path(samples)
195      if not inplace:
196          if is_folder:
197              if calibrate:
198                  new_name = target_path.name + "-sanitized-calibrated"
199              else:
200                  new_name = target_path.name + "-sanitized"
201          else:
202              if calibrate:
203                  new_name = target_path.name.replace(".jsonl", "-sanitized-calibrated.jsonl")
204              else:
205                  new_name = target_path.name.replace(".jsonl", "-sanitized.jsonl")
206          target_path = target_path.parent / new_name
207      target_path = str(target_path)
208  
209      nsan = 0
210      ntotal = 0
211  
212      new_solutions = []
213  
214      for solution in tqdm(load_solutions(samples)):
215          task_id = solution["task_id"]
216          if task_id not in dataset:
217              print(
218                  f"Skiping {task_id} as it does not existing in the latest EvalPlus dataset."
219              )
220              continue
221  
222          function_name = entry_point[task_id] if task_id in entry_point else None
223          dbg_identifier = solution["_identifier"]
224          if debug_task is not None and task_id != debug_task:
225              continue
226  
227          ntotal += 1
228          if "solution" in solution:
229              old_code = solution["solution"]
230              if calibrate:
231                  old_code = solution["solution"].replace("```python\n    ", "```python\n"+dataset[task_id]["complete_prompt"]+"    ")
232          else:
233              assert "completion" in solution
234              old_code = dataset[task_id]["complete_prompt"] + "\n" + solution["completion"]
235  
236          new_code = sanitize(code=old_code, entrypoint=function_name)
237          # if changed, print the message
238          if new_code != old_code:
239              msg = "Sanitized: " + dbg_identifier
240              if is_folder:
241                  msg += " -> " + dbg_identifier.replace(samples, target_path)
242              print(msg)
243              nsan += 1
244  
245          new_solutions.append({"task_id": task_id, "solution": new_code})
246  
247      if is_folder:
248          write_directory(target_path, new_solutions)
249      else:
250          write_jsonl(target_path, new_solutions)
251  
252      if nsan > 0:
253          print(f"Sanitized {nsan} out of {ntotal} files.")
254      else:
255          print(f"All files seems valid -- no files are sanitized.")
256      print(f"Check the sanitized files at {target_path}")
257  
258  
259  def main():
260      from fire import Fire
261  
262      Fire(script)
263  
264  
265  if __name__ == "__main__":
266      main()