/ tests / store / analytics / test_trace_correlation.py
test_trace_correlation.py
  1  import math
  2  
  3  import pytest
  4  
  5  from mlflow.store.analytics.trace_correlation import (
  6      calculate_npmi_from_counts,
  7      calculate_smoothed_npmi,
  8  )
  9  
 10  
 11  @pytest.mark.parametrize(
 12      (
 13          "joint_count",
 14          "filter1_count",
 15          "filter2_count",
 16          "total_count",
 17          "expected_npmi",
 18          "expected_smoothed_range",
 19      ),
 20      [
 21          (10, 10, 10, 100, 1.0, (0.95, 1.0)),
 22          (0, 20, 30, 100, -1.0, None),
 23          (10, 20, 50, 100, 0.0, None),
 24          (100, 100, 100, 100, 1.0, None),
 25      ],
 26      ids=["perfect_positive", "perfect_negative", "independence", "all_match_both"],
 27  )
 28  def test_npmi_correlations(
 29      joint_count, filter1_count, filter2_count, total_count, expected_npmi, expected_smoothed_range
 30  ):
 31      result = calculate_npmi_from_counts(joint_count, filter1_count, filter2_count, total_count)
 32  
 33      if expected_npmi == 0.0:
 34          assert abs(result.npmi) < 0.01
 35      else:
 36          assert result.npmi == expected_npmi
 37  
 38      if expected_smoothed_range:
 39          assert expected_smoothed_range[0] < result.npmi_smoothed <= expected_smoothed_range[1]
 40  
 41  
 42  @pytest.mark.parametrize(
 43      ("joint_count", "filter1_count", "filter2_count", "total_count"),
 44      [
 45          (0, 0, 10, 100),
 46          (0, 10, 0, 100),
 47          (0, 0, 0, 100),
 48          (0, 0, 0, 0),
 49          (50, 30, 40, 100),
 50      ],
 51      ids=["zero_filter1", "zero_filter2", "both_zero", "empty_dataset", "inconsistent"],
 52  )
 53  def test_npmi_undefined_cases(joint_count, filter1_count, filter2_count, total_count):
 54      result = calculate_npmi_from_counts(joint_count, filter1_count, filter2_count, total_count)
 55      assert math.isnan(result.npmi)
 56  
 57  
 58  def test_npmi_partial_overlap():
 59      result = calculate_npmi_from_counts(
 60          joint_count=15, filter1_count=40, filter2_count=30, total_count=100
 61      )
 62      assert 0 < result.npmi < 1
 63      assert 0.1 < result.npmi < 0.2
 64  
 65  
 66  def test_npmi_with_smoothing():
 67      result = calculate_npmi_from_counts(
 68          joint_count=0, filter1_count=2, filter2_count=3, total_count=10
 69      )
 70      assert result.npmi == -1.0
 71      assert result.npmi_smoothed is not None
 72      assert -1.0 < result.npmi_smoothed < 0
 73  
 74      npmi_smooth = calculate_smoothed_npmi(
 75          joint_count=0, filter1_count=2, filter2_count=3, total_count=10
 76      )
 77      assert -1.0 < npmi_smooth < 0
 78  
 79  
 80  def test_npmi_all_traces_match_both():
 81      result = calculate_npmi_from_counts(
 82          joint_count=100, filter1_count=100, filter2_count=100, total_count=100
 83      )
 84      assert result.npmi == 1.0
 85  
 86  
 87  @pytest.mark.parametrize(
 88      ("joint_count", "filter1_count", "filter2_count", "total_count"),
 89      [
 90          (50, 50, 50, 100),
 91          (1, 2, 3, 100),
 92          (99, 99, 99, 100),
 93          (25, 50, 75, 100),
 94      ],
 95      ids=["half_match", "small_counts", "near_all", "quarter_match"],
 96  )
 97  def test_npmi_clamping(joint_count, filter1_count, filter2_count, total_count):
 98      result = calculate_npmi_from_counts(joint_count, filter1_count, filter2_count, total_count)
 99      if not math.isnan(result.npmi):
100          assert -1.0 <= result.npmi <= 1.0
101  
102  
103  def test_both_npmi_values_returned():
104      result = calculate_npmi_from_counts(
105          joint_count=0, filter1_count=10, filter2_count=15, total_count=100
106      )
107  
108      assert result.npmi == -1.0
109      assert result.npmi_smoothed is not None
110      assert -1.0 < result.npmi_smoothed < 0
111  
112      result2 = calculate_npmi_from_counts(
113          joint_count=5, filter1_count=10, filter2_count=15, total_count=100
114      )
115  
116      assert result2.npmi > 0
117      assert result2.npmi_smoothed > 0
118      assert abs(result2.npmi - result2.npmi_smoothed) > 0.001
119  
120  
121  def test_symmetry():
122      result_ab = calculate_npmi_from_counts(15, 30, 40, 100)
123      result_reversed = calculate_npmi_from_counts(15, 40, 30, 100)
124      assert abs(result_ab.npmi - result_reversed.npmi) < 1e-10
125  
126  
127  def test_monotonicity_joint_count():
128      npmis = []
129      for joint in range(0, 21):
130          result = calculate_npmi_from_counts(joint, 30, 40, 100)
131          npmis.append(result.npmi)
132  
133      for i in range(1, len(npmis)):
134          if not math.isnan(npmis[i]) and not math.isnan(npmis[i - 1]):
135              assert npmis[i] >= npmis[i - 1]
136  
137  
138  @pytest.mark.parametrize(
139      ("joint_count", "filter1_count", "filter2_count", "total_count", "expected_range"),
140      [
141          (30, 30, 50, 100, (0.5, 1.0)),
142          (1, 30, 50, 100, (-0.7, -0.5)),
143      ],
144      ids=["high_overlap", "low_overlap"],
145  )
146  def test_boundary_values(joint_count, filter1_count, filter2_count, total_count, expected_range):
147      result = calculate_npmi_from_counts(joint_count, filter1_count, filter2_count, total_count)
148      assert expected_range[0] < result.npmi < expected_range[1]