huggingface_utils.py
1 import functools 2 import logging 3 import os 4 import time 5 6 from mlflow.environment_variables import _MLFLOW_TESTING 7 from mlflow.exceptions import MlflowException 8 from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST 9 10 _logger = logging.getLogger(__name__) 11 12 13 # NB: The maxsize=1 is added for encouraging the cache refresh so the user doesn't get stale 14 # commit hash from the cache. This doesn't work perfectly because it only updates cache 15 # when the user calls it with a different repo name, but it's better than nothing. 16 @functools.lru_cache(maxsize=1) 17 def get_latest_commit_for_repo(repo: str) -> str: 18 """ 19 Fetches the latest commit hash for a repository from the HuggingFace model hub. 20 """ 21 try: 22 import huggingface_hub as hub 23 except ImportError: 24 raise MlflowException( 25 "Unable to fetch model commit hash from the HuggingFace model hub. " 26 "This is required for saving a model without base model " 27 "weights, while ensuring the version consistency of the model. " 28 "Please install the `huggingface-hub` package and retry.", 29 error_code=RESOURCE_DOES_NOT_EXIST, 30 ) 31 32 from huggingface_hub.errors import HfHubHTTPError 33 34 api = hub.HfApi() 35 for i in range(7): 36 try: 37 return api.model_info(repo).sha 38 except HfHubHTTPError as e: 39 if not _MLFLOW_TESTING.get(): 40 raise 41 42 # Retry on rate limit error 43 if e.response.status_code == 429: 44 _logger.warning( 45 f"Rate limit exceeded while fetching commit hash for repo {repo}. " 46 f"Retrying in {2**i} seconds. Error: {e}", 47 ) 48 time.sleep(2**i) 49 continue 50 raise 51 52 raise MlflowException( 53 "Unable to fetch model commit hash from the HuggingFace model hub. " 54 "This is required for saving a model without base model " 55 "weights, while ensuring the version consistency of the model. ", 56 error_code=RESOURCE_DOES_NOT_EXIST, 57 ) 58 59 60 def is_valid_hf_repo_id(maybe_repo_id: str | None) -> bool: 61 """ 62 Check if the given string is a valid HuggingFace repo identifier e.g. "username/repo_id". 63 """ 64 65 if not maybe_repo_id or os.path.isdir(maybe_repo_id): 66 return False 67 68 try: 69 from huggingface_hub.utils import HFValidationError, validate_repo_id 70 except ImportError: 71 raise MlflowException( 72 "Unable to validate the repository identifier for the HuggingFace model hub " 73 "because the `huggingface-hub` package is not installed. Please install the " 74 "package with `pip install huggingface-hub` command and retry." 75 ) 76 77 try: 78 validate_repo_id(maybe_repo_id) 79 return True 80 except HFValidationError as e: 81 _logger.warning(f"The repository identified {maybe_repo_id} is invalid: {e}") 82 return False