/ dev / proto_to_graphql / schema_autogeneration.py
schema_autogeneration.py
  1  import ast
  2  
  3  from autogeneration_utils import (
  4      DUMMY_FIELD,
  5      INDENT,
  6      INDENT2,
  7      SCHEMA_EXTENSION,
  8      SCHEMA_EXTENSION_MODULE,
  9      get_descriptor_full_pascal_name,
 10      get_method_name,
 11      method_descriptor_to_generated_pb2_file_name,
 12  )
 13  from google.protobuf.descriptor import FieldDescriptor
 14  from string_utils import camel_to_snake, snake_to_pascal
 15  
 16  # Mapping from proto descriptor type to graphene object type.
 17  PROTO_TO_GRAPHENE_TYPE = {
 18      FieldDescriptor.TYPE_BOOL: "graphene.Boolean",
 19      FieldDescriptor.TYPE_FLOAT: "graphene.Float",
 20      FieldDescriptor.TYPE_INT32: "graphene.Int",
 21      FieldDescriptor.TYPE_INT64: "LongString",
 22      FieldDescriptor.TYPE_STRING: "graphene.String",
 23      FieldDescriptor.TYPE_DOUBLE: "graphene.Float",
 24      FieldDescriptor.TYPE_UINT32: "graphene.Int",
 25      FieldDescriptor.TYPE_UINT64: "LongString",
 26      FieldDescriptor.TYPE_SINT32: "graphene.Int",
 27      FieldDescriptor.TYPE_SINT64: "LongString",
 28      FieldDescriptor.TYPE_BYTES: "graphene.String",
 29      FieldDescriptor.TYPE_FIXED32: "graphene.Int",
 30      FieldDescriptor.TYPE_FIXED64: "LongString",
 31      FieldDescriptor.TYPE_SFIXED32: "graphene.Int",
 32      FieldDescriptor.TYPE_SFIXED64: "LongString",
 33      FieldDescriptor.TYPE_ENUM: "graphene.Enum",
 34  }
 35  
 36  """
 37  Based on graphql_schema_extensions.py, constructs a map from the name of the
 38  extended class to the name of the extending class.
 39  For example
 40      class AutogenExtension(OriginalAutogen)
 41  would give us {"OriginalAutogen": "AutogenExtension"}
 42  """
 43  
 44  
 45  class ClassInheritanceVisitor(ast.NodeVisitor):
 46      def __init__(self):
 47          self.inheritance_map = {}
 48  
 49      def visit_ClassDef(self, node):
 50          for base in node.bases:
 51              if isinstance(base, ast.Name):  # Direct superclass
 52                  if base.id in self.inheritance_map:
 53                      raise Exception(
 54                          f"{base.id} is being extended more than once in {SCHEMA_EXTENSION}. "
 55                          + "A GraphQL schema class should not be extended more than once."
 56                      )
 57                  self.inheritance_map[base.id] = node.name
 58          self.generic_visit(node)
 59  
 60  
 61  def get_manual_extensions():
 62      with open(SCHEMA_EXTENSION) as file:
 63          file_content = file.read()
 64  
 65      parsed_content = ast.parse(file_content)
 66      visitor = ClassInheritanceVisitor()
 67      visitor.visit(parsed_content)
 68  
 69      return visitor.inheritance_map
 70  
 71  
 72  # The resulting map
 73  EXTENDED_TO_EXTENDING = get_manual_extensions()
 74  
 75  """
 76  Given the GenerateSchemaState, generate the whole schema with Graphene.
 77  """
 78  
 79  
 80  def generate_schema(state):
 81      schema_builder = ""
 82      schema_builder += "# GENERATED FILE. PLEASE DON'T MODIFY.\n"
 83      schema_builder += "# Run uv run ./dev/proto_to_graphql/code_generator.py to regenerate.\n"
 84      schema_builder += "import graphene\n"
 85      schema_builder += "import mlflow\n"
 86      schema_builder += "from mlflow.server.graphql.graphql_custom_scalars import LongString\n"
 87      schema_builder += "from mlflow.server.graphql.graphql_errors import ApiError\n"
 88      schema_builder += "from mlflow.utils.proto_json_utils import parse_dict\n"
 89      schema_builder += "\n"
 90  
 91      for enum in sorted(state.enums, key=lambda item: item.full_name):
 92          pascal_class_name = snake_to_pascal(get_descriptor_full_pascal_name(enum))
 93          schema_builder += f"\nclass {pascal_class_name}(graphene.Enum):"
 94          for i in range(len(enum.values)):
 95              value = enum.values[i]
 96              # enum indices start from 1
 97              schema_builder += f"""\n{INDENT}{value.name} = {i + 1}"""
 98          schema_builder += "\n\n"
 99  
