set_matrix.py
1 """ 2 A script to set a matrix for the cross version tests for MLflow Models / autologging integrations. 3 4 # Usage: 5 6 ``` 7 # Test all items 8 python dev/set_matrix.py 9 10 # Exclude items for dev versions 11 python dev/set_matrix.py --no-dev 12 13 # Test items affected by config file updates 14 python dev/set_matrix.py --ref-versions-yaml /path/to/ref-versions.yml 15 16 # Test items affected by flavor module updates 17 python dev/set_matrix.py --changed-files "mlflow/sklearn/__init__.py" 18 19 # Test a specific flavor 20 python dev/set_matrix.py --flavors sklearn 21 22 # Test a specific version 23 python dev/set_matrix.py --versions 1.1.1 24 ``` 25 """ 26 27 import argparse 28 import functools 29 import json 30 import os 31 import re 32 import shlex 33 import shutil 34 import sys 35 import warnings 36 from collections import defaultdict 37 from datetime import datetime, timedelta, timezone 38 from pathlib import Path 39 from typing import Any, Iterator, TypeVar 40 41 import requests 42 import yaml 43 from packaging.specifiers import SpecifierSet 44 from packaging.version import InvalidVersion 45 from packaging.version import Version as OriginalVersion 46 from pydantic import BaseModel, ConfigDict, field_validator 47 48 VERSIONS_YAML_PATH = "mlflow/ml-package-versions.yml" 49 DEV_VERSION = "dev" 50 # Treat "dev" as "newer than any existing versions" 51 DEV_NUMERIC = "9999.9999.9999" 52 53 T = TypeVar("T") 54 55 56 class Version(OriginalVersion): 57 def __init__(self, version: str, release_date: datetime | None = None): 58 self._is_dev = version == DEV_VERSION 59 self._release_date = release_date 60 super().__init__(DEV_NUMERIC if self._is_dev else version) 61 62 def __str__(self): 63 return DEV_VERSION if self._is_dev else super().__str__() 64 65 @classmethod 66 def create_dev(cls): 67 return cls(DEV_VERSION, datetime.now(timezone.utc)) 68 69 @property 70 def days_since_release(self) -> int | None: 71 """ 72 Compute the number of days since this version was released. 73 Returns None if release date is not available. 74 """ 75 if self._release_date is None: 76 return None 77 delta = datetime.now(timezone.utc) - self._release_date 78 return delta.days 79 80 81 class PackageInfo(BaseModel): 82 model_config = ConfigDict(extra="forbid") 83 84 pip_release: str 85 install_dev: str | None = None 86 module_name: str | None = None 87 genai: bool = False 88 repo: str | None = None 89 90 91 class TestConfig(BaseModel): 92 minimum: Version 93 maximum: Version 94 unsupported: list[SpecifierSet] | None = None 95 requirements: dict[str, list[str]] | None = None 96 python: dict[str, str] | None = None 97 runs_on: dict[str, str] | None = None 98 java: dict[str, str] | None = None 99 run: str 100 allow_unreleased_max_version: bool | None = None 101 pre_test: str | None = None 102 test_every_n_versions: int = 1 103 test_tracing_sdk: bool = False 104 model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) 105 106 @field_validator("minimum", mode="before") 107 @classmethod 108 def validate_minimum(cls, v): 109 return Version(v) 110 111 @field_validator("maximum", mode="before") 112 @classmethod 113 def validate_maximum(cls, v): 114 return Version(v) 115 116 @field_validator("unsupported", mode="before") 117 @classmethod 118 def validate_unsupported(cls, v): 119 return [SpecifierSet(x) for x in v] if v else None 120 121 @field_validator("python", mode="before") 122 @classmethod 123 def validate_python_requirements(cls, v): 124 if v is None: 125 return v 126 127 # Read the minimum Python version from .python-version file 128 python_version_file = Path(".python-version") 129 min_python_version = python_version_file.read_text().strip() 130 131 # Check if any value in the python dict matches the minimum version 132 for version in v.values(): 133 if version == min_python_version: 134 raise ValueError(f"Unnecessary Python version requirement: {version}") 135 136 return v 137 138 139 class FlavorConfig(BaseModel): 140 model_config = ConfigDict(extra="forbid") 141 142 package_info: PackageInfo 143 models: TestConfig | None = None 144 autologging: TestConfig | None = None 145 146 @property 147 def categories(self) -> list[tuple[str, TestConfig]]: 148 cs = [] 149 if self.models: 150 cs.append(("models", self.models)) 151 if self.autologging: 152 cs.append(("autologging", self.autologging)) 153 return cs 154 155 156 class MatrixItem(BaseModel): 157 name: str 158 flavor: str 159 category: str 160 job_name: str 161 install: str 162 run: str 163 package: str 164 version: Version 165 python: str 166 java: str 167 supported: bool 168 free_disk_space: bool 169 runs_on: str 170 pre_test: str | None = None 171 model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) 172 173 def __hash__(self): 174 return hash(frozenset(dict(self))) 175 176 177 def read_yaml(location, if_error=None): 178 try: 179 with open(location) as f: 180 yaml_dict = yaml.safe_load(f) 181 return {name: FlavorConfig(**cfg) for name, cfg in yaml_dict.items()} 182 except Exception as e: 183 if if_error is not None: 184 print(f"Failed to read '{location}' due to: `{e}`") 185 return if_error 186 raise 187 188 189 RELEASE_CUTOFF_DAYS = 14 190 191 192 def get_released_versions(package_name: str) -> list[Version]: 193 data = pypi_json(package_name) 194 cutoff = datetime.now(tz=timezone.utc) - timedelta(days=RELEASE_CUTOFF_DAYS) 195 versions: list[Version] = [] 196 for version_str, distributions in data["releases"].items(): 197 if len(distributions) == 0 or any(d.get("yanked", False) for d in distributions): 198 continue 199 200 # Extract the earliest upload time as the release date 201 upload_times = [ 202 datetime.fromisoformat(ut.replace("Z", "+00:00")) 203 for dist in distributions 204 if (ut := dist.get("upload_time_iso_8601")) 205 ] 206 207 release_date = min(upload_times) if upload_times else None 208 209 # Exclude versions with unknown release dates or released on/after the cutoff date 210 if not release_date or release_date >= cutoff: 211 continue 212 213 try: 214 version = Version(version_str, release_date) 215 except InvalidVersion: 216 # Ignore invalid versions such as https://pypi.org/project/pytz/2004d 217 continue 218 219 if version.is_devrelease or version.is_prerelease: 220 continue 221 222 versions.append(version) 223 224 return versions 225 226 227 def get_latest_micro_versions(versions): 228 """ 229 Returns the latest micro version in each minor version. 230 """ 231 by_minor = {} 232 for ver in sorted(versions, reverse=True): 233 by_minor.setdefault(ver.release[:2], ver) 234 return list(by_minor.values()) 235 236 237 def filter_versions( 238 flavor: str, 239 versions: list[Version], 240 min_ver: Version, 241 max_ver: Version, 242 unsupported: list[SpecifierSet], 243 allow_unreleased_max_version: bool = False, 244 ): 245 """ 246 Returns the versions that satisfy the following conditions: 247 1. Newer than or equal to `min_ver`. 248 2. Older than or equal to `max_ver.major`. 249 3. Not in `unsupported`. 250 """ 251 252 def _is_supported(v): 253 for specified_set in unsupported: 254 if v in specified_set: 255 return False 256 return True 257 258 def _check_max(v: Version) -> bool: 259 return v <= max_ver or ( 260 # Exclude versions uploaded very recently to avoid testing unstable or potentially 261 # buggy releases. Newly released versions may have unresolved issues 262 # (see: https://github.com/huggingface/transformers/issues/34370). 263 v.major <= max_ver.major and v.days_since_release and v.days_since_release >= 1 264 ) 265 266 def _check_min(v: Version) -> bool: 267 return v >= min_ver 268 269 return [v for v in versions if _check_min(v) and _check_max(v) and _is_supported(v)] 270 271 272 FLAVOR_FILE_PATTERN = re.compile(r"^(mlflow|tests)/(.+?)(_autolog(ging)?)?(\.py|/)") 273 274 275 def get_changed_flavors(changed_files, flavors): 276 """ 277 Detects changed flavors from a list of changed files. 278 """ 279 changed_flavors = set() 280 for f in changed_files: 281 match = FLAVOR_FILE_PATTERN.match(f) 282 if match and match.group(2) in flavors: 283 changed_flavors.add(match.group(2)) 284 return changed_flavors 285 286 287 def _find_matches(spec: dict[str, T], version: str) -> Iterator[T]: 288 """ 289 Args: 290 spec: A dictionary with key as version specifier and value as the corresponding value. 291 For example, {"< 1.0.0": "numpy<2.0", ">= 1.0.0": "numpy>=2.0"}. 292 version: The version to match against the specifiers. 293 294 Returns: 295 An iterator of values that match the version. 296 """ 297 for specifier, val in spec.items(): 298 specifier_set = SpecifierSet(specifier.replace(DEV_VERSION, DEV_NUMERIC)) 299 if specifier_set.contains(DEV_NUMERIC if version == DEV_VERSION else version): 300 yield val 301 302 303 def get_matched_requirements(requirements, version=None): 304 if not isinstance(requirements, dict): 305 raise TypeError( 306 f"Invalid object type for `requirements`: '{type(requirements)}'. Must be dict." 307 ) 308 reqs = set() 309 for packages in _find_matches(requirements, version): 310 reqs.update(packages) 311 return sorted(reqs) 312 313 314 def get_java_version(java: dict[str, str] | None, version: str) -> str: 315 return _get_spec_value(java, version, "17") 316 317 318 @functools.lru_cache(maxsize=128) 319 def pypi_json(package: str) -> dict[str, Any]: 320 resp = requests.get(f"https://pypi.org/pypi/{package}/json") 321 resp.raise_for_status() 322 return resp.json() 323 324 325 def _requires_python(package: str, version: str) -> str | None: 326 package_json = pypi_json(package) 327 for ver, dist in package_json.get("releases", {}).items(): 328 if ver != version: 329 continue 330 331 for d in dist: 332 if rp := d.get("requires_python"): 333 return rp 334 return None 335 336 337 def _requires_python_from_repo(repo_url: str) -> str | None: 338 """ 339 Fetch requires-python from repository's pyproject.toml for dev version inference. 340 """ 341 match = re.match(r"https://github\.com/([^/]+/[^/]+)/tree/HEAD(?:/(.+))?", repo_url) 342 if not match: 343 raise ValueError(f"Invalid GitHub repository URL format: {repo_url}") 344 345 owner_repo = match.group(1) 346 subpath = match.group(2) or "" 347 pyproject_path = f"{subpath}/pyproject.toml" if subpath else "pyproject.toml" 348 raw_url = f"https://raw.githubusercontent.com/{owner_repo}/HEAD/{pyproject_path}" 349 350 print(f"Fetching pyproject.toml from {owner_repo} (path: {pyproject_path})", file=sys.stderr) 351 352 try: 353 resp = requests.get(raw_url, timeout=10) 354 resp.raise_for_status() 355 except requests.HTTPError as e: 356 if e.response.status_code == 404: 357 print(f" pyproject.toml not found at {raw_url}", file=sys.stderr) 358 return None 359 raise 360 361 if match := re.search(r'requires-python\s*=\s*["\']([^"\']+)["\']', resp.text): 362 print(f" Found requires-python: {match.group(1)}", file=sys.stderr) 363 return match.group(1) 364 365 print(" requires-python field not found in pyproject.toml", file=sys.stderr) 366 return None 367 368 369 def infer_python_version(package: str, version: str, repo_url: str | None = None) -> str: 370 """ 371 Infer the minimum Python version required by the package. 372 """ 373 candidates = ("3.10", "3.11") 374 375 if version == DEV_VERSION and repo_url: 376 if rp := _requires_python_from_repo(repo_url): 377 spec = SpecifierSet(rp) 378 return next(filter(spec.contains, candidates), candidates[0]) 379 380 if rp := _requires_python(package, version): 381 spec = SpecifierSet(rp) 382 return next(filter(spec.contains, candidates), candidates[0]) 383 384 return candidates[0] 385 386 387 def _get_spec_value(spec: dict[str, str] | None, version: str, default: str) -> str: 388 if spec and (match := next(_find_matches(spec, version), None)): 389 return match 390 return default 391 392 393 def get_python_version( 394 python: dict[str, str] | None, package: str, version: str, repo_url: str | None = None 395 ) -> str: 396 if python and (match := next(_find_matches(python, version), None)): 397 return match 398 399 return infer_python_version(package, version, repo_url) 400 401 402 def get_runs_on(runs_on: dict[str, str] | None, version: str) -> str: 403 return _get_spec_value(runs_on, version, "ubuntu-latest") 404 405 406 def remove_comments(s): 407 return "\n".join(l for l in s.strip().split("\n") if not l.strip().startswith("#")) 408 409 410 def make_pip_install_command(packages): 411 return "uv pip install --system " + " ".join(f"'{x}'" for x in packages) 412 413 414 def divider(title, length=None): 415 length = length or shutil.get_terminal_size(fallback=(80, 24))[0] 416 return "\n" + f" {title} ".center(length, "=") + "\n" 417 418 419 def split_by_comma(x): 420 return [s for item in x.split(",") if (s := item.strip())] 421 422 423 def parse_args(args): 424 parser = argparse.ArgumentParser(description="Set a test matrix for the cross version tests") 425 parser.add_argument( 426 "--versions-yaml", 427 required=False, 428 default="mlflow/ml-package-versions.yml", 429 help=( 430 "URL or local file path of the config yaml. Defaults to " 431 "'mlflow/ml-package-versions.yml' on the branch where this script is running." 432 ), 433 ) 434 parser.add_argument( 435 "--ref-versions-yaml", 436 required=False, 437 default=None, 438 help=( 439 "URL or local file path of the reference config yaml which will be compared with the " 440 "config specified by `--versions-yaml` in order to identify the config updates." 441 ), 442 ) 443 parser.add_argument( 444 "--changed-files", 445 type=lambda x: [] if x.strip() == "" else x.strip().split("\n"), 446 required=False, 447 default=None, 448 help=("A string that represents a list of changed files"), 449 ) 450 451 parser.add_argument( 452 "--flavors", 453 required=False, 454 type=split_by_comma, 455 help=( 456 "Comma-separated string specifying which flavors to test (e.g. 'sklearn, xgboost'). " 457 "If unspecified, all flavors are tested." 458 ), 459 ) 460 parser.add_argument( 461 "--versions", 462 required=False, 463 type=split_by_comma, 464 help=( 465 "Comma-separated string specifying which versions to test (e.g. '1.2.3, 4.5.6'). " 466 "If unspecified, all versions are tested." 467 ), 468 ) 469 parser.add_argument( 470 "--no-dev", 471 action="store_true", 472 default=False, 473 help="If True, exclude dev versions in the test matrix.", 474 ) 475 parser.add_argument( 476 "--only-latest", 477 action="store_true", 478 default=False, 479 help=( 480 "If True, only test the latest version of each group. Useful when you want to avoid " 481 "running too many GitHub Action jobs." 482 ), 483 ) 484 485 return parser.parse_args(args) 486 487 488 def get_flavor(name): 489 return {"pytorch-lightning": "pytorch"}.get(name, name) 490 491 492 def validate_test_coverage(flavor: str, config: FlavorConfig): 493 """ 494 Validate that all test files for the flavor are executed in the cross-version tests. 495 496 This is done by parsing `run` commands in the `ml-package-versions.yml` to get the list 497 of executed test files, and then comparing it with the actual test files in the directory. 498 """ 499 test_dir = os.path.join("tests", flavor) 500 tested_files = set() 501 502 for category, cfg in config.categories: 503 if not cfg.run: 504 continue 505 506 # Consolidate multi-line commands with "\" to a single line 507 commands = cfg.run.replace("\\\n", "").split("\n") 508 509 # Parse pytest commands to get the executed test files 510 for cmd in commands: 511 cmd = cmd.strip().rstrip(";") 512 if cmd.startswith("pytest"): 513 tested_files |= _get_test_files_from_pytest_command(cmd, test_dir) 514 515 if untested_files := _get_test_files(test_dir) - tested_files: 516 # TODO: Update this after fixing ml-package-versions.yml to 517 # have all test files in the matrix. 518 warnings.warn( 519 f"Flavor '{flavor}' has test files that are not covered by the test matrix. \n" 520 + "\n".join(f"\033[91m - {t}\033[0m" for t in untested_files) 521 + f"\nPlease update {VERSIONS_YAML_PATH} to execute all test files. Note that this " 522 "check does not handle complex syntax in test commands e.g. loop. It is generally " 523 "recommended to use simple commands as we cannot test the test commands themselves." 524 ) 525 526 527 PYTEST_FILE_PATTERN = re.compile(r"^test_.*\.py$") 528 529 530 def _get_test_files(test_dir_or_path: str) -> set[Path]: 531 """List all test files in the given directory or file path.""" 532 path = Path(test_dir_or_path) 533 if path.is_dir(): 534 return set(path.rglob("test_*.py")) 535 536 if PYTEST_FILE_PATTERN.match(path.name): 537 return {path} 538 539 return set() 540 541 542 def _get_test_files_from_pytest_command(cmd, test_dir): 543 parser = argparse.ArgumentParser() 544 parser.add_argument("--ignore", action="append") 545 parser.add_argument("paths", nargs="*") 546 args = parser.parse_known_args(shlex.split(cmd))[0] 547 548 executed_files = set() 549 ignore_files = set() 550 for path in args.paths: 551 if path.startswith(test_dir): 552 executed_files |= _get_test_files(path) 553 for ignore_path in args.ignore or []: 554 if ignore_path.startswith(test_dir): 555 ignore_files |= _get_test_files(ignore_path) 556 return executed_files - ignore_files 557 558 559 def validate_requirements( 560 requirements: dict[str, list[str]], 561 name: str, 562 category: str, 563 package_info: PackageInfo, 564 versions: list[Version], 565 ) -> None: 566 """ 567 Validate that the requirements specified in the config don't contain unused items. 568 Here's an example of invalid requirements: 569 570 ``` 571 sklearn: 572 package_info: 573 pip_release: "scikit-learn" 574 autologging: 575 minimum: "1.3.0" 576 maximum: "1.5.0" 577 requirements: 578 "< 1.0.0": ["numpy<2.0"] # Unused 579 ">= 1.4.0": ["numpy>=2.0"] # Used 580 ``` 581 """ 582 for specifier in requirements: 583 if "dev" in specifier and package_info.install_dev: 584 continue 585 586 # Does this version specifier (e.g. '< 1.0.0') match at least one version? 587 # If not, raise an error. 588 spec_set = SpecifierSet(specifier) 589 if not any(map(spec_set.contains, versions)): 590 raise ValueError( 591 f"Found unused requirements {specifier!r} for {name} / {category}. " 592 "Please remove it or adjust the version specifier." 593 ) 594 595 596 def expand_config(config: dict[str, Any], *, is_ref: bool = False) -> set[MatrixItem]: 597 matrix = set() 598 for name, flavor_config in config.items(): 599 flavor = get_flavor(name) 600 package_info = flavor_config.package_info 601 all_versions = get_released_versions(package_info.pip_release) 602 free_disk_space = package_info.pip_release in ( 603 "transformers", 604 "sentence-transformers", 605 "torch", 606 ) 607 validate_test_coverage(name, flavor_config) 608 for category, cfg in flavor_config.categories: 609 versions = filter_versions( 610 flavor, 611 all_versions, 612 cfg.minimum, 613 cfg.maximum, 614 cfg.unsupported or [], 615 allow_unreleased_max_version=cfg.allow_unreleased_max_version or False, 616 ) 617 versions = get_latest_micro_versions(versions) 618 619 # Test every n minor versions if specified 620 if cfg.test_every_n_versions > 1: 621 versions = sorted(versions)[:: -cfg.test_every_n_versions][::-1] 622 623 # Always test the minimum version 624 if cfg.minimum not in versions and cfg.minimum in all_versions: 625 versions.append(cfg.minimum) 626 627 if not is_ref and cfg.requirements: 628 validate_requirements(cfg.requirements, name, category, package_info, versions) 629 630 for ver in versions: 631 requirements = [f"{package_info.pip_release}=={ver}"] 632 requirements.extend(get_matched_requirements(cfg.requirements or {}, str(ver))) 633 install = make_pip_install_command(requirements) 634 run = remove_comments(cfg.run) 635 python = get_python_version( 636 cfg.python, package_info.pip_release, str(ver), package_info.repo 637 ) 638 runs_on = get_runs_on(cfg.runs_on, ver) 639 java = get_java_version(cfg.java, str(ver)) 640 641 matrix.add( 642 MatrixItem( 643 name=name, 644 flavor=flavor, 645 category=category, 646 job_name=f"{name} / {category} / {ver}", 647 install=install, 648 run=run, 649 package=package_info.pip_release, 650 version=ver, 651 python=python, 652 java=java, 653 supported=ver <= cfg.maximum, 654 free_disk_space=free_disk_space, 655 runs_on=runs_on, 656 pre_test=cfg.pre_test, 657 ) 658 ) 659 660 # Add tracing SDK test with the latest stable version 661 if len(versions) > 0 and category == "autologging" and cfg.test_tracing_sdk: 662 version = max(versions) # Test against the latest stable version 663 matrix.add( 664 MatrixItem( 665 name=f"{name}-tracing", 666 flavor=flavor, 667 category="tracing-sdk", 668 job_name=f"{name} / tracing-sdk / {version}", 669 install=install, 670 # --import-mode=importlib is required for testing tracing SDK 671 # (mlflow-tracing) works properly, without being affected by environment. 672 run=run.replace("pytest", "pytest --import-mode=importlib"), 673 package=package_info.pip_release, 674 version=version, 675 java=java, 676 supported=version <= cfg.maximum, 677 free_disk_space=free_disk_space, 678 python=python, 679 runs_on=runs_on, 680 ) 681 ) 682 683 # Skip dev version testing: install_dev installs from git, which 684 # doesn't respect UV_EXCLUDE_NEWER. 685 if False: # package_info.install_dev: 686 install_dev = remove_comments(package_info.install_dev) 687 if requirements := get_matched_requirements(cfg.requirements or {}, DEV_VERSION): 688 install = make_pip_install_command(requirements) + "\n" + install_dev 689 else: 690 install = install_dev 691 python = get_python_version( 692 cfg.python, package_info.pip_release, DEV_VERSION, package_info.repo 693 ) 694 runs_on = get_runs_on(cfg.runs_on, DEV_VERSION) 695 java = get_java_version(cfg.java, DEV_VERSION) 696 697 run = remove_comments(cfg.run) 698 dev_version = Version.create_dev() 699 matrix.add( 700 MatrixItem( 701 name=name, 702 flavor=flavor, 703 category=category, 704 job_name=f"{name} / {category} / {dev_version}", 705 install=install, 706 run=run, 707 package=package_info.pip_release, 708 version=dev_version, 709 python=python, 710 java=java, 711 supported=False, 712 free_disk_space=free_disk_space, 713 runs_on=runs_on, 714 pre_test=cfg.pre_test, 715 ) 716 ) 717 return matrix 718 719 720 def apply_changed_files(changed_files, matrix): 721 all_flavors = {x.flavor for x in matrix} 722 changed_flavors = ( 723 # If this file has been changed, re-run all tests 724 all_flavors 725 if str(Path(__file__).relative_to(Path.cwd())) in changed_files 726 else get_changed_flavors(changed_files, all_flavors) 727 ) 728 729 # Run langchain tests if any tracing files have been changed 730 if any(f.startswith("mlflow/tracing/") for f in changed_files): 731 changed_flavors.add("langchain") 732 733 return set(filter(lambda x: x.flavor in changed_flavors, matrix)) 734 735 736 def generate_matrix(args): 737 args = parse_args(args) 738 config = read_yaml(args.versions_yaml) 739 if (args.ref_versions_yaml, args.changed_files).count(None) == 2: 740 matrix = expand_config(config) 741 else: 742 matrix = set() 743 mat = expand_config(config) 744 745 if args.ref_versions_yaml: 746 ref_config = read_yaml(args.ref_versions_yaml, if_error={}) 747 ref_matrix = expand_config(ref_config, is_ref=True) 748 matrix.update(mat.difference(ref_matrix)) 749 750 if args.changed_files: 751 matrix.update(apply_changed_files(args.changed_files, mat)) 752 753 # Apply the filtering arguments 754 if args.no_dev: 755 matrix = filter(lambda x: x.version != Version.create_dev(), matrix) 756 757 if args.flavors: 758 matrix = filter(lambda x: x.flavor in args.flavors, matrix) 759 760 if args.versions: 761 matrix = filter(lambda x: x.version in map(Version, args.versions), matrix) 762 763 if args.only_latest: 764 groups = defaultdict(list) 765 for item in matrix: 766 groups[(item.name, item.category)].append(item) 767 matrix = {max(group, key=lambda x: x.version) for group in groups.values()} 768 769 return set(matrix) 770 771 772 class CustomEncoder(json.JSONEncoder): 773 def default(self, o): 774 if isinstance(o, MatrixItem): 775 return o.model_dump(exclude_none=True) 776 elif isinstance(o, Version): 777 return str(o) 778 return super().default(o) 779 780 781 def set_action_output(name, value): 782 with open(os.environ.get("GITHUB_OUTPUT"), "a") as f: 783 f.write(f"{name}={value}\n") 784 785 786 def split(matrix, n): 787 grouped_by_name = defaultdict(list) 788 for item in matrix: 789 grouped_by_name[item.name].append(item) 790 791 num = len(matrix) // n 792 chunk = [] 793 for group in grouped_by_name.values(): 794 chunk.extend(group) 795 if len(chunk) >= num: 796 yield chunk 797 chunk = [] 798 799 if chunk: 800 yield chunk 801 802 803 def main(args): 804 # https://docs.github.com/en/actions/learn-github-actions/usage-limits-billing-and-administration#usage-limits 805 # > A job matrix can generate a maximum of 256 jobs per workflow run. 806 MAX_ITEMS = 256 807 NUM_JOBS = 2 808 809 print(divider("Parameters")) 810 print(json.dumps(args, indent=2)) 811 matrix = generate_matrix(args) 812 matrix = sorted(matrix, key=lambda x: (x.name, x.category, x.version)) 813 assert len(matrix) <= MAX_ITEMS * 2, f"Too many jobs: {len(matrix)} > {MAX_ITEMS * NUM_JOBS}" 814 for idx, mat in enumerate(split(matrix, NUM_JOBS), start=1): 815 mat = {"include": mat, "job_name": [x.job_name for x in mat]} 816 print(divider(f"Matrix {idx}")) 817 print(json.dumps(mat, indent=2, cls=CustomEncoder)) 818 if "GITHUB_ACTIONS" in os.environ: 819 set_action_output(f"matrix{idx}", json.dumps(mat, cls=CustomEncoder)) 820 set_action_output(f"is_matrix{idx}_empty", "true" if len(mat) == 0 else "false") 821 822 823 if __name__ == "__main__": 824 main(sys.argv[1:])