/ internal / daemon / runner_test.go
runner_test.go
  1  package daemon
  2  
  3  import (
  4  	"context"
  5  	"encoding/base64"
  6  	"encoding/json"
  7  	"errors"
  8  	"fmt"
  9  	"os"
 10  	"path/filepath"
 11  	"strconv"
 12  	"testing"
 13  	"time"
 14  
 15  	"github.com/Kocoro-lab/ShanClaw/internal/agent"
 16  	"github.com/Kocoro-lab/ShanClaw/internal/client"
 17  	"github.com/Kocoro-lab/ShanClaw/internal/mcp"
 18  	"github.com/Kocoro-lab/ShanClaw/internal/session"
 19  )
 20  
 21  func TestCacheSourceFromDaemonSource(t *testing.T) {
 22  	cases := []struct {
 23  		source string
 24  		want   string
 25  	}{
 26  		{"slack", "slack"},
 27  		{"Slack", "slack"},
 28  		{"  line  ", "line"},
 29  		{"feishu", "feishu"},
 30  		{"telegram", "telegram"},
 31  		{"tui", "tui"},
 32  		{"shanclaw", "shanclaw"},
 33  		// Empty source is defaulted to "shanclaw" in server.go before reaching
 34  		// this function; the dedicated empty-string case was removed. Falls
 35  		// through to "unknown" (5m) defensively in case the default is ever
 36  		// bypassed — matches the fail-cheap policy documented in
 37  		// docs/cache-strategy.md.
 38  		{"", "unknown"},
 39  		{"webhook", "webhook"},
 40  		{"cron", "cron"},
 41  		{"schedule", "schedule"},
 42  		{"mcp", "mcp"},
 43  		{"cache_bench", "cache_bench"},
 44  		{"never-classified-source", "unknown"},
 45  	}
 46  	for _, c := range cases {
 47  		if got := cacheSourceFromDaemonSource(c.source); got != c.want {
 48  			t.Errorf("cacheSourceFromDaemonSource(%q) = %q, want %q", c.source, got, c.want)
 49  		}
 50  	}
 51  }
 52  
 53  func TestRunAgentRequest_Validate_EmptyText(t *testing.T) {
 54  	req := RunAgentRequest{Text: ""}
 55  	if err := req.Validate(); err == nil {
 56  		t.Fatal("expected error for empty text")
 57  	}
 58  }
 59  
 60  func TestRunAgentRequest_Validate_WhitespaceOnly(t *testing.T) {
 61  	req := RunAgentRequest{Text: "   "}
 62  	if err := req.Validate(); err == nil {
 63  		t.Fatal("expected error for whitespace-only text")
 64  	}
 65  }
 66  
 67  func TestRunAgentRequest_Validate_NonEmpty(t *testing.T) {
 68  	req := RunAgentRequest{Text: "hello"}
 69  	if err := req.Validate(); err != nil {
 70  		t.Fatalf("unexpected error: %v", err)
 71  	}
 72  }
 73  
 74  func TestRunAgentRequest_Validate_WithAgent(t *testing.T) {
 75  	req := RunAgentRequest{Text: "do something", Agent: "ops-bot"}
 76  	if err := req.Validate(); err != nil {
 77  		t.Fatalf("unexpected error: %v", err)
 78  	}
 79  }
 80  
 81  func TestRunAgentRequest_Validate_WithSessionID(t *testing.T) {
 82  	req := RunAgentRequest{Text: "do something", SessionID: "sess-123"}
 83  	if err := req.Validate(); err != nil {
 84  		t.Fatalf("unexpected error: %v", err)
 85  	}
 86  }
 87  
 88  func TestRunAgentRequest_Ephemeral(t *testing.T) {
 89  	req := RunAgentRequest{
 90  		Text:      "test",
 91  		Agent:     "test-agent",
 92  		Source:    "heartbeat",
 93  		Ephemeral: true,
 94  	}
 95  	if err := req.Validate(); err != nil {
 96  		t.Fatalf("valid ephemeral request should not fail: %v", err)
 97  	}
 98  }
 99  
