/ common / utils / api.go
api.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  package utils
 20  
 21  import (
 22  	"bytes"
 23  	"encoding/json"
 24  	"fmt"
 25  	"io"
 26  	"mime/multipart"
 27  	"net/http"
 28  	"os"
 29  	"path/filepath"
 30  	"strings"
 31  
 32  	"github.com/Tencent/AI-Infra-Guard/common/fingerprints/parser"
 33  	"github.com/Tencent/AI-Infra-Guard/internal/gologger"
 34  )
 35  
 36  // DownloadFile 下载文件
 37  // path 参数必须由调用方在调用前完成路径安全校验(防止路径穿越),
 38  // 本函数仅负责 HTTP 下载写入,不做路径验证。
 39  func DownloadFile(server, sessionId, uri, path string) error {
 40  	// Validate that path is not empty and does not contain path traversal sequences.
 41  	// Callers are responsible for ensuring path is within an expected directory.
 42  	if path == "" || strings.Contains(path, "..") {
 43  		return fmt.Errorf("非法文件路径")
 44  	}
 45  	// 创建 HTTP 客户端
 46  	client := &http.Client{}
 47  
 48  	data := map[string]string{
 49  		"fileUrl": uri,
 50  	}
 51  	jsonData, err := json.Marshal(data)
 52  	// 创建请求并添加 header
 53  	req, err := http.NewRequest("POST", fmt.Sprintf("http://%s/api/v1/app/tasks/%s/downloadFile", server, sessionId), io.NopCloser(bytes.NewBuffer(jsonData)))
 54  	if err != nil {
 55  		return err
 56  	}
 57  	req.Header.Set("Content-Type", "application/json")
 58  	req.Header.Set("X-APIKey", "zhuque")
 59  
 60  	// 发送 POST 请求
 61  	resp, err := client.Do(req)
 62  	if err != nil {
 63  		return err
 64  	}
 65  	defer resp.Body.Close()
 66  
 67  	// 检查 HTTP 状态码
 68  	if resp.StatusCode != http.StatusOK {
 69  		dd, _ := io.ReadAll(resp.Body)
 70  		return fmt.Errorf("下载失败,HTTP 状态码:%d content:%s", resp.StatusCode, string(dd))
 71  	}
 72  
 73  	// 创建文件
 74  	file, err := os.Create(path)
 75  	if err != nil {
 76  		return err
 77  	}
 78  	defer file.Close()
 79  
 80  	// 将响应体复制到文件
 81  	_, err = io.Copy(file, resp.Body)
 82  	if err != nil {
 83  		return err
 84  	}
 85  
 86  	return nil
 87  }
 88  
 89  // UploadFileResponse 上传文件响应结构
 90  type UploadFileResponse struct {
 91  	Status  int    `json:"status"`
 92  	Message string `json:"message"`
 93  	Data    struct {
 94  		FileUrl  string `json:"fileUrl"`
 95  		Filename string `json:"filename"`
 96  	} `json:"data"`
 97  }
 98  
 99  // UploadFile 上传文件到服务器
