client_test.go
1 package daemon 2 3 import ( 4 "context" 5 "encoding/json" 6 "net/http" 7 "net/http/httptest" 8 "strings" 9 "sync/atomic" 10 "testing" 11 "time" 12 13 "github.com/gorilla/websocket" 14 ) 15 16 func TestRunWithReconnect_CancelledContextExitsImmediately(t *testing.T) { 17 ctx, cancel := context.WithCancel(context.Background()) 18 client := NewClient("ws://localhost:99999/nonexistent", "key", func(msg MessagePayload) string { return "" }, nil) 19 20 done := make(chan struct{}) 21 go func() { 22 client.RunWithReconnect(ctx) 23 close(done) 24 }() 25 26 time.Sleep(100 * time.Millisecond) 27 cancel() 28 29 select { 30 case <-done: 31 case <-time.After(2 * time.Second): 32 t.Fatal("RunWithReconnect did not exit within 2s after cancel") 33 } 34 } 35 36 func TestClient_SendEnvelope_WritesToConn(t *testing.T) { 37 received := make(chan DaemonMessage, 1) 38 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 39 upgrader := websocket.Upgrader{} 40 conn, err := upgrader.Upgrade(w, r, nil) 41 if err != nil { 42 return 43 } 44 defer conn.Close() 45 var dm DaemonMessage 46 if err := conn.ReadJSON(&dm); err != nil { 47 return 48 } 49 received <- dm 50 })) 51 defer srv.Close() 52 53 wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") 54 c := NewClient(wsURL, "", nil, nil) 55 if err := c.Connect(context.Background()); err != nil { 56 t.Fatal(err) 57 } 58 defer c.Close() 59 60 if err := c.sendEnvelope(DaemonMessage{Type: MsgTypeClaim, MessageID: "msg-123"}); err != nil { 61 t.Fatal(err) 62 } 63 64 select { 65 case dm := <-received: 66 if dm.Type != MsgTypeClaim || dm.MessageID != "msg-123" { 67 t.Errorf("got type=%q id=%q, want type=%q id=%q", dm.Type, dm.MessageID, MsgTypeClaim, "msg-123") 68 } 69 case <-time.After(2 * time.Second): 70 t.Fatal("server did not receive message") 71 } 72 } 73 74 func TestClient_ConnectionState(t *testing.T) { 75 c := NewClient("ws://localhost:1/x", "", nil, nil) 76 if c.IsConnected() { 77 t.Error("should not be connected initially") 78 } 79 if c.Uptime() < 0 { 80 t.Error("uptime should be non-negative") 81 } 82 if c.ActiveAgent() != "" { 83 t.Error("no active agent initially") 84 } 85 } 86 87 func TestClient_ClaimHandshake_Granted(t *testing.T) { 88 var receivedClaim DaemonMessage 89 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 90 upgrader := websocket.Upgrader{} 91 conn, err := upgrader.Upgrade(w, r, nil) 92 if err != nil { 93 return 94 } 95 defer conn.Close() 96 97 // Send a message to the daemon. 98 payload, _ := json.Marshal(MessagePayload{Channel: "slack", Text: "hi", ThreadID: "t1"}) 99 conn.WriteJSON(ServerMessage{Type: MsgTypeMessage, MessageID: "msg-001", Payload: payload}) 100 101 // Read the claim. 102 conn.ReadJSON(&receivedClaim) 103 104 // Grant the claim. 105 ackPayload, _ := json.Marshal(ClaimAckPayload{Granted: true}) 106 conn.WriteJSON(ServerMessage{Type: MsgTypeClaimAck, MessageID: "msg-001", Payload: ackPayload}) 107 108 // Read messages until we get a reply (may get progress first). 109 for { 110 var reply DaemonMessage 111 if err := conn.ReadJSON(&reply); err != nil { 112 return 113 } 114 if reply.Type == MsgTypeReply { 115 var rp ReplyPayload 116 json.Unmarshal(reply.Payload, &rp) 117 if rp.Text != "agent-result" { 118 t.Errorf("reply text = %q, want %q", rp.Text, "agent-result") 119 } 120 return 121 } 122 } 123 })) 124 defer srv.Close() 125 126 wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") 127 onMsgCalled := make(chan struct{}) 128 c := NewClient(wsURL, "", func(msg MessagePayload) string { 129 close(onMsgCalled) 130 return "agent-result" 131 }, nil) 132 133 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) 134 defer cancel() 135 136 if err := c.Connect(ctx); err != nil { 137 t.Fatal(err) 138 } 139 defer c.Close() 140 141 go c.Listen(ctx) 142 143 select { 144 case <-onMsgCalled: 145 case <-ctx.Done(): 146 t.Fatal("onMsg was never called") 147 } 148 149 if receivedClaim.Type != MsgTypeClaim || receivedClaim.MessageID != "msg-001" { 150 t.Errorf("expected claim for msg-001, got %+v", receivedClaim) 151 } 152 } 153 154 func TestClient_ClaimHandshake_Denied(t *testing.T) { 155 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 156 upgrader := websocket.Upgrader{} 157 conn, err := upgrader.Upgrade(w, r, nil) 158 if err != nil { 159 return 160 } 161 defer conn.Close() 162 163 payload, _ := json.Marshal(MessagePayload{Channel: "slack", Text: "hi"}) 164 conn.WriteJSON(ServerMessage{Type: MsgTypeMessage, MessageID: "msg-002", Payload: payload}) 165 166 var dm DaemonMessage 167 conn.ReadJSON(&dm) 168 169 ackPayload, _ := json.Marshal(ClaimAckPayload{Granted: false}) 170 conn.WriteJSON(ServerMessage{Type: MsgTypeClaimAck, MessageID: "msg-002", Payload: ackPayload}) 171 172 // Keep connection open briefly so the client can process the denial. 173 time.Sleep(500 * time.Millisecond) 174 })) 175 defer srv.Close() 176 177 wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") 178 onMsgCalled := false 179 c := NewClient(wsURL, "", func(msg MessagePayload) string { 180 onMsgCalled = true 181 return "" 182 }, nil) 183 184 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 185 defer cancel() 186 if err := c.Connect(ctx); err != nil { 187 t.Fatal(err) 188 } 189 defer c.Close() 190 191 go c.Listen(ctx) 192 time.Sleep(500 * time.Millisecond) 193 cancel() 194 195 if onMsgCalled { 196 t.Error("onMsg should NOT be called when claim is denied") 197 } 198 } 199 200 func TestClient_GracefulDisconnect(t *testing.T) { 201 msgs := make(chan DaemonMessage, 10) 202 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 203 upgrader := websocket.Upgrader{} 204 conn, _ := upgrader.Upgrade(w, r, nil) 205 defer conn.Close() 206 for { 207 var dm DaemonMessage 208 if err := conn.ReadJSON(&dm); err != nil { 209 return 210 } 211 msgs <- dm 212 } 213 })) 214 defer srv.Close() 215 216 wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") 217 c := NewClient(wsURL, "", func(msg MessagePayload) string { return "" }, nil) 218 219 ctx, cancel := context.WithCancel(context.Background()) 220 if err := c.Connect(ctx); err != nil { 221 t.Fatal(err) 222 } 223 go c.Listen(ctx) 224 time.Sleep(100 * time.Millisecond) 225 226 cancel() 227 time.Sleep(200 * time.Millisecond) 228 229 // Check if disconnect was the last message 230 var lastMsg DaemonMessage 231 for { 232 select { 233 case m := <-msgs: 234 lastMsg = m 235 default: 236 goto done 237 } 238 } 239 done: 240 if lastMsg.Type != MsgTypeDisconnect { 241 t.Errorf("expected disconnect message, got type=%q", lastMsg.Type) 242 } 243 } 244 245 func TestClient_SystemMessage(t *testing.T) { 246 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 247 upgrader := websocket.Upgrader{} 248 conn, err := upgrader.Upgrade(w, r, nil) 249 if err != nil { 250 return 251 } 252 defer conn.Close() 253 payload := json.RawMessage(`"Quota warning: 90% used"`) 254 conn.WriteJSON(ServerMessage{Type: MsgTypeSystem, Payload: payload}) 255 time.Sleep(500 * time.Millisecond) 256 })) 257 defer srv.Close() 258 259 wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") 260 systemCh := make(chan string, 1) 261 c := NewClient(wsURL, "", func(msg MessagePayload) string { return "" }, func(text string) { 262 systemCh <- text 263 }) 264 265 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 266 defer cancel() 267 if err := c.Connect(ctx); err != nil { 268 t.Fatal(err) 269 } 270 defer c.Close() 271 272 go c.Listen(ctx) 273 274 select { 275 case msg := <-systemCh: 276 if msg != "Quota warning: 90% used" { 277 t.Errorf("system message = %q, want %q", msg, "Quota warning: 90% used") 278 } 279 case <-time.After(2 * time.Second): 280 t.Fatal("system message not received") 281 } 282 } 283 284 func TestClient_SendEvent_WireFormat(t *testing.T) { 285 received := make(chan DaemonMessage, 1) 286 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 287 upgrader := websocket.Upgrader{} 288 conn, err := upgrader.Upgrade(w, r, nil) 289 if err != nil { 290 return 291 } 292 defer conn.Close() 293 var dm DaemonMessage 294 if err := conn.ReadJSON(&dm); err != nil { 295 return 296 } 297 received <- dm 298 })) 299 defer srv.Close() 300 301 wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") 302 c := NewClient(wsURL, "", nil, nil) 303 if err := c.Connect(context.Background()); err != nil { 304 t.Fatal(err) 305 } 306 defer c.Close() 307 308 if err := c.SendEvent("msg-100", "tool_start", "running web_search", map[string]interface{}{"tool": "web_search"}); err != nil { 309 t.Fatal(err) 310 } 311 312 select { 313 case dm := <-received: 314 if dm.Type != MsgTypeEvent { 315 t.Errorf("type = %q, want %q", dm.Type, MsgTypeEvent) 316 } 317 if dm.MessageID != "msg-100" { 318 t.Errorf("message_id = %q, want %q", dm.MessageID, "msg-100") 319 } 320 var ep DaemonEventPayload 321 if err := json.Unmarshal(dm.Payload, &ep); err != nil { 322 t.Fatalf("unmarshal payload: %v", err) 323 } 324 if ep.EventType != "tool_start" { 325 t.Errorf("event_type = %q, want %q", ep.EventType, "tool_start") 326 } 327 if ep.Seq != 1 { 328 t.Errorf("seq = %d, want 1", ep.Seq) 329 } 330 if ep.Timestamp == "" { 331 t.Error("timestamp should not be empty") 332 } 333 case <-time.After(2 * time.Second): 334 t.Fatal("server did not receive event") 335 } 336 } 337 338 func TestClient_SendEvent_SeqIncrementsPerMessage(t *testing.T) { 339 received := make(chan DaemonMessage, 3) 340 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 341 upgrader := websocket.Upgrader{} 342 conn, err := upgrader.Upgrade(w, r, nil) 343 if err != nil { 344 return 345 } 346 defer conn.Close() 347 for i := 0; i < 3; i++ { 348 var dm DaemonMessage 349 if err := conn.ReadJSON(&dm); err != nil { 350 return 351 } 352 received <- dm 353 } 354 })) 355 defer srv.Close() 356 357 wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") 358 c := NewClient(wsURL, "", nil, nil) 359 if err := c.Connect(context.Background()); err != nil { 360 t.Fatal(err) 361 } 362 defer c.Close() 363 364 // Send 2 events for msg-A, 1 for msg-B. 365 c.SendEvent("msg-A", "step", "one", nil) 366 c.SendEvent("msg-A", "step", "two", nil) 367 c.SendEvent("msg-B", "step", "one", nil) 368 369 seqs := make(map[string][]int64) 370 for i := 0; i < 3; i++ { 371 select { 372 case dm := <-received: 373 var ep DaemonEventPayload 374 json.Unmarshal(dm.Payload, &ep) 375 seqs[dm.MessageID] = append(seqs[dm.MessageID], ep.Seq) 376 case <-time.After(2 * time.Second): 377 t.Fatal("timeout waiting for events") 378 } 379 } 380 381 if got := seqs["msg-A"]; len(got) != 2 || got[0] != 1 || got[1] != 2 { 382 t.Errorf("msg-A seqs = %v, want [1 2]", got) 383 } 384 if got := seqs["msg-B"]; len(got) != 1 || got[0] != 1 { 385 t.Errorf("msg-B seqs = %v, want [1]", got) 386 } 387 } 388 389 func TestClient_SendReply_CleansUpSeq(t *testing.T) { 390 c := NewClient("ws://localhost:1/x", "", nil, nil) 391 // Pre-populate a seq counter. 392 c.eventSeqs.Store("msg-cleanup", new(atomic.Int64)) 393 394 // SendReply will fail (no connection) but should still clean up eventSeqs. 395 _ = c.SendReply("msg-cleanup", ReplyPayload{Text: "done"}) 396 397 if _, loaded := c.eventSeqs.Load("msg-cleanup"); loaded { 398 t.Error("eventSeqs entry should have been deleted by SendReply") 399 } 400 }