100  func TestRunAgentRequest_ModelOverride(t *testing.T) {
101  	req := RunAgentRequest{
102  		Text:          "test",
103  		ModelOverride: "small",
104  	}
105  	if err := req.Validate(); err != nil {
106  		t.Fatalf("valid model override request should not fail: %v", err)
107  	}
108  }
109  
110  func TestRunAgentRequest_Validate_WithValidCWD(t *testing.T) {
111  	req := RunAgentRequest{
112  		Text: "test",
113  		CWD:  t.TempDir(),
114  	}
115  	if err := req.Validate(); err != nil {
116  		t.Fatalf("valid cwd request should not fail: %v", err)
117  	}
118  }
119  
120  func TestRunAgentRequest_Validate_WithInvalidCWD(t *testing.T) {
121  	req := RunAgentRequest{
122  		Text: "test",
123  		CWD:  "/nonexistent/path/for/inject-validation",
124  	}
125  	if err := req.Validate(); err == nil {
126  		t.Fatal("expected invalid cwd error")
127  	}
128  }
129  
130  func TestComputeRouteKey_BypassRouting(t *testing.T) {
131  	req := RunAgentRequest{Agent: "my-agent", BypassRouting: true}
132  	if got := ComputeRouteKey(req); got != "" {
133  		t.Errorf("ComputeRouteKey with BypassRouting=true returned %q, want empty", got)
134  	}
135  }
136  
137  func TestComputeRouteKey_AgentWithoutBypass(t *testing.T) {
138  	req := RunAgentRequest{Agent: "my-agent"}
139  	if got := ComputeRouteKey(req); got != "agent:my-agent" {
140  		t.Errorf("ComputeRouteKey returned %q, want %q", got, "agent:my-agent")
141  	}
142  }
143  
144  func TestRouteTitle(t *testing.T) {
145  	tests := []struct {
146  		source, channel, sender, want string
147  	}{
148  		{"slack", "slack", "Wayland", "Slack · Wayland"},
149  		{"slack", "slack", "", "Slack"},
150  		{"line", "line", "Tanaka", "Line · Tanaka"},
151  		{"feishu", "feishu", "", "Feishu"},
152  		{"slack", "#general", "", "Slack · #general"},
153  		{"slack", "#general", "Alice", "Slack · Alice"},
154  		{"webhook", "hook-1", "", "Webhook · hook-1"},
155  		{"", "slack", "Wayland", ""},
156  		{"slack", "", "Wayland", "Slack · Wayland"},
157  		{"", "", "", ""},
158  	}
159  	for _, tt := range tests {
160  		got := routeTitle(tt.source, tt.channel, tt.sender)
161  		if got != tt.want {
162  			t.Errorf("routeTitle(%q, %q, %q) = %q, want %q",
163  				tt.source, tt.channel, tt.sender, got, tt.want)
164  		}
165  	}
166  }
167  
168  func TestOutputFormatForSource(t *testing.T) {
169  	tests := []struct {
170  		source string
171  		want   string
172  	}{
173  		// Cloud-distributed channel sources → plain
174  		{"slack", "plain"},
175  		{"line", "plain"},
176  		{"webhook", "plain"},
177  		{"feishu", "plain"},
178  		{"lark", "plain"},
179  		{"telegram", "plain"},
180  		{"Slack", "plain"}, // case-insensitive
181  		{"LINE", "plain"},  // case-insensitive
182  		// Everything else → markdown (local, cron, schedule, web, unknown)
183  		{"shanclaw", "markdown"},
184  		{"desktop", "markdown"},
185  		{"web", "markdown"},
186  		{"cron", "markdown"},
187  		{"schedule", "markdown"},
188  		{"heartbeat", "markdown"},
189  		{"", "markdown"},
190  		{"unknown", "markdown"},
191  		{"custom-bot", "markdown"},
192  	}
193  	for _, tt := range tests {
194  		got := outputFormatForSource(tt.source)
195  		if got != tt.want {
196  			t.Errorf("outputFormatForSource(%q) = %q, want %q", tt.source, got, tt.want)
197  		}
198  	}
199  }
200  
201  func TestRunAgentRequestSource(t *testing.T) {
202  	req := RunAgentRequest{
203  		Text:   "hello",
204  		Agent:  "test",
205  		Source: "slack",
206  	}
207  	data, _ := json.Marshal(req)
208  	var decoded RunAgentRequest
209  	json.Unmarshal(data, &decoded)
210  	if decoded.Source != "slack" {
211  		t.Fatalf("expected source 'slack', got %q", decoded.Source)
212  	}
213  }
214  
215  // context.Canceled and context.DeadlineExceeded must be treated as soft errors
216  // (like ErrMaxIterReached) so the full conversation from RunMessages() is
217  // persisted, not just a friendly error stub.
218  func TestIsSoftRunError(t *testing.T) {
219  	tests := []struct {
220  		name string
221  		err  error
222  		want bool
223  	}{
224  		{"nil", nil, false},
225  		{"context.Canceled", context.Canceled, true},
226  		{"context.DeadlineExceeded", context.DeadlineExceeded, true},
227  		{"ErrMaxIterReached", agent.ErrMaxIterReached, true},
228  		{"ErrHardIdleTimeout", agent.ErrHardIdleTimeout, true},
229  		{"wrapped ErrHardIdleTimeout", fmt.Errorf("turn aborted: %w", agent.ErrHardIdleTimeout), true},
230  		{"wrapped Canceled", errors.Join(errors.New("loop"), context.Canceled), true},
231  		{"random error", errors.New("something broke"), false},
232  		{"API error", errors.New("429 rate limited"), false},
233  	}
234  	for _, tt := range tests {
235  		t.Run(tt.name, func(t *testing.T) {
236  			got := isSoftRunError(tt.err)
237  			if got != tt.want {
238  				t.Errorf("isSoftRunError(%v) = %v, want %v", tt.err, got, tt.want)
239  			}
240  		})
241  	}
242  }
243  
244  func TestResumeNamedAgentColdStart_ResumesPersistedEmptySession(t *testing.T) {
245  	sessionsDir := t.TempDir()
246  	storedCWD := t.TempDir()
247  	store := session.NewStore(sessionsDir)
248  	if err := store.Save(&session.Session{
249  		ID:    "persisted-empty",
250  		Title: "Persisted empty session",
251  		CWD:   storedCWD,
252  	}); err != nil {
253  		t.Fatalf("save session: %v", err)
254  	}
255  
256  	mgr := session.NewManager(sessionsDir)
257  	resumed, err := resumeNamedAgentColdStart(mgr)
258  	if err != nil {
259  		t.Fatalf("resumeNamedAgentColdStart error: %v", err)
260  	}
261  	if !resumed {
262  		t.Fatal("expected persisted empty session to count as resumed")
263  	}
264  	if got := mgr.Current(); got == nil || got.CWD != storedCWD {
265  		t.Fatalf("expected stored CWD %q, got %#v", storedCWD, got)
266  	}
267  }
268  
269  func TestResumeNamedAgentColdStart_NoPersistedSessionKeepsFreshCurrent(t *testing.T) {
270  	sessionsDir := t.TempDir()
271  	mgr := session.NewManager(sessionsDir)
272  	fresh := mgr.NewSession()
273  
274  	resumed, err := resumeNamedAgentColdStart(mgr)
275  	if err != nil {
276  		t.Fatalf("resumeNamedAgentColdStart error: %v", err)
277  	}
278  	if resumed {
279  		t.Fatal("expected no persisted session to remain fresh")
280  	}
281  	if got := mgr.Current(); got == nil || got.ID != fresh.ID {
282  		t.Fatalf("expected fresh current session %q to be preserved, got %#v", fresh.ID, got)
283  	}
284  }
285  
286  func TestResolveContentBlocks_TextAndImage(t *testing.T) {
287  	blocks := []RequestContentBlock{
288  		{Type: "text", Text: "hello"},
289  		{Type: "image", Source: &client.ImageSource{Type: "base64", MediaType: "image/png", Data: "abc123"}},
290  	}
291  	resolved := resolveContentBlocks(blocks)
292  	if len(resolved) != 2 {
293  		t.Fatalf("expected 2 blocks, got %d", len(resolved))
294  	}
295  	if resolved[0].Type != "text" || resolved[0].Text != "hello" {
296  		t.Errorf("text block mismatch: %+v", resolved[0])
297  	}
298  	if resolved[1].Type != "image" || resolved[1].Source == nil || resolved[1].Source.Data != "abc123" {
299  		t.Errorf("image block mismatch: %+v", resolved[1])
300  	}
301  }
302  
303  func TestResolveContentBlocks_FileRef(t *testing.T) {
304  	dir := t.TempDir()
305  	path := filepath.Join(dir, "test.txt")
306  	os.WriteFile(path, []byte("file content here"), 0644)
307  
308  	blocks := []RequestContentBlock{
309  		{Type: "file_ref", FilePath: path, Filename: "test.txt", ByteSize: 17},
310  	}
311  	resolved := resolveContentBlocks(blocks)
312  	if len(resolved) != 1 {
313  		t.Fatalf("expected 1 block, got %d", len(resolved))
314  	}
315  	if resolved[0].Type != "text" {
316  		t.Fatalf("expected text type, got %s", resolved[0].Type)
317  	}
318  	expected := "[User attached file: test.txt (17 bytes) at path: " + path + " — use the file_read tool to read its contents]"
319  	if resolved[0].Text != expected {
320  		t.Errorf("file ref text mismatch:\ngot:  %q\nwant: %q", resolved[0].Text, expected)
321  	}
322  }
323  
324  func TestResolveContentBlocks_ImageFileRef(t *testing.T) {
325  	dir := t.TempDir()
326  	path := filepath.Join(dir, "photo.png")
327  	raw := []byte("fake-png-data")
328  	if err := os.WriteFile(path, raw, 0644); err != nil {
329  		t.Fatalf("write image: %v", err)
330  	}
331  
332  	blocks := []RequestContentBlock{
333  		{Type: "file_ref", FilePath: path, Filename: "photo.png", ByteSize: int64(len(raw))},
334  	}
335  	resolved := resolveContentBlocks(blocks)
336  	if len(resolved) != 2 {
337  		t.Fatalf("expected 2 blocks, got %d", len(resolved))
338  	}
339  	if resolved[0].Type != "text" {
340  		t.Fatalf("expected first block to be text, got %s", resolved[0].Type)
341  	}
342  	expectedText := "[User attached image: photo.png (" + strconv.Itoa(len(raw)) + " bytes) at path: " + path + " — the image is included inline below for vision. Use the path if a tool needs the original file.]"
343  	if resolved[0].Text != expectedText {
344  		t.Errorf("image file ref text mismatch:\ngot:  %q\nwant: %q", resolved[0].Text, expectedText)
345  	}
346  	if resolved[1].Type != "image" || resolved[1].Source == nil {
347  		t.Fatalf("expected second block to be image, got %+v", resolved[1])
348  	}
349  	if resolved[1].Source.MediaType != "image/png" {
350  		t.Fatalf("expected image/png, got %q", resolved[1].Source.MediaType)
351  	}
352  	if resolved[1].Source.Data != base64.StdEncoding.EncodeToString(raw) {
353  		t.Errorf("image data mismatch: got %q", resolved[1].Source.Data)
354  	}
355  }
356  
357  func TestResolveContentBlocks_FileRefMissing(t *testing.T) {
358  	blocks := []RequestContentBlock{
359  		{Type: "file_ref", FilePath: "/nonexistent/path/file.log", Filename: "file.log"},
360  	}
361  	resolved := resolveContentBlocks(blocks)
362  	if len(resolved) != 1 {
363  		t.Fatalf("expected 1 block, got %d", len(resolved))
364  	}
365  	if resolved[0].Type != "text" {
366  		t.Fatalf("expected text type, got %s", resolved[0].Type)
367  	}
368  	expected := "[User attached file: file.log (0 bytes) at path: /nonexistent/path/file.log — use the file_read tool to read its contents]"
369  	if resolved[0].Text != expected {
370  		t.Errorf("error text mismatch:\ngot:  %q\nwant: %q", resolved[0].Text, expected)
371  	}
372  }
373  
374  func TestResolveContentBlocks_UnknownTypeSkipped(t *testing.T) {
375  	blocks := []RequestContentBlock{
376  		{Type: "text", Text: "keep"},
377  		{Type: "unknown_type", Text: "skip"},
378  	}
379  	resolved := resolveContentBlocks(blocks)
380  	if len(resolved) != 1 {
381  		t.Fatalf("expected 1 block (unknown skipped), got %d", len(resolved))
382  	}
383  	if resolved[0].Text != "keep" {
384  		t.Errorf("expected 'keep', got %q", resolved[0].Text)
385  	}
386  }
387  
388  func TestRunAgentRequest_ContentJSON(t *testing.T) {
389  	raw := `{
390  		"text": "analyze this",
391  		"content": [
392  			{"type": "text", "text": "describe the image"},
393  			{"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": "iVBOR"}}
394  		],
395  		"source": "shanclaw"
396  	}`
397  	var req RunAgentRequest
398  	if err := json.Unmarshal([]byte(raw), &req); err != nil {
399  		t.Fatalf("unmarshal error: %v", err)
400  	}
401  	if req.Text != "analyze this" {
402  		t.Errorf("text mismatch: %q", req.Text)
403  	}
404  	if len(req.Content) != 2 {
405  		t.Fatalf("expected 2 content blocks, got %d", len(req.Content))
406  	}
407  	if req.Content[0].Type != "text" || req.Content[0].Text != "describe the image" {
408  		t.Errorf("content[0] mismatch: %+v", req.Content[0])
409  	}
410  	if req.Content[1].Type != "image" || req.Content[1].Source == nil || req.Content[1].Source.Data != "iVBOR" {
411  		t.Errorf("content[1] mismatch: %+v", req.Content[1])
412  	}
413  }
414  
415  func TestRunAgentRequest_NoContent(t *testing.T) {
416  	raw := `{"text": "just text"}`
417  	var req RunAgentRequest
418  	if err := json.Unmarshal([]byte(raw), &req); err != nil {
419  		t.Fatalf("unmarshal error: %v", err)
420  	}
421  	if req.Text != "just text" {
422  		t.Errorf("text mismatch: %q", req.Text)
423  	}
424  	if req.Content != nil {
425  		t.Errorf("expected nil content, got %v", req.Content)
426  	}
427  }
428  
429  func TestExtractUserFilePaths(t *testing.T) {
430  	blocks := []RequestContentBlock{
431  		{Type: "text", Text: "analyze these"},
432  		{Type: "file_ref", FilePath: "/tmp/report.pdf", Filename: "report.pdf"},
433  		{Type: "image", Source: &client.ImageSource{Type: "base64", MediaType: "image/png", Data: "abc"}},
434  		{Type: "file_ref", FilePath: "/tmp/data.csv", Filename: "data.csv"},
435  	}
436  	paths := extractUserFilePaths(blocks)
437  	if len(paths) != 2 {
438  		t.Fatalf("expected 2 paths, got %d: %v", len(paths), paths)
439  	}
440  	if paths[0] != "/tmp/report.pdf" || paths[1] != "/tmp/data.csv" {
441  		t.Errorf("unexpected paths: %v", paths)
442  	}
443  }
444  
445  func TestExtractUserFilePaths_Empty(t *testing.T) {
446  	paths := extractUserFilePaths(nil)
447  	if len(paths) != 0 {
448  		t.Errorf("expected empty, got %v", paths)
449  	}
450  	paths = extractUserFilePaths([]RequestContentBlock{{Type: "text", Text: "hello"}})
451  	if len(paths) != 0 {
452  		t.Errorf("expected empty for text-only, got %v", paths)
453  	}
454  }
455  
456  func TestCleanupPlaywrightAfterTurn_CDPOnDemandStopsBrowser(t *testing.T) {
457  	mgr := mcp.NewClientManager()
458  	mgr.SeedConfig("playwright", mcp.MCPServerConfig{
459  		Command:   "dummy",
460  		Args:      []string{"--cdp-endpoint", "http://127.0.0.1:9223"},
461  		KeepAlive: false,
462  	})
463  
464  	oldIdle := disconnectPlaywrightAfterIdleFn
465  	oldNow := disconnectPlaywrightNowFn
466  	oldStop := stopPlaywrightChromeFn
467  	defer func() {
468  		disconnectPlaywrightAfterIdleFn = oldIdle
469  		disconnectPlaywrightNowFn = oldNow
470  		stopPlaywrightChromeFn = oldStop
471  	}()
472  
473  	idleCalls := 0
474  	nowCalls := 0
475  	stopCalls := 0
476  	disconnectPlaywrightAfterIdleFn = func(*mcp.ClientManager, time.Duration) { idleCalls++ }
477  	disconnectPlaywrightNowFn = func(*mcp.ClientManager) { nowCalls++ }
478  	stopPlaywrightChromeFn = func() { stopCalls++ }
479  
480  	cleanupPlaywrightAfterTurn(mgr)
481  
482  	if idleCalls != 0 {
483  		t.Fatalf("expected no idle disconnect scheduling, got %d", idleCalls)
484  	}
485  	if nowCalls != 1 {
486  		t.Fatalf("expected immediate disconnect once, got %d", nowCalls)
487  	}
488  	if stopCalls != 1 {
489  		t.Fatalf("expected dedicated Chrome stop once, got %d", stopCalls)
490  	}
491  }
492  
493  func TestCleanupPlaywrightAfterTurn_KeepAliveLeavesBrowserRunning(t *testing.T) {
494  	mgr := mcp.NewClientManager()
495  	mgr.SeedConfig("playwright", mcp.MCPServerConfig{
496  		Command:   "dummy",
497  		Args:      []string{"--cdp-endpoint", "http://127.0.0.1:9223"},
498  		KeepAlive: true,
499  	})
500  
501  	oldIdle := disconnectPlaywrightAfterIdleFn
502  	oldNow := disconnectPlaywrightNowFn
503  	oldStop := stopPlaywrightChromeFn
504  	defer func() {
505  		disconnectPlaywrightAfterIdleFn = oldIdle
506  		disconnectPlaywrightNowFn = oldNow
507  		stopPlaywrightChromeFn = oldStop
508  	}()
509  
510  	idleCalls := 0
511  	nowCalls := 0
512  	stopCalls := 0
513  	disconnectPlaywrightAfterIdleFn = func(*mcp.ClientManager, time.Duration) { idleCalls++ }
514  	disconnectPlaywrightNowFn = func(*mcp.ClientManager) { nowCalls++ }
515  	stopPlaywrightChromeFn = func() { stopCalls++ }
516  
517  	cleanupPlaywrightAfterTurn(mgr)
518  
519  	if idleCalls != 0 || nowCalls != 0 || stopCalls != 0 {
520  		t.Fatalf("expected no teardown while keepAlive=true, got idle=%d disconnect=%d stop=%d", idleCalls, nowCalls, stopCalls)
521  	}
522  }
523  
524  func TestCleanupPlaywrightAfterTurn_NonCDPUsesIdleDisconnect(t *testing.T) {
525  	mgr := mcp.NewClientManager()
526  	mgr.SeedConfig("playwright", mcp.MCPServerConfig{
527  		Command:   "dummy",
528  		Args:      []string{"--some-stdio-mode"},
529  		KeepAlive: false,
530  	})
531  
532  	oldIdle := disconnectPlaywrightAfterIdleFn
533  	oldNow := disconnectPlaywrightNowFn
534  	oldStop := stopPlaywrightChromeFn
535  	defer func() {
536  		disconnectPlaywrightAfterIdleFn = oldIdle
537  		disconnectPlaywrightNowFn = oldNow
538  		stopPlaywrightChromeFn = oldStop
539  	}()
540  
541  	idleCalls := 0
542  	var idleDuration time.Duration
543  	nowCalls := 0
544  	stopCalls := 0
545  	disconnectPlaywrightAfterIdleFn = func(_ *mcp.ClientManager, d time.Duration) {
546  		idleCalls++
547  		idleDuration = d
548  	}
549  	disconnectPlaywrightNowFn = func(*mcp.ClientManager) { nowCalls++ }
550  	stopPlaywrightChromeFn = func() { stopCalls++ }
551  
552  	cleanupPlaywrightAfterTurn(mgr)
553  
554  	if idleCalls != 1 {
555  		t.Fatalf("expected idle disconnect scheduling once, got %d", idleCalls)
556  	}
557  	if idleDuration != 5*time.Minute {
558  		t.Fatalf("expected 5m idle disconnect, got %v", idleDuration)
559  	}
560  	if nowCalls != 0 || stopCalls != 0 {
561  		t.Fatalf("expected no immediate teardown in non-CDP mode, got disconnect=%d stop=%d", nowCalls, stopCalls)
562  	}
563  }