/ dev / proto_to_graphql / code_generator.py
code_generator.py
 1  import os
 2  
 3  from autogeneration_utils import AUTOGENERATED_SCHEMA, AUTOGENERATED_SDL_SCHEMA
 4  from parsing_utils import process_method
 5  from schema_autogeneration import generate_schema
 6  
 7  from mlflow.protos import model_registry_pb2, service_pb2
 8  from mlflow.server.graphql.graphql_schema_extensions import schema
 9  
10  # Add proto descriptors to onboard RPCs to graphql.
11  ONBOARDED_DESCRIPTORS = [service_pb2.DESCRIPTOR, model_registry_pb2.DESCRIPTOR]
12  
13  
14  class GenerateSchemaState:
15      def __init__(self):
16          self.queries = set()  # method_descriptor
17          self.mutations = set()  # method_descriptor
18          self.inputs = []  # field_descriptor
19          self.outputs = set()  # field_descriptor
20          self.types = []  # field_descriptor
21          self.enums = set()  # enum_descriptor
22          self.method_names = set()  # package_name_method_name
23  
24  
25  # Entry point for generating the GraphQL schema.
26  def generate_code():
27      state = GenerateSchemaState()
28      for file_descriptor in ONBOARDED_DESCRIPTORS:
29          for service_name, service_descriptor in file_descriptor.services_by_name.items():
30              for method_name, method_descriptor in service_descriptor.methods_by_name.items():
31                  process_method(method_descriptor, state)
32  
33      generated_schema = generate_schema(state)
34  
35      os.makedirs(os.path.dirname(AUTOGENERATED_SCHEMA), exist_ok=True)
36  
37      with open(AUTOGENERATED_SCHEMA, "w") as file:
38          file.write(generated_schema)
39  
40      # Generate the sdl schema for typescript type generation.
41      sdl_schema = str(schema)
42      sdl_schema = f"""# GENERATED FILE. PLEASE DON'T MODIFY.
43  # Run uv run ./dev/proto_to_graphql/code_generator.py to regenerate.
44  
45  {sdl_schema}
46  
47  """
48  
49      with open(AUTOGENERATED_SDL_SCHEMA, "w") as f:
50          f.write(sdl_schema)
51  
52  
53  def main():
54      generate_code()
55  
56  
57  if __name__ == "__main__":
58      main()