100  func UploadFile(server, filePath string) (*UploadFileResponse, error) {
101  	// 打开文件
102  	file, err := os.Open(filePath)
103  	if err != nil {
104  		return nil, fmt.Errorf("无法打开文件: %v", err)
105  	}
106  	defer file.Close()
107  
108  	// 创建 multipart writer
109  	var requestBody bytes.Buffer
110  	writer := multipart.NewWriter(&requestBody)
111  
112  	// 创建文件字段
113  	part, err := writer.CreateFormFile("file", filepath.Base(filePath))
114  	if err != nil {
115  		return nil, fmt.Errorf("创建文件字段失败: %v", err)
116  	}
117  
118  	// 将文件内容复制到 part
119  	_, err = io.Copy(part, file)
120  	if err != nil {
121  		return nil, fmt.Errorf("复制文件内容失败: %v", err)
122  	}
123  
124  	// 关闭 writer
125  	err = writer.Close()
126  	if err != nil {
127  		return nil, fmt.Errorf("关闭 writer 失败: %v", err)
128  	}
129  
130  	// 创建 HTTP 请求
131  	req, err := http.NewRequest("POST", fmt.Sprintf("http://%s/api/v1/app/tasks/uploadFile", server), &requestBody)
132  	if err != nil {
133  		return nil, fmt.Errorf("创建请求失败: %v", err)
134  	}
135  
136  	// 设置 Content-Type
137  	req.Header.Set("Content-Type", writer.FormDataContentType())
138  	req.Header.Set("X-APIKey", "zhuque")
139  
140  	// 发送请求
141  	client := &http.Client{}
142  	resp, err := client.Do(req)
143  	if err != nil {
144  		return nil, fmt.Errorf("发送请求失败: %v", err)
145  	}
146  	defer resp.Body.Close()
147  
148  	// 读取响应体
149  	respBody, err := io.ReadAll(resp.Body)
150  	if err != nil {
151  		return nil, fmt.Errorf("读取响应失败: %v", err)
152  	}
153  
154  	// 检查 HTTP 状态码
155  	if resp.StatusCode != http.StatusOK {
156  		return nil, fmt.Errorf("上传失败,HTTP 状态码:%d content:%s", resp.StatusCode, string(respBody))
157  	}
158  
159  	// 解析响应 JSON
160  	var uploadResp UploadFileResponse
161  	err = json.Unmarshal(respBody, &uploadResp)
162  	if err != nil {
163  		return nil, fmt.Errorf("解析响应 JSON 失败: %v", err)
164  	}
165  
166  	return &uploadResp, nil
167  }
168  
169  func GetEvaluationsDetail(server, name string) ([]byte, error) {
170  	path := "/api/v1/knowledge/evaluations/" + name
171  	// 创建 HTTP 请求
172  	req, err := http.NewRequest("GET", fmt.Sprintf("http://%s%s", server, path), nil)
173  	if err != nil {
174  		return nil, fmt.Errorf("创建请求失败: %v", err)
175  	}
176  	req.Header.Set("X-APIKey", "zhuque")
177  
178  	// 发送请求
179  	client := &http.Client{}
180  	resp, err := client.Do(req)
181  	if err != nil {
182  		return nil, fmt.Errorf("发送请求失败: %v", err)
183  	}
184  	defer resp.Body.Close()
185  
186  	// 读取响应体
187  	respBody, err := io.ReadAll(resp.Body)
188  	if err != nil {
189  		return nil, fmt.Errorf("读取响应失败: %v", err)
190  	}
191  
192  	// 检查 HTTP 状态码
193  	if resp.StatusCode != http.StatusOK {
194  		return nil, fmt.Errorf("上传失败,HTTP 状态码:%d content:%s", resp.StatusCode, string(respBody))
195  	}
196  
197  	var msg struct {
198  		Data json.RawMessage `json:"data"`
199  	}
200  	err = json.Unmarshal(respBody, &msg)
201  	if err != nil {
202  		return nil, fmt.Errorf("解析响应 JSON 失败: %v", err)
203  	}
204  	return msg.Data, nil
205  }
206  
207  func LoadRemoteFingerPrints(hostname string) ([]parser.FingerPrint, error) {
208  	type msg struct {
209  		Data struct {
210  			FingerPrints []json.RawMessage `json:"items"`
211  			Total        int               `json:"total"`
212  		} `json:"data"`
213  		Message string `json:"message"`
214  	}
215  	// 创建请求并添加 header
216  	req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/api/v1/knowledge/fingerprints?page=1&size=9999", hostname), nil)
217  	if err != nil {
218  		return nil, err
219  	}
220  	req.Header.Set("X-APIKey", "zhuque")
221  
222  	// 发送请求
223  	client := &http.Client{}
224  	resp, err := client.Do(req)
225  	if err != nil {
226  		return nil, err
227  	}
228  	defer resp.Body.Close()
229  	if resp.StatusCode != http.StatusOK {
230  		return nil, fmt.Errorf("http status code: %d", resp.StatusCode)
231  	}
232  	data, err := io.ReadAll(resp.Body)
233  	if err != nil {
234  		return nil, err
235  	}
236  	var m msg
237  	if err := json.Unmarshal(data, &m); err != nil {
238  		return nil, err
239  	}
240  	fps := make([]parser.FingerPrint, 0)
241  	for _, raw := range m.Data.FingerPrints {
242  		fp, err := parser.InitFingerPrintFromData(raw)
243  		if err != nil {
244  			gologger.WithError(err).Fatalf("无法解析指纹模板:%s", string(raw))
245  			continue
246  		}
247  		fps = append(fps, *fp)
248  	}
249  	return fps, nil
250  }
251  
252  func LoadRemoteVulStruct(api string) ([]json.RawMessage, error) {
253  	type msg struct {
254  		Data struct {
255  			Vuls  []json.RawMessage `json:"items"`
256  			Total int               `json:"total"`
257  		} `json:"data"`
258  		Message string `json:"message"`
259  	}
260  	// 创建请求并添加 header
261  	req, err := http.NewRequest("GET", api, nil)
262  	if err != nil {
263  		return nil, err
264  	}
265  	req.Header.Set("X-APIKey", "zhuque")
266  
267  	// 发送请求
268  	client := &http.Client{}
269  	resp, err := client.Do(req)
270  	if err != nil {
271  		return nil, err
272  	}
273  	defer resp.Body.Close()
274  	if resp.StatusCode != http.StatusOK {
275  		return nil, fmt.Errorf("http status code: %d", resp.StatusCode)
276  	}
277  	data, err := io.ReadAll(resp.Body)
278  	if err != nil {
279  		return nil, err
280  	}
281  	var m msg
282  	if err := json.Unmarshal(data, &m); err != nil {
283  		return nil, err
284  	}
285  	return m.Data.Vuls, nil
286  }