/ dev / update_ml_package_versions.py
update_ml_package_versions.py
  1  """
  2  A script to update the maximum package versions in 'mlflow/ml-package-versions.yml'.
  3  
  4  # Prerequisites:
  5  $ pip install packaging pyyaml
  6  
  7  # How to run (make sure you're in the repository root):
  8  $ python dev/update_ml_package_versions.py
  9  """
 10  
 11  import argparse
 12  import json
 13  import os
 14  import re
 15  import sys
 16  import time
 17  import urllib.error
 18  import urllib.request
 19  from dataclasses import dataclass
 20  from datetime import datetime, timedelta, timezone
 21  from pathlib import Path
 22  
 23  import yaml
 24  from packaging.version import Version
 25  
 26  
 27  def read_file(path):
 28      with open(path) as f:
 29          return f.read()
 30  
 31  
 32  def save_file(src, path):
 33      with open(path, "w") as f:
 34          f.write(src)
 35  
 36  
 37  RELEASE_CUTOFF_DAYS = 14
 38  PYPI_URL = os.environ.get("PYPI_URL", "https://pypi.org").rstrip("/")
 39  
 40  
 41  def check_pypi_accessibility() -> None:
 42      try:
 43          with urllib.request.urlopen(PYPI_URL, timeout=5):
 44              pass
 45      except (urllib.error.URLError, OSError):
 46          raise SystemExit(
 47              f"Error: Cannot connect to {PYPI_URL}. "
 48              "If it's not accessible, set the PYPI_URL environment variable to a PyPI proxy URL."
 49          )
 50  
 51  
 52  @dataclass
 53  class VersionInfo:
 54      version: str
 55      upload_time: datetime
 56  
 57  
 58  def get_package_version_infos(package_name: str) -> list[VersionInfo]:
 59      url = f"{PYPI_URL}/pypi/{package_name}/json"
 60      for _ in range(5):  # Retry up to 5 times
 61          try:
 62              with urllib.request.urlopen(url) as res:
 63                  data = json.load(res)
 64          except ConnectionResetError as e:
 65              sys.stderr.write(f"Retrying {url} due to {e}\n")
 66              time.sleep(1)
 67          else:
 68              break
 69      else:
 70          raise Exception(f"Failed to fetch {url}")
 71  
 72      def is_dev_or_pre_release(version_str):
 73          v = Version(version_str)
 74          return v.is_devrelease or v.is_prerelease
 75  
 76      cutoff = datetime.now(timezone.utc) - timedelta(days=RELEASE_CUTOFF_DAYS)
 77  
 78      def uploaded_within_cutoff(dist) -> bool:
 79          if ut := dist.get("upload_time_iso_8601"):
 80              return datetime.fromisoformat(ut.replace("Z", "+00:00")) >= cutoff
 81          return False
 82  
 83      return [
 84          VersionInfo(
 85              version=version,
 86              upload_time=datetime.fromisoformat(dist_files[0]["upload_time"]),
 87          )
 88          for version, dist_files in data["releases"].items()
 89          if (
 90              len(dist_files) > 0
 91              and not is_dev_or_pre_release(version)
 92              and not any(uploaded_within_cutoff(dist) for dist in dist_files)
 93              and not any(dist.get("yanked", False) for dist in dist_files)
 94          )
 95      ]
 96  
 97  
 98  def get_latest_version(candidates):
 99      return max(candidates, key=Version)
