/ pkg / database / task.go
task.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  package database
 20  
 21  import (
 22  	"encoding/json"
 23  	"fmt"
 24  	"time"
 25  
 26  	"gorm.io/datatypes"
 27  	"gorm.io/gorm"
 28  )
 29  
 30  const (
 31  	publicUserUsername = "public_user"
 32  	demoTestUsername   = "demo-test"
 33  )
 34  
 35  // User 用户表(扩展版本)
 36  type User struct {
 37  	UserID     string `gorm:"primaryKey;column:user_id" json:"user_id"`
 38  	Username   string `gorm:"column:username;not null;uniqueIndex" json:"username"`        // 用户名(唯一)
 39  	Email      string `gorm:"column:email;not null;uniqueIndex" json:"email"`              // 邮箱(唯一)
 40  	IsActive   bool   `gorm:"column:is_active;not null;default:true" json:"is_active"`     // 是否激活
 41  	FirstLogin bool   `gorm:"column:first_login;not null;default:true" json:"first_login"` // 是否首次登录,默认true
 42  	CreatedAt  int64  `gorm:"column:created_at;not null" json:"created_at"`                // 创建时间
 43  }
 44  
 45  // Session 会话表(一个会话对应一个任务)
 46  type Session struct {
 47  	ID             string         `gorm:"primaryKey;column:id" json:"id"` // 会话ID,也是任务ID
 48  	Username       string         `gorm:"column:username;not null" json:"username"`
 49  	Title          string         `gorm:"column:title" json:"title"`
 50  	TaskType       string         `gorm:"column:task_type;not null" json:"task_type"`          // 任务类型
 51  	Content        string         `gorm:"column:content;not null" json:"content"`              // 任务内容
 52  	Params         datatypes.JSON `gorm:"column:params" json:"params"`                         // 任务参数
 53  	Attachments    datatypes.JSON `gorm:"column:attachments" json:"attachments"`               // 附件
 54  	Status         string         `gorm:"column:status;not null;default:'todo'" json:"status"` // todo, doing, done, error
 55  	AssignedAgent  string         `gorm:"column:assigned_agent" json:"assigned_agent"`         // 分配的Agent
 56  	CountryIsoCode string         `gorm:"column:contry_iso_code" json:"countryIsoCode"`        // 标识语言
 57  	StartedAt      *int64         `gorm:"column:started_at" json:"started_at"`                 // 时间戳毫秒级
 58  	CompletedAt    *int64         `gorm:"column:completed_at" json:"completed_at"`             // 时间戳毫秒级
 59  	CreatedAt      int64          `gorm:"column:created_at;not null" json:"created_at"`        // 时间戳毫秒级
 60  	UpdatedAt      int64          `gorm:"column:updated_at;not null" json:"updated_at"`        // 时间戳毫秒级
 61  
 62  	// 关联关系
 63  	User     User          `gorm:"foreignKey:Username" json:"user"`
 64  	Messages []TaskMessage `gorm:"foreignKey:SessionID" json:"messages"` // 直接关联到Session
 65  	Share    bool          `gorm:"column:share;not null;default:false" json:"share"`
 66  }
 67  
 68  // TaskMessage 任务消息表(存储所有类型的事件消息)
 69  type TaskMessage struct {
 70  	ID        string         `gorm:"primaryKey;column:id" json:"id"`               // 消息ID(前端生成的对话ID)
 71  	SessionID string         `gorm:"column:session_id;not null" json:"session_id"` // 会话ID(也是任务ID)
 72  	Type      string         `gorm:"column:type;not null" json:"type"`             // liveStatus, planUpdate, statusUpdate, toolUsed等
 73  	EventData datatypes.JSON `gorm:"column:event_data;not null" json:"event_data"` // 存储事件的具体数据
 74  	Timestamp int64          `gorm:"column:timestamp;not null" json:"timestamp"`
 75  	CreatedAt int64          `gorm:"column:created_at;not null" json:"created_at"` // 时间戳毫秒级
 76  
 77  	// 关联关系
 78  	Session Session `gorm:"foreignKey:SessionID" json:"session"`
 79  }
 80  
 81  // TaskStore 任务数据存储
 82  type TaskStore struct {
 83  	db *gorm.DB
 84  }
 85  
 86  // NewTaskStore 创建新的TaskStore实例
 87  func NewTaskStore(db *gorm.DB) *TaskStore {
 88  	return &TaskStore{db: db}
 89  }
 90  
 91  // ResetRunningTasks 重置运行中的任务为失败
 92  func (s *TaskStore) ResetRunningTasks() error {
 93  	return s.db.Model(&Session{}).Where("status = 'doing' or status = 'failed'").Updates(map[string]interface{}{
 94  		"status":     "error",
 95  		"updated_at": time.Now().UnixMilli(),
 96  	}).Error
 97  }
 98  
 99  // Init 自动迁移任务相关表结构
