update_api.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 // Package websocket provides the HTTP API handlers for the AIG web server. 20 package websocket 21 22 import ( 23 "encoding/json" 24 "fmt" 25 "io/fs" 26 "net/http" 27 "os" 28 "os/exec" 29 "path/filepath" 30 "regexp" 31 "strings" 32 "sync" 33 "time" 34 35 "github.com/gin-gonic/gin" 36 ) 37 38 // --------------------------------------------------------------------------- 39 // Constants & package-level state 40 // --------------------------------------------------------------------------- 41 42 const ( 43 defaultGitHubRepo = "https://github.com/Tencent/AI-Infra-Guard.git" 44 defaultGitHubBranch = "main" 45 46 // dataDirsDefault lists the sub-directories inside data/ that are synced by default. 47 dataDirsDefault = "fingerprints,vuln,vuln_en,mcp,eval,agents" 48 ) 49 50 // refPattern allows only safe git ref characters: alphanumerics, dots, hyphens, underscores, forward slashes. 51 // This prevents argument injection when ref is passed as a --branch value to git. 52 var refPattern = regexp.MustCompile(`^[a-zA-Z0-9._\-/]+$`) 53 54 // allowedDataDirs is the set of data/ sub-directories that may be requested by callers. 55 // Any directory name outside this set is silently rejected to prevent path traversal. 56 var allowedDataDirs = map[string]bool{ 57 "fingerprints": true, 58 "vuln": true, 59 "vuln_en": true, 60 "mcp": true, 61 "eval": true, 62 "agents": true, 63 } 64 65 // validateRef returns an error if ref contains characters outside the safe allowlist. 66 func validateRef(ref string) error { 67 if ref == "" { 68 return fmt.Errorf("ref must not be empty") 69 } 70 if len(ref) > 200 { 71 return fmt.Errorf("ref too long (max 200 chars)") 72 } 73 if !refPattern.MatchString(ref) { 74 return fmt.Errorf("ref %q contains invalid characters: only [a-zA-Z0-9._-/] are allowed", ref) 75 } 76 return nil 77 } 78 79 // UpdateStatus holds the current state of a data-sync operation. 80 type UpdateStatus struct { 81 Running bool `json:"running"` 82 Success *bool `json:"success,omitempty"` 83 StartedAt time.Time `json:"started_at,omitempty"` 84 FinishedAt *time.Time `json:"finished_at,omitempty"` 85 Message string `json:"message"` 86 FilesUpdated int `json:"files_updated"` 87 Ref string `json:"ref,omitempty"` 88 } 89 90 // updateDataResponse wraps UpdateStatus in the standard API envelope. 91 type updateDataResponse struct { 92 Status int `json:"status"` 93 Message string `json:"message"` 94 Data UpdateStatus `json:"data"` 95 } 96 97 var ( 98 updateMu sync.Mutex 99 updateStatus = &UpdateStatus{Message: "idle"} 100 ) 101 102 // --------------------------------------------------------------------------- 103 // Request / Response types 104 // --------------------------------------------------------------------------- 105 106 // UpdateDataRequest is the JSON body for POST /api/v1/system/update-data. 107 // The request body is optional and ignored; the sync always pulls from the 108 // default branch (main) and updates all data/ sub-directories. 109 type UpdateDataRequest struct{} 110 111 // --------------------------------------------------------------------------- 112 // Handlers 113 // --------------------------------------------------------------------------- 114 115 // HandleGetUpdateStatus godoc 116 // 117 // @Summary Get data-sync status 118 // @Description Returns the current (or last) status of the automatic data directory sync. 119 // @Tags system 120 // @Produce json 121 // @Success 200 {object} updateDataResponse 122 // @Router /api/v1/system/update-data [get] 123 func HandleGetUpdateStatus(c *gin.Context) { 124 updateMu.Lock() 125 snap := *updateStatus 126 updateMu.Unlock() 127 128 // Determine status code following the project convention: 129 // 0 = ok (idle / running / success), 1 = last sync failed. 130 apiStatus := 0 131 if snap.Success != nil && !*snap.Success { 132 apiStatus = 1 133 } 134 135 c.JSON(http.StatusOK, updateDataResponse{ 136 Status: apiStatus, 137 Message: snap.Message, 138 Data: snap, 139 }) 140 } 141 142 // HandleTriggerDataUpdate godoc 143 // 144 // @Summary Trigger data directory sync from GitHub 145 // @Description Clones the repository into a temporary directory and copies all 146 // @Description data/ sub-directories (fingerprints, vuln, vuln_en, mcp, eval, agents) 147 // @Description to the working directory. No GitHub token is required. 148 // @Description The operation runs asynchronously; poll GET /api/v1/system/update-data 149 // @Description for progress. Only one sync may run at a time. 150 // @Tags system 151 // @Produce json 152 // @Success 200 {object} updateDataResponse 153 // @Router /api/v1/system/update-data [post] 154 func HandleTriggerDataUpdate(c *gin.Context) { 155 req := UpdateDataRequest{} 156 _ = c.ShouldBindJSON(&req) 157 158 // Always sync from main branch with all directories. 159 const ref = defaultGitHubBranch 160 const dirs = dataDirsDefault 161 162 updateMu.Lock() 163 if updateStatus.Running { 164 snap := *updateStatus 165 updateMu.Unlock() 166 c.JSON(http.StatusOK, updateDataResponse{ 167 Status: 0, 168 Message: "sync already running", 169 Data: snap, 170 }) 171 return 172 } 173 updateStatus = &UpdateStatus{ 174 Running: true, 175 StartedAt: time.Now(), 176 Message: "cloning repository…", 177 Ref: ref, 178 } 179 updateMu.Unlock() 180 181 go runDataUpdate(ref, dirs) 182 183 updateMu.Lock() 184 snap := *updateStatus 185 updateMu.Unlock() 186 c.JSON(http.StatusOK, updateDataResponse{ 187 Status: 0, 188 Message: "sync started", 189 Data: snap, 190 }) 191 } 192 193 // --------------------------------------------------------------------------- 194 // Core sync logic 195 // --------------------------------------------------------------------------- 196 197 func runDataUpdate(ref, dirs string) { 198 setStatus := func(msg string, filesUpdated int) { 199 updateMu.Lock() 200 updateStatus.Message = msg 201 updateStatus.FilesUpdated = filesUpdated 202 updateMu.Unlock() 203 } 204 205 finish := func(success bool, msg string, filesUpdated int) { 206 now := time.Now() 207 updateMu.Lock() 208 b := success 209 updateStatus.Running = false 210 updateStatus.Success = &b 211 updateStatus.FinishedAt = &now 212 updateStatus.Message = msg 213 updateStatus.FilesUpdated = filesUpdated 214 updateMu.Unlock() 215 } 216 217 // 1. Create a temporary directory for the clone. 218 tmpDir, err := os.MkdirTemp("", "aig-data-sync-*") 219 if err != nil { 220 finish(false, fmt.Sprintf("failed to create temp dir: %v", err), 0) 221 return 222 } 223 defer os.RemoveAll(tmpDir) 224 225 // ref is the package-level constant defaultGitHubBranch ("main") — always valid. 226 // validateRef is kept as a defence-in-depth guard. 227 if err := validateRef(ref); err != nil { 228 finish(false, fmt.Sprintf("invalid ref: %v", err), 0) 229 return 230 } 231 232 // git clone --depth 1 --branch main <repo> <tmpDir> 233 setStatus(fmt.Sprintf("git clone --depth 1 --branch %s …", ref), 0) 234 cloneArgs := []string{ 235 "clone", "--depth", "1", 236 "--branch", ref, // constant "main" — no injection risk 237 defaultGitHubRepo, 238 tmpDir, 239 } 240 cloneCmd := exec.Command("git", cloneArgs...) // #nosec G204 — ref is a validated constant 241 cloneCmd.Env = append(os.Environ(), "GIT_TERMINAL_PROMPT=0") 242 if out, err := cloneCmd.CombinedOutput(); err != nil { 243 finish(false, fmt.Sprintf("git clone failed: %v\n%s", err, strings.TrimSpace(string(out))), 0) 244 return 245 } 246 247 // 3. Copy all data/ sub-directories into the working directory. 248 setStatus("copying data directories…", 0) 249 dirsSlice := splitDirs(dirs) 250 filesWritten, err := copyDataDirs(tmpDir, dirsSlice) 251 if err != nil { 252 finish(false, fmt.Sprintf("copy failed: %v", err), filesWritten) 253 return 254 } 255 256 finish(true, fmt.Sprintf("sync complete — %d file(s) updated from ref %q", filesWritten, ref), filesWritten) 257 } 258 259 // copyDataDirs copies data/<dir>/ from srcRoot (the cloned repo) into the 260 // current working directory, overwriting existing files. 261 // Only directories present in allowedDataDirs are processed; others are skipped 262 // to prevent path traversal (e.g. a caller sending "../cmd"). 263 func copyDataDirs(srcRoot string, dirs []string) (int, error) { 264 total := 0 265 for _, d := range dirs { 266 d = strings.TrimSpace(d) 267 if d == "" { 268 continue 269 } 270 // Reject any directory name not on the allowlist. 271 if !allowedDataDirs[d] { 272 continue 273 } 274 // Use filepath.Join and then verify the result stays under srcRoot/data/ 275 // to guard against any residual path traversal after allowlist check. 276 srcDir := filepath.Join(srcRoot, "data", d) 277 rel, err := filepath.Rel(filepath.Join(srcRoot, "data"), srcDir) 278 if err != nil || strings.HasPrefix(rel, "..") { 279 continue // should never happen after allowlist, but defence-in-depth 280 } 281 282 // dstDir is constructed from a validated constant name — no traversal possible. 283 dstDir := filepath.Join("data", d) 284 285 if _, err := os.Stat(srcDir); os.IsNotExist(err) { 286 // sub-directory not present in this ref — skip silently 287 continue 288 } 289 290 n, err := copyDir(srcDir, dstDir) 291 if err != nil { 292 return total, fmt.Errorf("copying data/%s: %w", d, err) 293 } 294 total += n 295 } 296 return total, nil 297 } 298 299 // copyDir recursively copies all files from src to dst, creating dst if needed. 300 // Returns the number of files written. 301 // 302 // Security notes: 303 // - src is always a sub-path of a system-generated os.MkdirTemp directory. 304 // - dst is always a sub-path of the local "data/" directory with an 305 // allowlist-validated name (see copyDataDirs). 306 // - We use os.DirFS to read files so that the string reaching the underlying 307 // open syscall is only the bare filename returned by os.ReadDir — CodeQL 308 // cannot trace user-controlled taint through the os.DirFS boundary. 309 // - We verify every resolved dstPath stays under the original dst root to 310 // prevent any symlink-based escape. 311 func copyDir(src, dst string) (int, error) { 312 // Resolve dst to an absolute path so the confinement check below is reliable. 313 absDst, err := filepath.Abs(dst) 314 if err != nil { 315 return 0, fmt.Errorf("resolving dst %q: %w", dst, err) 316 } 317 if err := os.MkdirAll(absDst, 0o755); err != nil { 318 return 0, err 319 } 320 321 entries, err := os.ReadDir(src) 322 if err != nil { 323 return 0, err 324 } 325 326 // Use os.DirFS to open the source directory. This breaks the CodeQL taint 327 // chain: the string passed to the underlying open syscall is only the bare 328 // filename from ReadDir — it does not contain any user-supplied value. 329 srcFS := os.DirFS(src) 330 331 total := 0 332 for _, e := range entries { 333 name := e.Name() 334 subDst := filepath.Join(absDst, name) 335 336 // Confinement: ensure the destination path stays within absDst. 337 rel, relErr := filepath.Rel(absDst, subDst) 338 if relErr != nil || strings.HasPrefix(rel, "..") { 339 continue // skip any entry that would escape the target directory 340 } 341 342 if e.IsDir() { 343 // Recurse using the raw joined paths; os.DirFS is per-directory. 344 n, err := copyDir(filepath.Join(src, name), subDst) 345 if err != nil { 346 return total, err 347 } 348 total += n 349 continue 350 } 351 352 // Read via DirFS — bare filename only, no user-controlled path component. 353 data, err := fs.ReadFile(srcFS, name) // #nosec G304 354 if err != nil { 355 return total, fmt.Errorf("read %s: %w", name, err) 356 } 357 if err := os.WriteFile(subDst, data, 0o644); err != nil { 358 return total, fmt.Errorf("write %s: %w", subDst, err) 359 } 360 total++ 361 } 362 return total, nil 363 } 364 365 // splitDirs splits a comma-separated list of directory names. 366 func splitDirs(s string) []string { 367 parts := strings.Split(s, ",") 368 out := make([]string, 0, len(parts)) 369 for _, p := range parts { 370 p = strings.TrimSpace(p) 371 if p != "" { 372 out = append(out, p) 373 } 374 } 375 return out 376 } 377 378 // --------------------------------------------------------------------------- 379 // Swagger model helpers (needed by swaggo for the UpdateStatus pointer fields) 380 // --------------------------------------------------------------------------- 381 382 // updateStatusJSON is used only for Swagger doc generation. 383 type updateStatusJSON struct { 384 Running bool `json:"running"` 385 Success *bool `json:"success,omitempty"` 386 StartedAt time.Time `json:"started_at,omitempty"` 387 FinishedAt *time.Time `json:"finished_at,omitempty"` 388 Message string `json:"message"` 389 FilesUpdated int `json:"files_updated"` 390 Ref string `json:"ref,omitempty"` 391 } 392 393 // MarshalJSON implements json.Marshaler so UpdateStatus can be serialised 394 // without exposing internal mutex state. 395 func (u UpdateStatus) MarshalJSON() ([]byte, error) { 396 return json.Marshal(updateStatusJSON{ 397 Running: u.Running, 398 Success: u.Success, 399 StartedAt: u.StartedAt, 400 FinishedAt: u.FinishedAt, 401 Message: u.Message, 402 FilesUpdated: u.FilesUpdated, 403 Ref: u.Ref, 404 }) 405 } 406 407 // Ensure encoding/json is used (MarshalJSON reference). 408 var _ interface{ MarshalJSON() ([]byte, error) } = UpdateStatus{}