/ mlflow / projects / utils.py
utils.py
  1  import logging
  2  import os
  3  import pathlib
  4  import re
  5  import shutil
  6  import tempfile
  7  import urllib.parse
  8  import zipfile
  9  from io import BytesIO
 10  
 11  from mlflow import tracking
 12  from mlflow.entities import Param, SourceType
 13  from mlflow.environment_variables import MLFLOW_EXPERIMENT_ID, MLFLOW_RUN_ID, MLFLOW_TRACKING_URI
 14  from mlflow.exceptions import ExecutionException
 15  from mlflow.projects import _project_spec
 16  from mlflow.tracking import fluent
 17  from mlflow.tracking.context.default_context import _get_user
 18  from mlflow.utils.git_utils import get_git_commit, get_git_repo_url
 19  from mlflow.utils.mlflow_tags import (
 20      MLFLOW_GIT_BRANCH,
 21      MLFLOW_GIT_COMMIT,
 22      MLFLOW_GIT_REPO_URL,
 23      MLFLOW_PARENT_RUN_ID,
 24      MLFLOW_PROJECT_ENTRY_POINT,
 25      MLFLOW_SOURCE_NAME,
 26      MLFLOW_SOURCE_TYPE,
 27      MLFLOW_USER,
 28  )
 29  from mlflow.utils.rest_utils import augmented_raise_for_status
 30  
 31  _FILE_URI_REGEX = re.compile(r"^file://.+")
 32  _ZIP_URI_REGEX = re.compile(r".+\.zip$")
 33  MLFLOW_LOCAL_BACKEND_RUN_ID_CONFIG = "_mlflow_local_backend_run_id"
 34  MLFLOW_DOCKER_WORKDIR_PATH = "/mlflow/projects/code/"
 35  
 36  PROJECT_ENV_MANAGER = "ENV_MANAGER"
 37  PROJECT_SYNCHRONOUS = "SYNCHRONOUS"
 38  PROJECT_DOCKER_ARGS = "DOCKER_ARGS"
 39  PROJECT_STORAGE_DIR = "STORAGE_DIR"
 40  PROJECT_BUILD_IMAGE = "build_image"
 41  PROJECT_DOCKER_AUTH = "docker_auth"
 42  GIT_FETCH_DEPTH = 1
 43  
 44  
 45  _logger = logging.getLogger(__name__)
 46  
 47  
 48  def _parse_subdirectory(uri):
 49      # Parses a uri and returns the uri and subdirectory as separate values.
 50      # Uses '#' as a delimiter.
 51      unquoted_uri = _strip_quotes(uri)
 52      subdirectory = ""
 53      parsed_uri = unquoted_uri
 54      if "#" in unquoted_uri:
 55          subdirectory = unquoted_uri[unquoted_uri.find("#") + 1 :]
 56          parsed_uri = unquoted_uri[: unquoted_uri.find("#")]
 57      if subdirectory and "." in subdirectory:
 58          raise ExecutionException("'.' is not allowed in project subdirectory paths.")
 59      return parsed_uri, subdirectory
 60  
 61  
 62  def _strip_quotes(uri):
 63      return uri.strip("'\"")
 64  
 65  
 66  def _get_storage_dir(storage_dir):
 67      if storage_dir is not None and not os.path.exists(storage_dir):
 68          os.makedirs(storage_dir)
 69      return tempfile.mkdtemp(dir=storage_dir)
 70  
 71  
 72  def _expand_uri(uri):
 73      if _is_local_uri(uri):
 74          return os.path.abspath(uri)
 75      return uri
 76  
 77  
 78  def _is_file_uri(uri):
 79      """Returns True if the passed-in URI is a file:// URI."""
 80      return _FILE_URI_REGEX.match(uri)
 81  
 82  
 83  def _is_git_repo(path) -> bool:
 84      """Returns True if passed-in path is a valid git repository"""
 85      import git
 86  
 87      try:
 88          repo = git.Repo(path)
 89          if len(repo.branches) > 0:
 90              return True
 91      except git.exc.InvalidGitRepositoryError:
 92          pass
 93      return False
 94  
 95  
 96  def _parse_file_uri(uri: str) -> str:
 97      """Converts file URIs to filesystem paths"""
 98      if _is_file_uri(uri):
 99          parsed_file_uri = urllib.parse.urlparse(uri)
