pyproject.py
1 from __future__ import annotations 2 3 import re 4 import subprocess 5 import sys 6 from collections import Counter 7 from enum import Enum 8 from pathlib import Path 9 from typing import Any, cast 10 11 import toml 12 import yaml 13 from packaging.version import Version 14 from pydantic import BaseModel, Field, RootModel 15 16 17 class PackageType(Enum): 18 SKINNY = "skinny" 19 RELEASE = "release" 20 DEV = "dev" 21 TRACING = "tracing" 22 23 def description(self) -> str: 24 WARNING = "# Auto-generated by dev/pyproject.py. Do not edit manually." 25 26 if self is PackageType.TRACING: 27 return f"""{WARNING} 28 # This file defines the package metadata of `mlflow-tracing`. 29 """ 30 31 if self is PackageType.SKINNY: 32 return f"""{WARNING} 33 # This file defines the package metadata of `mlflow-skinny`. 34 """ 35 if self is PackageType.RELEASE: 36 return f"""{WARNING} 37 # This file defines the package metadata of `mlflow`. `mlflow-skinny` and `mlflow-tracing` 38 # are included in the requirements to prevent a version mismatch between `mlflow` and those 39 # child packages. This file will replace `pyproject.toml` when releasing a new version. 40 """ 41 if self is PackageType.DEV: 42 return f"""{WARNING} 43 # This file defines the package metadata of `mlflow` **during development**. To install `mlflow` 44 # from the source code, `mlflow-skinny` and `mlflow-tracing` are NOT included in the requirements. 45 # This file will be replaced by `pyproject.release.toml` when releasing a new version. 46 """ 47 raise ValueError(f"Unreachable: {self}") 48 49 50 SEPARATOR = """ 51 # Package metadata: can't be updated manually, use dev/pyproject.py 52 # ----------------------------------------------------------------- 53 # Dev tool settings: can be updated manually 54 55 """ 56 57 SKINNY_README = """ 58 <!-- Autogenerated by dev/pyproject.py. Do not edit manually. --> 59 60 📣 This is the `mlflow-skinny` package, a lightweight MLflow package without SQL storage, server, UI, or data science dependencies. 61 Additional dependencies can be installed to leverage the full feature set of MLflow. For example: 62 63 - To use the `mlflow.sklearn` component of MLflow Models, install `scikit-learn`, `numpy` and `pandas`. 64 - To use SQL-based metadata storage, install `sqlalchemy`, `alembic`, and `sqlparse`. 65 - To use serving-based features, install `flask` and `pandas`. 66 67 **Note:** When using `mlflow-skinny`, set the tracking URI to your remote MLflow server: 68 69 ```bash 70 export MLFLOW_TRACKING_URI="http://your-mlflow-server:5000" 71 ``` 72 73 --- 74 75 <br> 76 <br> 77 78 """ # noqa: E501 79 80 # Tracing SDK should only include the minimum set of MLflow modules 81 # to minimize the size of the package. 82 TRACING_INCLUDE_FILES = [ 83 "mlflow", 84 # Flavors that we support auto tracing 85 "mlflow.agno*", 86 "mlflow.anthropic*", 87 "mlflow.autogen*", 88 "mlflow.bedrock*", 89 "mlflow.crewai*", 90 "mlflow.dspy*", 91 "mlflow.gemini*", 92 "mlflow.groq*", 93 "mlflow.langchain*", 94 "mlflow.litellm*", 95 "mlflow.llama_index*", 96 "mlflow.mistral*", 97 "mlflow.openai*", 98 "mlflow.strands*", 99 "mlflow.haystack*", 100 # Other necessary modules 101 "mlflow.azure*", 102 "mlflow.entities*", 103 "mlflow.environment_variables", 104 "mlflow.exceptions", 105 "mlflow.legacy_databricks_cli*", 106 "mlflow.prompt*", 107 "mlflow.protos*", 108 "mlflow.pydantic_ai*", 109 "mlflow.smolagents*", 110 "mlflow.store*", 111 "mlflow.telemetry*", 112 "mlflow.tracing*", 113 "mlflow.tracking*", 114 "mlflow.types*", 115 "mlflow.utils*", 116 "mlflow.version", 117 ] 118 TRACING_EXCLUDE_FILES = [ 119 # Large proto files that are not needed in the package 120 "mlflow/protos/databricks_artifacts_pb2.py", 121 "mlflow/protos/databricks_filesystem_service_pb2.py", 122 "mlflow/protos/databricks_uc_registry_messages_pb2.py", 123 "mlflow/protos/databricks_uc_registry_service_pb2.py", 124 "mlflow/protos/model_registry_pb2.py", 125 "mlflow/protos/unity_catalog_oss_messages_pb2.py", 126 "mlflow/protos/unity_catalog_oss_service_pb2.py", 127 # Test files 128 "tests", 129 "tests.*", 130 ] 131 132 133 def find_duplicates(seq: list[str]) -> list[str]: 134 counted = Counter(seq) 135 return [item for item, count in counted.items() if count > 1] 136 137 138 def write_file_if_changed(file_path: Path, new_content: str) -> None: 139 if file_path.exists(): 140 existing_content = file_path.read_text() 141 if existing_content == new_content: 142 print(f"No changes in {file_path}, skipping write.") 143 return 144 145 print(f"Writing changes to {file_path}.") 146 file_path.write_text(new_content) 147 148 149 def format_content_with_taplo(content: str) -> str: 150 return ( 151 subprocess.check_output( 152 ["bin/taplo", "fmt", "-"], 153 input=content, 154 text=True, 155 ).strip() 156 + "\n" 157 ) 158 159 160 def write_toml_file_if_changed( 161 file_path: Path, description: str, toml_data: dict[str, Any] 162 ) -> None: 163 """ 164 Write a TOML file with description only if content has changed. 165 Formats content with taplo before comparison. 166 """ 167 new_content = description + "\n" + toml.dumps(toml_data) 168 formatted_content = format_content_with_taplo(new_content) 169 write_file_if_changed(file_path, formatted_content) 170 171 172 class PackageRequirement(BaseModel): 173 pip_release: str = Field(..., description="The pip package name") 174 max_major_version: int = Field(..., description="Maximum major version allowed") 175 minimum: str | None = Field(None, description="Minimum version required") 176 unsupported: list[str] | None = Field(None, description="List of unsupported versions") 177 markers: str | None = Field( 178 None, description="Environment markers for conditional installation" 179 ) 180 extras: list[str] | None = Field(None, description="Package extras to install") 181 freeze: bool | None = Field(None, description="Whether to freeze this package version") 182 183 184 RequirementsYaml = RootModel[dict[str, PackageRequirement]] 185 186 187 def generate_requirements_from_yaml(requirements_yaml: RequirementsYaml) -> list[str]: 188 """Generate pip requirement strings from validated YAML specification.""" 189 requirement_strs: list[str] = [] 190 for package_entry in requirements_yaml.root.values(): 191 pip_release = package_entry.pip_release 192 version_specs: list[str] = [] 193 194 extras = f"[{','.join(package_entry.extras)}]" if package_entry.extras else "" 195 196 max_major_version = package_entry.max_major_version 197 version_specs.append(f"<{max_major_version + 1}") 198 199 if package_entry.minimum: 200 version_specs.append(f">={package_entry.minimum}") 201 202 if package_entry.unsupported: 203 version_specs.extend(f"!={version}" for version in package_entry.unsupported) 204 205 markers = f"; {package_entry.markers}" if package_entry.markers else "" 206 207 requirement_str = f"{pip_release}{extras}{','.join(version_specs)}{markers}" 208 requirement_strs.append(requirement_str) 209 210 requirement_strs.sort() 211 return requirement_strs 212 213 214 def read_requirements_yaml(yaml_path: Path) -> list[str]: 215 """Read and parse a YAML requirements file into pip requirement strings.""" 216 with yaml_path.open() as f: 217 requirements_data = yaml.safe_load(f) 218 219 return generate_requirements_from_yaml(RequirementsYaml(requirements_data)) 220 221 222 def read_package_versions_yml() -> dict[str, Any]: 223 with open("mlflow/ml-package-versions.yml") as f: 224 return cast(dict[str, Any], yaml.safe_load(f)) 225 226 227 def build(package_type: PackageType) -> None: 228 requirements_dir = Path("requirements") 229 tracing_requirements = read_requirements_yaml(requirements_dir / "tracing-requirements.yaml") 230 skinny_requirements = read_requirements_yaml(requirements_dir / "skinny-requirements.yaml") 231 _check_skinny_tracing_mismatch( 232 skinny_reqs=skinny_requirements, tracing_reqs=tracing_requirements 233 ) 234 core_requirements = read_requirements_yaml(requirements_dir / "core-requirements.yaml") 235 gateways_requirements = read_requirements_yaml(requirements_dir / "gateway-requirements.yaml") 236 genai_requirements = read_requirements_yaml(requirements_dir / "genai-requirements.yaml") 237 version_match = re.search( 238 r'^VERSION = "([a-z0-9\.]+)"$', Path("mlflow", "version.py").read_text(), re.MULTILINE 239 ) 240 if version_match is None: 241 raise ValueError( 242 'Could not find VERSION in mlflow/version.py. Expected format: VERSION = "x.y.z"' 243 ) 244 package_version = version_match.group(1) 245 python_version = Path(".python-version").read_text().strip() 246 versions_yaml = read_package_versions_yml() 247 langchain_requirements = [ 248 "langchain>={},<={}".format( 249 max( 250 Version(versions_yaml["langchain"]["autologging"]["minimum"]), 251 Version(versions_yaml["langchain"]["models"]["minimum"]), 252 ), 253 min( 254 Version(versions_yaml["langchain"]["autologging"]["maximum"]), 255 Version(versions_yaml["langchain"]["models"]["maximum"]), 256 ), 257 ) 258 ] 259 260 match package_type: 261 case PackageType.TRACING: 262 dependencies = sorted(tracing_requirements) 263 case PackageType.SKINNY: 264 dependencies = sorted(skinny_requirements) 265 case PackageType.RELEASE: 266 dependencies = [ 267 f"mlflow-skinny=={package_version}", 268 f"mlflow-tracing=={package_version}", 269 ] + sorted(core_requirements) 270 case PackageType.DEV: 271 # skinny_requirements is an exact superset of tracing_requirements 272 # (validated above), so we don't need to include both below. 273 dependencies = sorted(core_requirements + skinny_requirements) 274 case _: 275 raise ValueError(f"Unreachable: {package_type}") 276 277 if dep_duplicates := find_duplicates(dependencies): 278 raise RuntimeError(f"Duplicated dependencies are found: {dep_duplicates}") 279 280 match package_type: 281 case PackageType.TRACING: 282 package_name = "mlflow-tracing" 283 case PackageType.SKINNY: 284 package_name = "mlflow-skinny" 285 case _: 286 package_name = "mlflow" 287 288 description = ( 289 "MLflow is an open source platform for the complete machine learning lifecycle" 290 if package_type != PackageType.TRACING 291 else ( 292 "MLflow Tracing SDK is an open-source, lightweight Python package that only " 293 "includes the minimum set of dependencies and functionality to instrument " 294 "your code/models/agents with MLflow Tracing." 295 ) 296 ) 297 298 data = { 299 "build-system": { 300 "requires": ["setuptools<=82.0.1"], 301 "build-backend": "setuptools.build_meta", 302 }, 303 "project": { 304 "name": package_name, 305 "version": package_version, 306 "maintainers": [ 307 {"name": "Databricks", "email": "mlflow-oss-maintainers@googlegroups.com"} 308 ], 309 "description": description, 310 "readme": "README_SKINNY.md" if package_type == PackageType.SKINNY else "README.md", 311 "license": { 312 "file": "LICENSE.txt", 313 }, 314 "keywords": ["mlflow", "ai", "databricks"], 315 "classifiers": [ 316 "Development Status :: 5 - Production/Stable", 317 "Intended Audience :: Developers", 318 "Intended Audience :: End Users/Desktop", 319 "Intended Audience :: Science/Research", 320 "Intended Audience :: Information Technology", 321 "Topic :: Scientific/Engineering :: Artificial Intelligence", 322 "Topic :: Software Development :: Libraries :: Python Modules", 323 "License :: OSI Approved :: Apache Software License", 324 "Operating System :: OS Independent", 325 f"Programming Language :: Python :: {python_version}", 326 ], 327 "requires-python": f">={python_version}", 328 "dependencies": dependencies, 329 "optional-dependencies": { 330 "extras": [ 331 # Required to log artifacts and models to HDFS artifact locations 332 "pyarrow", 333 # Required to sign outgoing request with SigV4 signature 334 "requests-auth-aws-sigv4", 335 # Required to log artifacts and models to AWS S3 artifact locations 336 "boto3", 337 "botocore", 338 # Required to log artifacts and models to GCS artifact locations 339 "google-cloud-storage>=1.30.0", 340 "azureml-core>=1.2.0", 341 # Required to log artifacts to SFTP artifact locations 342 "pysftp", 343 # Required by the mlflow.projects module, when running projects against 344 # a remote Kubernetes cluster 345 "kubernetes", 346 # Required for exporting metrics from the MLflow server to Prometheus 347 # as part of the MLflow server monitoring add-on 348 "prometheus-flask-exporter", 349 ], 350 "db": [ 351 # Required to use MySQL, PostgreSQL, or SQL Server as the backend store 352 "PyMySQL", 353 "psycopg2-binary", 354 "pymssql", 355 ], 356 "databricks": [ 357 # Required to write model artifacts to unity catalog locations 358 "azure-storage-file-datalake>12", 359 "google-cloud-storage>=1.30.0", 360 "boto3>1", 361 "botocore", 362 "databricks-agents>=1.2.0,<2.0", 363 ], 364 "mlserver": [ 365 # Required to serve models through MLServer 366 "mlserver>=1.2.0,!=1.3.1,<2.0.0", 367 "mlserver-mlflow>=1.2.0,!=1.3.1,<2.0.0", 368 ], 369 "gateway": gateways_requirements, 370 "genai": genai_requirements, 371 # click 8.3.0 causes MLflow MCP server to fail: https://github.com/mlflow/mlflow/issues/18747 372 "mcp": ["fastmcp<4,>=2.0.0", "click!=8.3.0"], 373 "azure": [ 374 # Required to log artifacts and models to Azure Blob Storage 375 "azure-storage-blob>=12", 376 "azure-identity>=1.6.1", 377 ], 378 "sqlserver": ["mlflow-dbstore"], 379 "aliyun-oss": ["aliyunstoreplugin"], 380 "jfrog": ["mlflow-jfrog-plugin"], 381 "kubernetes": ["kubernetes"], 382 "langchain": langchain_requirements, 383 "auth": ["Flask-WTF<2"], 384 } 385 # Tracing SDK does not support extras 386 if package_type != PackageType.TRACING 387 else None, 388 "urls": { 389 "homepage": "https://mlflow.org", 390 "issues": "https://github.com/mlflow/mlflow/issues", 391 "documentation": "https://mlflow.org/docs/latest", 392 "repository": "https://github.com/mlflow/mlflow", 393 }, 394 "scripts": { 395 "mlflow": "mlflow.cli:cli", 396 } 397 if package_type != PackageType.TRACING 398 else None, 399 "entry-points": { 400 "mlflow.app": { 401 "basic-auth": "mlflow.server.auth:create_app", 402 }, 403 "mlflow.app.client": { 404 "basic-auth": "mlflow.server.auth.client:AuthServiceClient", 405 }, 406 "mlflow.deployments": { 407 "databricks": "mlflow.deployments.databricks", 408 "http": "mlflow.deployments.mlflow", 409 "https": "mlflow.deployments.mlflow", 410 "openai": "mlflow.deployments.openai", 411 }, 412 } 413 if package_type != PackageType.TRACING 414 else None, 415 }, 416 "tool": { 417 "setuptools": { 418 "packages": { 419 "find": { 420 "where": ["."], 421 "include": ["mlflow", "mlflow.*"] 422 if package_type != PackageType.TRACING 423 else TRACING_INCLUDE_FILES, 424 "exclude": ["tests", "tests.*"] 425 if package_type != PackageType.TRACING 426 else TRACING_EXCLUDE_FILES, 427 "namespaces": False, 428 } 429 }, 430 "package-data": _get_package_data(package_type), 431 } 432 }, 433 } 434 435 if package_type == PackageType.TRACING: 436 out_path = Path("libs/tracing/pyproject.toml") 437 write_toml_file_if_changed(out_path, package_type.description(), data) 438 elif package_type == PackageType.SKINNY: 439 out_path = Path("libs/skinny/pyproject.toml") 440 write_toml_file_if_changed(out_path, package_type.description(), data) 441 442 skinny_readme_path = Path("libs/skinny/README_SKINNY.md") 443 new_readme_content = SKINNY_README.lstrip() + Path("README.md").read_text() 444 write_file_if_changed(skinny_readme_path, new_readme_content) 445 446 for f in ["LICENSE.txt", "MANIFEST.in", "mlflow"]: 447 symlink = Path("libs/skinny", f) 448 if symlink.exists(): 449 symlink.unlink() 450 target = Path("../..", f) 451 symlink.symlink_to(target, target_is_directory=target.is_dir()) 452 elif package_type == PackageType.RELEASE: 453 out_path = Path(f"pyproject.{package_type.value}.toml") 454 write_toml_file_if_changed(out_path, package_type.description(), data) 455 else: 456 out_path = Path("pyproject.toml") 457 original_manual_content = out_path.read_text().split(SEPARATOR)[1] 458 generated_part = package_type.description() + "\n" + toml.dumps(data) 459 formatted_generated_part = format_content_with_taplo(generated_part) 460 formatted_full_content = formatted_generated_part + SEPARATOR + original_manual_content 461 462 write_file_if_changed(out_path, formatted_full_content) 463 subprocess.check_call(["uv", "lock"]) 464 465 466 def _get_package_data(package_type: PackageType) -> dict[str, list[str]] | None: 467 if package_type == PackageType.TRACING: 468 return None 469 470 package_data = { 471 "mlflow": [ 472 "store/db_migrations/alembic.ini", 473 "temporary_db_migrations_for_pre_1_users/alembic.ini", 474 "pyspark/ml/log_model_allowlist.txt", 475 "server/auth/basic_auth.ini", 476 "server/auth/db/migrations/alembic.ini", 477 "server/uvicorn_log_config.yaml", 478 "models/notebook_resources/**/*", 479 "ai_commands/**/*.md", 480 "assistant/skills/**/*", 481 ] 482 } 483 484 if package_type != PackageType.SKINNY: 485 package_data["mlflow"] += [ 486 "models/container/**/*", 487 "server/js/build/**/*", 488 "utils/model_catalog/*.json", 489 ] 490 491 return package_data 492 493 494 def _check_skinny_tracing_mismatch(*, skinny_reqs: list[str], tracing_reqs: list[str]) -> None: 495 """ 496 Check if the tracing requirements are a subset of the skinny requirements. 497 NB: We don't make mlflow-tracing as a hard dependency of mlflow-skinny because 498 it will complicate the package management (need another .release.toml file 499 that is dependent by pyproject.release.toml) 500 """ 501 if diff := set(tracing_reqs) - set(skinny_reqs): 502 raise RuntimeError( 503 "Tracing requirements must be a subset of skinny requirements. " 504 "Please check the requirements/skinny-requirements.yaml and " 505 "requirements/tracing-requirements.yaml files.\n" 506 f"Diff: {diff}" 507 ) 508 509 510 def main() -> None: 511 if not Path("bin/taplo").exists(): 512 print( 513 "taplo is required to generate pyproject.toml. " 514 "Please run 'python bin/install.py' to install it." 515 ) 516 sys.exit(1) 517 518 for package_type in PackageType: 519 build(package_type) 520 521 522 if __name__ == "__main__": 523 main()