/ dev / proto_plugin.py
proto_plugin.py
  1  import json
  2  import sys
  3  import textwrap
  4  from dataclasses import asdict, dataclass
  5  from dataclasses import field as dataclass_field
  6  from enum import Enum
  7  
  8  from google.protobuf import descriptor_pb2
  9  from google.protobuf.compiler import plugin_pb2
 10  
 11  from mlflow.protos import databricks_pb2
 12  
 13  
 14  class Visibility(Enum):
 15      PUBLIC = "public"
 16      INTERNAL = "internal"
 17      PUBLIC_UNDOCUMENTED = "public_undocumented"
 18      PUBLIC_UNDOCUMENTED_READ_ONLY = "public_undocumented_read_only"
 19  
 20  
 21  @dataclass
 22  class ProtoMessageField:
 23      description: str
 24      field_name: str
 25      field_default: str | None
 26      entity_type: str
 27      field_type: str
 28      full_path: list[str]
 29      visibility: str
 30      since: str
 31      deprecated: bool
 32      repeated: bool
 33      validate_required: bool
 34      oneof: list["ProtoMessageField"] = dataclass_field(default_factory=list)
 35  
 36  
 37  @dataclass
 38  class ProtoMessage:
 39      name: str
 40      full_path: list[str]
 41      description: str
 42      visibility: str
 43      fields: list[ProtoMessageField]
 44      enums: list["ProtoEnum"]
 45      messages: list["ProtoMessage"]
 46  
 47  
 48  @dataclass
 49  class ProtoEnumValue:
 50      value: str
 51      full_path: list[str]
 52      visibility: str
 53      description: str
 54  
 55  
 56  @dataclass
 57  class ProtoEnum:
 58      name: str
 59      description: str
 60      full_path: list[str]
 61      values: list[ProtoEnumValue]
 62      visibility: str
 63  
 64  
 65  @dataclass
 66  class DatabricksRpcOptionsDescription:
 67      path: str | None = None
 68      method: str | None = None
 69      visibility: str = "internal"
 70      since_major: int | None = None
 71      since_minor: int | None = None
 72      error_codes: list[int] | None = None
 73      rpc_doc_title: str = ""
 74  
 75      def __post_init__(self):
 76          if self.error_codes is None:
 77              self.error_codes = []
 78  
 79  
 80  @dataclass
 81  class ProtoServiceMethod:
 82      name: str
 83      full_path: list[str]
 84      request_full_path: list[str]
 85      response_full_path: list[str]
 86      description: str
 87      rpc_options: DatabricksRpcOptionsDescription | None
 88  
 89  
 90  @dataclass
 91  class ProtoService:
 92      name: str
 93      full_path: list[str]
 94      description: str
 95      visibility: str
 96      methods: list[ProtoServiceMethod]
 97  
 98  
 99  @dataclass
