task.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 package database 20 21 import ( 22 "encoding/json" 23 "fmt" 24 "time" 25 26 "gorm.io/datatypes" 27 "gorm.io/gorm" 28 ) 29 30 const ( 31 publicUserUsername = "public_user" 32 demoTestUsername = "demo-test" 33 ) 34 35 // User 用户表(扩展版本) 36 type User struct { 37 UserID string `gorm:"primaryKey;column:user_id" json:"user_id"` 38 Username string `gorm:"column:username;not null;uniqueIndex" json:"username"` // 用户名(唯一) 39 Email string `gorm:"column:email;not null;uniqueIndex" json:"email"` // 邮箱(唯一) 40 IsActive bool `gorm:"column:is_active;not null;default:true" json:"is_active"` // 是否激活 41 FirstLogin bool `gorm:"column:first_login;not null;default:true" json:"first_login"` // 是否首次登录,默认true 42 CreatedAt int64 `gorm:"column:created_at;not null" json:"created_at"` // 创建时间 43 } 44 45 // Session 会话表(一个会话对应一个任务) 46 type Session struct { 47 ID string `gorm:"primaryKey;column:id" json:"id"` // 会话ID,也是任务ID 48 Username string `gorm:"column:username;not null" json:"username"` 49 Title string `gorm:"column:title" json:"title"` 50 TaskType string `gorm:"column:task_type;not null" json:"task_type"` // 任务类型 51 Content string `gorm:"column:content;not null" json:"content"` // 任务内容 52 Params datatypes.JSON `gorm:"column:params" json:"params"` // 任务参数 53 Attachments datatypes.JSON `gorm:"column:attachments" json:"attachments"` // 附件 54 Status string `gorm:"column:status;not null;default:'todo'" json:"status"` // todo, doing, done, error 55 AssignedAgent string `gorm:"column:assigned_agent" json:"assigned_agent"` // 分配的Agent 56 CountryIsoCode string `gorm:"column:contry_iso_code" json:"countryIsoCode"` // 标识语言 57 StartedAt *int64 `gorm:"column:started_at" json:"started_at"` // 时间戳毫秒级 58 CompletedAt *int64 `gorm:"column:completed_at" json:"completed_at"` // 时间戳毫秒级 59 CreatedAt int64 `gorm:"column:created_at;not null" json:"created_at"` // 时间戳毫秒级 60 UpdatedAt int64 `gorm:"column:updated_at;not null" json:"updated_at"` // 时间戳毫秒级 61 62 // 关联关系 63 User User `gorm:"foreignKey:Username" json:"user"` 64 Messages []TaskMessage `gorm:"foreignKey:SessionID" json:"messages"` // 直接关联到Session 65 Share bool `gorm:"column:share;not null;default:false" json:"share"` 66 } 67 68 // TaskMessage 任务消息表(存储所有类型的事件消息) 69 type TaskMessage struct { 70 ID string `gorm:"primaryKey;column:id" json:"id"` // 消息ID(前端生成的对话ID) 71 SessionID string `gorm:"column:session_id;not null" json:"session_id"` // 会话ID(也是任务ID) 72 Type string `gorm:"column:type;not null" json:"type"` // liveStatus, planUpdate, statusUpdate, toolUsed等 73 EventData datatypes.JSON `gorm:"column:event_data;not null" json:"event_data"` // 存储事件的具体数据 74 Timestamp int64 `gorm:"column:timestamp;not null" json:"timestamp"` 75 CreatedAt int64 `gorm:"column:created_at;not null" json:"created_at"` // 时间戳毫秒级 76 77 // 关联关系 78 Session Session `gorm:"foreignKey:SessionID" json:"session"` 79 } 80 81 // TaskStore 任务数据存储 82 type TaskStore struct { 83 db *gorm.DB 84 } 85 86 // NewTaskStore 创建新的TaskStore实例 87 func NewTaskStore(db *gorm.DB) *TaskStore { 88 return &TaskStore{db: db} 89 } 90 91 // ResetRunningTasks 重置运行中的任务为失败 92 func (s *TaskStore) ResetRunningTasks() error { 93 return s.db.Model(&Session{}).Where("status = 'doing' or status = 'failed'").Updates(map[string]interface{}{ 94 "status": "error", 95 "updated_at": time.Now().UnixMilli(), 96 }).Error 97 } 98 99 // Init 自动迁移任务相关表结构 100 func (s *TaskStore) Init() error { 101 if err := s.db.AutoMigrate(&User{}, &Session{}, &TaskMessage{}); err != nil { 102 return err 103 } 104 return s.createIndexes() 105 } 106 107 // createIndexes 创建查询优化索引 108 func (s *TaskStore) createIndexes() error { 109 indexes := []string{ 110 // Session 表索引 111 "CREATE INDEX IF NOT EXISTS idx_sessions_username_created ON sessions(username, created_at DESC)", 112 "CREATE INDEX IF NOT EXISTS idx_sessions_username_tasktype ON sessions(username, task_type)", 113 "CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status)", 114 // TaskMessage 表索引 115 "CREATE INDEX IF NOT EXISTS idx_taskmessages_session_timestamp ON task_messages(session_id, timestamp)", 116 "CREATE INDEX IF NOT EXISTS idx_taskmessages_session_type ON task_messages(session_id, type)", 117 } 118 119 for _, sql := range indexes { 120 if err := s.db.Exec(sql).Error; err != nil { 121 return fmt.Errorf("创建索引失败: %s, error: %v", sql, err) 122 } 123 } 124 return nil 125 } 126 127 // CreateUser 创建用户 128 func (s *TaskStore) CreateUser(user *User) error { 129 now := time.Now().UnixMilli() 130 user.CreatedAt = now 131 return s.db.Create(user).Error 132 } 133 134 // GetUser 获取用户信息 135 func (s *TaskStore) GetUser(username string) (*User, error) { 136 var user User 137 err := s.db.First(&user, "username = ?", username).Error 138 if err != nil { 139 return nil, err 140 } 141 return &user, nil 142 } 143 144 // GetUserByEmail 根据邮箱获取用户 145 func (s *TaskStore) GetUserByEmail(email string) (*User, error) { 146 var user User 147 err := s.db.First(&user, "email = ?", email).Error 148 if err != nil { 149 return nil, err 150 } 151 return &user, nil 152 } 153 154 // CheckUserExists 检查用户是否存在 155 func (s *TaskStore) CheckUserExists(email string) (bool, error) { 156 var count int64 157 err := s.db.Model(&User{}).Where(" email = ?", email).Count(&count).Error 158 return count > 0, err 159 } 160 161 // CreateSession 创建会话(包含任务信息) 162 func (s *TaskStore) CreateSession(session *Session) error { 163 now := time.Now().UnixMilli() 164 session.CreatedAt = now 165 session.UpdatedAt = now 166 return s.db.Create(session).Error 167 } 168 169 // GetSession 获取会话信息 170 func (s *TaskStore) GetSession(id string) (*Session, error) { 171 var session Session 172 err := s.db.Preload("User").Preload("Messages").First(&session, "id = ?", id).Error 173 if err != nil { 174 return nil, err 175 } 176 return &session, nil 177 } 178 179 // SetShare 设置会话共享 180 func (s *TaskStore) SetShare(sessionID string, share bool) error { 181 return s.db.Model(&Session{}).Where("id = ?", sessionID).Update("share", share).Error 182 } 183 184 // UpdateSessionStatus 更新会话状态 185 func (s *TaskStore) UpdateSessionStatus(id string, status string) error { 186 now := time.Now().UnixMilli() 187 updates := map[string]interface{}{ 188 "status": status, 189 "updated_at": now, 190 } 191 192 if status == "doing" { 193 updates["started_at"] = &now 194 } else if status == "done" || status == "error" || status == "terminated" { 195 updates["completed_at"] = &now 196 } 197 198 return s.db.Model(&Session{}).Where("id = ?", id).Updates(updates).Error 199 } 200 201 // UpdateSessionAssignedAgent 更新会话的分配Agent和开始时间 202 func (s *TaskStore) UpdateSessionAssignedAgent(sessionID string, agentID string) error { 203 now := time.Now().UnixMilli() 204 updates := map[string]interface{}{ 205 "assigned_agent": agentID, 206 "status": "doing", 207 "started_at": &now, 208 } 209 210 return s.db.Model(&Session{}).Where("id = ?", sessionID).Updates(updates).Error 211 } 212 213 // UpdateSession 更新会话信息 214 func (s *TaskStore) UpdateSession(sessionID string, updates map[string]interface{}) error { 215 // 添加更新时间 216 updates["updated_at"] = time.Now().UnixMilli() 217 return s.db.Model(&Session{}).Where("id = ?", sessionID).Updates(updates).Error 218 } 219 220 // DeleteSession 删除会话 221 func (s *TaskStore) DeleteSession(sessionID string) error { 222 return s.db.Delete(&Session{}, "id = ?", sessionID).Error 223 } 224 225 // DeleteSessionMessages 删除会话的所有消息 226 func (s *TaskStore) DeleteSessionMessages(sessionID string) error { 227 return s.db.Where("session_id = ?", sessionID).Delete(&TaskMessage{}).Error 228 } 229 230 func (s *TaskStore) DeleteUser(email string) error { 231 return s.db.Where("email = ?", email).Delete(&User{}).Error 232 } 233 234 // UpdateUserFirstLogin 更新用户的首次登录状态 235 func (s *TaskStore) UpdateUserFirstLogin(username string, firstLogin bool) error { 236 return s.db.Model(&User{}).Where("username = ?", username).Update("first_login", firstLogin).Error 237 } 238 239 // DeleteSessionWithMessages 使用事务删除会话及其所有消息 240 func (s *TaskStore) DeleteSessionWithMessages(sessionID string) error { 241 return s.db.Transaction(func(tx *gorm.DB) error { 242 // 1. 删除会话的所有消息 243 if err := tx.Where("session_id = ?", sessionID).Delete(&TaskMessage{}).Error; err != nil { 244 return fmt.Errorf("删除会话消息失败: %v", err) 245 } 246 247 // 2. 删除会话记录 248 if err := tx.Delete(&Session{}, "id = ?", sessionID).Error; err != nil { 249 return fmt.Errorf("删除会话记录失败: %v", err) 250 } 251 252 return nil 253 }) 254 } 255 256 // CreateTaskMessage 创建任务消息 257 func (s *TaskStore) CreateTaskMessage(message *TaskMessage) error { 258 now := time.Now().UnixMilli() 259 message.CreatedAt = now 260 return s.db.Create(message).Error 261 } 262 263 // GetSessionMessages 获取会话的所有消息 264 func (s *TaskStore) GetSessionMessages(sessionID string) ([]*TaskMessage, error) { 265 var messages []*TaskMessage 266 err := s.db.Where("session_id = ?", sessionID).Order("timestamp ASC").Find(&messages).Error 267 if err != nil { 268 return nil, err 269 } 270 return messages, nil 271 } 272 273 // GetUserSessions 获取用户的所有会话 274 func (s *TaskStore) GetUserSessions(username string) ([]*Session, error) { 275 var sessions []*Session 276 err := s.visibleSessionsQuery(username). 277 Order("created_at DESC"). 278 Find(&sessions).Error 279 if err != nil { 280 return nil, err 281 } 282 return sessions, nil 283 } 284 285 // GetUserSessionsByType 获取用户的会话,支持可选的任务类型过滤 286 func (s *TaskStore) GetUserSessionsByType(username string, taskType string) ([]*Session, error) { 287 query := s.visibleSessionsQuery(username) 288 289 // 如果指定了任务类型,添加类型过滤 290 if taskType != "" { 291 query = query.Where("task_type = ?", taskType) 292 } 293 294 var sessions []*Session 295 err := query.Order("created_at DESC").Find(&sessions).Error 296 if err != nil { 297 return nil, err 298 } 299 return sessions, nil 300 } 301 302 // StoreEvent 存储事件消息 303 func (s *TaskStore) StoreEvent(id string, sessionID string, eventType string, eventData interface{}, timestamp int64) error { 304 // 将事件数据序列化为JSON 305 eventJSON, err := json.Marshal(eventData) 306 if err != nil { 307 return err 308 } 309 310 message := &TaskMessage{ 311 ID: id, 312 SessionID: sessionID, 313 Type: eventType, 314 EventData: datatypes.JSON(eventJSON), 315 Timestamp: timestamp, 316 } 317 318 return s.CreateTaskMessage(message) 319 } 320 321 // GetSessionEvents 获取会话的所有事件 322 func (s *TaskStore) GetSessionEvents(sessionID string) ([]*TaskMessage, error) { 323 return s.GetSessionMessages(sessionID) 324 } 325 326 // GetSessionEventsByType 根据类型获取会话事件 327 func (s *TaskStore) GetSessionEventsByType(sessionID string, eventType string) ([]*TaskMessage, error) { 328 var messages []*TaskMessage 329 err := s.db.Where("session_id = ? AND type = ?", sessionID, eventType).Order("timestamp ASC").Find(&messages).Error 330 if err != nil { 331 return nil, err 332 } 333 return messages, nil 334 } 335 336 // SearchUserSessionsSimple 使用单个查询参数搜索用户的会话,支持在title、content、task_type字段中搜索 337 func (s *TaskStore) SearchUserSessionsSimple(username string, searchParams SimpleSearchParams) ([]*Session, int64, error) { 338 query := s.visibleSessionsQuery(username) 339 340 // 如果指定了任务类型,添加类型过滤 341 if searchParams.TaskType != "" { 342 query = query.Where("task_type = ?", searchParams.TaskType) 343 } 344 345 // 如果有查询关键词,在多个字段中搜索 346 if searchParams.Query != "" { 347 query = query.Where("title LIKE ? OR content LIKE ? OR task_type LIKE ?", 348 "%"+searchParams.Query+"%", 349 "%"+searchParams.Query+"%", 350 "%"+searchParams.Query+"%") 351 } 352 353 // 获取总数 354 var total int64 355 if err := query.Count(&total).Error; err != nil { 356 return nil, 0, err 357 } 358 359 // 应用分页和排序 360 var sessions []*Session 361 err := query.Order("created_at DESC"). 362 Offset((searchParams.Page - 1) * searchParams.PageSize). 363 Limit(searchParams.PageSize). 364 Find(&sessions).Error 365 366 if err != nil { 367 return nil, 0, err 368 } 369 370 return sessions, total, nil 371 } 372 373 func (s *TaskStore) visibleSessionsQuery(username string) *gorm.DB { 374 query := s.db.Model(&Session{}) 375 376 if username == publicUserUsername || username == "" { 377 return query.Where( 378 "username = ? OR username = ? OR share = ?", 379 publicUserUsername, 380 demoTestUsername, 381 true, 382 ) 383 } 384 385 return query.Where( 386 "username = ? OR username = ?", 387 username, 388 demoTestUsername, 389 ) 390 } 391 392 // SimpleSearchParams 简化搜索参数结构 393 type SimpleSearchParams struct { 394 Query string `json:"query"` // 查询关键词,将在title、content、task_type字段中搜索 395 TaskType string `json:"task_type"` // 任务类型过滤 396 Page int `json:"page"` // 页码 397 PageSize int `json:"page_size"` // 每页大小 398 } 399 400 // generateMessageID 生成消息ID 401 func generateMessageID() string { 402 return time.Now().Format("20060102150405") + "_" + fmt.Sprintf("%d", time.Now().UnixNano()) 403 }