/ common / websocket / sse_manager.go
sse_manager.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  package websocket
 20  
 21  import (
 22  	"encoding/json"
 23  	"fmt"
 24  	"net/http"
 25  	"sync"
 26  	"time"
 27  
 28  	"trpc.group/trpc-go/trpc-go/log"
 29  )
 30  
 31  // SSEConnection 表示一个SSE连接
 32  type SSEConnection struct {
 33  	SessionID string
 34  	Username  string
 35  	Writer    http.ResponseWriter
 36  	Flusher   http.Flusher
 37  	CloseChan chan bool
 38  	LastPing  time.Time
 39  }
 40  
 41  // SSEManager 管理SSE连接和事件推送
 42  type SSEManager struct {
 43  	connections map[string]*SSEConnection // sessionId -> connection
 44  	mutex       sync.RWMutex
 45  }
 46  
 47  // NewSSEManager 创建新的SSE管理器
 48  func NewSSEManager() *SSEManager {
 49  	return &SSEManager{
 50  		connections: make(map[string]*SSEConnection),
 51  	}
 52  }
 53  
 54  // AddConnection 添加新的SSE连接
 55  func (sm *SSEManager) AddConnection(sessionID, username string, w http.ResponseWriter) error {
 56  	sm.mutex.Lock()
 57  	defer sm.mutex.Unlock()
 58  
 59  	// 检查是否已存在相同sessionId的连接
 60  	if existing, exists := sm.connections[sessionID]; exists {
 61  		// 关闭现有连接
 62  		close(existing.CloseChan)
 63  		log.Infof("SSE连接冲突,关闭现有连接: sessionId=%s, username=%s", sessionID, username)
 64  	}
 65  
 66  	// 检查是否支持SSE
 67  	flusher, ok := w.(http.Flusher)
 68  	if !ok {
 69  		log.Errorf("SSE流式传输不支持: sessionId=%s, username=%s", sessionID, username)
 70  		return fmt.Errorf("streaming unsupported")
 71  	}
 72  
 73  	// 设置SSE响应头
 74  	w.Header().Set("Content-Type", "text/event-stream")
 75  	w.Header().Set("Cache-Control", "no-cache")
 76  	w.Header().Set("Connection", "keep-alive")
 77  	w.Header().Set("Access-Control-Allow-Origin", "*")
 78  	w.Header().Set("Access-Control-Allow-Headers", "Cache-Control")
 79  
 80  	// 创建连接
 81  	conn := &SSEConnection{
 82  		SessionID: sessionID,
 83  		Username:  username,
 84  		Writer:    w,
 85  		Flusher:   flusher,
 86  		CloseChan: make(chan bool),
 87  		LastPing:  time.Now(),
 88  	}
 89  
 90  	sm.connections[sessionID] = conn
 91  	log.Infof("SSE连接建立: sessionId=%s, username=%s, totalConnections=%d", sessionID, username, len(sm.connections))
 92  
 93  	// 发送连接成功消息
 94  	sm.sendEventToConnection(conn, "connected", "connected", map[string]interface{}{
 95  		"message":   "SSE连接已建立",
 96  		"sessionId": sessionID,
 97  	})
 98  
 99  	// 启动心跳和连接保持
100  	go sm.keepConnectionAlive(conn)
101  
102  	return nil
103  }
104  
105  // keepConnectionAlive 保持连接活跃
106  func (sm *SSEManager) keepConnectionAlive(conn *SSEConnection) {
107  	ticker := time.NewTicker(10 * time.Second) // 改为10秒心跳,提高频率
108  	defer ticker.Stop()
109  
110  	log.Debugf("SSE心跳启动: sessionId=%s, username=%s", conn.SessionID, conn.Username)
111  
112  	for {
113  		select {
114  		case <-conn.CloseChan:
115  			log.Infof("SSE连接已关闭: sessionId=%s", conn.SessionID)
116  			log.Infof("SSE连接关闭: sessionId=%s, username=%s", conn.SessionID, conn.Username)
117  			return
118  		case <-ticker.C:
119  			// 发送liveStatus心跳消息
120  			heartbeat := TaskEventMessage{
121  				ID:        fmt.Sprintf("heartbeat_%d", time.Now().Unix()),
122  				Type:      "liveStatus", // 改为liveStatus类型
123  				SessionID: conn.SessionID,
124  				Timestamp: time.Now().Unix(),
125  				Event: LiveStatusEvent{
126  					ID:        fmt.Sprintf("heartbeat_%d", time.Now().Unix()),
127  					Type:      "liveStatus",
128  					Timestamp: time.Now().UnixMilli(),
129  					Text:      "思考中...", // 默认状态文本
130  				},
131  			}
132  
133  			eventData, err := json.Marshal(heartbeat)
134  			if err != nil {
135  				log.Errorf("SSE心跳序列化失败: sessionId=%s, error=%v", conn.SessionID, err)
136  				continue
137  			}
138  
139  			_, err = fmt.Fprintf(conn.Writer, "data: %s\n\n", eventData)
140  			if err != nil {
141  				log.Errorf("SSE心跳发送失败: sessionId=%s, error=%v", conn.SessionID, err)
142  				sm.RemoveConnection(conn.SessionID)
143  				return
144  			}
145  
146  			conn.Flusher.Flush()
147  			conn.LastPing = time.Now()
148  			log.Debugf("SSE心跳发送成功: sessionId=%s", conn.SessionID)
149  		}
150  	}
151  }
152  
153  // RemoveConnection 移除SSE连接
154  func (sm *SSEManager) RemoveConnection(sessionID string) {
155  	sm.mutex.Lock()
156  	defer sm.mutex.Unlock()
157  
158  	if conn, exists := sm.connections[sessionID]; exists {
159  		close(conn.CloseChan)
160  		delete(sm.connections, sessionID)
161  		log.Infof("SSE连接移除: sessionId=%s, username=%s, remainingConnections=%d", sessionID, conn.Username, len(sm.connections))
162  	}
163  }
164  
165  // SendEvent 向指定会话发送事件
166  func (sm *SSEManager) SendEvent(id string, sessionID string, eventType string, event interface{}) error {
167  	sm.mutex.RLock()
168  	conn, exists := sm.connections[sessionID]
169  	sm.mutex.RUnlock()
170  
171  	if !exists {
172  		log.Warnf("SSE连接不存在,跳过事件推送: sessionId=%s, eventType=%s", sessionID, eventType)
173  		return fmt.Errorf("连接不存在: sessionId=%s", sessionID)
174  	}
175  
176  	log.Debugf("SSE事件推送: sessionId=%s, eventType=%s, eventId=%s", sessionID, eventType, id)
177  	return sm.sendEventToConnection(conn, id, eventType, event)
178  }
179  
180  // sendEventToConnection 向单个连接发送事件
181  func (sm *SSEManager) sendEventToConnection(conn *SSEConnection, id string, eventType string, event interface{}) error {
182  	// 创建事件消息
183  	eventMessage := TaskEventMessage{
184  		ID:        id,
185  		Type:      eventType,
186  		SessionID: conn.SessionID,
187  		Timestamp: time.Now().Unix(),
188  		Event:     event,
189  	}
190  
191  	// 序列化事件
192  	eventData, err := json.Marshal(eventMessage)
193  	if err != nil {
194  		log.Errorf("SSE事件序列化失败: sessionId=%s, eventType=%s, error=%v", conn.SessionID, eventType, err)
195  		return fmt.Errorf("序列化事件失败: %v", err)
196  	}
197  
198  	// 按照SSE规范发送消息
199  	// 格式: id: <id>\nevent: <event_type>\ndata: <json_data>\n\n
200  	_, err = fmt.Fprintf(conn.Writer, "id: %s\nevent: %s\ndata: %s\n\n",
201  		id, eventType, eventData)
202  	if err != nil {
203  		log.Errorf("SSE事件发送失败: sessionId=%s, eventType=%s, error=%v", conn.SessionID, eventType, err)
204  		return fmt.Errorf("发送事件失败: %v", err)
205  	}
206  
207  	// 刷新缓冲区
208  	conn.Flusher.Flush()
209  	conn.LastPing = time.Now()
210  
211  	log.Infof("发送事件: sessionId=%s, eventType=%s", conn.SessionID, eventType)
212  	log.Debugf("SSE事件发送成功: sessionId=%s, eventType=%s, eventId=%s", conn.SessionID, eventType, id)
213  	return nil
214  }
215  
216  // GetConnectionCount 获取当前连接数
217  func (sm *SSEManager) GetConnectionCount() int {
218  	sm.mutex.RLock()
219  	defer sm.mutex.RUnlock()
220  	count := len(sm.connections)
221  	log.Debugf("SSE连接数统计: count=%d", count)
222  	return count
223  }
224  
225  // GetConnectionsByUser 获取指定用户的连接
226  func (sm *SSEManager) GetConnectionsByUser(username string) []string {
227  	sm.mutex.RLock()
228  	defer sm.mutex.RUnlock()
229  
230  	var sessionIDs []string
231  	for sessionID, conn := range sm.connections {
232  		if conn.Username == username {
233  			sessionIDs = append(sessionIDs, sessionID)
234  		}
235  	}
236  
237  	log.Debugf("用户SSE连接查询: username=%s, connectionCount=%d", username, len(sessionIDs))
238  	return sessionIDs
239  }
240  
241  // HasConnection 检查指定sessionId的连接是否存在
242  func (sm *SSEManager) HasConnection(sessionID string) bool {
243  	sm.mutex.RLock()
244  	defer sm.mutex.RUnlock()
245  	_, exists := sm.connections[sessionID]
246  	return exists
247  }