/ common / websocket / websocket.go
websocket.go
  1  // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  //
 15  // Requirement: Any integration or derivative work must explicitly attribute
 16  // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
 17  // documentation or user interface, as detailed in the NOTICE file.
 18  
 19  package websocket
 20  
 21  import (
 22  	"encoding/json"
 23  	"fmt"
 24  	"github.com/Tencent/AI-Infra-Guard/common/runner"
 25  	"github.com/Tencent/AI-Infra-Guard/internal/options"
 26  	"net/http"
 27  	"sync"
 28  
 29  	"github.com/Tencent/AI-Infra-Guard/internal/gologger"
 30  
 31  	"github.com/gorilla/websocket"
 32  )
 33  
 34  var upgrader = websocket.Upgrader{
 35  	CheckOrigin: func(r *http.Request) bool {
 36  		return true // 允许所有来源
 37  	},
 38  }
 39  
 40  // WSServer WebSocket服务器结构
 41  type WSServer struct {
 42  	broadcast chan []byte
 43  	mu        sync.Mutex
 44  	options   *options.Options
 45  }
 46  
 47  // NewWSServer 创建新的WebSocket服务器
 48  func NewWSServer(options *options.Options) *WSServer {
 49  	return &WSServer{
 50  		options: options,
 51  	}
 52  }
 53  
 54  // HandleWS 处理WebSocket连接
 55  func (s *WSServer) HandleAIInfraWS(w http.ResponseWriter, r *http.Request) {
 56  	conn, err := upgrader.Upgrade(w, r, nil)
 57  	if err != nil {
 58  		gologger.Errorln("升级WebSocket连接失败:", err)
 59  		return
 60  	}
 61  	go s.handleMessages(conn)
 62  }
 63  
 64  // SendMessage 发送消息给指定客户端
 65  func (s *WSServer) SendMessage(conn *websocket.Conn, msgType string, content interface{}) error {
 66  	msg := WSMessage{
 67  		Type:    msgType,
 68  		Content: content,
 69  	}
 70  	data, err := json.Marshal(msg)
 71  	if err != nil {
 72  		return err
 73  	}
 74  	return conn.WriteMessage(websocket.TextMessage, data)
 75  }
 76  
 77  // handleMessages 处理来自客户端的消息
 78  func (s *WSServer) handleMessages(conn *websocket.Conn) {
 79  	for {
 80  		_, message, err := conn.ReadMessage()
 81  		if err != nil {
 82  			break
 83  		}
 84  		var scanReq ScanRequest
 85  		if err := json.Unmarshal(message, &scanReq); err != nil {
 86  			fmt.Printf("解析消息失败: %v\n", err)
 87  			continue
 88  		}
 89  		resp := Response{
 90  			Status:  0,
 91  			Message: "success",
 92  		}
 93  		err = s.SendMessage(conn, WSMsgTypeScanRet, resp)
 94  		if err != nil {
 95  			gologger.Errorf("发送消息失败: %v\n", err)
 96  			continue
 97  		}
 98  		// 处理扫描请求
 99  		go s.handleScanRequest(conn, &scanReq)
100  	}
101  }
102  
103  // handleScanRequest 处理扫描请求
104  func (s *WSServer) handleScanRequest(conn *websocket.Conn, req *ScanRequest) {
105  	// 深拷贝options
106  	vv, _ := json.Marshal(s.options)
107  	opts := &options.Options{}
108  	_ = json.Unmarshal(vv, &opts)
109  
110  	switch req.ScanType {
111  	case "localscan":
112  		opts.LocalScan = true
113  	case "netscan":
114  		opts.Target = req.Target
115  	}
116  	mu := sync.Mutex{}
117  	processFunc := func(data interface{}) {
118  		mu.Lock()
119  		defer mu.Unlock()
120  		switch v := data.(type) {
121  		case runner.CallbackScanResult:
122  			s.SendMessage(conn, WSMsgTypeScanResult, v)
123  		case runner.CallbackProcessInfo:
124  			s.SendMessage(conn, WSMsgTypeProcessInfo, v)
125  		case runner.CallbackReportInfo:
126  			s.SendMessage(conn, WSMsgTypeReportInfo, v)
127  		default:
128  			gologger.Errorf("processFunc unknown type: %T\n", v)
129  		}
130  	}
131  	opts.SetCallback(processFunc)
132  	headers := make([]string, 0)
133  	for k, v := range req.Headers {
134  		headers = append(headers, k+":"+v)
135  	}
136  	opts.Headers = headers
137  	if req.Lang == "en" {
138  		opts.Language = "en"
139  	}
140  
141  	r, err := runner.New(opts) // 创建runner
142  	if err != nil {
143  		s.SendMessage(conn, WSMsgTypeLog, Log{
144  			Message: "Counld not create runner:" + err.Error(),
145  			Level:   "error",
146  		})
147  		return
148  	}
149  	defer r.Close()    // 关闭runner
150  	r.RunEnumeration() // 执行枚举
151  }