/ msgmux / msg_router.go
msg_router.go
  1  package msgmux
  2  
  3  import (
  4  	"context"
  5  	"fmt"
  6  	"maps"
  7  	"sync"
  8  
  9  	"github.com/btcsuite/btcd/btcec/v2"
 10  	"github.com/lightningnetwork/lnd/fn/v2"
 11  	"github.com/lightningnetwork/lnd/lnwire"
 12  )
 13  
 14  var (
 15  	// ErrDuplicateEndpoint is returned when an endpoint is registered with
 16  	// a name that already exists.
 17  	ErrDuplicateEndpoint = fmt.Errorf("endpoint already registered")
 18  
 19  	// ErrUnableToRouteMsg is returned when a message is unable to be
 20  	// routed to any endpoints.
 21  	ErrUnableToRouteMsg = fmt.Errorf("unable to route message")
 22  )
 23  
 24  // EndpointName is the name of a given endpoint. This MUST be unique across all
 25  // registered endpoints.
 26  type EndpointName = string
 27  
 28  // PeerMsg is a wire message that includes the public key of the peer that sent
 29  // it.
 30  type PeerMsg struct {
 31  	lnwire.Message
 32  
 33  	// PeerPub is the public key of the peer that sent this message.
 34  	PeerPub btcec.PublicKey
 35  }
 36  
 37  // Endpoint is an interface that represents a message endpoint, or the
 38  // sub-system that will handle processing an incoming wire message.
 39  type Endpoint interface {
 40  	// Name returns the name of this endpoint. This MUST be unique across
 41  	// all registered endpoints.
 42  	Name() EndpointName
 43  
 44  	// CanHandle returns true if the target message can be routed to this
 45  	// endpoint.
 46  	CanHandle(msg PeerMsg) bool
 47  
 48  	// SendMessage handles the target message, and returns true if the
 49  	// message was able being processed.
 50  	SendMessage(ctx context.Context, msg PeerMsg) bool
 51  }
 52  
 53  // Router is an interface that represents a message router, which is generic
 54  // sub-system capable of routing any incoming wire message to a set of
 55  // registered endpoints.
 56  type Router interface {
 57  	// RegisterEndpoint registers a new endpoint with the router. If a
 58  	// duplicate endpoint exists, an error is returned.
 59  	RegisterEndpoint(Endpoint) error
 60  
 61  	// UnregisterEndpoint unregisters the target endpoint from the router.
 62  	UnregisterEndpoint(EndpointName) error
 63  
 64  	// RouteMsg attempts to route the target message to a registered
 65  	// endpoint. If ANY endpoint could handle the message, then nil is
 66  	// returned. Otherwise, ErrUnableToRouteMsg is returned.
 67  	RouteMsg(PeerMsg) error
 68  
 69  	// Start starts the peer message router.
 70  	Start(ctx context.Context)
 71  
 72  	// Stop stops the peer message router.
 73  	Stop()
 74  }
 75  
 76  // sendQuery sends a query to the main event loop, and returns the response.
 77  func sendQuery[Q any, R any](sendChan chan fn.Req[Q, R], queryArg Q,
 78  	quit chan struct{}) fn.Result[R] {
 79  
 80  	query, respChan := fn.NewReq[Q, R](queryArg)
 81  
 82  	if !fn.SendOrQuit(sendChan, query, quit) {
 83  		return fn.Errf[R]("router shutting down")
 84  	}
 85  
 86  	return fn.NewResult(fn.RecvResp(respChan, nil, quit))
 87  }
 88  
 89  // sendQueryErr is a helper function based on sendQuery that can be used when
 90  // the query only needs an error response.
 91  func sendQueryErr[Q any](sendChan chan fn.Req[Q, error], queryArg Q,
 92  	quitChan chan struct{}) error {
 93  
 94  	return fn.ElimEither(
 95  		sendQuery(sendChan, queryArg, quitChan).Either,
 96  		fn.Iden, fn.Iden,
 97  	)
 98  }
 99  
