utils.py
  1  # The MIT License
  2  #
  3  # Copyright (c) OpenAI (https://openai.com)
  4  #
  5  # Permission is hereby granted, free of charge, to any person obtaining a copy
  6  # of this software and associated documentation files (the "Software"), to deal
  7  # in the Software without restriction, including without limitation the rights
  8  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9  # copies of the Software, and to permit persons to whom the Software is
 10  # furnished to do so, subject to the following conditions:
 11  #
 12  # The above copyright notice and this permission notice shall be included in
 13  # all copies or substantial portions of the Software.
 14  #
 15  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 16  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 17  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 18  # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 19  # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 20  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 21  # THE SOFTWARE.
 22  
 23  import contextlib
 24  import faulthandler
 25  import io
 26  import os
 27  import platform
 28  import signal
 29  import tempfile
 30  import subprocess
 31  import multiprocessing
 32  from typing import Optional
 33  
 34  TIMEOUT_LIMIT=240.0
 35  
 36  @contextlib.contextmanager
 37  def swallow_subprocess_output():
 38      """Context manager to swallow stdout and stderr for subprocesses."""
 39      original_popen = subprocess.Popen
 40      original_run = subprocess.run
 41  
 42      def _popen_patch(*args, **kwargs):
 43          if 'capture_output' in kwargs and kwargs['capture_output']:
 44              # Avoid setting stdout or stderr if capture_output is True
 45              kwargs.pop('stdout', None)
 46              kwargs.pop('stderr', None)
 47          else:
 48              kwargs.setdefault('stdout', subprocess.PIPE)
 49              kwargs.setdefault('stderr', subprocess.PIPE)
 50          return original_popen(*args, **kwargs)
 51  
 52      def _run_patch(*args, **kwargs):
 53          if 'capture_output' in kwargs and kwargs['capture_output']:
 54              # Avoid setting stdout or stderr if capture_output is True
 55              kwargs.pop('stdout', None)
 56              kwargs.pop('stderr', None)
 57          else:
 58              kwargs.setdefault('stdout', subprocess.PIPE)
 59              kwargs.setdefault('stderr', subprocess.PIPE)
 60          return original_run(*args, **kwargs)
 61  
 62      subprocess.Popen = _popen_patch
 63      subprocess.run = _run_patch
 64      try:
 65          yield
 66      finally:
 67          subprocess.Popen = original_popen
 68          subprocess.run = original_run
 69  
 70  @contextlib.contextmanager
 71  def swallow_io():
 72      stream = WriteOnlyStringIO()
 73      with contextlib.redirect_stdout(stream):
 74          with contextlib.redirect_stderr(stream):
 75              with redirect_stdin(stream):
 76                  with swallow_subprocess_output():
 77                      yield
 78  
 79  
 80  @contextlib.contextmanager
 81  def time_limit(seconds: float):
 82      def signal_handler(signum, frame):
 83          raise TimeoutException("Timed out!")
 84  
 85      signal.setitimer(signal.ITIMER_REAL, seconds)
 86      signal.signal(signal.SIGALRM, signal_handler)
 87      try:
 88          yield
 89      finally:
 90          signal.setitimer(signal.ITIMER_REAL, 0)
 91  
 92  
 93  @contextlib.contextmanager
 94  def create_tempdir():
 95      with tempfile.TemporaryDirectory() as dirname:
 96          with chdir(dirname):
 97              yield dirname
 98  
 99  
