/ dev / set_matrix.py
set_matrix.py
  1  """
  2  A script to set a matrix for the cross version tests for MLflow Models / autologging integrations.
  3  
  4  # Usage:
  5  
  6  ```
  7  # Test all items
  8  python dev/set_matrix.py
  9  
 10  # Exclude items for dev versions
 11  python dev/set_matrix.py --no-dev
 12  
 13  # Test items affected by config file updates
 14  python dev/set_matrix.py --ref-versions-yaml /path/to/ref-versions.yml
 15  
 16  # Test items affected by flavor module updates
 17  python dev/set_matrix.py --changed-files "mlflow/sklearn/__init__.py"
 18  
 19  # Test a specific flavor
 20  python dev/set_matrix.py --flavors sklearn
 21  
 22  # Test a specific version
 23  python dev/set_matrix.py --versions 1.1.1
 24  ```
 25  """
 26  
 27  import argparse
 28  import functools
 29  import json
 30  import os
 31  import re
 32  import shlex
 33  import shutil
 34  import sys
 35  import warnings
 36  from collections import defaultdict
 37  from datetime import datetime, timedelta, timezone
 38  from pathlib import Path
 39  from typing import Any, Iterator, TypeVar
 40  
 41  import requests
 42  import yaml
 43  from packaging.specifiers import SpecifierSet
 44  from packaging.version import InvalidVersion
 45  from packaging.version import Version as OriginalVersion
 46  from pydantic import BaseModel, ConfigDict, field_validator
 47  
 48  VERSIONS_YAML_PATH = "mlflow/ml-package-versions.yml"
 49  DEV_VERSION = "dev"
 50  # Treat "dev" as "newer than any existing versions"
 51  DEV_NUMERIC = "9999.9999.9999"
 52  
 53  T = TypeVar("T")
 54  
 55  
 56  class Version(OriginalVersion):
 57      def __init__(self, version: str, release_date: datetime | None = None):
 58          self._is_dev = version == DEV_VERSION
 59          self._release_date = release_date
 60          super().__init__(DEV_NUMERIC if self._is_dev else version)
 61  
 62      def __str__(self):
 63          return DEV_VERSION if self._is_dev else super().__str__()
 64  
 65      @classmethod
 66      def create_dev(cls):
 67          return cls(DEV_VERSION, datetime.now(timezone.utc))
 68  
 69      @property
 70      def days_since_release(self) -> int | None:
 71          """
 72          Compute the number of days since this version was released.
 73          Returns None if release date is not available.
 74          """
 75          if self._release_date is None:
 76              return None
 77          delta = datetime.now(timezone.utc) - self._release_date
 78          return delta.days
 79  
 80  
 81  class PackageInfo(BaseModel):
 82      model_config = ConfigDict(extra="forbid")
 83  
 84      pip_release: str
 85      install_dev: str | None = None
 86      module_name: str | None = None
 87      genai: bool = False
 88      repo: str | None = None
 89  
 90  
 91  class TestConfig(BaseModel):
 92      minimum: Version
 93      maximum: Version
 94      unsupported: list[SpecifierSet] | None = None
 95      requirements: dict[str, list[str]] | None = None
 96      python: dict[str, str] | None = None
 97      runs_on: dict[str, str] | None = None
 98      java: dict[str, str] | None = None
 99      run: str
