/ dev / pyproject.py
pyproject.py
  1  from __future__ import annotations
  2  
  3  import re
  4  import subprocess
  5  import sys
  6  from collections import Counter
  7  from enum import Enum
  8  from pathlib import Path
  9  from typing import Any, cast
 10  
 11  import toml
 12  import yaml
 13  from packaging.version import Version
 14  from pydantic import BaseModel, Field, RootModel
 15  
 16  
 17  class PackageType(Enum):
 18      SKINNY = "skinny"
 19      RELEASE = "release"
 20      DEV = "dev"
 21      TRACING = "tracing"
 22  
 23      def description(self) -> str:
 24          WARNING = "# Auto-generated by dev/pyproject.py. Do not edit manually."
 25  
 26          if self is PackageType.TRACING:
 27              return f"""{WARNING}
 28  # This file defines the package metadata of `mlflow-tracing`.
 29  """
 30  
 31          if self is PackageType.SKINNY:
 32              return f"""{WARNING}
 33  # This file defines the package metadata of `mlflow-skinny`.
 34  """
 35          if self is PackageType.RELEASE:
 36              return f"""{WARNING}
 37  # This file defines the package metadata of `mlflow`. `mlflow-skinny` and `mlflow-tracing`
 38  # are included in the requirements to prevent a version mismatch between `mlflow` and those
 39  # child packages. This file will replace `pyproject.toml` when releasing a new version.
 40  """
 41          if self is PackageType.DEV:
 42              return f"""{WARNING}
 43  # This file defines the package metadata of `mlflow` **during development**. To install `mlflow`
 44  # from the source code, `mlflow-skinny` and `mlflow-tracing` are NOT included in the requirements.
 45  # This file will be replaced by `pyproject.release.toml` when releasing a new version.
 46  """
 47          raise ValueError(f"Unreachable: {self}")
 48  
 49  
 50  SEPARATOR = """
 51  # Package metadata: can't be updated manually, use dev/pyproject.py
 52  # -----------------------------------------------------------------
 53  # Dev tool settings: can be updated manually
 54  
 55  """
 56  
 57  SKINNY_README = """
 58  <!--  Autogenerated by dev/pyproject.py. Do not edit manually.  -->
 59  
 60  📣 This is the `mlflow-skinny` package, a lightweight MLflow package without SQL storage, server, UI, or data science dependencies.
 61  Additional dependencies can be installed to leverage the full feature set of MLflow. For example:
 62  
 63  - To use the `mlflow.sklearn` component of MLflow Models, install `scikit-learn`, `numpy` and `pandas`.
 64  - To use SQL-based metadata storage, install `sqlalchemy`, `alembic`, and `sqlparse`.
 65  - To use serving-based features, install `flask` and `pandas`.
 66  
 67  **Note:** When using `mlflow-skinny`, set the tracking URI to your remote MLflow server:
 68  
 69  ```bash
 70  export MLFLOW_TRACKING_URI="http://your-mlflow-server:5000"
 71  ```
 72  
 73  ---
 74  
 75  <br>
 76  <br>
 77  
 78  """  # noqa: E501
 79  
 80  # Tracing SDK should only include the minimum set of MLflow modules
 81  # to minimize the size of the package.
 82  TRACING_INCLUDE_FILES = [
 83      "mlflow",
 84      # Flavors that we support auto tracing
 85      "mlflow.agno*",
 86      "mlflow.anthropic*",
 87      "mlflow.autogen*",
 88      "mlflow.bedrock*",
 89      "mlflow.crewai*",
 90      "mlflow.dspy*",
 91      "mlflow.gemini*",
 92      "mlflow.groq*",
 93      "mlflow.langchain*",
 94      "mlflow.litellm*",
 95      "mlflow.llama_index*",
 96      "mlflow.mistral*",
 97      "mlflow.openai*",
 98      "mlflow.strands*",
 99      "mlflow.haystack*",
100      # Other necessary modules
101      "mlflow.azure*",
102      "mlflow.entities*",
103      "mlflow.environment_variables",
104      "mlflow.exceptions",
105      "mlflow.legacy_databricks_cli*",
106      "mlflow.prompt*",
107      "mlflow.protos*",
108      "mlflow.pydantic_ai*",
109      "mlflow.smolagents*",
110      "mlflow.store*",
111      "mlflow.telemetry*",
112      "mlflow.tracing*",
113      "mlflow.tracking*",
114      "mlflow.types*",
115      "mlflow.utils*",
116      "mlflow.version",
117  ]
118  TRACING_EXCLUDE_FILES = [
119      # Large proto files that are not needed in the package
120      "mlflow/protos/databricks_artifacts_pb2.py",
121      "mlflow/protos/databricks_filesystem_service_pb2.py",
122      "mlflow/protos/databricks_uc_registry_messages_pb2.py",
123      "mlflow/protos/databricks_uc_registry_service_pb2.py",
124      "mlflow/protos/model_registry_pb2.py",
125      "mlflow/protos/unity_catalog_oss_messages_pb2.py",
126      "mlflow/protos/unity_catalog_oss_service_pb2.py",
127      # Test files
128      "tests",
129      "tests.*",
130  ]
131  
132  
133  def find_duplicates(seq: list[str]) -> list[str]:
134      counted = Counter(seq)
135      return [item for item, count in counted.items() if count > 1]
136  
137  
138  def write_file_if_changed(file_path: Path, new_content: str) -> None:
139      if file_path.exists():
140          existing_content = file_path.read_text()
141          if existing_content == new_content:
142              print(f"No changes in {file_path}, skipping write.")
143              return
144  
145      print(f"Writing changes to {file_path}.")
146      file_path.write_text(new_content)
147  
148  
149  def format_content_with_taplo(content: str) -> str:
150      return (
151          subprocess.check_output(
152              ["bin/taplo", "fmt", "-"],
153              input=content,
154              text=True,
155          ).strip()
156          + "\n"
157      )
158  
159  
160  def write_toml_file_if_changed(
161      file_path: Path, description: str, toml_data: dict[str, Any]
162  ) -> None:
163      """
164      Write a TOML file with description only if content has changed.
165      Formats content with taplo before comparison.
166      """
167      new_content = description + "\n" + toml.dumps(toml_data)
168      formatted_content = format_content_with_taplo(new_content)
169      write_file_if_changed(file_path, formatted_content)
170  
171  
172  class PackageRequirement(BaseModel):
173      pip_release: str = Field(..., description="The pip package name")
174      max_major_version: int = Field(..., description="Maximum major version allowed")
175      minimum: str | None = Field(None, description="Minimum version required")
176      unsupported: list[str] | None = Field(None, description="List of unsupported versions")
177      markers: str | None = Field(
178          None, description="Environment markers for conditional installation"
179      )
180      extras: list[str] | None = Field(None, description="Package extras to install")
181      freeze: bool | None = Field(None, description="Whether to freeze this package version")
182  
183  
184  RequirementsYaml = RootModel[dict[str, PackageRequirement]]
185  
186  
187  def generate_requirements_from_yaml(requirements_yaml: RequirementsYaml) -> list[str]:
188      """Generate pip requirement strings from validated YAML specification."""
189      requirement_strs: list[str] = []
190      for package_entry in requirements_yaml.root.values():
191          pip_release = package_entry.pip_release
192          version_specs: list[str] = []
193  
194          extras = f"[{','.join(package_entry.extras)}]" if package_entry.extras else ""
195  
196          max_major_version = package_entry.max_major_version
197          version_specs.append(f"<{max_major_version + 1}")
198  
199          if package_entry.minimum:
200              version_specs.append(f">={package_entry.minimum}")
201  
202          if package_entry.unsupported:
203              version_specs.extend(f"!={version}" for version in package_entry.unsupported)
204  
205          markers = f"; {package_entry.markers}" if package_entry.markers else ""
206  
207          requirement_str = f"{pip_release}{extras}{','.join(version_specs)}{markers}"
208          requirement_strs.append(requirement_str)
209  
210      requirement_strs.sort()
211      return requirement_strs
212  
213  
214  def read_requirements_yaml(yaml_path: Path) -> list[str]:
215      """Read and parse a YAML requirements file into pip requirement strings."""
216      with yaml_path.open() as f:
217          requirements_data = yaml.safe_load(f)
218  
219      return generate_requirements_from_yaml(RequirementsYaml(requirements_data))
220  
221  
222  def read_package_versions_yml() -> dict[str, Any]:
223      with open("mlflow/ml-package-versions.yml") as f:
224          return cast(dict[str, Any], yaml.safe_load(f))
225  
226  
227  def build(package_type: PackageType) -> None:
228      requirements_dir = Path("requirements")
229      tracing_requirements = read_requirements_yaml(requirements_dir / "tracing-requirements.yaml")
230      skinny_requirements = read_requirements_yaml(requirements_dir / "skinny-requirements.yaml")
231      _check_skinny_tracing_mismatch(
232          skinny_reqs=skinny_requirements, tracing_reqs=tracing_requirements
233      )
234      core_requirements = read_requirements_yaml(requirements_dir / "core-requirements.yaml")
235      gateways_requirements = read_requirements_yaml(requirements_dir / "gateway-requirements.yaml")
236      genai_requirements = read_requirements_yaml(requirements_dir / "genai-requirements.yaml")
237      version_match = re.search(
238          r'^VERSION = "([a-z0-9\.]+)"$', Path("mlflow", "version.py").read_text(), re.MULTILINE
239      )
240      if version_match is None:
241          raise ValueError(
242              'Could not find VERSION in mlflow/version.py. Expected format: VERSION = "x.y.z"'
243          )
244      package_version = version_match.group(1)
245      python_version = Path(".python-version").read_text().strip()
246      versions_yaml = read_package_versions_yml()
247      langchain_requirements = [
248          "langchain>={},<={}".format(
249              max(
250                  Version(versions_yaml["langchain"]["autologging"]["minimum"]),
251                  Version(versions_yaml["langchain"]["models"]["minimum"]),
252              ),
253              min(
254                  Version(versions_yaml["langchain"]["autologging"]["maximum"]),
255                  Version(versions_yaml["langchain"]["models"]["maximum"]),
256              ),
257          )
258      ]
259  
260      match package_type:
261          case PackageType.TRACING:
262              dependencies = sorted(tracing_requirements)
263          case PackageType.SKINNY:
264              dependencies = sorted(skinny_requirements)
265          case PackageType.RELEASE:
266              dependencies = [
267                  f"mlflow-skinny=={package_version}",
268                  f"mlflow-tracing=={package_version}",
269              ] + sorted(core_requirements)
270          case PackageType.DEV:
271              # skinny_requirements is an exact superset of tracing_requirements
272              # (validated above), so we don't need to include both below.
273              dependencies = sorted(core_requirements + skinny_requirements)
274          case _:
275              raise ValueError(f"Unreachable: {package_type}")
276  
277      if dep_duplicates := find_duplicates(dependencies):
278          raise RuntimeError(f"Duplicated dependencies are found: {dep_duplicates}")
279  
280      match package_type:
281          case PackageType.TRACING:
282              package_name = "mlflow-tracing"
283          case PackageType.SKINNY:
284              package_name = "mlflow-skinny"
285          case _:
286              package_name = "mlflow"
287  
288      description = (
289          "MLflow is an open source platform for the complete machine learning lifecycle"
290          if package_type != PackageType.TRACING
291          else (
292              "MLflow Tracing SDK is an open-source, lightweight Python package that only "
293              "includes the minimum set of dependencies and functionality to instrument "
294              "your code/models/agents with MLflow Tracing."
295          )
296      )
297  
298      data = {
299          "build-system": {
300              "requires": ["setuptools<=82.0.1"],
301              "build-backend": "setuptools.build_meta",
302          },
303          "project": {
304              "name": package_name,
305              "version": package_version,
306              "maintainers": [
307                  {"name": "Databricks", "email": "mlflow-oss-maintainers@googlegroups.com"}
308              ],
309              "description": description,
310              "readme": "README_SKINNY.md" if package_type == PackageType.SKINNY else "README.md",
311              "license": {
312                  "file": "LICENSE.txt",
313              },
314              "keywords": ["mlflow", "ai", "databricks"],
315              "classifiers": [
316                  "Development Status :: 5 - Production/Stable",
317                  "Intended Audience :: Developers",
318                  "Intended Audience :: End Users/Desktop",
319                  "Intended Audience :: Science/Research",
320                  "Intended Audience :: Information Technology",
321                  "Topic :: Scientific/Engineering :: Artificial Intelligence",
322                  "Topic :: Software Development :: Libraries :: Python Modules",
323                  "License :: OSI Approved :: Apache Software License",
324                  "Operating System :: OS Independent",
325                  f"Programming Language :: Python :: {python_version}",
326              ],
327              "requires-python": f">={python_version}",
328              "dependencies": dependencies,
329              "optional-dependencies": {
330                  "extras": [
331                      # Required to log artifacts and models to HDFS artifact locations
332                      "pyarrow",
333                      # Required to sign outgoing request with SigV4 signature
334                      "requests-auth-aws-sigv4",
335                      # Required to log artifacts and models to AWS S3 artifact locations
336                      "boto3",
337                      "botocore",
338                      # Required to log artifacts and models to GCS artifact locations
339                      "google-cloud-storage>=1.30.0",
340                      "azureml-core>=1.2.0",
341                      # Required to log artifacts to SFTP artifact locations
342                      "pysftp",
343                      # Required by the mlflow.projects module, when running projects against
344                      # a remote Kubernetes cluster
345                      "kubernetes",
346                      # Required for exporting metrics from the MLflow server to Prometheus
347                      # as part of the MLflow server monitoring add-on
348                      "prometheus-flask-exporter",
349                  ],
350                  "db": [
351                      # Required to use MySQL, PostgreSQL, or SQL Server as the backend store
352                      "PyMySQL",
353                      "psycopg2-binary",
354                      "pymssql",
355                  ],
356                  "databricks": [
357                      # Required to write model artifacts to unity catalog locations
358                      "azure-storage-file-datalake>12",
359                      "google-cloud-storage>=1.30.0",
360                      "boto3>1",
361                      "botocore",
362                      "databricks-agents>=1.2.0,<2.0",
363                  ],
364                  "mlserver": [
365                      # Required to serve models through MLServer
366                      "mlserver>=1.2.0,!=1.3.1,<2.0.0",
367                      "mlserver-mlflow>=1.2.0,!=1.3.1,<2.0.0",
368                  ],
369                  "gateway": gateways_requirements,
370                  "genai": genai_requirements,
371                  # click 8.3.0 causes MLflow MCP server to fail: https://github.com/mlflow/mlflow/issues/18747
372                  "mcp": ["fastmcp<4,>=2.0.0", "click!=8.3.0"],
373                  "azure": [
374                      # Required to log artifacts and models to Azure Blob Storage
375                      "azure-storage-blob>=12",
376                      "azure-identity>=1.6.1",
377                  ],
378                  "sqlserver": ["mlflow-dbstore"],
379                  "aliyun-oss": ["aliyunstoreplugin"],
380                  "jfrog": ["mlflow-jfrog-plugin"],
381                  "kubernetes": ["kubernetes"],
382                  "langchain": langchain_requirements,
383                  "auth": ["Flask-WTF<2"],
384              }
385              # Tracing SDK does not support extras
386              if package_type != PackageType.TRACING
387              else None,
388              "urls": {
389                  "homepage": "https://mlflow.org",
390                  "issues": "https://github.com/mlflow/mlflow/issues",
391                  "documentation": "https://mlflow.org/docs/latest",
392                  "repository": "https://github.com/mlflow/mlflow",
393              },
394              "scripts": {
395                  "mlflow": "mlflow.cli:cli",
396              }
397              if package_type != PackageType.TRACING
398              else None,
399              "entry-points": {
400                  "mlflow.app": {
401                      "basic-auth": "mlflow.server.auth:create_app",
402                  },
403                  "mlflow.app.client": {
404                      "basic-auth": "mlflow.server.auth.client:AuthServiceClient",
405                  },
406                  "mlflow.deployments": {
407                      "databricks": "mlflow.deployments.databricks",
408                      "http": "mlflow.deployments.mlflow",
409                      "https": "mlflow.deployments.mlflow",
410                      "openai": "mlflow.deployments.openai",
411                  },
412              }
413              if package_type != PackageType.TRACING
414              else None,
415          },
416          "tool": {
417              "setuptools": {
418                  "packages": {
419                      "find": {
420                          "where": ["."],
421                          "include": ["mlflow", "mlflow.*"]
422                          if package_type != PackageType.TRACING
423                          else TRACING_INCLUDE_FILES,
424                          "exclude": ["tests", "tests.*"]
425                          if package_type != PackageType.TRACING
426                          else TRACING_EXCLUDE_FILES,
427                          "namespaces": False,
428                      }
429                  },
430                  "package-data": _get_package_data(package_type),
431              }
432          },
433      }
434  
435      if package_type == PackageType.TRACING:
436          out_path = Path("libs/tracing/pyproject.toml")
437          write_toml_file_if_changed(out_path, package_type.description(), data)
438      elif package_type == PackageType.SKINNY:
439          out_path = Path("libs/skinny/pyproject.toml")
440          write_toml_file_if_changed(out_path, package_type.description(), data)
441  
442          skinny_readme_path = Path("libs/skinny/README_SKINNY.md")
443          new_readme_content = SKINNY_README.lstrip() + Path("README.md").read_text()
444          write_file_if_changed(skinny_readme_path, new_readme_content)
445  
446          for f in ["LICENSE.txt", "MANIFEST.in", "mlflow"]:
447              symlink = Path("libs/skinny", f)
448              if symlink.exists():
449                  symlink.unlink()
450              target = Path("../..", f)
451              symlink.symlink_to(target, target_is_directory=target.is_dir())
452      elif package_type == PackageType.RELEASE:
453          out_path = Path(f"pyproject.{package_type.value}.toml")
454          write_toml_file_if_changed(out_path, package_type.description(), data)
455      else:
456          out_path = Path("pyproject.toml")
457          original_manual_content = out_path.read_text().split(SEPARATOR)[1]
458          generated_part = package_type.description() + "\n" + toml.dumps(data)
459          formatted_generated_part = format_content_with_taplo(generated_part)
460          formatted_full_content = formatted_generated_part + SEPARATOR + original_manual_content
461  
462          write_file_if_changed(out_path, formatted_full_content)
463          subprocess.check_call(["uv", "lock"])
464  
465  
466  def _get_package_data(package_type: PackageType) -> dict[str, list[str]] | None:
467      if package_type == PackageType.TRACING:
468          return None
469  
470      package_data = {
471          "mlflow": [
472              "store/db_migrations/alembic.ini",
473              "temporary_db_migrations_for_pre_1_users/alembic.ini",
474              "pyspark/ml/log_model_allowlist.txt",
475              "server/auth/basic_auth.ini",
476              "server/auth/db/migrations/alembic.ini",
477              "server/uvicorn_log_config.yaml",
478              "models/notebook_resources/**/*",
479              "ai_commands/**/*.md",
480              "assistant/skills/**/*",
481          ]
482      }
483  
484      if package_type != PackageType.SKINNY:
485          package_data["mlflow"] += [
486              "models/container/**/*",
487              "server/js/build/**/*",
488              "utils/model_catalog/*.json",
489          ]
490  
491      return package_data
492  
493  
494  def _check_skinny_tracing_mismatch(*, skinny_reqs: list[str], tracing_reqs: list[str]) -> None:
495      """
496      Check if the tracing requirements are a subset of the skinny requirements.
497      NB: We don't make mlflow-tracing as a hard dependency of mlflow-skinny because
498      it will complicate the package management (need another .release.toml file
499      that is dependent by pyproject.release.toml)
500      """
501      if diff := set(tracing_reqs) - set(skinny_reqs):
502          raise RuntimeError(
503              "Tracing requirements must be a subset of skinny requirements. "
504              "Please check the requirements/skinny-requirements.yaml and "
505              "requirements/tracing-requirements.yaml files.\n"
506              f"Diff: {diff}"
507          )
508  
509  
510  def main() -> None:
511      if not Path("bin/taplo").exists():
512          print(
513              "taplo is required to generate pyproject.toml. "
514              "Please run 'python bin/install.py' to install it."
515          )
516          sys.exit(1)
517  
518      for package_type in PackageType:
519          build(package_type)
520  
521  
522  if __name__ == "__main__":
523      main()