proto_plugin.py
1 import json 2 import sys 3 import textwrap 4 from dataclasses import asdict, dataclass 5 from dataclasses import field as dataclass_field 6 from enum import Enum 7 8 from google.protobuf import descriptor_pb2 9 from google.protobuf.compiler import plugin_pb2 10 11 from mlflow.protos import databricks_pb2 12 13 14 class Visibility(Enum): 15 PUBLIC = "public" 16 INTERNAL = "internal" 17 PUBLIC_UNDOCUMENTED = "public_undocumented" 18 PUBLIC_UNDOCUMENTED_READ_ONLY = "public_undocumented_read_only" 19 20 21 @dataclass 22 class ProtoMessageField: 23 description: str 24 field_name: str 25 field_default: str | None 26 entity_type: str 27 field_type: str 28 full_path: list[str] 29 visibility: str 30 since: str 31 deprecated: bool 32 repeated: bool 33 validate_required: bool 34 oneof: list["ProtoMessageField"] = dataclass_field(default_factory=list) 35 36 37 @dataclass 38 class ProtoMessage: 39 name: str 40 full_path: list[str] 41 description: str 42 visibility: str 43 fields: list[ProtoMessageField] 44 enums: list["ProtoEnum"] 45 messages: list["ProtoMessage"] 46 47 48 @dataclass 49 class ProtoEnumValue: 50 value: str 51 full_path: list[str] 52 visibility: str 53 description: str 54 55 56 @dataclass 57 class ProtoEnum: 58 name: str 59 description: str 60 full_path: list[str] 61 values: list[ProtoEnumValue] 62 visibility: str 63 64 65 @dataclass 66 class DatabricksRpcOptionsDescription: 67 path: str | None = None 68 method: str | None = None 69 visibility: str = "internal" 70 since_major: int | None = None 71 since_minor: int | None = None 72 error_codes: list[int] | None = None 73 rpc_doc_title: str = "" 74 75 def __post_init__(self): 76 if self.error_codes is None: 77 self.error_codes = [] 78 79 80 @dataclass 81 class ProtoServiceMethod: 82 name: str 83 full_path: list[str] 84 request_full_path: list[str] 85 response_full_path: list[str] 86 description: str 87 rpc_options: DatabricksRpcOptionsDescription | None 88 89 90 @dataclass 91 class ProtoService: 92 name: str 93 full_path: list[str] 94 description: str 95 visibility: str 96 methods: list[ProtoServiceMethod] 97 98 99 @dataclass 100 class ProtoTopComment: 101 content: str 102 visibility: str 103 104 105 @dataclass 106 class ProtoFileElement: 107 comment: ProtoTopComment | None = None 108 enum: ProtoEnum | None = None 109 message: ProtoMessage | None = None 110 service: ProtoService | None = None 111 112 113 @dataclass 114 class ProtoFile: 115 filename: str 116 requested_visibility: str 117 content: list[ProtoFileElement] 118 119 120 @dataclass 121 class ProtoAllContent: 122 requested_visibility: str 123 files: list[ProtoFile] 124 125 126 class ProtobufDocGenerator: 127 def get_field_type_name(self, field: descriptor_pb2.FieldDescriptorProto) -> str: 128 type_names = { 129 descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE: "DOUBLE", 130 descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT: "FLOAT", 131 descriptor_pb2.FieldDescriptorProto.TYPE_INT64: "INT64", 132 descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: "UINT64", 133 descriptor_pb2.FieldDescriptorProto.TYPE_INT32: "INT32", 134 descriptor_pb2.FieldDescriptorProto.TYPE_FIXED64: "FIXED64", 135 descriptor_pb2.FieldDescriptorProto.TYPE_FIXED32: "FIXED32", 136 descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: "BOOL", 137 descriptor_pb2.FieldDescriptorProto.TYPE_STRING: "STRING", 138 descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: "BYTES", 139 descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: "UINT32", 140 descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED32: "SFIXED32", 141 descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED64: "SFIXED64", 142 descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: "SINT32", 143 descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: "SINT64", 144 } 145 146 if field.type in type_names: 147 return type_names[field.type] 148 elif field.type in ( 149 descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, 150 descriptor_pb2.FieldDescriptorProto.TYPE_ENUM, 151 ): 152 # Remove leading dot if present 153 type_name = field.type_name 154 return type_name.removeprefix(".") 155 else: 156 return "unknown" 157 158 def get_visibility(self, options) -> Visibility: 159 # For now, return PUBLIC as default 160 # In a real implementation, this would check custom options 161 return Visibility.PUBLIC 162 163 def get_validate_required(self, options) -> bool: 164 """Extract validate_required value from field options. 165 166 Checks if the field has the validate_required extension set. 167 """ 168 if options.HasExtension(databricks_pb2.validate_required): 169 return options.Extensions[databricks_pb2.validate_required] 170 return False 171 172 def extract_rpc_options(self, options) -> DatabricksRpcOptionsDescription | None: 173 """Extract RPC options from method options. 174 175 This requires databricks_pb2 to be available for full extraction. 176 When running as a protoc plugin, the extension data is present in the 177 file descriptors but needs the compiled extension definitions to parse. 178 """ 179 # Check if the method has the Databricks RPC extension 180 if options.HasExtension(databricks_pb2.rpc): 181 # Extract the RPC options 182 rpc_ext = options.Extensions[databricks_pb2.rpc] 183 184 # Extract endpoint information 185 path = None 186 method = None 187 since_major = None 188 since_minor = None 189 190 if rpc_ext.endpoints: 191 # Use the first endpoint for now 192 endpoint = rpc_ext.endpoints[0] 193 path = endpoint.path if endpoint.HasField("path") else None 194 method = endpoint.method if endpoint.HasField("method") else None 195 196 if endpoint.HasField("since") and endpoint.since: 197 since_major = endpoint.since.major if endpoint.since.HasField("major") else None 198 since_minor = endpoint.since.minor if endpoint.since.HasField("minor") else None 199 200 # Extract visibility 201 visibility = "internal" # default 202 if rpc_ext.HasField("visibility"): 203 visibility_enum = rpc_ext.visibility 204 # Map the enum value to string 205 if visibility_enum == databricks_pb2.PUBLIC: 206 visibility = "public" 207 elif visibility_enum == databricks_pb2.INTERNAL: 208 visibility = "internal" 209 elif visibility_enum == databricks_pb2.PUBLIC_UNDOCUMENTED: 210 visibility = "public_undocumented" 211 212 # Extract error codes 213 error_codes = list(rpc_ext.error_codes) if rpc_ext.error_codes else [] 214 215 # Extract RPC doc title 216 rpc_doc_title = rpc_ext.rpc_doc_title if rpc_ext.HasField("rpc_doc_title") else "" 217 218 return DatabricksRpcOptionsDescription( 219 path=path, 220 method=method, 221 visibility=visibility, 222 since_major=since_major, 223 since_minor=since_minor, 224 error_codes=error_codes, 225 rpc_doc_title=rpc_doc_title, 226 ) 227 228 def get_full_path_for_file( 229 self, file: descriptor_pb2.FileDescriptorProto, name: str 230 ) -> list[str]: 231 path_parts = [] 232 if file.package: 233 path_parts.extend(file.package.split(".")) 234 path_parts.append(name) 235 return path_parts 236 237 def get_full_path_for_nested(self, parent_path: list[str], name: str) -> list[str]: 238 return parent_path + [name] 239 240 def get_documentation(self, source_location: descriptor_pb2.SourceCodeInfo.Location) -> str: 241 if source_location and source_location.leading_comments: 242 return textwrap.dedent(source_location.leading_comments).strip() 243 return "" 244 245 def find_source_location( 246 self, source_info: descriptor_pb2.SourceCodeInfo, path: list[int] 247 ) -> descriptor_pb2.SourceCodeInfo.Location | None: 248 for location in source_info.location: 249 if list(location.path) == path: 250 return location 251 return None 252 253 def process_field( 254 self, 255 field: descriptor_pb2.FieldDescriptorProto, 256 parent_path: list[str], 257 field_index: int, 258 source_info: descriptor_pb2.SourceCodeInfo, 259 message_path: list[int], 260 ) -> ProtoMessageField: 261 # Build source location path for this field 262 field_path = message_path + [2, field_index] # 2 = field in message 263 location = self.find_source_location(source_info, field_path) 264 265 field_type = self.get_field_type_name(field) 266 267 default_value = None 268 if field.HasField("default_value"): 269 default_value = field.default_value 270 271 return ProtoMessageField( 272 description=self.get_documentation(location) if location else "", 273 field_name=field.name, 274 field_default=default_value, 275 entity_type=str(field.type), 276 field_type=field_type, 277 full_path=self.get_full_path_for_nested(parent_path, field.name), 278 visibility=self.get_visibility(field.options).value, 279 since="", 280 deprecated=field.options.deprecated if field.options.HasField("deprecated") else False, 281 repeated=field.label == descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED, 282 validate_required=self.get_validate_required(field.options), 283 oneof=[], 284 ) 285 286 def process_enum_value( 287 self, 288 value: descriptor_pb2.EnumValueDescriptorProto, 289 parent_path: list[str], 290 value_index: int, 291 source_info: descriptor_pb2.SourceCodeInfo, 292 enum_path: list[int], 293 ) -> ProtoEnumValue: 294 # Build source location path for this enum value 295 value_path = enum_path + [2, value_index] # 2 = value in enum 296 location = self.find_source_location(source_info, value_path) 297 298 return ProtoEnumValue( 299 value=value.name, 300 full_path=self.get_full_path_for_nested(parent_path, value.name), 301 visibility=self.get_visibility(value.options).value, 302 description=self.get_documentation(location) if location else "", 303 ) 304 305 def process_enum( 306 self, 307 enum: descriptor_pb2.EnumDescriptorProto, 308 parent_path: list[str], 309 enum_index: int, 310 source_info: descriptor_pb2.SourceCodeInfo, 311 parent_path_numbers: list[int], 312 is_nested: bool = False, 313 ) -> ProtoEnum: 314 # Build source location path for this enum 315 enum_path = parent_path_numbers + [4, enum_index] if is_nested else [5, enum_index] 316 317 location = self.find_source_location(source_info, enum_path) 318 319 full_path = self.get_full_path_for_nested(parent_path, enum.name) 320 321 values = [] 322 for i, value in enumerate(enum.value): 323 values.append(self.process_enum_value(value, full_path, i, source_info, enum_path)) 324 325 return ProtoEnum( 326 name=enum.name, 327 description=self.get_documentation(location) if location else "", 328 full_path=full_path, 329 values=values, 330 visibility=self.get_visibility(enum.options).value, 331 ) 332 333 def process_message( 334 self, 335 msg: descriptor_pb2.DescriptorProto, 336 parent_path: list[str], 337 msg_index: int, 338 source_info: descriptor_pb2.SourceCodeInfo, 339 parent_path_numbers: list[int] | None = None, 340 is_nested: bool = False, 341 ) -> ProtoMessage: 342 # Build source location path for this message 343 message_path = parent_path_numbers + [3, msg_index] if is_nested else [4, msg_index] 344 345 location = self.find_source_location(source_info, message_path) 346 347 full_path = self.get_full_path_for_nested(parent_path, msg.name) 348 349 fields = [] 350 # Process regular fields 351 for i, proto_field in enumerate(msg.field): 352 if not proto_field.HasField("oneof_index"): # Skip oneof fields for now 353 fields.append( 354 self.process_field(proto_field, full_path, i, source_info, message_path) 355 ) 356 357 # Process oneofs 358 for oneof_index, oneof in enumerate(msg.oneof_decl): 359 oneof_fields = [] 360 for i, proto_field in enumerate(msg.field): 361 if proto_field.HasField("oneof_index") and proto_field.oneof_index == oneof_index: 362 oneof_fields.append( 363 self.process_field(proto_field, full_path, i, source_info, message_path) 364 ) 365 366 if oneof_fields: 367 oneof_field = ProtoMessageField( 368 description="", 369 field_name=oneof.name, 370 field_default=None, 371 entity_type="oneof", 372 field_type="oneof", 373 full_path=self.get_full_path_for_nested(full_path, oneof.name), 374 visibility=self.get_visibility(oneof.options).value, 375 since="", 376 deprecated=False, 377 repeated=False, 378 validate_required=False, 379 oneof=oneof_fields, 380 ) 381 fields.append(oneof_field) 382 383 # Process nested enums 384 enums = [] 385 for i, enum in enumerate(msg.enum_type): 386 enums.append( 387 self.process_enum(enum, full_path, i, source_info, message_path, is_nested=True) 388 ) 389 390 # Process nested messages 391 messages = [] 392 for i, nested in enumerate(msg.nested_type): 393 messages.append( 394 self.process_message( 395 nested, full_path, i, source_info, message_path, is_nested=True 396 ) 397 ) 398 399 return ProtoMessage( 400 name=msg.name, 401 full_path=full_path, 402 description=self.get_documentation(location) if location else "", 403 visibility=self.get_visibility(msg.options).value, 404 fields=fields, 405 enums=enums, 406 messages=messages, 407 ) 408 409 def process_method( 410 self, 411 method: descriptor_pb2.MethodDescriptorProto, 412 parent_path: list[str], 413 method_index: int, 414 source_info: descriptor_pb2.SourceCodeInfo, 415 service_path: list[int], 416 ) -> ProtoServiceMethod: 417 # Build source location path for this method 418 method_path = service_path + [2, method_index] # 2 = method in service 419 location = self.find_source_location(source_info, method_path) 420 421 # Remove leading dots from type names 422 input_type = method.input_type 423 input_type = input_type.removeprefix(".") 424 output_type = method.output_type 425 output_type = output_type.removeprefix(".") 426 427 input_path = input_type.split(".") 428 output_path = output_type.split(".") 429 430 # Extract RPC options from custom extensions 431 rpc_options = self.extract_rpc_options(method.options) 432 433 return ProtoServiceMethod( 434 name=method.name, 435 full_path=self.get_full_path_for_nested(parent_path, method.name), 436 request_full_path=input_path, 437 response_full_path=output_path, 438 description=self.get_documentation(location) if location else "", 439 rpc_options=rpc_options, 440 ) 441 442 def process_service( 443 self, 444 service: descriptor_pb2.ServiceDescriptorProto, 445 parent_path: list[str], 446 service_index: int, 447 source_info: descriptor_pb2.SourceCodeInfo, 448 ) -> ProtoService: 449 # Build source location path for this service 450 service_path = [6, service_index] # 6 = service at file level 451 location = self.find_source_location(source_info, service_path) 452 453 full_path = self.get_full_path_for_nested(parent_path, service.name) 454 455 methods = [] 456 for i, method in enumerate(service.method): 457 methods.append(self.process_method(method, full_path, i, source_info, service_path)) 458 459 return ProtoService( 460 name=service.name, 461 full_path=full_path, 462 description=self.get_documentation(location) if location else "", 463 visibility=self.get_visibility(service.options).value, 464 methods=methods, 465 ) 466 467 def process_file( 468 self, file: descriptor_pb2.FileDescriptorProto, requested_vis: Visibility 469 ) -> ProtoFile: 470 elements = [] 471 472 # Base path from package 473 base_path = file.package.split(".") if file.package else [] 474 475 # Get source code info for documentation 476 source_info = file.source_code_info if file.HasField("source_code_info") else None 477 if not source_info: 478 source_info = descriptor_pb2.SourceCodeInfo() 479 480 # Process top-level messages 481 for i, msg in enumerate(file.message_type): 482 elements.append( 483 ProtoFileElement(message=self.process_message(msg, base_path, i, source_info)) 484 ) 485 486 # Process top-level enums 487 for i, enum in enumerate(file.enum_type): 488 elements.append( 489 ProtoFileElement( 490 enum=self.process_enum(enum, base_path, i, source_info, [], is_nested=False) 491 ) 492 ) 493 494 # Process services 495 for i, service in enumerate(file.service): 496 elements.append( 497 ProtoFileElement(service=self.process_service(service, base_path, i, source_info)) 498 ) 499 500 return ProtoFile( 501 filename=file.name, requested_visibility=requested_vis.value, content=elements 502 ) 503 504 505 class ProtocPlugin: 506 """Protoc plugin implementation.""" 507 508 def __init__(self): 509 self.generator = ProtobufDocGenerator() 510 511 def process_request( 512 self, request: plugin_pb2.CodeGeneratorRequest 513 ) -> plugin_pb2.CodeGeneratorResponse: 514 response = plugin_pb2.CodeGeneratorResponse() 515 516 files = [] 517 518 # Process each file that was requested to be generated 519 for file_name in request.file_to_generate: 520 # Find the file descriptor 521 file_descriptor = None 522 for proto_file in request.proto_file: 523 if proto_file.name == file_name: 524 file_descriptor = proto_file 525 break 526 527 if file_descriptor: 528 # Process the file 529 proto_file = self.generator.process_file(file_descriptor, Visibility.PUBLIC) 530 files.append(proto_file) 531 532 # Generate documentation 533 doc_content = ProtoAllContent(requested_visibility=Visibility.PUBLIC.value, files=files) 534 535 # Generate doc_public.json 536 doc_file = response.file.add() 537 doc_file.name = "protos.json" 538 doc_file.content = json.dumps(asdict(doc_content), indent=2) 539 540 return response 541 542 543 def main(): 544 # Protoc plugin mode 545 # Read CodeGeneratorRequest from stdin 546 data = sys.stdin.buffer.read() 547 request = plugin_pb2.CodeGeneratorRequest.FromString(data) 548 549 # Process request 550 plugin = ProtocPlugin() 551 response = plugin.process_request(request) 552 553 # Write response to stdout 554 sys.stdout.buffer.write(response.SerializeToString()) 555 556 557 if __name__ == "__main__": 558 main()