100  // EndpointsMap is a map of all registered endpoints.
101  type EndpointsMap map[EndpointName]Endpoint
102  
103  // MultiMsgRouter is a type of message router that is capable of routing new
104  // incoming messages, permitting a message to be routed to multiple registered
105  // endpoints.
106  type MultiMsgRouter struct {
107  	startOnce sync.Once
108  	stopOnce  sync.Once
109  
110  	// registerChan is the channel that all new endpoints will be sent to.
111  	registerChan chan fn.Req[Endpoint, error]
112  
113  	// unregisterChan is the channel that all endpoints that are to be
114  	// removed are sent to.
115  	unregisterChan chan fn.Req[EndpointName, error]
116  
117  	// msgChan is the channel that all messages will be sent to for
118  	// processing.
119  	msgChan chan fn.Req[PeerMsg, error]
120  
121  	// endpointsQueries is a channel that all queries to the endpoints map
122  	// will be sent to.
123  	endpointQueries chan fn.Req[Endpoint, EndpointsMap]
124  
125  	wg   sync.WaitGroup
126  	quit chan struct{}
127  }
128  
129  // NewMultiMsgRouter creates a new instance of a peer message router.
130  func NewMultiMsgRouter() *MultiMsgRouter {
131  	return &MultiMsgRouter{
132  		registerChan:    make(chan fn.Req[Endpoint, error]),
133  		unregisterChan:  make(chan fn.Req[EndpointName, error]),
134  		msgChan:         make(chan fn.Req[PeerMsg, error]),
135  		endpointQueries: make(chan fn.Req[Endpoint, EndpointsMap]),
136  		quit:            make(chan struct{}),
137  	}
138  }
139  
140  // Start starts the peer message router.
141  func (p *MultiMsgRouter) Start(ctx context.Context) {
142  	log.Infof("Starting Router")
143  
144  	p.startOnce.Do(func() {
145  		p.wg.Add(1)
146  		go p.msgRouter(ctx)
147  	})
148  }
149  
150  // Stop stops the peer message router.
151  func (p *MultiMsgRouter) Stop() {
152  	log.Infof("Stopping Router")
153  
154  	p.stopOnce.Do(func() {
155  		close(p.quit)
156  		p.wg.Wait()
157  	})
158  }
159  
160  // RegisterEndpoint registers a new endpoint with the router. If a duplicate
161  // endpoint exists, an error is returned.
162  func (p *MultiMsgRouter) RegisterEndpoint(endpoint Endpoint) error {
163  	return sendQueryErr(p.registerChan, endpoint, p.quit)
164  }
165  
166  // UnregisterEndpoint unregisters the target endpoint from the router.
167  func (p *MultiMsgRouter) UnregisterEndpoint(name EndpointName) error {
168  	return sendQueryErr(p.unregisterChan, name, p.quit)
169  }
170  
171  // RouteMsg attempts to route the target message to a registered endpoint. If
172  // ANY endpoint could handle the message, then nil is returned.
173  func (p *MultiMsgRouter) RouteMsg(msg PeerMsg) error {
174  	return sendQueryErr(p.msgChan, msg, p.quit)
175  }
176  
177  // Endpoints returns a list of all registered endpoints.
178  func (p *MultiMsgRouter) endpoints() fn.Result[EndpointsMap] {
179  	return sendQuery(p.endpointQueries, nil, p.quit)
180  }
181  
182  // msgRouter is the main goroutine that handles all incoming messages.
183  func (p *MultiMsgRouter) msgRouter(ctx context.Context) {
184  	defer p.wg.Done()
185  
186  	// endpoints is a map of all registered endpoints.
187  	endpoints := make(map[EndpointName]Endpoint)
188  
189  	for {
190  		select {
191  		// A new endpoint was just sent in, so we'll add it to our set
192  		// of registered endpoints.
193  		case newEndpointMsg := <-p.registerChan:
194  			endpoint := newEndpointMsg.Request
195  
196  			log.Infof("MsgRouter: registering new "+
197  				"Endpoint(%s)", endpoint.Name())
198  
199  			// If this endpoint already exists, then we'll return
200  			// an error as we require unique names.
201  			if _, ok := endpoints[endpoint.Name()]; ok {
202  				log.Errorf("MsgRouter: rejecting "+
203  					"duplicate endpoint: %v",
204  					endpoint.Name())
205  
206  				newEndpointMsg.Resolve(ErrDuplicateEndpoint)
207  
208  				continue
209  			}
210  
211  			endpoints[endpoint.Name()] = endpoint
212  
213  			newEndpointMsg.Resolve(nil)
214  
215  		// A request to unregister an endpoint was just sent in, so
216  		// we'll attempt to remove it.
217  		case endpointName := <-p.unregisterChan:
218  			delete(endpoints, endpointName.Request)
219  
220  			log.Infof("MsgRouter: unregistering "+
221  				"Endpoint(%s)", endpointName.Request)
222  
223  			endpointName.Resolve(nil)
224  
225  		// A new message was just sent in. We'll attempt to route it to
226  		// all the endpoints that can handle it.
227  		case msgQuery := <-p.msgChan:
228  			msg := msgQuery.Request
229  
230  			// Loop through all the endpoints and send the message
231  			// to those that can handle it the message.
232  			var couldSend bool
233  			for _, endpoint := range endpoints {
234  				if endpoint.CanHandle(msg) {
235  					log.Tracef("MsgRouter: sending "+
236  						"msg %T to endpoint %s", msg,
237  						endpoint.Name())
238  
239  					sent := endpoint.SendMessage(ctx, msg)
240  					couldSend = couldSend || sent
241  				}
242  			}
243  
244  			var err error
245  			if !couldSend {
246  				log.Tracef("MsgRouter: unable to route "+
247  					"msg %T", msg.Message)
248  
249  				err = ErrUnableToRouteMsg
250  			}
251  
252  			msgQuery.Resolve(err)
253  
254  		// A query for the endpoint state just came in, we'll send back
255  		// a copy of our current state.
256  		case endpointQuery := <-p.endpointQueries:
257  			endpointsCopy := make(EndpointsMap, len(endpoints))
258  			maps.Copy(endpointsCopy, endpoints)
259  
260  			endpointQuery.Resolve(endpointsCopy)
261  
262  		case <-p.quit:
263  			return
264  		}
265  	}
266  }
267  
268  // A compile time check to ensure MultiMsgRouter implements the MsgRouter
269  // interface.
270  var _ Router = (*MultiMsgRouter)(nil)