testing_util.py
  1  import ast
  2  import json
  3  import sys
  4  import faulthandler
  5  import platform
  6  
  7  # used for debugging to time steps
  8  from datetime import datetime
  9  
 10  # to run the solution files we're using a timing based approach
 11  import signal
 12  
 13  import numpy as np
 14  
 15  # for capturing the stdout
 16  from io import StringIO
 17  
 18  # used for testing the code that reads from input
 19  from unittest.mock import patch, mock_open
 20  
 21  from pyext import RuntimeModule
 22  
 23  from enum import Enum
 24  
 25  
 26  class CODE_TYPE(Enum):
 27      call_based = 0
 28      standard_input = 1
 29  
 30  
 31  # stuff for setting up signal timer
 32  class TimeoutException(Exception):
 33      pass
 34  
 35  
 36  def timeout_handler(signum, frame):
 37      print("alarm went off")
 38      # return
 39      raise TimeoutException
 40  
 41  
 42  signal.signal(signal.SIGALRM, timeout_handler)
 43  # timeout = 6  # seconds
 44  
 45  
 46  # used to capture stdout as a list
 47  # from https://stackoverflow.com/a/16571630/6416660
 48  # alternative use redirect_stdout() from contextlib
 49  class Capturing(list):
 50      def __enter__(self):
 51          self._stdout = sys.stdout
 52          sys.stdout = self._stringio = StringIO()
 53          # Make closing the StringIO a no-op
 54          self._stringio.close = lambda x: 1
 55          return self
 56  
 57      def __exit__(self, *args):
 58          self.extend(self._stringio.getvalue().splitlines())
 59          del self._stringio  # free up some memory
 60          sys.stdout = self._stdout
 61  
 62  
 63  def only_int_check(val):
 64      return isinstance(val, int)
 65  
 66  
 67  def string_int_check(val):
 68      return isinstance(val, str) and val.isdigit()
 69  
 70  
 71  def combined_int_check(val):
 72      return only_int_check(val) or string_int_check(val)
 73  
 74  
 75  def run_test(sample, test=None, debug=False, timeout=6):
 76      """
 77      if test(generated_code) is not None it'll try to run the code.
 78      otherwise it'll just return an input and output pair.
 79      """
 80      # Disable functionalities that can make destructive changes to the test.
 81      reliability_guard()
 82  
 83      if debug:
 84          print(f"start = {datetime.now().time()}")
 85  
 86      try:
 87          in_outs = json.loads(sample["input_output"])
 88      except ValueError:
 89          in_outs = None
 90      if in_outs:
 91          if in_outs.get("fn_name") is None:
 92              which_type = CODE_TYPE.standard_input  # Standard input
 93              method_name = None
 94          else:
 95              which_type = CODE_TYPE.call_based  # Call-based
 96              method_name = in_outs["fn_name"]
 97  
 98      if debug:
 99          print(f"loaded input_output = {datetime.now().time()}")
