utils.py
1 import logging 2 import os 3 import pathlib 4 import re 5 import shutil 6 import tempfile 7 import urllib.parse 8 import zipfile 9 from io import BytesIO 10 11 from mlflow import tracking 12 from mlflow.entities import Param, SourceType 13 from mlflow.environment_variables import MLFLOW_EXPERIMENT_ID, MLFLOW_RUN_ID, MLFLOW_TRACKING_URI 14 from mlflow.exceptions import ExecutionException 15 from mlflow.projects import _project_spec 16 from mlflow.tracking import fluent 17 from mlflow.tracking.context.default_context import _get_user 18 from mlflow.utils.git_utils import get_git_commit, get_git_repo_url 19 from mlflow.utils.mlflow_tags import ( 20 MLFLOW_GIT_BRANCH, 21 MLFLOW_GIT_COMMIT, 22 MLFLOW_GIT_REPO_URL, 23 MLFLOW_PARENT_RUN_ID, 24 MLFLOW_PROJECT_ENTRY_POINT, 25 MLFLOW_SOURCE_NAME, 26 MLFLOW_SOURCE_TYPE, 27 MLFLOW_USER, 28 ) 29 from mlflow.utils.rest_utils import augmented_raise_for_status 30 31 _FILE_URI_REGEX = re.compile(r"^file://.+") 32 _ZIP_URI_REGEX = re.compile(r".+\.zip$") 33 MLFLOW_LOCAL_BACKEND_RUN_ID_CONFIG = "_mlflow_local_backend_run_id" 34 MLFLOW_DOCKER_WORKDIR_PATH = "/mlflow/projects/code/" 35 36 PROJECT_ENV_MANAGER = "ENV_MANAGER" 37 PROJECT_SYNCHRONOUS = "SYNCHRONOUS" 38 PROJECT_DOCKER_ARGS = "DOCKER_ARGS" 39 PROJECT_STORAGE_DIR = "STORAGE_DIR" 40 PROJECT_BUILD_IMAGE = "build_image" 41 PROJECT_DOCKER_AUTH = "docker_auth" 42 GIT_FETCH_DEPTH = 1 43 44 45 _logger = logging.getLogger(__name__) 46 47 48 def _parse_subdirectory(uri): 49 # Parses a uri and returns the uri and subdirectory as separate values. 50 # Uses '#' as a delimiter. 51 unquoted_uri = _strip_quotes(uri) 52 subdirectory = "" 53 parsed_uri = unquoted_uri 54 if "#" in unquoted_uri: 55 subdirectory = unquoted_uri[unquoted_uri.find("#") + 1 :] 56 parsed_uri = unquoted_uri[: unquoted_uri.find("#")] 57 if subdirectory and "." in subdirectory: 58 raise ExecutionException("'.' is not allowed in project subdirectory paths.") 59 return parsed_uri, subdirectory 60 61 62 def _strip_quotes(uri): 63 return uri.strip("'\"") 64 65 66 def _get_storage_dir(storage_dir): 67 if storage_dir is not None and not os.path.exists(storage_dir): 68 os.makedirs(storage_dir) 69 return tempfile.mkdtemp(dir=storage_dir) 70 71 72 def _expand_uri(uri): 73 if _is_local_uri(uri): 74 return os.path.abspath(uri) 75 return uri 76 77 78 def _is_file_uri(uri): 79 """Returns True if the passed-in URI is a file:// URI.""" 80 return _FILE_URI_REGEX.match(uri) 81 82 83 def _is_git_repo(path) -> bool: 84 """Returns True if passed-in path is a valid git repository""" 85 import git 86 87 try: 88 repo = git.Repo(path) 89 if len(repo.branches) > 0: 90 return True 91 except git.exc.InvalidGitRepositoryError: 92 pass 93 return False 94 95 96 def _parse_file_uri(uri: str) -> str: 97 """Converts file URIs to filesystem paths""" 98 if _is_file_uri(uri): 99 parsed_file_uri = urllib.parse.urlparse(uri) 100 return str( 101 pathlib.Path(parsed_file_uri.netloc, parsed_file_uri.path, parsed_file_uri.fragment) 102 ) 103 return uri 104 105 106 def _is_local_uri(uri: str) -> bool: 107 """Returns True if passed-in URI should be interpreted as a folder on the local filesystem.""" 108 resolved_uri = pathlib.Path(_parse_file_uri(uri)).resolve() 109 return resolved_uri.exists() 110 111 112 def _is_zip_uri(uri): 113 """Returns True if the passed-in URI points to a ZIP file.""" 114 return _ZIP_URI_REGEX.match(uri) 115 116 117 def _is_valid_branch_name(work_dir, version): 118 """ 119 Returns True if the ``version`` is the name of a branch in a Git project. 120 ``work_dir`` must be the working directory in a git repo. 121 """ 122 if version is not None: 123 from git import Repo 124 from git.exc import GitCommandError 125 126 repo = Repo(work_dir, search_parent_directories=True) 127 try: 128 return repo.git.rev_parse("--verify", f"refs/heads/{version}") != "" 129 except GitCommandError: 130 return False 131 return False 132 133 134 def fetch_and_validate_project(uri, version, entry_point, parameters): 135 parameters = parameters or {} 136 work_dir = _fetch_project(uri=uri, version=version) 137 project = _project_spec.load_project(work_dir) 138 if entry_point_obj := project.get_entry_point(entry_point): 139 entry_point_obj._validate_parameters(parameters) 140 return work_dir 141 142 143 def load_project(work_dir): 144 return _project_spec.load_project(work_dir) 145 146 147 def _fetch_project(uri, version=None): 148 """ 149 Fetch a project into a local directory, returning the path to the local project directory. 150 """ 151 parsed_uri, subdirectory = _parse_subdirectory(uri) 152 use_temp_dst_dir = _is_zip_uri(parsed_uri) or not _is_local_uri(parsed_uri) 153 dst_dir = tempfile.mkdtemp() if use_temp_dst_dir else _parse_file_uri(parsed_uri) 154 155 if use_temp_dst_dir: 156 _logger.info("=== Fetching project from %s into %s ===", uri, dst_dir) 157 if _is_zip_uri(parsed_uri): 158 parsed_uri = _parse_file_uri(parsed_uri) 159 _unzip_repo( 160 zip_file=(parsed_uri if _is_local_uri(parsed_uri) else _fetch_zip_repo(parsed_uri)), 161 dst_dir=dst_dir, 162 ) 163 elif _is_local_uri(parsed_uri): 164 if use_temp_dst_dir: 165 shutil.copytree(parsed_uri, dst_dir, dirs_exist_ok=True) 166 if version is not None: 167 if not _is_git_repo(_parse_file_uri(parsed_uri)): 168 raise ExecutionException("Setting a version is only supported for Git project URIs") 169 _fetch_git_repo(parsed_uri, version, dst_dir) 170 else: 171 _fetch_git_repo(parsed_uri, version, dst_dir) 172 res = os.path.abspath(os.path.join(dst_dir, subdirectory)) 173 if not os.path.exists(res): 174 raise ExecutionException(f"Could not find subdirectory {subdirectory} of {dst_dir}") 175 return res 176 177 178 def _unzip_repo(zip_file, dst_dir): 179 with zipfile.ZipFile(zip_file) as zip_in: 180 zip_in.extractall(dst_dir) 181 182 183 _HEAD_BRANCH_REGEX = re.compile(r"^\s*HEAD branch:\s+(?P<branch>\S+)") 184 185 186 def _get_head_branch(remote_show_output): 187 for line in remote_show_output.splitlines(): 188 if match := _HEAD_BRANCH_REGEX.match(line): 189 return match.group("branch") 190 191 192 def _fetch_git_repo(uri, version, dst_dir): 193 """ 194 Clone the git repo at ``uri`` into ``dst_dir``, checking out commit ``version`` (or defaulting 195 to the head commit of the repository's master branch if version is unspecified). 196 Assumes authentication parameters are specified by the environment, e.g. by a Git credential 197 helper. 198 """ 199 # We defer importing git until the last moment, because the import requires that the git 200 # executable is available on the PATH, so we only want to fail if we actually need it. 201 import git 202 203 repo = git.Repo.init(dst_dir) 204 origin = next((remote for remote in repo.remotes), None) 205 if origin is None: 206 origin = repo.create_remote("origin", uri) 207 if version is not None: 208 try: 209 origin.fetch(refspec=version, depth=GIT_FETCH_DEPTH, tags=True) 210 repo.git.checkout(version) 211 except git.exc.GitCommandError as e: 212 raise ExecutionException( 213 f"Unable to checkout version '{version}' of git repo {uri}" 214 "- please ensure that the version exists in the repo. " 215 f"Error: {e}" 216 ) 217 else: 218 g = git.cmd.Git(dst_dir) 219 cmd = ["git", "remote", "show", "origin"] 220 output = g.execute(cmd) 221 head_branch = _get_head_branch(output) 222 if head_branch is None: 223 raise ExecutionException( 224 "Failed to find HEAD branch. Output of `{cmd}`:\n{output}".format( 225 cmd=" ".join(cmd), output=output 226 ) 227 ) 228 origin.fetch(head_branch, depth=GIT_FETCH_DEPTH) 229 ref = origin.refs[0] 230 _logger.info("Fetched '%s' branch", head_branch) 231 repo.create_head(head_branch, ref) 232 repo.heads[head_branch].checkout() 233 repo.git.execute(command=["git", "submodule", "update", "--init", "--recursive"]) 234 235 236 def _fetch_zip_repo(uri): 237 import requests 238 239 # TODO (dbczumar): Replace HTTP resolution via ``requests.get`` with an invocation of 240 # ```mlflow.data.download_uri()`` when the API supports the same set of available stores as 241 # the artifact repository (Azure, FTP, etc). See the following issue: 242 # https://github.com/mlflow/mlflow/issues/763. 243 response = requests.get(uri) 244 try: 245 augmented_raise_for_status(response) 246 except requests.HTTPError as error: 247 raise ExecutionException(f"Unable to retrieve ZIP file. Reason: {error!s}") 248 return BytesIO(response.content) 249 250 251 def get_or_create_run(run_id, uri, experiment_id, work_dir, version, entry_point, parameters): 252 if run_id: 253 return tracking.MlflowClient().get_run(run_id) 254 else: 255 return _create_run(uri, experiment_id, work_dir, version, entry_point, parameters) 256 257 258 def _create_run(uri, experiment_id, work_dir, version, entry_point, parameters): 259 """ 260 Create a ``Run`` against the current MLflow tracking server, logging metadata (e.g. the URI, 261 entry point, and parameters of the project) about the run. Return an ``ActiveRun`` that can be 262 used to report additional data about the run (metrics/params) to the tracking server. 263 """ 264 if _is_local_uri(uri): 265 source_name = tracking._tracking_service.utils._get_git_url_if_present(_expand_uri(uri)) 266 else: 267 source_name = _expand_uri(uri) 268 source_version = get_git_commit(work_dir) 269 existing_run = fluent.active_run() 270 parent_run_id = existing_run.info.run_id if existing_run else None 271 272 tags = { 273 MLFLOW_USER: _get_user(), 274 MLFLOW_SOURCE_NAME: source_name, 275 MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.PROJECT), 276 MLFLOW_PROJECT_ENTRY_POINT: entry_point, 277 } 278 if source_version is not None: 279 tags[MLFLOW_GIT_COMMIT] = source_version 280 if parent_run_id is not None: 281 tags[MLFLOW_PARENT_RUN_ID] = parent_run_id 282 283 repo_url = get_git_repo_url(work_dir) 284 if repo_url is not None: 285 tags[MLFLOW_GIT_REPO_URL] = repo_url 286 287 # Add branch name tag if a branch is specified through -version 288 if _is_valid_branch_name(work_dir, version): 289 tags[MLFLOW_GIT_BRANCH] = version 290 active_run = tracking.MlflowClient().create_run(experiment_id=experiment_id, tags=tags) 291 292 project = _project_spec.load_project(work_dir) 293 # Consolidate parameters for logging. 294 # `storage_dir` is `None` since we want to log actual path not downloaded local path 295 if entry_point_obj := project.get_entry_point(entry_point): 296 final_params, extra_params = entry_point_obj.compute_parameters( 297 parameters, storage_dir=None 298 ) 299 params_list = [ 300 Param(key, value) 301 for key, value in list(final_params.items()) + list(extra_params.items()) 302 ] 303 tracking.MlflowClient().log_batch(active_run.info.run_id, params=params_list) 304 return active_run 305 306 307 def get_entry_point_command(project, entry_point, parameters, storage_dir): 308 """ 309 Returns the shell command to execute in order to run the specified entry point. 310 311 Args: 312 project: Project containing the target entry point. 313 entry_point: Entry point to run. 314 parameters: Parameters (dictionary) for the entry point command. 315 storage_dir: Base local directory to use for downloading remote artifacts passed to 316 arguments of type 'path'. If None, a temporary base directory is used. 317 """ 318 storage_dir_for_run = _get_storage_dir(storage_dir) 319 _logger.info( 320 "=== Created directory %s for downloading remote URIs passed to arguments of" 321 " type 'path' ===", 322 storage_dir_for_run, 323 ) 324 commands = [] 325 commands.append( 326 project.get_entry_point(entry_point).compute_command(parameters, storage_dir_for_run) 327 ) 328 return commands 329 330 331 def get_run_env_vars(run_id, experiment_id): 332 """ 333 Returns a dictionary of environment variable key-value pairs to set in subprocess launched 334 to run MLflow projects. 335 """ 336 return { 337 MLFLOW_RUN_ID.name: run_id, 338 MLFLOW_TRACKING_URI.name: tracking.get_tracking_uri(), 339 MLFLOW_EXPERIMENT_ID.name: str(experiment_id), 340 }