100  class ProtoTopComment:
101      content: str
102      visibility: str
103  
104  
105  @dataclass
106  class ProtoFileElement:
107      comment: ProtoTopComment | None = None
108      enum: ProtoEnum | None = None
109      message: ProtoMessage | None = None
110      service: ProtoService | None = None
111  
112  
113  @dataclass
114  class ProtoFile:
115      filename: str
116      requested_visibility: str
117      content: list[ProtoFileElement]
118  
119  
120  @dataclass
121  class ProtoAllContent:
122      requested_visibility: str
123      files: list[ProtoFile]
124  
125  
126  class ProtobufDocGenerator:
127      def get_field_type_name(self, field: descriptor_pb2.FieldDescriptorProto) -> str:
128          type_names = {
129              descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE: "DOUBLE",
130              descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT: "FLOAT",
131              descriptor_pb2.FieldDescriptorProto.TYPE_INT64: "INT64",
132              descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: "UINT64",
133              descriptor_pb2.FieldDescriptorProto.TYPE_INT32: "INT32",
134              descriptor_pb2.FieldDescriptorProto.TYPE_FIXED64: "FIXED64",
135              descriptor_pb2.FieldDescriptorProto.TYPE_FIXED32: "FIXED32",
136              descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: "BOOL",
137              descriptor_pb2.FieldDescriptorProto.TYPE_STRING: "STRING",
138              descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: "BYTES",
139              descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: "UINT32",
140              descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED32: "SFIXED32",
141              descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED64: "SFIXED64",
142              descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: "SINT32",
143              descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: "SINT64",
144          }
145  
146          if field.type in type_names:
147              return type_names[field.type]
148          elif field.type in (
149              descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE,
150              descriptor_pb2.FieldDescriptorProto.TYPE_ENUM,
151          ):
152              # Remove leading dot if present
153              type_name = field.type_name
154              return type_name.removeprefix(".")
155          else:
156              return "unknown"
157  
158      def get_visibility(self, options) -> Visibility:
159          # For now, return PUBLIC as default
160          # In a real implementation, this would check custom options
161          return Visibility.PUBLIC
162  
163      def get_validate_required(self, options) -> bool:
164          """Extract validate_required value from field options.
165  
166          Checks if the field has the validate_required extension set.
167          """
168          if options.HasExtension(databricks_pb2.validate_required):
169              return options.Extensions[databricks_pb2.validate_required]
170          return False
171  
172      def extract_rpc_options(self, options) -> DatabricksRpcOptionsDescription | None:
173          """Extract RPC options from method options.
174  
175          This requires databricks_pb2 to be available for full extraction.
176          When running as a protoc plugin, the extension data is present in the
177          file descriptors but needs the compiled extension definitions to parse.
178          """
179          # Check if the method has the Databricks RPC extension
180          if options.HasExtension(databricks_pb2.rpc):
181              # Extract the RPC options
182              rpc_ext = options.Extensions[databricks_pb2.rpc]
183  
184              # Extract endpoint information
185              path = None
186              method = None
187              since_major = None
188              since_minor = None
189  
190              if rpc_ext.endpoints:
191                  # Use the first endpoint for now
192                  endpoint = rpc_ext.endpoints[0]
193                  path = endpoint.path if endpoint.HasField("path") else None
194                  method = endpoint.method if endpoint.HasField("method") else None
195  
196                  if endpoint.HasField("since") and endpoint.since:
197                      since_major = endpoint.since.major if endpoint.since.HasField("major") else None
198                      since_minor = endpoint.since.minor if endpoint.since.HasField("minor") else None
199  
200              # Extract visibility
201              visibility = "internal"  # default
202              if rpc_ext.HasField("visibility"):
203                  visibility_enum = rpc_ext.visibility
204                  # Map the enum value to string
205                  if visibility_enum == databricks_pb2.PUBLIC:
206                      visibility = "public"
207                  elif visibility_enum == databricks_pb2.INTERNAL:
208                      visibility = "internal"
209                  elif visibility_enum == databricks_pb2.PUBLIC_UNDOCUMENTED:
210                      visibility = "public_undocumented"
211  
212              # Extract error codes
213              error_codes = list(rpc_ext.error_codes) if rpc_ext.error_codes else []
214  
215              # Extract RPC doc title
216              rpc_doc_title = rpc_ext.rpc_doc_title if rpc_ext.HasField("rpc_doc_title") else ""
217  
218              return DatabricksRpcOptionsDescription(
219                  path=path,
220                  method=method,
221                  visibility=visibility,
222                  since_major=since_major,
223                  since_minor=since_minor,
224                  error_codes=error_codes,
225                  rpc_doc_title=rpc_doc_title,
226              )
227  
228      def get_full_path_for_file(
229          self, file: descriptor_pb2.FileDescriptorProto, name: str
230      ) -> list[str]:
231          path_parts = []
232          if file.package:
233              path_parts.extend(file.package.split("."))
234          path_parts.append(name)
235          return path_parts
236  
237      def get_full_path_for_nested(self, parent_path: list[str], name: str) -> list[str]:
238          return parent_path + [name]
239  
240      def get_documentation(self, source_location: descriptor_pb2.SourceCodeInfo.Location) -> str:
241          if source_location and source_location.leading_comments:
242              return textwrap.dedent(source_location.leading_comments).strip()
243          return ""
244  
245      def find_source_location(
246          self, source_info: descriptor_pb2.SourceCodeInfo, path: list[int]
247      ) -> descriptor_pb2.SourceCodeInfo.Location | None:
248          for location in source_info.location:
249              if list(location.path) == path:
250                  return location
251          return None
252  
253      def process_field(
254          self,
255          field: descriptor_pb2.FieldDescriptorProto,
256          parent_path: list[str],
257          field_index: int,
258          source_info: descriptor_pb2.SourceCodeInfo,
259          message_path: list[int],
260      ) -> ProtoMessageField:
261          # Build source location path for this field
262          field_path = message_path + [2, field_index]  # 2 = field in message
263          location = self.find_source_location(source_info, field_path)
264  
265          field_type = self.get_field_type_name(field)
266  
267          default_value = None
268          if field.HasField("default_value"):
269              default_value = field.default_value
270  
271          return ProtoMessageField(
272              description=self.get_documentation(location) if location else "",
273              field_name=field.name,
274              field_default=default_value,
275              entity_type=str(field.type),
276              field_type=field_type,
277              full_path=self.get_full_path_for_nested(parent_path, field.name),
278              visibility=self.get_visibility(field.options).value,
279              since="",
280              deprecated=field.options.deprecated if field.options.HasField("deprecated") else False,
281              repeated=field.label == descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED,
282              validate_required=self.get_validate_required(field.options),
283              oneof=[],
284          )
285  
286      def process_enum_value(
287          self,
288          value: descriptor_pb2.EnumValueDescriptorProto,
289          parent_path: list[str],
290          value_index: int,
291          source_info: descriptor_pb2.SourceCodeInfo,
292          enum_path: list[int],
293      ) -> ProtoEnumValue:
294          # Build source location path for this enum value
295          value_path = enum_path + [2, value_index]  # 2 = value in enum
296          location = self.find_source_location(source_info, value_path)
297  
298          return ProtoEnumValue(
299              value=value.name,
300              full_path=self.get_full_path_for_nested(parent_path, value.name),
301              visibility=self.get_visibility(value.options).value,
302              description=self.get_documentation(location) if location else "",
303          )
304  
305      def process_enum(
306          self,
307          enum: descriptor_pb2.EnumDescriptorProto,
308          parent_path: list[str],
309          enum_index: int,
310          source_info: descriptor_pb2.SourceCodeInfo,
311          parent_path_numbers: list[int],
312          is_nested: bool = False,
313      ) -> ProtoEnum:
314          # Build source location path for this enum
315          enum_path = parent_path_numbers + [4, enum_index] if is_nested else [5, enum_index]
316  
317          location = self.find_source_location(source_info, enum_path)
318  
319          full_path = self.get_full_path_for_nested(parent_path, enum.name)
320  
321          values = []
322          for i, value in enumerate(enum.value):
323              values.append(self.process_enum_value(value, full_path, i, source_info, enum_path))
324  
325          return ProtoEnum(
326              name=enum.name,
327              description=self.get_documentation(location) if location else "",
328              full_path=full_path,
329              values=values,
330              visibility=self.get_visibility(enum.options).value,
331          )
332  
333      def process_message(
334          self,
335          msg: descriptor_pb2.DescriptorProto,
336          parent_path: list[str],
337          msg_index: int,
338          source_info: descriptor_pb2.SourceCodeInfo,
339          parent_path_numbers: list[int] | None = None,
340          is_nested: bool = False,
341      ) -> ProtoMessage:
342          # Build source location path for this message
343          message_path = parent_path_numbers + [3, msg_index] if is_nested else [4, msg_index]
344  
345          location = self.find_source_location(source_info, message_path)
346  
347          full_path = self.get_full_path_for_nested(parent_path, msg.name)
348  
349          fields = []
350          # Process regular fields
351          for i, proto_field in enumerate(msg.field):
352              if not proto_field.HasField("oneof_index"):  # Skip oneof fields for now
353                  fields.append(
354                      self.process_field(proto_field, full_path, i, source_info, message_path)
355                  )
356  
357          # Process oneofs
358          for oneof_index, oneof in enumerate(msg.oneof_decl):
359              oneof_fields = []
360              for i, proto_field in enumerate(msg.field):
361                  if proto_field.HasField("oneof_index") and proto_field.oneof_index == oneof_index:
362                      oneof_fields.append(
363                          self.process_field(proto_field, full_path, i, source_info, message_path)
364                      )
365  
366              if oneof_fields:
367                  oneof_field = ProtoMessageField(
368                      description="",
369                      field_name=oneof.name,
370                      field_default=None,
371                      entity_type="oneof",
372                      field_type="oneof",
373                      full_path=self.get_full_path_for_nested(full_path, oneof.name),
374                      visibility=self.get_visibility(oneof.options).value,
375                      since="",
376                      deprecated=False,
377                      repeated=False,
378                      validate_required=False,
379                      oneof=oneof_fields,
380                  )
381                  fields.append(oneof_field)
382  
383          # Process nested enums
384          enums = []
385          for i, enum in enumerate(msg.enum_type):
386              enums.append(
387                  self.process_enum(enum, full_path, i, source_info, message_path, is_nested=True)
388              )
389  
390          # Process nested messages
391          messages = []
392          for i, nested in enumerate(msg.nested_type):
393              messages.append(
394                  self.process_message(
395                      nested, full_path, i, source_info, message_path, is_nested=True
396                  )
397              )
398  
399          return ProtoMessage(
400              name=msg.name,
401              full_path=full_path,
402              description=self.get_documentation(location) if location else "",
403              visibility=self.get_visibility(msg.options).value,
404              fields=fields,
405              enums=enums,
406              messages=messages,
407          )
408  
409      def process_method(
410          self,
411          method: descriptor_pb2.MethodDescriptorProto,
412          parent_path: list[str],
413          method_index: int,
414          source_info: descriptor_pb2.SourceCodeInfo,
415          service_path: list[int],
416      ) -> ProtoServiceMethod:
417          # Build source location path for this method
418          method_path = service_path + [2, method_index]  # 2 = method in service
419          location = self.find_source_location(source_info, method_path)
420  
421          # Remove leading dots from type names
422          input_type = method.input_type
423          input_type = input_type.removeprefix(".")
424          output_type = method.output_type
425          output_type = output_type.removeprefix(".")
426  
427          input_path = input_type.split(".")
428          output_path = output_type.split(".")
429  
430          # Extract RPC options from custom extensions
431          rpc_options = self.extract_rpc_options(method.options)
432  
433          return ProtoServiceMethod(
434              name=method.name,
435              full_path=self.get_full_path_for_nested(parent_path, method.name),
436              request_full_path=input_path,
437              response_full_path=output_path,
438              description=self.get_documentation(location) if location else "",
439              rpc_options=rpc_options,
440          )
441  
442      def process_service(
443          self,
444          service: descriptor_pb2.ServiceDescriptorProto,
445          parent_path: list[str],
446          service_index: int,
447          source_info: descriptor_pb2.SourceCodeInfo,
448      ) -> ProtoService:
449          # Build source location path for this service
450          service_path = [6, service_index]  # 6 = service at file level
451          location = self.find_source_location(source_info, service_path)
452  
453          full_path = self.get_full_path_for_nested(parent_path, service.name)
454  
455          methods = []
456          for i, method in enumerate(service.method):
457              methods.append(self.process_method(method, full_path, i, source_info, service_path))
458  
459          return ProtoService(
460              name=service.name,
461              full_path=full_path,
462              description=self.get_documentation(location) if location else "",
463              visibility=self.get_visibility(service.options).value,
464              methods=methods,
465          )
466  
467      def process_file(
468          self, file: descriptor_pb2.FileDescriptorProto, requested_vis: Visibility
469      ) -> ProtoFile:
470          elements = []
471  
472          # Base path from package
473          base_path = file.package.split(".") if file.package else []
474  
475          # Get source code info for documentation
476          source_info = file.source_code_info if file.HasField("source_code_info") else None
477          if not source_info:
478              source_info = descriptor_pb2.SourceCodeInfo()
479  
480          # Process top-level messages
481          for i, msg in enumerate(file.message_type):
482              elements.append(
483                  ProtoFileElement(message=self.process_message(msg, base_path, i, source_info))
484              )
485  
486          # Process top-level enums
487          for i, enum in enumerate(file.enum_type):
488              elements.append(
489                  ProtoFileElement(
490                      enum=self.process_enum(enum, base_path, i, source_info, [], is_nested=False)
491                  )
492              )
493  
494          # Process services
495          for i, service in enumerate(file.service):
496              elements.append(
497                  ProtoFileElement(service=self.process_service(service, base_path, i, source_info))
498              )
499  
500          return ProtoFile(
501              filename=file.name, requested_visibility=requested_vis.value, content=elements
502          )
503  
504  
505  class ProtocPlugin:
506      """Protoc plugin implementation."""
507  
508      def __init__(self):
509          self.generator = ProtobufDocGenerator()
510  
511      def process_request(
512          self, request: plugin_pb2.CodeGeneratorRequest
513      ) -> plugin_pb2.CodeGeneratorResponse:
514          response = plugin_pb2.CodeGeneratorResponse()
515  
516          files = []
517  
518          # Process each file that was requested to be generated
519          for file_name in request.file_to_generate:
520              # Find the file descriptor
521              file_descriptor = None
522              for proto_file in request.proto_file:
523                  if proto_file.name == file_name:
524                      file_descriptor = proto_file
525                      break
526  
527              if file_descriptor:
528                  # Process the file
529                  proto_file = self.generator.process_file(file_descriptor, Visibility.PUBLIC)
530                  files.append(proto_file)
531  
532          # Generate documentation
533          doc_content = ProtoAllContent(requested_visibility=Visibility.PUBLIC.value, files=files)
534  
535          # Generate doc_public.json
536          doc_file = response.file.add()
537          doc_file.name = "protos.json"
538          doc_file.content = json.dumps(asdict(doc_content), indent=2)
539  
540          return response
541  
542  
543  def main():
544      # Protoc plugin mode
545      # Read CodeGeneratorRequest from stdin
546      data = sys.stdin.buffer.read()
547      request = plugin_pb2.CodeGeneratorRequest.FromString(data)
548  
549      # Process request
550      plugin = ProtocPlugin()
551      response = plugin.process_request(request)
552  
553      # Write response to stdout
554      sys.stdout.buffer.write(response.SerializeToString())
555  
556  
557  if __name__ == "__main__":
558      main()