test_trace_correlation.py
1 import math 2 3 import pytest 4 5 from mlflow.store.analytics.trace_correlation import ( 6 calculate_npmi_from_counts, 7 calculate_smoothed_npmi, 8 ) 9 10 11 @pytest.mark.parametrize( 12 ( 13 "joint_count", 14 "filter1_count", 15 "filter2_count", 16 "total_count", 17 "expected_npmi", 18 "expected_smoothed_range", 19 ), 20 [ 21 (10, 10, 10, 100, 1.0, (0.95, 1.0)), 22 (0, 20, 30, 100, -1.0, None), 23 (10, 20, 50, 100, 0.0, None), 24 (100, 100, 100, 100, 1.0, None), 25 ], 26 ids=["perfect_positive", "perfect_negative", "independence", "all_match_both"], 27 ) 28 def test_npmi_correlations( 29 joint_count, filter1_count, filter2_count, total_count, expected_npmi, expected_smoothed_range 30 ): 31 result = calculate_npmi_from_counts(joint_count, filter1_count, filter2_count, total_count) 32 33 if expected_npmi == 0.0: 34 assert abs(result.npmi) < 0.01 35 else: 36 assert result.npmi == expected_npmi 37 38 if expected_smoothed_range: 39 assert expected_smoothed_range[0] < result.npmi_smoothed <= expected_smoothed_range[1] 40 41 42 @pytest.mark.parametrize( 43 ("joint_count", "filter1_count", "filter2_count", "total_count"), 44 [ 45 (0, 0, 10, 100), 46 (0, 10, 0, 100), 47 (0, 0, 0, 100), 48 (0, 0, 0, 0), 49 (50, 30, 40, 100), 50 ], 51 ids=["zero_filter1", "zero_filter2", "both_zero", "empty_dataset", "inconsistent"], 52 ) 53 def test_npmi_undefined_cases(joint_count, filter1_count, filter2_count, total_count): 54 result = calculate_npmi_from_counts(joint_count, filter1_count, filter2_count, total_count) 55 assert math.isnan(result.npmi) 56 57 58 def test_npmi_partial_overlap(): 59 result = calculate_npmi_from_counts( 60 joint_count=15, filter1_count=40, filter2_count=30, total_count=100 61 ) 62 assert 0 < result.npmi < 1 63 assert 0.1 < result.npmi < 0.2 64 65 66 def test_npmi_with_smoothing(): 67 result = calculate_npmi_from_counts( 68 joint_count=0, filter1_count=2, filter2_count=3, total_count=10 69 ) 70 assert result.npmi == -1.0 71 assert result.npmi_smoothed is not None 72 assert -1.0 < result.npmi_smoothed < 0 73 74 npmi_smooth = calculate_smoothed_npmi( 75 joint_count=0, filter1_count=2, filter2_count=3, total_count=10 76 ) 77 assert -1.0 < npmi_smooth < 0 78 79 80 def test_npmi_all_traces_match_both(): 81 result = calculate_npmi_from_counts( 82 joint_count=100, filter1_count=100, filter2_count=100, total_count=100 83 ) 84 assert result.npmi == 1.0 85 86 87 @pytest.mark.parametrize( 88 ("joint_count", "filter1_count", "filter2_count", "total_count"), 89 [ 90 (50, 50, 50, 100), 91 (1, 2, 3, 100), 92 (99, 99, 99, 100), 93 (25, 50, 75, 100), 94 ], 95 ids=["half_match", "small_counts", "near_all", "quarter_match"], 96 ) 97 def test_npmi_clamping(joint_count, filter1_count, filter2_count, total_count): 98 result = calculate_npmi_from_counts(joint_count, filter1_count, filter2_count, total_count) 99 if not math.isnan(result.npmi): 100 assert -1.0 <= result.npmi <= 1.0 101 102 103 def test_both_npmi_values_returned(): 104 result = calculate_npmi_from_counts( 105 joint_count=0, filter1_count=10, filter2_count=15, total_count=100 106 ) 107 108 assert result.npmi == -1.0 109 assert result.npmi_smoothed is not None 110 assert -1.0 < result.npmi_smoothed < 0 111 112 result2 = calculate_npmi_from_counts( 113 joint_count=5, filter1_count=10, filter2_count=15, total_count=100 114 ) 115 116 assert result2.npmi > 0 117 assert result2.npmi_smoothed > 0 118 assert abs(result2.npmi - result2.npmi_smoothed) > 0.001 119 120 121 def test_symmetry(): 122 result_ab = calculate_npmi_from_counts(15, 30, 40, 100) 123 result_reversed = calculate_npmi_from_counts(15, 40, 30, 100) 124 assert abs(result_ab.npmi - result_reversed.npmi) < 1e-10 125 126 127 def test_monotonicity_joint_count(): 128 npmis = [] 129 for joint in range(0, 21): 130 result = calculate_npmi_from_counts(joint, 30, 40, 100) 131 npmis.append(result.npmi) 132 133 for i in range(1, len(npmis)): 134 if not math.isnan(npmis[i]) and not math.isnan(npmis[i - 1]): 135 assert npmis[i] >= npmis[i - 1] 136 137 138 @pytest.mark.parametrize( 139 ("joint_count", "filter1_count", "filter2_count", "total_count", "expected_range"), 140 [ 141 (30, 30, 50, 100, (0.5, 1.0)), 142 (1, 30, 50, 100, (-0.7, -0.5)), 143 ], 144 ids=["high_overlap", "low_overlap"], 145 ) 146 def test_boundary_values(joint_count, filter1_count, filter2_count, total_count, expected_range): 147 result = calculate_npmi_from_counts(joint_count, filter1_count, filter2_count, total_count) 148 assert expected_range[0] < result.npmi < expected_range[1]