/ internal / daemon / client.go
client.go
  1  package daemon
  2  
  3  import (
  4  	"context"
  5  	"encoding/json"
  6  	"fmt"
  7  	"log"
  8  	"net/http"
  9  	"sync"
 10  	"sync/atomic"
 11  	"time"
 12  
 13  	"github.com/gorilla/websocket"
 14  )
 15  
 16  // MaxConcurrentAgents limits how many agent loops can run simultaneously.
 17  const MaxConcurrentAgents = 5
 18  
 19  type Client struct {
 20  	endpoint      string
 21  	apiKey        string
 22  	conn          *websocket.Conn
 23  	writeMu       sync.Mutex
 24  	onMsg         func(MessagePayload) string // returns reply text
 25  	onSystem      func(string)                // system notifications
 26  	sem           chan struct{}
 27  	pendingClaims sync.Map   // map[string]chan bool
 28  	activeMsgs    sync.Map   // map[string]context.CancelFunc
 29  	eventSeqs     sync.Map   // map[string]*atomic.Int64
 30  	connected     atomic.Bool
 31  	activeAgent   atomic.Value // stores string
 32  	startTime     time.Time
 33  	broker        *ApprovalBroker
 34  	eventBus      *EventBus
 35  }
 36  
 37  // SetEventBus sets the event bus for emitting daemon events.
 38  func (c *Client) SetEventBus(bus *EventBus) {
 39  	c.eventBus = bus
 40  }
 41  
 42  func NewClient(endpoint, apiKey string, onMsg func(MessagePayload) string, onSystem func(string)) *Client {
 43  	return &Client{
 44  		endpoint:  endpoint,
 45  		apiKey:    apiKey,
 46  		onMsg:     onMsg,
 47  		onSystem:  onSystem,
 48  		sem:       make(chan struct{}, MaxConcurrentAgents),
 49  		startTime: time.Now(),
 50  	}
 51  }
 52  
 53  func (c *Client) Connect(ctx context.Context) error {
 54  	header := http.Header{}
 55  	header.Set("Authorization", "Bearer "+c.apiKey)
 56  	dialer := websocket.Dialer{
 57  		HandshakeTimeout: 10 * time.Second,
 58  	}
 59  	conn, _, err := dialer.DialContext(ctx, c.endpoint, header)
 60  	if err != nil {
 61  		return fmt.Errorf("websocket connect: %w", err)
 62  	}
 63  	c.conn = conn
 64  	return nil
 65  }
 66  
 67  // IsConnected reports whether the client has an active WebSocket connection.
 68  func (c *Client) IsConnected() bool {
 69  	return c.connected.Load()
 70  }
 71  
 72  // ActiveAgent returns the name of the agent currently processing a message,
 73  // or "" if idle.
 74  func (c *Client) ActiveAgent() string {
 75  	if v := c.activeAgent.Load(); v != nil {
 76  		return v.(string)
 77  	}
 78  	return ""
 79  }
 80  
 81  // Uptime returns how long since the client was created.
 82  func (c *Client) Uptime() time.Duration {
 83  	return time.Since(c.startTime)
 84  }
 85  
 86  func (c *Client) sendEnvelope(dm DaemonMessage) error {
 87  	if c.conn == nil {
 88  		return fmt.Errorf("not connected")
 89  	}
 90  	data, err := json.Marshal(dm)
 91  	if err != nil {
 92  		return err
 93  	}
 94  	c.writeMu.Lock()
 95  	defer c.writeMu.Unlock()
 96  	_ = c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
 97  	return c.conn.WriteMessage(websocket.TextMessage, data)
 98  }
 99  
