sse_manager.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 package websocket 20 21 import ( 22 "encoding/json" 23 "fmt" 24 "net/http" 25 "sync" 26 "time" 27 28 "trpc.group/trpc-go/trpc-go/log" 29 ) 30 31 // SSEConnection 表示一个SSE连接 32 type SSEConnection struct { 33 SessionID string 34 Username string 35 Writer http.ResponseWriter 36 Flusher http.Flusher 37 CloseChan chan bool 38 LastPing time.Time 39 } 40 41 // SSEManager 管理SSE连接和事件推送 42 type SSEManager struct { 43 connections map[string]*SSEConnection // sessionId -> connection 44 mutex sync.RWMutex 45 } 46 47 // NewSSEManager 创建新的SSE管理器 48 func NewSSEManager() *SSEManager { 49 return &SSEManager{ 50 connections: make(map[string]*SSEConnection), 51 } 52 } 53 54 // AddConnection 添加新的SSE连接 55 func (sm *SSEManager) AddConnection(sessionID, username string, w http.ResponseWriter) error { 56 sm.mutex.Lock() 57 defer sm.mutex.Unlock() 58 59 // 检查是否已存在相同sessionId的连接 60 if existing, exists := sm.connections[sessionID]; exists { 61 // 关闭现有连接 62 close(existing.CloseChan) 63 log.Infof("SSE连接冲突,关闭现有连接: sessionId=%s, username=%s", sessionID, username) 64 } 65 66 // 检查是否支持SSE 67 flusher, ok := w.(http.Flusher) 68 if !ok { 69 log.Errorf("SSE流式传输不支持: sessionId=%s, username=%s", sessionID, username) 70 return fmt.Errorf("streaming unsupported") 71 } 72 73 // 设置SSE响应头 74 w.Header().Set("Content-Type", "text/event-stream") 75 w.Header().Set("Cache-Control", "no-cache") 76 w.Header().Set("Connection", "keep-alive") 77 w.Header().Set("Access-Control-Allow-Origin", "*") 78 w.Header().Set("Access-Control-Allow-Headers", "Cache-Control") 79 80 // 创建连接 81 conn := &SSEConnection{ 82 SessionID: sessionID, 83 Username: username, 84 Writer: w, 85 Flusher: flusher, 86 CloseChan: make(chan bool), 87 LastPing: time.Now(), 88 } 89 90 sm.connections[sessionID] = conn 91 log.Infof("SSE连接建立: sessionId=%s, username=%s, totalConnections=%d", sessionID, username, len(sm.connections)) 92 93 // 发送连接成功消息 94 sm.sendEventToConnection(conn, "connected", "connected", map[string]interface{}{ 95 "message": "SSE连接已建立", 96 "sessionId": sessionID, 97 }) 98 99 // 启动心跳和连接保持 100 go sm.keepConnectionAlive(conn) 101 102 return nil 103 } 104 105 // keepConnectionAlive 保持连接活跃 106 func (sm *SSEManager) keepConnectionAlive(conn *SSEConnection) { 107 ticker := time.NewTicker(10 * time.Second) // 改为10秒心跳,提高频率 108 defer ticker.Stop() 109 110 log.Debugf("SSE心跳启动: sessionId=%s, username=%s", conn.SessionID, conn.Username) 111 112 for { 113 select { 114 case <-conn.CloseChan: 115 log.Infof("SSE连接已关闭: sessionId=%s", conn.SessionID) 116 log.Infof("SSE连接关闭: sessionId=%s, username=%s", conn.SessionID, conn.Username) 117 return 118 case <-ticker.C: 119 // 发送liveStatus心跳消息 120 heartbeat := TaskEventMessage{ 121 ID: fmt.Sprintf("heartbeat_%d", time.Now().Unix()), 122 Type: "liveStatus", // 改为liveStatus类型 123 SessionID: conn.SessionID, 124 Timestamp: time.Now().Unix(), 125 Event: LiveStatusEvent{ 126 ID: fmt.Sprintf("heartbeat_%d", time.Now().Unix()), 127 Type: "liveStatus", 128 Timestamp: time.Now().UnixMilli(), 129 Text: "思考中...", // 默认状态文本 130 }, 131 } 132 133 eventData, err := json.Marshal(heartbeat) 134 if err != nil { 135 log.Errorf("SSE心跳序列化失败: sessionId=%s, error=%v", conn.SessionID, err) 136 continue 137 } 138 139 _, err = fmt.Fprintf(conn.Writer, "data: %s\n\n", eventData) 140 if err != nil { 141 log.Errorf("SSE心跳发送失败: sessionId=%s, error=%v", conn.SessionID, err) 142 sm.RemoveConnection(conn.SessionID) 143 return 144 } 145 146 conn.Flusher.Flush() 147 conn.LastPing = time.Now() 148 log.Debugf("SSE心跳发送成功: sessionId=%s", conn.SessionID) 149 } 150 } 151 } 152 153 // RemoveConnection 移除SSE连接 154 func (sm *SSEManager) RemoveConnection(sessionID string) { 155 sm.mutex.Lock() 156 defer sm.mutex.Unlock() 157 158 if conn, exists := sm.connections[sessionID]; exists { 159 close(conn.CloseChan) 160 delete(sm.connections, sessionID) 161 log.Infof("SSE连接移除: sessionId=%s, username=%s, remainingConnections=%d", sessionID, conn.Username, len(sm.connections)) 162 } 163 } 164 165 // SendEvent 向指定会话发送事件 166 func (sm *SSEManager) SendEvent(id string, sessionID string, eventType string, event interface{}) error { 167 sm.mutex.RLock() 168 conn, exists := sm.connections[sessionID] 169 sm.mutex.RUnlock() 170 171 if !exists { 172 log.Warnf("SSE连接不存在,跳过事件推送: sessionId=%s, eventType=%s", sessionID, eventType) 173 return fmt.Errorf("连接不存在: sessionId=%s", sessionID) 174 } 175 176 log.Debugf("SSE事件推送: sessionId=%s, eventType=%s, eventId=%s", sessionID, eventType, id) 177 return sm.sendEventToConnection(conn, id, eventType, event) 178 } 179 180 // sendEventToConnection 向单个连接发送事件 181 func (sm *SSEManager) sendEventToConnection(conn *SSEConnection, id string, eventType string, event interface{}) error { 182 // 创建事件消息 183 eventMessage := TaskEventMessage{ 184 ID: id, 185 Type: eventType, 186 SessionID: conn.SessionID, 187 Timestamp: time.Now().Unix(), 188 Event: event, 189 } 190 191 // 序列化事件 192 eventData, err := json.Marshal(eventMessage) 193 if err != nil { 194 log.Errorf("SSE事件序列化失败: sessionId=%s, eventType=%s, error=%v", conn.SessionID, eventType, err) 195 return fmt.Errorf("序列化事件失败: %v", err) 196 } 197 198 // 按照SSE规范发送消息 199 // 格式: id: <id>\nevent: <event_type>\ndata: <json_data>\n\n 200 _, err = fmt.Fprintf(conn.Writer, "id: %s\nevent: %s\ndata: %s\n\n", 201 id, eventType, eventData) 202 if err != nil { 203 log.Errorf("SSE事件发送失败: sessionId=%s, eventType=%s, error=%v", conn.SessionID, eventType, err) 204 return fmt.Errorf("发送事件失败: %v", err) 205 } 206 207 // 刷新缓冲区 208 conn.Flusher.Flush() 209 conn.LastPing = time.Now() 210 211 log.Infof("发送事件: sessionId=%s, eventType=%s", conn.SessionID, eventType) 212 log.Debugf("SSE事件发送成功: sessionId=%s, eventType=%s, eventId=%s", conn.SessionID, eventType, id) 213 return nil 214 } 215 216 // GetConnectionCount 获取当前连接数 217 func (sm *SSEManager) GetConnectionCount() int { 218 sm.mutex.RLock() 219 defer sm.mutex.RUnlock() 220 count := len(sm.connections) 221 log.Debugf("SSE连接数统计: count=%d", count) 222 return count 223 } 224 225 // GetConnectionsByUser 获取指定用户的连接 226 func (sm *SSEManager) GetConnectionsByUser(username string) []string { 227 sm.mutex.RLock() 228 defer sm.mutex.RUnlock() 229 230 var sessionIDs []string 231 for sessionID, conn := range sm.connections { 232 if conn.Username == username { 233 sessionIDs = append(sessionIDs, sessionID) 234 } 235 } 236 237 log.Debugf("用户SSE连接查询: username=%s, connectionCount=%d", username, len(sessionIDs)) 238 return sessionIDs 239 } 240 241 // HasConnection 检查指定sessionId的连接是否存在 242 func (sm *SSEManager) HasConnection(sessionID string) bool { 243 sm.mutex.RLock() 244 defer sm.mutex.RUnlock() 245 _, exists := sm.connections[sessionID] 246 return exists 247 }