testing_util.py
1 import ast 2 import json 3 import sys 4 import faulthandler 5 import platform 6 7 # used for debugging to time steps 8 from datetime import datetime 9 10 # to run the solution files we're using a timing based approach 11 import signal 12 13 import numpy as np 14 15 # for capturing the stdout 16 from io import StringIO 17 18 # used for testing the code that reads from input 19 from unittest.mock import patch, mock_open 20 21 from pyext import RuntimeModule 22 23 from enum import Enum 24 25 26 class CODE_TYPE(Enum): 27 call_based = 0 28 standard_input = 1 29 30 31 # stuff for setting up signal timer 32 class TimeoutException(Exception): 33 pass 34 35 36 def timeout_handler(signum, frame): 37 print("alarm went off") 38 # return 39 raise TimeoutException 40 41 42 signal.signal(signal.SIGALRM, timeout_handler) 43 # timeout = 6 # seconds 44 45 46 # used to capture stdout as a list 47 # from https://stackoverflow.com/a/16571630/6416660 48 # alternative use redirect_stdout() from contextlib 49 class Capturing(list): 50 def __enter__(self): 51 self._stdout = sys.stdout 52 sys.stdout = self._stringio = StringIO() 53 # Make closing the StringIO a no-op 54 self._stringio.close = lambda x: 1 55 return self 56 57 def __exit__(self, *args): 58 self.extend(self._stringio.getvalue().splitlines()) 59 del self._stringio # free up some memory 60 sys.stdout = self._stdout 61 62 63 def only_int_check(val): 64 return isinstance(val, int) 65 66 67 def string_int_check(val): 68 return isinstance(val, str) and val.isdigit() 69 70 71 def combined_int_check(val): 72 return only_int_check(val) or string_int_check(val) 73 74 75 def run_test(sample, test=None, debug=False, timeout=6): 76 """ 77 if test(generated_code) is not None it'll try to run the code. 78 otherwise it'll just return an input and output pair. 79 """ 80 # Disable functionalities that can make destructive changes to the test. 81 reliability_guard() 82 83 if debug: 84 print(f"start = {datetime.now().time()}") 85 86 try: 87 in_outs = json.loads(sample["input_output"]) 88 except ValueError: 89 in_outs = None 90 if in_outs: 91 if in_outs.get("fn_name") is None: 92 which_type = CODE_TYPE.standard_input # Standard input 93 method_name = None 94 else: 95 which_type = CODE_TYPE.call_based # Call-based 96 method_name = in_outs["fn_name"] 97 98 if debug: 99 print(f"loaded input_output = {datetime.now().time()}") 100 101 if test is None: 102 return in_outs 103 elif test is not None: 104 results = [] 105 sol = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(6*10**5)\n" 106 if debug: 107 print(f"loading test code = {datetime.now().time()}") 108 109 if which_type == CODE_TYPE.call_based: 110 111 sol += test 112 if debug: 113 print(f"sol = {sol}") 114 signal.alarm(timeout) 115 try: 116 tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol) 117 if "class Solution" not in test: 118 tmp = tmp_sol 119 else: 120 tmp = tmp_sol.Solution() 121 signal.alarm(0) 122 except Exception as e: 123 signal.alarm(0) 124 if debug: 125 print(f"type 0 compilation error = {e}") 126 results.append(-2) 127 return results 128 signal.alarm(0) 129 130 elif which_type == CODE_TYPE.standard_input: 131 # sol 132 # if code has if __name__ == "__main__": then remove it 133 try: 134 astree = ast.parse(test) 135 last_block = astree.body[-1] 136 if isinstance(last_block, ast.If): 137 condition = last_block.test 138 if ast.unparse(condition).strip() == "__name__ == '__main__'": 139 test = ( 140 ast.unparse(astree.body[:-1]) 141 + "\n" 142 + ast.unparse(last_block.body) 143 ) 144 except: 145 pass 146 147 tmp_test = test.split("\n") 148 149 new_test = [] 150 for x in tmp_test: 151 if (not x.startswith("from ")) and (not x.startswith("import ")): 152 new_test.append("\t" + x + "\n") 153 else: 154 new_test.append(x + "\n") 155 tmp_test = new_test 156 157 new_test = "" 158 started = False 159 for i in tmp_test: 160 if i.startswith("\t") and not started: 161 new_test += "stdin = sys.stdin\nstdout = sys.stdout\n" 162 new_test += "def code():\n" 163 new_test += i 164 started = True 165 elif started and ((i.startswith("from ")) or (i.startswith("import "))): 166 new_test += "\t" + i 167 else: 168 new_test += i 169 tmp_test = new_test 170 171 sol += tmp_test 172 if debug: 173 print(f"sol = {sol}") 174 method_name = "code" 175 signal.alarm(timeout) 176 try: 177 tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol) 178 tmp = tmp_sol 179 signal.alarm(0) 180 except Exception as e: 181 signal.alarm(0) 182 if debug: 183 print(f"type 1 compilation error = {e}") 184 results.append(-2) 185 return results 186 signal.alarm(0) 187 if debug: 188 print(f"get method = {datetime.now().time()}") 189 190 try: 191 method = getattr(tmp, method_name) # get_attr second arg must be str 192 except: 193 signal.alarm(0) 194 e = sys.exc_info() 195 print(f"unable to get function error = {e}") 196 results.append(-2) 197 return results 198 199 for index, inputs in enumerate(in_outs["inputs"]): 200 if which_type == CODE_TYPE.call_based: 201 inputs = [json.loads(line) for line in inputs.split("\n")] 202 in_outs["outputs"][index] = json.loads(in_outs["outputs"][index]) 203 # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list) 204 try: 205 if isinstance(inputs[0], dict): 206 inputs = [{int(k): v for k, v in inputs[0].items()}] 207 except: 208 True 209 try: 210 if isinstance(in_outs["outputs"][index], dict): 211 in_outs["outputs"][index] = [ 212 {int(k): v for k, v in in_outs["outputs"][index].items()} 213 ] 214 except: 215 True 216 try: 217 if isinstance(in_outs["outputs"][index][0], dict): 218 in_outs["outputs"][index] = [ 219 {int(k): v for k, v in in_outs["outputs"][index][0].items()} 220 ] 221 except: 222 True 223 224 if debug: 225 print( 226 f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}" 227 ) 228 if which_type == CODE_TYPE.call_based: # Call-based 229 signal.alarm(timeout) 230 faulthandler.enable() 231 try: 232 output = method(*inputs) 233 234 # ground truth sequences are not tuples 235 if isinstance(output, tuple): 236 output = list(output) 237 238 tmp_result = output == in_outs["outputs"][index] 239 if ( 240 isinstance(in_outs["outputs"][index], list) 241 and in_outs["outputs"][index] 242 ): 243 tmp_result = tmp_result or ( 244 output == in_outs["outputs"][index][0] 245 ) 246 247 # ground truth sequences are not tuples 248 try: 249 if isinstance(output[0], tuple): 250 tmp_result = tmp_result or ( 251 [list(x) for x in output] 252 == in_outs["outputs"][index][0] 253 ) 254 except: 255 True 256 results.append(tmp_result) 257 if tmp_result != True: 258 return results 259 # reset the alarm 260 signal.alarm(0) 261 except Exception as e: 262 signal.alarm(0) 263 faulthandler.disable() 264 if debug: 265 print( 266 f"Standard input runtime error or time limit exceeded error = {e}" 267 ) 268 results.append(-1) 269 return results 270 faulthandler.disable() 271 signal.alarm(0) 272 if debug: 273 print( 274 f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" 275 ) 276 elif which_type == CODE_TYPE.standard_input: # Standard input 277 faulthandler.enable() 278 passed = False 279 280 if isinstance(inputs, list): 281 inputs = "\n".join(inputs) 282 if isinstance(in_outs["outputs"][index], list): 283 in_outs["outputs"][index] = "\n".join(in_outs["outputs"][index]) 284 285 signal.alarm(timeout) 286 with Capturing() as output: 287 try: 288 call_method(method, inputs) 289 # reset the alarm 290 signal.alarm(0) 291 passed = True 292 except Exception as e: 293 # runtime error or took too long 294 signal.alarm(0) 295 print( 296 f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}" 297 ) 298 results.append(-1) 299 return results 300 signal.alarm(0) 301 302 if not passed: 303 if debug: 304 nl = "\n" 305 if not isinstance(inputs, list): 306 print( 307 f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" 308 ) 309 else: 310 print( 311 f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" 312 ) 313 continue 314 315 if passed and debug: 316 print( 317 f"==> output = {output}, test outputs = {in_outs['outputs'][index]}" 318 ) 319 320 if custom_compare_(output, in_outs["outputs"][index]): 321 tmp_result = True 322 results.append(tmp_result) 323 continue 324 325 # ground truth sequences are expressed as lists not tuples 326 if isinstance(output, tuple): 327 output = list(output) 328 329 tmp_result = False 330 try: 331 tmp_result = output == [in_outs["outputs"][index]] 332 if isinstance(in_outs["outputs"][index], list): 333 tmp_result = tmp_result or (output == in_outs["outputs"][index]) 334 if isinstance(output[0], str): 335 tmp_result = tmp_result or ( 336 [e.strip() for e in output] == in_outs["outputs"][index] 337 ) 338 except Exception as e: 339 if debug: 340 print(f"Failed check1 exception = {e}") 341 pass 342 343 if tmp_result == True: 344 results.append(tmp_result) 345 continue 346 347 # try one more time without \n 348 if isinstance(in_outs["outputs"][index], list): 349 for tmp_index, i in enumerate(in_outs["outputs"][index]): 350 in_outs["outputs"][index][tmp_index] = i.split("\n") 351 in_outs["outputs"][index][tmp_index] = [ 352 x.strip() for x in in_outs["outputs"][index][tmp_index] if x 353 ] 354 else: 355 in_outs["outputs"][index] = in_outs["outputs"][index].split("\n") 356 in_outs["outputs"][index] = list( 357 filter(len, in_outs["outputs"][index]) 358 ) 359 in_outs["outputs"][index] = list( 360 map(lambda x: x.strip(), in_outs["outputs"][index]) 361 ) 362 363 try: 364 tmp_result = output == [in_outs["outputs"][index]] 365 if isinstance(in_outs["outputs"][index], list): 366 tmp_result = tmp_result or (output == in_outs["outputs"][index]) 367 except Exception as e: 368 if debug: 369 print(f"Failed check2 exception = {e}") 370 pass 371 372 if tmp_result == True: 373 results.append(tmp_result) 374 continue 375 376 # try by converting the output into a split up list too 377 if isinstance(output, list): 378 output = list(filter(len, output)) 379 380 if debug: 381 nl = "\n" 382 if not isinstance(inputs, list): 383 print( 384 f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}" 385 ) 386 else: 387 print( 388 f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}" 389 ) 390 391 if tmp_result == True: 392 results.append(tmp_result) 393 continue 394 395 if debug: 396 print(f"{tmp_result=} @a") 397 398 try: 399 tmp_result = output == [in_outs["outputs"][index]] 400 if isinstance(in_outs["outputs"][index], list): 401 tmp_result = tmp_result or (output == in_outs["outputs"][index]) 402 except Exception as e: 403 if debug: 404 print(f"Failed check3 exception = {e}") 405 pass 406 407 if debug: 408 print(f"{tmp_result=} @b") 409 410 try: 411 all_ints = all( 412 combined_int_check(e1) and combined_int_check(e2) 413 for e1, e2 in zip(output, in_outs["outputs"][index]) 414 ) 415 if not all_ints: 416 if debug: 417 print( 418 [ 419 combined_int_check(e1) and combined_int_check(e2) 420 for e1, e2 in zip(output, in_outs["outputs"][index]) 421 ] 422 ) 423 output_float = [float(e) for e in output] 424 gt_float = [float(e) for e in in_outs["outputs"][index]] 425 tmp_result = tmp_result or ( 426 (len(output_float) == len(gt_float)) 427 and np.allclose(output_float, gt_float) 428 ) 429 except Exception as e: 430 pass 431 432 if debug: 433 print(f"{tmp_result=} @c") 434 435 try: 436 if isinstance(output[0], list): 437 all_ints = all( 438 combined_int_check(e1) and combined_int_check(e2) 439 for e1, e2 in zip(output[0], in_outs["outputs"][index]) 440 ) 441 if not all_ints: 442 output_float = [float(e) for e in output[0]] 443 gt_float = [float(e) for e in in_outs["outputs"][index][0]] 444 tmp_result = tmp_result or ( 445 (len(output_float) == len(gt_float)) 446 and np.allclose(output_float, gt_float) 447 ) 448 except Exception as e: 449 pass 450 451 if tmp_result == True: 452 results.append(tmp_result) 453 continue 454 455 if debug: 456 print(f"{tmp_result=} @d") 457 # try by converting the stuff into split up list 458 if isinstance(in_outs["outputs"][index], list): 459 for tmp_index, i in enumerate(in_outs["outputs"][index]): 460 in_outs["outputs"][index][tmp_index] = set(i.split()) 461 else: 462 in_outs["outputs"][index] = set(in_outs["outputs"][index].split()) 463 464 if debug: 465 print(f"{tmp_result=} @e") 466 467 try: 468 tmp_result = output == in_outs["outputs"][index] 469 except Exception as e: 470 if debug: 471 print(f"Failed check4 exception = {e}") 472 continue 473 474 if tmp_result == True: 475 results.append(tmp_result) 476 continue 477 478 if debug: 479 print(f"{tmp_result=} @f") 480 481 # try by converting the output into a split up list too 482 if isinstance(output, list): 483 for tmp_index, i in enumerate(output): 484 output[tmp_index] = i.split() 485 output = list(filter(len, output)) 486 for tmp_index, i in enumerate(output): 487 output[tmp_index] = set(i) 488 else: 489 output = output.split() 490 output = list(filter(len, output)) 491 output = set(output) 492 493 if debug: 494 print(f"{tmp_result=} @g") 495 # try: 496 # tmp_result = set(frozenset(s) for s in output) == set( 497 # frozenset(s) for s in in_outs["outputs"][index] 498 # ) 499 # except Exception as e: 500 # if debug: 501 # print(f"Failed check5 exception = {e}") 502 503 # if they are all numbers, round so that similar numbers are treated as identical 504 # try: 505 # all_ints = all( 506 # combined_int_check(e1) and combined_int_check(e2) 507 # for e1, e2 in zip(output, in_outs["outputs"][index]) 508 # ) 509 # tmp_result = tmp_result or ( 510 # set(frozenset(round(float(t), 3) for t in s) for s in output) 511 # == set( 512 # frozenset(round(float(t), 3) for t in s) 513 # for s in in_outs["outputs"][index] 514 # ) 515 # ) 516 # except Exception as e: 517 # if debug: 518 # print(f"Failed check6 exception = {e}") 519 520 if debug: 521 print(f"{tmp_result=} @h") 522 523 if tmp_result == True and debug: 524 print("PASSED") 525 526 results.append(tmp_result) 527 if tmp_result != True: 528 return results 529 530 if debug: 531 nl = "\n" 532 if not isinstance(inputs, list): 533 print( 534 f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" 535 ) 536 else: 537 print( 538 f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}" 539 ) 540 541 print(f"results = {results}") 542 543 return results 544 545 546 def custom_compare_(output, ground_truth): 547 548 if isinstance(output, list): 549 output_1 = "\n".join(output) 550 if stripped_string_compare(output_1, ground_truth): 551 return True 552 553 if isinstance(output, list): 554 output_2 = [o.lstrip().rstrip() for o in output] 555 output_2 = "\n".join(output_2) 556 if stripped_string_compare(output_2, ground_truth): 557 return True 558 559 return False 560 561 562 def stripped_string_compare(s1, s2): 563 s1 = s1.lstrip().rstrip() 564 s2 = s2.lstrip().rstrip() 565 return s1 == s2 566 567 568 def call_method(method, inputs): 569 570 if isinstance(inputs, list): 571 inputs = "\n".join(inputs) 572 573 inputs_line_iterator = iter(inputs.split("\n")) 574 575 # sys.setrecursionlimit(10000) 576 577 # @patch('builtins.input', side_effect=inputs.split("\n")) 578 @patch("builtins.open", mock_open(read_data=inputs)) 579 @patch("sys.stdin", StringIO(inputs)) 580 @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator)) 581 @patch("sys.stdin.readlines", lambda *args: inputs.split("\n")) 582 @patch("sys.stdin.read", lambda *args: inputs) 583 # @patch('sys.stdout.write', print) 584 def _inner_call_method(_method): 585 try: 586 return _method() 587 except SystemExit as e: 588 pass 589 finally: 590 pass 591 592 return _inner_call_method(method) 593 594 595 def reliability_guard(maximum_memory_bytes=None): 596 """ 597 This disables various destructive functions and prevents the generated code 598 from interfering with the test (e.g. fork bomb, killing other processes, 599 removing filesystem files, etc.) 600 WARNING 601 This function is NOT a security sandbox. Untrusted code, including, model- 602 generated code, should not be blindly executed outside of one. See the 603 Codex paper for more information about OpenAI's code sandbox, and proceed 604 with caution. 605 """ 606 607 if maximum_memory_bytes is not None: 608 import resource 609 610 resource.setrlimit( 611 resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes) 612 ) 613 resource.setrlimit( 614 resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes) 615 ) 616 if not platform.uname().system == "Darwin": 617 resource.setrlimit( 618 resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes) 619 ) 620 621 faulthandler.disable() 622 623 import builtins 624 625 builtins.exit = None 626 builtins.quit = None 627 628 import os 629 630 os.environ["OMP_NUM_THREADS"] = "1" 631 632 os.kill = None 633 os.system = None 634 os.putenv = None 635 os.remove = None 636 os.removedirs = None 637 os.rmdir = None 638 os.fchdir = None 639 os.setuid = None 640 os.fork = None 641 os.forkpty = None 642 os.killpg = None 643 os.rename = None 644 os.renames = None 645 os.truncate = None 646 os.replace = None 647 os.unlink = None 648 os.fchmod = None 649 os.fchown = None 650 os.chmod = None 651 os.chown = None 652 os.chroot = None 653 os.fchdir = None 654 os.lchflags = None 655 os.lchmod = None 656 os.lchown = None 657 os.getcwd = None 658 os.chdir = None 659 660 import shutil 661 662 shutil.rmtree = None 663 shutil.move = None 664 shutil.chown = None 665 666 import subprocess 667 668 subprocess.Popen = None # type: ignore 669 670 __builtins__["help"] = None 671 672 import sys 673 674 sys.modules["ipdb"] = None 675 sys.modules["joblib"] = None 676 sys.modules["resource"] = None 677 sys.modules["psutil"] = None 678 sys.modules["tkinter"] = None