100  @contextlib.contextmanager
101  def chdir(root):
102      if root == ".":
103          yield
104          return
105      cwd = os.getcwd()
106      os.chdir(root)
107      try:
108          yield
109      except BaseException as exc:
110          raise exc
111      finally:
112          os.chdir(cwd)
113  
114  
115  @contextlib.contextmanager
116  def safe_environment():
117      # Save original functions
118      original_kill = os.kill
119      original_killpg = os.killpg
120      original_system = os.system
121      original_subprocess_call = subprocess.call
122      original_subprocess_check_output = subprocess.check_output
123      original_subprocess_run = subprocess.run
124      original_subprocess_popen = subprocess.Popen
125      original_os_popen = os.popen
126      original_os_execv = os.execv
127      original_os_execvp = os.execvp
128      original_os_execvpe = os.execvpe
129  
130      current_pid = os.getpid()
131      current_pgid = os.getpgid(current_pid)
132      manager = multiprocessing.Manager()
133      child_pids = manager.list()
134  
135      def safe_kill(pid, sig):
136          try:
137              pgid = os.getpgid(pid)
138              if pid == current_pid or pid in child_pids:
139                  print(f"Allowed to kill PID {pid} with signal {sig}")
140                  original_kill(pid, sig)
141              else:
142                  print(f"Prevented attempt to kill PID {pid} with signal {sig}")
143          except ProcessLookupError:
144              print(f"Process {pid} does not exist.")
145  
146      def safe_killpg(pgid, sig):
147          if pgid == current_pgid or pgid in {os.getpgid(pid) for pid in child_pids}:
148              print(f"Allowed to kill PGID {pgid} with signal {sig}")
149              original_killpg(pgid, sig)
150          else:
151              print(f"Prevented attempt to kill PGID {pgid} with signal {sig}")
152  
153      def safe_system(command):
154          print(f"Intercepted system command: {command}")
155          if 'kill' in command or 'killall' in command:
156              return 0  # Simulate successful execution without doing anything
157          return original_system(command)
158  
159      def safe_subprocess_call(command, *args, **kwargs):
160          print(f"Intercepted subprocess call: {command}")
161          if 'kill' in command or 'killall' in command:
162              return 0  # Simulate successful execution without doing anything
163          return original_subprocess_call(command, *args, **kwargs)
164  
165      def safe_subprocess_check_output(command, *args, **kwargs):
166          print(f"Intercepted command: {command}")
167          if 'ps' in command:
168              return b""  # Simulate no processes found
169          return original_subprocess_check_output(command, *args, **kwargs)
170  
171      def safe_subprocess_run(*args, **kwargs):
172          print(f"Intercepted subprocess run command: {args}")
173          if 'kill' in args[0] or 'killall' in args[0]:
174              return subprocess.CompletedProcess(args, 0, b'', b'')  # Simulate successful execution
175          return original_subprocess_run(*args, **kwargs)
176  
177      class SafePopen(subprocess.Popen):
178          def __init__(self, *args, **kwargs):
179              print(f"Intercepted Popen command: {args}")
180              kwargs['preexec_fn'] = os.setsid  # Start the process in a new session
181              super().__init__(*args, **kwargs)
182              child_pids.append(self.pid)
183  
184          def communicate(self, *args, **kwargs):
185              try:
186                  return super().communicate(*args, **kwargs)
187              except subprocess.TimeoutExpired:
188                  print("Timeout expired, intercepted and returning None")
189                  return None, None
190  
191          def kill(self):
192              print(f"Intercepted kill call for PID {self.pid}")
193              safe_kill(self.pid, signal.SIGTERM)
194  
195          def terminate(self):
196              print(f"Intercepted terminate call for PID {self.pid}")
197              safe_kill(self.pid, signal.SIGTERM)
198  
199      def safe_os_popen(command):
200          print(f"Intercepted os.popen command: {command}")
201          if 'kill' in command or 'killall' in command:
202              return os.popen('echo Intercepted')
203          return original_os_popen(command)
204  
205      def safe_exec(*args, **kwargs):
206          print(f"Intercepted exec command: {args}")
207  
208      # Override the risky functions with the safe versions
209      os.kill = safe_kill
210      os.killpg = safe_killpg
211      os.system = safe_system
212      subprocess.call = safe_subprocess_call
213      subprocess.check_output = safe_subprocess_check_output
214      subprocess.run = safe_subprocess_run
215      subprocess.Popen = SafePopen
216      os.popen = safe_os_popen
217      os.execv = safe_exec
218      os.execvp = safe_exec
219      os.execvpe = safe_exec
220  
221      try:
222          yield
223      finally:
224          # Restore original functions after the block
225          os.kill = original_kill
226          os.killpg = original_killpg
227          os.system = original_system
228          subprocess.call = original_subprocess_call
229          subprocess.check_output = original_subprocess_check_output
230          subprocess.run = original_subprocess_run
231          subprocess.Popen = original_subprocess_popen
232          os.popen = original_os_popen
233          os.execv = original_os_execv
234          os.execvp = original_os_execvp
235          os.execvpe = original_os_execvpe
236  
237  
238  class TimeoutException(Exception):
239      pass
240  
241  
242  class WriteOnlyStringIO(io.StringIO):
243      """StringIO that throws an exception when it's read from"""
244  
245      def read(self, *args, **kwargs):
246          raise IOError
247  
248      def readline(self, *args, **kwargs):
249          raise IOError
250  
251      def readlines(self, *args, **kwargs):
252          raise IOError
253  
254      def readable(self, *args, **kwargs):
255          """Returns True if the IO object can be read."""
256          return False
257  
258  
259  class redirect_stdin(contextlib._RedirectStream):  # type: ignore
260      _stream = "stdin"
261  
262  
263  def reliability_guard(max_as_limit, max_data_limit, max_stack_limit):
264      """
265      This disables various destructive functions and prevents the generated code
266      from interfering with the test (e.g. fork bomb, killing other processes,
267      removing filesystem files, etc.)
268  
269      WARNING
270      This function is NOT a security sandbox. Untrusted code, including, model-
271      generated code, should not be blindly executed outside of one. See the
272      Codex paper for more information about OpenAI's code sandbox, and proceed
273      with caution.
274      """
275      
276      import os
277      import time
278      from datetime import datetime
279  
280      os.environ['TZ'] = 'UTC'
281      time.tzset()
282      
283      os.environ["OMP_NUM_THREADS"] = "1"
284      os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3" 
285      os.environ['TF_ENABLE_ONEDNN_OPTS'] = "0"
286      
287      if max_as_limit and max_data_limit and max_stack_limit:
288          import resource
289          
290          max_as_limit = max_as_limit * 1024 * 1024
291          max_data_limit = max_data_limit * 1024 * 1024
292          max_stack_limit = max_stack_limit * 1024 * 1024
293          
294          resource.setrlimit(
295              resource.RLIMIT_AS, (max_as_limit, max_as_limit)
296          )
297          resource.setrlimit(
298              resource.RLIMIT_DATA, (max_data_limit, max_data_limit)
299          )
300          if not platform.uname().system == "Darwin":
301              resource.setrlimit(
302                  resource.RLIMIT_STACK, (max_stack_limit, max_stack_limit)
303              )
304  
305      faulthandler.disable()
306  
307      import builtins
308  
309      builtins.exit = None
310      builtins.quit = None
311  
312      import matplotlib.pyplot as plt
313      plt.close('all')