/ mlflow / pyfunc / spark_model_cache.py
spark_model_cache.py
 1  from mlflow.utils._spark_utils import _SparkDirectoryDistributor
 2  
 3  
 4  class SparkModelCache:
 5      """Caches models in memory on Spark Executors, to avoid continually reloading from disk.
 6  
 7      This class has to be part of a different module than the one that _uses_ it. This is
 8      because Spark will pickle classes that are defined in the local scope, but relies on
 9      Python's module loading behavior for classes in different modules. In this case, we
10      are relying on the fact that Python will load a module at-most-once, and can therefore
11      store per-process state in a static map.
12      """
13  
14      # Map from unique name --> (loaded model, local_model_path).
15      _models = {}
16  
17      # Number of cache hits we've had, for testing purposes.
18      _cache_hits = 0
19  
20      def __init__(self):
21          pass
22  
23      @staticmethod
24      def add_local_model(spark, model_path):
25          """Given a SparkSession and a model_path which refers to a pyfunc directory locally,
26          we will zip the directory up, enable it to be distributed to executors, and return
27          the "archive_path", which should be used as the path in get_or_load().
28          """
29          return _SparkDirectoryDistributor.add_dir(spark, model_path)
30  
31      @staticmethod
32      def get_or_load(archive_path):
33          """Given a path returned by add_local_model(), this method will return a tuple of
34          (loaded_model, local_model_path).
35          If this Python process ever loaded the model before, we will reuse that copy.
36          """
37          if archive_path in SparkModelCache._models:
38              SparkModelCache._cache_hits += 1
39              return SparkModelCache._models[archive_path]
40  
41          local_model_dir = _SparkDirectoryDistributor.get_or_extract(archive_path)
42  
43          # We must rely on a supposed cyclic import here because we want this behavior
44          # on the Spark Executors (i.e., don't try to pickle the load_model function).
45          from mlflow.pyfunc import load_model
46  
47          SparkModelCache._models[archive_path] = (load_model(local_model_dir), local_model_dir)
48          return SparkModelCache._models[archive_path]