/ dev / generate_protos.py
generate_protos.py
  1  import platform
  2  import subprocess
  3  import tempfile
  4  import textwrap
  5  import urllib.request
  6  import zipfile
  7  from pathlib import Path
  8  from typing import Literal
  9  
 10  SYSTEM = platform.system()
 11  MACHINE = platform.machine()
 12  CACHE_DIR = Path(".cache/protobuf_cache")
 13  MLFLOW_PROTOS_DIR = Path("mlflow/protos")
 14  TEST_PROTOS_DIR = Path("tests/protos")
 15  OTEL_PROTOS_DIR = Path("mlflow/protos/opentelemetry")
 16  
 17  
 18  def gen_protos(
 19      proto_dir: Path,
 20      proto_files: list[Path],
 21      lang: Literal["python", "java"],
 22      protoc_bin: Path,
 23      protoc_include_paths: list[Path],
 24      out_dir: Path,
 25  ) -> None:
 26      assert lang in ["python", "java"]
 27      out_dir.mkdir(parents=True, exist_ok=True)
 28      subprocess.check_call([
 29          protoc_bin,
 30          "--fatal_warnings",
 31          *(f"-I={p}" for p in protoc_include_paths),
 32          f"-I={proto_dir}",
 33          f"--{lang}_out={out_dir}",
 34          *[proto_dir / pf for pf in proto_files],
 35      ])
 36  
 37  
 38  def gen_stub_files(
 39      proto_dir: Path,
 40      proto_files: list[Path],
 41      protoc_bin: Path,
 42      protoc_include_paths: list[Path],
 43      out_dir: Path,
 44  ) -> None:
 45      out_dir.mkdir(parents=True, exist_ok=True)
 46      subprocess.check_call([
 47          protoc_bin,
 48          "--fatal_warnings",
 49          *(f"-I={p}" for p in protoc_include_paths),
 50          f"-I={proto_dir}",
 51          f"--pyi_out={out_dir}",
 52          *[proto_dir / pf for pf in proto_files],
 53      ])
 54  
 55  
 56  def gen_proto_docs(
 57      proto_dir: Path,
 58      proto_files: list[Path],
 59      protoc_bin: Path,
 60      protoc_include_path: Path,
 61      out_dir: Path,
 62  ) -> None:
 63      plugin_path = Path("dev/proto-plugin.sh").resolve()
 64      subprocess.check_call([
 65          protoc_bin,
 66          f"-I={protoc_include_path}",
 67          f"-I={proto_dir}",
 68          f"--plugin=protoc-gen-doc={plugin_path}",
 69          f"--doc_out={out_dir}",
 70          *[proto_dir / pf for pf in proto_files],
 71      ])
 72  
 73  
 74  def apply_python_gencode_replacement(file_path: Path) -> None:
 75      content = file_path.read_text()
 76  
 77      for old, new in python_gencode_replacements:
 78          content = content.replace(old, new)
 79  
 80      file_path.write_text(content, encoding="UTF-8")
 81  
 82  
 83  def _get_python_output_path(proto_file_path: Path) -> Path:
 84      return proto_file_path.parent / (proto_file_path.stem + "_pb2.py")
 85  
 86  
 87  def to_paths(*args: str) -> list[Path]:
 88      return list(map(Path, args))
 89  
 90  
 91  basic_proto_files = to_paths(
 92      "databricks.proto",
 93      "service.proto",
 94      "model_registry.proto",
 95      "databricks_artifacts.proto",
 96      "mlflow_artifacts.proto",
 97      "internal.proto",
 98      "scalapb/scalapb.proto",
 99      "assessments.proto",
100      "datasets.proto",
101      "issues.proto",
102      "webhooks.proto",
103      "jobs.proto",
104      "prompt_optimization.proto",
105  )
106  uc_proto_files = to_paths(
107      "databricks_managed_catalog_messages.proto",
108      "databricks_managed_catalog_service.proto",
109      "databricks_uc_registry_messages.proto",
110      "databricks_uc_registry_service.proto",
111      "databricks_filesystem_service.proto",
112      "unity_catalog_oss_messages.proto",
113      "unity_catalog_oss_service.proto",
114      "unity_catalog_prompt_messages.proto",
115      "unity_catalog_prompt_service.proto",
116  )
117  tracing_proto_files = to_paths(
118      "databricks_exception_with_details.proto",
119      "databricks_tracing.proto",
120  )
121  facet_proto_files = to_paths("facet_feature_statistics.proto")
122  python_proto_files = basic_proto_files + uc_proto_files + facet_proto_files + tracing_proto_files
123  test_proto_files = to_paths("test_message.proto")
124  
125  
126  python_gencode_replacements = [
127      (
128          "from scalapb import scalapb_pb2 as scalapb_dot_scalapb__pb2",
129          "from .scalapb import scalapb_pb2 as scalapb_dot_scalapb__pb2",
130      ),
131      (
132          "import databricks_pb2 as databricks__pb2",
133          "from . import databricks_pb2 as databricks__pb2",
134      ),
135      (
136          "import databricks_uc_registry_messages_pb2 as databricks__uc__registry__messages__pb2",
137          "from . import databricks_uc_registry_messages_pb2 as databricks_uc_registry_messages_pb2",
138      ),
139      (
140          "import databricks_managed_catalog_messages_pb2 as databricks__managed__catalog__"
141          "messages__pb2",
142          "from . import databricks_managed_catalog_messages_pb2 as databricks_managed_"
143          "catalog_messages_pb2",
144      ),
145      (
146          "import unity_catalog_oss_messages_pb2 as unity__catalog__oss__messages__pb2",
147          "from . import unity_catalog_oss_messages_pb2 as unity_catalog_oss_messages_pb2",
148      ),
149      (
150          "import unity_catalog_prompt_messages_pb2 as unity__catalog__prompt__messages__pb2",
151          "from . import unity_catalog_prompt_messages_pb2 as unity_catalog_prompt_messages_pb2",
152      ),
153      (
154          "import service_pb2 as service__pb2",
155          "from . import service_pb2 as service__pb2",
156      ),
157      (
158          "import assessments_pb2 as assessments__pb2",
159          "from . import assessments_pb2 as assessments__pb2",
160      ),
161      (
162          "import datasets_pb2 as datasets__pb2",
163          "from . import datasets_pb2 as datasets__pb2",
164      ),
165      (
166          "import issues_pb2 as issues__pb2",
167          "from . import issues_pb2 as issues__pb2",
168      ),
169      (
170          "import webhooks_pb2 as webhooks__pb2",
171          "from . import webhooks_pb2 as webhooks__pb2",
172      ),
173      (
174          "import jobs_pb2 as jobs__pb2",
175          "from . import jobs_pb2 as jobs__pb2",
176      ),
177      (
178          "import prompt_optimization_pb2 as prompt__optimization__pb2",
179          "from . import prompt_optimization_pb2 as prompt__optimization__pb2",
180      ),
181      (
182          "import databricks_exception_with_details_pb2 as databricks__exception__with__details__pb2",
183          "from . import databricks_exception_with_details_pb2 as databricks_exception_"
184          "with_details_pb2",
185      ),
186  ]
187  
188  
189  def gen_python_protos(protoc_bin: Path, protoc_include_paths: list[Path], out_dir: Path) -> None:
190      gen_protos(
191          MLFLOW_PROTOS_DIR,
192          python_proto_files,
193          "python",
194          protoc_bin,
195          protoc_include_paths,
196          out_dir,
197      )
198  
199      gen_protos(
200          TEST_PROTOS_DIR,
201          test_proto_files,
202          "python",
203          protoc_bin,
204          protoc_include_paths,
205          out_dir,
206      )
207  
208      for proto_file in python_proto_files:
209          apply_python_gencode_replacement(out_dir / _get_python_output_path(proto_file))
210  
211  
212  def download_file(url: str, output_path: Path) -> None:
213      urllib.request.urlretrieve(url, output_path)
214  
215  
216  def download_and_extract_protoc(version: Literal["3.19.4", "26.0"]) -> tuple[Path, Path]:
217      """
218      Download and extract specific version protoc tool for Linux systems,
219      return extracted protoc executable file path and include path.
220      """
221      assert SYSTEM == "Linux", "This script only supports Linux systems."
222      assert MACHINE in ["x86_64", "aarch64"], (
223          "This script only supports x86_64 or aarch64 CPU architectures."
224      )
225  
226      cpu_type = "x86_64" if MACHINE == "x86_64" else "aarch_64"
227      protoc_zip_filename = f"protoc-{version}-linux-{cpu_type}.zip"
228  
229      downloaded_protoc_bin = CACHE_DIR / f"protoc-{version}" / "bin" / "protoc"
230      downloaded_protoc_include_path = CACHE_DIR / f"protoc-{version}" / "include"
231      if not (downloaded_protoc_bin.is_file() and downloaded_protoc_include_path.is_dir()):
232          with tempfile.TemporaryDirectory() as t:
233              zip_path = Path(t) / protoc_zip_filename
234              download_file(
235                  f"https://github.com/protocolbuffers/protobuf/releases/download/v{version}/{protoc_zip_filename}",
236                  zip_path,
237              )
238              with zipfile.ZipFile(zip_path, "r") as zip_ref:
239                  zip_ref.extractall(CACHE_DIR / f"protoc-{version}")
240  
241          # Make protoc executable
242          downloaded_protoc_bin.chmod(0o755)
243      return downloaded_protoc_bin, downloaded_protoc_include_path
244  
245  
246  def generate_final_python_gencode(
247      gencode3194_path: Path, gencode5260_path: Path, out_path: Path
248  ) -> None:
249      gencode3194 = gencode3194_path.read_text()
250      gencode5260 = gencode5260_path.read_text()
251  
252      merged_code = f"""
253  import google.protobuf
254  from packaging.version import Version
255  if Version(google.protobuf.__version__).major >= 5:
256  {textwrap.indent(gencode5260, "  ")}
257  else:
258  {textwrap.indent(gencode3194, "  ")}
259  """
260      out_path.write_text(merged_code, encoding="UTF-8")
261  
262  
263  def main() -> None:
264      CACHE_DIR.mkdir(parents=True, exist_ok=True)
265      with tempfile.TemporaryDirectory() as temp_gencode_dir:
266          temp_gencode_path = Path(temp_gencode_dir)
267          proto3194_out = temp_gencode_path / "3.19.4"
268          proto5260_out = temp_gencode_path / "26.0"
269          proto3194_out.mkdir(exist_ok=True)
270          proto5260_out.mkdir(exist_ok=True)
271  
272          protoc3194, protoc3194_include = download_and_extract_protoc("3.19.4")
273          protoc5260, protoc5260_include = download_and_extract_protoc("26.0")
274  
275          # Build include paths list
276          protoc3194_includes = [protoc3194_include, OTEL_PROTOS_DIR]
277          protoc5260_includes = [protoc5260_include, OTEL_PROTOS_DIR]
278  
279          gen_python_protos(protoc3194, protoc3194_includes, proto3194_out)
280          gen_python_protos(protoc5260, protoc5260_includes, proto5260_out)
281  
282          for proto_files, protos_dir in [
283              (python_proto_files, MLFLOW_PROTOS_DIR),
284              (test_proto_files, TEST_PROTOS_DIR),
285          ]:
286              for proto_file in proto_files:
287                  gencode_path = _get_python_output_path(proto_file)
288  
289                  generate_final_python_gencode(
290                      proto3194_out / gencode_path,
291                      proto5260_out / gencode_path,
292                      protos_dir / gencode_path,
293                  )
294  
295      # generate java gencode using pinned protoc 3.19.4 version.
296      gen_protos(
297          MLFLOW_PROTOS_DIR,
298          basic_proto_files,
299          "java",
300          protoc3194,
301          protoc3194_includes,
302          Path("mlflow/java/client/src/main/java"),
303      )
304  
305      gen_stub_files(
306          MLFLOW_PROTOS_DIR,
307          python_proto_files,
308          protoc5260,
309          protoc5260_includes,
310          Path("mlflow/protos/"),
311      )
312  
313      gen_proto_docs(
314          MLFLOW_PROTOS_DIR,
315          basic_proto_files,
316          protoc5260,
317          protoc5260_include,
318          Path("mlflow/protos"),
319      )
320  
321  
322  if __name__ == "__main__":
323      main()