schema_autogeneration.py
1 import ast 2 3 from autogeneration_utils import ( 4 DUMMY_FIELD, 5 INDENT, 6 INDENT2, 7 SCHEMA_EXTENSION, 8 SCHEMA_EXTENSION_MODULE, 9 get_descriptor_full_pascal_name, 10 get_method_name, 11 method_descriptor_to_generated_pb2_file_name, 12 ) 13 from google.protobuf.descriptor import FieldDescriptor 14 from string_utils import camel_to_snake, snake_to_pascal 15 16 # Mapping from proto descriptor type to graphene object type. 17 PROTO_TO_GRAPHENE_TYPE = { 18 FieldDescriptor.TYPE_BOOL: "graphene.Boolean", 19 FieldDescriptor.TYPE_FLOAT: "graphene.Float", 20 FieldDescriptor.TYPE_INT32: "graphene.Int", 21 FieldDescriptor.TYPE_INT64: "LongString", 22 FieldDescriptor.TYPE_STRING: "graphene.String", 23 FieldDescriptor.TYPE_DOUBLE: "graphene.Float", 24 FieldDescriptor.TYPE_UINT32: "graphene.Int", 25 FieldDescriptor.TYPE_UINT64: "LongString", 26 FieldDescriptor.TYPE_SINT32: "graphene.Int", 27 FieldDescriptor.TYPE_SINT64: "LongString", 28 FieldDescriptor.TYPE_BYTES: "graphene.String", 29 FieldDescriptor.TYPE_FIXED32: "graphene.Int", 30 FieldDescriptor.TYPE_FIXED64: "LongString", 31 FieldDescriptor.TYPE_SFIXED32: "graphene.Int", 32 FieldDescriptor.TYPE_SFIXED64: "LongString", 33 FieldDescriptor.TYPE_ENUM: "graphene.Enum", 34 } 35 36 """ 37 Based on graphql_schema_extensions.py, constructs a map from the name of the 38 extended class to the name of the extending class. 39 For example 40 class AutogenExtension(OriginalAutogen) 41 would give us {"OriginalAutogen": "AutogenExtension"} 42 """ 43 44 45 class ClassInheritanceVisitor(ast.NodeVisitor): 46 def __init__(self): 47 self.inheritance_map = {} 48 49 def visit_ClassDef(self, node): 50 for base in node.bases: 51 if isinstance(base, ast.Name): # Direct superclass 52 if base.id in self.inheritance_map: 53 raise Exception( 54 f"{base.id} is being extended more than once in {SCHEMA_EXTENSION}. " 55 + "A GraphQL schema class should not be extended more than once." 56 ) 57 self.inheritance_map[base.id] = node.name 58 self.generic_visit(node) 59 60 61 def get_manual_extensions(): 62 with open(SCHEMA_EXTENSION) as file: 63 file_content = file.read() 64 65 parsed_content = ast.parse(file_content) 66 visitor = ClassInheritanceVisitor() 67 visitor.visit(parsed_content) 68 69 return visitor.inheritance_map 70 71 72 # The resulting map 73 EXTENDED_TO_EXTENDING = get_manual_extensions() 74 75 """ 76 Given the GenerateSchemaState, generate the whole schema with Graphene. 77 """ 78 79 80 def generate_schema(state): 81 schema_builder = "" 82 schema_builder += "# GENERATED FILE. PLEASE DON'T MODIFY.\n" 83 schema_builder += "# Run uv run ./dev/proto_to_graphql/code_generator.py to regenerate.\n" 84 schema_builder += "import graphene\n" 85 schema_builder += "import mlflow\n" 86 schema_builder += "from mlflow.server.graphql.graphql_custom_scalars import LongString\n" 87 schema_builder += "from mlflow.server.graphql.graphql_errors import ApiError\n" 88 schema_builder += "from mlflow.utils.proto_json_utils import parse_dict\n" 89 schema_builder += "\n" 90 91 for enum in sorted(state.enums, key=lambda item: item.full_name): 92 pascal_class_name = snake_to_pascal(get_descriptor_full_pascal_name(enum)) 93 schema_builder += f"\nclass {pascal_class_name}(graphene.Enum):" 94 for i in range(len(enum.values)): 95 value = enum.values[i] 96 # enum indices start from 1 97 schema_builder += f"""\n{INDENT}{value.name} = {i + 1}""" 98 schema_builder += "\n\n" 99 100 for type in state.types: 101 pascal_class_name = snake_to_pascal(get_descriptor_full_pascal_name(type)) 102 schema_builder += f"\nclass {pascal_class_name}(graphene.ObjectType):" 103 for field in type.fields: 104 graphene_type = get_graphene_type_for_field(field, False) 105 schema_builder += f"\n{INDENT}{camel_to_snake(field.name)} = {graphene_type}" 106 107 if type in state.outputs: 108 schema_builder += f"\n{INDENT}apiError = graphene.Field(ApiError)" 109 110 if len(type.fields) == 0: 111 schema_builder += f"\n{INDENT}{DUMMY_FIELD}" 112 113 schema_builder += "\n\n" 114 115 for input in state.inputs: 116 pascal_class_name = snake_to_pascal(get_descriptor_full_pascal_name(input)) + "Input" 117 schema_builder += f"\nclass {pascal_class_name}(graphene.InputObjectType):" 118 for field in input.fields: 119 graphene_type = get_graphene_type_for_field(field, True) 120 schema_builder += f"\n{INDENT}{camel_to_snake(field.name)} = {graphene_type}" 121 if len(input.fields) == 0: 122 schema_builder += f"\n{INDENT}{DUMMY_FIELD}" 123 124 schema_builder += "\n\n" 125 126 schema_builder += "\nclass QueryType(graphene.ObjectType):" 127 128 if len(state.queries) == 0: 129 schema_builder += f"\n{INDENT}pass" 130 131 for query in sorted(state.queries, key=lambda item: item.name): 132 schema_builder += proto_method_to_graphql_operation(query) 133 134 schema_builder += "\n" 135 136 for query in sorted(state.queries, key=lambda item: item.name): 137 schema_builder += generate_resolver_function(query) 138 139 schema_builder += "\n" 140 schema_builder += "\nclass MutationType(graphene.ObjectType):" 141 142 if len(state.mutations) == 0: 143 schema_builder += f"\n{INDENT}pass" 144 145 for mutation in sorted(state.mutations, key=lambda item: item.name): 146 schema_builder += proto_method_to_graphql_operation(mutation) 147 148 schema_builder += "\n" 149 150 for mutation in sorted(state.mutations, key=lambda item: item.name): 151 schema_builder += generate_resolver_function(mutation) 152 153 return schema_builder 154 155 156 def apply_schema_extension(referenced_class_name): 157 if referenced_class_name in EXTENDED_TO_EXTENDING: 158 # Using dotted module path as pointed out in the linked GitHub issue.r 159 # This is an undocumented feature of Graphene. 160 # https://github.com/graphql-python/graphene/issues/110#issuecomment-1219737639 161 return f"'{SCHEMA_EXTENSION_MODULE}.{EXTENDED_TO_EXTENDING[referenced_class_name]}'" 162 else: 163 return referenced_class_name 164 165 166 def get_graphene_type_for_field(field, is_input): 167 if field.type == FieldDescriptor.TYPE_ENUM: 168 referenced_class_name = apply_schema_extension( 169 get_descriptor_full_pascal_name(field.enum_type) 170 ) 171 if field.label == FieldDescriptor.LABEL_REPEATED: 172 return f"graphene.List(graphene.NonNull({referenced_class_name}))" 173 else: 174 return f"graphene.Field({referenced_class_name})" 175 elif field.type in (FieldDescriptor.TYPE_GROUP, FieldDescriptor.TYPE_MESSAGE): 176 if is_input: 177 referenced_class_name = apply_schema_extension( 178 f"{get_descriptor_full_pascal_name(field.message_type)}Input" 179 ) 180 field_type_base = f"graphene.InputField({referenced_class_name})" 181 else: 182 referenced_class_name = apply_schema_extension( 183 get_descriptor_full_pascal_name(field.message_type) 184 ) 185 field_type_base = f"graphene.Field({referenced_class_name})" 186 if field.label == FieldDescriptor.LABEL_REPEATED: 187 return f"graphene.List(graphene.NonNull({referenced_class_name}))" 188 else: 189 return field_type_base 190 else: 191 field_type_base = PROTO_TO_GRAPHENE_TYPE[field.type] 192 if field.label == FieldDescriptor.LABEL_REPEATED: 193 return f"graphene.List({field_type_base})" 194 else: 195 return f"{field_type_base}()" 196 197 198 def proto_method_to_graphql_operation(method): 199 method_name = get_method_name(method) 200 input_descriptor = method.input_type 201 output_descriptor = method.output_type 202 input_class_name = get_descriptor_full_pascal_name(input_descriptor) + "Input" 203 out_put_class_name = get_descriptor_full_pascal_name(output_descriptor) 204 field_def = f"graphene.Field({out_put_class_name}, input={input_class_name}())" 205 return f"\n{INDENT}{method_name} = {field_def}" 206 207 208 def generate_resolver_function(method): 209 full_method_name = get_method_name(method) 210 snake_case_method_name = camel_to_snake(method.name) 211 pascal_case_method_name = snake_to_pascal(snake_case_method_name) 212 pb2_file_name = method_descriptor_to_generated_pb2_file_name(method) 213 214 function_builder = "" 215 function_builder += f"\n{INDENT}def resolve_{full_method_name}(self, info, input):" 216 function_builder += f"\n{INDENT2}input_dict = vars(input)" 217 function_builder += ( 218 f"\n{INDENT2}request_message = mlflow.protos.{pb2_file_name}.{pascal_case_method_name}()" 219 ) 220 function_builder += f"\n{INDENT2}parse_dict(input_dict, request_message)" 221 function_builder += ( 222 f"\n{INDENT2}return mlflow.server.handlers.{snake_case_method_name}_impl(request_message)" 223 ) 224 function_builder += "\n" 225 return function_builder