msg_router.go
1 package msgmux 2 3 import ( 4 "context" 5 "fmt" 6 "maps" 7 "sync" 8 9 "github.com/btcsuite/btcd/btcec/v2" 10 "github.com/lightningnetwork/lnd/fn/v2" 11 "github.com/lightningnetwork/lnd/lnwire" 12 ) 13 14 var ( 15 // ErrDuplicateEndpoint is returned when an endpoint is registered with 16 // a name that already exists. 17 ErrDuplicateEndpoint = fmt.Errorf("endpoint already registered") 18 19 // ErrUnableToRouteMsg is returned when a message is unable to be 20 // routed to any endpoints. 21 ErrUnableToRouteMsg = fmt.Errorf("unable to route message") 22 ) 23 24 // EndpointName is the name of a given endpoint. This MUST be unique across all 25 // registered endpoints. 26 type EndpointName = string 27 28 // PeerMsg is a wire message that includes the public key of the peer that sent 29 // it. 30 type PeerMsg struct { 31 lnwire.Message 32 33 // PeerPub is the public key of the peer that sent this message. 34 PeerPub btcec.PublicKey 35 } 36 37 // Endpoint is an interface that represents a message endpoint, or the 38 // sub-system that will handle processing an incoming wire message. 39 type Endpoint interface { 40 // Name returns the name of this endpoint. This MUST be unique across 41 // all registered endpoints. 42 Name() EndpointName 43 44 // CanHandle returns true if the target message can be routed to this 45 // endpoint. 46 CanHandle(msg PeerMsg) bool 47 48 // SendMessage handles the target message, and returns true if the 49 // message was able being processed. 50 SendMessage(ctx context.Context, msg PeerMsg) bool 51 } 52 53 // Router is an interface that represents a message router, which is generic 54 // sub-system capable of routing any incoming wire message to a set of 55 // registered endpoints. 56 type Router interface { 57 // RegisterEndpoint registers a new endpoint with the router. If a 58 // duplicate endpoint exists, an error is returned. 59 RegisterEndpoint(Endpoint) error 60 61 // UnregisterEndpoint unregisters the target endpoint from the router. 62 UnregisterEndpoint(EndpointName) error 63 64 // RouteMsg attempts to route the target message to a registered 65 // endpoint. If ANY endpoint could handle the message, then nil is 66 // returned. Otherwise, ErrUnableToRouteMsg is returned. 67 RouteMsg(PeerMsg) error 68 69 // Start starts the peer message router. 70 Start(ctx context.Context) 71 72 // Stop stops the peer message router. 73 Stop() 74 } 75 76 // sendQuery sends a query to the main event loop, and returns the response. 77 func sendQuery[Q any, R any](sendChan chan fn.Req[Q, R], queryArg Q, 78 quit chan struct{}) fn.Result[R] { 79 80 query, respChan := fn.NewReq[Q, R](queryArg) 81 82 if !fn.SendOrQuit(sendChan, query, quit) { 83 return fn.Errf[R]("router shutting down") 84 } 85 86 return fn.NewResult(fn.RecvResp(respChan, nil, quit)) 87 } 88 89 // sendQueryErr is a helper function based on sendQuery that can be used when 90 // the query only needs an error response. 91 func sendQueryErr[Q any](sendChan chan fn.Req[Q, error], queryArg Q, 92 quitChan chan struct{}) error { 93 94 return fn.ElimEither( 95 sendQuery(sendChan, queryArg, quitChan).Either, 96 fn.Iden, fn.Iden, 97 ) 98 } 99 100 // EndpointsMap is a map of all registered endpoints. 101 type EndpointsMap map[EndpointName]Endpoint 102 103 // MultiMsgRouter is a type of message router that is capable of routing new 104 // incoming messages, permitting a message to be routed to multiple registered 105 // endpoints. 106 type MultiMsgRouter struct { 107 startOnce sync.Once 108 stopOnce sync.Once 109 110 // registerChan is the channel that all new endpoints will be sent to. 111 registerChan chan fn.Req[Endpoint, error] 112 113 // unregisterChan is the channel that all endpoints that are to be 114 // removed are sent to. 115 unregisterChan chan fn.Req[EndpointName, error] 116 117 // msgChan is the channel that all messages will be sent to for 118 // processing. 119 msgChan chan fn.Req[PeerMsg, error] 120 121 // endpointsQueries is a channel that all queries to the endpoints map 122 // will be sent to. 123 endpointQueries chan fn.Req[Endpoint, EndpointsMap] 124 125 wg sync.WaitGroup 126 quit chan struct{} 127 } 128 129 // NewMultiMsgRouter creates a new instance of a peer message router. 130 func NewMultiMsgRouter() *MultiMsgRouter { 131 return &MultiMsgRouter{ 132 registerChan: make(chan fn.Req[Endpoint, error]), 133 unregisterChan: make(chan fn.Req[EndpointName, error]), 134 msgChan: make(chan fn.Req[PeerMsg, error]), 135 endpointQueries: make(chan fn.Req[Endpoint, EndpointsMap]), 136 quit: make(chan struct{}), 137 } 138 } 139 140 // Start starts the peer message router. 141 func (p *MultiMsgRouter) Start(ctx context.Context) { 142 log.Infof("Starting Router") 143 144 p.startOnce.Do(func() { 145 p.wg.Add(1) 146 go p.msgRouter(ctx) 147 }) 148 } 149 150 // Stop stops the peer message router. 151 func (p *MultiMsgRouter) Stop() { 152 log.Infof("Stopping Router") 153 154 p.stopOnce.Do(func() { 155 close(p.quit) 156 p.wg.Wait() 157 }) 158 } 159 160 // RegisterEndpoint registers a new endpoint with the router. If a duplicate 161 // endpoint exists, an error is returned. 162 func (p *MultiMsgRouter) RegisterEndpoint(endpoint Endpoint) error { 163 return sendQueryErr(p.registerChan, endpoint, p.quit) 164 } 165 166 // UnregisterEndpoint unregisters the target endpoint from the router. 167 func (p *MultiMsgRouter) UnregisterEndpoint(name EndpointName) error { 168 return sendQueryErr(p.unregisterChan, name, p.quit) 169 } 170 171 // RouteMsg attempts to route the target message to a registered endpoint. If 172 // ANY endpoint could handle the message, then nil is returned. 173 func (p *MultiMsgRouter) RouteMsg(msg PeerMsg) error { 174 return sendQueryErr(p.msgChan, msg, p.quit) 175 } 176 177 // Endpoints returns a list of all registered endpoints. 178 func (p *MultiMsgRouter) endpoints() fn.Result[EndpointsMap] { 179 return sendQuery(p.endpointQueries, nil, p.quit) 180 } 181 182 // msgRouter is the main goroutine that handles all incoming messages. 183 func (p *MultiMsgRouter) msgRouter(ctx context.Context) { 184 defer p.wg.Done() 185 186 // endpoints is a map of all registered endpoints. 187 endpoints := make(map[EndpointName]Endpoint) 188 189 for { 190 select { 191 // A new endpoint was just sent in, so we'll add it to our set 192 // of registered endpoints. 193 case newEndpointMsg := <-p.registerChan: 194 endpoint := newEndpointMsg.Request 195 196 log.Infof("MsgRouter: registering new "+ 197 "Endpoint(%s)", endpoint.Name()) 198 199 // If this endpoint already exists, then we'll return 200 // an error as we require unique names. 201 if _, ok := endpoints[endpoint.Name()]; ok { 202 log.Errorf("MsgRouter: rejecting "+ 203 "duplicate endpoint: %v", 204 endpoint.Name()) 205 206 newEndpointMsg.Resolve(ErrDuplicateEndpoint) 207 208 continue 209 } 210 211 endpoints[endpoint.Name()] = endpoint 212 213 newEndpointMsg.Resolve(nil) 214 215 // A request to unregister an endpoint was just sent in, so 216 // we'll attempt to remove it. 217 case endpointName := <-p.unregisterChan: 218 delete(endpoints, endpointName.Request) 219 220 log.Infof("MsgRouter: unregistering "+ 221 "Endpoint(%s)", endpointName.Request) 222 223 endpointName.Resolve(nil) 224 225 // A new message was just sent in. We'll attempt to route it to 226 // all the endpoints that can handle it. 227 case msgQuery := <-p.msgChan: 228 msg := msgQuery.Request 229 230 // Loop through all the endpoints and send the message 231 // to those that can handle it the message. 232 var couldSend bool 233 for _, endpoint := range endpoints { 234 if endpoint.CanHandle(msg) { 235 log.Tracef("MsgRouter: sending "+ 236 "msg %T to endpoint %s", msg, 237 endpoint.Name()) 238 239 sent := endpoint.SendMessage(ctx, msg) 240 couldSend = couldSend || sent 241 } 242 } 243 244 var err error 245 if !couldSend { 246 log.Tracef("MsgRouter: unable to route "+ 247 "msg %T", msg.Message) 248 249 err = ErrUnableToRouteMsg 250 } 251 252 msgQuery.Resolve(err) 253 254 // A query for the endpoint state just came in, we'll send back 255 // a copy of our current state. 256 case endpointQuery := <-p.endpointQueries: 257 endpointsCopy := make(EndpointsMap, len(endpoints)) 258 maps.Copy(endpointsCopy, endpoints) 259 260 endpointQuery.Resolve(endpointsCopy) 261 262 case <-p.quit: 263 return 264 } 265 } 266 } 267 268 // A compile time check to ensure MultiMsgRouter implements the MsgRouter 269 // interface. 270 var _ Router = (*MultiMsgRouter)(nil)