update_ml_package_versions.py
1 """ 2 A script to update the maximum package versions in 'mlflow/ml-package-versions.yml'. 3 4 # Prerequisites: 5 $ pip install packaging pyyaml 6 7 # How to run (make sure you're in the repository root): 8 $ python dev/update_ml_package_versions.py 9 """ 10 11 import argparse 12 import json 13 import os 14 import re 15 import sys 16 import time 17 import urllib.error 18 import urllib.request 19 from dataclasses import dataclass 20 from datetime import datetime, timedelta, timezone 21 from pathlib import Path 22 23 import yaml 24 from packaging.version import Version 25 26 27 def read_file(path): 28 with open(path) as f: 29 return f.read() 30 31 32 def save_file(src, path): 33 with open(path, "w") as f: 34 f.write(src) 35 36 37 RELEASE_CUTOFF_DAYS = 14 38 PYPI_URL = os.environ.get("PYPI_URL", "https://pypi.org").rstrip("/") 39 40 41 def check_pypi_accessibility() -> None: 42 try: 43 with urllib.request.urlopen(PYPI_URL, timeout=5): 44 pass 45 except (urllib.error.URLError, OSError): 46 raise SystemExit( 47 f"Error: Cannot connect to {PYPI_URL}. " 48 "If it's not accessible, set the PYPI_URL environment variable to a PyPI proxy URL." 49 ) 50 51 52 @dataclass 53 class VersionInfo: 54 version: str 55 upload_time: datetime 56 57 58 def get_package_version_infos(package_name: str) -> list[VersionInfo]: 59 url = f"{PYPI_URL}/pypi/{package_name}/json" 60 for _ in range(5): # Retry up to 5 times 61 try: 62 with urllib.request.urlopen(url) as res: 63 data = json.load(res) 64 except ConnectionResetError as e: 65 sys.stderr.write(f"Retrying {url} due to {e}\n") 66 time.sleep(1) 67 else: 68 break 69 else: 70 raise Exception(f"Failed to fetch {url}") 71 72 def is_dev_or_pre_release(version_str): 73 v = Version(version_str) 74 return v.is_devrelease or v.is_prerelease 75 76 cutoff = datetime.now(timezone.utc) - timedelta(days=RELEASE_CUTOFF_DAYS) 77 78 def uploaded_within_cutoff(dist) -> bool: 79 if ut := dist.get("upload_time_iso_8601"): 80 return datetime.fromisoformat(ut.replace("Z", "+00:00")) >= cutoff 81 return False 82 83 return [ 84 VersionInfo( 85 version=version, 86 upload_time=datetime.fromisoformat(dist_files[0]["upload_time"]), 87 ) 88 for version, dist_files in data["releases"].items() 89 if ( 90 len(dist_files) > 0 91 and not is_dev_or_pre_release(version) 92 and not any(uploaded_within_cutoff(dist) for dist in dist_files) 93 and not any(dist.get("yanked", False) for dist in dist_files) 94 ) 95 ] 96 97 98 def get_latest_version(candidates): 99 return max(candidates, key=Version) 100 101 102 def update_version(src, key, new_version, category, update_max): 103 """ 104 Examples 105 ======== 106 >>> src = ''' 107 ... sklearn: 108 ... ... 109 ... models: 110 ... minimum: "0.0.0" 111 ... maximum: "0.0.0" 112 ... xgboost: 113 ... ... 114 ... autologging: 115 ... minimum: "1.1.1" 116 ... maximum: "1.1.1" 117 ... '''.strip() 118 >>> new_src = update_version(src, "sklearn", "0.1.0", "models", update_max=True) 119 >>> new_src = update_version(new_src, "xgboost", "1.2.1", "autologging", update_max=True) 120 >>> print(new_src) 121 sklearn: 122 ... 123 models: 124 minimum: "0.0.0" 125 maximum: "0.1.0" 126 xgboost: 127 ... 128 autologging: 129 minimum: "1.1.1" 130 maximum: "1.2.1" 131 """ 132 match = "maximum" if update_max else "minimum" 133 pattern = r"((^|\n){key}:.+?{category}:.+?{match}: )\".+?\"".format( 134 key=re.escape(key), category=category, match=match 135 ) 136 # Matches the following pattern: 137 # 138 # <key>: 139 # ... 140 # <category>: 141 # ... 142 # maximum: "1.2.3" 143 return re.sub(pattern, rf'\g<1>"{new_version}"', src, flags=re.DOTALL) 144 145 146 def extract_field(d, keys): 147 for key in keys: 148 if key in d: 149 d = d[key] 150 else: 151 return None 152 return d 153 154 155 def _get_autolog_flavor_module_map(config): 156 """ 157 Parse _ML_PACKAGE_VERSIONS to get the mapping of flavor name to 158 the module name to be imported for autologging. 159 """ 160 autolog_flavor_module_map = {} 161 for flavor, config in config.items(): 162 if "autologging" not in config: 163 continue 164 module_name = config["package_info"].get("module_name", flavor) 165 autolog_flavor_module_map[flavor] = module_name 166 167 return autolog_flavor_module_map 168 169 170 def update_ml_package_versions_py(config_path): 171 with open(config_path) as f: 172 genai_config = {} 173 non_genai_config = {} 174 175 for name, cfg in yaml.load(f, Loader=yaml.SafeLoader).items(): 176 # Extract required fields 177 pip_release = extract_field(cfg, ("package_info", "pip_release")) 178 module_name = extract_field(cfg, ("package_info", "module_name")) 179 min_version = extract_field(cfg, ("models", "minimum")) 180 max_version = extract_field(cfg, ("models", "maximum")) 181 genai = extract_field(cfg, ("package_info", "genai")) 182 config_to_update = genai_config if genai else non_genai_config 183 if min_version: 184 config_to_update[name] = { 185 "package_info": { 186 "pip_release": pip_release, 187 }, 188 "models": { 189 "minimum": min_version, 190 "maximum": max_version, 191 }, 192 } 193 else: 194 config_to_update[name] = { 195 "package_info": { 196 "pip_release": pip_release, 197 } 198 } 199 if module_name: 200 config_to_update[name]["package_info"]["module_name"] = module_name 201 202 # Check for autologging configuration 203 autolog_min_version = extract_field(cfg, ("autologging", "minimum")) 204 autolog_max_version = extract_field(cfg, ("autologging", "maximum")) 205 if (pip_release, autolog_min_version, autolog_max_version).count(None) > 0: 206 continue 207 208 config_to_update[name].update( 209 { 210 "autologging": { 211 "minimum": autolog_min_version, 212 "maximum": autolog_max_version, 213 } 214 }, 215 ) 216 217 genai_flavor_module_mapping = _get_autolog_flavor_module_map(genai_config) 218 # We have "langgraph" entry in ml-package-versions.yml so that we can run test 219 # against multiple versions of langgraph. However, we don't have a flavor for 220 # langgraph and it is a part of the langchain flavor. 221 genai_flavor_module_mapping.pop("langgraph", None) 222 223 non_genai_flavor_module_mapping = _get_autolog_flavor_module_map(non_genai_config) 224 # Add special case for pyspark.ml (non-GenAI) 225 non_genai_flavor_module_mapping["pyspark.ml"] = "pyspark" 226 227 this_file = Path(__file__).name 228 dst = Path("mlflow", "ml_package_versions.py") 229 230 config_str = json.dumps(genai_config | non_genai_config, indent=4) 231 232 Path(dst).write_text( 233 f"""\ 234 # This file was auto-generated by {this_file}. 235 # Please do not edit it manually. 236 237 _ML_PACKAGE_VERSIONS = {config_str} 238 239 # A mapping of flavor name to the module name to be imported for autologging. 240 # This is used for checking version compatibility in autologging. 241 # DO NOT EDIT MANUALLY 242 243 # GenAI packages 244 GENAI_FLAVOR_TO_MODULE_NAME = {json.dumps(genai_flavor_module_mapping, indent=4)} 245 246 # Non-GenAI packages 247 NON_GENAI_FLAVOR_TO_MODULE_NAME = {json.dumps(non_genai_flavor_module_mapping, indent=4)} 248 249 # Combined mapping for backward compatibility 250 FLAVOR_TO_MODULE_NAME = NON_GENAI_FLAVOR_TO_MODULE_NAME | GENAI_FLAVOR_TO_MODULE_NAME 251 """ 252 ) 253 254 255 def parse_args(): 256 parser = argparse.ArgumentParser(description="Update MLflow package versions") 257 parser.add_argument( 258 "--skip-yml", action="store_true", help="Skip updating ml-package-versions.yml" 259 ) 260 return parser.parse_args() 261 262 263 def get_min_supported_version(versions_infos: list[VersionInfo], genai: bool = False) -> str | None: 264 """ 265 Get the minimum version that is released within the past two years 266 """ 267 years = 1 if genai else 2 268 min_support_date = datetime.now() - timedelta(days=years * 365) 269 min_support_date = min_support_date.replace(tzinfo=None) 270 271 # Extract versions that were released in the past two years 272 recent_versions = [v for v in versions_infos if v.upload_time > min_support_date] 273 274 if not recent_versions: 275 return None 276 277 # Get minimum version according to upload date 278 return min(recent_versions, key=lambda v: v.upload_time).version 279 280 281 def update(skip_yml=False): 282 if not skip_yml: 283 check_pypi_accessibility() 284 yml_path = "mlflow/ml-package-versions.yml" 285 286 if not skip_yml: 287 old_src = read_file(yml_path) 288 new_src = old_src 289 config_dict = yaml.load(old_src, Loader=yaml.SafeLoader) 290 for flavor_key, config in config_dict.items(): 291 # We currently don't have bandwidth to support newer versions of these flavors. 292 if flavor_key in ["litellm"]: 293 continue 294 package_name = config["package_info"]["pip_release"] 295 genai = config["package_info"].get("genai", False) 296 versions_and_upload_times = get_package_version_infos(package_name) 297 min_supported_version = get_min_supported_version( 298 versions_and_upload_times, genai=genai 299 ) 300 301 for category in ["autologging", "models"]: 302 print("Processing", flavor_key, category) 303 304 if category in config and "minimum" in config[category]: 305 old_min_version = config[category]["minimum"] 306 if flavor_key == "spark": 307 # We should support pyspark versions that are older than the cut off date. 308 pass 309 elif min_supported_version is None: 310 # The latest release version was 2 years ago. 311 # set the min version to be the same with the max version. 312 max_ver = config[category]["maximum"] 313 new_src = update_version( 314 new_src, flavor_key, max_ver, category, update_max=False 315 ) 316 elif Version(min_supported_version) > Version(old_min_version): 317 new_src = update_version( 318 new_src, flavor_key, min_supported_version, category, update_max=False 319 ) 320 321 if (category not in config) or config[category].get("pin_maximum", False): 322 continue 323 324 max_ver = config[category]["maximum"] 325 versions = [v.version for v in versions_and_upload_times] 326 unsupported = config[category].get("unsupported", []) 327 versions = set(versions).difference(unsupported) # exclude unsupported versions 328 latest_version = get_latest_version(versions) 329 330 if Version(latest_version) <= Version(max_ver): 331 continue 332 333 new_src = update_version( 334 new_src, flavor_key, latest_version, category, update_max=True 335 ) 336 337 save_file(new_src, yml_path) 338 339 update_ml_package_versions_py(yml_path) 340 341 342 def main(): 343 args = parse_args() 344 update(args.skip_yml) 345 346 347 if __name__ == "__main__": 348 main()