code_generator.py
1 import os 2 3 from autogeneration_utils import AUTOGENERATED_SCHEMA, AUTOGENERATED_SDL_SCHEMA 4 from parsing_utils import process_method 5 from schema_autogeneration import generate_schema 6 7 from mlflow.protos import model_registry_pb2, service_pb2 8 from mlflow.server.graphql.graphql_schema_extensions import schema 9 10 # Add proto descriptors to onboard RPCs to graphql. 11 ONBOARDED_DESCRIPTORS = [service_pb2.DESCRIPTOR, model_registry_pb2.DESCRIPTOR] 12 13 14 class GenerateSchemaState: 15 def __init__(self): 16 self.queries = set() # method_descriptor 17 self.mutations = set() # method_descriptor 18 self.inputs = [] # field_descriptor 19 self.outputs = set() # field_descriptor 20 self.types = [] # field_descriptor 21 self.enums = set() # enum_descriptor 22 self.method_names = set() # package_name_method_name 23 24 25 # Entry point for generating the GraphQL schema. 26 def generate_code(): 27 state = GenerateSchemaState() 28 for file_descriptor in ONBOARDED_DESCRIPTORS: 29 for service_name, service_descriptor in file_descriptor.services_by_name.items(): 30 for method_name, method_descriptor in service_descriptor.methods_by_name.items(): 31 process_method(method_descriptor, state) 32 33 generated_schema = generate_schema(state) 34 35 os.makedirs(os.path.dirname(AUTOGENERATED_SCHEMA), exist_ok=True) 36 37 with open(AUTOGENERATED_SCHEMA, "w") as file: 38 file.write(generated_schema) 39 40 # Generate the sdl schema for typescript type generation. 41 sdl_schema = str(schema) 42 sdl_schema = f"""# GENERATED FILE. PLEASE DON'T MODIFY. 43 # Run uv run ./dev/proto_to_graphql/code_generator.py to regenerate. 44 45 {sdl_schema} 46 47 """ 48 49 with open(AUTOGENERATED_SDL_SCHEMA, "w") as f: 50 f.write(sdl_schema) 51 52 53 def main(): 54 generate_code() 55 56 57 if __name__ == "__main__": 58 main()