/ dev / proto_to_graphql / parsing_utils.py
parsing_utils.py
 1  from autogeneration_utils import get_method_name
 2  from google.protobuf.descriptor import FieldDescriptor
 3  
 4  from mlflow.protos import databricks_pb2
 5  
 6  
 7  def get_method_type(method_descriptor):
 8      return method_descriptor.GetOptions().Extensions[databricks_pb2.rpc].endpoints[0].method
 9  
10  
11  def process_method(method_descriptor, state):
12      """
13      Given a method descriptor, add information being referenced into the GenerateSchemaState.
14      """
15      if not method_descriptor.GetOptions().HasExtension(databricks_pb2.graphql):
16          return
17      rpcOptions = method_descriptor.GetOptions().Extensions[databricks_pb2.rpc]
18      # Only add those methods that are not internal.
19      if rpcOptions.visibility != databricks_pb2.INTERNAL:
20          name = get_method_name(method_descriptor)
21          if name in state.method_names:
22              return
23          state.method_names.add(name)
24          request_method = get_method_type(method_descriptor)
25          if request_method == "GET":
26              state.queries.add(method_descriptor)
27          else:
28              state.mutations.add(method_descriptor)
29          state.outputs.add(method_descriptor.output_type)
30          populate_message_types(method_descriptor.input_type, state, True, set())
31          populate_message_types(method_descriptor.output_type, state, False, set())
32  
33  
34  def populate_message_types(field_descriptor, state, is_input, visited):
35      """
36      Given a field descriptor, recursively walk through the referenced message types and add
37      information being referenced into the GenerateSchemaState.
38      """
39      if field_descriptor in visited:
40          # Break the loop for recursive types.
41          return
42      visited.add(field_descriptor)
43      if is_input:
44          add_message_descriptor_to_list(field_descriptor, state.inputs)
45      else:
46          add_message_descriptor_to_list(field_descriptor, state.types)
47  
48      for sub_field in field_descriptor.fields:
49          type = sub_field.type
50          if type in (FieldDescriptor.TYPE_MESSAGE, FieldDescriptor.TYPE_GROUP):
51              populate_message_types(sub_field.message_type, state, is_input, visited)
52          elif type == FieldDescriptor.TYPE_ENUM:
53              state.enums.add(sub_field.enum_type)
54          else:
55              continue
56  
57  
58  def add_message_descriptor_to_list(descriptor, target_list):
59      # Always put the referenced message at the beginning, so that when generating the schema,
60      # the ordering can be maintained in a way that correspond to the reference graph.
61      # list.remove() and insert(0) are not optimal in terms of efficiency but are fine because
62      # the amount of data is very small here.
63      if descriptor not in target_list:
64          target_list.insert(0, descriptor)
65      else:
66          target_list.remove(descriptor)
67          target_list.insert(0, descriptor)