100      for type in state.types:
101          pascal_class_name = snake_to_pascal(get_descriptor_full_pascal_name(type))
102          schema_builder += f"\nclass {pascal_class_name}(graphene.ObjectType):"
103          for field in type.fields:
104              graphene_type = get_graphene_type_for_field(field, False)
105              schema_builder += f"\n{INDENT}{camel_to_snake(field.name)} = {graphene_type}"
106  
107          if type in state.outputs:
108              schema_builder += f"\n{INDENT}apiError = graphene.Field(ApiError)"
109  
110          if len(type.fields) == 0:
111              schema_builder += f"\n{INDENT}{DUMMY_FIELD}"
112  
113          schema_builder += "\n\n"
114  
115      for input in state.inputs:
116          pascal_class_name = snake_to_pascal(get_descriptor_full_pascal_name(input)) + "Input"
117          schema_builder += f"\nclass {pascal_class_name}(graphene.InputObjectType):"
118          for field in input.fields:
119              graphene_type = get_graphene_type_for_field(field, True)
120              schema_builder += f"\n{INDENT}{camel_to_snake(field.name)} = {graphene_type}"
121          if len(input.fields) == 0:
122              schema_builder += f"\n{INDENT}{DUMMY_FIELD}"
123  
124          schema_builder += "\n\n"
125  
126      schema_builder += "\nclass QueryType(graphene.ObjectType):"
127  
128      if len(state.queries) == 0:
129          schema_builder += f"\n{INDENT}pass"
130  
131      for query in sorted(state.queries, key=lambda item: item.name):
132          schema_builder += proto_method_to_graphql_operation(query)
133  
134      schema_builder += "\n"
135  
136      for query in sorted(state.queries, key=lambda item: item.name):
137          schema_builder += generate_resolver_function(query)
138  
139      schema_builder += "\n"
140      schema_builder += "\nclass MutationType(graphene.ObjectType):"
141  
142      if len(state.mutations) == 0:
143          schema_builder += f"\n{INDENT}pass"
144  
145      for mutation in sorted(state.mutations, key=lambda item: item.name):
146          schema_builder += proto_method_to_graphql_operation(mutation)
147  
148      schema_builder += "\n"
149  
150      for mutation in sorted(state.mutations, key=lambda item: item.name):
151          schema_builder += generate_resolver_function(mutation)
152  
153      return schema_builder
154  
155  
156  def apply_schema_extension(referenced_class_name):
157      if referenced_class_name in EXTENDED_TO_EXTENDING:
158          # Using dotted module path as pointed out in the linked GitHub issue.r
159          # This is an undocumented feature of Graphene.
160          # https://github.com/graphql-python/graphene/issues/110#issuecomment-1219737639
161          return f"'{SCHEMA_EXTENSION_MODULE}.{EXTENDED_TO_EXTENDING[referenced_class_name]}'"
162      else:
163          return referenced_class_name
164  
165  
166  def get_graphene_type_for_field(field, is_input):
167      if field.type == FieldDescriptor.TYPE_ENUM:
168          referenced_class_name = apply_schema_extension(
169              get_descriptor_full_pascal_name(field.enum_type)
170          )
171          if field.label == FieldDescriptor.LABEL_REPEATED:
172              return f"graphene.List(graphene.NonNull({referenced_class_name}))"
173          else:
174              return f"graphene.Field({referenced_class_name})"
175      elif field.type in (FieldDescriptor.TYPE_GROUP, FieldDescriptor.TYPE_MESSAGE):
176          if is_input:
177              referenced_class_name = apply_schema_extension(
178                  f"{get_descriptor_full_pascal_name(field.message_type)}Input"
179              )
180              field_type_base = f"graphene.InputField({referenced_class_name})"
181          else:
182              referenced_class_name = apply_schema_extension(
183                  get_descriptor_full_pascal_name(field.message_type)
184              )
185              field_type_base = f"graphene.Field({referenced_class_name})"
186          if field.label == FieldDescriptor.LABEL_REPEATED:
187              return f"graphene.List(graphene.NonNull({referenced_class_name}))"
188          else:
189              return field_type_base
190      else:
191          field_type_base = PROTO_TO_GRAPHENE_TYPE[field.type]
192          if field.label == FieldDescriptor.LABEL_REPEATED:
193              return f"graphene.List({field_type_base})"
194          else:
195              return f"{field_type_base}()"
196  
197  
198  def proto_method_to_graphql_operation(method):
199      method_name = get_method_name(method)
200      input_descriptor = method.input_type
201      output_descriptor = method.output_type
202      input_class_name = get_descriptor_full_pascal_name(input_descriptor) + "Input"
203      out_put_class_name = get_descriptor_full_pascal_name(output_descriptor)
204      field_def = f"graphene.Field({out_put_class_name}, input={input_class_name}())"
205      return f"\n{INDENT}{method_name} = {field_def}"
206  
207  
208  def generate_resolver_function(method):
209      full_method_name = get_method_name(method)
210      snake_case_method_name = camel_to_snake(method.name)
211      pascal_case_method_name = snake_to_pascal(snake_case_method_name)
212      pb2_file_name = method_descriptor_to_generated_pb2_file_name(method)
213  
214      function_builder = ""
215      function_builder += f"\n{INDENT}def resolve_{full_method_name}(self, info, input):"
216      function_builder += f"\n{INDENT2}input_dict = vars(input)"
217      function_builder += (
218          f"\n{INDENT2}request_message = mlflow.protos.{pb2_file_name}.{pascal_case_method_name}()"
219      )
220      function_builder += f"\n{INDENT2}parse_dict(input_dict, request_message)"
221      function_builder += (
222          f"\n{INDENT2}return mlflow.server.handlers.{snake_case_method_name}_impl(request_message)"
223      )
224      function_builder += "\n"
225      return function_builder