parsing_utils.py
1 from autogeneration_utils import get_method_name 2 from google.protobuf.descriptor import FieldDescriptor 3 4 from mlflow.protos import databricks_pb2 5 6 7 def get_method_type(method_descriptor): 8 return method_descriptor.GetOptions().Extensions[databricks_pb2.rpc].endpoints[0].method 9 10 11 def process_method(method_descriptor, state): 12 """ 13 Given a method descriptor, add information being referenced into the GenerateSchemaState. 14 """ 15 if not method_descriptor.GetOptions().HasExtension(databricks_pb2.graphql): 16 return 17 rpcOptions = method_descriptor.GetOptions().Extensions[databricks_pb2.rpc] 18 # Only add those methods that are not internal. 19 if rpcOptions.visibility != databricks_pb2.INTERNAL: 20 name = get_method_name(method_descriptor) 21 if name in state.method_names: 22 return 23 state.method_names.add(name) 24 request_method = get_method_type(method_descriptor) 25 if request_method == "GET": 26 state.queries.add(method_descriptor) 27 else: 28 state.mutations.add(method_descriptor) 29 state.outputs.add(method_descriptor.output_type) 30 populate_message_types(method_descriptor.input_type, state, True, set()) 31 populate_message_types(method_descriptor.output_type, state, False, set()) 32 33 34 def populate_message_types(field_descriptor, state, is_input, visited): 35 """ 36 Given a field descriptor, recursively walk through the referenced message types and add 37 information being referenced into the GenerateSchemaState. 38 """ 39 if field_descriptor in visited: 40 # Break the loop for recursive types. 41 return 42 visited.add(field_descriptor) 43 if is_input: 44 add_message_descriptor_to_list(field_descriptor, state.inputs) 45 else: 46 add_message_descriptor_to_list(field_descriptor, state.types) 47 48 for sub_field in field_descriptor.fields: 49 type = sub_field.type 50 if type in (FieldDescriptor.TYPE_MESSAGE, FieldDescriptor.TYPE_GROUP): 51 populate_message_types(sub_field.message_type, state, is_input, visited) 52 elif type == FieldDescriptor.TYPE_ENUM: 53 state.enums.add(sub_field.enum_type) 54 else: 55 continue 56 57 58 def add_message_descriptor_to_list(descriptor, target_list): 59 # Always put the referenced message at the beginning, so that when generating the schema, 60 # the ordering can be maintained in a way that correspond to the reference graph. 61 # list.remove() and insert(0) are not optimal in terms of efficiency but are fine because 62 # the amount of data is very small here. 63 if descriptor not in target_list: 64 target_list.insert(0, descriptor) 65 else: 66 target_list.remove(descriptor) 67 target_list.insert(0, descriptor)