api.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 package utils 20 21 import ( 22 "bytes" 23 "encoding/json" 24 "fmt" 25 "io" 26 "mime/multipart" 27 "net/http" 28 "os" 29 "path/filepath" 30 "strings" 31 32 "github.com/Tencent/AI-Infra-Guard/common/fingerprints/parser" 33 "github.com/Tencent/AI-Infra-Guard/internal/gologger" 34 ) 35 36 // DownloadFile 下载文件 37 // path 参数必须由调用方在调用前完成路径安全校验(防止路径穿越), 38 // 本函数仅负责 HTTP 下载写入,不做路径验证。 39 func DownloadFile(server, sessionId, uri, path string) error { 40 // Validate that path is not empty and does not contain path traversal sequences. 41 // Callers are responsible for ensuring path is within an expected directory. 42 if path == "" || strings.Contains(path, "..") { 43 return fmt.Errorf("非法文件路径") 44 } 45 // 创建 HTTP 客户端 46 client := &http.Client{} 47 48 data := map[string]string{ 49 "fileUrl": uri, 50 } 51 jsonData, err := json.Marshal(data) 52 // 创建请求并添加 header 53 req, err := http.NewRequest("POST", fmt.Sprintf("http://%s/api/v1/app/tasks/%s/downloadFile", server, sessionId), io.NopCloser(bytes.NewBuffer(jsonData))) 54 if err != nil { 55 return err 56 } 57 req.Header.Set("Content-Type", "application/json") 58 req.Header.Set("X-APIKey", "zhuque") 59 60 // 发送 POST 请求 61 resp, err := client.Do(req) 62 if err != nil { 63 return err 64 } 65 defer resp.Body.Close() 66 67 // 检查 HTTP 状态码 68 if resp.StatusCode != http.StatusOK { 69 dd, _ := io.ReadAll(resp.Body) 70 return fmt.Errorf("下载失败,HTTP 状态码:%d content:%s", resp.StatusCode, string(dd)) 71 } 72 73 // 创建文件 74 file, err := os.Create(path) 75 if err != nil { 76 return err 77 } 78 defer file.Close() 79 80 // 将响应体复制到文件 81 _, err = io.Copy(file, resp.Body) 82 if err != nil { 83 return err 84 } 85 86 return nil 87 } 88 89 // UploadFileResponse 上传文件响应结构 90 type UploadFileResponse struct { 91 Status int `json:"status"` 92 Message string `json:"message"` 93 Data struct { 94 FileUrl string `json:"fileUrl"` 95 Filename string `json:"filename"` 96 } `json:"data"` 97 } 98 99 // UploadFile 上传文件到服务器 100 func UploadFile(server, filePath string) (*UploadFileResponse, error) { 101 // 打开文件 102 file, err := os.Open(filePath) 103 if err != nil { 104 return nil, fmt.Errorf("无法打开文件: %v", err) 105 } 106 defer file.Close() 107 108 // 创建 multipart writer 109 var requestBody bytes.Buffer 110 writer := multipart.NewWriter(&requestBody) 111 112 // 创建文件字段 113 part, err := writer.CreateFormFile("file", filepath.Base(filePath)) 114 if err != nil { 115 return nil, fmt.Errorf("创建文件字段失败: %v", err) 116 } 117 118 // 将文件内容复制到 part 119 _, err = io.Copy(part, file) 120 if err != nil { 121 return nil, fmt.Errorf("复制文件内容失败: %v", err) 122 } 123 124 // 关闭 writer 125 err = writer.Close() 126 if err != nil { 127 return nil, fmt.Errorf("关闭 writer 失败: %v", err) 128 } 129 130 // 创建 HTTP 请求 131 req, err := http.NewRequest("POST", fmt.Sprintf("http://%s/api/v1/app/tasks/uploadFile", server), &requestBody) 132 if err != nil { 133 return nil, fmt.Errorf("创建请求失败: %v", err) 134 } 135 136 // 设置 Content-Type 137 req.Header.Set("Content-Type", writer.FormDataContentType()) 138 req.Header.Set("X-APIKey", "zhuque") 139 140 // 发送请求 141 client := &http.Client{} 142 resp, err := client.Do(req) 143 if err != nil { 144 return nil, fmt.Errorf("发送请求失败: %v", err) 145 } 146 defer resp.Body.Close() 147 148 // 读取响应体 149 respBody, err := io.ReadAll(resp.Body) 150 if err != nil { 151 return nil, fmt.Errorf("读取响应失败: %v", err) 152 } 153 154 // 检查 HTTP 状态码 155 if resp.StatusCode != http.StatusOK { 156 return nil, fmt.Errorf("上传失败,HTTP 状态码:%d content:%s", resp.StatusCode, string(respBody)) 157 } 158 159 // 解析响应 JSON 160 var uploadResp UploadFileResponse 161 err = json.Unmarshal(respBody, &uploadResp) 162 if err != nil { 163 return nil, fmt.Errorf("解析响应 JSON 失败: %v", err) 164 } 165 166 return &uploadResp, nil 167 } 168 169 func GetEvaluationsDetail(server, name string) ([]byte, error) { 170 path := "/api/v1/knowledge/evaluations/" + name 171 // 创建 HTTP 请求 172 req, err := http.NewRequest("GET", fmt.Sprintf("http://%s%s", server, path), nil) 173 if err != nil { 174 return nil, fmt.Errorf("创建请求失败: %v", err) 175 } 176 req.Header.Set("X-APIKey", "zhuque") 177 178 // 发送请求 179 client := &http.Client{} 180 resp, err := client.Do(req) 181 if err != nil { 182 return nil, fmt.Errorf("发送请求失败: %v", err) 183 } 184 defer resp.Body.Close() 185 186 // 读取响应体 187 respBody, err := io.ReadAll(resp.Body) 188 if err != nil { 189 return nil, fmt.Errorf("读取响应失败: %v", err) 190 } 191 192 // 检查 HTTP 状态码 193 if resp.StatusCode != http.StatusOK { 194 return nil, fmt.Errorf("上传失败,HTTP 状态码:%d content:%s", resp.StatusCode, string(respBody)) 195 } 196 197 var msg struct { 198 Data json.RawMessage `json:"data"` 199 } 200 err = json.Unmarshal(respBody, &msg) 201 if err != nil { 202 return nil, fmt.Errorf("解析响应 JSON 失败: %v", err) 203 } 204 return msg.Data, nil 205 } 206 207 func LoadRemoteFingerPrints(hostname string) ([]parser.FingerPrint, error) { 208 type msg struct { 209 Data struct { 210 FingerPrints []json.RawMessage `json:"items"` 211 Total int `json:"total"` 212 } `json:"data"` 213 Message string `json:"message"` 214 } 215 // 创建请求并添加 header 216 req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/api/v1/knowledge/fingerprints?page=1&size=9999", hostname), nil) 217 if err != nil { 218 return nil, err 219 } 220 req.Header.Set("X-APIKey", "zhuque") 221 222 // 发送请求 223 client := &http.Client{} 224 resp, err := client.Do(req) 225 if err != nil { 226 return nil, err 227 } 228 defer resp.Body.Close() 229 if resp.StatusCode != http.StatusOK { 230 return nil, fmt.Errorf("http status code: %d", resp.StatusCode) 231 } 232 data, err := io.ReadAll(resp.Body) 233 if err != nil { 234 return nil, err 235 } 236 var m msg 237 if err := json.Unmarshal(data, &m); err != nil { 238 return nil, err 239 } 240 fps := make([]parser.FingerPrint, 0) 241 for _, raw := range m.Data.FingerPrints { 242 fp, err := parser.InitFingerPrintFromData(raw) 243 if err != nil { 244 gologger.WithError(err).Fatalf("无法解析指纹模板:%s", string(raw)) 245 continue 246 } 247 fps = append(fps, *fp) 248 } 249 return fps, nil 250 } 251 252 func LoadRemoteVulStruct(api string) ([]json.RawMessage, error) { 253 type msg struct { 254 Data struct { 255 Vuls []json.RawMessage `json:"items"` 256 Total int `json:"total"` 257 } `json:"data"` 258 Message string `json:"message"` 259 } 260 // 创建请求并添加 header 261 req, err := http.NewRequest("GET", api, nil) 262 if err != nil { 263 return nil, err 264 } 265 req.Header.Set("X-APIKey", "zhuque") 266 267 // 发送请求 268 client := &http.Client{} 269 resp, err := client.Do(req) 270 if err != nil { 271 return nil, err 272 } 273 defer resp.Body.Close() 274 if resp.StatusCode != http.StatusOK { 275 return nil, fmt.Errorf("http status code: %d", resp.StatusCode) 276 } 277 data, err := io.ReadAll(resp.Body) 278 if err != nil { 279 return nil, err 280 } 281 var m msg 282 if err := json.Unmarshal(data, &m); err != nil { 283 return nil, err 284 } 285 return m.Data.Vuls, nil 286 }