model.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 package database 20 21 import ( 22 "os" 23 "time" 24 25 "gorm.io/gorm" 26 ) 27 28 type ModelParams struct { 29 BaseUrl string `json:"base_url"` 30 Token string `json:"token"` 31 Model string `json:"model"` 32 Limit int `json:"limit"` 33 } 34 35 // Model 模型表 36 type Model struct { 37 ModelID string `gorm:"primaryKey;column:model_id" json:"model_id" yaml:"model_id"` // 模型ID 38 Username string `gorm:"column:username;not null" json:"username" yaml:"-"` // 创建者用户名 39 ModelName string `gorm:"column:model_name;not null" json:"model_name" yaml:"model_name"` // 模型名称 40 Token string `gorm:"column:token;not null" json:"token" yaml:"token"` // API Token 41 BaseURL string `gorm:"column:base_url;not null" json:"base_url" yaml:"base_url"` // 基础URL 42 Note string `gorm:"column:note" json:"note" yaml:"note,omitempty"` // 备注信息 43 Limit int `gorm:"column:limit" json:"limit" yaml:"limit,omitempty"` 44 Default []string `gorm:"-" json:"default,omitempty" yaml:"default,omitempty"` // 默认字段 45 CreatedAt int64 `gorm:"column:created_at;not null" json:"created_at" yaml:"-"` // 时间戳毫秒级 46 UpdatedAt int64 `gorm:"column:updated_at;not null" json:"updated_at" yaml:"-"` // 时间戳毫秒级 47 48 // 关联关系 49 User User `gorm:"foreignKey:Username" json:"user" yaml:"-"` 50 } 51 52 // ModelStore 模型数据存储 53 type ModelStore struct { 54 db *gorm.DB 55 } 56 57 // NewModelStore 创建新的ModelStore实例 58 func NewModelStore(db *gorm.DB) *ModelStore { 59 return &ModelStore{db: db} 60 } 61 62 // Init 自动迁移模型相关表结构 63 func (s *ModelStore) Init() error { 64 if err := s.db.AutoMigrate(&Model{}); err != nil { 65 return err 66 } 67 // 创建索引优化查询 68 return s.db.Exec("CREATE INDEX IF NOT EXISTS idx_models_username_created ON models(username, created_at DESC)").Error 69 } 70 71 // CreateModel 创建模型 72 func (s *ModelStore) CreateModel(model *Model) error { 73 now := time.Now().UnixMilli() 74 model.CreatedAt = now 75 model.UpdatedAt = now 76 return s.db.Create(model).Error 77 } 78 79 // GetModel 获取模型信息 80 func (s *ModelStore) GetModel(modelID string) (*Model, error) { 81 var model Model 82 err := s.db.Preload("User").First(&model, "model_id = ?", modelID).Error 83 if err != nil { 84 // Try YAML model 85 if yamlModel := s.GetYamlModel(modelID); yamlModel != nil { 86 return yamlModel, nil 87 } 88 return nil, err 89 } 90 return &model, nil 91 } 92 93 // GetModelByUser 获取用户创建的模型 94 func (s *ModelStore) GetModelByUser(modelID string, username string) (*Model, error) { 95 var model Model 96 err := s.db.Preload("User").First(&model, "model_id = ? AND username = ?", modelID, username).Error 97 if err != nil { 98 return nil, err 99 } 100 return &model, nil 101 } 102 103 // GetAllModels 获取所有模型 104 func (s *ModelStore) GetAllModels() ([]*Model, error) { 105 var models []*Model 106 err := s.db.Preload("User").Order("created_at DESC").Find(&models).Error 107 if err != nil { 108 return nil, err 109 } 110 return models, nil 111 } 112 113 // GetUserModels 获取用户的所有模型 114 func (s *ModelStore) GetUserModels(username string) ([]*Model, error) { 115 var models []*Model 116 err := s.db.Preload("User").Where("username = ? or username = '' or username = 'public_user'", username).Order("created_at DESC").Find(&models).Error 117 if err != nil { 118 return nil, err 119 } 120 121 // Append YAML models 122 yamlModels, _ := s.LoadYamlModels() 123 if len(yamlModels) > 0 { 124 models = append(models, yamlModels...) 125 } 126 127 return models, nil 128 } 129 130 // UpdateModel 更新模型信息 131 func (s *ModelStore) UpdateModel(modelID string, username string, updates map[string]interface{}) error { 132 // 添加更新时间 133 updates["updated_at"] = time.Now().UnixMilli() 134 return s.db.Model(&Model{}).Where("model_id = ? AND username = ?", modelID, username).Updates(updates).Error 135 } 136 137 // DeleteModel 删除模型 138 func (s *ModelStore) DeleteModel(modelID string, username string) error { 139 return s.db.Delete(&Model{}, "model_id = ? AND username = ?", modelID, username).Error 140 } 141 142 // BatchDeleteModels 批量删除模型 143 func (s *ModelStore) BatchDeleteModels(modelIDs []string, username string) (int64, error) { 144 result := s.db.Delete(&Model{}, "model_id IN ? AND username = ?", modelIDs, username) 145 return result.RowsAffected, result.Error 146 } 147 148 // CheckModelExists 检查模型是否存在 149 func (s *ModelStore) CheckModelExists(modelID string) (bool, error) { 150 var count int64 151 err := s.db.Model(&Model{}).Where("model_id = ?", modelID).Count(&count).Error 152 if count > 0 { 153 return true, nil 154 } 155 // Check YAML 156 if s.GetYamlModel(modelID) != nil { 157 return true, nil 158 } 159 return false, err 160 } 161 162 // CheckModelExistsByUser 检查用户是否拥有该模型 163 func (s *ModelStore) CheckModelExistsByUser(modelID string, username string) (bool, error) { 164 var count int64 165 err := s.db.Model(&Model{}).Where("model_id = ? AND username = ?", modelID, username).Count(&count).Error 166 return count > 0, err 167 } 168 169 func (s *ModelStore) AutoAddModels() { 170 // 判断如果模型为空,并且环境变量存在 model token base_url,则自动添加模型 171 if s.db == nil { 172 return 173 } 174 var count int64 175 s.db.Model(&Model{}).Count(&count) 176 if count == 0 { 177 model := os.Getenv("MODEL") 178 token := os.Getenv("TOKEN") 179 baseUrl := os.Getenv("BASE_URL") 180 if model != "" && token != "" && baseUrl != "" { 181 s.CreateModel(&Model{ 182 ModelID: "system_default", 183 Username: "", 184 ModelName: model, 185 Token: token, 186 BaseURL: baseUrl, 187 Note: "系统默认内置", 188 CreatedAt: time.Now().UnixMilli(), 189 UpdatedAt: time.Now().UnixMilli(), 190 }) 191 } 192 } 193 }