/ common / fingerprints / preload / preload.go
preload.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  // Package preload 漏洞指纹判断golang语言写法
 20  package preload
 21  
 22  import (
 23  	"crypto/sha256"
 24  	"encoding/hex"
 25  	"regexp"
 26  	"strconv"
 27  	"strings"
 28  	"sync"
 29  
 30  	"github.com/Tencent/AI-Infra-Guard/common/fingerprints/parser"
 31  	"github.com/Tencent/AI-Infra-Guard/internal/gologger"
 32  	"github.com/Tencent/AI-Infra-Guard/pkg/httpx"
 33  	"github.com/remeh/sizedwaitgroup"
 34  )
 35  
 36  // FingerPrintFunc 指纹识别接口
 37  // 实现此接口可以添加自定义的指纹识别逻辑
 38  type FingerPrintFunc interface {
 39  	Match(httpx *httpx.HTTPX, uri string) bool
 40  	GetVersion(httpx *httpx.HTTPX, uri string) (string, error)
 41  	Name() string
 42  }
 43  
 44  // CollectedFpReqs 返回所有已注册的指纹识别实现
 45  func CollectedFpReqs() []FingerPrintFunc {
 46  	return []FingerPrintFunc{
 47  		Mlflow{},
 48  	}
 49  }
 50  
 51  // FpResult 指纹结构体
 52  type FpResult struct {
 53  	Name    string `json:"name"`
 54  	Version string `json:"version,omitempty"`
 55  	Type    string `json:"type,omitempty"`
 56  }
 57  
 58  // Runner 指纹识别运行器
 59  // 用于执行指纹识别任务
 60  type Runner struct {
 61  	hp  *httpx.HTTPX
 62  	fps []parser.FingerPrint
 63  }
 64  
 65  // New 创建新的Runner实例
 66  func New(hp *httpx.HTTPX, fps parser.FingerPrints) *Runner {
 67  	r := &Runner{hp, fps}
 68  	return r
 69  }
 70  
 71  // RunFpReqs 执行指纹识别
 72  // uri: 目标URL
 73  // concurrent: 并发数
 74  // faviconHash: favicon图标的hash值
 75  // 返回识别到的指纹结果列表
 76  func (r *Runner) RunFpReqs(uri string, concurrent int, faviconHash int32) []FpResult {
 77  	wg := sizedwaitgroup.New(concurrent)
 78  	mux := sync.Mutex{}
 79  	ret := make([]FpResult, 0)
 80  	uri = strings.TrimRight(uri, "/")
 81  
 82  	indexCache, _ := r.hp.Get(uri+"/", nil)
 83  
 84  	for _, fp := range r.fps {
 85  		wg.Add()
 86  		go func(fp parser.FingerPrint) {
 87  			defer wg.Done()
 88  			var resp *httpx.Response
 89  			var err error
 90  			for _, req := range fp.Http {
 91  				if req.Path == "/" && req.Method == "GET" {
 92  					resp = indexCache
 93  				} else {
 94  					if req.Method == "POST" {
 95  						resp, err = r.hp.POST(uri+req.Path, req.Data, nil)
 96  					} else {
 97  						resp, err = r.hp.Get(uri+req.Path, nil)
 98  					}
 99  					if err != nil {
100  						gologger.WithError(err).Debugln("请求失败")
101  						continue
102  					}
103  				}
104  				if resp == nil {
105  					continue
106  				}
107  				sum := sha256.Sum256(resp.Data)
108  				respHash := hex.EncodeToString(sum[:])
109  				fpConfig := parser.Config{
110  					Body:   resp.DataStr,
111  					Header: resp.GetHeaderRaw(),
112  					Icon:   faviconHash,
113  					Hash:   respHash,
114  				}
115  
116  				matched := false
117  				if len(req.GetDsl()) == 0 {
118  					matched = true
119  				} else {
120  					for _, dsl := range req.GetDsl() {
121  						if parser.Eval(&fpConfig, dsl) {
122  							matched = true
123  							break
124  						}
125  					}
126  				}
127  
128  				if matched {
129  					name := fp.Info.Name
130  					version := ""
131  					version, err := EvalFpVersion(uri, r.hp, fp)
132  					if err != nil {
133  						gologger.WithError(err).Errorln("获取版本失败")
134  					}
135  					mux.Lock()
136  					type_, ok := fp.Info.Metadata["type"]
137  					if !ok {
138  						type_ = ""
139  					}
140  					ret = append(ret, FpResult{
141  						Name:    name,
142  						Version: version,
143  						Type:    type_,
144  					})
145  					mux.Unlock()
146  				}
147  			}
148  		}(fp)
149  	}
150  	for _, fpReq := range CollectedFpReqs() {
151  		wg.Add()
152  		go func(fpReq FingerPrintFunc) {
153  			defer wg.Done()
154  			if fpReq.Match(r.hp, uri) {
155  				fpresult := FpResult{
156  					Name:    fpReq.Name(),
157  					Version: "",
158  					Type:    "",
159  				}
160  				version, err := fpReq.GetVersion(r.hp, uri)
161  				if err == nil {
162  					fpresult.Version = version
163  				}
164  				mux.Lock()
165  				ret = append(ret, fpresult)
166  				mux.Unlock()
167  			}
168  		}(fpReq)
169  	}
170  	wg.Wait()
171  	ret = r.Deduplication(ret)
172  	return ret
173  }
174  
175  // Deduplication 对指纹识别结果进行去重
176  // 如果存在相同名称的指纹,保留版本号不为空的结果
177  func (r *Runner) Deduplication(results []FpResult) []FpResult {
178  	var ret []FpResult
179  	var dup = make(map[string]string)
180  	for _, result := range results {
181  		_, ok := dup[result.Name]
182  		if !ok {
183  			dup[result.Name] = result.Version
184  			ret = append(ret, result)
185  		} else {
186  			if result.Version != "" && dup[result.Name] != result.Version {
187  				dup[result.Name] = result.Version
188  				// 删除原来
189  				for i, v := range ret {
190  					if v.Name == result.Name {
191  						ret = append(ret[:i], ret[i+1:]...)
192  						break
193  					}
194  				}
195  				ret = append(ret, result)
196  			}
197  		}
198  	}
199  	return ret
200  }
201  
202  // GetFps 获取当前Runner中的所有指纹规则
203  func (r *Runner) GetFps() []parser.FingerPrint {
204  	return r.fps
205  }
206  
207  // EvalFpVersion 获取指定指纹的版本信息
208  // 通过正则表达式从响应中提取版本号
209  func EvalFpVersion(uri string, hp *httpx.HTTPX, fp parser.FingerPrint) (string, error) {
210  	fuzzyRanges := make([]versionRange, 0)
211  
212  	for _, req := range fp.Version {
213  		var (
214  			resp *httpx.Response
215  			err  error
216  		)
217  
218  		switch strings.ToUpper(req.Method) {
219  		case "POST":
220  			resp, err = hp.POST(uri+req.Path, req.Data, nil)
221  		default:
222  			resp, err = hp.Get(uri+req.Path, nil)
223  		}
224  		if err != nil {
225  			gologger.WithError(err).Errorln("请求失败")
226  			continue
227  		}
228  		if resp == nil {
229  			continue
230  		}
231  
232  		sum := sha256.Sum256(resp.Data)
233  		respHash := hex.EncodeToString(sum[:])
234  		fpConfig := &parser.Config{
235  			Body:   resp.DataStr,
236  			Header: resp.GetHeaderRaw(),
237  			Icon:   0,
238  			Hash:   respHash,
239  		}
240  
241  		matched := false
242  		if len(req.GetDsl()) == 0 {
243  			matched = true
244  		} else {
245  			for _, dsl := range req.GetDsl() {
246  				if parser.Eval(fpConfig, dsl) {
247  					matched = true
248  					break
249  				}
250  			}
251  		}
252  		if !matched {
253  			continue
254  		}
255  
256  		if strings.TrimSpace(req.VersionRange) == "" {
257  			version := ""
258  			if req.Extractor.Regex != "" {
259  				compileRegex, err := regexp.Compile("(?i)" + req.Extractor.Regex)
260  				if err != nil {
261  					gologger.WithError(err).Errorln("compile regex error", req.Extractor.Regex)
262  				} else {
263  					index, err := strconv.Atoi(req.Extractor.Group)
264  					if err != nil {
265  						gologger.WithError(err).Errorln("parse part error", req.Extractor.Part)
266  					} else {
267  						body := fpConfig.Body
268  						if req.Extractor.Part == "header" {
269  							body = fpConfig.Header
270  						}
271  						submatches := compileRegex.FindStringSubmatch(body)
272  						if len(submatches) > 0 {
273  							if index < 0 || index >= len(submatches) {
274  								index = len(submatches) - 1
275  							}
276  							version = submatches[index]
277  						}
278  					}
279  				}
280  			}
281  			if version != "" {
282  				return version, nil
283  			}
284  			continue
285  		}
286  
287  		vr, err := parseVersionRange(req.VersionRange)
288  		if err != nil {
289  			gologger.WithError(err).Errorln("parse version range error", req.VersionRange)
290  			continue
291  		}
292  		fuzzyRanges = append(fuzzyRanges, vr)
293  	}
294  
295  	if len(fuzzyRanges) > 0 {
296  		if vr, ok := intersectVersionRanges(fuzzyRanges); ok {
297  			return vr.String(), nil
298  		}
299  	}
300  
301  	return "", nil
302  }