/ mlflow / server / graphql / graphql_schema_extensions.py
graphql_schema_extensions.py
 1  import math
 2  
 3  import graphene
 4  from graphql import (
 5      DirectiveLocation,
 6      GraphQLArgument,
 7      GraphQLDirective,
 8      GraphQLNonNull,
 9      GraphQLString,
10  )
11  
12  import mlflow
13  from mlflow.server.graphql.autogenerated_graphql_schema import (
14      MlflowExperiment,
15      MlflowMetric,
16      MlflowModelVersion,
17      MlflowRun,
18      MlflowSearchRunsInput,
19      MlflowSearchRunsResponse,
20      MutationType,
21      QueryType,
22  )
23  from mlflow.utils.proto_json_utils import parse_dict
24  
25  # Component identifier, to keep compatible with Databricks in-house implementations.
26  ComponentDirective = GraphQLDirective(
27      name="component",
28      locations=[
29          DirectiveLocation.QUERY,
30          DirectiveLocation.MUTATION,
31      ],
32      args={"name": GraphQLArgument(GraphQLNonNull(GraphQLString))},
33  )
34  
35  
36  class Test(graphene.ObjectType):
37      output = graphene.String(description="Echoes the input string")
38  
39  
40  class TestMutation(graphene.ObjectType):
41      output = graphene.String(description="Echoes the input string")
42  
43  
44  class MlflowRunExtension(MlflowRun):
45      experiment = graphene.Field(MlflowExperiment)
46      model_versions = graphene.List(graphene.NonNull(MlflowModelVersion))
47  
48      def resolve_experiment(self, info):
49          experiment_id = self.info.experiment_id
50          input_dict = {"experiment_id": experiment_id}
51          request_message = mlflow.protos.service_pb2.GetExperiment()
52          parse_dict(input_dict, request_message)
53          return mlflow.server.handlers.get_experiment_impl(request_message).experiment
54  
55      def resolve_model_versions(self, info):
56          run_id = self.info.run_id
57          input_dict = {"filter": f"run_id='{run_id}'"}
58          request_message = mlflow.protos.model_registry_pb2.SearchModelVersions()
59          parse_dict(input_dict, request_message)
60          return mlflow.server.handlers.search_model_versions_impl(request_message).model_versions
61  
62  
63  class MlflowMetricExtension(MlflowMetric):
64      value = graphene.Float()
65  
66      # metric values that are NaN will cause an error in graphQL validation as
67      # the type is Float. as a workaround, we return None if the value is NaN.
68      def resolve_value(self, info):
69          return None if math.isnan(self.value) else self.value
70  
71  
72  class Query(QueryType):
73      test = graphene.Field(Test, input_string=graphene.String(), description="Simple echoing field")
74      mlflow_search_runs = graphene.Field(MlflowSearchRunsResponse, input=MlflowSearchRunsInput())
75  
76      def resolve_test(self, info, input_string):
77          return {"output": input_string}
78  
79      def resolve_mlflow_search_runs(self, info, input):
80          input_dict = vars(input)
81          request_message = mlflow.protos.service_pb2.SearchRuns()
82          parse_dict(input_dict, request_message)
83          return mlflow.server.handlers.search_runs_impl(request_message)
84  
85  
86  class Mutation(MutationType):
87      testMutation = graphene.Field(
88          TestMutation, input_string=graphene.String(), description="Simple echoing field"
89      )
90  
91      def resolve_test_mutation(self, info, input_string):
92          return {"output": input_string}
93  
94  
95  schema = graphene.Schema(query=Query, mutation=Mutation, directives=[ComponentDirective])