client.py
1 """ 2 This example demonstrates how to create a trace with multiple spans using the low-level MLflow client APIs. 3 """ 4 5 import mlflow 6 7 exp = mlflow.set_experiment("mlflow-tracing-example") 8 exp_id = exp.experiment_id 9 10 # Initialize MLflow client. 11 client = mlflow.MlflowClient() 12 13 14 def run(x: int, y: int) -> int: 15 # Create a trace. The `start_trace` API returns a root span of the trace. 16 root_span = client.start_trace( 17 name="my_trace", 18 inputs={"x": x, "y": y}, 19 # Tags are key-value pairs associated with the trace. 20 # You can update the tags later using `client.set_trace_tag` API. 21 tags={ 22 "fruit": "apple", 23 "vegetable": "carrot", 24 }, 25 ) 26 27 z = x + y 28 29 # Trace ID is a unique identifier for the trace. You will need this ID 30 # to interact with the trace later using the MLflow client. 31 trace_id = root_span.trace_id 32 33 # Create a child span of the root span. 34 child_span = client.start_span( 35 name="child_span", 36 # Specify the trace ID to which the child span belongs. 37 trace_id=trace_id, 38 # Also specify the ID of the parent span to build the span hierarchy. 39 # You can access the span ID via `span_id` property of the span object. 40 parent_id=root_span.span_id, 41 # Each span has its own inputs. 42 inputs={"z": z}, 43 # Attributes are key-value pairs associated with the span. 44 attributes={ 45 "model": "my_model", 46 "temperature": 0.5, 47 }, 48 ) 49 50 z = z**2 51 52 # End the child span. Please make sure to end the child span before ending the root span. 53 client.end_span( 54 trace_id=trace_id, 55 span_id=child_span.span_id, 56 # Set the output(s) of the span. 57 outputs=z, 58 # Set the completion status, such as "OK" (default), "ERROR", etc. 59 status="OK", 60 ) 61 62 z = z + 1 63 64 # End the root span. 65 client.end_trace( 66 trace_id=trace_id, 67 # Set the output(s) of the span. 68 outputs=z, 69 ) 70 71 return z 72 73 74 assert run(1, 2) == 10 75 76 # Retrieve the trace just created using get_last_active_trace_id() API. 77 trace_id = mlflow.get_last_active_trace_id() 78 trace = client.get_trace(trace_id) 79 80 # Alternatively, you can use search_traces() API 81 # to retrieve the traces from the tracking server. 82 trace = client.search_traces(locations=[exp_id])[0] 83 assert trace.info.tags["fruit"] == "apple" 84 assert trace.info.tags["vegetable"] == "carrot" 85 86 # Update the tags using set_trace_tag() and delete_trace_tag() APIs. 87 client.set_trace_tag(trace.info.trace_id, "fruit", "orange") 88 client.delete_trace_tag(trace.info.trace_id, "vegetable") 89 90 trace = client.get_trace(trace.info.trace_id) 91 assert trace.info.tags["fruit"] == "orange" 92 assert "vegetable" not in trace.info.tags 93 94 # Print the trace in JSON format 95 print(trace.to_json(pretty=True)) 96 97 print( 98 "\033[92m" 99 + "🤖Now run `mlflow server` and open MLflow UI to see the trace visualization!" 100 + "\033[0m" 101 )