100  
101      if test is None:
102          return in_outs
103      elif test is not None:
104          results = []
105          sol = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(6*10**5)\n"
106          if debug:
107              print(f"loading test code = {datetime.now().time()}")
108  
109          if which_type == CODE_TYPE.call_based:
110  
111              sol += test
112              if debug:
113                  print(f"sol = {sol}")
114              signal.alarm(timeout)
115              try:
116                  tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
117                  if "class Solution" not in test:
118                      tmp = tmp_sol
119                  else:
120                      tmp = tmp_sol.Solution()
121                  signal.alarm(0)
122              except Exception as e:
123                  signal.alarm(0)
124                  if debug:
125                      print(f"type 0 compilation error = {e}")
126                  results.append(-2)
127                  return results
128              signal.alarm(0)
129  
130          elif which_type == CODE_TYPE.standard_input:
131              # sol
132              # if code has if __name__ == "__main__": then remove it
133              try:
134                  astree = ast.parse(test)
135                  last_block = astree.body[-1]
136                  if isinstance(last_block, ast.If):
137                      condition = last_block.test
138                      if ast.unparse(condition).strip() == "__name__ == '__main__'":
139                          test = (
140                              ast.unparse(astree.body[:-1])
141                              + "\n"
142                              + ast.unparse(last_block.body)
143                          )
144              except:
145                  pass
146  
147              tmp_test = test.split("\n")
148  
149              new_test = []
150              for x in tmp_test:
151                  if (not x.startswith("from ")) and (not x.startswith("import ")):
152                      new_test.append("\t" + x + "\n")
153                  else:
154                      new_test.append(x + "\n")
155              tmp_test = new_test
156  
157              new_test = ""
158              started = False
159              for i in tmp_test:
160                  if i.startswith("\t") and not started:
161                      new_test += "stdin = sys.stdin\nstdout = sys.stdout\n"
162                      new_test += "def code():\n"
163                      new_test += i
164                      started = True
165                  elif started and ((i.startswith("from ")) or (i.startswith("import "))):
166                      new_test += "\t" + i
167                  else:
168                      new_test += i
169              tmp_test = new_test
170  
171              sol += tmp_test
172              if debug:
173                  print(f"sol = {sol}")
174              method_name = "code"
175              signal.alarm(timeout)
176              try:
177                  tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
178                  tmp = tmp_sol
179                  signal.alarm(0)
180              except Exception as e:
181                  signal.alarm(0)
182                  if debug:
183                      print(f"type 1 compilation error = {e}")
184                  results.append(-2)
185                  return results
186              signal.alarm(0)
187          if debug:
188              print(f"get method = {datetime.now().time()}")
189  
190          try:
191              method = getattr(tmp, method_name)  # get_attr second arg must be str
192          except:
193              signal.alarm(0)
194              e = sys.exc_info()
195              print(f"unable to get function error = {e}")
196              results.append(-2)
197              return results
198  
199          for index, inputs in enumerate(in_outs["inputs"]):
200              if which_type == CODE_TYPE.call_based:
201                  inputs = [json.loads(line) for line in inputs.split("\n")]
202                  in_outs["outputs"][index] = json.loads(in_outs["outputs"][index])
203              # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
204              try:
205                  if isinstance(inputs[0], dict):
206                      inputs = [{int(k): v for k, v in inputs[0].items()}]
207              except:
208                  True
209              try:
210                  if isinstance(in_outs["outputs"][index], dict):
211                      in_outs["outputs"][index] = [
212                          {int(k): v for k, v in in_outs["outputs"][index].items()}
213                      ]
214              except:
215                  True
216              try:
217                  if isinstance(in_outs["outputs"][index][0], dict):
218                      in_outs["outputs"][index] = [
219                          {int(k): v for k, v in in_outs["outputs"][index][0].items()}
220                      ]
221              except:
222                  True
223  
224              if debug:
225                  print(
226                      f"time: {datetime.now().time()} testing index = {index}  inputs = {inputs}, {type(inputs)}. type = {which_type}"
227                  )
228              if which_type == CODE_TYPE.call_based:  # Call-based
229                  signal.alarm(timeout)
230                  faulthandler.enable()
231                  try:
232                      output = method(*inputs)
233  
234                      # ground truth sequences are not tuples
235                      if isinstance(output, tuple):
236                          output = list(output)
237  
238                      tmp_result = output == in_outs["outputs"][index]
239                      if (
240                          isinstance(in_outs["outputs"][index], list)
241                          and in_outs["outputs"][index]
242                      ):
243                          tmp_result = tmp_result or (
244                              output == in_outs["outputs"][index][0]
245                          )
246  
247                      # ground truth sequences are not tuples
248                      try:
249                          if isinstance(output[0], tuple):
250                              tmp_result = tmp_result or (
251                                  [list(x) for x in output]
252                                  == in_outs["outputs"][index][0]
253                              )
254                      except:
255                          True
256                      results.append(tmp_result)
257                      if tmp_result != True:
258                          return results
259                      # reset the alarm
260                      signal.alarm(0)
261                  except Exception as e:
262                      signal.alarm(0)
263                      faulthandler.disable()
264                      if debug:
265                          print(
266                              f"Standard input runtime error or time limit exceeded error = {e}"
267                          )
268                      results.append(-1)
269                      return results
270                  faulthandler.disable()
271                  signal.alarm(0)
272                  if debug:
273                      print(
274                          f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
275                      )
276              elif which_type == CODE_TYPE.standard_input:  # Standard input
277                  faulthandler.enable()
278                  passed = False
279  
280                  if isinstance(inputs, list):
281                      inputs = "\n".join(inputs)
282                  if isinstance(in_outs["outputs"][index], list):
283                      in_outs["outputs"][index] = "\n".join(in_outs["outputs"][index])
284  
285                  signal.alarm(timeout)
286                  with Capturing() as output:
287                      try:
288                          call_method(method, inputs)
289                          # reset the alarm
290                          signal.alarm(0)
291                          passed = True
292                      except Exception as e:
293                          # runtime error or took too long
294                          signal.alarm(0)
295                          print(
296                              f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}"
297                          )
298                          results.append(-1)
299                          return results
300                      signal.alarm(0)
301  
302                  if not passed:
303                      if debug:
304                          nl = "\n"
305                          if not isinstance(inputs, list):
306                              print(
307                                  f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
308                              )
309                          else:
310                              print(
311                                  f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
312                              )
313                      continue
314  
315                  if passed and debug:
316                      print(
317                          f"==> output = {output}, test outputs = {in_outs['outputs'][index]}"
318                      )
319  
320                  if custom_compare_(output, in_outs["outputs"][index]):
321                      tmp_result = True
322                      results.append(tmp_result)
323                      continue
324  
325                  # ground truth sequences are expressed as lists not tuples
326                  if isinstance(output, tuple):
327                      output = list(output)
328  
329                  tmp_result = False
330                  try:
331                      tmp_result = output == [in_outs["outputs"][index]]
332                      if isinstance(in_outs["outputs"][index], list):
333                          tmp_result = tmp_result or (output == in_outs["outputs"][index])
334                          if isinstance(output[0], str):
335                              tmp_result = tmp_result or (
336                                  [e.strip() for e in output] == in_outs["outputs"][index]
337                              )
338                  except Exception as e:
339                      if debug:
340                          print(f"Failed check1 exception = {e}")
341                      pass
342  
343                  if tmp_result == True:
344                      results.append(tmp_result)
345                      continue
346  
347                  # try one more time without \n
348                  if isinstance(in_outs["outputs"][index], list):
349                      for tmp_index, i in enumerate(in_outs["outputs"][index]):
350                          in_outs["outputs"][index][tmp_index] = i.split("\n")
351                          in_outs["outputs"][index][tmp_index] = [
352                              x.strip() for x in in_outs["outputs"][index][tmp_index] if x
353                          ]
354                  else:
355                      in_outs["outputs"][index] = in_outs["outputs"][index].split("\n")
356                      in_outs["outputs"][index] = list(
357                          filter(len, in_outs["outputs"][index])
358                      )
359                      in_outs["outputs"][index] = list(
360                          map(lambda x: x.strip(), in_outs["outputs"][index])
361                      )
362  
363                  try:
364                      tmp_result = output == [in_outs["outputs"][index]]
365                      if isinstance(in_outs["outputs"][index], list):
366                          tmp_result = tmp_result or (output == in_outs["outputs"][index])
367                  except Exception as e:
368                      if debug:
369                          print(f"Failed check2 exception = {e}")
370                      pass
371  
372                  if tmp_result == True:
373                      results.append(tmp_result)
374                      continue
375  
376                  # try by converting the output into a split up list too
377                  if isinstance(output, list):
378                      output = list(filter(len, output))
379  
380                  if debug:
381                      nl = "\n"
382                      if not isinstance(inputs, list):
383                          print(
384                              f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}"
385                          )
386                      else:
387                          print(
388                              f"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}"
389                          )
390  
391                  if tmp_result == True:
392                      results.append(tmp_result)
393                      continue
394  
395                  if debug:
396                      print(f"{tmp_result=} @a")
397  
398                  try:
399                      tmp_result = output == [in_outs["outputs"][index]]
400                      if isinstance(in_outs["outputs"][index], list):
401                          tmp_result = tmp_result or (output == in_outs["outputs"][index])
402                  except Exception as e:
403                      if debug:
404                          print(f"Failed check3 exception = {e}")
405                      pass
406  
407                  if debug:
408                      print(f"{tmp_result=} @b")
409  
410                  try:
411                      all_ints = all(
412                          combined_int_check(e1) and combined_int_check(e2)
413                          for e1, e2 in zip(output, in_outs["outputs"][index])
414                      )
415                      if not all_ints:
416                          if debug:
417                              print(
418                                  [
419                                      combined_int_check(e1) and combined_int_check(e2)
420                                      for e1, e2 in zip(output, in_outs["outputs"][index])
421                                  ]
422                              )
423                          output_float = [float(e) for e in output]
424                          gt_float = [float(e) for e in in_outs["outputs"][index]]
425                          tmp_result = tmp_result or (
426                              (len(output_float) == len(gt_float))
427                              and np.allclose(output_float, gt_float)
428                          )
429                  except Exception as e:
430                      pass
431  
432                  if debug:
433                      print(f"{tmp_result=} @c")
434  
435                  try:
436                      if isinstance(output[0], list):
437                          all_ints = all(
438                              combined_int_check(e1) and combined_int_check(e2)
439                              for e1, e2 in zip(output[0], in_outs["outputs"][index])
440                          )
441                          if not all_ints:
442                              output_float = [float(e) for e in output[0]]
443                              gt_float = [float(e) for e in in_outs["outputs"][index][0]]
444                              tmp_result = tmp_result or (
445                                  (len(output_float) == len(gt_float))
446                                  and np.allclose(output_float, gt_float)
447                              )
448                  except Exception as e:
449                      pass
450  
451                  if tmp_result == True:
452                      results.append(tmp_result)
453                      continue
454  
455                  if debug:
456                      print(f"{tmp_result=} @d")
457                  # try by converting the stuff into split up list
458                  if isinstance(in_outs["outputs"][index], list):
459                      for tmp_index, i in enumerate(in_outs["outputs"][index]):
460                          in_outs["outputs"][index][tmp_index] = set(i.split())
461                  else:
462                      in_outs["outputs"][index] = set(in_outs["outputs"][index].split())
463  
464                  if debug:
465                      print(f"{tmp_result=} @e")
466  
467                  try:
468                      tmp_result = output == in_outs["outputs"][index]
469                  except Exception as e:
470                      if debug:
471                          print(f"Failed check4 exception = {e}")
472                      continue
473  
474                  if tmp_result == True:
475                      results.append(tmp_result)
476                      continue
477  
478                  if debug:
479                      print(f"{tmp_result=} @f")
480  
481                  # try by converting the output into a split up list too
482                  if isinstance(output, list):
483                      for tmp_index, i in enumerate(output):
484                          output[tmp_index] = i.split()
485                      output = list(filter(len, output))
486                      for tmp_index, i in enumerate(output):
487                          output[tmp_index] = set(i)
488                  else:
489                      output = output.split()
490                      output = list(filter(len, output))
491                      output = set(output)
492  
493                  if debug:
494                      print(f"{tmp_result=} @g")
495                  # try:
496                  #     tmp_result = set(frozenset(s) for s in output) == set(
497                  #         frozenset(s) for s in in_outs["outputs"][index]
498                  #     )
499                  # except Exception as e:
500                  #     if debug:
501                  #         print(f"Failed check5 exception = {e}")
502  
503                  # if they are all numbers, round so that similar numbers are treated as identical
504                  # try:
505                  #     all_ints = all(
506                  #         combined_int_check(e1) and combined_int_check(e2)
507                  #         for e1, e2 in zip(output, in_outs["outputs"][index])
508                  #     )
509                  #     tmp_result = tmp_result or (
510                  #         set(frozenset(round(float(t), 3) for t in s) for s in output)
511                  #         == set(
512                  #             frozenset(round(float(t), 3) for t in s)
513                  #             for s in in_outs["outputs"][index]
514                  #         )
515                  #     )
516                  # except Exception as e:
517                  #     if debug:
518                  #         print(f"Failed check6 exception = {e}")
519  
520                  if debug:
521                      print(f"{tmp_result=} @h")
522  
523                  if tmp_result == True and debug:
524                      print("PASSED")
525  
526                  results.append(tmp_result)
527                  if tmp_result != True:
528                      return results
529  
530                  if debug:
531                      nl = "\n"
532                      if not isinstance(inputs, list):
533                          print(
534                              f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
535                          )
536                      else:
537                          print(
538                              f"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}"
539                          )
540  
541                      print(f"results = {results}")
542  
543      return results
544  
545  
546  def custom_compare_(output, ground_truth):
547  
548      if isinstance(output, list):
549          output_1 = "\n".join(output)
550          if stripped_string_compare(output_1, ground_truth):
551              return True
552  
553      if isinstance(output, list):
554          output_2 = [o.lstrip().rstrip() for o in output]
555          output_2 = "\n".join(output_2)
556          if stripped_string_compare(output_2, ground_truth):
557              return True
558  
559      return False
560  
561  
562  def stripped_string_compare(s1, s2):
563      s1 = s1.lstrip().rstrip()
564      s2 = s2.lstrip().rstrip()
565      return s1 == s2
566  
567  
568  def call_method(method, inputs):
569  
570      if isinstance(inputs, list):
571          inputs = "\n".join(inputs)
572  
573      inputs_line_iterator = iter(inputs.split("\n"))
574  
575      # sys.setrecursionlimit(10000)
576  
577      # @patch('builtins.input', side_effect=inputs.split("\n"))
578      @patch("builtins.open", mock_open(read_data=inputs))
579      @patch("sys.stdin", StringIO(inputs))
580      @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator))
581      @patch("sys.stdin.readlines", lambda *args: inputs.split("\n"))
582      @patch("sys.stdin.read", lambda *args: inputs)
583      # @patch('sys.stdout.write', print)
584      def _inner_call_method(_method):
585          try:
586              return _method()
587          except SystemExit as e:
588              pass
589          finally:
590              pass
591  
592      return _inner_call_method(method)
593  
594  
595  def reliability_guard(maximum_memory_bytes=None):
596      """
597      This disables various destructive functions and prevents the generated code
598      from interfering with the test (e.g. fork bomb, killing other processes,
599      removing filesystem files, etc.)
600      WARNING
601      This function is NOT a security sandbox. Untrusted code, including, model-
602      generated code, should not be blindly executed outside of one. See the
603      Codex paper for more information about OpenAI's code sandbox, and proceed
604      with caution.
605      """
606  
607      if maximum_memory_bytes is not None:
608          import resource
609  
610          resource.setrlimit(
611              resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)
612          )
613          resource.setrlimit(
614              resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)
615          )
616          if not platform.uname().system == "Darwin":
617              resource.setrlimit(
618                  resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)
619              )
620  
621      faulthandler.disable()
622  
623      import builtins
624  
625      builtins.exit = None
626      builtins.quit = None
627  
628      import os
629  
630      os.environ["OMP_NUM_THREADS"] = "1"
631  
632      os.kill = None
633      os.system = None
634      os.putenv = None
635      os.remove = None
636      os.removedirs = None
637      os.rmdir = None
638      os.fchdir = None
639      os.setuid = None
640      os.fork = None
641      os.forkpty = None
642      os.killpg = None
643      os.rename = None
644      os.renames = None
645      os.truncate = None
646      os.replace = None
647      os.unlink = None
648      os.fchmod = None
649      os.fchown = None
650      os.chmod = None
651      os.chown = None
652      os.chroot = None
653      os.fchdir = None
654      os.lchflags = None
655      os.lchmod = None
656      os.lchown = None
657      os.getcwd = None
658      os.chdir = None
659  
660      import shutil
661  
662      shutil.rmtree = None
663      shutil.move = None
664      shutil.chown = None
665  
666      import subprocess
667  
668      subprocess.Popen = None  # type: ignore
669  
670      __builtins__["help"] = None
671  
672      import sys
673  
674      sys.modules["ipdb"] = None
675      sys.modules["joblib"] = None
676      sys.modules["resource"] = None
677      sys.modules["psutil"] = None
678      sys.modules["tkinter"] = None