monkey_patch.py
1 import os 2 import random 3 4 from triton.runtime.cache import FileCacheManager 5 6 7 class LigerTritonFileCacheManager(FileCacheManager): 8 def put(self, data, filename, binary=True) -> str: 9 if not self.cache_dir: 10 raise RuntimeError("Could not create or locate cache dir") 11 binary = isinstance(data, bytes) 12 if not binary: 13 data = str(data) 14 assert self.lock_path is not None 15 filepath = self._make_path(filename) 16 # Random ID to avoid any collisions 17 rnd_id = random.randint(0, 1000000) 18 # we use the PID incase a bunch of these around so we can see what PID made it 19 pid = os.getpid() 20 # use temp dir to be robust against program interruptions 21 temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}") 22 os.makedirs(temp_dir, exist_ok=True) 23 temp_path = os.path.join(temp_dir, filename) 24 25 mode = "wb" if binary else "w" 26 with open(temp_path, mode) as f: 27 f.write(data) 28 # Replace is guaranteed to be atomic on POSIX systems if it succeeds 29 # so filepath cannot see a partial write 30 os.replace(temp_path, filepath) 31 os.removedirs(temp_dir) 32 return filepath 33 34 35 def apply_liger_triton_cache_manager(): 36 """ 37 Experimental feature to get around transient FileNotFoundError in triton compilation. 38 For more details please see https://github.com/triton-lang/triton/pull/4295 39 """ 40 os.environ["TRITON_CACHE_MANAGER"] = ( 41 "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager" 42 )