100      allow_unreleased_max_version: bool | None = None
101      pre_test: str | None = None
102      test_every_n_versions: int = 1
103      test_tracing_sdk: bool = False
104      model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
105  
106      @field_validator("minimum", mode="before")
107      @classmethod
108      def validate_minimum(cls, v):
109          return Version(v)
110  
111      @field_validator("maximum", mode="before")
112      @classmethod
113      def validate_maximum(cls, v):
114          return Version(v)
115  
116      @field_validator("unsupported", mode="before")
117      @classmethod
118      def validate_unsupported(cls, v):
119          return [SpecifierSet(x) for x in v] if v else None
120  
121      @field_validator("python", mode="before")
122      @classmethod
123      def validate_python_requirements(cls, v):
124          if v is None:
125              return v
126  
127          # Read the minimum Python version from .python-version file
128          python_version_file = Path(".python-version")
129          min_python_version = python_version_file.read_text().strip()
130  
131          # Check if any value in the python dict matches the minimum version
132          for version in v.values():
133              if version == min_python_version:
134                  raise ValueError(f"Unnecessary Python version requirement: {version}")
135  
136          return v
137  
138  
139  class FlavorConfig(BaseModel):
140      model_config = ConfigDict(extra="forbid")
141  
142      package_info: PackageInfo
143      models: TestConfig | None = None
144      autologging: TestConfig | None = None
145  
146      @property
147      def categories(self) -> list[tuple[str, TestConfig]]:
148          cs = []
149          if self.models:
150              cs.append(("models", self.models))
151          if self.autologging:
152              cs.append(("autologging", self.autologging))
153          return cs
154  
155  
156  class MatrixItem(BaseModel):
157      name: str
158      flavor: str
159      category: str
160      job_name: str
161      install: str
162      run: str
163      package: str
164      version: Version
165      python: str
166      java: str
167      supported: bool
168      free_disk_space: bool
169      runs_on: str
170      pre_test: str | None = None
171      model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
172  
173      def __hash__(self):
174          return hash(frozenset(dict(self)))
175  
176  
177  def read_yaml(location, if_error=None):
178      try:
179          with open(location) as f:
180              yaml_dict = yaml.safe_load(f)
181          return {name: FlavorConfig(**cfg) for name, cfg in yaml_dict.items()}
182      except Exception as e:
183          if if_error is not None:
184              print(f"Failed to read '{location}' due to: `{e}`")
185              return if_error
186          raise
187  
188  
189  RELEASE_CUTOFF_DAYS = 14
190  
191  
192  def get_released_versions(package_name: str) -> list[Version]:
193      data = pypi_json(package_name)
194      cutoff = datetime.now(tz=timezone.utc) - timedelta(days=RELEASE_CUTOFF_DAYS)
195      versions: list[Version] = []
196      for version_str, distributions in data["releases"].items():
197          if len(distributions) == 0 or any(d.get("yanked", False) for d in distributions):
198              continue
199  
200          # Extract the earliest upload time as the release date
201          upload_times = [
202              datetime.fromisoformat(ut.replace("Z", "+00:00"))
203              for dist in distributions
204              if (ut := dist.get("upload_time_iso_8601"))
205          ]
206  
207          release_date = min(upload_times) if upload_times else None
208  
209          # Exclude versions with unknown release dates or released on/after the cutoff date
210          if not release_date or release_date >= cutoff:
211              continue
212  
213          try:
214              version = Version(version_str, release_date)
215          except InvalidVersion:
216              # Ignore invalid versions such as https://pypi.org/project/pytz/2004d
217              continue
218  
219          if version.is_devrelease or version.is_prerelease:
220              continue
221  
222          versions.append(version)
223  
224      return versions
225  
226  
227  def get_latest_micro_versions(versions):
228      """
229      Returns the latest micro version in each minor version.
230      """
231      by_minor = {}
232      for ver in sorted(versions, reverse=True):
233          by_minor.setdefault(ver.release[:2], ver)
234      return list(by_minor.values())
235  
236  
237  def filter_versions(
238      flavor: str,
239      versions: list[Version],
240      min_ver: Version,
241      max_ver: Version,
242      unsupported: list[SpecifierSet],
243      allow_unreleased_max_version: bool = False,
244  ):
245      """
246      Returns the versions that satisfy the following conditions:
247      1. Newer than or equal to `min_ver`.
248      2. Older than or equal to `max_ver.major`.
249      3. Not in `unsupported`.
250      """
251  
252      def _is_supported(v):
253          for specified_set in unsupported:
254              if v in specified_set:
255                  return False
256          return True
257  
258      def _check_max(v: Version) -> bool:
259          return v <= max_ver or (
260              # Exclude versions uploaded very recently to avoid testing unstable or potentially
261              # buggy releases. Newly released versions may have unresolved issues
262              # (see: https://github.com/huggingface/transformers/issues/34370).
263              v.major <= max_ver.major and v.days_since_release and v.days_since_release >= 1
264          )
265  
266      def _check_min(v: Version) -> bool:
267          return v >= min_ver
268  
269      return [v for v in versions if _check_min(v) and _check_max(v) and _is_supported(v)]
270  
271  
272  FLAVOR_FILE_PATTERN = re.compile(r"^(mlflow|tests)/(.+?)(_autolog(ging)?)?(\.py|/)")
273  
274  
275  def get_changed_flavors(changed_files, flavors):
276      """
277      Detects changed flavors from a list of changed files.
278      """
279      changed_flavors = set()
280      for f in changed_files:
281          match = FLAVOR_FILE_PATTERN.match(f)
282          if match and match.group(2) in flavors:
283              changed_flavors.add(match.group(2))
284      return changed_flavors
285  
286  
287  def _find_matches(spec: dict[str, T], version: str) -> Iterator[T]:
288      """
289      Args:
290          spec: A dictionary with key as version specifier and value as the corresponding value.
291              For example, {"< 1.0.0": "numpy<2.0", ">= 1.0.0": "numpy>=2.0"}.
292          version: The version to match against the specifiers.
293  
294      Returns:
295          An iterator of values that match the version.
296      """
297      for specifier, val in spec.items():
298          specifier_set = SpecifierSet(specifier.replace(DEV_VERSION, DEV_NUMERIC))
299          if specifier_set.contains(DEV_NUMERIC if version == DEV_VERSION else version):
300              yield val
301  
302  
303  def get_matched_requirements(requirements, version=None):
304      if not isinstance(requirements, dict):
305          raise TypeError(
306              f"Invalid object type for `requirements`: '{type(requirements)}'. Must be dict."
307          )
308      reqs = set()
309      for packages in _find_matches(requirements, version):
310          reqs.update(packages)
311      return sorted(reqs)
312  
313  
314  def get_java_version(java: dict[str, str] | None, version: str) -> str:
315      return _get_spec_value(java, version, "17")
316  
317  
318  @functools.lru_cache(maxsize=128)
319  def pypi_json(package: str) -> dict[str, Any]:
320      resp = requests.get(f"https://pypi.org/pypi/{package}/json")
321      resp.raise_for_status()
322      return resp.json()
323  
324  
325  def _requires_python(package: str, version: str) -> str | None:
326      package_json = pypi_json(package)
327      for ver, dist in package_json.get("releases", {}).items():
328          if ver != version:
329              continue
330  
331          for d in dist:
332              if rp := d.get("requires_python"):
333                  return rp
334      return None
335  
336  
337  def _requires_python_from_repo(repo_url: str) -> str | None:
338      """
339      Fetch requires-python from repository's pyproject.toml for dev version inference.
340      """
341      match = re.match(r"https://github\.com/([^/]+/[^/]+)/tree/HEAD(?:/(.+))?", repo_url)
342      if not match:
343          raise ValueError(f"Invalid GitHub repository URL format: {repo_url}")
344  
345      owner_repo = match.group(1)
346      subpath = match.group(2) or ""
347      pyproject_path = f"{subpath}/pyproject.toml" if subpath else "pyproject.toml"
348      raw_url = f"https://raw.githubusercontent.com/{owner_repo}/HEAD/{pyproject_path}"
349  
350      print(f"Fetching pyproject.toml from {owner_repo} (path: {pyproject_path})", file=sys.stderr)
351  
352      try:
353          resp = requests.get(raw_url, timeout=10)
354          resp.raise_for_status()
355      except requests.HTTPError as e:
356          if e.response.status_code == 404:
357              print(f"  pyproject.toml not found at {raw_url}", file=sys.stderr)
358              return None
359          raise
360  
361      if match := re.search(r'requires-python\s*=\s*["\']([^"\']+)["\']', resp.text):
362          print(f"  Found requires-python: {match.group(1)}", file=sys.stderr)
363          return match.group(1)
364  
365      print("  requires-python field not found in pyproject.toml", file=sys.stderr)
366      return None
367  
368  
369  def infer_python_version(package: str, version: str, repo_url: str | None = None) -> str:
370      """
371      Infer the minimum Python version required by the package.
372      """
373      candidates = ("3.10", "3.11")
374  
375      if version == DEV_VERSION and repo_url:
376          if rp := _requires_python_from_repo(repo_url):
377              spec = SpecifierSet(rp)
378              return next(filter(spec.contains, candidates), candidates[0])
379  
380      if rp := _requires_python(package, version):
381          spec = SpecifierSet(rp)
382          return next(filter(spec.contains, candidates), candidates[0])
383  
384      return candidates[0]
385  
386  
387  def _get_spec_value(spec: dict[str, str] | None, version: str, default: str) -> str:
388      if spec and (match := next(_find_matches(spec, version), None)):
389          return match
390      return default
391  
392  
393  def get_python_version(
394      python: dict[str, str] | None, package: str, version: str, repo_url: str | None = None
395  ) -> str:
396      if python and (match := next(_find_matches(python, version), None)):
397          return match
398  
399      return infer_python_version(package, version, repo_url)
400  
401  
402  def get_runs_on(runs_on: dict[str, str] | None, version: str) -> str:
403      return _get_spec_value(runs_on, version, "ubuntu-latest")
404  
405  
406  def remove_comments(s):
407      return "\n".join(l for l in s.strip().split("\n") if not l.strip().startswith("#"))
408  
409  
410  def make_pip_install_command(packages):
411      return "uv pip install --system " + " ".join(f"'{x}'" for x in packages)
412  
413  
414  def divider(title, length=None):
415      length = length or shutil.get_terminal_size(fallback=(80, 24))[0]
416      return "\n" + f" {title} ".center(length, "=") + "\n"
417  
418  
419  def split_by_comma(x):
420      return [s for item in x.split(",") if (s := item.strip())]
421  
422  
423  def parse_args(args):
424      parser = argparse.ArgumentParser(description="Set a test matrix for the cross version tests")
425      parser.add_argument(
426          "--versions-yaml",
427          required=False,
428          default="mlflow/ml-package-versions.yml",
429          help=(
430              "URL or local file path of the config yaml. Defaults to "
431              "'mlflow/ml-package-versions.yml' on the branch where this script is running."
432          ),
433      )
434      parser.add_argument(
435          "--ref-versions-yaml",
436          required=False,
437          default=None,
438          help=(
439              "URL or local file path of the reference config yaml which will be compared with the "
440              "config specified by `--versions-yaml` in order to identify the config updates."
441          ),
442      )
443      parser.add_argument(
444          "--changed-files",
445          type=lambda x: [] if x.strip() == "" else x.strip().split("\n"),
446          required=False,
447          default=None,
448          help=("A string that represents a list of changed files"),
449      )
450  
451      parser.add_argument(
452          "--flavors",
453          required=False,
454          type=split_by_comma,
455          help=(
456              "Comma-separated string specifying which flavors to test (e.g. 'sklearn, xgboost'). "
457              "If unspecified, all flavors are tested."
458          ),
459      )
460      parser.add_argument(
461          "--versions",
462          required=False,
463          type=split_by_comma,
464          help=(
465              "Comma-separated string specifying which versions to test (e.g. '1.2.3, 4.5.6'). "
466              "If unspecified, all versions are tested."
467          ),
468      )
469      parser.add_argument(
470          "--no-dev",
471          action="store_true",
472          default=False,
473          help="If True, exclude dev versions in the test matrix.",
474      )
475      parser.add_argument(
476          "--only-latest",
477          action="store_true",
478          default=False,
479          help=(
480              "If True, only test the latest version of each group. Useful when you want to avoid "
481              "running too many GitHub Action jobs."
482          ),
483      )
484  
485      return parser.parse_args(args)
486  
487  
488  def get_flavor(name):
489      return {"pytorch-lightning": "pytorch"}.get(name, name)
490  
491  
492  def validate_test_coverage(flavor: str, config: FlavorConfig):
493      """
494      Validate that all test files for the flavor are executed in the cross-version tests.
495  
496      This is done by parsing `run` commands in the `ml-package-versions.yml` to get the list
497      of executed test files, and then comparing it with the actual test files in the directory.
498      """
499      test_dir = os.path.join("tests", flavor)
500      tested_files = set()
501  
502      for category, cfg in config.categories:
503          if not cfg.run:
504              continue
505  
506          # Consolidate multi-line commands with "\" to a single line
507          commands = cfg.run.replace("\\\n", "").split("\n")
508  
509          # Parse pytest commands to get the executed test files
510          for cmd in commands:
511              cmd = cmd.strip().rstrip(";")
512              if cmd.startswith("pytest"):
513                  tested_files |= _get_test_files_from_pytest_command(cmd, test_dir)
514  
515      if untested_files := _get_test_files(test_dir) - tested_files:
516          # TODO: Update this after fixing ml-package-versions.yml to
517          # have all test files in the matrix.
518          warnings.warn(
519              f"Flavor '{flavor}' has test files that are not covered by the test matrix. \n"
520              + "\n".join(f"\033[91m - {t}\033[0m" for t in untested_files)
521              + f"\nPlease update {VERSIONS_YAML_PATH} to execute all test files. Note that this "
522              "check does not handle complex syntax in test commands e.g. loop. It is generally "
523              "recommended to use simple commands as we cannot test the test commands themselves."
524          )
525  
526  
527  PYTEST_FILE_PATTERN = re.compile(r"^test_.*\.py$")
528  
529  
530  def _get_test_files(test_dir_or_path: str) -> set[Path]:
531      """List all test files in the given directory or file path."""
532      path = Path(test_dir_or_path)
533      if path.is_dir():
534          return set(path.rglob("test_*.py"))
535  
536      if PYTEST_FILE_PATTERN.match(path.name):
537          return {path}
538  
539      return set()
540  
541  
542  def _get_test_files_from_pytest_command(cmd, test_dir):
543      parser = argparse.ArgumentParser()
544      parser.add_argument("--ignore", action="append")
545      parser.add_argument("paths", nargs="*")
546      args = parser.parse_known_args(shlex.split(cmd))[0]
547  
548      executed_files = set()
549      ignore_files = set()
550      for path in args.paths:
551          if path.startswith(test_dir):
552              executed_files |= _get_test_files(path)
553      for ignore_path in args.ignore or []:
554          if ignore_path.startswith(test_dir):
555              ignore_files |= _get_test_files(ignore_path)
556      return executed_files - ignore_files
557  
558  
559  def validate_requirements(
560      requirements: dict[str, list[str]],
561      name: str,
562      category: str,
563      package_info: PackageInfo,
564      versions: list[Version],
565  ) -> None:
566      """
567      Validate that the requirements specified in the config don't contain unused items.
568      Here's an example of invalid requirements:
569  
570      ```
571      sklearn:
572          package_info:
573              pip_release: "scikit-learn"
574          autologging:
575              minimum: "1.3.0"
576              maximum: "1.5.0"
577              requirements:
578                  "< 1.0.0": ["numpy<2.0"]    # Unused
579                  ">= 1.4.0": ["numpy>=2.0"]  # Used
580      ```
581      """
582      for specifier in requirements:
583          if "dev" in specifier and package_info.install_dev:
584              continue
585  
586          # Does this version specifier (e.g. '< 1.0.0') match at least one version?
587          # If not, raise an error.
588          spec_set = SpecifierSet(specifier)
589          if not any(map(spec_set.contains, versions)):
590              raise ValueError(
591                  f"Found unused requirements {specifier!r} for {name} / {category}. "
592                  "Please remove it or adjust the version specifier."
593              )
594  
595  
596  def expand_config(config: dict[str, Any], *, is_ref: bool = False) -> set[MatrixItem]:
597      matrix = set()
598      for name, flavor_config in config.items():
599          flavor = get_flavor(name)
600          package_info = flavor_config.package_info
601          all_versions = get_released_versions(package_info.pip_release)
602          free_disk_space = package_info.pip_release in (
603              "transformers",
604              "sentence-transformers",
605              "torch",
606          )
607          validate_test_coverage(name, flavor_config)
608          for category, cfg in flavor_config.categories:
609              versions = filter_versions(
610                  flavor,
611                  all_versions,
612                  cfg.minimum,
613                  cfg.maximum,
614                  cfg.unsupported or [],
615                  allow_unreleased_max_version=cfg.allow_unreleased_max_version or False,
616              )
617              versions = get_latest_micro_versions(versions)
618  
619              # Test every n minor versions if specified
620              if cfg.test_every_n_versions > 1:
621                  versions = sorted(versions)[:: -cfg.test_every_n_versions][::-1]
622  
623              # Always test the minimum version
624              if cfg.minimum not in versions and cfg.minimum in all_versions:
625                  versions.append(cfg.minimum)
626  
627              if not is_ref and cfg.requirements:
628                  validate_requirements(cfg.requirements, name, category, package_info, versions)
629  
630              for ver in versions:
631                  requirements = [f"{package_info.pip_release}=={ver}"]
632                  requirements.extend(get_matched_requirements(cfg.requirements or {}, str(ver)))
633                  install = make_pip_install_command(requirements)
634                  run = remove_comments(cfg.run)
635                  python = get_python_version(
636                      cfg.python, package_info.pip_release, str(ver), package_info.repo
637                  )
638                  runs_on = get_runs_on(cfg.runs_on, ver)
639                  java = get_java_version(cfg.java, str(ver))
640  
641                  matrix.add(
642                      MatrixItem(
643                          name=name,
644                          flavor=flavor,
645                          category=category,
646                          job_name=f"{name} / {category} / {ver}",
647                          install=install,
648                          run=run,
649                          package=package_info.pip_release,
650                          version=ver,
651                          python=python,
652                          java=java,
653                          supported=ver <= cfg.maximum,
654                          free_disk_space=free_disk_space,
655                          runs_on=runs_on,
656                          pre_test=cfg.pre_test,
657                      )
658                  )
659  
660              # Add tracing SDK test with the latest stable version
661              if len(versions) > 0 and category == "autologging" and cfg.test_tracing_sdk:
662                  version = max(versions)  # Test against the latest stable version
663                  matrix.add(
664                      MatrixItem(
665                          name=f"{name}-tracing",
666                          flavor=flavor,
667                          category="tracing-sdk",
668                          job_name=f"{name} / tracing-sdk / {version}",
669                          install=install,
670                          # --import-mode=importlib is required for testing tracing SDK
671                          # (mlflow-tracing) works properly, without being affected by environment.
672                          run=run.replace("pytest", "pytest --import-mode=importlib"),
673                          package=package_info.pip_release,
674                          version=version,
675                          java=java,
676                          supported=version <= cfg.maximum,
677                          free_disk_space=free_disk_space,
678                          python=python,
679                          runs_on=runs_on,
680                      )
681                  )
682  
683              # Skip dev version testing: install_dev installs from git, which
684              # doesn't respect UV_EXCLUDE_NEWER.
685              if False:  # package_info.install_dev:
686                  install_dev = remove_comments(package_info.install_dev)
687                  if requirements := get_matched_requirements(cfg.requirements or {}, DEV_VERSION):
688                      install = make_pip_install_command(requirements) + "\n" + install_dev
689                  else:
690                      install = install_dev
691                  python = get_python_version(
692                      cfg.python, package_info.pip_release, DEV_VERSION, package_info.repo
693                  )
694                  runs_on = get_runs_on(cfg.runs_on, DEV_VERSION)
695                  java = get_java_version(cfg.java, DEV_VERSION)
696  
697                  run = remove_comments(cfg.run)
698                  dev_version = Version.create_dev()
699                  matrix.add(
700                      MatrixItem(
701                          name=name,
702                          flavor=flavor,
703                          category=category,
704                          job_name=f"{name} / {category} / {dev_version}",
705                          install=install,
706                          run=run,
707                          package=package_info.pip_release,
708                          version=dev_version,
709                          python=python,
710                          java=java,
711                          supported=False,
712                          free_disk_space=free_disk_space,
713                          runs_on=runs_on,
714                          pre_test=cfg.pre_test,
715                      )
716                  )
717      return matrix
718  
719  
720  def apply_changed_files(changed_files, matrix):
721      all_flavors = {x.flavor for x in matrix}
722      changed_flavors = (
723          # If this file has been changed, re-run all tests
724          all_flavors
725          if str(Path(__file__).relative_to(Path.cwd())) in changed_files
726          else get_changed_flavors(changed_files, all_flavors)
727      )
728  
729      # Run langchain tests if any tracing files have been changed
730      if any(f.startswith("mlflow/tracing/") for f in changed_files):
731          changed_flavors.add("langchain")
732  
733      return set(filter(lambda x: x.flavor in changed_flavors, matrix))
734  
735  
736  def generate_matrix(args):
737      args = parse_args(args)
738      config = read_yaml(args.versions_yaml)
739      if (args.ref_versions_yaml, args.changed_files).count(None) == 2:
740          matrix = expand_config(config)
741      else:
742          matrix = set()
743          mat = expand_config(config)
744  
745          if args.ref_versions_yaml:
746              ref_config = read_yaml(args.ref_versions_yaml, if_error={})
747              ref_matrix = expand_config(ref_config, is_ref=True)
748              matrix.update(mat.difference(ref_matrix))
749  
750          if args.changed_files:
751              matrix.update(apply_changed_files(args.changed_files, mat))
752  
753      # Apply the filtering arguments
754      if args.no_dev:
755          matrix = filter(lambda x: x.version != Version.create_dev(), matrix)
756  
757      if args.flavors:
758          matrix = filter(lambda x: x.flavor in args.flavors, matrix)
759  
760      if args.versions:
761          matrix = filter(lambda x: x.version in map(Version, args.versions), matrix)
762  
763      if args.only_latest:
764          groups = defaultdict(list)
765          for item in matrix:
766              groups[(item.name, item.category)].append(item)
767          matrix = {max(group, key=lambda x: x.version) for group in groups.values()}
768  
769      return set(matrix)
770  
771  
772  class CustomEncoder(json.JSONEncoder):
773      def default(self, o):
774          if isinstance(o, MatrixItem):
775              return o.model_dump(exclude_none=True)
776          elif isinstance(o, Version):
777              return str(o)
778          return super().default(o)
779  
780  
781  def set_action_output(name, value):
782      with open(os.environ.get("GITHUB_OUTPUT"), "a") as f:
783          f.write(f"{name}={value}\n")
784  
785  
786  def split(matrix, n):
787      grouped_by_name = defaultdict(list)
788      for item in matrix:
789          grouped_by_name[item.name].append(item)
790  
791      num = len(matrix) // n
792      chunk = []
793      for group in grouped_by_name.values():
794          chunk.extend(group)
795          if len(chunk) >= num:
796              yield chunk
797              chunk = []
798  
799      if chunk:
800          yield chunk
801  
802  
803  def main(args):
804      # https://docs.github.com/en/actions/learn-github-actions/usage-limits-billing-and-administration#usage-limits
805      # > A job matrix can generate a maximum of 256 jobs per workflow run.
806      MAX_ITEMS = 256
807      NUM_JOBS = 2
808  
809      print(divider("Parameters"))
810      print(json.dumps(args, indent=2))
811      matrix = generate_matrix(args)
812      matrix = sorted(matrix, key=lambda x: (x.name, x.category, x.version))
813      assert len(matrix) <= MAX_ITEMS * 2, f"Too many jobs: {len(matrix)} > {MAX_ITEMS * NUM_JOBS}"
814      for idx, mat in enumerate(split(matrix, NUM_JOBS), start=1):
815          mat = {"include": mat, "job_name": [x.job_name for x in mat]}
816          print(divider(f"Matrix {idx}"))
817          print(json.dumps(mat, indent=2, cls=CustomEncoder))
818          if "GITHUB_ACTIONS" in os.environ:
819              set_action_output(f"matrix{idx}", json.dumps(mat, cls=CustomEncoder))
820              set_action_output(f"is_matrix{idx}_empty", "true" if len(mat) == 0 else "false")
821  
822  
823  if __name__ == "__main__":
824      main(sys.argv[1:])