/ common / websocket / model_api.go
model_api.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  package websocket
 20  
 21  import (
 22  	"context"
 23  	"net/http"
 24  	"strings"
 25  
 26  	"github.com/Tencent/AI-Infra-Guard/common/utils/models"
 27  
 28  	"github.com/Tencent/AI-Infra-Guard/pkg/database"
 29  	"github.com/gin-gonic/gin"
 30  	"trpc.group/trpc-go/trpc-go/log"
 31  )
 32  
 33  // ModelInfo 模型信息(用于创建)
 34  type ModelInfo struct {
 35  	Model              string `json:"model" binding:"required"`
 36  	Token              string `json:"token" binding:"required"`
 37  	BaseURL            string `json:"base_url" binding:"required"`
 38  	Limit              int    `json:"limit"`
 39  	Note               string `json:"note"`
 40  }
 41  
 42  // CreateModelRequest 创建模型请求
 43  type CreateModelRequest struct {
 44  	ModelID string    `json:"model_id" binding:"required"`
 45  	Model   ModelInfo `json:"model" binding:"required"`
 46  }
 47  
 48  // UpdateModelInfo 模型信息(用于更新)
 49  // 这里不对 Token/BaseURL 使用 binding:"required",以支持“只改名称等字段”的场景。
 50  type UpdateModelInfo struct {
 51  	Model              string `json:"model"`
 52  	Token              string `json:"token"`
 53  	BaseURL            string `json:"base_url"`
 54  	Limit              int    `json:"limit"`
 55  	Note               string `json:"note"`
 56  }
 57  
 58  // UpdateModelRequest 更新模型请求
 59  type UpdateModelRequest struct {
 60  	Model UpdateModelInfo `json:"model" binding:"required"`
 61  }
 62  
 63  // DeleteModelRequest 删除模型请求
 64  type DeleteModelRequest struct {
 65  	ModelIDs []string `json:"model_ids" binding:"required"`
 66  }
 67  
 68  // ModelManager 模型管理器
 69  type ModelManager struct {
 70  	modelStore *database.ModelStore
 71  }
 72  
 73  const maskedToken = "********"
 74  
 75  // maskToken 用于在对外返回模型信息时隐藏真实的 Token。
 76  // 仅用于 JSON 返回,不影响数据库中实际存储的 Token。
 77  func maskToken(token string) string {
 78  	if token == "" {
 79  		return ""
 80  	}
 81  	return maskedToken
 82  }
 83  
 84  // NewModelManager 创建新的ModelManager实例
 85  func NewModelManager(modelStore *database.ModelStore) *ModelManager {
 86  	return &ModelManager{
 87  		modelStore: modelStore,
 88  	}
 89  }
 90  
 91  // HandleGetModelList 获取模型列表接口
 92  func HandleGetModelList(c *gin.Context, mm *ModelManager) {
 93  	traceID := getTraceID(c)
 94  	username := c.GetString("username")
 95  
 96  	log.Debugf("用户请求获取模型列表: trace_id=%s, username=%s", traceID, username)
 97  
 98  	var userModels []*database.Model
 99  	var err error
100  
101  	userModels, err = mm.modelStore.GetUserModels(username)
102  	if err != nil {
103  		log.Errorf("获取用户模型列表失败: trace_id=%s, username=%s, error=%v", traceID, username, err)
104  		c.JSON(http.StatusOK, gin.H{
105  			"status":  1,
106  			"message": "获取模型列表失败: " + err.Error(),
107  			"data":    nil,
108  		})
109  		return
110  	}
111  
112  	// 转换为期望的返回格式
113  	var result []map[string]interface{}
114  
115  	for _, model := range userModels {
116  		item := map[string]interface{}{
117  			"model_id": model.ModelID,
118  			"model": map[string]interface{}{
119  				"model": model.ModelName,
120  				// 对外返回时也对用户模型的 token 进行掩码处理
121  				"token":    maskToken(model.Token),
122  				"base_url": model.BaseURL,
123  				"note":     model.Note,
124  				"limit":    model.Limit,
125  			},
126  		}
127  		if model.Default != nil {
128  			item["default"] = model.Default
129  		}
130  		result = append(result, item)
131  	}
132  
133  	log.Debugf("获取模型列表成功: trace_id=%s, username=%s, userModels=%d, publicModels=%d, total=%d",
134  		traceID, username, len(userModels), len(userModels), len(result))
135  
136  	c.JSON(http.StatusOK, gin.H{
137  		"status":  0,
138  		"message": "获取模型列表成功",
139  		"data":    result,
140  	})
141  }
142  
143  // HandleGetModelDetail 获取模型详情接口
144  func HandleGetModelDetail(c *gin.Context, mm *ModelManager) {
145  	traceID := getTraceID(c)
146  	modelID := c.Param("modelId")
147  	username := c.GetString("username")
148  
149  	// 1. 字段校验
150  	if modelID == "" {
151  		log.Errorf("模型ID为空: trace_id=%s, username=%s", traceID, username)
152  		c.JSON(http.StatusOK, gin.H{
153  			"status":  1,
154  			"message": "模型ID不能为空",
155  			"data":    nil,
156  		})
157  		return
158  	}
159  
160  	log.Debugf("用户请求获取模型详情: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username)
161  
162  	// 2. 获取模型信息
163  	model, err := mm.modelStore.GetModel(modelID)
164  	if err != nil {
165  		log.Errorf("获取模型详情失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, modelID, username, err)
166  		c.JSON(http.StatusOK, gin.H{
167  			"status":  1,
168  			"message": "模型不存在",
169  			"data":    nil,
170  		})
171  		return
172  	}
173  
174  	// 3. 身份校验(只有创建者可以查看)
175  	if model.Username != username {
176  		log.Errorf("无权限查看模型: trace_id=%s, modelID=%s, username=%s, owner=%s", traceID, modelID, username, model.Username)
177  		c.JSON(http.StatusOK, gin.H{
178  			"status":  1,
179  			"message": "无权限查看此模型",
180  			"data":    nil,
181  		})
182  		return
183  	}
184  
185  	log.Debugf("获取模型详情成功: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username)
186  
187  	// 转换为期望的返回格式
188  	result := map[string]interface{}{
189  		"model_id": model.ModelID,
190  		"model": map[string]interface{}{
191  			"model": model.ModelName,
192  			// 对外隐藏真实 token,前端如需修改,只能输入新 token
193  			"token":    maskToken(model.Token),
194  			"base_url": model.BaseURL,
195  			"note":     model.Note,
196  			"limit":    model.Limit,
197  		},
198  		"default": model.Default,
199  	}
200  
201  	c.JSON(http.StatusOK, gin.H{
202  		"status":  0,
203  		"message": "获取模型详情成功",
204  		"data":    result,
205  	})
206  }
207  
208  // HandleCreateModel 创建模型接口
209  func HandleCreateModel(c *gin.Context, mm *ModelManager) {
210  	traceID := getTraceID(c)
211  	username := c.GetString("username")
212  
213  	// 1. 字段校验
214  	var req CreateModelRequest
215  	if err := c.ShouldBindJSON(&req); err != nil {
216  		log.Errorf("请求参数解析失败: trace_id=%s, username=%s, error=%v", traceID, username, err)
217  		c.JSON(http.StatusOK, gin.H{
218  			"status":  1,
219  			"message": "请求参数错误: " + err.Error(),
220  			"data":    nil,
221  		})
222  		return
223  	}
224  
225  	// 2. 验证必填字段
226  	if req.ModelID == "" {
227  		log.Errorf("模型ID为空: trace_id=%s, username=%s", traceID, username)
228  		c.JSON(http.StatusOK, gin.H{
229  			"status":  1,
230  			"message": "模型ID不能为空",
231  			"data":    nil,
232  		})
233  		return
234  	}
235  
236  	if req.Model.Model == "" {
237  		log.Errorf("模型名称为空: trace_id=%s, username=%s", traceID, username)
238  		c.JSON(http.StatusOK, gin.H{
239  			"status":  1,
240  			"message": "模型名称不能为空",
241  			"data":    nil,
242  		})
243  		return
244  	}
245  
246  	if req.Model.Token == "" {
247  		log.Errorf("API Token为空: trace_id=%s, username=%s", traceID, username)
248  		c.JSON(http.StatusOK, gin.H{
249  			"status":  1,
250  			"message": "API Token不能为空",
251  			"data":    nil,
252  		})
253  		return
254  	}
255  
256  	if req.Model.BaseURL == "" {
257  		log.Errorf("基础URL为空: trace_id=%s, username=%s", traceID, username)
258  		c.JSON(http.StatusOK, gin.H{
259  			"status":  1,
260  			"message": "基础URL不能为空",
261  			"data":    nil,
262  		})
263  		return
264  	}
265  	if req.Model.Limit == 0 {
266  		req.Model.Limit = 1000
267  	}
268  
269  	log.Debugf("用户请求创建模型: trace_id=%s, modelID=%s, modelName=%s, username=%s", traceID, req.ModelID, req.Model.Model, username)
270  
271  	// 3. 检查模型是否已存在
272  	exists, err := mm.modelStore.CheckModelExists(req.ModelID)
273  	if err != nil {
274  		log.Errorf("检查模型是否存在失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, req.ModelID, username, err)
275  		c.JSON(http.StatusOK, gin.H{
276  			"status":  1,
277  			"message": "检查模型失败: " + err.Error(),
278  			"data":    nil,
279  		})
280  		return
281  	}
282  
283  	if exists {
284  		log.Errorf("模型已存在: trace_id=%s, modelID=%s, username=%s", traceID, req.ModelID, username)
285  		c.JSON(http.StatusOK, gin.H{
286  			"status":  1,
287  			"message": "模型ID已存在",
288  			"data":    nil,
289  		})
290  		return
291  	}
292  	// 校验模型 token base_url
293  	ai := &models.OpenAI{
294  		Key:                req.Model.Token,
295  		Model:              req.Model.Model,
296  		BaseUrl:            req.Model.BaseURL,
297  	}
298  	if !strings.HasSuffix(ai.BaseUrl, "/") {
299  		ai.BaseUrl += "/"
300  	}
301  	err = ai.Vaild(context.Background())
302  	if err != nil {
303  		log.Errorf("模型校验失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, req.ModelID, username, err)
304  		c.JSON(http.StatusOK, gin.H{
305  			"status":  1,
306  			"message": "模型校验失败: " + err.Error(),
307  			"data":    nil,
308  		})
309  		return
310  	}
311  
312  	// 4. 创建模型
313  	model := &database.Model{
314  		ModelID:            req.ModelID,
315  		Username:           username,
316  		ModelName:          req.Model.Model,
317  		Token:              req.Model.Token,
318  		BaseURL:            req.Model.BaseURL,
319  		Note:               req.Model.Note,
320  		Limit:              req.Model.Limit,
321  	}
322  
323  	err = mm.modelStore.CreateModel(model)
324  	if err != nil {
325  		log.Errorf("创建模型失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, req.ModelID, username, err)
326  		c.JSON(http.StatusOK, gin.H{
327  			"status":  1,
328  			"message": "创建模型失败: " + err.Error(),
329  			"data":    nil,
330  		})
331  		return
332  	}
333  
334  	log.Debugf("创建模型成功: trace_id=%s, modelID=%s, modelName=%s, username=%s", traceID, req.ModelID, req.Model.Model, username)
335  
336  	c.JSON(http.StatusOK, gin.H{
337  		"status":  0,
338  		"message": "模型创建成功",
339  		"data":    nil,
340  	})
341  }
342  
343  // HandleUpdateModel 更新模型接口
344  func HandleUpdateModel(c *gin.Context, mm *ModelManager) {
345  	traceID := getTraceID(c)
346  	modelID := c.Param("modelId")
347  	username := c.GetString("username")
348  
349  	// 1. 字段校验
350  	if modelID == "" {
351  		log.Errorf("模型ID为空: trace_id=%s, username=%s", traceID, username)
352  		c.JSON(http.StatusOK, gin.H{
353  			"status":  1,
354  			"message": "模型ID不能为空",
355  			"data":    nil,
356  		})
357  		return
358  	}
359  
360  	var req UpdateModelRequest
361  	if err := c.ShouldBindJSON(&req); err != nil {
362  		log.Errorf("请求参数解析失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, modelID, username, err)
363  		c.JSON(http.StatusOK, gin.H{
364  			"status":  1,
365  			"message": "请求参数错误: " + err.Error(),
366  			"data":    nil,
367  		})
368  		return
369  	}
370  
371  	log.Infof("用户请求更新模型: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username)
372  
373  	// 2. 身份校验(检查模型是否存在且属于该用户)
374  	exists, err := mm.modelStore.CheckModelExistsByUser(modelID, username)
375  	if err != nil {
376  		log.Errorf("检查模型权限失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, modelID, username, err)
377  		c.JSON(http.StatusOK, gin.H{
378  			"status":  1,
379  			"message": "检查模型权限失败: " + err.Error(),
380  			"data":    nil,
381  		})
382  		return
383  	}
384  
385  	if !exists {
386  		log.Errorf("模型不存在或无权限: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username)
387  		c.JSON(http.StatusOK, gin.H{
388  			"status":  1,
389  			"message": "模型不存在或无权限",
390  			"data":    nil,
391  		})
392  		return
393  	}
394  
395  	// 3. 构造更新字段
396  	// 支持“只改模型名称/备注,不改 key/base_url”的场景:
397  	// - 前端在编辑时如果不填写 token/base_url,则保持数据库中的原值不变;
398  	// - 只有在显式传入新 token/base_url 且不等于掩码串时才会更新。
399  	updates := map[string]interface{}{
400  		"model_name": req.Model.Model,
401  		"note":       req.Model.Note,
402  		"limit":      req.Model.Limit,
403  	}
404  	if req.Model.Token != "" && req.Model.Token != maskedToken {
405  		updates["token"] = req.Model.Token
406  	}
407  	if req.Model.BaseURL != "" {
408  		updates["base_url"] = req.Model.BaseURL
409  	}
410  
411  	err = mm.modelStore.UpdateModel(modelID, username, updates)
412  	if err != nil {
413  		log.Errorf("更新模型失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, modelID, username, err)
414  		c.JSON(http.StatusOK, gin.H{
415  			"status":  1,
416  			"message": "更新模型失败: " + err.Error(),
417  			"data":    nil,
418  		})
419  		return
420  	}
421  
422  	log.Infof("更新模型成功: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username)
423  
424  	c.JSON(http.StatusOK, gin.H{
425  		"status":  0,
426  		"message": "模型更新成功",
427  		"data":    nil,
428  	})
429  }
430  
431  // HandleDeleteModel 删除模型接口(支持单个和批量)
432  func HandleDeleteModel(c *gin.Context, mm *ModelManager) {
433  	traceID := getTraceID(c)
434  	username := c.GetString("username")
435  
436  	// 1. 字段校验
437  	var req DeleteModelRequest
438  	if err := c.ShouldBindJSON(&req); err != nil {
439  		log.Errorf("请求参数解析失败: trace_id=%s, username=%s, error=%v", traceID, username, err)
440  		c.JSON(http.StatusOK, gin.H{
441  			"status":  1,
442  			"message": "请求参数错误: " + err.Error(),
443  			"data":    nil,
444  		})
445  		return
446  	}
447  
448  	if len(req.ModelIDs) == 0 {
449  		log.Errorf("模型ID列表为空: trace_id=%s, username=%s", traceID, username)
450  		c.JSON(http.StatusOK, gin.H{
451  			"status":  1,
452  			"message": "模型ID列表不能为空",
453  			"data":    nil,
454  		})
455  		return
456  	}
457  
458  	log.Infof("用户请求删除模型: trace_id=%s, modelIDs=%v, username=%s", traceID, req.ModelIDs, username)
459  
460  	// 2. 身份校验(检查所有模型是否属于该用户)
461  	for _, modelID := range req.ModelIDs {
462  		exists, err := mm.modelStore.CheckModelExistsByUser(modelID, username)
463  		if err != nil {
464  			log.Errorf("检查模型权限失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, modelID, username, err)
465  			c.JSON(http.StatusOK, gin.H{
466  				"status":  1,
467  				"message": "检查模型权限失败: " + err.Error(),
468  				"data":    nil,
469  			})
470  			return
471  		}
472  
473  		if !exists {
474  			log.Errorf("模型不存在或无权限: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username)
475  			c.JSON(http.StatusOK, gin.H{
476  				"status":  1,
477  				"message": "模型不存在或无权限",
478  				"data":    nil,
479  			})
480  			return
481  		}
482  	}
483  
484  	// 3. 批量删除模型
485  	deletedCount, err := mm.modelStore.BatchDeleteModels(req.ModelIDs, username)
486  	if err != nil {
487  		log.Errorf("删除模型失败: trace_id=%s, modelIDs=%v, username=%s, error=%v", traceID, req.ModelIDs, username, err)
488  		c.JSON(http.StatusOK, gin.H{
489  			"status":  1,
490  			"message": "删除模型失败: " + err.Error(),
491  			"data":    nil,
492  		})
493  		return
494  	}
495  
496  	log.Infof("删除模型成功: trace_id=%s, modelIDs=%v, username=%s, deletedCount=%d", traceID, req.ModelIDs, username, deletedCount)
497  
498  	c.JSON(http.StatusOK, gin.H{
499  		"status":  0,
500  		"message": "删除成功",
501  		"data":    nil,
502  	})
503  }