watchdog_integration_test.go
1 package agent 2 3 import ( 4 "context" 5 "errors" 6 "sync" 7 "sync/atomic" 8 "testing" 9 "time" 10 11 "github.com/Kocoro-lab/ShanClaw/internal/client" 12 ) 13 14 // hangingLLMClient is a minimal client.LLMClient whose Complete blocks on 15 // ctx.Done. Used to exercise the watchdog end-to-end inside AgentLoop.Run. 16 type hangingLLMClient struct { 17 calls atomic.Int32 18 } 19 20 func (h *hangingLLMClient) Complete(ctx context.Context, req client.CompletionRequest) (*client.CompletionResponse, error) { 21 h.calls.Add(1) 22 <-ctx.Done() 23 return nil, ctx.Err() 24 } 25 26 func (h *hangingLLMClient) CompleteStream(ctx context.Context, req client.CompletionRequest, _ func(client.StreamDelta)) (*client.CompletionResponse, error) { 27 return h.Complete(ctx, req) 28 } 29 30 // recordingHandler captures OnRunStatus events for assertions. 31 type recordingHandler struct { 32 mockHandler 33 mu sync.Mutex 34 statuses []string 35 } 36 37 func (h *recordingHandler) OnRunStatus(code, detail string) { 38 h.mu.Lock() 39 h.statuses = append(h.statuses, code+":"+detail) 40 h.mu.Unlock() 41 } 42 43 func (h *recordingHandler) Statuses() []string { 44 h.mu.Lock() 45 defer h.mu.Unlock() 46 out := make([]string, len(h.statuses)) 47 copy(out, h.statuses) 48 return out 49 } 50 51 func (h *recordingHandler) HasCode(code string) bool { 52 for _, s := range h.Statuses() { 53 if len(s) >= len(code) && s[:len(code)] == code { 54 return true 55 } 56 } 57 return false 58 } 59 60 func TestAgentLoop_Watchdog_SoftStatus_HangingClient(t *testing.T) { 61 gw := &hangingLLMClient{} 62 loop := NewAgentLoop(gw, NewToolRegistry(), "medium", "", 25, 2000, 200, nil, nil, nil) 63 loop.SetEnableStreaming(false) 64 loop.idleSoftTimeout = 30 * time.Millisecond 65 loop.watchdogTick = 5 * time.Millisecond 66 handler := &recordingHandler{mockHandler: mockHandler{approveResult: true}} 67 loop.SetHandler(handler) 68 69 ctx, cancel := context.WithCancel(context.Background()) 70 defer cancel() 71 72 go func() { 73 time.Sleep(150 * time.Millisecond) 74 cancel() 75 }() 76 77 _, _, err := loop.Run(ctx, "hello", nil, nil) 78 if err == nil { 79 t.Fatal("expected cancel error from hanging client") 80 } 81 82 if !handler.HasCode("idle_soft") { 83 t.Fatalf("expected idle_soft status, got: %v", handler.Statuses()) 84 } 85 } 86 87 func TestAgentLoop_Watchdog_ForceStop_HardTimeout_SurfacesHardIdleError(t *testing.T) { 88 // Regression for finding #4: during PhaseForceStop, completeWithRetry 89 // must preserve ErrHardIdleTimeout in the error chain (via context.Cause) 90 // rather than collapsing it into ctx.Err() == context.Canceled. 91 gw := &hangingLLMClient{} 92 loop := NewAgentLoop(gw, NewToolRegistry(), "medium", "", 25, 2000, 200, nil, nil, nil) 93 loop.SetEnableStreaming(false) 94 handler := &recordingHandler{mockHandler: mockHandler{approveResult: true}} 95 loop.SetHandler(handler) 96 97 // Simulate ForceStop-style call directly by entering PhaseForceStop and 98 // running completeWithRetry against a ctx cancelled by a cause. 99 loop.tracker = newPhaseTracker() 100 loop.tracker.Enter(PhaseForceStop) 101 102 ctx, cancel := context.WithCancelCause(context.Background()) 103 go func() { 104 time.Sleep(20 * time.Millisecond) 105 cancel(ErrHardIdleTimeout) 106 }() 107 108 _, err := loop.completeWithRetry(ctx, client.CompletionRequest{}) 109 if err == nil { 110 t.Fatal("expected cancel error") 111 } 112 if !errors.Is(err, ErrHardIdleTimeout) { 113 t.Fatalf("want ErrHardIdleTimeout via context.Cause, got: %v", err) 114 } 115 } 116 117 func TestAgentLoop_Watchdog_HardTimeout_CancelsWithCause(t *testing.T) { 118 gw := &hangingLLMClient{} 119 loop := NewAgentLoop(gw, NewToolRegistry(), "medium", "", 25, 2000, 200, nil, nil, nil) 120 loop.SetEnableStreaming(false) 121 loop.idleSoftTimeout = 0 122 loop.idleHardTimeout = 40 * time.Millisecond 123 loop.watchdogTick = 5 * time.Millisecond 124 handler := &recordingHandler{mockHandler: mockHandler{approveResult: true}} 125 loop.SetHandler(handler) 126 127 _, _, err := loop.Run(context.Background(), "hello", nil, nil) 128 if err == nil { 129 t.Fatal("expected hard-timeout error") 130 } 131 if !errors.Is(err, ErrHardIdleTimeout) { 132 t.Fatalf("want ErrHardIdleTimeout in error chain, got: %v", err) 133 } 134 status := loop.LastRunStatus() 135 if !status.Partial { 136 t.Errorf("expected Partial=true on hard-timeout, got: %+v", status) 137 } 138 if !handler.HasCode("idle_hard") { 139 t.Errorf("expected idle_hard status event, got: %v", handler.Statuses()) 140 } 141 } 142 143 func TestAgentLoop_Watchdog_HardZero_NoCancellation(t *testing.T) { 144 // Regression guard: default rollout (hard=0) must not change 145 // cancellation semantics. Run should complete on a cooperating client 146 // without any watchdog-originated cancel. 147 callCount := 0 148 gw := fakeLLMClient{ 149 resp: func() *client.CompletionResponse { 150 callCount++ 151 if callCount == 1 { 152 return &client.CompletionResponse{ 153 OutputText: "ok", 154 FinishReason: "end_turn", 155 } 156 } 157 return &client.CompletionResponse{OutputText: "", FinishReason: "end_turn"} 158 }, 159 } 160 loop := NewAgentLoop(&gw, NewToolRegistry(), "medium", "", 25, 2000, 200, nil, nil, nil) 161 loop.SetEnableStreaming(false) 162 loop.idleSoftTimeout = 10 * time.Millisecond // would fire if we stalled 163 loop.idleHardTimeout = 0 // disabled — must not cancel 164 handler := &recordingHandler{mockHandler: mockHandler{approveResult: true}} 165 loop.SetHandler(handler) 166 167 text, _, err := loop.Run(context.Background(), "hi", nil, nil) 168 if err != nil { 169 t.Fatalf("unexpected error with hard=0: %v", err) 170 } 171 if text != "ok" { 172 t.Fatalf("want text %q, got %q", "ok", text) 173 } 174 } 175 176 // fakeLLMClient is a tiny cooperating client that returns a fixed response. 177 type fakeLLMClient struct { 178 resp func() *client.CompletionResponse 179 } 180 181 func (f *fakeLLMClient) Complete(ctx context.Context, _ client.CompletionRequest) (*client.CompletionResponse, error) { 182 if ctx.Err() != nil { 183 return nil, ctx.Err() 184 } 185 return f.resp(), nil 186 } 187 188 func (f *fakeLLMClient) CompleteStream(ctx context.Context, req client.CompletionRequest, _ func(client.StreamDelta)) (*client.CompletionResponse, error) { 189 return f.Complete(ctx, req) 190 }