phase_test.go
1 package agent 2 3 import ( 4 "sync" 5 "testing" 6 "time" 7 ) 8 9 func TestTurnPhase_CountsAsIdle(t *testing.T) { 10 idle := map[TurnPhase]bool{ 11 PhaseAwaitingLLM: true, 12 PhaseForceStop: true, 13 } 14 all := []TurnPhase{ 15 PhaseInit, PhaseSetup, PhaseAwaitingLLM, PhaseRetryingLLM, 16 PhaseCompacting, PhaseAwaitingApproval, PhaseExecutingTools, 17 PhaseInjectingMessage, PhaseForceStop, PhaseDone, 18 } 19 for _, p := range all { 20 want := idle[p] 21 if got := p.CountsAsIdle(); got != want { 22 t.Errorf("%s.CountsAsIdle() = %v, want %v", p, got, want) 23 } 24 } 25 } 26 27 func TestTurnPhase_String(t *testing.T) { 28 cases := map[TurnPhase]string{ 29 PhaseInit: "init", PhaseSetup: "setup", PhaseAwaitingLLM: "awaiting_llm", 30 PhaseRetryingLLM: "retrying_llm", PhaseCompacting: "compacting", 31 PhaseAwaitingApproval: "awaiting_approval", PhaseExecutingTools: "executing_tools", 32 PhaseInjectingMessage: "injecting_message", PhaseForceStop: "force_stop", 33 PhaseDone: "done", 34 } 35 for p, want := range cases { 36 if got := p.String(); got != want { 37 t.Errorf("%d.String() = %q, want %q", int(p), got, want) 38 } 39 } 40 if got := TurnPhase(999).String(); got != "unknown" { 41 t.Errorf("unknown phase: %q", got) 42 } 43 } 44 45 func TestPhaseTracker_EnterAndCurrent(t *testing.T) { 46 tr := newPhaseTracker() 47 p, _, _ := tr.Current() 48 if p != PhaseInit { 49 t.Fatalf("initial phase = %s, want init", p) 50 } 51 tr.Enter(PhaseAwaitingLLM) 52 p, d, _ := tr.Current() 53 if p != PhaseAwaitingLLM { 54 t.Fatalf("after Enter: phase = %s", p) 55 } 56 if d < 0 || d > time.Second { 57 t.Fatalf("since-time unreasonable: %v", d) 58 } 59 } 60 61 func TestPhaseTracker_EnterTransient_RestoresPrev(t *testing.T) { 62 tr := newPhaseTracker() 63 tr.Enter(PhaseCompacting) 64 65 restore := tr.EnterTransient(PhaseAwaitingLLM) 66 p, _, _ := tr.Current() 67 if p != PhaseAwaitingLLM { 68 t.Fatalf("inside transient: phase = %s", p) 69 } 70 71 restore() 72 p, _, _ = tr.Current() 73 if p != PhaseCompacting { 74 t.Fatalf("after restore: phase = %s, want compacting", p) 75 } 76 } 77 78 func TestPhaseTracker_EnterTransient_NestedDoesNotLeak(t *testing.T) { 79 tr := newPhaseTracker() 80 tr.Enter(PhaseCompacting) 81 82 r1 := tr.EnterTransient(PhaseAwaitingLLM) 83 if p, _, _ := tr.Current(); p != PhaseAwaitingLLM { 84 t.Fatalf("first transient: %s", p) 85 } 86 87 // Imagine a nested call also needing AwaitingLLM (uncommon but legal). 88 r2 := tr.EnterTransient(PhaseAwaitingLLM) 89 r2() 90 if p, _, _ := tr.Current(); p != PhaseAwaitingLLM { 91 t.Fatalf("after inner restore, outer transient lost: %s", p) 92 } 93 94 r1() 95 if p, _, _ := tr.Current(); p != PhaseCompacting { 96 t.Fatalf("after outer restore: %s, want compacting", p) 97 } 98 tr.AssertClean() // depth should be 0 99 } 100 101 func TestPhaseTracker_RestoreIdempotent(t *testing.T) { 102 tr := newPhaseTracker() 103 tr.Enter(PhaseSetup) 104 restore := tr.EnterTransient(PhaseAwaitingLLM) 105 restore() 106 restore() // second call must not underflow depth or change phase 107 if p, _, _ := tr.Current(); p != PhaseSetup { 108 t.Fatalf("after double restore: %s", p) 109 } 110 tr.AssertClean() 111 } 112 113 func TestPhaseTracker_SeqBumpsOnEveryTransition(t *testing.T) { 114 tr := newPhaseTracker() 115 _, _, s0 := tr.Current() 116 117 tr.Enter(PhaseAwaitingLLM) 118 _, _, s1 := tr.Current() 119 if s1 <= s0 { 120 t.Fatalf("seq did not bump on Enter: s0=%d s1=%d", s0, s1) 121 } 122 123 // Re-entering the same phase type must still bump seq (so observers can 124 // re-arm dedupes on transition, not phase-type identity). 125 tr.Enter(PhaseAwaitingLLM) 126 _, _, s2 := tr.Current() 127 if s2 <= s1 { 128 t.Fatalf("seq must bump on same-phase re-entry: s1=%d s2=%d", s1, s2) 129 } 130 131 // EnterTransient bumps; restore bumps again. 132 restore := tr.EnterTransient(PhaseAwaitingLLM) 133 _, _, s3 := tr.Current() 134 if s3 <= s2 { 135 t.Fatalf("seq did not bump on EnterTransient: s2=%d s3=%d", s2, s3) 136 } 137 restore() 138 _, _, s4 := tr.Current() 139 if s4 <= s3 { 140 t.Fatalf("seq did not bump on transient restore: s3=%d s4=%d", s3, s4) 141 } 142 } 143 144 func TestPhaseTracker_InvalidFlag(t *testing.T) { 145 tr := newPhaseTracker() 146 if tr.Invalid() { 147 t.Fatal("new tracker should not be invalid") 148 } 149 150 // Trigger a violation via forgotten restore + AssertClean. Under 151 // testing.Testing() this panics, so guard + recover. 152 _ = tr.EnterTransient(PhaseAwaitingLLM) // intentionally drop 153 func() { 154 defer func() { _ = recover() }() 155 tr.AssertClean() 156 }() 157 if !tr.Invalid() { 158 t.Fatal("expected tracker to be marked invalid after violation") 159 } 160 } 161 162 func TestPhaseTracker_AssertClean_DetectsForgottenRestore(t *testing.T) { 163 tr := newPhaseTracker() 164 _ = tr.EnterTransient(PhaseAwaitingLLM) // intentionally drop restore 165 166 defer func() { 167 r := recover() 168 if r == nil { 169 t.Fatal("AssertClean should have panicked for forgotten transient") 170 } 171 }() 172 tr.AssertClean() 173 } 174 175 func TestPhaseTracker_Enter_PanicsInsideTransient(t *testing.T) { 176 tr := newPhaseTracker() 177 tr.Enter(PhaseSetup) 178 restore := tr.EnterTransient(PhaseAwaitingLLM) 179 defer restore() 180 181 defer func() { 182 r := recover() 183 if r == nil { 184 t.Fatal("Enter inside active transient should panic in test mode") 185 } 186 }() 187 tr.Enter(PhaseExecutingTools) // violates layering 188 } 189 190 func TestPhaseTracker_Dirty(t *testing.T) { 191 tr := newPhaseTracker() 192 if tr.TakeDirty() { 193 t.Fatal("new tracker should not be dirty") 194 } 195 tr.MarkDirty() 196 if !tr.TakeDirty() { 197 t.Fatal("MarkDirty should set dirty") 198 } 199 if tr.TakeDirty() { 200 t.Fatal("TakeDirty should clear on read") 201 } 202 } 203 204 func TestPhaseTracker_ConcurrentReadDuringWrite(t *testing.T) { 205 tr := newPhaseTracker() 206 tr.Enter(PhaseSetup) 207 208 const N = 200 209 var wg sync.WaitGroup 210 wg.Add(2) 211 212 // Writer: flip phases rapidly. 213 go func() { 214 defer wg.Done() 215 phases := []TurnPhase{PhaseAwaitingLLM, PhaseExecutingTools, PhaseSetup} 216 for i := 0; i < N; i++ { 217 tr.Enter(phases[i%len(phases)]) 218 } 219 }() 220 221 // Reader: poll concurrently, verify return values are well-typed (no torn reads). 222 go func() { 223 defer wg.Done() 224 for i := 0; i < N; i++ { 225 p, d, seq := tr.Current() 226 if p < PhaseInit || p > PhaseDone { 227 t.Errorf("torn phase read: %d", int(p)) 228 return 229 } 230 if d < 0 { 231 t.Errorf("negative since-duration: %v", d) 232 return 233 } 234 if seq < 0 { 235 t.Errorf("negative seq: %d", seq) 236 return 237 } 238 } 239 }() 240 241 wg.Wait() 242 } 243 244 func TestPhaseTracker_SinceRearms(t *testing.T) { 245 tr := newPhaseTracker() 246 tr.Enter(PhaseAwaitingLLM) 247 time.Sleep(5 * time.Millisecond) 248 _, d1, _ := tr.Current() 249 if d1 < 5*time.Millisecond { 250 t.Fatalf("since should be >= 5ms, got %v", d1) 251 } 252 253 tr.Enter(PhaseAwaitingLLM) // same phase re-entered 254 _, d2, _ := tr.Current() 255 if d2 > d1 { 256 t.Fatalf("re-entry should reset since: got %v, prior was %v", d2, d1) 257 } 258 }