/ internal / agent / phase_test.go
phase_test.go
  1  package agent
  2  
  3  import (
  4  	"sync"
  5  	"testing"
  6  	"time"
  7  )
  8  
  9  func TestTurnPhase_CountsAsIdle(t *testing.T) {
 10  	idle := map[TurnPhase]bool{
 11  		PhaseAwaitingLLM: true,
 12  		PhaseForceStop:   true,
 13  	}
 14  	all := []TurnPhase{
 15  		PhaseInit, PhaseSetup, PhaseAwaitingLLM, PhaseRetryingLLM,
 16  		PhaseCompacting, PhaseAwaitingApproval, PhaseExecutingTools,
 17  		PhaseInjectingMessage, PhaseForceStop, PhaseDone,
 18  	}
 19  	for _, p := range all {
 20  		want := idle[p]
 21  		if got := p.CountsAsIdle(); got != want {
 22  			t.Errorf("%s.CountsAsIdle() = %v, want %v", p, got, want)
 23  		}
 24  	}
 25  }
 26  
 27  func TestTurnPhase_String(t *testing.T) {
 28  	cases := map[TurnPhase]string{
 29  		PhaseInit: "init", PhaseSetup: "setup", PhaseAwaitingLLM: "awaiting_llm",
 30  		PhaseRetryingLLM: "retrying_llm", PhaseCompacting: "compacting",
 31  		PhaseAwaitingApproval: "awaiting_approval", PhaseExecutingTools: "executing_tools",
 32  		PhaseInjectingMessage: "injecting_message", PhaseForceStop: "force_stop",
 33  		PhaseDone: "done",
 34  	}
 35  	for p, want := range cases {
 36  		if got := p.String(); got != want {
 37  			t.Errorf("%d.String() = %q, want %q", int(p), got, want)
 38  		}
 39  	}
 40  	if got := TurnPhase(999).String(); got != "unknown" {
 41  		t.Errorf("unknown phase: %q", got)
 42  	}
 43  }
 44  
 45  func TestPhaseTracker_EnterAndCurrent(t *testing.T) {
 46  	tr := newPhaseTracker()
 47  	p, _, _ := tr.Current()
 48  	if p != PhaseInit {
 49  		t.Fatalf("initial phase = %s, want init", p)
 50  	}
 51  	tr.Enter(PhaseAwaitingLLM)
 52  	p, d, _ := tr.Current()
 53  	if p != PhaseAwaitingLLM {
 54  		t.Fatalf("after Enter: phase = %s", p)
 55  	}
 56  	if d < 0 || d > time.Second {
 57  		t.Fatalf("since-time unreasonable: %v", d)
 58  	}
 59  }
 60  
 61  func TestPhaseTracker_EnterTransient_RestoresPrev(t *testing.T) {
 62  	tr := newPhaseTracker()
 63  	tr.Enter(PhaseCompacting)
 64  
 65  	restore := tr.EnterTransient(PhaseAwaitingLLM)
 66  	p, _, _ := tr.Current()
 67  	if p != PhaseAwaitingLLM {
 68  		t.Fatalf("inside transient: phase = %s", p)
 69  	}
 70  
 71  	restore()
 72  	p, _, _ = tr.Current()
 73  	if p != PhaseCompacting {
 74  		t.Fatalf("after restore: phase = %s, want compacting", p)
 75  	}
 76  }
 77  
 78  func TestPhaseTracker_EnterTransient_NestedDoesNotLeak(t *testing.T) {
 79  	tr := newPhaseTracker()
 80  	tr.Enter(PhaseCompacting)
 81  
 82  	r1 := tr.EnterTransient(PhaseAwaitingLLM)
 83  	if p, _, _ := tr.Current(); p != PhaseAwaitingLLM {
 84  		t.Fatalf("first transient: %s", p)
 85  	}
 86  
 87  	// Imagine a nested call also needing AwaitingLLM (uncommon but legal).
 88  	r2 := tr.EnterTransient(PhaseAwaitingLLM)
 89  	r2()
 90  	if p, _, _ := tr.Current(); p != PhaseAwaitingLLM {
 91  		t.Fatalf("after inner restore, outer transient lost: %s", p)
 92  	}
 93  
 94  	r1()
 95  	if p, _, _ := tr.Current(); p != PhaseCompacting {
 96  		t.Fatalf("after outer restore: %s, want compacting", p)
 97  	}
 98  	tr.AssertClean() // depth should be 0
 99  }
