/ tests / demo / test_traces_generator.py
test_traces_generator.py
  1  import pytest
  2  
  3  from mlflow import MlflowClient, get_experiment_by_name, set_experiment
  4  from mlflow.demo.base import DEMO_EXPERIMENT_NAME, DemoFeature, DemoResult
  5  from mlflow.demo.generators.traces import (
  6      DEMO_TRACE_TYPE_TAG,
  7      DEMO_VERSION_TAG,
  8      TracesDemoGenerator,
  9  )
 10  
 11  
 12  @pytest.fixture
 13  def traces_generator():
 14      generator = TracesDemoGenerator()
 15      original_version = generator.version
 16      yield generator
 17      TracesDemoGenerator.version = original_version
 18  
 19  
 20  def test_generator_attributes():
 21      generator = TracesDemoGenerator()
 22      assert generator.name == DemoFeature.TRACES
 23      assert generator.version == 2
 24  
 25  
 26  def test_data_exists_false_when_no_experiment():
 27      generator = TracesDemoGenerator()
 28      assert generator._data_exists() is False
 29  
 30  
 31  def test_data_exists_false_when_experiment_empty():
 32      set_experiment(DEMO_EXPERIMENT_NAME)
 33      generator = TracesDemoGenerator()
 34      assert generator._data_exists() is False
 35  
 36  
 37  def test_generate_creates_traces():
 38      generator = TracesDemoGenerator()
 39      result = generator.generate()
 40  
 41      assert isinstance(result, DemoResult)
 42      assert result.feature == DemoFeature.TRACES
 43      assert len(result.entity_ids) > 0
 44      assert "experiments" in result.navigation_url
 45  
 46  
 47  def test_generate_creates_experiment():
 48      generator = TracesDemoGenerator()
 49      generator.generate()
 50  
 51      experiment = get_experiment_by_name(DEMO_EXPERIMENT_NAME)
 52      assert experiment is not None
 53      assert experiment.lifecycle_stage == "active"
 54  
 55  
 56  def test_data_exists_true_after_generate():
 57      generator = TracesDemoGenerator()
 58      assert generator._data_exists() is False
 59  
 60      generator.generate()
 61  
 62      assert generator._data_exists() is True
 63  
 64  
 65  def test_delete_demo_removes_traces():
 66      generator = TracesDemoGenerator()
 67      generator.generate()
 68      assert generator._data_exists() is True
 69  
 70      generator.delete_demo()
 71  
 72      assert generator._data_exists() is False
 73  
 74  
 75  def test_traces_have_expected_structure():
 76      generator = TracesDemoGenerator()
 77      generator.generate()
 78  
 79      experiment = get_experiment_by_name(DEMO_EXPERIMENT_NAME)
 80      client = MlflowClient()
 81      traces = client.search_traces(locations=[experiment.experiment_id], max_results=100)
 82  
 83      assert len(traces) > 0
 84  
 85      all_span_names = set()
 86      for trace in traces:
 87          all_span_names.update(span.name for span in trace.data.spans)
 88  
 89      assert "rag_pipeline" in all_span_names
 90      assert "agent" in all_span_names
 91      assert "chat_agent" in all_span_names
 92      assert "prompt_chain" in all_span_names
 93      assert "render_prompt" in all_span_names
 94      assert "embed_query" in all_span_names
 95      assert "retrieve_docs" in all_span_names
 96      assert "generate_response" in all_span_names
 97  
 98  
 99  def test_traces_have_version_metadata():
100      generator = TracesDemoGenerator()
101      generator.generate()
102  
103      experiment = get_experiment_by_name(DEMO_EXPERIMENT_NAME)
104      client = MlflowClient()
105      traces = client.search_traces(locations=[experiment.experiment_id], max_results=100)
106  
107      v1_traces = [t for t in traces if t.info.trace_metadata.get(DEMO_VERSION_TAG) == "v1"]
108      v2_traces = [t for t in traces if t.info.trace_metadata.get(DEMO_VERSION_TAG) == "v2"]
109  
110      # 2 RAG + 2 agent + 6 prompt + 4 multimodal + 7 session = 21 per version
111      assert len(v1_traces) == 21
112      assert len(v2_traces) == 21
113      assert len(traces) == 42
114  
115  
116  def test_traces_have_type_metadata():
117      generator = TracesDemoGenerator()
118      generator.generate()
119  
120      experiment = get_experiment_by_name(DEMO_EXPERIMENT_NAME)
121      client = MlflowClient()
122      traces = client.search_traces(locations=[experiment.experiment_id], max_results=50)
123  
124      rag_traces = [t for t in traces if t.info.trace_metadata.get(DEMO_TRACE_TYPE_TAG) == "rag"]
125      agent_traces = [t for t in traces if t.info.trace_metadata.get(DEMO_TRACE_TYPE_TAG) == "agent"]
126      prompt_traces = [
127          t for t in traces if t.info.trace_metadata.get(DEMO_TRACE_TYPE_TAG) == "prompt"
128      ]
129      session_traces = [
130          t for t in traces if t.info.trace_metadata.get(DEMO_TRACE_TYPE_TAG) == "session"
131      ]
132  
133      # 2 RAG per version = 4 total
134      # 2 agent per version = 4 total
135      # 6 prompt per version = 12 total
136      # 7 session per version = 14 total
137      assert len(rag_traces) == 4
138      assert len(agent_traces) == 4
139      assert len(prompt_traces) == 12
140      assert len(session_traces) == 14
141  
142  
143  def test_is_generated_checks_version(traces_generator):
144      traces_generator.generate()
145      traces_generator.store_version()
146  
147      assert traces_generator.is_generated() is True
148  
149      TracesDemoGenerator.version = 99
150      assert traces_generator.is_generated() is False