spark_model_cache.py
1 from mlflow.utils._spark_utils import _SparkDirectoryDistributor 2 3 4 class SparkModelCache: 5 """Caches models in memory on Spark Executors, to avoid continually reloading from disk. 6 7 This class has to be part of a different module than the one that _uses_ it. This is 8 because Spark will pickle classes that are defined in the local scope, but relies on 9 Python's module loading behavior for classes in different modules. In this case, we 10 are relying on the fact that Python will load a module at-most-once, and can therefore 11 store per-process state in a static map. 12 """ 13 14 # Map from unique name --> (loaded model, local_model_path). 15 _models = {} 16 17 # Number of cache hits we've had, for testing purposes. 18 _cache_hits = 0 19 20 def __init__(self): 21 pass 22 23 @staticmethod 24 def add_local_model(spark, model_path): 25 """Given a SparkSession and a model_path which refers to a pyfunc directory locally, 26 we will zip the directory up, enable it to be distributed to executors, and return 27 the "archive_path", which should be used as the path in get_or_load(). 28 """ 29 return _SparkDirectoryDistributor.add_dir(spark, model_path) 30 31 @staticmethod 32 def get_or_load(archive_path): 33 """Given a path returned by add_local_model(), this method will return a tuple of 34 (loaded_model, local_model_path). 35 If this Python process ever loaded the model before, we will reuse that copy. 36 """ 37 if archive_path in SparkModelCache._models: 38 SparkModelCache._cache_hits += 1 39 return SparkModelCache._models[archive_path] 40 41 local_model_dir = _SparkDirectoryDistributor.get_or_extract(archive_path) 42 43 # We must rely on a supposed cyclic import here because we want this behavior 44 # on the Spark Executors (i.e., don't try to pickle the load_model function). 45 from mlflow.pyfunc import load_model 46 47 SparkModelCache._models[archive_path] = (load_model(local_model_dir), local_model_dir) 48 return SparkModelCache._models[archive_path]