/ mlflow / utils / _spark_utils.py
_spark_utils.py
  1  import contextlib
  2  import multiprocessing
  3  import os
  4  import shutil
  5  import tempfile
  6  import zipfile
  7  
  8  
  9  def _get_active_spark_session():
 10      try:
 11          from pyspark.sql import SparkSession
 12      except ImportError:
 13          # Return None if user doesn't have PySpark installed
 14          return None
 15      try:
 16          # getActiveSession() only exists in Spark 3.0 and above
 17          return SparkSession.getActiveSession()
 18      except Exception:
 19          # Fall back to this internal field for Spark 2.x and below.
 20          return SparkSession._instantiatedSession
 21  
 22  
 23  # Suppose we have a parent process already initiate a spark session that connected to a spark
 24  # cluster, then the parent process spawns a child process, if child process directly creates
 25  # a local spark session, it does not work correctly, because of PYSPARK_GATEWAY_PORT and
 26  # PYSPARK_GATEWAY_SECRET are inherited from parent process and child process pyspark session
 27  # will try to connect to the port and cause error.
 28  # So the 2 lines here are to clear 'PYSPARK_GATEWAY_PORT' and 'PYSPARK_GATEWAY_SECRET' to
 29  # enforce launching a new pyspark JVM gateway.
 30  def _prepare_subprocess_environ_for_creating_local_spark_session():
 31      from mlflow.utils.databricks_utils import is_in_databricks_runtime
 32  
 33      if is_in_databricks_runtime():
 34          os.environ["SPARK_DIST_CLASSPATH"] = "/databricks/jars/*"
 35  
 36      os.environ.pop("PYSPARK_GATEWAY_PORT", None)
 37      os.environ.pop("PYSPARK_GATEWAY_SECRET", None)
 38  
 39  
 40  def _get_spark_scala_version_from_spark_session(spark):
 41      version = spark._jvm.scala.util.Properties.versionNumberString().split(".", 2)
 42      return f"{version[0]}.{version[1]}"
 43  
 44  
 45  def _get_spark_scala_version_child_proc_target(result_queue):
 46      from pyspark.sql import SparkSession
 47  
 48      _prepare_subprocess_environ_for_creating_local_spark_session()
 49      with SparkSession.builder.master("local[1]").getOrCreate() as spark_session:
 50          scala_version = _get_spark_scala_version_from_spark_session(spark_session)
 51          result_queue.put(scala_version)
 52  
 53  
 54  def _get_spark_scala_version():
 55      from mlflow.utils.databricks_utils import is_in_databricks_runtime
 56  
 57      if is_in_databricks_runtime() and "SPARK_SCALA_VERSION" in os.environ:
 58          return os.environ["SPARK_SCALA_VERSION"]
 59  
 60      if spark := _get_active_spark_session():
 61          return _get_spark_scala_version_from_spark_session(spark)
 62  
 63      result_queue = multiprocessing.Queue()
 64  
 65      # If we need to create a new spark local session for reading scala version,
 66      # we have to create the temporal spark session in a child process,
 67      # if we create the temporal spark session in current process,
 68      # after terminating the temporal spark session, creating another spark session
 69      # with "spark.jars.packages" configuration doesn't work.
 70      proc = multiprocessing.Process(
 71          target=_get_spark_scala_version_child_proc_target, args=(result_queue,)
 72      )
 73      proc.start()
 74      proc.join()
 75      if proc.exitcode != 0:
 76          raise RuntimeError("Failed to read scala version.")
 77  
 78      return result_queue.get()
 79  
 80  
 81  def _create_local_spark_session_for_loading_spark_model():
 82      from pyspark.sql import SparkSession
 83  
 84      return (
 85          SparkSession.builder
 86          .config("spark.python.worker.reuse", "true")
 87          # The config is a workaround for avoiding databricks delta cache issue when loading
 88          # some specific model such as ALSModel.
 89          .config("spark.databricks.io.cache.enabled", "false")
 90          # In Spark 3.1 and above, we need to set this conf explicitly to enable creating
 91          # a SparkSession on the workers
 92          .config("spark.executor.allowSparkContext", "true")
 93          # Binding "spark.driver.host" to 127.0.0.1 helps avoiding some local hostname
 94          # related issues (e.g. https://github.com/mlflow/mlflow/issues/5733).
 95          # Note that we should set "spark.driver.host" instead of "spark.driver.bindAddress",
 96          # the latter one only set server binding host, but it doesn't set client side request
 97          # destination host.
 98          .config("spark.driver.host", "127.0.0.1")
 99          .config("spark.executor.allowSparkContext", "true")
100          .config(
101              "spark.driver.extraJavaOptions",
102              "-Dlog4j.configuration=file:/usr/local/spark/conf/log4j.properties",
103          )
104          .master("local[1]")
105          .getOrCreate()
106      )
107  
108  
109  _NFS_PATH_PREFIX = "nfs:"
110  
111  
112  def _get_spark_distributor_nfs_cache_dir():
113      from mlflow.utils.nfs_on_spark import get_nfs_cache_root_dir  # avoid circular import
114  
115      if (nfs_root_dir := get_nfs_cache_root_dir()) is not None:
116          cache_dir = os.path.join(nfs_root_dir, "mlflow_distributor_cache_dir")
117          os.makedirs(cache_dir, exist_ok=True)
118          return cache_dir
119      return None
120  
121  
122  class _SparkDirectoryDistributor:
123      """Distribute spark directory from driver to executors."""
124  
125      _extracted_dir_paths = {}
126  
127      def __init__(self):
128          pass
129  
130      @staticmethod
131      def add_dir(spark, dir_path):
132          """Given a SparkSession and a model_path which refers to a pyfunc directory locally,
133          we will zip the directory up, enable it to be distributed to executors, and return
134          the "archive_path", which should be used as the path in get_or_load().
135          """
136          _, archive_basepath = tempfile.mkstemp()
137          # NB: We must archive the directory as Spark.addFile does not support non-DFS
138          # directories when recursive=True.
139          archive_path = shutil.make_archive(archive_basepath, "zip", dir_path)
140  
141          if (nfs_cache_dir := _get_spark_distributor_nfs_cache_dir()) is not None:
142              # If NFS directory (shared by all spark nodes) is available, use NFS directory
143              # instead of `SparkContext.addFile` to distribute files.
144              # Because `SparkContext.addFile` is not secure, so it is not allowed to be called
145              # on a shared cluster.
146              dest_path = os.path.join(nfs_cache_dir, os.path.basename(archive_path))
147              shutil.copy(archive_path, dest_path)
148              return _NFS_PATH_PREFIX + dest_path
149  
150          spark.sparkContext.addFile(archive_path)
151          return archive_path
152  
153      @staticmethod
154      def get_or_extract(archive_path):
155          """Given a path returned by add_local_model(), this method will return a tuple of
156          (loaded_model, local_model_path).
157          If this Python process ever loaded the model before, we will reuse that copy.
158          """
159          from pyspark.files import SparkFiles
160  
161          if archive_path in _SparkDirectoryDistributor._extracted_dir_paths:
162              return _SparkDirectoryDistributor._extracted_dir_paths[archive_path]
163  
164          # BUG: Despite the documentation of SparkContext.addFile() and SparkFiles.get() in Scala
165          # and Python, it turns out that we actually need to use the basename as the input to
166          # SparkFiles.get(), as opposed to the (absolute) path.
167          if archive_path.startswith(_NFS_PATH_PREFIX):
168              local_path = archive_path[len(_NFS_PATH_PREFIX) :]
169          else:
170              archive_path_basename = os.path.basename(archive_path)
171              local_path = SparkFiles.get(archive_path_basename)
172          temp_dir = tempfile.mkdtemp()
173          zip_ref = zipfile.ZipFile(local_path, "r")
174          zip_ref.extractall(temp_dir)
175          zip_ref.close()
176  
177          _SparkDirectoryDistributor._extracted_dir_paths[archive_path] = temp_dir
178          return _SparkDirectoryDistributor._extracted_dir_paths[archive_path]
179  
180  
181  @contextlib.contextmanager
182  def modified_environ(update):
183      """Temporarily updates the ``os.environ`` dictionary in-place.
184  
185      The ``os.environ`` dictionary is updated in-place so that the modification
186      is sure to work in all situations.
187  
188      Args:
189          update: Dictionary of environment variables and values to add/update.
190      """
191      update = update or {}
192      original_env = {k: os.environ.get(k) for k in update}
193  
194      try:
195          os.environ.update(update)
196          yield
197      finally:
198          for k, v in original_env.items():
199              if v is None:
200                  os.environ.pop(k, None)
201              else:
202                  os.environ[k] = v