yaml_model.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 package database 20 21 import ( 22 "os" 23 24 "github.com/Tencent/AI-Infra-Guard/internal/gologger" 25 "gopkg.in/yaml.v3" 26 ) 27 28 const YamlModelPath = "db/model.yaml" 29 30 // LoadYamlModels 加载YAML模型配置 31 func (s *ModelStore) LoadYamlModels() ([]*Model, error) { 32 data, err := os.ReadFile(YamlModelPath) 33 if err != nil { 34 if os.IsNotExist(err) { 35 return nil, nil 36 } 37 gologger.Errorf("读取模型配置文件失败: %v", err) 38 return nil, err 39 } 40 41 var models []*Model 42 if err := yaml.Unmarshal(data, &models); err != nil { 43 gologger.Errorf("解析模型配置文件失败: %v", err) 44 return nil, err 45 } 46 47 return models, nil 48 } 49 50 // GetYamlModel 获取指定的YAML模型 51 func (s *ModelStore) GetYamlModel(modelID string) *Model { 52 models, err := s.LoadYamlModels() 53 if err != nil { 54 return nil 55 } 56 for _, m := range models { 57 if m.ModelID == modelID { 58 return m 59 } 60 } 61 return nil 62 }