100          return str(
101              pathlib.Path(parsed_file_uri.netloc, parsed_file_uri.path, parsed_file_uri.fragment)
102          )
103      return uri
104  
105  
106  def _is_local_uri(uri: str) -> bool:
107      """Returns True if passed-in URI should be interpreted as a folder on the local filesystem."""
108      resolved_uri = pathlib.Path(_parse_file_uri(uri)).resolve()
109      return resolved_uri.exists()
110  
111  
112  def _is_zip_uri(uri):
113      """Returns True if the passed-in URI points to a ZIP file."""
114      return _ZIP_URI_REGEX.match(uri)
115  
116  
117  def _is_valid_branch_name(work_dir, version):
118      """
119      Returns True if the ``version`` is the name of a branch in a Git project.
120      ``work_dir`` must be the working directory in a git repo.
121      """
122      if version is not None:
123          from git import Repo
124          from git.exc import GitCommandError
125  
126          repo = Repo(work_dir, search_parent_directories=True)
127          try:
128              return repo.git.rev_parse("--verify", f"refs/heads/{version}") != ""
129          except GitCommandError:
130              return False
131      return False
132  
133  
134  def fetch_and_validate_project(uri, version, entry_point, parameters):
135      parameters = parameters or {}
136      work_dir = _fetch_project(uri=uri, version=version)
137      project = _project_spec.load_project(work_dir)
138      if entry_point_obj := project.get_entry_point(entry_point):
139          entry_point_obj._validate_parameters(parameters)
140      return work_dir
141  
142  
143  def load_project(work_dir):
144      return _project_spec.load_project(work_dir)
145  
146  
147  def _fetch_project(uri, version=None):
148      """
149      Fetch a project into a local directory, returning the path to the local project directory.
150      """
151      parsed_uri, subdirectory = _parse_subdirectory(uri)
152      use_temp_dst_dir = _is_zip_uri(parsed_uri) or not _is_local_uri(parsed_uri)
153      dst_dir = tempfile.mkdtemp() if use_temp_dst_dir else _parse_file_uri(parsed_uri)
154  
155      if use_temp_dst_dir:
156          _logger.info("=== Fetching project from %s into %s ===", uri, dst_dir)
157      if _is_zip_uri(parsed_uri):
158          parsed_uri = _parse_file_uri(parsed_uri)
159          _unzip_repo(
160              zip_file=(parsed_uri if _is_local_uri(parsed_uri) else _fetch_zip_repo(parsed_uri)),
161              dst_dir=dst_dir,
162          )
163      elif _is_local_uri(parsed_uri):
164          if use_temp_dst_dir:
165              shutil.copytree(parsed_uri, dst_dir, dirs_exist_ok=True)
166          if version is not None:
167              if not _is_git_repo(_parse_file_uri(parsed_uri)):
168                  raise ExecutionException("Setting a version is only supported for Git project URIs")
169              _fetch_git_repo(parsed_uri, version, dst_dir)
170      else:
171          _fetch_git_repo(parsed_uri, version, dst_dir)
172      res = os.path.abspath(os.path.join(dst_dir, subdirectory))
173      if not os.path.exists(res):
174          raise ExecutionException(f"Could not find subdirectory {subdirectory} of {dst_dir}")
175      return res
176  
177  
178  def _unzip_repo(zip_file, dst_dir):
179      with zipfile.ZipFile(zip_file) as zip_in:
180          zip_in.extractall(dst_dir)
181  
182  
183  _HEAD_BRANCH_REGEX = re.compile(r"^\s*HEAD branch:\s+(?P<branch>\S+)")
184  
185  
186  def _get_head_branch(remote_show_output):
187      for line in remote_show_output.splitlines():
188          if match := _HEAD_BRANCH_REGEX.match(line):
189              return match.group("branch")
190  
191  
192  def _fetch_git_repo(uri, version, dst_dir):
193      """
194      Clone the git repo at ``uri`` into ``dst_dir``, checking out commit ``version`` (or defaulting
195      to the head commit of the repository's master branch if version is unspecified).
196      Assumes authentication parameters are specified by the environment, e.g. by a Git credential
197      helper.
198      """
199      # We defer importing git until the last moment, because the import requires that the git
200      # executable is available on the PATH, so we only want to fail if we actually need it.
201      import git
202  
203      repo = git.Repo.init(dst_dir)
204      origin = next((remote for remote in repo.remotes), None)
205      if origin is None:
206          origin = repo.create_remote("origin", uri)
207      if version is not None:
208          try:
209              origin.fetch(refspec=version, depth=GIT_FETCH_DEPTH, tags=True)
210              repo.git.checkout(version)
211          except git.exc.GitCommandError as e:
212              raise ExecutionException(
213                  f"Unable to checkout version '{version}' of git repo {uri}"
214                  "- please ensure that the version exists in the repo. "
215                  f"Error: {e}"
216              )
217      else:
218          g = git.cmd.Git(dst_dir)
219          cmd = ["git", "remote", "show", "origin"]
220          output = g.execute(cmd)
221          head_branch = _get_head_branch(output)
222          if head_branch is None:
223              raise ExecutionException(
224                  "Failed to find HEAD branch. Output of `{cmd}`:\n{output}".format(
225                      cmd=" ".join(cmd), output=output
226                  )
227              )
228          origin.fetch(head_branch, depth=GIT_FETCH_DEPTH)
229          ref = origin.refs[0]
230          _logger.info("Fetched '%s' branch", head_branch)
231          repo.create_head(head_branch, ref)
232          repo.heads[head_branch].checkout()
233      repo.git.execute(command=["git", "submodule", "update", "--init", "--recursive"])
234  
235  
236  def _fetch_zip_repo(uri):
237      import requests
238  
239      # TODO (dbczumar): Replace HTTP resolution via ``requests.get`` with an invocation of
240      # ```mlflow.data.download_uri()`` when the API supports the same set of available stores as
241      # the artifact repository (Azure, FTP, etc). See the following issue:
242      # https://github.com/mlflow/mlflow/issues/763.
243      response = requests.get(uri)
244      try:
245          augmented_raise_for_status(response)
246      except requests.HTTPError as error:
247          raise ExecutionException(f"Unable to retrieve ZIP file. Reason: {error!s}")
248      return BytesIO(response.content)
249  
250  
251  def get_or_create_run(run_id, uri, experiment_id, work_dir, version, entry_point, parameters):
252      if run_id:
253          return tracking.MlflowClient().get_run(run_id)
254      else:
255          return _create_run(uri, experiment_id, work_dir, version, entry_point, parameters)
256  
257  
258  def _create_run(uri, experiment_id, work_dir, version, entry_point, parameters):
259      """
260      Create a ``Run`` against the current MLflow tracking server, logging metadata (e.g. the URI,
261      entry point, and parameters of the project) about the run. Return an ``ActiveRun`` that can be
262      used to report additional data about the run (metrics/params) to the tracking server.
263      """
264      if _is_local_uri(uri):
265          source_name = tracking._tracking_service.utils._get_git_url_if_present(_expand_uri(uri))
266      else:
267          source_name = _expand_uri(uri)
268      source_version = get_git_commit(work_dir)
269      existing_run = fluent.active_run()
270      parent_run_id = existing_run.info.run_id if existing_run else None
271  
272      tags = {
273          MLFLOW_USER: _get_user(),
274          MLFLOW_SOURCE_NAME: source_name,
275          MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.PROJECT),
276          MLFLOW_PROJECT_ENTRY_POINT: entry_point,
277      }
278      if source_version is not None:
279          tags[MLFLOW_GIT_COMMIT] = source_version
280      if parent_run_id is not None:
281          tags[MLFLOW_PARENT_RUN_ID] = parent_run_id
282  
283      repo_url = get_git_repo_url(work_dir)
284      if repo_url is not None:
285          tags[MLFLOW_GIT_REPO_URL] = repo_url
286  
287      # Add branch name tag if a branch is specified through -version
288      if _is_valid_branch_name(work_dir, version):
289          tags[MLFLOW_GIT_BRANCH] = version
290      active_run = tracking.MlflowClient().create_run(experiment_id=experiment_id, tags=tags)
291  
292      project = _project_spec.load_project(work_dir)
293      # Consolidate parameters for logging.
294      # `storage_dir` is `None` since we want to log actual path not downloaded local path
295      if entry_point_obj := project.get_entry_point(entry_point):
296          final_params, extra_params = entry_point_obj.compute_parameters(
297              parameters, storage_dir=None
298          )
299          params_list = [
300              Param(key, value)
301              for key, value in list(final_params.items()) + list(extra_params.items())
302          ]
303          tracking.MlflowClient().log_batch(active_run.info.run_id, params=params_list)
304      return active_run
305  
306  
307  def get_entry_point_command(project, entry_point, parameters, storage_dir):
308      """
309      Returns the shell command to execute in order to run the specified entry point.
310  
311      Args:
312          project: Project containing the target entry point.
313          entry_point: Entry point to run.
314          parameters: Parameters (dictionary) for the entry point command.
315          storage_dir: Base local directory to use for downloading remote artifacts passed to
316              arguments of type 'path'. If None, a temporary base directory is used.
317      """
318      storage_dir_for_run = _get_storage_dir(storage_dir)
319      _logger.info(
320          "=== Created directory %s for downloading remote URIs passed to arguments of"
321          " type 'path' ===",
322          storage_dir_for_run,
323      )
324      commands = []
325      commands.append(
326          project.get_entry_point(entry_point).compute_command(parameters, storage_dir_for_run)
327      )
328      return commands
329  
330  
331  def get_run_env_vars(run_id, experiment_id):
332      """
333      Returns a dictionary of environment variable key-value pairs to set in subprocess launched
334      to run MLflow projects.
335      """
336      return {
337          MLFLOW_RUN_ID.name: run_id,
338          MLFLOW_TRACKING_URI.name: tracking.get_tracking_uri(),
339          MLFLOW_EXPERIMENT_ID.name: str(experiment_id),
340      }