test_traces_generator.py
1 import pytest 2 3 from mlflow import MlflowClient, get_experiment_by_name, set_experiment 4 from mlflow.demo.base import DEMO_EXPERIMENT_NAME, DemoFeature, DemoResult 5 from mlflow.demo.generators.traces import ( 6 DEMO_TRACE_TYPE_TAG, 7 DEMO_VERSION_TAG, 8 TracesDemoGenerator, 9 ) 10 11 12 @pytest.fixture 13 def traces_generator(): 14 generator = TracesDemoGenerator() 15 original_version = generator.version 16 yield generator 17 TracesDemoGenerator.version = original_version 18 19 20 def test_generator_attributes(): 21 generator = TracesDemoGenerator() 22 assert generator.name == DemoFeature.TRACES 23 assert generator.version == 2 24 25 26 def test_data_exists_false_when_no_experiment(): 27 generator = TracesDemoGenerator() 28 assert generator._data_exists() is False 29 30 31 def test_data_exists_false_when_experiment_empty(): 32 set_experiment(DEMO_EXPERIMENT_NAME) 33 generator = TracesDemoGenerator() 34 assert generator._data_exists() is False 35 36 37 def test_generate_creates_traces(): 38 generator = TracesDemoGenerator() 39 result = generator.generate() 40 41 assert isinstance(result, DemoResult) 42 assert result.feature == DemoFeature.TRACES 43 assert len(result.entity_ids) > 0 44 assert "experiments" in result.navigation_url 45 46 47 def test_generate_creates_experiment(): 48 generator = TracesDemoGenerator() 49 generator.generate() 50 51 experiment = get_experiment_by_name(DEMO_EXPERIMENT_NAME) 52 assert experiment is not None 53 assert experiment.lifecycle_stage == "active" 54 55 56 def test_data_exists_true_after_generate(): 57 generator = TracesDemoGenerator() 58 assert generator._data_exists() is False 59 60 generator.generate() 61 62 assert generator._data_exists() is True 63 64 65 def test_delete_demo_removes_traces(): 66 generator = TracesDemoGenerator() 67 generator.generate() 68 assert generator._data_exists() is True 69 70 generator.delete_demo() 71 72 assert generator._data_exists() is False 73 74 75 def test_traces_have_expected_structure(): 76 generator = TracesDemoGenerator() 77 generator.generate() 78 79 experiment = get_experiment_by_name(DEMO_EXPERIMENT_NAME) 80 client = MlflowClient() 81 traces = client.search_traces(locations=[experiment.experiment_id], max_results=100) 82 83 assert len(traces) > 0 84 85 all_span_names = set() 86 for trace in traces: 87 all_span_names.update(span.name for span in trace.data.spans) 88 89 assert "rag_pipeline" in all_span_names 90 assert "agent" in all_span_names 91 assert "chat_agent" in all_span_names 92 assert "prompt_chain" in all_span_names 93 assert "render_prompt" in all_span_names 94 assert "embed_query" in all_span_names 95 assert "retrieve_docs" in all_span_names 96 assert "generate_response" in all_span_names 97 98 99 def test_traces_have_version_metadata(): 100 generator = TracesDemoGenerator() 101 generator.generate() 102 103 experiment = get_experiment_by_name(DEMO_EXPERIMENT_NAME) 104 client = MlflowClient() 105 traces = client.search_traces(locations=[experiment.experiment_id], max_results=100) 106 107 v1_traces = [t for t in traces if t.info.trace_metadata.get(DEMO_VERSION_TAG) == "v1"] 108 v2_traces = [t for t in traces if t.info.trace_metadata.get(DEMO_VERSION_TAG) == "v2"] 109 110 # 2 RAG + 2 agent + 6 prompt + 4 multimodal + 7 session = 21 per version 111 assert len(v1_traces) == 21 112 assert len(v2_traces) == 21 113 assert len(traces) == 42 114 115 116 def test_traces_have_type_metadata(): 117 generator = TracesDemoGenerator() 118 generator.generate() 119 120 experiment = get_experiment_by_name(DEMO_EXPERIMENT_NAME) 121 client = MlflowClient() 122 traces = client.search_traces(locations=[experiment.experiment_id], max_results=50) 123 124 rag_traces = [t for t in traces if t.info.trace_metadata.get(DEMO_TRACE_TYPE_TAG) == "rag"] 125 agent_traces = [t for t in traces if t.info.trace_metadata.get(DEMO_TRACE_TYPE_TAG) == "agent"] 126 prompt_traces = [ 127 t for t in traces if t.info.trace_metadata.get(DEMO_TRACE_TYPE_TAG) == "prompt" 128 ] 129 session_traces = [ 130 t for t in traces if t.info.trace_metadata.get(DEMO_TRACE_TYPE_TAG) == "session" 131 ] 132 133 # 2 RAG per version = 4 total 134 # 2 agent per version = 4 total 135 # 6 prompt per version = 12 total 136 # 7 session per version = 14 total 137 assert len(rag_traces) == 4 138 assert len(agent_traces) == 4 139 assert len(prompt_traces) == 12 140 assert len(session_traces) == 14 141 142 143 def test_is_generated_checks_version(traces_generator): 144 traces_generator.generate() 145 traces_generator.store_version() 146 147 assert traces_generator.is_generated() is True 148 149 TracesDemoGenerator.version = 99 150 assert traces_generator.is_generated() is False