config.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 package database 20 21 import ( 22 "fmt" 23 "os" 24 "path/filepath" 25 26 "github.com/Tencent/AI-Infra-Guard/internal/gologger" 27 28 "github.com/glebarez/sqlite" 29 "gorm.io/gorm" 30 ) 31 32 // Config 用于保存数据库配置 33 type Config struct { 34 DBPath string 35 } 36 37 // NewConfig 创建一个新的数据库配置 38 func NewConfig(dbPath string) *Config { 39 return &Config{DBPath: dbPath} 40 } 41 42 // LoadConfigFromEnv 从环境变量加载数据库配置 43 func LoadConfigFromEnv() *Config { 44 // 默认数据库路径 45 defaultDBPath := "db/tasks.db" 46 47 // 从环境变量读取数据库路径 48 if dbPath := os.Getenv("DB_PATH"); dbPath != "" { 49 defaultDBPath = dbPath 50 } 51 52 return &Config{DBPath: defaultDBPath} 53 } 54 55 // InitDB 用 GORM 初始化数据库连接并返回 *gorm.DB 56 func InitDB(config *Config) (*gorm.DB, error) { 57 // 确保数据库目录存在 58 dir := filepath.Dir(config.DBPath) 59 if err := os.MkdirAll(dir, 0755); err != nil { 60 return nil, fmt.Errorf("创建数据库目录失败: %v", err) 61 } 62 63 //打开数据库连接 - 启用WAL模式和共享缓存以支持并发访问 64 db, err := gorm.Open(sqlite.Open(config.DBPath+"?_journal=WAL&_timeout=5000&cache=shared"), &gorm.Config{}) 65 if err != nil { 66 gologger.WithError(err).Fatalln("无法打开数据库连接") 67 } 68 // 获取底层的SQL DB以配置连接池 69 sqlDB, err := db.DB() 70 if err != nil { 71 panic("failed to get database connection") 72 } 73 74 // 设置连接池参数 75 sqlDB.SetMaxIdleConns(1000) 76 sqlDB.SetMaxOpenConns(1000) 77 78 return db, nil 79 }