100  
101  func TestPhaseTracker_RestoreIdempotent(t *testing.T) {
102  	tr := newPhaseTracker()
103  	tr.Enter(PhaseSetup)
104  	restore := tr.EnterTransient(PhaseAwaitingLLM)
105  	restore()
106  	restore() // second call must not underflow depth or change phase
107  	if p, _, _ := tr.Current(); p != PhaseSetup {
108  		t.Fatalf("after double restore: %s", p)
109  	}
110  	tr.AssertClean()
111  }
112  
113  func TestPhaseTracker_SeqBumpsOnEveryTransition(t *testing.T) {
114  	tr := newPhaseTracker()
115  	_, _, s0 := tr.Current()
116  
117  	tr.Enter(PhaseAwaitingLLM)
118  	_, _, s1 := tr.Current()
119  	if s1 <= s0 {
120  		t.Fatalf("seq did not bump on Enter: s0=%d s1=%d", s0, s1)
121  	}
122  
123  	// Re-entering the same phase type must still bump seq (so observers can
124  	// re-arm dedupes on transition, not phase-type identity).
125  	tr.Enter(PhaseAwaitingLLM)
126  	_, _, s2 := tr.Current()
127  	if s2 <= s1 {
128  		t.Fatalf("seq must bump on same-phase re-entry: s1=%d s2=%d", s1, s2)
129  	}
130  
131  	// EnterTransient bumps; restore bumps again.
132  	restore := tr.EnterTransient(PhaseAwaitingLLM)
133  	_, _, s3 := tr.Current()
134  	if s3 <= s2 {
135  		t.Fatalf("seq did not bump on EnterTransient: s2=%d s3=%d", s2, s3)
136  	}
137  	restore()
138  	_, _, s4 := tr.Current()
139  	if s4 <= s3 {
140  		t.Fatalf("seq did not bump on transient restore: s3=%d s4=%d", s3, s4)
141  	}
142  }
143  
144  func TestPhaseTracker_InvalidFlag(t *testing.T) {
145  	tr := newPhaseTracker()
146  	if tr.Invalid() {
147  		t.Fatal("new tracker should not be invalid")
148  	}
149  
150  	// Trigger a violation via forgotten restore + AssertClean. Under
151  	// testing.Testing() this panics, so guard + recover.
152  	_ = tr.EnterTransient(PhaseAwaitingLLM) // intentionally drop
153  	func() {
154  		defer func() { _ = recover() }()
155  		tr.AssertClean()
156  	}()
157  	if !tr.Invalid() {
158  		t.Fatal("expected tracker to be marked invalid after violation")
159  	}
160  }
161  
162  func TestPhaseTracker_AssertClean_DetectsForgottenRestore(t *testing.T) {
163  	tr := newPhaseTracker()
164  	_ = tr.EnterTransient(PhaseAwaitingLLM) // intentionally drop restore
165  
166  	defer func() {
167  		r := recover()
168  		if r == nil {
169  			t.Fatal("AssertClean should have panicked for forgotten transient")
170  		}
171  	}()
172  	tr.AssertClean()
173  }
174  
175  func TestPhaseTracker_Enter_PanicsInsideTransient(t *testing.T) {
176  	tr := newPhaseTracker()
177  	tr.Enter(PhaseSetup)
178  	restore := tr.EnterTransient(PhaseAwaitingLLM)
179  	defer restore()
180  
181  	defer func() {
182  		r := recover()
183  		if r == nil {
184  			t.Fatal("Enter inside active transient should panic in test mode")
185  		}
186  	}()
187  	tr.Enter(PhaseExecutingTools) // violates layering
188  }
189  
190  func TestPhaseTracker_Dirty(t *testing.T) {
191  	tr := newPhaseTracker()
192  	if tr.TakeDirty() {
193  		t.Fatal("new tracker should not be dirty")
194  	}
195  	tr.MarkDirty()
196  	if !tr.TakeDirty() {
197  		t.Fatal("MarkDirty should set dirty")
198  	}
199  	if tr.TakeDirty() {
200  		t.Fatal("TakeDirty should clear on read")
201  	}
202  }
203  
204  func TestPhaseTracker_ConcurrentReadDuringWrite(t *testing.T) {
205  	tr := newPhaseTracker()
206  	tr.Enter(PhaseSetup)
207  
208  	const N = 200
209  	var wg sync.WaitGroup
210  	wg.Add(2)
211  
212  	// Writer: flip phases rapidly.
213  	go func() {
214  		defer wg.Done()
215  		phases := []TurnPhase{PhaseAwaitingLLM, PhaseExecutingTools, PhaseSetup}
216  		for i := 0; i < N; i++ {
217  			tr.Enter(phases[i%len(phases)])
218  		}
219  	}()
220  
221  	// Reader: poll concurrently, verify return values are well-typed (no torn reads).
222  	go func() {
223  		defer wg.Done()
224  		for i := 0; i < N; i++ {
225  			p, d, seq := tr.Current()
226  			if p < PhaseInit || p > PhaseDone {
227  				t.Errorf("torn phase read: %d", int(p))
228  				return
229  			}
230  			if d < 0 {
231  				t.Errorf("negative since-duration: %v", d)
232  				return
233  			}
234  			if seq < 0 {
235  				t.Errorf("negative seq: %d", seq)
236  				return
237  			}
238  		}
239  	}()
240  
241  	wg.Wait()
242  }
243  
244  func TestPhaseTracker_SinceRearms(t *testing.T) {
245  	tr := newPhaseTracker()
246  	tr.Enter(PhaseAwaitingLLM)
247  	time.Sleep(5 * time.Millisecond)
248  	_, d1, _ := tr.Current()
249  	if d1 < 5*time.Millisecond {
250  		t.Fatalf("since should be >= 5ms, got %v", d1)
251  	}
252  
253  	tr.Enter(PhaseAwaitingLLM) // same phase re-entered
254  	_, d2, _ := tr.Current()
255  	if d2 > d1 {
256  		t.Fatalf("re-entry should reset since: got %v, prior was %v", d2, d1)
257  	}
258  }