eval_metric.py
1 import re 2 import sys 3 import json 4 import timeout_decorator 5 import numpy as np 6 7 from tqdm import tqdm 8 from typing import Callable, List 9 from fuzzywuzzy import fuzz 10 import editdistance 11 from functools import partial 12 import torch.multiprocessing as mp 13 from tree_sitter import Language, Parser 14 from typing import List, Callable, Union 15 from tree_sitter.binding import Node as TSNode 16 import os 17 18 parser = None 19 20 21 def cal_edit_sim(references, hypotheses): 22 total = len(references) 23 edit_sim = 0.0 24 for pred, gt in zip(hypotheses, references): 25 pred = pred.strip() 26 gt = gt.strip() 27 edit_sim += fuzz.ratio(pred, gt) 28 return edit_sim / total 29 30 31 def cal_edit_sim_repoeval(references, hypotheses): 32 total = len(references) 33 edit_sim = 0.0 34 for pred, gt in zip(hypotheses, references): 35 pred = pred.strip() 36 gt = gt.strip() 37 if max(len(pred), len(gt)) == 0: 38 continue 39 edit_sim += 1 - editdistance.eval(pred, gt) / max(len(pred), len(gt)) 40 return edit_sim / total 41 42 43 def tokenize_code(code): 44 code = re.sub(r"([^A-Za-z0-9_])", r" \1 ", code) 45 code = re.sub(r"([a-z])([A-Z])", r"\1 \2", code) 46 code = re.sub(r"\s+", " ", code) 47 code = code.replace('"', "`") 48 code = code.replace("'", "`") 49 tokens = [t for t in code.split(" ") if t] 50 return tokens 51 52 53 def cal_exact_match(references, hypotheses): 54 em_score = [] 55 for pred, gold in zip(hypotheses, references): 56 em_score.append(tokenize_code(pred) == tokenize_code(gold)) 57 return np.mean(em_score) 58 59 60 def remove_comments(code): 61 code = re.sub(r'#.*', '', code) 62 return code 63 64 65 def is_parse_valid(parser, code): 66 def syntax_error(node): 67 if node.type == "ERROR": 68 return True 69 try: 70 for child in node.children: 71 if syntax_error(child): 72 return True 73 except RecursionError as err: 74 return True 75 76 return False 77 78 tree = get_ast(parser, code) 79 if tree is not None: 80 return not syntax_error(tree.root_node) 81 return False 82 83 84 def get_valid_completion(prompt, completion, parser): 85 for i in range(len(completion), -1, -1): 86 code = prompt + completion[:i] 87 if is_parse_valid(parser, code): 88 return "parseable", completion[:i].rstrip() 89 90 return "not_parseable", completion 91 92 93 def dfs( 94 node: TSNode, 95 node_types: List[str], 96 callback: Callable, 97 ignore_node_types: List[str] = None, 98 ): 99 """ 100 Helper to traverse parsed AST 101 """ 102 if node.type in node_types: 103 callback(node) 104 105 for child in node.children: 106 if not ignore_node_types or child.type not in ignore_node_types: 107 dfs(child, node_types, callback, ignore_node_types) 108 109 110 def collect_nodes(root_node, node_types, ignore_node_types=None): 111 """ 112 Collect all nodes that belong to certain types 113 """ 114 result = list() 115 116 def _cb(n): 117 result.append(n) 118 119 if root_node is not None: 120 try: 121 dfs(root_node, node_types, _cb, ignore_node_types) 122 except RecursionError as err: 123 print('collection of nodes failed due to RecursionError') 124 return [] 125 126 return result 127 128 @timeout_decorator.timeout(5) 129 def get_ast(parser, code): 130 assert isinstance(code, str) or isinstance(code, bytes) 131 if isinstance(code, str): 132 code = bytes(code, "utf8") 133 try: 134 tree = parser.parse(code) 135 return tree 136 except Exception as e: 137 return None 138 139 140 def get_functions(parser, code): 141 """ 142 This function returns all functions (irrespective of whether they are inside a class) in a dict format. 143 :param code: 144 :return: Dict() 145 """ 146 try: 147 tree = get_ast(parser, code) 148 except: 149 return [] 150 if tree is None: 151 return [] 152 153 functions = [] 154 function_nodes = collect_nodes(tree.root_node, ["function_definition"]) 155 for fnode in function_nodes: 156 assert fnode.children[-1].type == "block" 157 body_text = fnode.children[-1].text.decode("utf-8") 158 functions.append(body_text) 159 160 return functions 161 162 163 def get_function_completion(prompt, completion, parser): 164 code = prompt + "pass" 165 target_fn_idx = len(get_functions(parser, code)) - 1 166 # assert target_fn_idx != -1 167 168 code = prompt + completion 169 function_body = get_functions(parser, code)[target_fn_idx] 170 return function_body 171 172 173 def process_examples(task, args): 174 sample, ex = args 175 global parser 176 177 prediction = sample["pred"] 178 target = ex["groundtruth"] 179 origin = "" 180 181 if task == "function_completion": 182 status, prediction = get_valid_completion(ex["prompt"], prediction, parser) 183 if status == "parseable": 184 try: 185 origin = prediction 186 prediction = get_function_completion(ex["prompt"], prediction, parser) 187 target = get_function_completion(ex["prompt"], target, parser) 188 except: 189 print(f'[warning] parsing failed: task_id:{ex["task_id"]}') 190 else: 191 print(f'[warning] parsing failed: task_id:{ex["task_id"]}') 192 else: 193 num_target_lines = sum([1 for l in target.split("\n") if l.strip()]) 194 pred_lines = [l for l in prediction.split("\n") if l.strip()][:num_target_lines] 195 prediction = "\n".join(pred_lines) 196 197 trunc_s = { 198 "task_id": sample["task_id"], 199 "pred": prediction, 200 "target": target, 201 #"origin": origin 202 } 203 204 return trunc_s 205 206 207 def compute_metric_stmt(args): 208 all_task_results = {} 209 210 for task in args.tasks: 211 print(f"\nComputing metrics for task: {task}") 212 with open(f"{args.output_dir}/{args.dataset}/{args.language}/{task}/prediction.jsonl", "r") as f_pred: 213 samples = [] 214 for l in f_pred.readlines(): 215 samples.append(json.loads(l)) 216 217 # 构建task特定的prompt文件路径 218 task_prompt_file = args.prompt_file.replace('TASK', task) 219 examples = {} 220 with open(task_prompt_file, "r") as f_in: 221 for l in f_in.readlines(): 222 ex = json.loads(l) 223 if hasattr(args, "focused_repo") and args.focused_repo and args.focused_repo not in re.sub('/', '_', ex['metadata']['repository']): 224 continue 225 examples[ex["metadata"]["task_id"]] = { 226 "task_id": ex["metadata"]["task_id"], 227 "prompt": ex["prompt"], 228 "groundtruth": ex["groundtruth"] 229 } 230 231 if len(samples) == len(examples): 232 print('Warning: len(samples) ({}) == len(examples) ({})'.format(len(samples), len(examples))) 233 234 global parser 235 ts_lang = args.language 236 if ts_lang == 'csharp': 237 ts_lang = 'c_sharp' 238 language = Language(args.ts_lib, ts_lang) 239 parser = Parser() 240 parser.set_language(language) 241 242 truncated_samples = [] 243 print("post-processing samples ...") 244 pool = mp.Pool(mp.cpu_count() - 1) 245 worker = partial(process_examples, task) 246 247 with tqdm(total=len(samples)) as pbar: 248 for trunc_s in pool.imap_unordered(worker, zip(samples, [examples[s["task_id"]] for s in samples])): 249 truncated_samples.append(trunc_s) 250 pbar.update() 251 252 pool.close() 253 pool.join() 254 255 task_output_dir = os.path.join(args.output_dir, args.dataset, args.language, task) 256 # with open(f"{task_output_dir}/prediction_truncated.jsonl", 'w', encoding="utf-8") as pt: 257 # for trunc_s in truncated_samples: 258 # pt.write(json.dumps(trunc_s) + "\n") 259 260 ### Score calculation 261 detailed_results = [] 262 exact_match = 0 263 edit_sim = 0 264 edit_sim_repoeval = 0 265 266 for idx, trunc_s in enumerate(truncated_samples): 267 es = cal_edit_sim([trunc_s["target"]], [trunc_s["pred"]]) 268 es_repoeval = cal_edit_sim_repoeval([trunc_s["target"]], [trunc_s["pred"]]) 269 em = cal_exact_match([trunc_s["target"]], [trunc_s["pred"]]) 270 edit_sim += es 271 edit_sim_repoeval += es_repoeval 272 exact_match += em 273 274 detailed_results.append({ 275 "task_id": trunc_s["task_id"], 276 "pred": trunc_s["pred"], 277 "target": trunc_s["target"], 278 "em": em, 279 "es": es, 280 "es_repoeval": es_repoeval 281 }) 282 283 total_samples = len(truncated_samples) 284 em_ratio = round(exact_match / total_samples * 100, 2) 285 edit_sim_avg = round(edit_sim / total_samples, 2) 286 edit_sim_repoeval_avg = round(edit_sim_repoeval / total_samples * 100, 2) 287 288 print( 289 f"Code Matching for {task}: " 290 f"EM {em_ratio:.2f}, " 291 f"ES {edit_sim_avg:.2f}, " 292 f"ES RepoEval {edit_sim_repoeval_avg:.2f}" 293 ) 294 295 # 保存详细结果 296 with open(f"{task_output_dir}/detailed_results.json", 'w') as f: 297 for dr in detailed_results: 298 f.write(json.dumps(dr) + "\n") 299 300 # 保存任务级别的结果 301 task_results = { 302 "em": em_ratio, 303 "es": edit_sim_avg, 304 "es_repoeval": edit_sim_repoeval_avg, 305 "total": total_samples 306 } 307 308 with open(f"{task_output_dir}/results.json", 'w') as f: 309 json.dump(task_results, f, indent=2) 310 311 # 将当前任务的结果添加到总结果字典中 312 all_task_results[task] = task_results 313 314 # 计算所有任务的加权平均值 315 total_samples = sum(res["total"] for res in all_task_results.values()) 316 weighted_em = sum(res["em"] * res["total"] for res in all_task_results.values()) / total_samples 317 weighted_es = sum(res["es"] * res["total"] for res in all_task_results.values()) / total_samples 318 weighted_es_repoeval = sum(res["es_repoeval"] * res["total"] for res in all_task_results.values()) / total_samples 319 320 # 创建最终的合并结果 321 merged_results = { 322 "overall": { 323 "em": round(weighted_em, 4), 324 "es": round(weighted_es, 4), 325 "es_repoeval": round(weighted_es_repoeval, 4), 326 "total": total_samples 327 }, 328 "per_task": all_task_results 329 } 330 331 # 保存合并后的结果 332 with open(f"{args.output_dir}/{args.dataset}/results.json", 'w') as f: 333 json.dump(merged_results, f, indent=2) 334 335 print("\nOverall Results (Weighted Average):") 336 print(f"EM: {weighted_em:.2f}") 337 print(f"ES: {weighted_es:.2f}") 338 print(f"ES RepoEval: {weighted_es_repoeval:.2f}") 339 print(f"Total Samples: {total_samples}") 340 341 def compute_metric_stmt_multilang(args): 342 all_task_results = {} 343 344 for language in args.languages: 345 print(f"\nComputing metrics for language: {language}") 346 with open(f"{args.output_dir}/{args.dataset}/{language}/{args.task}/prediction.jsonl", "r") as f_pred: 347 samples = [] 348 for l in f_pred.readlines(): 349 samples.append(json.loads(l)) 350 351 # 构建task特定的prompt文件路径 352 task_prompt_file = args.prompt_file.replace('LANGUAGE', language) 353 examples = {} 354 with open(task_prompt_file, "r") as f_in: 355 for l in f_in.readlines(): 356 ex = json.loads(l) 357 if hasattr(args, "focused_repo") and args.focused_repo and args.focused_repo not in re.sub('/', '_', ex['metadata']['repository']): 358 continue 359 examples[ex["metadata"]["task_id"]] = { 360 "task_id": ex["metadata"]["task_id"], 361 "prompt": ex["prompt"], 362 "groundtruth": ex["groundtruth"] 363 } 364 365 if len(samples) == len(examples): 366 print('Warning: len(samples) ({}) == len(examples) ({})'.format(len(samples), len(examples))) 367 368 global parser 369 ts_lang = language 370 if ts_lang == 'csharp': 371 ts_lang = 'c_sharp' 372 373 ts_lib = args.ts_lib.replace('LANGUAGE', language) 374 language_ts = Language(ts_lib, ts_lang) 375 parser = Parser() 376 parser.set_language(language_ts) 377 378 truncated_samples = [] 379 print("post-processing samples ...") 380 pool = mp.Pool(mp.cpu_count() - 1) 381 worker = partial(process_examples, args.task) 382 383 with tqdm(total=len(samples)) as pbar: 384 for trunc_s in pool.imap_unordered(worker, zip(samples, [examples[s["task_id"]] for s in samples])): 385 truncated_samples.append(trunc_s) 386 pbar.update() 387 388 pool.close() 389 pool.join() 390 391 task_output_dir = os.path.join(args.output_dir, args.dataset, language, args.task) 392 # with open(f"{task_output_dir}/prediction_truncated.jsonl", 'w', encoding="utf-8") as pt: 393 # for trunc_s in truncated_samples: 394 # pt.write(json.dumps(trunc_s) + "\n") 395 396 ### Score calculation 397 detailed_results = [] 398 exact_match = 0 399 edit_sim = 0 400 edit_sim_repoeval = 0 401 402 for idx, trunc_s in enumerate(truncated_samples): 403 es = cal_edit_sim([trunc_s["target"]], [trunc_s["pred"]]) 404 es_repoeval = cal_edit_sim_repoeval([trunc_s["target"]], [trunc_s["pred"]]) 405 em = cal_exact_match([trunc_s["target"]], [trunc_s["pred"]]) 406 edit_sim += es 407 edit_sim_repoeval += es_repoeval 408 exact_match += em 409 410 detailed_results.append({ 411 "task_id": trunc_s["task_id"], 412 "pred": trunc_s["pred"], 413 "target": trunc_s["target"], 414 "em": em, 415 "es": es, 416 "es_repoeval": es_repoeval 417 }) 418 419 total_samples = len(truncated_samples) 420 em_ratio = round(exact_match / total_samples * 100, 2) 421 edit_sim_avg = round(edit_sim / total_samples, 2) 422 edit_sim_repoeval_avg = round(edit_sim_repoeval / total_samples * 100, 2) 423 424 print( 425 f"Code Matching for {language}: " 426 f"EM {em_ratio:.2f}, " 427 f"ES {edit_sim_avg:.2f}, " 428 f"ES RepoEval {edit_sim_repoeval_avg:.2f}" 429 ) 430 431 # 保存详细结果 432 with open(f"{task_output_dir}/detailed_results.json", 'w') as f: 433 for dr in detailed_results: 434 f.write(json.dumps(dr) + "\n") 435 436 # 保存任务级别的结果 437 task_results = { 438 "em": em_ratio, 439 "es": edit_sim_avg, 440 "es_repoeval": edit_sim_repoeval_avg, 441 "total": total_samples 442 } 443 444 with open(f"{task_output_dir}/results.json", 'w') as f: 445 json.dump(task_results, f, indent=2) 446 447 # 将当前任务的结果添加到总结果字典中 448 all_task_results[language] = task_results 449 450 # 计算所有任务的加权平均值 451 total_samples = sum(res["total"] for res in all_task_results.values()) 452 weighted_em = sum(res["em"] * res["total"] for res in all_task_results.values()) / total_samples 453 weighted_es = sum(res["es"] * res["total"] for res in all_task_results.values()) / total_samples 454 weighted_es_repoeval = sum(res["es_repoeval"] * res["total"] for res in all_task_results.values()) / total_samples 455 456 # 创建最终的合并结果 457 merged_results = { 458 "overall": { 459 "em": round(weighted_em, 4), 460 "es": round(weighted_es, 4), 461 "es_repoeval": round(weighted_es_repoeval, 4), 462 "total": total_samples 463 }, 464 "per_language": all_task_results 465 } 466 467 # 保存合并后的结果 468 with open(f"{args.output_dir}/{args.dataset}/results.json", 'w') as f: 469 json.dump(merged_results, f, indent=2) 470 471 print("\nOverall Results (Weighted Average):") 472 print(f"EM: {weighted_em:.2f}") 473 print(f"ES: {weighted_es:.2f}") 474 print(f"ES RepoEval: {weighted_es_repoeval:.2f}") 475 print(f"Total Samples: {total_samples}") 476 477 478 def compute_metric_stmt_custom(predictions_file, prompt_file, output_dir, 479 ts_lib, task, focused_repo=None, anchor_file=None, out_f_suffix=""): 480 eval_ids = set() 481 482 if anchor_file: 483 with open(anchor_file, "r") as f_pred: 484 for l in f_pred.readlines(): 485 eval_ids.add(json.loads(l)['task_id']) 486 487 with open(predictions_file, "r") as f_pred: 488 samples = [] 489 for l in f_pred.readlines(): 490 if anchor_file: 491 if json.loads(l)['task_id'] in eval_ids: 492 samples.append(json.loads(l)) 493 else: 494 entry = json.loads(l) 495 # entry['task_id'] = re.sub('-', '_',entry['task_id']) 496 if entry['task_id'] in eval_ids: 497 continue 498 if focused_repo is not None: 499 if type(focused_repo) == str and focused_repo not in re.sub('/', '_', entry['task_id']): 500 continue 501 elif type(focused_repo) == list and not any([x in re.sub('/', '_', entry['task_id']) for x in focused_repo]): 502 continue 503 samples.append(entry) 504 eval_ids.add(entry['task_id']) 505 506 examples = {} 507 with open(prompt_file, "r") as f_in: 508 for l in f_in.readlines(): 509 ex = json.loads(l) 510 if focused_repo is not None: 511 if type(focused_repo) == str and focused_repo not in re.sub('/', '_', ex['metadata']['repository']): 512 continue 513 elif type(focused_repo) == list and not any([x in re.sub('/', '_', ex['metadata']['repository']) for x in focused_repo]): 514 continue 515 if ex["metadata"]["task_id"] not in eval_ids: 516 continue 517 examples[ex["metadata"]["task_id"]] = { 518 "task_id": ex["metadata"]["task_id"], 519 "prompt": ex["prompt"], 520 "groundtruth": ex["groundtruth"] 521 } 522 523 assert len(samples) == len(examples), f"{len(samples)} != {len(examples)}" 524 525 global parser 526 language = Language(ts_lib, "python") 527 parser = Parser() 528 parser.set_language(language) 529 530 truncated_samples = [] 531 print("post-processing samples ...") 532 pool = mp.Pool(mp.cpu_count() - 1) 533 worker = partial(process_examples, task) 534 535 with tqdm(total=len(samples)) as pbar: 536 for trunc_s in pool.imap_unordered(worker, zip(samples, [examples[s["task_id"]] for s in samples])): 537 truncated_samples.append(trunc_s) 538 pbar.update() 539 540 with open(f"{output_dir}/prediction_truncated{out_f_suffix}.jsonl", 'w', encoding="utf-8") as pt: 541 for trunc_s in truncated_samples: 542 pt.write(json.dumps(trunc_s) + "\n") 543 544 ### Score calculation 545 546 detailed_results = [] 547 exact_match = 0 548 edit_sim = 0 549 edit_sim_repoeval = 0 550 551 for idx, trunc_s in enumerate(truncated_samples): 552 es = cal_edit_sim([trunc_s["target"]], [trunc_s["pred"]]) 553 es_repoeval = cal_edit_sim_repoeval([trunc_s["target"]], [trunc_s["pred"]]) 554 em = cal_exact_match([trunc_s["target"]], [trunc_s["pred"]]) 555 edit_sim += es 556 edit_sim_repoeval += es_repoeval 557 exact_match += em 558 559 detailed_results.append({ 560 "task_id": trunc_s["task_id"], 561 "em": em, 562 "es": es, 563 "es_repoeval": es_repoeval 564 }) 565 566 em_ratio = round(exact_match / len(truncated_samples) * 100, 2) 567 edit_sim = round(edit_sim / len(truncated_samples), 2) 568 edit_sim_repoeval = round(edit_sim_repoeval / len(truncated_samples) * 100, 2) 569 570 print( 571 f"Code Matching: " 572 f"EM {em_ratio:.2f}, " 573 f"ES {edit_sim:.2f}, " 574 f"ES RepoEval {edit_sim_repoeval:.2f}" 575 ) 576 577 with open(f"{output_dir}/detailed_results{out_f_suffix}.json", 'w') as f: 578 for dr in detailed_results: 579 f.write(json.dumps(dr) + "\n") 580 581 # write the results to a file 582 with open(f"{output_dir}/results{out_f_suffix}.json", 'w') as f: 583 res = { 584 "em": em_ratio, 585 "es": edit_sim, 586 "es_repoeval": edit_sim_repoeval, 587 "total": len(truncated_samples) 588 } 589 f.write(json.dumps(res, indent=2)) 590 591 def extract_block(text: str) -> str: 592 """提取文本中代码块的内容""" 593 start = text.find('```') + 3 594 end = text.find('```', start) 595 content = text[start:end] 596 return content