/ internal / agent / watchdog_integration_test.go
watchdog_integration_test.go
  1  package agent
  2  
  3  import (
  4  	"context"
  5  	"errors"
  6  	"sync"
  7  	"sync/atomic"
  8  	"testing"
  9  	"time"
 10  
 11  	"github.com/Kocoro-lab/ShanClaw/internal/client"
 12  )
 13  
 14  // hangingLLMClient is a minimal client.LLMClient whose Complete blocks on
 15  // ctx.Done. Used to exercise the watchdog end-to-end inside AgentLoop.Run.
 16  type hangingLLMClient struct {
 17  	calls atomic.Int32
 18  }
 19  
 20  func (h *hangingLLMClient) Complete(ctx context.Context, req client.CompletionRequest) (*client.CompletionResponse, error) {
 21  	h.calls.Add(1)
 22  	<-ctx.Done()
 23  	return nil, ctx.Err()
 24  }
 25  
 26  func (h *hangingLLMClient) CompleteStream(ctx context.Context, req client.CompletionRequest, _ func(client.StreamDelta)) (*client.CompletionResponse, error) {
 27  	return h.Complete(ctx, req)
 28  }
 29  
 30  // recordingHandler captures OnRunStatus events for assertions.
 31  type recordingHandler struct {
 32  	mockHandler
 33  	mu       sync.Mutex
 34  	statuses []string
 35  }
 36  
 37  func (h *recordingHandler) OnRunStatus(code, detail string) {
 38  	h.mu.Lock()
 39  	h.statuses = append(h.statuses, code+":"+detail)
 40  	h.mu.Unlock()
 41  }
 42  
 43  func (h *recordingHandler) Statuses() []string {
 44  	h.mu.Lock()
 45  	defer h.mu.Unlock()
 46  	out := make([]string, len(h.statuses))
 47  	copy(out, h.statuses)
 48  	return out
 49  }
 50  
 51  func (h *recordingHandler) HasCode(code string) bool {
 52  	for _, s := range h.Statuses() {
 53  		if len(s) >= len(code) && s[:len(code)] == code {
 54  			return true
 55  		}
 56  	}
 57  	return false
 58  }
 59  
 60  func TestAgentLoop_Watchdog_SoftStatus_HangingClient(t *testing.T) {
 61  	gw := &hangingLLMClient{}
 62  	loop := NewAgentLoop(gw, NewToolRegistry(), "medium", "", 25, 2000, 200, nil, nil, nil)
 63  	loop.SetEnableStreaming(false)
 64  	loop.idleSoftTimeout = 30 * time.Millisecond
 65  	loop.watchdogTick = 5 * time.Millisecond
 66  	handler := &recordingHandler{mockHandler: mockHandler{approveResult: true}}
 67  	loop.SetHandler(handler)
 68  
 69  	ctx, cancel := context.WithCancel(context.Background())
 70  	defer cancel()
 71  
 72  	go func() {
 73  		time.Sleep(150 * time.Millisecond)
 74  		cancel()
 75  	}()
 76  
 77  	_, _, err := loop.Run(ctx, "hello", nil, nil)
 78  	if err == nil {
 79  		t.Fatal("expected cancel error from hanging client")
 80  	}
 81  
 82  	if !handler.HasCode("idle_soft") {
 83  		t.Fatalf("expected idle_soft status, got: %v", handler.Statuses())
 84  	}
 85  }
 86  
 87  func TestAgentLoop_Watchdog_ForceStop_HardTimeout_SurfacesHardIdleError(t *testing.T) {
 88  	// Regression for finding #4: during PhaseForceStop, completeWithRetry
 89  	// must preserve ErrHardIdleTimeout in the error chain (via context.Cause)
 90  	// rather than collapsing it into ctx.Err() == context.Canceled.
 91  	gw := &hangingLLMClient{}
 92  	loop := NewAgentLoop(gw, NewToolRegistry(), "medium", "", 25, 2000, 200, nil, nil, nil)
 93  	loop.SetEnableStreaming(false)
 94  	handler := &recordingHandler{mockHandler: mockHandler{approveResult: true}}
 95  	loop.SetHandler(handler)
 96  
 97  	// Simulate ForceStop-style call directly by entering PhaseForceStop and
 98  	// running completeWithRetry against a ctx cancelled by a cause.
 99  	loop.tracker = newPhaseTracker()
100  	loop.tracker.Enter(PhaseForceStop)
101  
102  	ctx, cancel := context.WithCancelCause(context.Background())
103  	go func() {
104  		time.Sleep(20 * time.Millisecond)
105  		cancel(ErrHardIdleTimeout)
106  	}()
107  
108  	_, err := loop.completeWithRetry(ctx, client.CompletionRequest{})
109  	if err == nil {
110  		t.Fatal("expected cancel error")
111  	}
112  	if !errors.Is(err, ErrHardIdleTimeout) {
113  		t.Fatalf("want ErrHardIdleTimeout via context.Cause, got: %v", err)
114  	}
115  }
116  
117  func TestAgentLoop_Watchdog_HardTimeout_CancelsWithCause(t *testing.T) {
118  	gw := &hangingLLMClient{}
119  	loop := NewAgentLoop(gw, NewToolRegistry(), "medium", "", 25, 2000, 200, nil, nil, nil)
120  	loop.SetEnableStreaming(false)
121  	loop.idleSoftTimeout = 0
122  	loop.idleHardTimeout = 40 * time.Millisecond
123  	loop.watchdogTick = 5 * time.Millisecond
124  	handler := &recordingHandler{mockHandler: mockHandler{approveResult: true}}
125  	loop.SetHandler(handler)
126  
127  	_, _, err := loop.Run(context.Background(), "hello", nil, nil)
128  	if err == nil {
129  		t.Fatal("expected hard-timeout error")
130  	}
131  	if !errors.Is(err, ErrHardIdleTimeout) {
132  		t.Fatalf("want ErrHardIdleTimeout in error chain, got: %v", err)
133  	}
134  	status := loop.LastRunStatus()
135  	if !status.Partial {
136  		t.Errorf("expected Partial=true on hard-timeout, got: %+v", status)
137  	}
138  	if !handler.HasCode("idle_hard") {
139  		t.Errorf("expected idle_hard status event, got: %v", handler.Statuses())
140  	}
141  }
142  
143  func TestAgentLoop_Watchdog_HardZero_NoCancellation(t *testing.T) {
144  	// Regression guard: default rollout (hard=0) must not change
145  	// cancellation semantics. Run should complete on a cooperating client
146  	// without any watchdog-originated cancel.
147  	callCount := 0
148  	gw := fakeLLMClient{
149  		resp: func() *client.CompletionResponse {
150  			callCount++
151  			if callCount == 1 {
152  				return &client.CompletionResponse{
153  					OutputText:   "ok",
154  					FinishReason: "end_turn",
155  				}
156  			}
157  			return &client.CompletionResponse{OutputText: "", FinishReason: "end_turn"}
158  		},
159  	}
160  	loop := NewAgentLoop(&gw, NewToolRegistry(), "medium", "", 25, 2000, 200, nil, nil, nil)
161  	loop.SetEnableStreaming(false)
162  	loop.idleSoftTimeout = 10 * time.Millisecond // would fire if we stalled
163  	loop.idleHardTimeout = 0                     // disabled — must not cancel
164  	handler := &recordingHandler{mockHandler: mockHandler{approveResult: true}}
165  	loop.SetHandler(handler)
166  
167  	text, _, err := loop.Run(context.Background(), "hi", nil, nil)
168  	if err != nil {
169  		t.Fatalf("unexpected error with hard=0: %v", err)
170  	}
171  	if text != "ok" {
172  		t.Fatalf("want text %q, got %q", "ok", text)
173  	}
174  }
175  
176  // fakeLLMClient is a tiny cooperating client that returns a fixed response.
177  type fakeLLMClient struct {
178  	resp func() *client.CompletionResponse
179  }
180  
181  func (f *fakeLLMClient) Complete(ctx context.Context, _ client.CompletionRequest) (*client.CompletionResponse, error) {
182  	if ctx.Err() != nil {
183  		return nil, ctx.Err()
184  	}
185  	return f.resp(), nil
186  }
187  
188  func (f *fakeLLMClient) CompleteStream(ctx context.Context, req client.CompletionRequest, _ func(client.StreamDelta)) (*client.CompletionResponse, error) {
189  	return f.Complete(ctx, req)
190  }