100  func (s *TaskStore) Init() error {
101  	if err := s.db.AutoMigrate(&User{}, &Session{}, &TaskMessage{}); err != nil {
102  		return err
103  	}
104  	return s.createIndexes()
105  }
106  
107  // createIndexes 创建查询优化索引
108  func (s *TaskStore) createIndexes() error {
109  	indexes := []string{
110  		// Session 表索引
111  		"CREATE INDEX IF NOT EXISTS idx_sessions_username_created ON sessions(username, created_at DESC)",
112  		"CREATE INDEX IF NOT EXISTS idx_sessions_username_tasktype ON sessions(username, task_type)",
113  		"CREATE INDEX IF NOT EXISTS idx_sessions_status ON sessions(status)",
114  		// TaskMessage 表索引
115  		"CREATE INDEX IF NOT EXISTS idx_taskmessages_session_timestamp ON task_messages(session_id, timestamp)",
116  		"CREATE INDEX IF NOT EXISTS idx_taskmessages_session_type ON task_messages(session_id, type)",
117  	}
118  
119  	for _, sql := range indexes {
120  		if err := s.db.Exec(sql).Error; err != nil {
121  			return fmt.Errorf("创建索引失败: %s, error: %v", sql, err)
122  		}
123  	}
124  	return nil
125  }
126  
127  // CreateUser 创建用户
128  func (s *TaskStore) CreateUser(user *User) error {
129  	now := time.Now().UnixMilli()
130  	user.CreatedAt = now
131  	return s.db.Create(user).Error
132  }
133  
134  // GetUser 获取用户信息
135  func (s *TaskStore) GetUser(username string) (*User, error) {
136  	var user User
137  	err := s.db.First(&user, "username = ?", username).Error
138  	if err != nil {
139  		return nil, err
140  	}
141  	return &user, nil
142  }
143  
144  // GetUserByEmail 根据邮箱获取用户
145  func (s *TaskStore) GetUserByEmail(email string) (*User, error) {
146  	var user User
147  	err := s.db.First(&user, "email = ?", email).Error
148  	if err != nil {
149  		return nil, err
150  	}
151  	return &user, nil
152  }
153  
154  // CheckUserExists 检查用户是否存在
155  func (s *TaskStore) CheckUserExists(email string) (bool, error) {
156  	var count int64
157  	err := s.db.Model(&User{}).Where(" email = ?", email).Count(&count).Error
158  	return count > 0, err
159  }
160  
161  // CreateSession 创建会话(包含任务信息)
162  func (s *TaskStore) CreateSession(session *Session) error {
163  	now := time.Now().UnixMilli()
164  	session.CreatedAt = now
165  	session.UpdatedAt = now
166  	return s.db.Create(session).Error
167  }
168  
169  // GetSession 获取会话信息
170  func (s *TaskStore) GetSession(id string) (*Session, error) {
171  	var session Session
172  	err := s.db.Preload("User").Preload("Messages").First(&session, "id = ?", id).Error
173  	if err != nil {
174  		return nil, err
175  	}
176  	return &session, nil
177  }
178  
179  // SetShare 设置会话共享
180  func (s *TaskStore) SetShare(sessionID string, share bool) error {
181  	return s.db.Model(&Session{}).Where("id = ?", sessionID).Update("share", share).Error
182  }
183  
184  // UpdateSessionStatus 更新会话状态
185  func (s *TaskStore) UpdateSessionStatus(id string, status string) error {
186  	now := time.Now().UnixMilli()
187  	updates := map[string]interface{}{
188  		"status":     status,
189  		"updated_at": now,
190  	}
191  
192  	if status == "doing" {
193  		updates["started_at"] = &now
194  	} else if status == "done" || status == "error" || status == "terminated" {
195  		updates["completed_at"] = &now
196  	}
197  
198  	return s.db.Model(&Session{}).Where("id = ?", id).Updates(updates).Error
199  }
200  
201  // UpdateSessionAssignedAgent 更新会话的分配Agent和开始时间
202  func (s *TaskStore) UpdateSessionAssignedAgent(sessionID string, agentID string) error {
203  	now := time.Now().UnixMilli()
204  	updates := map[string]interface{}{
205  		"assigned_agent": agentID,
206  		"status":         "doing",
207  		"started_at":     &now,
208  	}
209  
210  	return s.db.Model(&Session{}).Where("id = ?", sessionID).Updates(updates).Error
211  }
212  
213  // UpdateSession 更新会话信息
214  func (s *TaskStore) UpdateSession(sessionID string, updates map[string]interface{}) error {
215  	// 添加更新时间
216  	updates["updated_at"] = time.Now().UnixMilli()
217  	return s.db.Model(&Session{}).Where("id = ?", sessionID).Updates(updates).Error
218  }
219  
220  // DeleteSession 删除会话
221  func (s *TaskStore) DeleteSession(sessionID string) error {
222  	return s.db.Delete(&Session{}, "id = ?", sessionID).Error
223  }
224  
225  // DeleteSessionMessages 删除会话的所有消息
226  func (s *TaskStore) DeleteSessionMessages(sessionID string) error {
227  	return s.db.Where("session_id = ?", sessionID).Delete(&TaskMessage{}).Error
228  }
229  
230  func (s *TaskStore) DeleteUser(email string) error {
231  	return s.db.Where("email = ?", email).Delete(&User{}).Error
232  }
233  
234  // UpdateUserFirstLogin 更新用户的首次登录状态
235  func (s *TaskStore) UpdateUserFirstLogin(username string, firstLogin bool) error {
236  	return s.db.Model(&User{}).Where("username = ?", username).Update("first_login", firstLogin).Error
237  }
238  
239  // DeleteSessionWithMessages 使用事务删除会话及其所有消息
240  func (s *TaskStore) DeleteSessionWithMessages(sessionID string) error {
241  	return s.db.Transaction(func(tx *gorm.DB) error {
242  		// 1. 删除会话的所有消息
243  		if err := tx.Where("session_id = ?", sessionID).Delete(&TaskMessage{}).Error; err != nil {
244  			return fmt.Errorf("删除会话消息失败: %v", err)
245  		}
246  
247  		// 2. 删除会话记录
248  		if err := tx.Delete(&Session{}, "id = ?", sessionID).Error; err != nil {
249  			return fmt.Errorf("删除会话记录失败: %v", err)
250  		}
251  
252  		return nil
253  	})
254  }
255  
256  // CreateTaskMessage 创建任务消息
257  func (s *TaskStore) CreateTaskMessage(message *TaskMessage) error {
258  	now := time.Now().UnixMilli()
259  	message.CreatedAt = now
260  	return s.db.Create(message).Error
261  }
262  
263  // GetSessionMessages 获取会话的所有消息
264  func (s *TaskStore) GetSessionMessages(sessionID string) ([]*TaskMessage, error) {
265  	var messages []*TaskMessage
266  	err := s.db.Where("session_id = ?", sessionID).Order("timestamp ASC").Find(&messages).Error
267  	if err != nil {
268  		return nil, err
269  	}
270  	return messages, nil
271  }
272  
273  // GetUserSessions 获取用户的所有会话
274  func (s *TaskStore) GetUserSessions(username string) ([]*Session, error) {
275  	var sessions []*Session
276  	err := s.visibleSessionsQuery(username).
277  		Order("created_at DESC").
278  		Find(&sessions).Error
279  	if err != nil {
280  		return nil, err
281  	}
282  	return sessions, nil
283  }
284  
285  // GetUserSessionsByType 获取用户的会话,支持可选的任务类型过滤
286  func (s *TaskStore) GetUserSessionsByType(username string, taskType string) ([]*Session, error) {
287  	query := s.visibleSessionsQuery(username)
288  
289  	// 如果指定了任务类型,添加类型过滤
290  	if taskType != "" {
291  		query = query.Where("task_type = ?", taskType)
292  	}
293  
294  	var sessions []*Session
295  	err := query.Order("created_at DESC").Find(&sessions).Error
296  	if err != nil {
297  		return nil, err
298  	}
299  	return sessions, nil
300  }
301  
302  // StoreEvent 存储事件消息
303  func (s *TaskStore) StoreEvent(id string, sessionID string, eventType string, eventData interface{}, timestamp int64) error {
304  	// 将事件数据序列化为JSON
305  	eventJSON, err := json.Marshal(eventData)
306  	if err != nil {
307  		return err
308  	}
309  
310  	message := &TaskMessage{
311  		ID:        id,
312  		SessionID: sessionID,
313  		Type:      eventType,
314  		EventData: datatypes.JSON(eventJSON),
315  		Timestamp: timestamp,
316  	}
317  
318  	return s.CreateTaskMessage(message)
319  }
320  
321  // GetSessionEvents 获取会话的所有事件
322  func (s *TaskStore) GetSessionEvents(sessionID string) ([]*TaskMessage, error) {
323  	return s.GetSessionMessages(sessionID)
324  }
325  
326  // GetSessionEventsByType 根据类型获取会话事件
327  func (s *TaskStore) GetSessionEventsByType(sessionID string, eventType string) ([]*TaskMessage, error) {
328  	var messages []*TaskMessage
329  	err := s.db.Where("session_id = ? AND type = ?", sessionID, eventType).Order("timestamp ASC").Find(&messages).Error
330  	if err != nil {
331  		return nil, err
332  	}
333  	return messages, nil
334  }
335  
336  // SearchUserSessionsSimple 使用单个查询参数搜索用户的会话,支持在title、content、task_type字段中搜索
337  func (s *TaskStore) SearchUserSessionsSimple(username string, searchParams SimpleSearchParams) ([]*Session, int64, error) {
338  	query := s.visibleSessionsQuery(username)
339  
340  	// 如果指定了任务类型,添加类型过滤
341  	if searchParams.TaskType != "" {
342  		query = query.Where("task_type = ?", searchParams.TaskType)
343  	}
344  
345  	// 如果有查询关键词,在多个字段中搜索
346  	if searchParams.Query != "" {
347  		query = query.Where("title LIKE ? OR content LIKE ? OR task_type LIKE ?",
348  			"%"+searchParams.Query+"%",
349  			"%"+searchParams.Query+"%",
350  			"%"+searchParams.Query+"%")
351  	}
352  
353  	// 获取总数
354  	var total int64
355  	if err := query.Count(&total).Error; err != nil {
356  		return nil, 0, err
357  	}
358  
359  	// 应用分页和排序
360  	var sessions []*Session
361  	err := query.Order("created_at DESC").
362  		Offset((searchParams.Page - 1) * searchParams.PageSize).
363  		Limit(searchParams.PageSize).
364  		Find(&sessions).Error
365  
366  	if err != nil {
367  		return nil, 0, err
368  	}
369  
370  	return sessions, total, nil
371  }
372  
373  func (s *TaskStore) visibleSessionsQuery(username string) *gorm.DB {
374  	query := s.db.Model(&Session{})
375  
376  	if username == publicUserUsername || username == "" {
377  		return query.Where(
378  			"username = ? OR username = ? OR share = ?",
379  			publicUserUsername,
380  			demoTestUsername,
381  			true,
382  		)
383  	}
384  
385  	return query.Where(
386  		"username = ? OR username = ?",
387  		username,
388  		demoTestUsername,
389  	)
390  }
391  
392  // SimpleSearchParams 简化搜索参数结构
393  type SimpleSearchParams struct {
394  	Query    string `json:"query"`     // 查询关键词,将在title、content、task_type字段中搜索
395  	TaskType string `json:"task_type"` // 任务类型过滤
396  	Page     int    `json:"page"`      // 页码
397  	PageSize int    `json:"page_size"` // 每页大小
398  }
399  
400  // generateMessageID 生成消息ID
401  func generateMessageID() string {
402  	return time.Now().Format("20060102150405") + "_" + fmt.Sprintf("%d", time.Now().UnixNano())
403  }