/ src / liger_kernel / triton / monkey_patch.py
monkey_patch.py
 1  import os
 2  import random
 3  
 4  from triton.runtime.cache import FileCacheManager
 5  
 6  
 7  class LigerTritonFileCacheManager(FileCacheManager):
 8      def put(self, data, filename, binary=True) -> str:
 9          if not self.cache_dir:
10              raise RuntimeError("Could not create or locate cache dir")
11          binary = isinstance(data, bytes)
12          if not binary:
13              data = str(data)
14          assert self.lock_path is not None
15          filepath = self._make_path(filename)
16          # Random ID to avoid any collisions
17          rnd_id = random.randint(0, 1000000)
18          # we use the PID incase a bunch of these around so we can see what PID made it
19          pid = os.getpid()
20          # use temp dir to be robust against program interruptions
21          temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
22          os.makedirs(temp_dir, exist_ok=True)
23          temp_path = os.path.join(temp_dir, filename)
24  
25          mode = "wb" if binary else "w"
26          with open(temp_path, mode) as f:
27              f.write(data)
28          # Replace is guaranteed to be atomic on POSIX systems if it succeeds
29          # so filepath cannot see a partial write
30          os.replace(temp_path, filepath)
31          os.removedirs(temp_dir)
32          return filepath
33  
34  
35  def apply_liger_triton_cache_manager():
36      """
37      Experimental feature to get around transient FileNotFoundError in triton compilation.
38      For more details please see https://github.com/triton-lang/triton/pull/4295
39      """
40      os.environ["TRITON_CACHE_MANAGER"] = (
41          "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"
42      )