websocket.go
1 // Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Requirement: Any integration or derivative work must explicitly attribute 16 // Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 // documentation or user interface, as detailed in the NOTICE file. 18 19 package websocket 20 21 import ( 22 "encoding/json" 23 "fmt" 24 "github.com/Tencent/AI-Infra-Guard/common/runner" 25 "github.com/Tencent/AI-Infra-Guard/internal/options" 26 "net/http" 27 "sync" 28 29 "github.com/Tencent/AI-Infra-Guard/internal/gologger" 30 31 "github.com/gorilla/websocket" 32 ) 33 34 var upgrader = websocket.Upgrader{ 35 CheckOrigin: func(r *http.Request) bool { 36 return true // 允许所有来源 37 }, 38 } 39 40 // WSServer WebSocket服务器结构 41 type WSServer struct { 42 broadcast chan []byte 43 mu sync.Mutex 44 options *options.Options 45 } 46 47 // NewWSServer 创建新的WebSocket服务器 48 func NewWSServer(options *options.Options) *WSServer { 49 return &WSServer{ 50 options: options, 51 } 52 } 53 54 // HandleWS 处理WebSocket连接 55 func (s *WSServer) HandleAIInfraWS(w http.ResponseWriter, r *http.Request) { 56 conn, err := upgrader.Upgrade(w, r, nil) 57 if err != nil { 58 gologger.Errorln("升级WebSocket连接失败:", err) 59 return 60 } 61 go s.handleMessages(conn) 62 } 63 64 // SendMessage 发送消息给指定客户端 65 func (s *WSServer) SendMessage(conn *websocket.Conn, msgType string, content interface{}) error { 66 msg := WSMessage{ 67 Type: msgType, 68 Content: content, 69 } 70 data, err := json.Marshal(msg) 71 if err != nil { 72 return err 73 } 74 return conn.WriteMessage(websocket.TextMessage, data) 75 } 76 77 // handleMessages 处理来自客户端的消息 78 func (s *WSServer) handleMessages(conn *websocket.Conn) { 79 for { 80 _, message, err := conn.ReadMessage() 81 if err != nil { 82 break 83 } 84 var scanReq ScanRequest 85 if err := json.Unmarshal(message, &scanReq); err != nil { 86 fmt.Printf("解析消息失败: %v\n", err) 87 continue 88 } 89 resp := Response{ 90 Status: 0, 91 Message: "success", 92 } 93 err = s.SendMessage(conn, WSMsgTypeScanRet, resp) 94 if err != nil { 95 gologger.Errorf("发送消息失败: %v\n", err) 96 continue 97 } 98 // 处理扫描请求 99 go s.handleScanRequest(conn, &scanReq) 100 } 101 } 102 103 // handleScanRequest 处理扫描请求 104 func (s *WSServer) handleScanRequest(conn *websocket.Conn, req *ScanRequest) { 105 // 深拷贝options 106 vv, _ := json.Marshal(s.options) 107 opts := &options.Options{} 108 _ = json.Unmarshal(vv, &opts) 109 110 switch req.ScanType { 111 case "localscan": 112 opts.LocalScan = true 113 case "netscan": 114 opts.Target = req.Target 115 } 116 mu := sync.Mutex{} 117 processFunc := func(data interface{}) { 118 mu.Lock() 119 defer mu.Unlock() 120 switch v := data.(type) { 121 case runner.CallbackScanResult: 122 s.SendMessage(conn, WSMsgTypeScanResult, v) 123 case runner.CallbackProcessInfo: 124 s.SendMessage(conn, WSMsgTypeProcessInfo, v) 125 case runner.CallbackReportInfo: 126 s.SendMessage(conn, WSMsgTypeReportInfo, v) 127 default: 128 gologger.Errorf("processFunc unknown type: %T\n", v) 129 } 130 } 131 opts.SetCallback(processFunc) 132 headers := make([]string, 0) 133 for k, v := range req.Headers { 134 headers = append(headers, k+":"+v) 135 } 136 opts.Headers = headers 137 if req.Lang == "en" { 138 opts.Language = "en" 139 } 140 141 r, err := runner.New(opts) // 创建runner 142 if err != nil { 143 s.SendMessage(conn, WSMsgTypeLog, Log{ 144 Message: "Counld not create runner:" + err.Error(), 145 Level: "error", 146 }) 147 return 148 } 149 defer r.Close() // 关闭runner 150 r.RunEnumeration() // 执行枚举 151 }