/ mlflow / utils / huggingface_utils.py
huggingface_utils.py
 1  import functools
 2  import logging
 3  import os
 4  import time
 5  
 6  from mlflow.environment_variables import _MLFLOW_TESTING
 7  from mlflow.exceptions import MlflowException
 8  from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST
 9  
10  _logger = logging.getLogger(__name__)
11  
12  
13  # NB: The maxsize=1 is added for encouraging the cache refresh so the user doesn't get stale
14  #    commit hash from the cache. This doesn't work perfectly because it only updates cache
15  #    when the user calls it with a different repo name, but it's better than nothing.
16  @functools.lru_cache(maxsize=1)
17  def get_latest_commit_for_repo(repo: str) -> str:
18      """
19      Fetches the latest commit hash for a repository from the HuggingFace model hub.
20      """
21      try:
22          import huggingface_hub as hub
23      except ImportError:
24          raise MlflowException(
25              "Unable to fetch model commit hash from the HuggingFace model hub. "
26              "This is required for saving a model without base model "
27              "weights, while ensuring the version consistency of the model. "
28              "Please install the `huggingface-hub` package and retry.",
29              error_code=RESOURCE_DOES_NOT_EXIST,
30          )
31  
32      from huggingface_hub.errors import HfHubHTTPError
33  
34      api = hub.HfApi()
35      for i in range(7):
36          try:
37              return api.model_info(repo).sha
38          except HfHubHTTPError as e:
39              if not _MLFLOW_TESTING.get():
40                  raise
41  
42              # Retry on rate limit error
43              if e.response.status_code == 429:
44                  _logger.warning(
45                      f"Rate limit exceeded while fetching commit hash for repo {repo}. "
46                      f"Retrying in {2**i} seconds. Error: {e}",
47                  )
48                  time.sleep(2**i)
49                  continue
50              raise
51  
52      raise MlflowException(
53          "Unable to fetch model commit hash from the HuggingFace model hub. "
54          "This is required for saving a model without base model "
55          "weights, while ensuring the version consistency of the model. ",
56          error_code=RESOURCE_DOES_NOT_EXIST,
57      )
58  
59  
60  def is_valid_hf_repo_id(maybe_repo_id: str | None) -> bool:
61      """
62      Check if the given string is a valid HuggingFace repo identifier e.g. "username/repo_id".
63      """
64  
65      if not maybe_repo_id or os.path.isdir(maybe_repo_id):
66          return False
67  
68      try:
69          from huggingface_hub.utils import HFValidationError, validate_repo_id
70      except ImportError:
71          raise MlflowException(
72              "Unable to validate the repository identifier for the HuggingFace model hub "
73              "because the `huggingface-hub` package is not installed. Please install the "
74              "package with `pip install huggingface-hub` command and retry."
75          )
76  
77      try:
78          validate_repo_id(maybe_repo_id)
79          return True
80      except HFValidationError as e:
81          _logger.warning(f"The repository identified {maybe_repo_id} is invalid: {e}")
82          return False