100  func (c *Client) sendClaim(messageID string) error {
101  	return c.sendEnvelope(DaemonMessage{Type: MsgTypeClaim, MessageID: messageID})
102  }
103  
104  func (c *Client) sendProgress(messageID string) error {
105  	return c.sendEnvelope(DaemonMessage{Type: MsgTypeProgress, MessageID: messageID})
106  }
107  
108  // SendProgressWithWorkflow sends a progress heartbeat with a workflow_id payload.
109  // This tells Cloud to start streaming card replies for the originating channel.
110  func (c *Client) SendProgressWithWorkflow(messageID, workflowID string) error {
111  	payload, _ := json.Marshal(map[string]string{"workflow_id": workflowID})
112  	return c.sendEnvelope(DaemonMessage{Type: MsgTypeProgress, MessageID: messageID, Payload: payload})
113  }
114  
115  // SendEvent sends a daemon agent loop event to Cloud for channel streaming.
116  // Fire-and-forget: errors are returned but callers should log and continue.
117  func (c *Client) SendEvent(messageID string, eventType, message string, data map[string]interface{}) error {
118  	val, _ := c.eventSeqs.LoadOrStore(messageID, new(atomic.Int64))
119  	seq := val.(*atomic.Int64).Add(1)
120  
121  	payload, err := json.Marshal(DaemonEventPayload{
122  		EventType: eventType,
123  		Message:   message,
124  		Data:      data,
125  		Seq:       seq,
126  		Timestamp: time.Now().UTC().Format(time.RFC3339),
127  	})
128  	if err != nil {
129  		return err
130  	}
131  	return c.sendEnvelope(DaemonMessage{
132  		Type:      MsgTypeEvent,
133  		MessageID: messageID,
134  		Payload:   payload,
135  	})
136  }
137  
138  // SendReply sends the final reply for a message and cancels its heartbeat.
139  func (c *Client) SendReply(messageID string, payload ReplyPayload) error {
140  	c.eventSeqs.Delete(messageID)
141  	if cancel, ok := c.activeMsgs.LoadAndDelete(messageID); ok {
142  		cancel.(context.CancelFunc)()
143  	}
144  	payloadBytes, err := json.Marshal(payload)
145  	if err != nil {
146  		return err
147  	}
148  	return c.sendEnvelope(DaemonMessage{Type: MsgTypeReply, MessageID: messageID, Payload: payloadBytes})
149  }
150  
151  // SendProactive sends an unsolicited message to all channels mapped to the agent.
152  // This is fire-and-forget — no claim/ack cycle.
153  func (c *Client) SendProactive(agentName, text, sessionID string) error {
154  	if agentName == "" || text == "" {
155  		return nil
156  	}
157  	payload, err := json.Marshal(ProactivePayload{
158  		AgentName: agentName,
159  		Text:      text,
160  		Format:    FormatText,
161  		SessionID: sessionID,
162  	})
163  	if err != nil {
164  		return fmt.Errorf("marshal proactive payload: %w", err)
165  	}
166  	return c.sendEnvelope(DaemonMessage{
167  		Type:    MsgTypeProactive,
168  		Payload: payload,
169  	})
170  }
171  
172  func (c *Client) sendDisconnect() error {
173  	return c.sendEnvelope(DaemonMessage{Type: MsgTypeDisconnect})
174  }
175  
176  // Close sends a disconnect message and closes the WebSocket connection.
177  func (c *Client) Close() error {
178  	if c.conn == nil {
179  		return nil
180  	}
181  	_ = c.sendDisconnect()
182  	return c.conn.Close()
183  }
184  
185  // SetApprovalBroker sets the broker for interactive tool approval.
186  func (c *Client) SetApprovalBroker(b *ApprovalBroker) {
187  	c.broker = b
188  }
189  
190  // SendApprovalRequest sends an approval_request message over WS.
191  func (c *Client) SendApprovalRequest(req ApprovalRequest) error {
192  	payload, err := json.Marshal(req)
193  	if err != nil {
194  		return err
195  	}
196  	return c.sendEnvelope(DaemonMessage{Type: MsgTypeApprovalRequest, Payload: payload})
197  }
198  
199  // SendApprovalResolved sends an approval_resolved message over WS to Cloud.
200  func (c *Client) SendApprovalResolved(p ApprovalResolvedPayload) error {
201  	payload, err := json.Marshal(p)
202  	if err != nil {
203  		return err
204  	}
205  	return c.sendEnvelope(DaemonMessage{
206  		Type:    MsgTypeApprovalResolved,
207  		Payload: payload,
208  	})
209  }
210  
211  // Listen reads messages from the WebSocket and dispatches them.
212  // It blocks until the context is cancelled or the connection drops.
213  func (c *Client) Listen(ctx context.Context) error {
214  	if c.conn == nil {
215  		return fmt.Errorf("not connected")
216  	}
217  	c.connected.Store(true)
218  	defer func() {
219  		c.connected.Store(false)
220  		if c.broker != nil {
221  			c.broker.CancelAll()
222  		}
223  		c.conn.Close()
224  	}()
225  
226  	go func() {
227  		<-ctx.Done()
228  		_ = c.sendDisconnect()
229  		c.conn.Close()
230  	}()
231  
232  	for {
233  		_, data, err := c.conn.ReadMessage()
234  		if err != nil {
235  			if ctx.Err() != nil {
236  				return ctx.Err()
237  			}
238  			return fmt.Errorf("read: %w", err)
239  		}
240  
241  		var sm ServerMessage
242  		if err := json.Unmarshal(data, &sm); err != nil {
243  			log.Printf("daemon: invalid message: %v", err)
244  			continue
245  		}
246  
247  		switch sm.Type {
248  		case MsgTypeConnected:
249  			log.Println("daemon: connected to Shannon Cloud")
250  		case MsgTypeMessage:
251  			go c.handleMessage(ctx, sm)
252  		case MsgTypeClaimAck:
253  			if ch, ok := c.pendingClaims.Load(sm.MessageID); ok {
254  				var ack ClaimAckPayload
255  				if err := json.Unmarshal(sm.Payload, &ack); err == nil {
256  					select {
257  					case ch.(chan bool) <- ack.Granted:
258  					default:
259  					}
260  				}
261  			}
262  		case MsgTypeApprovalResponse:
263  			var resp ApprovalResponse
264  			if err := json.Unmarshal(sm.Payload, &resp); err != nil {
265  				log.Printf("daemon: invalid approval_response: %v", err)
266  				continue
267  			}
268  			// Emit before Resolve so Ptfrog dismisses the card before seeing the reply.
269  			resolvedBy := resp.ResolvedBy
270  			if resolvedBy == "" {
271  				resolvedBy = "external"
272  			}
273  			if c.eventBus != nil {
274  				payload, _ := json.Marshal(map[string]string{
275  					"request_id":  resp.RequestID,
276  					"decision":    string(resp.Decision),
277  					"resolved_by": resolvedBy,
278  				})
279  				c.eventBus.Emit(Event{Type: EventApprovalResolved, Payload: payload})
280  			}
281  			if c.broker != nil {
282  				c.broker.Resolve(resp.RequestID, resp.Decision)
283  			}
284  		case MsgTypeSystem:
285  			if c.onSystem != nil {
286  				var text string
287  				if err := json.Unmarshal(sm.Payload, &text); err == nil {
288  					c.onSystem(text)
289  				}
290  			}
291  		default:
292  			log.Printf("daemon: unknown message type: %s", sm.Type)
293  		}
294  	}
295  }
296  
297  func (c *Client) handleMessage(ctx context.Context, sm ServerMessage) {
298  	var payload MessagePayload
299  	if err := json.Unmarshal(sm.Payload, &payload); err != nil {
300  		log.Printf("daemon: invalid message payload: %v", err)
301  		return
302  	}
303  
304  	// Acquire semaphore for bounded concurrency with context check.
305  	select {
306  	case c.sem <- struct{}{}:
307  	case <-ctx.Done():
308  		log.Printf("daemon: context cancelled waiting for semaphore (message %s)", sm.MessageID)
309  		return
310  	}
311  	defer func() { <-c.sem }()
312  
313  	// Send claim.
314  	claimCh := make(chan bool, 1)
315  	c.pendingClaims.Store(sm.MessageID, claimCh)
316  	defer c.pendingClaims.Delete(sm.MessageID)
317  
318  	if err := c.sendClaim(sm.MessageID); err != nil {
319  		log.Printf("daemon: failed to send claim: %v", err)
320  		return
321  	}
322  
323  	// Wait for claim ack with 5s timeout.
324  	select {
325  	case granted := <-claimCh:
326  		if !granted {
327  			log.Printf("daemon: claim denied for %s", sm.MessageID)
328  			return
329  		}
330  	case <-time.After(5 * time.Second):
331  		log.Printf("daemon: claim timeout for %s", sm.MessageID)
332  		return
333  	case <-ctx.Done():
334  		return
335  	}
336  
337  	// Start heartbeat.
338  	heartbeatCtx, heartbeatCancel := context.WithCancel(ctx)
339  	c.activeMsgs.Store(sm.MessageID, heartbeatCancel)
340  	go func() {
341  		ticker := time.NewTicker(15 * time.Second)
342  		defer ticker.Stop()
343  		for {
344  			select {
345  			case <-heartbeatCtx.Done():
346  				return
347  			case <-ticker.C:
348  				_ = c.sendProgress(sm.MessageID)
349  			}
350  		}
351  	}()
352  
353  	// Attach envelope messageID so downstream tools can reference it.
354  	payload.MessageID = sm.MessageID
355  
356  	// Set active agent.
357  	agentName := payload.AgentName
358  	if agentName == "" {
359  		agentName = "(default)"
360  	}
361  	c.activeAgent.Store(agentName)
362  
363  	// Run agent callback.
364  	result := c.onMsg(payload)
365  
366  	// Cleanup.
367  	c.activeAgent.Store("")
368  	heartbeatCancel()
369  	c.activeMsgs.Delete(sm.MessageID)
370  
371  	// Send reply.
372  	if err := c.SendReply(sm.MessageID, ReplyPayload{
373  		Channel:  payload.Channel,
374  		ThreadID: payload.ThreadID,
375  		Text:     result,
376  		Format:   FormatText,
377  	}); err != nil {
378  		log.Printf("daemon: SendReply failed for message %s: %v", sm.MessageID, err)
379  		if c.eventBus != nil {
380  			// Match the source fallback applied at the WS callback entry point
381  			// (cmd/daemon.go) so consumers see a consistent source field during
382  			// Cloud rolling deploys where msg.Source may be empty.
383  			source := payload.Source
384  			if source == "" {
385  				source = payload.Channel
386  			}
387  			// Use the raw payload.AgentName (not the rewritten "(default)"
388  			// display value used by c.activeAgent) so consumers that route
389  			// on the agent identifier — including the desktop matcher that
390  			// expects "" / "default" for the default agent — see the wire
391  			// form rather than the local display string.
392  			errPayload, _ := json.Marshal(map[string]any{
393  				"agent":      payload.AgentName,
394  				"message_id": sm.MessageID,
395  				"source":     source,
396  				"error":      fmt.Sprintf("reply delivery failed: %v", err),
397  			})
398  			c.eventBus.Emit(Event{Type: EventAgentError, Payload: errPayload})
399  		}
400  	}
401  }
402  
403  // RunWithReconnect connects to the server and reconnects on failure with
404  // exponential backoff. It blocks until the context is cancelled.
405  func (c *Client) RunWithReconnect(ctx context.Context) {
406  	backoff := time.Second
407  	maxBackoff := 30 * time.Second
408  	for {
409  		select {
410  		case <-ctx.Done():
411  			return
412  		default:
413  		}
414  		if err := c.Connect(ctx); err != nil {
415  			if ctx.Err() != nil {
416  				return
417  			}
418  			log.Printf("daemon: connect failed: %v (retry in %v)", err, backoff)
419  			select {
420  			case <-ctx.Done():
421  				return
422  			case <-time.After(backoff):
423  			}
424  			backoff = min(backoff*2, maxBackoff)
425  			continue
426  		}
427  		backoff = time.Second
428  		if err := c.Listen(ctx); err != nil {
429  			if ctx.Err() != nil {
430  				return
431  			}
432  			log.Printf("daemon: connection lost: %v (reconnecting)", err)
433  		}
434  	}
435  }