model_api.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 package websocket 20 21 import ( 22 "context" 23 "net/http" 24 "strings" 25 26 "github.com/Tencent/AI-Infra-Guard/common/utils/models" 27 28 "github.com/Tencent/AI-Infra-Guard/pkg/database" 29 "github.com/gin-gonic/gin" 30 "trpc.group/trpc-go/trpc-go/log" 31 ) 32 33 // ModelInfo 模型信息(用于创建) 34 type ModelInfo struct { 35 Model string `json:"model" binding:"required"` 36 Token string `json:"token" binding:"required"` 37 BaseURL string `json:"base_url" binding:"required"` 38 Limit int `json:"limit"` 39 Note string `json:"note"` 40 } 41 42 // CreateModelRequest 创建模型请求 43 type CreateModelRequest struct { 44 ModelID string `json:"model_id" binding:"required"` 45 Model ModelInfo `json:"model" binding:"required"` 46 } 47 48 // UpdateModelInfo 模型信息(用于更新) 49 // 这里不对 Token/BaseURL 使用 binding:"required",以支持“只改名称等字段”的场景。 50 type UpdateModelInfo struct { 51 Model string `json:"model"` 52 Token string `json:"token"` 53 BaseURL string `json:"base_url"` 54 Limit int `json:"limit"` 55 Note string `json:"note"` 56 } 57 58 // UpdateModelRequest 更新模型请求 59 type UpdateModelRequest struct { 60 Model UpdateModelInfo `json:"model" binding:"required"` 61 } 62 63 // DeleteModelRequest 删除模型请求 64 type DeleteModelRequest struct { 65 ModelIDs []string `json:"model_ids" binding:"required"` 66 } 67 68 // ModelManager 模型管理器 69 type ModelManager struct { 70 modelStore *database.ModelStore 71 } 72 73 const maskedToken = "********" 74 75 // maskToken 用于在对外返回模型信息时隐藏真实的 Token。 76 // 仅用于 JSON 返回,不影响数据库中实际存储的 Token。 77 func maskToken(token string) string { 78 if token == "" { 79 return "" 80 } 81 return maskedToken 82 } 83 84 // NewModelManager 创建新的ModelManager实例 85 func NewModelManager(modelStore *database.ModelStore) *ModelManager { 86 return &ModelManager{ 87 modelStore: modelStore, 88 } 89 } 90 91 // HandleGetModelList 获取模型列表接口 92 func HandleGetModelList(c *gin.Context, mm *ModelManager) { 93 traceID := getTraceID(c) 94 username := c.GetString("username") 95 96 log.Debugf("用户请求获取模型列表: trace_id=%s, username=%s", traceID, username) 97 98 var userModels []*database.Model 99 var err error 100 101 userModels, err = mm.modelStore.GetUserModels(username) 102 if err != nil { 103 log.Errorf("获取用户模型列表失败: trace_id=%s, username=%s, error=%v", traceID, username, err) 104 c.JSON(http.StatusOK, gin.H{ 105 "status": 1, 106 "message": "获取模型列表失败: " + err.Error(), 107 "data": nil, 108 }) 109 return 110 } 111 112 // 转换为期望的返回格式 113 var result []map[string]interface{} 114 115 for _, model := range userModels { 116 item := map[string]interface{}{ 117 "model_id": model.ModelID, 118 "model": map[string]interface{}{ 119 "model": model.ModelName, 120 // 对外返回时也对用户模型的 token 进行掩码处理 121 "token": maskToken(model.Token), 122 "base_url": model.BaseURL, 123 "note": model.Note, 124 "limit": model.Limit, 125 }, 126 } 127 if model.Default != nil { 128 item["default"] = model.Default 129 } 130 result = append(result, item) 131 } 132 133 log.Debugf("获取模型列表成功: trace_id=%s, username=%s, userModels=%d, publicModels=%d, total=%d", 134 traceID, username, len(userModels), len(userModels), len(result)) 135 136 c.JSON(http.StatusOK, gin.H{ 137 "status": 0, 138 "message": "获取模型列表成功", 139 "data": result, 140 }) 141 } 142 143 // HandleGetModelDetail 获取模型详情接口 144 func HandleGetModelDetail(c *gin.Context, mm *ModelManager) { 145 traceID := getTraceID(c) 146 modelID := c.Param("modelId") 147 username := c.GetString("username") 148 149 // 1. 字段校验 150 if modelID == "" { 151 log.Errorf("模型ID为空: trace_id=%s, username=%s", traceID, username) 152 c.JSON(http.StatusOK, gin.H{ 153 "status": 1, 154 "message": "模型ID不能为空", 155 "data": nil, 156 }) 157 return 158 } 159 160 log.Debugf("用户请求获取模型详情: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username) 161 162 // 2. 获取模型信息 163 model, err := mm.modelStore.GetModel(modelID) 164 if err != nil { 165 log.Errorf("获取模型详情失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, modelID, username, err) 166 c.JSON(http.StatusOK, gin.H{ 167 "status": 1, 168 "message": "模型不存在", 169 "data": nil, 170 }) 171 return 172 } 173 174 // 3. 身份校验(只有创建者可以查看) 175 if model.Username != username { 176 log.Errorf("无权限查看模型: trace_id=%s, modelID=%s, username=%s, owner=%s", traceID, modelID, username, model.Username) 177 c.JSON(http.StatusOK, gin.H{ 178 "status": 1, 179 "message": "无权限查看此模型", 180 "data": nil, 181 }) 182 return 183 } 184 185 log.Debugf("获取模型详情成功: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username) 186 187 // 转换为期望的返回格式 188 result := map[string]interface{}{ 189 "model_id": model.ModelID, 190 "model": map[string]interface{}{ 191 "model": model.ModelName, 192 // 对外隐藏真实 token,前端如需修改,只能输入新 token 193 "token": maskToken(model.Token), 194 "base_url": model.BaseURL, 195 "note": model.Note, 196 "limit": model.Limit, 197 }, 198 "default": model.Default, 199 } 200 201 c.JSON(http.StatusOK, gin.H{ 202 "status": 0, 203 "message": "获取模型详情成功", 204 "data": result, 205 }) 206 } 207 208 // HandleCreateModel 创建模型接口 209 func HandleCreateModel(c *gin.Context, mm *ModelManager) { 210 traceID := getTraceID(c) 211 username := c.GetString("username") 212 213 // 1. 字段校验 214 var req CreateModelRequest 215 if err := c.ShouldBindJSON(&req); err != nil { 216 log.Errorf("请求参数解析失败: trace_id=%s, username=%s, error=%v", traceID, username, err) 217 c.JSON(http.StatusOK, gin.H{ 218 "status": 1, 219 "message": "请求参数错误: " + err.Error(), 220 "data": nil, 221 }) 222 return 223 } 224 225 // 2. 验证必填字段 226 if req.ModelID == "" { 227 log.Errorf("模型ID为空: trace_id=%s, username=%s", traceID, username) 228 c.JSON(http.StatusOK, gin.H{ 229 "status": 1, 230 "message": "模型ID不能为空", 231 "data": nil, 232 }) 233 return 234 } 235 236 if req.Model.Model == "" { 237 log.Errorf("模型名称为空: trace_id=%s, username=%s", traceID, username) 238 c.JSON(http.StatusOK, gin.H{ 239 "status": 1, 240 "message": "模型名称不能为空", 241 "data": nil, 242 }) 243 return 244 } 245 246 if req.Model.Token == "" { 247 log.Errorf("API Token为空: trace_id=%s, username=%s", traceID, username) 248 c.JSON(http.StatusOK, gin.H{ 249 "status": 1, 250 "message": "API Token不能为空", 251 "data": nil, 252 }) 253 return 254 } 255 256 if req.Model.BaseURL == "" { 257 log.Errorf("基础URL为空: trace_id=%s, username=%s", traceID, username) 258 c.JSON(http.StatusOK, gin.H{ 259 "status": 1, 260 "message": "基础URL不能为空", 261 "data": nil, 262 }) 263 return 264 } 265 if req.Model.Limit == 0 { 266 req.Model.Limit = 1000 267 } 268 269 log.Debugf("用户请求创建模型: trace_id=%s, modelID=%s, modelName=%s, username=%s", traceID, req.ModelID, req.Model.Model, username) 270 271 // 3. 检查模型是否已存在 272 exists, err := mm.modelStore.CheckModelExists(req.ModelID) 273 if err != nil { 274 log.Errorf("检查模型是否存在失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, req.ModelID, username, err) 275 c.JSON(http.StatusOK, gin.H{ 276 "status": 1, 277 "message": "检查模型失败: " + err.Error(), 278 "data": nil, 279 }) 280 return 281 } 282 283 if exists { 284 log.Errorf("模型已存在: trace_id=%s, modelID=%s, username=%s", traceID, req.ModelID, username) 285 c.JSON(http.StatusOK, gin.H{ 286 "status": 1, 287 "message": "模型ID已存在", 288 "data": nil, 289 }) 290 return 291 } 292 // 校验模型 token base_url 293 ai := &models.OpenAI{ 294 Key: req.Model.Token, 295 Model: req.Model.Model, 296 BaseUrl: req.Model.BaseURL, 297 } 298 if !strings.HasSuffix(ai.BaseUrl, "/") { 299 ai.BaseUrl += "/" 300 } 301 err = ai.Vaild(context.Background()) 302 if err != nil { 303 log.Errorf("模型校验失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, req.ModelID, username, err) 304 c.JSON(http.StatusOK, gin.H{ 305 "status": 1, 306 "message": "模型校验失败: " + err.Error(), 307 "data": nil, 308 }) 309 return 310 } 311 312 // 4. 创建模型 313 model := &database.Model{ 314 ModelID: req.ModelID, 315 Username: username, 316 ModelName: req.Model.Model, 317 Token: req.Model.Token, 318 BaseURL: req.Model.BaseURL, 319 Note: req.Model.Note, 320 Limit: req.Model.Limit, 321 } 322 323 err = mm.modelStore.CreateModel(model) 324 if err != nil { 325 log.Errorf("创建模型失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, req.ModelID, username, err) 326 c.JSON(http.StatusOK, gin.H{ 327 "status": 1, 328 "message": "创建模型失败: " + err.Error(), 329 "data": nil, 330 }) 331 return 332 } 333 334 log.Debugf("创建模型成功: trace_id=%s, modelID=%s, modelName=%s, username=%s", traceID, req.ModelID, req.Model.Model, username) 335 336 c.JSON(http.StatusOK, gin.H{ 337 "status": 0, 338 "message": "模型创建成功", 339 "data": nil, 340 }) 341 } 342 343 // HandleUpdateModel 更新模型接口 344 func HandleUpdateModel(c *gin.Context, mm *ModelManager) { 345 traceID := getTraceID(c) 346 modelID := c.Param("modelId") 347 username := c.GetString("username") 348 349 // 1. 字段校验 350 if modelID == "" { 351 log.Errorf("模型ID为空: trace_id=%s, username=%s", traceID, username) 352 c.JSON(http.StatusOK, gin.H{ 353 "status": 1, 354 "message": "模型ID不能为空", 355 "data": nil, 356 }) 357 return 358 } 359 360 var req UpdateModelRequest 361 if err := c.ShouldBindJSON(&req); err != nil { 362 log.Errorf("请求参数解析失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, modelID, username, err) 363 c.JSON(http.StatusOK, gin.H{ 364 "status": 1, 365 "message": "请求参数错误: " + err.Error(), 366 "data": nil, 367 }) 368 return 369 } 370 371 log.Infof("用户请求更新模型: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username) 372 373 // 2. 身份校验(检查模型是否存在且属于该用户) 374 exists, err := mm.modelStore.CheckModelExistsByUser(modelID, username) 375 if err != nil { 376 log.Errorf("检查模型权限失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, modelID, username, err) 377 c.JSON(http.StatusOK, gin.H{ 378 "status": 1, 379 "message": "检查模型权限失败: " + err.Error(), 380 "data": nil, 381 }) 382 return 383 } 384 385 if !exists { 386 log.Errorf("模型不存在或无权限: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username) 387 c.JSON(http.StatusOK, gin.H{ 388 "status": 1, 389 "message": "模型不存在或无权限", 390 "data": nil, 391 }) 392 return 393 } 394 395 // 3. 构造更新字段 396 // 支持“只改模型名称/备注,不改 key/base_url”的场景: 397 // - 前端在编辑时如果不填写 token/base_url,则保持数据库中的原值不变; 398 // - 只有在显式传入新 token/base_url 且不等于掩码串时才会更新。 399 updates := map[string]interface{}{ 400 "model_name": req.Model.Model, 401 "note": req.Model.Note, 402 "limit": req.Model.Limit, 403 } 404 if req.Model.Token != "" && req.Model.Token != maskedToken { 405 updates["token"] = req.Model.Token 406 } 407 if req.Model.BaseURL != "" { 408 updates["base_url"] = req.Model.BaseURL 409 } 410 411 err = mm.modelStore.UpdateModel(modelID, username, updates) 412 if err != nil { 413 log.Errorf("更新模型失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, modelID, username, err) 414 c.JSON(http.StatusOK, gin.H{ 415 "status": 1, 416 "message": "更新模型失败: " + err.Error(), 417 "data": nil, 418 }) 419 return 420 } 421 422 log.Infof("更新模型成功: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username) 423 424 c.JSON(http.StatusOK, gin.H{ 425 "status": 0, 426 "message": "模型更新成功", 427 "data": nil, 428 }) 429 } 430 431 // HandleDeleteModel 删除模型接口(支持单个和批量) 432 func HandleDeleteModel(c *gin.Context, mm *ModelManager) { 433 traceID := getTraceID(c) 434 username := c.GetString("username") 435 436 // 1. 字段校验 437 var req DeleteModelRequest 438 if err := c.ShouldBindJSON(&req); err != nil { 439 log.Errorf("请求参数解析失败: trace_id=%s, username=%s, error=%v", traceID, username, err) 440 c.JSON(http.StatusOK, gin.H{ 441 "status": 1, 442 "message": "请求参数错误: " + err.Error(), 443 "data": nil, 444 }) 445 return 446 } 447 448 if len(req.ModelIDs) == 0 { 449 log.Errorf("模型ID列表为空: trace_id=%s, username=%s", traceID, username) 450 c.JSON(http.StatusOK, gin.H{ 451 "status": 1, 452 "message": "模型ID列表不能为空", 453 "data": nil, 454 }) 455 return 456 } 457 458 log.Infof("用户请求删除模型: trace_id=%s, modelIDs=%v, username=%s", traceID, req.ModelIDs, username) 459 460 // 2. 身份校验(检查所有模型是否属于该用户) 461 for _, modelID := range req.ModelIDs { 462 exists, err := mm.modelStore.CheckModelExistsByUser(modelID, username) 463 if err != nil { 464 log.Errorf("检查模型权限失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, modelID, username, err) 465 c.JSON(http.StatusOK, gin.H{ 466 "status": 1, 467 "message": "检查模型权限失败: " + err.Error(), 468 "data": nil, 469 }) 470 return 471 } 472 473 if !exists { 474 log.Errorf("模型不存在或无权限: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username) 475 c.JSON(http.StatusOK, gin.H{ 476 "status": 1, 477 "message": "模型不存在或无权限", 478 "data": nil, 479 }) 480 return 481 } 482 } 483 484 // 3. 批量删除模型 485 deletedCount, err := mm.modelStore.BatchDeleteModels(req.ModelIDs, username) 486 if err != nil { 487 log.Errorf("删除模型失败: trace_id=%s, modelIDs=%v, username=%s, error=%v", traceID, req.ModelIDs, username, err) 488 c.JSON(http.StatusOK, gin.H{ 489 "status": 1, 490 "message": "删除模型失败: " + err.Error(), 491 "data": nil, 492 }) 493 return 494 } 495 496 log.Infof("删除模型成功: trace_id=%s, modelIDs=%v, username=%s, deletedCount=%d", traceID, req.ModelIDs, username, deletedCount) 497 498 c.JSON(http.StatusOK, gin.H{ 499 "status": 0, 500 "message": "删除成功", 501 "data": nil, 502 }) 503 }