generate_protos.py
1 import platform 2 import subprocess 3 import tempfile 4 import textwrap 5 import urllib.request 6 import zipfile 7 from pathlib import Path 8 from typing import Literal 9 10 SYSTEM = platform.system() 11 MACHINE = platform.machine() 12 CACHE_DIR = Path(".cache/protobuf_cache") 13 MLFLOW_PROTOS_DIR = Path("mlflow/protos") 14 TEST_PROTOS_DIR = Path("tests/protos") 15 OTEL_PROTOS_DIR = Path("mlflow/protos/opentelemetry") 16 17 18 def gen_protos( 19 proto_dir: Path, 20 proto_files: list[Path], 21 lang: Literal["python", "java"], 22 protoc_bin: Path, 23 protoc_include_paths: list[Path], 24 out_dir: Path, 25 ) -> None: 26 assert lang in ["python", "java"] 27 out_dir.mkdir(parents=True, exist_ok=True) 28 subprocess.check_call([ 29 protoc_bin, 30 "--fatal_warnings", 31 *(f"-I={p}" for p in protoc_include_paths), 32 f"-I={proto_dir}", 33 f"--{lang}_out={out_dir}", 34 *[proto_dir / pf for pf in proto_files], 35 ]) 36 37 38 def gen_stub_files( 39 proto_dir: Path, 40 proto_files: list[Path], 41 protoc_bin: Path, 42 protoc_include_paths: list[Path], 43 out_dir: Path, 44 ) -> None: 45 out_dir.mkdir(parents=True, exist_ok=True) 46 subprocess.check_call([ 47 protoc_bin, 48 "--fatal_warnings", 49 *(f"-I={p}" for p in protoc_include_paths), 50 f"-I={proto_dir}", 51 f"--pyi_out={out_dir}", 52 *[proto_dir / pf for pf in proto_files], 53 ]) 54 55 56 def gen_proto_docs( 57 proto_dir: Path, 58 proto_files: list[Path], 59 protoc_bin: Path, 60 protoc_include_path: Path, 61 out_dir: Path, 62 ) -> None: 63 plugin_path = Path("dev/proto-plugin.sh").resolve() 64 subprocess.check_call([ 65 protoc_bin, 66 f"-I={protoc_include_path}", 67 f"-I={proto_dir}", 68 f"--plugin=protoc-gen-doc={plugin_path}", 69 f"--doc_out={out_dir}", 70 *[proto_dir / pf for pf in proto_files], 71 ]) 72 73 74 def apply_python_gencode_replacement(file_path: Path) -> None: 75 content = file_path.read_text() 76 77 for old, new in python_gencode_replacements: 78 content = content.replace(old, new) 79 80 file_path.write_text(content, encoding="UTF-8") 81 82 83 def _get_python_output_path(proto_file_path: Path) -> Path: 84 return proto_file_path.parent / (proto_file_path.stem + "_pb2.py") 85 86 87 def to_paths(*args: str) -> list[Path]: 88 return list(map(Path, args)) 89 90 91 basic_proto_files = to_paths( 92 "databricks.proto", 93 "service.proto", 94 "model_registry.proto", 95 "databricks_artifacts.proto", 96 "mlflow_artifacts.proto", 97 "internal.proto", 98 "scalapb/scalapb.proto", 99 "assessments.proto", 100 "datasets.proto", 101 "issues.proto", 102 "webhooks.proto", 103 "jobs.proto", 104 "prompt_optimization.proto", 105 ) 106 uc_proto_files = to_paths( 107 "databricks_managed_catalog_messages.proto", 108 "databricks_managed_catalog_service.proto", 109 "databricks_uc_registry_messages.proto", 110 "databricks_uc_registry_service.proto", 111 "databricks_filesystem_service.proto", 112 "unity_catalog_oss_messages.proto", 113 "unity_catalog_oss_service.proto", 114 "unity_catalog_prompt_messages.proto", 115 "unity_catalog_prompt_service.proto", 116 ) 117 tracing_proto_files = to_paths( 118 "databricks_exception_with_details.proto", 119 "databricks_tracing.proto", 120 ) 121 facet_proto_files = to_paths("facet_feature_statistics.proto") 122 python_proto_files = basic_proto_files + uc_proto_files + facet_proto_files + tracing_proto_files 123 test_proto_files = to_paths("test_message.proto") 124 125 126 python_gencode_replacements = [ 127 ( 128 "from scalapb import scalapb_pb2 as scalapb_dot_scalapb__pb2", 129 "from .scalapb import scalapb_pb2 as scalapb_dot_scalapb__pb2", 130 ), 131 ( 132 "import databricks_pb2 as databricks__pb2", 133 "from . import databricks_pb2 as databricks__pb2", 134 ), 135 ( 136 "import databricks_uc_registry_messages_pb2 as databricks__uc__registry__messages__pb2", 137 "from . import databricks_uc_registry_messages_pb2 as databricks_uc_registry_messages_pb2", 138 ), 139 ( 140 "import databricks_managed_catalog_messages_pb2 as databricks__managed__catalog__" 141 "messages__pb2", 142 "from . import databricks_managed_catalog_messages_pb2 as databricks_managed_" 143 "catalog_messages_pb2", 144 ), 145 ( 146 "import unity_catalog_oss_messages_pb2 as unity__catalog__oss__messages__pb2", 147 "from . import unity_catalog_oss_messages_pb2 as unity_catalog_oss_messages_pb2", 148 ), 149 ( 150 "import unity_catalog_prompt_messages_pb2 as unity__catalog__prompt__messages__pb2", 151 "from . import unity_catalog_prompt_messages_pb2 as unity_catalog_prompt_messages_pb2", 152 ), 153 ( 154 "import service_pb2 as service__pb2", 155 "from . import service_pb2 as service__pb2", 156 ), 157 ( 158 "import assessments_pb2 as assessments__pb2", 159 "from . import assessments_pb2 as assessments__pb2", 160 ), 161 ( 162 "import datasets_pb2 as datasets__pb2", 163 "from . import datasets_pb2 as datasets__pb2", 164 ), 165 ( 166 "import issues_pb2 as issues__pb2", 167 "from . import issues_pb2 as issues__pb2", 168 ), 169 ( 170 "import webhooks_pb2 as webhooks__pb2", 171 "from . import webhooks_pb2 as webhooks__pb2", 172 ), 173 ( 174 "import jobs_pb2 as jobs__pb2", 175 "from . import jobs_pb2 as jobs__pb2", 176 ), 177 ( 178 "import prompt_optimization_pb2 as prompt__optimization__pb2", 179 "from . import prompt_optimization_pb2 as prompt__optimization__pb2", 180 ), 181 ( 182 "import databricks_exception_with_details_pb2 as databricks__exception__with__details__pb2", 183 "from . import databricks_exception_with_details_pb2 as databricks_exception_" 184 "with_details_pb2", 185 ), 186 ] 187 188 189 def gen_python_protos(protoc_bin: Path, protoc_include_paths: list[Path], out_dir: Path) -> None: 190 gen_protos( 191 MLFLOW_PROTOS_DIR, 192 python_proto_files, 193 "python", 194 protoc_bin, 195 protoc_include_paths, 196 out_dir, 197 ) 198 199 gen_protos( 200 TEST_PROTOS_DIR, 201 test_proto_files, 202 "python", 203 protoc_bin, 204 protoc_include_paths, 205 out_dir, 206 ) 207 208 for proto_file in python_proto_files: 209 apply_python_gencode_replacement(out_dir / _get_python_output_path(proto_file)) 210 211 212 def download_file(url: str, output_path: Path) -> None: 213 urllib.request.urlretrieve(url, output_path) 214 215 216 def download_and_extract_protoc(version: Literal["3.19.4", "26.0"]) -> tuple[Path, Path]: 217 """ 218 Download and extract specific version protoc tool for Linux systems, 219 return extracted protoc executable file path and include path. 220 """ 221 assert SYSTEM == "Linux", "This script only supports Linux systems." 222 assert MACHINE in ["x86_64", "aarch64"], ( 223 "This script only supports x86_64 or aarch64 CPU architectures." 224 ) 225 226 cpu_type = "x86_64" if MACHINE == "x86_64" else "aarch_64" 227 protoc_zip_filename = f"protoc-{version}-linux-{cpu_type}.zip" 228 229 downloaded_protoc_bin = CACHE_DIR / f"protoc-{version}" / "bin" / "protoc" 230 downloaded_protoc_include_path = CACHE_DIR / f"protoc-{version}" / "include" 231 if not (downloaded_protoc_bin.is_file() and downloaded_protoc_include_path.is_dir()): 232 with tempfile.TemporaryDirectory() as t: 233 zip_path = Path(t) / protoc_zip_filename 234 download_file( 235 f"https://github.com/protocolbuffers/protobuf/releases/download/v{version}/{protoc_zip_filename}", 236 zip_path, 237 ) 238 with zipfile.ZipFile(zip_path, "r") as zip_ref: 239 zip_ref.extractall(CACHE_DIR / f"protoc-{version}") 240 241 # Make protoc executable 242 downloaded_protoc_bin.chmod(0o755) 243 return downloaded_protoc_bin, downloaded_protoc_include_path 244 245 246 def generate_final_python_gencode( 247 gencode3194_path: Path, gencode5260_path: Path, out_path: Path 248 ) -> None: 249 gencode3194 = gencode3194_path.read_text() 250 gencode5260 = gencode5260_path.read_text() 251 252 merged_code = f""" 253 import google.protobuf 254 from packaging.version import Version 255 if Version(google.protobuf.__version__).major >= 5: 256 {textwrap.indent(gencode5260, " ")} 257 else: 258 {textwrap.indent(gencode3194, " ")} 259 """ 260 out_path.write_text(merged_code, encoding="UTF-8") 261 262 263 def main() -> None: 264 CACHE_DIR.mkdir(parents=True, exist_ok=True) 265 with tempfile.TemporaryDirectory() as temp_gencode_dir: 266 temp_gencode_path = Path(temp_gencode_dir) 267 proto3194_out = temp_gencode_path / "3.19.4" 268 proto5260_out = temp_gencode_path / "26.0" 269 proto3194_out.mkdir(exist_ok=True) 270 proto5260_out.mkdir(exist_ok=True) 271 272 protoc3194, protoc3194_include = download_and_extract_protoc("3.19.4") 273 protoc5260, protoc5260_include = download_and_extract_protoc("26.0") 274 275 # Build include paths list 276 protoc3194_includes = [protoc3194_include, OTEL_PROTOS_DIR] 277 protoc5260_includes = [protoc5260_include, OTEL_PROTOS_DIR] 278 279 gen_python_protos(protoc3194, protoc3194_includes, proto3194_out) 280 gen_python_protos(protoc5260, protoc5260_includes, proto5260_out) 281 282 for proto_files, protos_dir in [ 283 (python_proto_files, MLFLOW_PROTOS_DIR), 284 (test_proto_files, TEST_PROTOS_DIR), 285 ]: 286 for proto_file in proto_files: 287 gencode_path = _get_python_output_path(proto_file) 288 289 generate_final_python_gencode( 290 proto3194_out / gencode_path, 291 proto5260_out / gencode_path, 292 protos_dir / gencode_path, 293 ) 294 295 # generate java gencode using pinned protoc 3.19.4 version. 296 gen_protos( 297 MLFLOW_PROTOS_DIR, 298 basic_proto_files, 299 "java", 300 protoc3194, 301 protoc3194_includes, 302 Path("mlflow/java/client/src/main/java"), 303 ) 304 305 gen_stub_files( 306 MLFLOW_PROTOS_DIR, 307 python_proto_files, 308 protoc5260, 309 protoc5260_includes, 310 Path("mlflow/protos/"), 311 ) 312 313 gen_proto_docs( 314 MLFLOW_PROTOS_DIR, 315 basic_proto_files, 316 protoc5260, 317 protoc5260_include, 318 Path("mlflow/protos"), 319 ) 320 321 322 if __name__ == "__main__": 323 main()