client.go
1 package daemon 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "log" 8 "net/http" 9 "sync" 10 "sync/atomic" 11 "time" 12 13 "github.com/gorilla/websocket" 14 ) 15 16 // MaxConcurrentAgents limits how many agent loops can run simultaneously. 17 const MaxConcurrentAgents = 5 18 19 type Client struct { 20 endpoint string 21 apiKey string 22 conn *websocket.Conn 23 writeMu sync.Mutex 24 onMsg func(MessagePayload) string // returns reply text 25 onSystem func(string) // system notifications 26 sem chan struct{} 27 pendingClaims sync.Map // map[string]chan bool 28 activeMsgs sync.Map // map[string]context.CancelFunc 29 eventSeqs sync.Map // map[string]*atomic.Int64 30 connected atomic.Bool 31 activeAgent atomic.Value // stores string 32 startTime time.Time 33 broker *ApprovalBroker 34 eventBus *EventBus 35 } 36 37 // SetEventBus sets the event bus for emitting daemon events. 38 func (c *Client) SetEventBus(bus *EventBus) { 39 c.eventBus = bus 40 } 41 42 func NewClient(endpoint, apiKey string, onMsg func(MessagePayload) string, onSystem func(string)) *Client { 43 return &Client{ 44 endpoint: endpoint, 45 apiKey: apiKey, 46 onMsg: onMsg, 47 onSystem: onSystem, 48 sem: make(chan struct{}, MaxConcurrentAgents), 49 startTime: time.Now(), 50 } 51 } 52 53 func (c *Client) Connect(ctx context.Context) error { 54 header := http.Header{} 55 header.Set("Authorization", "Bearer "+c.apiKey) 56 dialer := websocket.Dialer{ 57 HandshakeTimeout: 10 * time.Second, 58 } 59 conn, _, err := dialer.DialContext(ctx, c.endpoint, header) 60 if err != nil { 61 return fmt.Errorf("websocket connect: %w", err) 62 } 63 c.conn = conn 64 return nil 65 } 66 67 // IsConnected reports whether the client has an active WebSocket connection. 68 func (c *Client) IsConnected() bool { 69 return c.connected.Load() 70 } 71 72 // ActiveAgent returns the name of the agent currently processing a message, 73 // or "" if idle. 74 func (c *Client) ActiveAgent() string { 75 if v := c.activeAgent.Load(); v != nil { 76 return v.(string) 77 } 78 return "" 79 } 80 81 // Uptime returns how long since the client was created. 82 func (c *Client) Uptime() time.Duration { 83 return time.Since(c.startTime) 84 } 85 86 func (c *Client) sendEnvelope(dm DaemonMessage) error { 87 if c.conn == nil { 88 return fmt.Errorf("not connected") 89 } 90 data, err := json.Marshal(dm) 91 if err != nil { 92 return err 93 } 94 c.writeMu.Lock() 95 defer c.writeMu.Unlock() 96 _ = c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) 97 return c.conn.WriteMessage(websocket.TextMessage, data) 98 } 99 100 func (c *Client) sendClaim(messageID string) error { 101 return c.sendEnvelope(DaemonMessage{Type: MsgTypeClaim, MessageID: messageID}) 102 } 103 104 func (c *Client) sendProgress(messageID string) error { 105 return c.sendEnvelope(DaemonMessage{Type: MsgTypeProgress, MessageID: messageID}) 106 } 107 108 // SendProgressWithWorkflow sends a progress heartbeat with a workflow_id payload. 109 // This tells Cloud to start streaming card replies for the originating channel. 110 func (c *Client) SendProgressWithWorkflow(messageID, workflowID string) error { 111 payload, _ := json.Marshal(map[string]string{"workflow_id": workflowID}) 112 return c.sendEnvelope(DaemonMessage{Type: MsgTypeProgress, MessageID: messageID, Payload: payload}) 113 } 114 115 // SendEvent sends a daemon agent loop event to Cloud for channel streaming. 116 // Fire-and-forget: errors are returned but callers should log and continue. 117 func (c *Client) SendEvent(messageID string, eventType, message string, data map[string]interface{}) error { 118 val, _ := c.eventSeqs.LoadOrStore(messageID, new(atomic.Int64)) 119 seq := val.(*atomic.Int64).Add(1) 120 121 payload, err := json.Marshal(DaemonEventPayload{ 122 EventType: eventType, 123 Message: message, 124 Data: data, 125 Seq: seq, 126 Timestamp: time.Now().UTC().Format(time.RFC3339), 127 }) 128 if err != nil { 129 return err 130 } 131 return c.sendEnvelope(DaemonMessage{ 132 Type: MsgTypeEvent, 133 MessageID: messageID, 134 Payload: payload, 135 }) 136 } 137 138 // SendReply sends the final reply for a message and cancels its heartbeat. 139 func (c *Client) SendReply(messageID string, payload ReplyPayload) error { 140 c.eventSeqs.Delete(messageID) 141 if cancel, ok := c.activeMsgs.LoadAndDelete(messageID); ok { 142 cancel.(context.CancelFunc)() 143 } 144 payloadBytes, err := json.Marshal(payload) 145 if err != nil { 146 return err 147 } 148 return c.sendEnvelope(DaemonMessage{Type: MsgTypeReply, MessageID: messageID, Payload: payloadBytes}) 149 } 150 151 // SendProactive sends an unsolicited message to all channels mapped to the agent. 152 // This is fire-and-forget — no claim/ack cycle. 153 func (c *Client) SendProactive(agentName, text, sessionID string) error { 154 if agentName == "" || text == "" { 155 return nil 156 } 157 payload, err := json.Marshal(ProactivePayload{ 158 AgentName: agentName, 159 Text: text, 160 Format: FormatText, 161 SessionID: sessionID, 162 }) 163 if err != nil { 164 return fmt.Errorf("marshal proactive payload: %w", err) 165 } 166 return c.sendEnvelope(DaemonMessage{ 167 Type: MsgTypeProactive, 168 Payload: payload, 169 }) 170 } 171 172 func (c *Client) sendDisconnect() error { 173 return c.sendEnvelope(DaemonMessage{Type: MsgTypeDisconnect}) 174 } 175 176 // Close sends a disconnect message and closes the WebSocket connection. 177 func (c *Client) Close() error { 178 if c.conn == nil { 179 return nil 180 } 181 _ = c.sendDisconnect() 182 return c.conn.Close() 183 } 184 185 // SetApprovalBroker sets the broker for interactive tool approval. 186 func (c *Client) SetApprovalBroker(b *ApprovalBroker) { 187 c.broker = b 188 } 189 190 // SendApprovalRequest sends an approval_request message over WS. 191 func (c *Client) SendApprovalRequest(req ApprovalRequest) error { 192 payload, err := json.Marshal(req) 193 if err != nil { 194 return err 195 } 196 return c.sendEnvelope(DaemonMessage{Type: MsgTypeApprovalRequest, Payload: payload}) 197 } 198 199 // SendApprovalResolved sends an approval_resolved message over WS to Cloud. 200 func (c *Client) SendApprovalResolved(p ApprovalResolvedPayload) error { 201 payload, err := json.Marshal(p) 202 if err != nil { 203 return err 204 } 205 return c.sendEnvelope(DaemonMessage{ 206 Type: MsgTypeApprovalResolved, 207 Payload: payload, 208 }) 209 } 210 211 // Listen reads messages from the WebSocket and dispatches them. 212 // It blocks until the context is cancelled or the connection drops. 213 func (c *Client) Listen(ctx context.Context) error { 214 if c.conn == nil { 215 return fmt.Errorf("not connected") 216 } 217 c.connected.Store(true) 218 defer func() { 219 c.connected.Store(false) 220 if c.broker != nil { 221 c.broker.CancelAll() 222 } 223 c.conn.Close() 224 }() 225 226 go func() { 227 <-ctx.Done() 228 _ = c.sendDisconnect() 229 c.conn.Close() 230 }() 231 232 for { 233 _, data, err := c.conn.ReadMessage() 234 if err != nil { 235 if ctx.Err() != nil { 236 return ctx.Err() 237 } 238 return fmt.Errorf("read: %w", err) 239 } 240 241 var sm ServerMessage 242 if err := json.Unmarshal(data, &sm); err != nil { 243 log.Printf("daemon: invalid message: %v", err) 244 continue 245 } 246 247 switch sm.Type { 248 case MsgTypeConnected: 249 log.Println("daemon: connected to Shannon Cloud") 250 case MsgTypeMessage: 251 go c.handleMessage(ctx, sm) 252 case MsgTypeClaimAck: 253 if ch, ok := c.pendingClaims.Load(sm.MessageID); ok { 254 var ack ClaimAckPayload 255 if err := json.Unmarshal(sm.Payload, &ack); err == nil { 256 select { 257 case ch.(chan bool) <- ack.Granted: 258 default: 259 } 260 } 261 } 262 case MsgTypeApprovalResponse: 263 var resp ApprovalResponse 264 if err := json.Unmarshal(sm.Payload, &resp); err != nil { 265 log.Printf("daemon: invalid approval_response: %v", err) 266 continue 267 } 268 // Emit before Resolve so Ptfrog dismisses the card before seeing the reply. 269 resolvedBy := resp.ResolvedBy 270 if resolvedBy == "" { 271 resolvedBy = "external" 272 } 273 if c.eventBus != nil { 274 payload, _ := json.Marshal(map[string]string{ 275 "request_id": resp.RequestID, 276 "decision": string(resp.Decision), 277 "resolved_by": resolvedBy, 278 }) 279 c.eventBus.Emit(Event{Type: EventApprovalResolved, Payload: payload}) 280 } 281 if c.broker != nil { 282 c.broker.Resolve(resp.RequestID, resp.Decision) 283 } 284 case MsgTypeSystem: 285 if c.onSystem != nil { 286 var text string 287 if err := json.Unmarshal(sm.Payload, &text); err == nil { 288 c.onSystem(text) 289 } 290 } 291 default: 292 log.Printf("daemon: unknown message type: %s", sm.Type) 293 } 294 } 295 } 296 297 func (c *Client) handleMessage(ctx context.Context, sm ServerMessage) { 298 var payload MessagePayload 299 if err := json.Unmarshal(sm.Payload, &payload); err != nil { 300 log.Printf("daemon: invalid message payload: %v", err) 301 return 302 } 303 304 // Acquire semaphore for bounded concurrency with context check. 305 select { 306 case c.sem <- struct{}{}: 307 case <-ctx.Done(): 308 log.Printf("daemon: context cancelled waiting for semaphore (message %s)", sm.MessageID) 309 return 310 } 311 defer func() { <-c.sem }() 312 313 // Send claim. 314 claimCh := make(chan bool, 1) 315 c.pendingClaims.Store(sm.MessageID, claimCh) 316 defer c.pendingClaims.Delete(sm.MessageID) 317 318 if err := c.sendClaim(sm.MessageID); err != nil { 319 log.Printf("daemon: failed to send claim: %v", err) 320 return 321 } 322 323 // Wait for claim ack with 5s timeout. 324 select { 325 case granted := <-claimCh: 326 if !granted { 327 log.Printf("daemon: claim denied for %s", sm.MessageID) 328 return 329 } 330 case <-time.After(5 * time.Second): 331 log.Printf("daemon: claim timeout for %s", sm.MessageID) 332 return 333 case <-ctx.Done(): 334 return 335 } 336 337 // Start heartbeat. 338 heartbeatCtx, heartbeatCancel := context.WithCancel(ctx) 339 c.activeMsgs.Store(sm.MessageID, heartbeatCancel) 340 go func() { 341 ticker := time.NewTicker(15 * time.Second) 342 defer ticker.Stop() 343 for { 344 select { 345 case <-heartbeatCtx.Done(): 346 return 347 case <-ticker.C: 348 _ = c.sendProgress(sm.MessageID) 349 } 350 } 351 }() 352 353 // Attach envelope messageID so downstream tools can reference it. 354 payload.MessageID = sm.MessageID 355 356 // Set active agent. 357 agentName := payload.AgentName 358 if agentName == "" { 359 agentName = "(default)" 360 } 361 c.activeAgent.Store(agentName) 362 363 // Run agent callback. 364 result := c.onMsg(payload) 365 366 // Cleanup. 367 c.activeAgent.Store("") 368 heartbeatCancel() 369 c.activeMsgs.Delete(sm.MessageID) 370 371 // Send reply. 372 if err := c.SendReply(sm.MessageID, ReplyPayload{ 373 Channel: payload.Channel, 374 ThreadID: payload.ThreadID, 375 Text: result, 376 Format: FormatText, 377 }); err != nil { 378 log.Printf("daemon: SendReply failed for message %s: %v", sm.MessageID, err) 379 if c.eventBus != nil { 380 // Match the source fallback applied at the WS callback entry point 381 // (cmd/daemon.go) so consumers see a consistent source field during 382 // Cloud rolling deploys where msg.Source may be empty. 383 source := payload.Source 384 if source == "" { 385 source = payload.Channel 386 } 387 // Use the raw payload.AgentName (not the rewritten "(default)" 388 // display value used by c.activeAgent) so consumers that route 389 // on the agent identifier — including the desktop matcher that 390 // expects "" / "default" for the default agent — see the wire 391 // form rather than the local display string. 392 errPayload, _ := json.Marshal(map[string]any{ 393 "agent": payload.AgentName, 394 "message_id": sm.MessageID, 395 "source": source, 396 "error": fmt.Sprintf("reply delivery failed: %v", err), 397 }) 398 c.eventBus.Emit(Event{Type: EventAgentError, Payload: errPayload}) 399 } 400 } 401 } 402 403 // RunWithReconnect connects to the server and reconnects on failure with 404 // exponential backoff. It blocks until the context is cancelled. 405 func (c *Client) RunWithReconnect(ctx context.Context) { 406 backoff := time.Second 407 maxBackoff := 30 * time.Second 408 for { 409 select { 410 case <-ctx.Done(): 411 return 412 default: 413 } 414 if err := c.Connect(ctx); err != nil { 415 if ctx.Err() != nil { 416 return 417 } 418 log.Printf("daemon: connect failed: %v (retry in %v)", err, backoff) 419 select { 420 case <-ctx.Done(): 421 return 422 case <-time.After(backoff): 423 } 424 backoff = min(backoff*2, maxBackoff) 425 continue 426 } 427 backoff = time.Second 428 if err := c.Listen(ctx); err != nil { 429 if ctx.Err() != nil { 430 return 431 } 432 log.Printf("daemon: connection lost: %v (reconnecting)", err) 433 } 434 } 435 }