/ pkg / database / config.go
config.go
 1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
 2  //
 3  // Licensed under the Apache License, Version 2.0 (the "License");
 4  // you may not use this file except in compliance with the License.
 5  // You may obtain a copy of the License at
 6  //
 7  //     http://www.apache.org/licenses/LICENSE-2.0
 8  //
 9  // Unless required by applicable law or agreed to in writing, software
10  // distributed under the License is distributed on an "AS IS" BASIS,
11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  // See the License for the specific language governing permissions and
13  // limitations under the License.
14  //
15  // Requirement: Any integration or derivative work must explicitly attribute
16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
17  // documentation or user interface, as detailed in the NOTICE file.
18  
19  package database
20  
21  import (
22  	"fmt"
23  	"os"
24  	"path/filepath"
25  
26  	"github.com/Tencent/AI-Infra-Guard/internal/gologger"
27  
28  	"github.com/glebarez/sqlite"
29  	"gorm.io/gorm"
30  )
31  
32  // Config 用于保存数据库配置
33  type Config struct {
34  	DBPath string
35  }
36  
37  // NewConfig 创建一个新的数据库配置
38  func NewConfig(dbPath string) *Config {
39  	return &Config{DBPath: dbPath}
40  }
41  
42  // LoadConfigFromEnv 从环境变量加载数据库配置
43  func LoadConfigFromEnv() *Config {
44  	// 默认数据库路径
45  	defaultDBPath := "db/tasks.db"
46  
47  	// 从环境变量读取数据库路径
48  	if dbPath := os.Getenv("DB_PATH"); dbPath != "" {
49  		defaultDBPath = dbPath
50  	}
51  
52  	return &Config{DBPath: defaultDBPath}
53  }
54  
55  // InitDB 用 GORM 初始化数据库连接并返回 *gorm.DB
56  func InitDB(config *Config) (*gorm.DB, error) {
57  	// 确保数据库目录存在
58  	dir := filepath.Dir(config.DBPath)
59  	if err := os.MkdirAll(dir, 0755); err != nil {
60  		return nil, fmt.Errorf("创建数据库目录失败: %v", err)
61  	}
62  
63  	//打开数据库连接 - 启用WAL模式和共享缓存以支持并发访问
64  	db, err := gorm.Open(sqlite.Open(config.DBPath+"?_journal=WAL&_timeout=5000&cache=shared"), &gorm.Config{})
65  	if err != nil {
66  		gologger.WithError(err).Fatalln("无法打开数据库连接")
67  	}
68  	// 获取底层的SQL DB以配置连接池
69  	sqlDB, err := db.DB()
70  	if err != nil {
71  		panic("failed to get database connection")
72  	}
73  
74  	// 设置连接池参数
75  	sqlDB.SetMaxIdleConns(1000)
76  	sqlDB.SetMaxOpenConns(1000)
77  
78  	return db, nil
79  }