/ qwencoder-eval / base / benchmarks / fim-bench / eval_metric.py
eval_metric.py
  1  import re
  2  import sys
  3  import json
  4  import timeout_decorator
  5  import numpy as np
  6  
  7  from tqdm import tqdm
  8  from typing import Callable, List
  9  from fuzzywuzzy import fuzz
 10  import editdistance
 11  from functools import partial
 12  import torch.multiprocessing as mp
 13  from tree_sitter import Language, Parser
 14  from typing import List, Callable, Union
 15  from tree_sitter.binding import Node as TSNode
 16  import os
 17  
 18  parser = None
 19  
 20  
 21  def cal_edit_sim(references, hypotheses):
 22      total = len(references)
 23      edit_sim = 0.0
 24      for pred, gt in zip(hypotheses, references):
 25          pred = pred.strip()
 26          gt = gt.strip()
 27          edit_sim += fuzz.ratio(pred, gt)
 28      return edit_sim / total
 29  
 30  
 31  def cal_edit_sim_repoeval(references, hypotheses):
 32      total = len(references)
 33      edit_sim = 0.0
 34      for pred, gt in zip(hypotheses, references):
 35          pred = pred.strip()
 36          gt = gt.strip()
 37          if max(len(pred), len(gt)) == 0:
 38              continue
 39          edit_sim += 1 - editdistance.eval(pred, gt) / max(len(pred), len(gt))
 40      return edit_sim / total
 41  
 42  
 43  def tokenize_code(code):
 44      code = re.sub(r"([^A-Za-z0-9_])", r" \1 ", code)
 45      code = re.sub(r"([a-z])([A-Z])", r"\1 \2", code)
 46      code = re.sub(r"\s+", " ", code)
 47      code = code.replace('"', "`")
 48      code = code.replace("'", "`")
 49      tokens = [t for t in code.split(" ") if t]
 50      return tokens
 51  
 52  
 53  def cal_exact_match(references, hypotheses):
 54      em_score = []
 55      for pred, gold in zip(hypotheses, references):
 56          em_score.append(tokenize_code(pred) == tokenize_code(gold))
 57      return np.mean(em_score)
 58  
 59  
 60  def remove_comments(code):
 61      code = re.sub(r'#.*', '', code)
 62      return code
 63  
 64  
 65  def is_parse_valid(parser, code):
 66      def syntax_error(node):
 67          if node.type == "ERROR":
 68              return True
 69          try:
 70              for child in node.children:
 71                  if syntax_error(child):
 72                      return True
 73          except RecursionError as err:
 74              return True
 75  
 76          return False
 77  
 78      tree = get_ast(parser, code)
 79      if tree is not None:
 80          return not syntax_error(tree.root_node)
 81      return False
 82  
 83  
 84  def get_valid_completion(prompt, completion, parser):
 85      for i in range(len(completion), -1, -1):
 86          code = prompt + completion[:i]
 87          if is_parse_valid(parser, code):
 88              return "parseable", completion[:i].rstrip()
 89  
 90      return "not_parseable", completion
 91  
 92  
 93  def dfs(
 94          node: TSNode,
 95          node_types: List[str],
 96          callback: Callable,
 97          ignore_node_types: List[str] = None,
 98  ):
 99      """
100      Helper to traverse parsed AST
101      """
102      if node.type in node_types:
103          callback(node)
104  
105      for child in node.children:
106          if not ignore_node_types or child.type not in ignore_node_types:
107              dfs(child, node_types, callback, ignore_node_types)
108  
109  
110  def collect_nodes(root_node, node_types, ignore_node_types=None):
111      """
112      Collect all nodes that belong to certain types
113      """
114      result = list()
115  
116      def _cb(n):
117          result.append(n)
118  
119      if root_node is not None:
120          try:
121              dfs(root_node, node_types, _cb, ignore_node_types)
122          except RecursionError as err:
123              print('collection of nodes failed due to RecursionError')
124              return []
125  
126      return result
127  
128  @timeout_decorator.timeout(5)
129  def get_ast(parser, code):
130      assert isinstance(code, str) or isinstance(code, bytes)
131      if isinstance(code, str):
132          code = bytes(code, "utf8")
133      try:
134          tree = parser.parse(code)
135          return tree
136      except Exception as e:
137          return None
138  
139  
140  def get_functions(parser, code):
141      """
142      This function returns all functions (irrespective of whether they are inside a class) in a dict format.
143      :param code:
144      :return: Dict()
145      """
146      try:
147          tree = get_ast(parser, code)
148      except:
149          return []
150      if tree is None:
151          return []
152  
153      functions = []
154      function_nodes = collect_nodes(tree.root_node, ["function_definition"])
155      for fnode in function_nodes:
156          assert fnode.children[-1].type == "block"
157          body_text = fnode.children[-1].text.decode("utf-8")
158          functions.append(body_text)
159  
160      return functions
161  
162  
163  def get_function_completion(prompt, completion, parser):
164      code = prompt + "pass"
165      target_fn_idx = len(get_functions(parser, code)) - 1
166      # assert target_fn_idx != -1
167  
168      code = prompt + completion
169      function_body = get_functions(parser, code)[target_fn_idx]
170      return function_body
171  
172  
173  def process_examples(task, args):
174      sample, ex = args
175      global parser
176  
177      prediction = sample["pred"]
178      target = ex["groundtruth"]
179      origin = ""
180  
181      if task == "function_completion":
182          status, prediction = get_valid_completion(ex["prompt"], prediction, parser)
183          if status == "parseable":
184              try:
185                  origin = prediction
186                  prediction = get_function_completion(ex["prompt"], prediction, parser)
187                  target = get_function_completion(ex["prompt"], target, parser)
188              except:
189                  print(f'[warning] parsing failed: task_id:{ex["task_id"]}')
190          else:
191              print(f'[warning] parsing failed: task_id:{ex["task_id"]}')
192      else:
193          num_target_lines = sum([1 for l in target.split("\n") if l.strip()])
194          pred_lines = [l for l in prediction.split("\n") if l.strip()][:num_target_lines]
195          prediction = "\n".join(pred_lines)
196      
197      trunc_s = {
198          "task_id": sample["task_id"],
199          "pred": prediction,
200          "target": target,
201          #"origin": origin
202      }
203      
204      return trunc_s
205  
206  
207  def compute_metric_stmt(args):
208      all_task_results = {}
209      
210      for task in args.tasks:
211          print(f"\nComputing metrics for task: {task}")
212          with open(f"{args.output_dir}/{args.dataset}/{args.language}/{task}/prediction.jsonl", "r") as f_pred:
213              samples = []
214              for l in f_pred.readlines():
215                  samples.append(json.loads(l))
216  
217          # 构建task特定的prompt文件路径
218          task_prompt_file = args.prompt_file.replace('TASK', task)
219          examples = {}
220          with open(task_prompt_file, "r") as f_in:
221              for l in f_in.readlines():
222                  ex = json.loads(l)
223                  if hasattr(args, "focused_repo") and args.focused_repo and args.focused_repo not in re.sub('/', '_', ex['metadata']['repository']):
224                      continue
225                  examples[ex["metadata"]["task_id"]] = {
226                      "task_id": ex["metadata"]["task_id"],
227                      "prompt": ex["prompt"],
228                      "groundtruth": ex["groundtruth"]
229                  }
230  
231          if len(samples) == len(examples):
232              print('Warning: len(samples) ({}) == len(examples) ({})'.format(len(samples), len(examples)))
233  
234          global parser
235          ts_lang = args.language
236          if ts_lang == 'csharp':
237              ts_lang = 'c_sharp'
238          language = Language(args.ts_lib, ts_lang)
239          parser = Parser()
240          parser.set_language(language)
241  
242          truncated_samples = []
243          print("post-processing samples ...")
244          pool = mp.Pool(mp.cpu_count() - 1)
245          worker = partial(process_examples, task)
246  
247          with tqdm(total=len(samples)) as pbar:
248              for trunc_s in pool.imap_unordered(worker, zip(samples, [examples[s["task_id"]] for s in samples])):
249                  truncated_samples.append(trunc_s)
250                  pbar.update()
251  
252          pool.close()
253          pool.join()
254  
255          task_output_dir = os.path.join(args.output_dir, args.dataset, args.language, task)
256          # with open(f"{task_output_dir}/prediction_truncated.jsonl", 'w', encoding="utf-8") as pt:
257          #     for trunc_s in truncated_samples:
258          #         pt.write(json.dumps(trunc_s) + "\n")
259  
260          ### Score calculation
261          detailed_results = []
262          exact_match = 0
263          edit_sim = 0
264          edit_sim_repoeval = 0
265  
266          for idx, trunc_s in enumerate(truncated_samples):
267              es = cal_edit_sim([trunc_s["target"]], [trunc_s["pred"]])
268              es_repoeval = cal_edit_sim_repoeval([trunc_s["target"]], [trunc_s["pred"]])
269              em = cal_exact_match([trunc_s["target"]], [trunc_s["pred"]])
270              edit_sim += es
271              edit_sim_repoeval += es_repoeval
272              exact_match += em
273  
274              detailed_results.append({
275                  "task_id": trunc_s["task_id"],
276                  "pred": trunc_s["pred"],
277                  "target": trunc_s["target"],
278                  "em": em,
279                  "es": es,
280                  "es_repoeval": es_repoeval
281              })
282  
283          total_samples = len(truncated_samples)
284          em_ratio = round(exact_match / total_samples * 100, 2)
285          edit_sim_avg = round(edit_sim / total_samples, 2)
286          edit_sim_repoeval_avg = round(edit_sim_repoeval / total_samples * 100, 2)
287  
288          print(
289              f"Code Matching for {task}: "
290              f"EM {em_ratio:.2f}, "
291              f"ES {edit_sim_avg:.2f}, "
292              f"ES RepoEval {edit_sim_repoeval_avg:.2f}"
293          )
294  
295          # 保存详细结果
296          with open(f"{task_output_dir}/detailed_results.json", 'w') as f:
297              for dr in detailed_results:
298                  f.write(json.dumps(dr) + "\n")
299  
300          # 保存任务级别的结果
301          task_results = {
302              "em": em_ratio,
303              "es": edit_sim_avg,
304              "es_repoeval": edit_sim_repoeval_avg,
305              "total": total_samples
306          }
307          
308          with open(f"{task_output_dir}/results.json", 'w') as f:
309              json.dump(task_results, f, indent=2)
310              
311          # 将当前任务的结果添加到总结果字典中
312          all_task_results[task] = task_results
313  
314      # 计算所有任务的加权平均值
315      total_samples = sum(res["total"] for res in all_task_results.values())
316      weighted_em = sum(res["em"] * res["total"] for res in all_task_results.values()) / total_samples
317      weighted_es = sum(res["es"] * res["total"] for res in all_task_results.values()) / total_samples
318      weighted_es_repoeval = sum(res["es_repoeval"] * res["total"] for res in all_task_results.values()) / total_samples
319  
320      # 创建最终的合并结果
321      merged_results = {
322          "overall": {
323              "em": round(weighted_em, 4),
324              "es": round(weighted_es, 4),
325              "es_repoeval": round(weighted_es_repoeval, 4),
326              "total": total_samples
327          },
328          "per_task": all_task_results
329      }
330  
331      # 保存合并后的结果
332      with open(f"{args.output_dir}/{args.dataset}/results.json", 'w') as f:
333          json.dump(merged_results, f, indent=2)
334  
335      print("\nOverall Results (Weighted Average):")
336      print(f"EM: {weighted_em:.2f}")
337      print(f"ES: {weighted_es:.2f}")
338      print(f"ES RepoEval: {weighted_es_repoeval:.2f}")
339      print(f"Total Samples: {total_samples}")
340  
341  def compute_metric_stmt_multilang(args):
342      all_task_results = {}
343      
344      for language in args.languages:
345          print(f"\nComputing metrics for language: {language}")
346          with open(f"{args.output_dir}/{args.dataset}/{language}/{args.task}/prediction.jsonl", "r") as f_pred:
347              samples = []
348              for l in f_pred.readlines():
349                  samples.append(json.loads(l))
350  
351          # 构建task特定的prompt文件路径
352          task_prompt_file = args.prompt_file.replace('LANGUAGE', language)
353          examples = {}
354          with open(task_prompt_file, "r") as f_in:
355              for l in f_in.readlines():
356                  ex = json.loads(l)
357                  if hasattr(args, "focused_repo") and args.focused_repo and args.focused_repo not in re.sub('/', '_', ex['metadata']['repository']):
358                      continue
359                  examples[ex["metadata"]["task_id"]] = {
360                      "task_id": ex["metadata"]["task_id"],
361                      "prompt": ex["prompt"],
362                      "groundtruth": ex["groundtruth"]
363                  }
364  
365          if len(samples) == len(examples):
366              print('Warning: len(samples) ({}) == len(examples) ({})'.format(len(samples), len(examples)))
367  
368          global parser
369          ts_lang = language
370          if ts_lang == 'csharp':
371              ts_lang = 'c_sharp'
372  
373          ts_lib = args.ts_lib.replace('LANGUAGE', language)
374          language_ts = Language(ts_lib, ts_lang)
375          parser = Parser()
376          parser.set_language(language_ts)
377  
378          truncated_samples = []
379          print("post-processing samples ...")
380          pool = mp.Pool(mp.cpu_count() - 1)
381          worker = partial(process_examples, args.task)
382  
383          with tqdm(total=len(samples)) as pbar:
384              for trunc_s in pool.imap_unordered(worker, zip(samples, [examples[s["task_id"]] for s in samples])):
385                  truncated_samples.append(trunc_s)
386                  pbar.update()
387  
388          pool.close()
389          pool.join()
390  
391          task_output_dir = os.path.join(args.output_dir, args.dataset, language, args.task)
392          # with open(f"{task_output_dir}/prediction_truncated.jsonl", 'w', encoding="utf-8") as pt:
393          #     for trunc_s in truncated_samples:
394          #         pt.write(json.dumps(trunc_s) + "\n")
395  
396          ### Score calculation
397          detailed_results = []
398          exact_match = 0
399          edit_sim = 0
400          edit_sim_repoeval = 0
401  
402          for idx, trunc_s in enumerate(truncated_samples):
403              es = cal_edit_sim([trunc_s["target"]], [trunc_s["pred"]])
404              es_repoeval = cal_edit_sim_repoeval([trunc_s["target"]], [trunc_s["pred"]])
405              em = cal_exact_match([trunc_s["target"]], [trunc_s["pred"]])
406              edit_sim += es
407              edit_sim_repoeval += es_repoeval
408              exact_match += em
409  
410              detailed_results.append({
411                  "task_id": trunc_s["task_id"],
412                  "pred": trunc_s["pred"],
413                  "target": trunc_s["target"],
414                  "em": em,
415                  "es": es,
416                  "es_repoeval": es_repoeval
417              })
418  
419          total_samples = len(truncated_samples)
420          em_ratio = round(exact_match / total_samples * 100, 2)
421          edit_sim_avg = round(edit_sim / total_samples, 2)
422          edit_sim_repoeval_avg = round(edit_sim_repoeval / total_samples * 100, 2)
423  
424          print(
425              f"Code Matching for {language}: "
426              f"EM {em_ratio:.2f}, "
427              f"ES {edit_sim_avg:.2f}, "
428              f"ES RepoEval {edit_sim_repoeval_avg:.2f}"
429          )
430  
431          # 保存详细结果
432          with open(f"{task_output_dir}/detailed_results.json", 'w') as f:
433              for dr in detailed_results:
434                  f.write(json.dumps(dr) + "\n")
435  
436          # 保存任务级别的结果
437          task_results = {
438              "em": em_ratio,
439              "es": edit_sim_avg,
440              "es_repoeval": edit_sim_repoeval_avg,
441              "total": total_samples
442          }
443          
444          with open(f"{task_output_dir}/results.json", 'w') as f:
445              json.dump(task_results, f, indent=2)
446              
447          # 将当前任务的结果添加到总结果字典中
448          all_task_results[language] = task_results
449  
450      # 计算所有任务的加权平均值
451      total_samples = sum(res["total"] for res in all_task_results.values())
452      weighted_em = sum(res["em"] * res["total"] for res in all_task_results.values()) / total_samples
453      weighted_es = sum(res["es"] * res["total"] for res in all_task_results.values()) / total_samples
454      weighted_es_repoeval = sum(res["es_repoeval"] * res["total"] for res in all_task_results.values()) / total_samples
455  
456      # 创建最终的合并结果
457      merged_results = {
458          "overall": {
459              "em": round(weighted_em, 4),
460              "es": round(weighted_es, 4),
461              "es_repoeval": round(weighted_es_repoeval, 4),
462              "total": total_samples
463          },
464          "per_language": all_task_results
465      }
466  
467      # 保存合并后的结果
468      with open(f"{args.output_dir}/{args.dataset}/results.json", 'w') as f:
469          json.dump(merged_results, f, indent=2)
470  
471      print("\nOverall Results (Weighted Average):")
472      print(f"EM: {weighted_em:.2f}")
473      print(f"ES: {weighted_es:.2f}")
474      print(f"ES RepoEval: {weighted_es_repoeval:.2f}")
475      print(f"Total Samples: {total_samples}")
476  
477  
478  def compute_metric_stmt_custom(predictions_file, prompt_file, output_dir, 
479                                 ts_lib, task, focused_repo=None, anchor_file=None, out_f_suffix=""):
480      eval_ids = set()
481  
482      if anchor_file:
483          with open(anchor_file, "r") as f_pred:
484              for l in f_pred.readlines():
485                  eval_ids.add(json.loads(l)['task_id'])
486  
487      with open(predictions_file, "r") as f_pred:
488          samples = []
489          for l in f_pred.readlines():
490              if anchor_file:
491                  if json.loads(l)['task_id'] in eval_ids:
492                      samples.append(json.loads(l))
493              else:
494                  entry = json.loads(l)
495                  # entry['task_id'] = re.sub('-', '_',entry['task_id'])
496                  if entry['task_id'] in eval_ids:
497                      continue
498                  if focused_repo is not None:
499                      if type(focused_repo) == str and focused_repo not in re.sub('/', '_', entry['task_id']):
500                          continue
501                      elif type(focused_repo) == list and not any([x in re.sub('/', '_', entry['task_id']) for x in focused_repo]):
502                          continue
503                  samples.append(entry)
504                  eval_ids.add(entry['task_id'])
505  
506      examples = {}
507      with open(prompt_file, "r") as f_in:
508          for l in f_in.readlines():
509              ex = json.loads(l)
510              if focused_repo is not None:
511                  if type(focused_repo) == str and focused_repo not in re.sub('/', '_', ex['metadata']['repository']):
512                      continue
513                  elif type(focused_repo) == list and not any([x in re.sub('/', '_', ex['metadata']['repository']) for x in focused_repo]):
514                      continue
515              if ex["metadata"]["task_id"] not in eval_ids:
516                  continue
517              examples[ex["metadata"]["task_id"]] = {
518                  "task_id": ex["metadata"]["task_id"],
519                  "prompt": ex["prompt"],
520                  "groundtruth": ex["groundtruth"]
521              }
522  
523      assert len(samples) == len(examples), f"{len(samples)} != {len(examples)}"
524  
525      global parser
526      language = Language(ts_lib, "python")
527      parser = Parser()
528      parser.set_language(language)
529  
530      truncated_samples = []
531      print("post-processing samples ...")
532      pool = mp.Pool(mp.cpu_count() - 1)
533      worker = partial(process_examples, task)
534  
535      with tqdm(total=len(samples)) as pbar:
536          for trunc_s in pool.imap_unordered(worker, zip(samples, [examples[s["task_id"]] for s in samples])):
537              truncated_samples.append(trunc_s)
538              pbar.update()
539  
540      with open(f"{output_dir}/prediction_truncated{out_f_suffix}.jsonl", 'w', encoding="utf-8") as pt:
541          for trunc_s in truncated_samples:
542              pt.write(json.dumps(trunc_s) + "\n")
543  
544      ### Score calculation
545  
546      detailed_results = []
547      exact_match = 0
548      edit_sim = 0
549      edit_sim_repoeval = 0
550  
551      for idx, trunc_s in enumerate(truncated_samples):
552          es = cal_edit_sim([trunc_s["target"]], [trunc_s["pred"]])
553          es_repoeval = cal_edit_sim_repoeval([trunc_s["target"]], [trunc_s["pred"]])
554          em = cal_exact_match([trunc_s["target"]], [trunc_s["pred"]])
555          edit_sim += es
556          edit_sim_repoeval += es_repoeval
557          exact_match += em
558  
559          detailed_results.append({
560              "task_id": trunc_s["task_id"],
561              "em": em,
562              "es": es,
563              "es_repoeval": es_repoeval
564          })
565  
566      em_ratio = round(exact_match / len(truncated_samples) * 100, 2)
567      edit_sim = round(edit_sim / len(truncated_samples), 2)
568      edit_sim_repoeval = round(edit_sim_repoeval / len(truncated_samples) * 100, 2)
569  
570      print(
571          f"Code Matching: "
572          f"EM {em_ratio:.2f}, "
573          f"ES {edit_sim:.2f}, "
574          f"ES RepoEval {edit_sim_repoeval:.2f}"
575      )
576  
577      with open(f"{output_dir}/detailed_results{out_f_suffix}.json", 'w') as f:
578          for dr in detailed_results:
579              f.write(json.dumps(dr) + "\n")
580  
581      # write the results to a file
582      with open(f"{output_dir}/results{out_f_suffix}.json", 'w') as f:
583          res = {
584              "em": em_ratio,
585              "es": edit_sim,
586              "es_repoeval": edit_sim_repoeval,
587              "total": len(truncated_samples)
588          }
589          f.write(json.dumps(res, indent=2))
590  
591  def extract_block(text: str) -> str:
592      """提取文本中代码块的内容"""
593      start = text.find('```') + 3
594      end = text.find('```', start)
595      content = text[start:end]
596      return content