100  
101  
102  def update_version(src, key, new_version, category, update_max):
103      """
104      Examples
105      ========
106      >>> src = '''
107      ... sklearn:
108      ...   ...
109      ...   models:
110      ...     minimum: "0.0.0"
111      ...     maximum: "0.0.0"
112      ... xgboost:
113      ...   ...
114      ...   autologging:
115      ...     minimum: "1.1.1"
116      ...     maximum: "1.1.1"
117      ... '''.strip()
118      >>> new_src = update_version(src, "sklearn", "0.1.0", "models", update_max=True)
119      >>> new_src = update_version(new_src, "xgboost", "1.2.1", "autologging", update_max=True)
120      >>> print(new_src)
121      sklearn:
122        ...
123        models:
124          minimum: "0.0.0"
125          maximum: "0.1.0"
126      xgboost:
127        ...
128        autologging:
129          minimum: "1.1.1"
130          maximum: "1.2.1"
131      """
132      match = "maximum" if update_max else "minimum"
133      pattern = r"((^|\n){key}:.+?{category}:.+?{match}: )\".+?\"".format(
134          key=re.escape(key), category=category, match=match
135      )
136      # Matches the following pattern:
137      #
138      # <key>:
139      #   ...
140      #   <category>:
141      #     ...
142      #     maximum: "1.2.3"
143      return re.sub(pattern, rf'\g<1>"{new_version}"', src, flags=re.DOTALL)
144  
145  
146  def extract_field(d, keys):
147      for key in keys:
148          if key in d:
149              d = d[key]
150          else:
151              return None
152      return d
153  
154  
155  def _get_autolog_flavor_module_map(config):
156      """
157      Parse _ML_PACKAGE_VERSIONS to get the mapping of flavor name to
158      the module name to be imported for autologging.
159      """
160      autolog_flavor_module_map = {}
161      for flavor, config in config.items():
162          if "autologging" not in config:
163              continue
164          module_name = config["package_info"].get("module_name", flavor)
165          autolog_flavor_module_map[flavor] = module_name
166  
167      return autolog_flavor_module_map
168  
169  
170  def update_ml_package_versions_py(config_path):
171      with open(config_path) as f:
172          genai_config = {}
173          non_genai_config = {}
174  
175          for name, cfg in yaml.load(f, Loader=yaml.SafeLoader).items():
176              # Extract required fields
177              pip_release = extract_field(cfg, ("package_info", "pip_release"))
178              module_name = extract_field(cfg, ("package_info", "module_name"))
179              min_version = extract_field(cfg, ("models", "minimum"))
180              max_version = extract_field(cfg, ("models", "maximum"))
181              genai = extract_field(cfg, ("package_info", "genai"))
182              config_to_update = genai_config if genai else non_genai_config
183              if min_version:
184                  config_to_update[name] = {
185                      "package_info": {
186                          "pip_release": pip_release,
187                      },
188                      "models": {
189                          "minimum": min_version,
190                          "maximum": max_version,
191                      },
192                  }
193              else:
194                  config_to_update[name] = {
195                      "package_info": {
196                          "pip_release": pip_release,
197                      }
198                  }
199              if module_name:
200                  config_to_update[name]["package_info"]["module_name"] = module_name
201  
202              # Check for autologging configuration
203              autolog_min_version = extract_field(cfg, ("autologging", "minimum"))
204              autolog_max_version = extract_field(cfg, ("autologging", "maximum"))
205              if (pip_release, autolog_min_version, autolog_max_version).count(None) > 0:
206                  continue
207  
208              config_to_update[name].update(
209                  {
210                      "autologging": {
211                          "minimum": autolog_min_version,
212                          "maximum": autolog_max_version,
213                      }
214                  },
215              )
216  
217          genai_flavor_module_mapping = _get_autolog_flavor_module_map(genai_config)
218          # We have "langgraph" entry in ml-package-versions.yml so that we can run test
219          # against multiple versions of langgraph. However, we don't have a flavor for
220          # langgraph and it is a part of the langchain flavor.
221          genai_flavor_module_mapping.pop("langgraph", None)
222  
223          non_genai_flavor_module_mapping = _get_autolog_flavor_module_map(non_genai_config)
224          # Add special case for pyspark.ml (non-GenAI)
225          non_genai_flavor_module_mapping["pyspark.ml"] = "pyspark"
226  
227          this_file = Path(__file__).name
228          dst = Path("mlflow", "ml_package_versions.py")
229  
230          config_str = json.dumps(genai_config | non_genai_config, indent=4)
231  
232          Path(dst).write_text(
233              f"""\
234  # This file was auto-generated by {this_file}.
235  # Please do not edit it manually.
236  
237  _ML_PACKAGE_VERSIONS = {config_str}
238  
239  # A mapping of flavor name to the module name to be imported for autologging.
240  # This is used for checking version compatibility in autologging.
241  # DO NOT EDIT MANUALLY
242  
243  # GenAI packages
244  GENAI_FLAVOR_TO_MODULE_NAME = {json.dumps(genai_flavor_module_mapping, indent=4)}
245  
246  # Non-GenAI packages
247  NON_GENAI_FLAVOR_TO_MODULE_NAME = {json.dumps(non_genai_flavor_module_mapping, indent=4)}
248  
249  # Combined mapping for backward compatibility
250  FLAVOR_TO_MODULE_NAME = NON_GENAI_FLAVOR_TO_MODULE_NAME | GENAI_FLAVOR_TO_MODULE_NAME
251  """
252          )
253  
254  
255  def parse_args():
256      parser = argparse.ArgumentParser(description="Update MLflow package versions")
257      parser.add_argument(
258          "--skip-yml", action="store_true", help="Skip updating ml-package-versions.yml"
259      )
260      return parser.parse_args()
261  
262  
263  def get_min_supported_version(versions_infos: list[VersionInfo], genai: bool = False) -> str | None:
264      """
265      Get the minimum version that is released within the past two years
266      """
267      years = 1 if genai else 2
268      min_support_date = datetime.now() - timedelta(days=years * 365)
269      min_support_date = min_support_date.replace(tzinfo=None)
270  
271      # Extract versions that were released in the past two years
272      recent_versions = [v for v in versions_infos if v.upload_time > min_support_date]
273  
274      if not recent_versions:
275          return None
276  
277      # Get minimum version according to upload date
278      return min(recent_versions, key=lambda v: v.upload_time).version
279  
280  
281  def update(skip_yml=False):
282      if not skip_yml:
283          check_pypi_accessibility()
284      yml_path = "mlflow/ml-package-versions.yml"
285  
286      if not skip_yml:
287          old_src = read_file(yml_path)
288          new_src = old_src
289          config_dict = yaml.load(old_src, Loader=yaml.SafeLoader)
290          for flavor_key, config in config_dict.items():
291              # We currently don't have bandwidth to support newer versions of these flavors.
292              if flavor_key in ["litellm"]:
293                  continue
294              package_name = config["package_info"]["pip_release"]
295              genai = config["package_info"].get("genai", False)
296              versions_and_upload_times = get_package_version_infos(package_name)
297              min_supported_version = get_min_supported_version(
298                  versions_and_upload_times, genai=genai
299              )
300  
301              for category in ["autologging", "models"]:
302                  print("Processing", flavor_key, category)
303  
304                  if category in config and "minimum" in config[category]:
305                      old_min_version = config[category]["minimum"]
306                      if flavor_key == "spark":
307                          # We should support pyspark versions that are older than the cut off date.
308                          pass
309                      elif min_supported_version is None:
310                          # The latest release version was 2 years ago.
311                          # set the min version to be the same with the max version.
312                          max_ver = config[category]["maximum"]
313                          new_src = update_version(
314                              new_src, flavor_key, max_ver, category, update_max=False
315                          )
316                      elif Version(min_supported_version) > Version(old_min_version):
317                          new_src = update_version(
318                              new_src, flavor_key, min_supported_version, category, update_max=False
319                          )
320  
321                  if (category not in config) or config[category].get("pin_maximum", False):
322                      continue
323  
324                  max_ver = config[category]["maximum"]
325                  versions = [v.version for v in versions_and_upload_times]
326                  unsupported = config[category].get("unsupported", [])
327                  versions = set(versions).difference(unsupported)  # exclude unsupported versions
328                  latest_version = get_latest_version(versions)
329  
330                  if Version(latest_version) <= Version(max_ver):
331                      continue
332  
333                  new_src = update_version(
334                      new_src, flavor_key, latest_version, category, update_max=True
335                  )
336  
337          save_file(new_src, yml_path)
338  
339      update_ml_package_versions_py(yml_path)
340  
341  
342  def main():
343      args = parse_args()
344      update(args.skip_yml)
345  
346  
347  if __name__ == "__main__":
348      main()