/ internal / daemon / client_test.go
client_test.go
  1  package daemon
  2  
  3  import (
  4  	"context"
  5  	"encoding/json"
  6  	"net/http"
  7  	"net/http/httptest"
  8  	"strings"
  9  	"sync/atomic"
 10  	"testing"
 11  	"time"
 12  
 13  	"github.com/gorilla/websocket"
 14  )
 15  
 16  func TestRunWithReconnect_CancelledContextExitsImmediately(t *testing.T) {
 17  	ctx, cancel := context.WithCancel(context.Background())
 18  	client := NewClient("ws://localhost:99999/nonexistent", "key", func(msg MessagePayload) string { return "" }, nil)
 19  
 20  	done := make(chan struct{})
 21  	go func() {
 22  		client.RunWithReconnect(ctx)
 23  		close(done)
 24  	}()
 25  
 26  	time.Sleep(100 * time.Millisecond)
 27  	cancel()
 28  
 29  	select {
 30  	case <-done:
 31  	case <-time.After(2 * time.Second):
 32  		t.Fatal("RunWithReconnect did not exit within 2s after cancel")
 33  	}
 34  }
 35  
 36  func TestClient_SendEnvelope_WritesToConn(t *testing.T) {
 37  	received := make(chan DaemonMessage, 1)
 38  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 39  		upgrader := websocket.Upgrader{}
 40  		conn, err := upgrader.Upgrade(w, r, nil)
 41  		if err != nil {
 42  			return
 43  		}
 44  		defer conn.Close()
 45  		var dm DaemonMessage
 46  		if err := conn.ReadJSON(&dm); err != nil {
 47  			return
 48  		}
 49  		received <- dm
 50  	}))
 51  	defer srv.Close()
 52  
 53  	wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
 54  	c := NewClient(wsURL, "", nil, nil)
 55  	if err := c.Connect(context.Background()); err != nil {
 56  		t.Fatal(err)
 57  	}
 58  	defer c.Close()
 59  
 60  	if err := c.sendEnvelope(DaemonMessage{Type: MsgTypeClaim, MessageID: "msg-123"}); err != nil {
 61  		t.Fatal(err)
 62  	}
 63  
 64  	select {
 65  	case dm := <-received:
 66  		if dm.Type != MsgTypeClaim || dm.MessageID != "msg-123" {
 67  			t.Errorf("got type=%q id=%q, want type=%q id=%q", dm.Type, dm.MessageID, MsgTypeClaim, "msg-123")
 68  		}
 69  	case <-time.After(2 * time.Second):
 70  		t.Fatal("server did not receive message")
 71  	}
 72  }
 73  
 74  func TestClient_ConnectionState(t *testing.T) {
 75  	c := NewClient("ws://localhost:1/x", "", nil, nil)
 76  	if c.IsConnected() {
 77  		t.Error("should not be connected initially")
 78  	}
 79  	if c.Uptime() < 0 {
 80  		t.Error("uptime should be non-negative")
 81  	}
 82  	if c.ActiveAgent() != "" {
 83  		t.Error("no active agent initially")
 84  	}
 85  }
 86  
 87  func TestClient_ClaimHandshake_Granted(t *testing.T) {
 88  	var receivedClaim DaemonMessage
 89  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 90  		upgrader := websocket.Upgrader{}
 91  		conn, err := upgrader.Upgrade(w, r, nil)
 92  		if err != nil {
 93  			return
 94  		}
 95  		defer conn.Close()
 96  
 97  		// Send a message to the daemon.
 98  		payload, _ := json.Marshal(MessagePayload{Channel: "slack", Text: "hi", ThreadID: "t1"})
 99  		conn.WriteJSON(ServerMessage{Type: MsgTypeMessage, MessageID: "msg-001", Payload: payload})
100  
101  		// Read the claim.
102  		conn.ReadJSON(&receivedClaim)
103  
104  		// Grant the claim.
105  		ackPayload, _ := json.Marshal(ClaimAckPayload{Granted: true})
106  		conn.WriteJSON(ServerMessage{Type: MsgTypeClaimAck, MessageID: "msg-001", Payload: ackPayload})
107  
108  		// Read messages until we get a reply (may get progress first).
109  		for {
110  			var reply DaemonMessage
111  			if err := conn.ReadJSON(&reply); err != nil {
112  				return
113  			}
114  			if reply.Type == MsgTypeReply {
115  				var rp ReplyPayload
116  				json.Unmarshal(reply.Payload, &rp)
117  				if rp.Text != "agent-result" {
118  					t.Errorf("reply text = %q, want %q", rp.Text, "agent-result")
119  				}
120  				return
121  			}
122  		}
123  	}))
124  	defer srv.Close()
125  
126  	wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
127  	onMsgCalled := make(chan struct{})
128  	c := NewClient(wsURL, "", func(msg MessagePayload) string {
129  		close(onMsgCalled)
130  		return "agent-result"
131  	}, nil)
132  
133  	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
134  	defer cancel()
135  
136  	if err := c.Connect(ctx); err != nil {
137  		t.Fatal(err)
138  	}
139  	defer c.Close()
140  
141  	go c.Listen(ctx)
142  
143  	select {
144  	case <-onMsgCalled:
145  	case <-ctx.Done():
146  		t.Fatal("onMsg was never called")
147  	}
148  
149  	if receivedClaim.Type != MsgTypeClaim || receivedClaim.MessageID != "msg-001" {
150  		t.Errorf("expected claim for msg-001, got %+v", receivedClaim)
151  	}
152  }
153  
154  func TestClient_ClaimHandshake_Denied(t *testing.T) {
155  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
156  		upgrader := websocket.Upgrader{}
157  		conn, err := upgrader.Upgrade(w, r, nil)
158  		if err != nil {
159  			return
160  		}
161  		defer conn.Close()
162  
163  		payload, _ := json.Marshal(MessagePayload{Channel: "slack", Text: "hi"})
164  		conn.WriteJSON(ServerMessage{Type: MsgTypeMessage, MessageID: "msg-002", Payload: payload})
165  
166  		var dm DaemonMessage
167  		conn.ReadJSON(&dm)
168  
169  		ackPayload, _ := json.Marshal(ClaimAckPayload{Granted: false})
170  		conn.WriteJSON(ServerMessage{Type: MsgTypeClaimAck, MessageID: "msg-002", Payload: ackPayload})
171  
172  		// Keep connection open briefly so the client can process the denial.
173  		time.Sleep(500 * time.Millisecond)
174  	}))
175  	defer srv.Close()
176  
177  	wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
178  	onMsgCalled := false
179  	c := NewClient(wsURL, "", func(msg MessagePayload) string {
180  		onMsgCalled = true
181  		return ""
182  	}, nil)
183  
184  	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
185  	defer cancel()
186  	if err := c.Connect(ctx); err != nil {
187  		t.Fatal(err)
188  	}
189  	defer c.Close()
190  
191  	go c.Listen(ctx)
192  	time.Sleep(500 * time.Millisecond)
193  	cancel()
194  
195  	if onMsgCalled {
196  		t.Error("onMsg should NOT be called when claim is denied")
197  	}
198  }
199  
200  func TestClient_GracefulDisconnect(t *testing.T) {
201  	msgs := make(chan DaemonMessage, 10)
202  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
203  		upgrader := websocket.Upgrader{}
204  		conn, _ := upgrader.Upgrade(w, r, nil)
205  		defer conn.Close()
206  		for {
207  			var dm DaemonMessage
208  			if err := conn.ReadJSON(&dm); err != nil {
209  				return
210  			}
211  			msgs <- dm
212  		}
213  	}))
214  	defer srv.Close()
215  
216  	wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
217  	c := NewClient(wsURL, "", func(msg MessagePayload) string { return "" }, nil)
218  
219  	ctx, cancel := context.WithCancel(context.Background())
220  	if err := c.Connect(ctx); err != nil {
221  		t.Fatal(err)
222  	}
223  	go c.Listen(ctx)
224  	time.Sleep(100 * time.Millisecond)
225  
226  	cancel()
227  	time.Sleep(200 * time.Millisecond)
228  
229  	// Check if disconnect was the last message
230  	var lastMsg DaemonMessage
231  	for {
232  		select {
233  		case m := <-msgs:
234  			lastMsg = m
235  		default:
236  			goto done
237  		}
238  	}
239  done:
240  	if lastMsg.Type != MsgTypeDisconnect {
241  		t.Errorf("expected disconnect message, got type=%q", lastMsg.Type)
242  	}
243  }
244  
245  func TestClient_SystemMessage(t *testing.T) {
246  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
247  		upgrader := websocket.Upgrader{}
248  		conn, err := upgrader.Upgrade(w, r, nil)
249  		if err != nil {
250  			return
251  		}
252  		defer conn.Close()
253  		payload := json.RawMessage(`"Quota warning: 90% used"`)
254  		conn.WriteJSON(ServerMessage{Type: MsgTypeSystem, Payload: payload})
255  		time.Sleep(500 * time.Millisecond)
256  	}))
257  	defer srv.Close()
258  
259  	wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
260  	systemCh := make(chan string, 1)
261  	c := NewClient(wsURL, "", func(msg MessagePayload) string { return "" }, func(text string) {
262  		systemCh <- text
263  	})
264  
265  	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
266  	defer cancel()
267  	if err := c.Connect(ctx); err != nil {
268  		t.Fatal(err)
269  	}
270  	defer c.Close()
271  
272  	go c.Listen(ctx)
273  
274  	select {
275  	case msg := <-systemCh:
276  		if msg != "Quota warning: 90% used" {
277  			t.Errorf("system message = %q, want %q", msg, "Quota warning: 90% used")
278  		}
279  	case <-time.After(2 * time.Second):
280  		t.Fatal("system message not received")
281  	}
282  }
283  
284  func TestClient_SendEvent_WireFormat(t *testing.T) {
285  	received := make(chan DaemonMessage, 1)
286  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
287  		upgrader := websocket.Upgrader{}
288  		conn, err := upgrader.Upgrade(w, r, nil)
289  		if err != nil {
290  			return
291  		}
292  		defer conn.Close()
293  		var dm DaemonMessage
294  		if err := conn.ReadJSON(&dm); err != nil {
295  			return
296  		}
297  		received <- dm
298  	}))
299  	defer srv.Close()
300  
301  	wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
302  	c := NewClient(wsURL, "", nil, nil)
303  	if err := c.Connect(context.Background()); err != nil {
304  		t.Fatal(err)
305  	}
306  	defer c.Close()
307  
308  	if err := c.SendEvent("msg-100", "tool_start", "running web_search", map[string]interface{}{"tool": "web_search"}); err != nil {
309  		t.Fatal(err)
310  	}
311  
312  	select {
313  	case dm := <-received:
314  		if dm.Type != MsgTypeEvent {
315  			t.Errorf("type = %q, want %q", dm.Type, MsgTypeEvent)
316  		}
317  		if dm.MessageID != "msg-100" {
318  			t.Errorf("message_id = %q, want %q", dm.MessageID, "msg-100")
319  		}
320  		var ep DaemonEventPayload
321  		if err := json.Unmarshal(dm.Payload, &ep); err != nil {
322  			t.Fatalf("unmarshal payload: %v", err)
323  		}
324  		if ep.EventType != "tool_start" {
325  			t.Errorf("event_type = %q, want %q", ep.EventType, "tool_start")
326  		}
327  		if ep.Seq != 1 {
328  			t.Errorf("seq = %d, want 1", ep.Seq)
329  		}
330  		if ep.Timestamp == "" {
331  			t.Error("timestamp should not be empty")
332  		}
333  	case <-time.After(2 * time.Second):
334  		t.Fatal("server did not receive event")
335  	}
336  }
337  
338  func TestClient_SendEvent_SeqIncrementsPerMessage(t *testing.T) {
339  	received := make(chan DaemonMessage, 3)
340  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
341  		upgrader := websocket.Upgrader{}
342  		conn, err := upgrader.Upgrade(w, r, nil)
343  		if err != nil {
344  			return
345  		}
346  		defer conn.Close()
347  		for i := 0; i < 3; i++ {
348  			var dm DaemonMessage
349  			if err := conn.ReadJSON(&dm); err != nil {
350  				return
351  			}
352  			received <- dm
353  		}
354  	}))
355  	defer srv.Close()
356  
357  	wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
358  	c := NewClient(wsURL, "", nil, nil)
359  	if err := c.Connect(context.Background()); err != nil {
360  		t.Fatal(err)
361  	}
362  	defer c.Close()
363  
364  	// Send 2 events for msg-A, 1 for msg-B.
365  	c.SendEvent("msg-A", "step", "one", nil)
366  	c.SendEvent("msg-A", "step", "two", nil)
367  	c.SendEvent("msg-B", "step", "one", nil)
368  
369  	seqs := make(map[string][]int64)
370  	for i := 0; i < 3; i++ {
371  		select {
372  		case dm := <-received:
373  			var ep DaemonEventPayload
374  			json.Unmarshal(dm.Payload, &ep)
375  			seqs[dm.MessageID] = append(seqs[dm.MessageID], ep.Seq)
376  		case <-time.After(2 * time.Second):
377  			t.Fatal("timeout waiting for events")
378  		}
379  	}
380  
381  	if got := seqs["msg-A"]; len(got) != 2 || got[0] != 1 || got[1] != 2 {
382  		t.Errorf("msg-A seqs = %v, want [1 2]", got)
383  	}
384  	if got := seqs["msg-B"]; len(got) != 1 || got[0] != 1 {
385  		t.Errorf("msg-B seqs = %v, want [1]", got)
386  	}
387  }
388  
389  func TestClient_SendReply_CleansUpSeq(t *testing.T) {
390  	c := NewClient("ws://localhost:1/x", "", nil, nil)
391  	// Pre-populate a seq counter.
392  	c.eventSeqs.Store("msg-cleanup", new(atomic.Int64))
393  
394  	// SendReply will fail (no connection) but should still clean up eventSeqs.
395  	_ = c.SendReply("msg-cleanup", ReplyPayload{Text: "done"})
396  
397  	if _, loaded := c.eventSeqs.Load("msg-cleanup"); loaded {
398  		t.Error("eventSeqs entry should have been deleted by SendReply")
399  	}
400  }