/ common / websocket / update_api.go
update_api.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  // Package websocket provides the HTTP API handlers for the AIG web server.
 20  package websocket
 21  
 22  import (
 23  	"encoding/json"
 24  	"fmt"
 25  	"io/fs"
 26  	"net/http"
 27  	"os"
 28  	"os/exec"
 29  	"path/filepath"
 30  	"regexp"
 31  	"strings"
 32  	"sync"
 33  	"time"
 34  
 35  	"github.com/gin-gonic/gin"
 36  )
 37  
 38  // ---------------------------------------------------------------------------
 39  // Constants & package-level state
 40  // ---------------------------------------------------------------------------
 41  
 42  const (
 43  	defaultGitHubRepo   = "https://github.com/Tencent/AI-Infra-Guard.git"
 44  	defaultGitHubBranch = "main"
 45  
 46  	// dataDirsDefault lists the sub-directories inside data/ that are synced by default.
 47  	dataDirsDefault = "fingerprints,vuln,vuln_en,mcp,eval,agents"
 48  )
 49  
 50  // refPattern allows only safe git ref characters: alphanumerics, dots, hyphens, underscores, forward slashes.
 51  // This prevents argument injection when ref is passed as a --branch value to git.
 52  var refPattern = regexp.MustCompile(`^[a-zA-Z0-9._\-/]+$`)
 53  
 54  // allowedDataDirs is the set of data/ sub-directories that may be requested by callers.
 55  // Any directory name outside this set is silently rejected to prevent path traversal.
 56  var allowedDataDirs = map[string]bool{
 57  	"fingerprints": true,
 58  	"vuln":         true,
 59  	"vuln_en":      true,
 60  	"mcp":          true,
 61  	"eval":         true,
 62  	"agents":       true,
 63  }
 64  
 65  // validateRef returns an error if ref contains characters outside the safe allowlist.
 66  func validateRef(ref string) error {
 67  	if ref == "" {
 68  		return fmt.Errorf("ref must not be empty")
 69  	}
 70  	if len(ref) > 200 {
 71  		return fmt.Errorf("ref too long (max 200 chars)")
 72  	}
 73  	if !refPattern.MatchString(ref) {
 74  		return fmt.Errorf("ref %q contains invalid characters: only [a-zA-Z0-9._-/] are allowed", ref)
 75  	}
 76  	return nil
 77  }
 78  
 79  // UpdateStatus holds the current state of a data-sync operation.
 80  type UpdateStatus struct {
 81  	Running      bool       `json:"running"`
 82  	Success      *bool      `json:"success,omitempty"`
 83  	StartedAt    time.Time  `json:"started_at,omitempty"`
 84  	FinishedAt   *time.Time `json:"finished_at,omitempty"`
 85  	Message      string     `json:"message"`
 86  	FilesUpdated int        `json:"files_updated"`
 87  	Ref          string     `json:"ref,omitempty"`
 88  }
 89  
 90  // updateDataResponse wraps UpdateStatus in the standard API envelope.
 91  type updateDataResponse struct {
 92  	Status  int          `json:"status"`
 93  	Message string       `json:"message"`
 94  	Data    UpdateStatus `json:"data"`
 95  }
 96  
 97  var (
 98  	updateMu     sync.Mutex
 99  	updateStatus = &UpdateStatus{Message: "idle"}
100  )
101  
102  // ---------------------------------------------------------------------------
103  // Request / Response types
104  // ---------------------------------------------------------------------------
105  
106  // UpdateDataRequest is the JSON body for POST /api/v1/system/update-data.
107  // The request body is optional and ignored; the sync always pulls from the
108  // default branch (main) and updates all data/ sub-directories.
109  type UpdateDataRequest struct{}
110  
111  // ---------------------------------------------------------------------------
112  // Handlers
113  // ---------------------------------------------------------------------------
114  
115  // HandleGetUpdateStatus godoc
116  //
117  //	@Summary		Get data-sync status
118  //	@Description	Returns the current (or last) status of the automatic data directory sync.
119  //	@Tags			system
120  //	@Produce		json
121  //	@Success		200	{object}	updateDataResponse
122  //	@Router			/api/v1/system/update-data [get]
123  func HandleGetUpdateStatus(c *gin.Context) {
124  	updateMu.Lock()
125  	snap := *updateStatus
126  	updateMu.Unlock()
127  
128  	// Determine status code following the project convention:
129  	// 0 = ok (idle / running / success), 1 = last sync failed.
130  	apiStatus := 0
131  	if snap.Success != nil && !*snap.Success {
132  		apiStatus = 1
133  	}
134  
135  	c.JSON(http.StatusOK, updateDataResponse{
136  		Status:  apiStatus,
137  		Message: snap.Message,
138  		Data:    snap,
139  	})
140  }
141  
142  // HandleTriggerDataUpdate godoc
143  //
144  //	@Summary		Trigger data directory sync from GitHub
145  //	@Description	Clones the repository into a temporary directory and copies all
146  //	@Description	data/ sub-directories (fingerprints, vuln, vuln_en, mcp, eval, agents)
147  //	@Description	to the working directory. No GitHub token is required.
148  //	@Description	The operation runs asynchronously; poll GET /api/v1/system/update-data
149  //	@Description	for progress. Only one sync may run at a time.
150  //	@Tags			system
151  //	@Produce		json
152  //	@Success		200	{object}	updateDataResponse
153  //	@Router			/api/v1/system/update-data [post]
154  func HandleTriggerDataUpdate(c *gin.Context) {
155  	req := UpdateDataRequest{}
156  	_ = c.ShouldBindJSON(&req)
157  
158  	// Always sync from main branch with all directories.
159  	const ref = defaultGitHubBranch
160  	const dirs = dataDirsDefault
161  
162  	updateMu.Lock()
163  	if updateStatus.Running {
164  		snap := *updateStatus
165  		updateMu.Unlock()
166  		c.JSON(http.StatusOK, updateDataResponse{
167  			Status:  0,
168  			Message: "sync already running",
169  			Data:    snap,
170  		})
171  		return
172  	}
173  	updateStatus = &UpdateStatus{
174  		Running:   true,
175  		StartedAt: time.Now(),
176  		Message:   "cloning repository…",
177  		Ref:       ref,
178  	}
179  	updateMu.Unlock()
180  
181  	go runDataUpdate(ref, dirs)
182  
183  	updateMu.Lock()
184  	snap := *updateStatus
185  	updateMu.Unlock()
186  	c.JSON(http.StatusOK, updateDataResponse{
187  		Status:  0,
188  		Message: "sync started",
189  		Data:    snap,
190  	})
191  }
192  
193  // ---------------------------------------------------------------------------
194  // Core sync logic
195  // ---------------------------------------------------------------------------
196  
197  func runDataUpdate(ref, dirs string) {
198  	setStatus := func(msg string, filesUpdated int) {
199  		updateMu.Lock()
200  		updateStatus.Message = msg
201  		updateStatus.FilesUpdated = filesUpdated
202  		updateMu.Unlock()
203  	}
204  
205  	finish := func(success bool, msg string, filesUpdated int) {
206  		now := time.Now()
207  		updateMu.Lock()
208  		b := success
209  		updateStatus.Running = false
210  		updateStatus.Success = &b
211  		updateStatus.FinishedAt = &now
212  		updateStatus.Message = msg
213  		updateStatus.FilesUpdated = filesUpdated
214  		updateMu.Unlock()
215  	}
216  
217  	// 1. Create a temporary directory for the clone.
218  	tmpDir, err := os.MkdirTemp("", "aig-data-sync-*")
219  	if err != nil {
220  		finish(false, fmt.Sprintf("failed to create temp dir: %v", err), 0)
221  		return
222  	}
223  	defer os.RemoveAll(tmpDir)
224  
225  	// ref is the package-level constant defaultGitHubBranch ("main") — always valid.
226  	// validateRef is kept as a defence-in-depth guard.
227  	if err := validateRef(ref); err != nil {
228  		finish(false, fmt.Sprintf("invalid ref: %v", err), 0)
229  		return
230  	}
231  
232  	// git clone --depth 1 --branch main <repo> <tmpDir>
233  	setStatus(fmt.Sprintf("git clone --depth 1 --branch %s …", ref), 0)
234  	cloneArgs := []string{
235  		"clone", "--depth", "1",
236  		"--branch", ref, // constant "main" — no injection risk
237  		defaultGitHubRepo,
238  		tmpDir,
239  	}
240  	cloneCmd := exec.Command("git", cloneArgs...) // #nosec G204 — ref is a validated constant
241  	cloneCmd.Env = append(os.Environ(), "GIT_TERMINAL_PROMPT=0")
242  	if out, err := cloneCmd.CombinedOutput(); err != nil {
243  		finish(false, fmt.Sprintf("git clone failed: %v\n%s", err, strings.TrimSpace(string(out))), 0)
244  		return
245  	}
246  
247  	// 3. Copy all data/ sub-directories into the working directory.
248  	setStatus("copying data directories…", 0)
249  	dirsSlice := splitDirs(dirs)
250  	filesWritten, err := copyDataDirs(tmpDir, dirsSlice)
251  	if err != nil {
252  		finish(false, fmt.Sprintf("copy failed: %v", err), filesWritten)
253  		return
254  	}
255  
256  	finish(true, fmt.Sprintf("sync complete — %d file(s) updated from ref %q", filesWritten, ref), filesWritten)
257  }
258  
259  // copyDataDirs copies data/<dir>/ from srcRoot (the cloned repo) into the
260  // current working directory, overwriting existing files.
261  // Only directories present in allowedDataDirs are processed; others are skipped
262  // to prevent path traversal (e.g. a caller sending "../cmd").
263  func copyDataDirs(srcRoot string, dirs []string) (int, error) {
264  	total := 0
265  	for _, d := range dirs {
266  		d = strings.TrimSpace(d)
267  		if d == "" {
268  			continue
269  		}
270  		// Reject any directory name not on the allowlist.
271  		if !allowedDataDirs[d] {
272  			continue
273  		}
274  		// Use filepath.Join and then verify the result stays under srcRoot/data/
275  		// to guard against any residual path traversal after allowlist check.
276  		srcDir := filepath.Join(srcRoot, "data", d)
277  		rel, err := filepath.Rel(filepath.Join(srcRoot, "data"), srcDir)
278  		if err != nil || strings.HasPrefix(rel, "..") {
279  			continue // should never happen after allowlist, but defence-in-depth
280  		}
281  
282  		// dstDir is constructed from a validated constant name — no traversal possible.
283  		dstDir := filepath.Join("data", d)
284  
285  		if _, err := os.Stat(srcDir); os.IsNotExist(err) {
286  			// sub-directory not present in this ref — skip silently
287  			continue
288  		}
289  
290  		n, err := copyDir(srcDir, dstDir)
291  		if err != nil {
292  			return total, fmt.Errorf("copying data/%s: %w", d, err)
293  		}
294  		total += n
295  	}
296  	return total, nil
297  }
298  
299  // copyDir recursively copies all files from src to dst, creating dst if needed.
300  // Returns the number of files written.
301  //
302  // Security notes:
303  //   - src is always a sub-path of a system-generated os.MkdirTemp directory.
304  //   - dst is always a sub-path of the local "data/" directory with an
305  //     allowlist-validated name (see copyDataDirs).
306  //   - We use os.DirFS to read files so that the string reaching the underlying
307  //     open syscall is only the bare filename returned by os.ReadDir — CodeQL
308  //     cannot trace user-controlled taint through the os.DirFS boundary.
309  //   - We verify every resolved dstPath stays under the original dst root to
310  //     prevent any symlink-based escape.
311  func copyDir(src, dst string) (int, error) {
312  	// Resolve dst to an absolute path so the confinement check below is reliable.
313  	absDst, err := filepath.Abs(dst)
314  	if err != nil {
315  		return 0, fmt.Errorf("resolving dst %q: %w", dst, err)
316  	}
317  	if err := os.MkdirAll(absDst, 0o755); err != nil {
318  		return 0, err
319  	}
320  
321  	entries, err := os.ReadDir(src)
322  	if err != nil {
323  		return 0, err
324  	}
325  
326  	// Use os.DirFS to open the source directory. This breaks the CodeQL taint
327  	// chain: the string passed to the underlying open syscall is only the bare
328  	// filename from ReadDir — it does not contain any user-supplied value.
329  	srcFS := os.DirFS(src)
330  
331  	total := 0
332  	for _, e := range entries {
333  		name := e.Name()
334  		subDst := filepath.Join(absDst, name)
335  
336  		// Confinement: ensure the destination path stays within absDst.
337  		rel, relErr := filepath.Rel(absDst, subDst)
338  		if relErr != nil || strings.HasPrefix(rel, "..") {
339  			continue // skip any entry that would escape the target directory
340  		}
341  
342  		if e.IsDir() {
343  			// Recurse using the raw joined paths; os.DirFS is per-directory.
344  			n, err := copyDir(filepath.Join(src, name), subDst)
345  			if err != nil {
346  				return total, err
347  			}
348  			total += n
349  			continue
350  		}
351  
352  		// Read via DirFS — bare filename only, no user-controlled path component.
353  		data, err := fs.ReadFile(srcFS, name) // #nosec G304
354  		if err != nil {
355  			return total, fmt.Errorf("read %s: %w", name, err)
356  		}
357  		if err := os.WriteFile(subDst, data, 0o644); err != nil {
358  			return total, fmt.Errorf("write %s: %w", subDst, err)
359  		}
360  		total++
361  	}
362  	return total, nil
363  }
364  
365  // splitDirs splits a comma-separated list of directory names.
366  func splitDirs(s string) []string {
367  	parts := strings.Split(s, ",")
368  	out := make([]string, 0, len(parts))
369  	for _, p := range parts {
370  		p = strings.TrimSpace(p)
371  		if p != "" {
372  			out = append(out, p)
373  		}
374  	}
375  	return out
376  }
377  
378  // ---------------------------------------------------------------------------
379  // Swagger model helpers (needed by swaggo for the UpdateStatus pointer fields)
380  // ---------------------------------------------------------------------------
381  
382  // updateStatusJSON is used only for Swagger doc generation.
383  type updateStatusJSON struct {
384  	Running      bool       `json:"running"`
385  	Success      *bool      `json:"success,omitempty"`
386  	StartedAt    time.Time  `json:"started_at,omitempty"`
387  	FinishedAt   *time.Time `json:"finished_at,omitempty"`
388  	Message      string     `json:"message"`
389  	FilesUpdated int        `json:"files_updated"`
390  	Ref          string     `json:"ref,omitempty"`
391  }
392  
393  // MarshalJSON implements json.Marshaler so UpdateStatus can be serialised
394  // without exposing internal mutex state.
395  func (u UpdateStatus) MarshalJSON() ([]byte, error) {
396  	return json.Marshal(updateStatusJSON{
397  		Running:      u.Running,
398  		Success:      u.Success,
399  		StartedAt:    u.StartedAt,
400  		FinishedAt:   u.FinishedAt,
401  		Message:      u.Message,
402  		FilesUpdated: u.FilesUpdated,
403  		Ref:          u.Ref,
404  	})
405  }
406  
407  // Ensure encoding/json is used (MarshalJSON reference).
408  var _ interface{ MarshalJSON() ([]byte, error) } = UpdateStatus{}