graphql_schema_extensions.py
1 import math 2 3 import graphene 4 from graphql import ( 5 DirectiveLocation, 6 GraphQLArgument, 7 GraphQLDirective, 8 GraphQLNonNull, 9 GraphQLString, 10 ) 11 12 import mlflow 13 from mlflow.server.graphql.autogenerated_graphql_schema import ( 14 MlflowExperiment, 15 MlflowMetric, 16 MlflowModelVersion, 17 MlflowRun, 18 MlflowSearchRunsInput, 19 MlflowSearchRunsResponse, 20 MutationType, 21 QueryType, 22 ) 23 from mlflow.utils.proto_json_utils import parse_dict 24 25 # Component identifier, to keep compatible with Databricks in-house implementations. 26 ComponentDirective = GraphQLDirective( 27 name="component", 28 locations=[ 29 DirectiveLocation.QUERY, 30 DirectiveLocation.MUTATION, 31 ], 32 args={"name": GraphQLArgument(GraphQLNonNull(GraphQLString))}, 33 ) 34 35 36 class Test(graphene.ObjectType): 37 output = graphene.String(description="Echoes the input string") 38 39 40 class TestMutation(graphene.ObjectType): 41 output = graphene.String(description="Echoes the input string") 42 43 44 class MlflowRunExtension(MlflowRun): 45 experiment = graphene.Field(MlflowExperiment) 46 model_versions = graphene.List(graphene.NonNull(MlflowModelVersion)) 47 48 def resolve_experiment(self, info): 49 experiment_id = self.info.experiment_id 50 input_dict = {"experiment_id": experiment_id} 51 request_message = mlflow.protos.service_pb2.GetExperiment() 52 parse_dict(input_dict, request_message) 53 return mlflow.server.handlers.get_experiment_impl(request_message).experiment 54 55 def resolve_model_versions(self, info): 56 run_id = self.info.run_id 57 input_dict = {"filter": f"run_id='{run_id}'"} 58 request_message = mlflow.protos.model_registry_pb2.SearchModelVersions() 59 parse_dict(input_dict, request_message) 60 return mlflow.server.handlers.search_model_versions_impl(request_message).model_versions 61 62 63 class MlflowMetricExtension(MlflowMetric): 64 value = graphene.Float() 65 66 # metric values that are NaN will cause an error in graphQL validation as 67 # the type is Float. as a workaround, we return None if the value is NaN. 68 def resolve_value(self, info): 69 return None if math.isnan(self.value) else self.value 70 71 72 class Query(QueryType): 73 test = graphene.Field(Test, input_string=graphene.String(), description="Simple echoing field") 74 mlflow_search_runs = graphene.Field(MlflowSearchRunsResponse, input=MlflowSearchRunsInput()) 75 76 def resolve_test(self, info, input_string): 77 return {"output": input_string} 78 79 def resolve_mlflow_search_runs(self, info, input): 80 input_dict = vars(input) 81 request_message = mlflow.protos.service_pb2.SearchRuns() 82 parse_dict(input_dict, request_message) 83 return mlflow.server.handlers.search_runs_impl(request_message) 84 85 86 class Mutation(MutationType): 87 testMutation = graphene.Field( 88 TestMutation, input_string=graphene.String(), description="Simple echoing field" 89 ) 90 91 def resolve_test_mutation(self, info, input_string): 92 return {"output": input_string} 93 94 95 schema = graphene.Schema(query=Query, mutation=Mutation, directives=[ComponentDirective])