_spark_utils.py
1 import contextlib 2 import multiprocessing 3 import os 4 import shutil 5 import tempfile 6 import zipfile 7 8 9 def _get_active_spark_session(): 10 try: 11 from pyspark.sql import SparkSession 12 except ImportError: 13 # Return None if user doesn't have PySpark installed 14 return None 15 try: 16 # getActiveSession() only exists in Spark 3.0 and above 17 return SparkSession.getActiveSession() 18 except Exception: 19 # Fall back to this internal field for Spark 2.x and below. 20 return SparkSession._instantiatedSession 21 22 23 # Suppose we have a parent process already initiate a spark session that connected to a spark 24 # cluster, then the parent process spawns a child process, if child process directly creates 25 # a local spark session, it does not work correctly, because of PYSPARK_GATEWAY_PORT and 26 # PYSPARK_GATEWAY_SECRET are inherited from parent process and child process pyspark session 27 # will try to connect to the port and cause error. 28 # So the 2 lines here are to clear 'PYSPARK_GATEWAY_PORT' and 'PYSPARK_GATEWAY_SECRET' to 29 # enforce launching a new pyspark JVM gateway. 30 def _prepare_subprocess_environ_for_creating_local_spark_session(): 31 from mlflow.utils.databricks_utils import is_in_databricks_runtime 32 33 if is_in_databricks_runtime(): 34 os.environ["SPARK_DIST_CLASSPATH"] = "/databricks/jars/*" 35 36 os.environ.pop("PYSPARK_GATEWAY_PORT", None) 37 os.environ.pop("PYSPARK_GATEWAY_SECRET", None) 38 39 40 def _get_spark_scala_version_from_spark_session(spark): 41 version = spark._jvm.scala.util.Properties.versionNumberString().split(".", 2) 42 return f"{version[0]}.{version[1]}" 43 44 45 def _get_spark_scala_version_child_proc_target(result_queue): 46 from pyspark.sql import SparkSession 47 48 _prepare_subprocess_environ_for_creating_local_spark_session() 49 with SparkSession.builder.master("local[1]").getOrCreate() as spark_session: 50 scala_version = _get_spark_scala_version_from_spark_session(spark_session) 51 result_queue.put(scala_version) 52 53 54 def _get_spark_scala_version(): 55 from mlflow.utils.databricks_utils import is_in_databricks_runtime 56 57 if is_in_databricks_runtime() and "SPARK_SCALA_VERSION" in os.environ: 58 return os.environ["SPARK_SCALA_VERSION"] 59 60 if spark := _get_active_spark_session(): 61 return _get_spark_scala_version_from_spark_session(spark) 62 63 result_queue = multiprocessing.Queue() 64 65 # If we need to create a new spark local session for reading scala version, 66 # we have to create the temporal spark session in a child process, 67 # if we create the temporal spark session in current process, 68 # after terminating the temporal spark session, creating another spark session 69 # with "spark.jars.packages" configuration doesn't work. 70 proc = multiprocessing.Process( 71 target=_get_spark_scala_version_child_proc_target, args=(result_queue,) 72 ) 73 proc.start() 74 proc.join() 75 if proc.exitcode != 0: 76 raise RuntimeError("Failed to read scala version.") 77 78 return result_queue.get() 79 80 81 def _create_local_spark_session_for_loading_spark_model(): 82 from pyspark.sql import SparkSession 83 84 return ( 85 SparkSession.builder 86 .config("spark.python.worker.reuse", "true") 87 # The config is a workaround for avoiding databricks delta cache issue when loading 88 # some specific model such as ALSModel. 89 .config("spark.databricks.io.cache.enabled", "false") 90 # In Spark 3.1 and above, we need to set this conf explicitly to enable creating 91 # a SparkSession on the workers 92 .config("spark.executor.allowSparkContext", "true") 93 # Binding "spark.driver.host" to 127.0.0.1 helps avoiding some local hostname 94 # related issues (e.g. https://github.com/mlflow/mlflow/issues/5733). 95 # Note that we should set "spark.driver.host" instead of "spark.driver.bindAddress", 96 # the latter one only set server binding host, but it doesn't set client side request 97 # destination host. 98 .config("spark.driver.host", "127.0.0.1") 99 .config("spark.executor.allowSparkContext", "true") 100 .config( 101 "spark.driver.extraJavaOptions", 102 "-Dlog4j.configuration=file:/usr/local/spark/conf/log4j.properties", 103 ) 104 .master("local[1]") 105 .getOrCreate() 106 ) 107 108 109 _NFS_PATH_PREFIX = "nfs:" 110 111 112 def _get_spark_distributor_nfs_cache_dir(): 113 from mlflow.utils.nfs_on_spark import get_nfs_cache_root_dir # avoid circular import 114 115 if (nfs_root_dir := get_nfs_cache_root_dir()) is not None: 116 cache_dir = os.path.join(nfs_root_dir, "mlflow_distributor_cache_dir") 117 os.makedirs(cache_dir, exist_ok=True) 118 return cache_dir 119 return None 120 121 122 class _SparkDirectoryDistributor: 123 """Distribute spark directory from driver to executors.""" 124 125 _extracted_dir_paths = {} 126 127 def __init__(self): 128 pass 129 130 @staticmethod 131 def add_dir(spark, dir_path): 132 """Given a SparkSession and a model_path which refers to a pyfunc directory locally, 133 we will zip the directory up, enable it to be distributed to executors, and return 134 the "archive_path", which should be used as the path in get_or_load(). 135 """ 136 _, archive_basepath = tempfile.mkstemp() 137 # NB: We must archive the directory as Spark.addFile does not support non-DFS 138 # directories when recursive=True. 139 archive_path = shutil.make_archive(archive_basepath, "zip", dir_path) 140 141 if (nfs_cache_dir := _get_spark_distributor_nfs_cache_dir()) is not None: 142 # If NFS directory (shared by all spark nodes) is available, use NFS directory 143 # instead of `SparkContext.addFile` to distribute files. 144 # Because `SparkContext.addFile` is not secure, so it is not allowed to be called 145 # on a shared cluster. 146 dest_path = os.path.join(nfs_cache_dir, os.path.basename(archive_path)) 147 shutil.copy(archive_path, dest_path) 148 return _NFS_PATH_PREFIX + dest_path 149 150 spark.sparkContext.addFile(archive_path) 151 return archive_path 152 153 @staticmethod 154 def get_or_extract(archive_path): 155 """Given a path returned by add_local_model(), this method will return a tuple of 156 (loaded_model, local_model_path). 157 If this Python process ever loaded the model before, we will reuse that copy. 158 """ 159 from pyspark.files import SparkFiles 160 161 if archive_path in _SparkDirectoryDistributor._extracted_dir_paths: 162 return _SparkDirectoryDistributor._extracted_dir_paths[archive_path] 163 164 # BUG: Despite the documentation of SparkContext.addFile() and SparkFiles.get() in Scala 165 # and Python, it turns out that we actually need to use the basename as the input to 166 # SparkFiles.get(), as opposed to the (absolute) path. 167 if archive_path.startswith(_NFS_PATH_PREFIX): 168 local_path = archive_path[len(_NFS_PATH_PREFIX) :] 169 else: 170 archive_path_basename = os.path.basename(archive_path) 171 local_path = SparkFiles.get(archive_path_basename) 172 temp_dir = tempfile.mkdtemp() 173 zip_ref = zipfile.ZipFile(local_path, "r") 174 zip_ref.extractall(temp_dir) 175 zip_ref.close() 176 177 _SparkDirectoryDistributor._extracted_dir_paths[archive_path] = temp_dir 178 return _SparkDirectoryDistributor._extracted_dir_paths[archive_path] 179 180 181 @contextlib.contextmanager 182 def modified_environ(update): 183 """Temporarily updates the ``os.environ`` dictionary in-place. 184 185 The ``os.environ`` dictionary is updated in-place so that the modification 186 is sure to work in all situations. 187 188 Args: 189 update: Dictionary of environment variables and values to add/update. 190 """ 191 update = update or {} 192 original_env = {k: os.environ.get(k) for k in update} 193 194 try: 195 os.environ.update(update) 196 yield 197 finally: 198 for k, v in original_env.items(): 199 if v is None: 200 os.environ.pop(k, None) 201 else: 202 os.environ[k] = v