/ pkg / database / model.go
model.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  package database
 20  
 21  import (
 22  	"os"
 23  	"time"
 24  
 25  	"gorm.io/gorm"
 26  )
 27  
 28  type ModelParams struct {
 29  	BaseUrl string `json:"base_url"`
 30  	Token   string `json:"token"`
 31  	Model   string `json:"model"`
 32  	Limit   int    `json:"limit"`
 33  }
 34  
 35  // Model 模型表
 36  type Model struct {
 37  	ModelID            string   `gorm:"primaryKey;column:model_id" json:"model_id" yaml:"model_id"`                              // 模型ID
 38  	Username           string   `gorm:"column:username;not null" json:"username" yaml:"-"`                                        // 创建者用户名
 39  	ModelName          string   `gorm:"column:model_name;not null" json:"model_name" yaml:"model_name"`                           // 模型名称
 40  	Token              string   `gorm:"column:token;not null" json:"token" yaml:"token"`                                          // API Token
 41  	BaseURL            string   `gorm:"column:base_url;not null" json:"base_url" yaml:"base_url"`                                 // 基础URL
 42  	Note               string   `gorm:"column:note" json:"note" yaml:"note,omitempty"`                                            // 备注信息
 43  	Limit              int      `gorm:"column:limit" json:"limit" yaml:"limit,omitempty"`
 44  	Default            []string `gorm:"-" json:"default,omitempty" yaml:"default,omitempty"`                                      // 默认字段
 45  	CreatedAt          int64    `gorm:"column:created_at;not null" json:"created_at" yaml:"-"`                                    // 时间戳毫秒级
 46  	UpdatedAt          int64    `gorm:"column:updated_at;not null" json:"updated_at" yaml:"-"`                                    // 时间戳毫秒级
 47  
 48  	// 关联关系
 49  	User User `gorm:"foreignKey:Username" json:"user" yaml:"-"`
 50  }
 51  
 52  // ModelStore 模型数据存储
 53  type ModelStore struct {
 54  	db *gorm.DB
 55  }
 56  
 57  // NewModelStore 创建新的ModelStore实例
 58  func NewModelStore(db *gorm.DB) *ModelStore {
 59  	return &ModelStore{db: db}
 60  }
 61  
 62  // Init 自动迁移模型相关表结构
 63  func (s *ModelStore) Init() error {
 64  	if err := s.db.AutoMigrate(&Model{}); err != nil {
 65  		return err
 66  	}
 67  	// 创建索引优化查询
 68  	return s.db.Exec("CREATE INDEX IF NOT EXISTS idx_models_username_created ON models(username, created_at DESC)").Error
 69  }
 70  
 71  // CreateModel 创建模型
 72  func (s *ModelStore) CreateModel(model *Model) error {
 73  	now := time.Now().UnixMilli()
 74  	model.CreatedAt = now
 75  	model.UpdatedAt = now
 76  	return s.db.Create(model).Error
 77  }
 78  
 79  // GetModel 获取模型信息
 80  func (s *ModelStore) GetModel(modelID string) (*Model, error) {
 81  	var model Model
 82  	err := s.db.Preload("User").First(&model, "model_id = ?", modelID).Error
 83  	if err != nil {
 84  		// Try YAML model
 85  		if yamlModel := s.GetYamlModel(modelID); yamlModel != nil {
 86  			return yamlModel, nil
 87  		}
 88  		return nil, err
 89  	}
 90  	return &model, nil
 91  }
 92  
 93  // GetModelByUser 获取用户创建的模型
 94  func (s *ModelStore) GetModelByUser(modelID string, username string) (*Model, error) {
 95  	var model Model
 96  	err := s.db.Preload("User").First(&model, "model_id = ? AND username = ?", modelID, username).Error
 97  	if err != nil {
 98  		return nil, err
 99  	}
100  	return &model, nil
101  }
102  
103  // GetAllModels 获取所有模型
104  func (s *ModelStore) GetAllModels() ([]*Model, error) {
105  	var models []*Model
106  	err := s.db.Preload("User").Order("created_at DESC").Find(&models).Error
107  	if err != nil {
108  		return nil, err
109  	}
110  	return models, nil
111  }
112  
113  // GetUserModels 获取用户的所有模型
114  func (s *ModelStore) GetUserModels(username string) ([]*Model, error) {
115  	var models []*Model
116  	err := s.db.Preload("User").Where("username = ? or username = '' or username = 'public_user'", username).Order("created_at DESC").Find(&models).Error
117  	if err != nil {
118  		return nil, err
119  	}
120  
121  	// Append YAML models
122  	yamlModels, _ := s.LoadYamlModels()
123  	if len(yamlModels) > 0 {
124  		models = append(models, yamlModels...)
125  	}
126  
127  	return models, nil
128  }
129  
130  // UpdateModel 更新模型信息
131  func (s *ModelStore) UpdateModel(modelID string, username string, updates map[string]interface{}) error {
132  	// 添加更新时间
133  	updates["updated_at"] = time.Now().UnixMilli()
134  	return s.db.Model(&Model{}).Where("model_id = ? AND username = ?", modelID, username).Updates(updates).Error
135  }
136  
137  // DeleteModel 删除模型
138  func (s *ModelStore) DeleteModel(modelID string, username string) error {
139  	return s.db.Delete(&Model{}, "model_id = ? AND username = ?", modelID, username).Error
140  }
141  
142  // BatchDeleteModels 批量删除模型
143  func (s *ModelStore) BatchDeleteModels(modelIDs []string, username string) (int64, error) {
144  	result := s.db.Delete(&Model{}, "model_id IN ? AND username = ?", modelIDs, username)
145  	return result.RowsAffected, result.Error
146  }
147  
148  // CheckModelExists 检查模型是否存在
149  func (s *ModelStore) CheckModelExists(modelID string) (bool, error) {
150  	var count int64
151  	err := s.db.Model(&Model{}).Where("model_id = ?", modelID).Count(&count).Error
152  	if count > 0 {
153  		return true, nil
154  	}
155  	// Check YAML
156  	if s.GetYamlModel(modelID) != nil {
157  		return true, nil
158  	}
159  	return false, err
160  }
161  
162  // CheckModelExistsByUser 检查用户是否拥有该模型
163  func (s *ModelStore) CheckModelExistsByUser(modelID string, username string) (bool, error) {
164  	var count int64
165  	err := s.db.Model(&Model{}).Where("model_id = ? AND username = ?", modelID, username).Count(&count).Error
166  	return count > 0, err
167  }
168  
169  func (s *ModelStore) AutoAddModels() {
170  	// 判断如果模型为空,并且环境变量存在 model token base_url,则自动添加模型
171  	if s.db == nil {
172  		return
173  	}
174  	var count int64
175  	s.db.Model(&Model{}).Count(&count)
176  	if count == 0 {
177  		model := os.Getenv("MODEL")
178  		token := os.Getenv("TOKEN")
179  		baseUrl := os.Getenv("BASE_URL")
180  		if model != "" && token != "" && baseUrl != "" {
181  			s.CreateModel(&Model{
182  				ModelID:   "system_default",
183  				Username:  "",
184  				ModelName: model,
185  				Token:     token,
186  				BaseURL:   baseUrl,
187  				Note:      "系统默认内置",
188  				CreatedAt: time.Now().UnixMilli(),
189  				UpdatedAt: time.Now().UnixMilli(),
190  			})
191  		}
192  	}
193  }