/ peer / test_utils.go
test_utils.go
  1  package peer
  2  
  3  import (
  4  	"bytes"
  5  	crand "crypto/rand"
  6  	"encoding/binary"
  7  	"io"
  8  	"math/rand"
  9  	"net"
 10  	"sync/atomic"
 11  	"testing"
 12  	"time"
 13  
 14  	"github.com/btcsuite/btcd/btcec/v2"
 15  	"github.com/btcsuite/btcd/btcutil"
 16  	"github.com/btcsuite/btcd/chaincfg/chainhash"
 17  	"github.com/btcsuite/btcd/wire"
 18  	"github.com/lightningnetwork/lnd/chainntnfs"
 19  	"github.com/lightningnetwork/lnd/channeldb"
 20  	"github.com/lightningnetwork/lnd/channelnotifier"
 21  	"github.com/lightningnetwork/lnd/fn/v2"
 22  	graphdb "github.com/lightningnetwork/lnd/graph/db"
 23  	"github.com/lightningnetwork/lnd/htlcswitch"
 24  	"github.com/lightningnetwork/lnd/input"
 25  	"github.com/lightningnetwork/lnd/keychain"
 26  	"github.com/lightningnetwork/lnd/lntest/channels"
 27  	"github.com/lightningnetwork/lnd/lntest/mock"
 28  	"github.com/lightningnetwork/lnd/lntypes"
 29  	"github.com/lightningnetwork/lnd/lnwallet"
 30  	"github.com/lightningnetwork/lnd/lnwallet/chainfee"
 31  	"github.com/lightningnetwork/lnd/lnwire"
 32  	"github.com/lightningnetwork/lnd/netann"
 33  	"github.com/lightningnetwork/lnd/pool"
 34  	"github.com/lightningnetwork/lnd/queue"
 35  	"github.com/lightningnetwork/lnd/shachain"
 36  	"github.com/stretchr/testify/require"
 37  )
 38  
 39  const (
 40  	broadcastHeight = 100
 41  
 42  	// timeout is a timeout value to use for tests which need to wait for
 43  	// a return value on a channel.
 44  	timeout = time.Second * 5
 45  
 46  	// testCltvRejectDelta is the minimum delta between expiry and current
 47  	// height below which htlcs are rejected.
 48  	testCltvRejectDelta = 13
 49  )
 50  
 51  var (
 52  	testKeyLoc = keychain.KeyLocator{Family: keychain.KeyFamilyNodeKey}
 53  )
 54  
 55  // noUpdate is a function which can be used as a parameter in
 56  // createTestPeerWithChannel to call the setup code with no custom values on
 57  // the channels set up.
 58  var noUpdate = func(a, b *channeldb.OpenChannel) {}
 59  
 60  type peerTestCtx struct {
 61  	peer          *Brontide
 62  	channel       *lnwallet.LightningChannel
 63  	notifier      *mock.ChainNotifier
 64  	publishTx     <-chan *wire.MsgTx
 65  	mockSwitch    *mockMessageSwitch
 66  	db            *channeldb.DB
 67  	privKey       *btcec.PrivateKey
 68  	mockConn      *mockMessageConn
 69  	customChan    chan *customMsg
 70  	chanStatusMgr *netann.ChanStatusManager
 71  }
 72  
 73  // createTestPeerWithChannel creates a channel between two nodes, and returns a
 74  // peer for one of the nodes, together with the channel seen from both nodes.
 75  // It takes an updateChan function which can be used to modify the default
 76  // values on the channel states for each peer.
 77  func createTestPeerWithChannel(t *testing.T, updateChan func(a,
 78  	b *channeldb.OpenChannel)) (*peerTestCtx, error) {
 79  
 80  	params := createTestPeer(t)
 81  
 82  	var (
 83  		publishTx     = params.publishTx
 84  		mockSwitch    = params.mockSwitch
 85  		alicePeer     = params.peer
 86  		notifier      = params.notifier
 87  		aliceKeyPriv  = params.privKey
 88  		dbAlice       = params.db
 89  		chanStatusMgr = params.chanStatusMgr
 90  	)
 91  
 92  	err := chanStatusMgr.Start()
 93  	require.NoError(t, err)
 94  	t.Cleanup(func() {
 95  		require.NoError(t, chanStatusMgr.Stop())
 96  	})
 97  
 98  	aliceKeyPub := alicePeer.IdentityKey()
 99  	estimator := alicePeer.cfg.FeeEstimator
100  
101  	channelCapacity := btcutil.Amount(10 * 1e8)
102  	channelBal := channelCapacity / 2
103  	aliceDustLimit := btcutil.Amount(200)
104  	bobDustLimit := btcutil.Amount(1300)
105  	csvTimeoutAlice := uint32(5)
106  	csvTimeoutBob := uint32(4)
107  	isAliceInitiator := true
108  
109  	prevOut := &wire.OutPoint{
110  		Hash:  channels.TestHdSeed,
111  		Index: 0,
112  	}
113  	fundingTxIn := wire.NewTxIn(prevOut, nil, nil)
114  
115  	bobKeyPriv, bobKeyPub := btcec.PrivKeyFromBytes(
116  		channels.BobsPrivKey,
117  	)
118  
119  	aliceCfg := channeldb.ChannelConfig{
120  		ChannelStateBounds: channeldb.ChannelStateBounds{
121  			MaxPendingAmount: lnwire.MilliSatoshi(rand.Int63()),
122  			ChanReserve:      btcutil.Amount(rand.Int63()),
123  			MinHTLC:          lnwire.MilliSatoshi(rand.Int63()),
124  			MaxAcceptedHtlcs: uint16(rand.Int31()),
125  		},
126  		CommitmentParams: channeldb.CommitmentParams{
127  			DustLimit: aliceDustLimit,
128  			CsvDelay:  uint16(csvTimeoutAlice),
129  		},
130  		MultiSigKey: keychain.KeyDescriptor{
131  			PubKey: aliceKeyPub,
132  		},
133  		RevocationBasePoint: keychain.KeyDescriptor{
134  			PubKey: aliceKeyPub,
135  		},
136  		PaymentBasePoint: keychain.KeyDescriptor{
137  			PubKey: aliceKeyPub,
138  		},
139  		DelayBasePoint: keychain.KeyDescriptor{
140  			PubKey: aliceKeyPub,
141  		},
142  		HtlcBasePoint: keychain.KeyDescriptor{
143  			PubKey: aliceKeyPub,
144  		},
145  	}
146  	bobCfg := channeldb.ChannelConfig{
147  		ChannelStateBounds: channeldb.ChannelStateBounds{
148  			MaxPendingAmount: lnwire.MilliSatoshi(rand.Int63()),
149  			ChanReserve:      btcutil.Amount(rand.Int63()),
150  			MinHTLC:          lnwire.MilliSatoshi(rand.Int63()),
151  			MaxAcceptedHtlcs: uint16(rand.Int31()),
152  		},
153  		CommitmentParams: channeldb.CommitmentParams{
154  			DustLimit: bobDustLimit,
155  			CsvDelay:  uint16(csvTimeoutBob),
156  		},
157  		MultiSigKey: keychain.KeyDescriptor{
158  			PubKey: bobKeyPub,
159  		},
160  		RevocationBasePoint: keychain.KeyDescriptor{
161  			PubKey: bobKeyPub,
162  		},
163  		PaymentBasePoint: keychain.KeyDescriptor{
164  			PubKey: bobKeyPub,
165  		},
166  		DelayBasePoint: keychain.KeyDescriptor{
167  			PubKey: bobKeyPub,
168  		},
169  		HtlcBasePoint: keychain.KeyDescriptor{
170  			PubKey: bobKeyPub,
171  		},
172  	}
173  
174  	bobRoot, err := chainhash.NewHash(bobKeyPriv.Serialize())
175  	if err != nil {
176  		return nil, err
177  	}
178  	bobPreimageProducer := shachain.NewRevocationProducer(*bobRoot)
179  	bobFirstRevoke, err := bobPreimageProducer.AtIndex(0)
180  	if err != nil {
181  		return nil, err
182  	}
183  	bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:])
184  
185  	aliceRoot, err := chainhash.NewHash(aliceKeyPriv.Serialize())
186  	if err != nil {
187  		return nil, err
188  	}
189  	alicePreimageProducer := shachain.NewRevocationProducer(*aliceRoot)
190  	aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0)
191  	if err != nil {
192  		return nil, err
193  	}
194  	aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:])
195  
196  	aliceCommitTx, bobCommitTx, err := lnwallet.CreateCommitmentTxns(
197  		channelBal, channelBal, &aliceCfg, &bobCfg, aliceCommitPoint,
198  		bobCommitPoint, *fundingTxIn, channeldb.SingleFunderTweaklessBit,
199  		isAliceInitiator, 0,
200  	)
201  	if err != nil {
202  		return nil, err
203  	}
204  
205  	dbBob := channeldb.OpenForTesting(t, t.TempDir())
206  
207  	feePerKw, err := estimator.EstimateFeePerKW(1)
208  	if err != nil {
209  		return nil, err
210  	}
211  
212  	// TODO(roasbeef): need to factor in commit fee?
213  	aliceCommit := channeldb.ChannelCommitment{
214  		CommitHeight:  0,
215  		LocalBalance:  lnwire.NewMSatFromSatoshis(channelBal),
216  		RemoteBalance: lnwire.NewMSatFromSatoshis(channelBal),
217  		FeePerKw:      btcutil.Amount(feePerKw),
218  		CommitFee:     feePerKw.FeeForWeight(input.CommitWeight),
219  		CommitTx:      aliceCommitTx,
220  		CommitSig:     bytes.Repeat([]byte{1}, 71),
221  	}
222  	bobCommit := channeldb.ChannelCommitment{
223  		CommitHeight:  0,
224  		LocalBalance:  lnwire.NewMSatFromSatoshis(channelBal),
225  		RemoteBalance: lnwire.NewMSatFromSatoshis(channelBal),
226  		FeePerKw:      btcutil.Amount(feePerKw),
227  		CommitFee:     feePerKw.FeeForWeight(input.CommitWeight),
228  		CommitTx:      bobCommitTx,
229  		CommitSig:     bytes.Repeat([]byte{1}, 71),
230  	}
231  
232  	var chanIDBytes [8]byte
233  	if _, err := io.ReadFull(crand.Reader, chanIDBytes[:]); err != nil {
234  		return nil, err
235  	}
236  
237  	shortChanID := lnwire.NewShortChanIDFromInt(
238  		binary.BigEndian.Uint64(chanIDBytes[:]),
239  	)
240  
241  	aliceChannelState := &channeldb.OpenChannel{
242  		LocalChanCfg:            aliceCfg,
243  		RemoteChanCfg:           bobCfg,
244  		IdentityPub:             aliceKeyPub,
245  		FundingOutpoint:         *prevOut,
246  		ShortChannelID:          shortChanID,
247  		ChanType:                channeldb.SingleFunderTweaklessBit,
248  		IsInitiator:             isAliceInitiator,
249  		Capacity:                channelCapacity,
250  		RemoteCurrentRevocation: bobCommitPoint,
251  		RevocationProducer:      alicePreimageProducer,
252  		RevocationStore:         shachain.NewRevocationStore(),
253  		LocalCommitment:         aliceCommit,
254  		RemoteCommitment:        aliceCommit,
255  		Db:                      dbAlice.ChannelStateDB(),
256  		Packager:                channeldb.NewChannelPackager(shortChanID),
257  		FundingTxn:              channels.TestFundingTx,
258  	}
259  	bobChannelState := &channeldb.OpenChannel{
260  		LocalChanCfg:            bobCfg,
261  		RemoteChanCfg:           aliceCfg,
262  		IdentityPub:             bobKeyPub,
263  		FundingOutpoint:         *prevOut,
264  		ChanType:                channeldb.SingleFunderTweaklessBit,
265  		IsInitiator:             !isAliceInitiator,
266  		Capacity:                channelCapacity,
267  		RemoteCurrentRevocation: aliceCommitPoint,
268  		RevocationProducer:      bobPreimageProducer,
269  		RevocationStore:         shachain.NewRevocationStore(),
270  		LocalCommitment:         bobCommit,
271  		RemoteCommitment:        bobCommit,
272  		Db:                      dbBob.ChannelStateDB(),
273  		Packager:                channeldb.NewChannelPackager(shortChanID),
274  	}
275  
276  	// Set custom values on the channel states.
277  	updateChan(aliceChannelState, bobChannelState)
278  
279  	aliceAddr := alicePeer.cfg.Addr.Address
280  	if err := aliceChannelState.SyncPending(aliceAddr, 0); err != nil {
281  		return nil, err
282  	}
283  
284  	bobAddr := &net.TCPAddr{
285  		IP:   net.ParseIP("127.0.0.1"),
286  		Port: 18556,
287  	}
288  
289  	if err := bobChannelState.SyncPending(bobAddr, 0); err != nil {
290  		return nil, err
291  	}
292  
293  	aliceSigner := input.NewMockSigner(
294  		[]*btcec.PrivateKey{aliceKeyPriv}, nil,
295  	)
296  	bobSigner := input.NewMockSigner(
297  		[]*btcec.PrivateKey{bobKeyPriv}, nil,
298  	)
299  
300  	alicePool := lnwallet.NewSigPool(1, aliceSigner)
301  	channelAlice, err := lnwallet.NewLightningChannel(
302  		aliceSigner, aliceChannelState, alicePool,
303  		lnwallet.WithLeafStore(&lnwallet.MockAuxLeafStore{}),
304  		lnwallet.WithAuxSigner(lnwallet.NewAuxSignerMock(
305  			lnwallet.EmptyMockJobHandler,
306  		)),
307  	)
308  	if err != nil {
309  		return nil, err
310  	}
311  	_ = alicePool.Start()
312  	t.Cleanup(func() {
313  		require.NoError(t, alicePool.Stop())
314  	})
315  
316  	bobPool := lnwallet.NewSigPool(1, bobSigner)
317  	channelBob, err := lnwallet.NewLightningChannel(
318  		bobSigner, bobChannelState, bobPool,
319  		lnwallet.WithLeafStore(&lnwallet.MockAuxLeafStore{}),
320  		lnwallet.WithAuxSigner(lnwallet.NewAuxSignerMock(
321  			lnwallet.EmptyMockJobHandler,
322  		)),
323  	)
324  	if err != nil {
325  		return nil, err
326  	}
327  	_ = bobPool.Start()
328  	t.Cleanup(func() {
329  		require.NoError(t, bobPool.Stop())
330  	})
331  
332  	alicePeer.remoteFeatures = lnwire.NewFeatureVector(
333  		nil, lnwire.Features,
334  	)
335  
336  	chanID := lnwire.NewChanIDFromOutPoint(channelAlice.ChannelPoint())
337  	alicePeer.activeChannels.Store(chanID, channelAlice)
338  
339  	alicePeer.cg.WgAdd(1)
340  	go alicePeer.channelManager()
341  
342  	return &peerTestCtx{
343  		peer:       alicePeer,
344  		channel:    channelBob,
345  		notifier:   notifier,
346  		publishTx:  publishTx,
347  		mockSwitch: mockSwitch,
348  		mockConn:   params.mockConn,
349  	}, nil
350  }
351  
352  // mockMessageSwitch is a mock implementation of the messageSwitch interface
353  // used for testing without relying on a *htlcswitch.Switch in unit tests.
354  type mockMessageSwitch struct {
355  	links []htlcswitch.ChannelUpdateHandler
356  }
357  
358  // BestHeight currently returns a dummy value.
359  func (m *mockMessageSwitch) BestHeight() uint32 {
360  	return 0
361  }
362  
363  // CircuitModifier currently returns a dummy value.
364  func (m *mockMessageSwitch) CircuitModifier() htlcswitch.CircuitModifier {
365  	return nil
366  }
367  
368  // RemoveLink currently does nothing.
369  func (m *mockMessageSwitch) RemoveLink(cid lnwire.ChannelID) {}
370  
371  // CreateAndAddLink currently returns a dummy value.
372  func (m *mockMessageSwitch) CreateAndAddLink(cfg htlcswitch.ChannelLinkConfig,
373  	lnChan *lnwallet.LightningChannel) error {
374  
375  	return nil
376  }
377  
378  // GetLinksByInterface returns the active links.
379  func (m *mockMessageSwitch) GetLinksByInterface(pub [33]byte) (
380  	[]htlcswitch.ChannelUpdateHandler, error) {
381  
382  	return m.links, nil
383  }
384  
385  // mockUpdateHandler is a mock implementation of the ChannelUpdateHandler
386  // interface. It is used in mockMessageSwitch's GetLinksByInterface method.
387  type mockUpdateHandler struct {
388  	cid                  lnwire.ChannelID
389  	isOutgoingAddBlocked atomic.Bool
390  	isIncomingAddBlocked atomic.Bool
391  }
392  
393  // newMockUpdateHandler creates a new mockUpdateHandler.
394  func newMockUpdateHandler(cid lnwire.ChannelID) *mockUpdateHandler {
395  	return &mockUpdateHandler{
396  		cid: cid,
397  	}
398  }
399  
400  // HandleChannelUpdate currently does nothing.
401  func (m *mockUpdateHandler) HandleChannelUpdate(msg lnwire.Message) {}
402  
403  // ChanID returns the mockUpdateHandler's cid.
404  func (m *mockUpdateHandler) ChanID() lnwire.ChannelID { return m.cid }
405  
406  // Bandwidth currently returns a dummy value.
407  func (m *mockUpdateHandler) Bandwidth() lnwire.MilliSatoshi { return 0 }
408  
409  // EligibleToForward currently returns a dummy value.
410  func (m *mockUpdateHandler) EligibleToForward() bool { return false }
411  
412  // MayAddOutgoingHtlc currently returns nil.
413  func (m *mockUpdateHandler) MayAddOutgoingHtlc(lnwire.MilliSatoshi) error { return nil }
414  
415  type mockMessageConn struct {
416  	t *testing.T
417  
418  	// MessageConn embeds our interface so that the mock does not need to
419  	// implement every function. The mock will panic if an unspecified function
420  	// is called.
421  	MessageConn
422  
423  	// writtenMessages is a channel that our mock pushes written messages into.
424  	writtenMessages chan []byte
425  
426  	readMessages   chan []byte
427  	curReadMessage []byte
428  
429  	// writeRaceDetectingCounter is incremented on any function call
430  	// associated with writing to the connection. The race detector will
431  	// trigger on this counter if a data race exists.
432  	writeRaceDetectingCounter int
433  
434  	// readRaceDetectingCounter is incremented on any function call
435  	// associated with reading from the connection. The race detector will
436  	// trigger on this counter if a data race exists.
437  	readRaceDetectingCounter int
438  }
439  
440  func (m *mockUpdateHandler) EnableAdds(dir htlcswitch.LinkDirection) bool {
441  	if dir == htlcswitch.Outgoing {
442  		return m.isOutgoingAddBlocked.Swap(false)
443  	}
444  
445  	return m.isIncomingAddBlocked.Swap(false)
446  }
447  
448  func (m *mockUpdateHandler) DisableAdds(dir htlcswitch.LinkDirection) bool {
449  	if dir == htlcswitch.Outgoing {
450  		return !m.isOutgoingAddBlocked.Swap(true)
451  	}
452  
453  	return !m.isIncomingAddBlocked.Swap(true)
454  }
455  
456  func (m *mockUpdateHandler) IsFlushing(dir htlcswitch.LinkDirection) bool {
457  	switch dir {
458  	case htlcswitch.Outgoing:
459  		return m.isOutgoingAddBlocked.Load()
460  	case htlcswitch.Incoming:
461  		return m.isIncomingAddBlocked.Load()
462  	}
463  
464  	return false
465  }
466  
467  func (m *mockUpdateHandler) OnFlushedOnce(hook func()) {
468  	hook()
469  }
470  func (m *mockUpdateHandler) OnCommitOnce(
471  	_ htlcswitch.LinkDirection, hook func(),
472  ) {
473  
474  	hook()
475  }
476  func (m *mockUpdateHandler) InitStfu() <-chan fn.Result[lntypes.ChannelParty] {
477  	// TODO(proofofkeags): Implement
478  	c := make(chan fn.Result[lntypes.ChannelParty], 1)
479  
480  	c <- fn.Errf[lntypes.ChannelParty]("InitStfu not yet implemented")
481  
482  	return c
483  }
484  
485  func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn {
486  	return &mockMessageConn{
487  		t:               t,
488  		writtenMessages: make(chan []byte, expectedMessages),
489  		readMessages:    make(chan []byte, 1),
490  	}
491  }
492  
493  // SetWriteDeadline mocks setting write deadline for our conn.
494  func (m *mockMessageConn) SetWriteDeadline(time.Time) error {
495  	m.writeRaceDetectingCounter++
496  	return nil
497  }
498  
499  // Flush mocks a message conn flush.
500  func (m *mockMessageConn) Flush() (int, error) {
501  	m.writeRaceDetectingCounter++
502  	return 0, nil
503  }
504  
505  // WriteMessage mocks sending of a message on our connection. It will push
506  // the bytes sent into the mock's writtenMessages channel.
507  func (m *mockMessageConn) WriteMessage(msg []byte) error {
508  	m.writeRaceDetectingCounter++
509  
510  	msgCopy := make([]byte, len(msg))
511  	copy(msgCopy, msg)
512  
513  	select {
514  	case m.writtenMessages <- msgCopy:
515  	case <-time.After(timeout):
516  		m.t.Fatalf("timeout sending message: %v", msgCopy)
517  	}
518  
519  	return nil
520  }
521  
522  // assertWrite asserts that our mock as had WriteMessage called with the byte
523  // slice we expect.
524  func (m *mockMessageConn) assertWrite(expected []byte) {
525  	select {
526  	case actual := <-m.writtenMessages:
527  		require.Equal(m.t, expected, actual)
528  
529  	case <-time.After(timeout):
530  		m.t.Fatalf("timeout waiting for write: %v", expected)
531  	}
532  }
533  
534  func (m *mockMessageConn) SetReadDeadline(t time.Time) error {
535  	m.readRaceDetectingCounter++
536  	return nil
537  }
538  
539  func (m *mockMessageConn) ReadNextHeader() (uint32, error) {
540  	m.readRaceDetectingCounter++
541  	m.curReadMessage = <-m.readMessages
542  	return uint32(len(m.curReadMessage)), nil
543  }
544  
545  func (m *mockMessageConn) ReadNextBody(buf []byte) ([]byte, error) {
546  	m.readRaceDetectingCounter++
547  	return m.curReadMessage, nil
548  }
549  
550  func (m *mockMessageConn) RemoteAddr() net.Addr {
551  	return nil
552  }
553  
554  func (m *mockMessageConn) LocalAddr() net.Addr {
555  	return nil
556  }
557  
558  func (m *mockMessageConn) Close() error {
559  	return nil
560  }
561  
562  // createTestPeer creates a new peer for testing and returns a context struct
563  // containing necessary handles and mock objects for conducting tests on peer
564  // functionalities.
565  func createTestPeer(t *testing.T) *peerTestCtx {
566  	nodeKeyLocator := keychain.KeyLocator{
567  		Family: keychain.KeyFamilyNodeKey,
568  	}
569  
570  	aliceKeyPriv, aliceKeyPub := btcec.PrivKeyFromBytes(
571  		channels.AlicesPrivKey,
572  	)
573  
574  	aliceKeySigner := keychain.NewPrivKeyMessageSigner(
575  		aliceKeyPriv, nodeKeyLocator,
576  	)
577  
578  	aliceAddr := &net.TCPAddr{
579  		IP:   net.ParseIP("127.0.0.1"),
580  		Port: 18555,
581  	}
582  	cfgAddr := &lnwire.NetAddress{
583  		IdentityKey: aliceKeyPub,
584  		Address:     aliceAddr,
585  		ChainNet:    wire.SimNet,
586  	}
587  
588  	errBuffer, err := queue.NewCircularBuffer(ErrorBufferSize)
589  	require.NoError(t, err)
590  
591  	chainIO := &mock.ChainIO{
592  		BestHeight: broadcastHeight,
593  	}
594  
595  	publishTx := make(chan *wire.MsgTx)
596  	wallet := &lnwallet.LightningWallet{
597  		WalletController: &mock.WalletController{
598  			RootKey:               aliceKeyPriv,
599  			PublishedTransactions: publishTx,
600  		},
601  	}
602  
603  	const chanActiveTimeout = time.Minute
604  
605  	dbAliceGraph := graphdb.MakeTestGraph(t)
606  	require.NoError(t, dbAliceGraph.Start())
607  	t.Cleanup(func() {
608  		require.NoError(t, dbAliceGraph.Stop())
609  	})
610  
611  	dbAliceChannel := channeldb.OpenForTesting(t, t.TempDir())
612  
613  	nodeSignerAlice := netann.NewNodeSigner(aliceKeySigner)
614  
615  	chanStatusMgr, err := netann.NewChanStatusManager(&netann.
616  		ChanStatusConfig{
617  		ChanStatusSampleInterval: 30 * time.Second,
618  		ChanEnableTimeout:        chanActiveTimeout,
619  		ChanDisableTimeout:       2 * time.Minute,
620  		DB:                       dbAliceChannel.ChannelStateDB(),
621  		Graph:                    dbAliceGraph,
622  		MessageSigner:            nodeSignerAlice,
623  		OurPubKey:                aliceKeyPub,
624  		OurKeyLoc:                testKeyLoc,
625  		IsChannelActive: func(lnwire.ChannelID) bool {
626  			return true
627  		},
628  		ApplyChannelUpdate: func(*lnwire.ChannelUpdate1,
629  			*wire.OutPoint, bool) error {
630  
631  			return nil
632  		},
633  	})
634  	require.NoError(t, err)
635  
636  	interceptableSwitchNotifier := &mock.ChainNotifier{
637  		EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
638  	}
639  	interceptableSwitchNotifier.EpochChan <- &chainntnfs.BlockEpoch{
640  		Height: 1,
641  	}
642  
643  	interceptableSwitch, err := htlcswitch.NewInterceptableSwitch(
644  		&htlcswitch.InterceptableSwitchConfig{
645  			CltvRejectDelta:    testCltvRejectDelta,
646  			CltvInterceptDelta: testCltvRejectDelta + 3,
647  			Notifier:           interceptableSwitchNotifier,
648  		},
649  	)
650  	require.NoError(t, err)
651  
652  	// TODO(yy): create interface for lnwallet.LightningChannel so we can
653  	// easily mock it without the following setups.
654  	notifier := &mock.ChainNotifier{
655  		SpendChan: make(chan *chainntnfs.SpendDetail),
656  		EpochChan: make(chan *chainntnfs.BlockEpoch),
657  		ConfChan:  make(chan *chainntnfs.TxConfirmation),
658  	}
659  
660  	mockSwitch := &mockMessageSwitch{}
661  
662  	// TODO(yy): change ChannelNotifier to be an interface.
663  	channelNotifier := channelnotifier.New(dbAliceChannel.ChannelStateDB())
664  	require.NoError(t, channelNotifier.Start())
665  	t.Cleanup(func() {
666  		require.NoError(t, channelNotifier.Stop(),
667  			"stop channel notifier failed")
668  	})
669  
670  	writeBufferPool := pool.NewWriteBuffer(
671  		pool.DefaultWriteBufferGCInterval,
672  		pool.DefaultWriteBufferExpiryInterval,
673  	)
674  
675  	writePool := pool.NewWrite(
676  		writeBufferPool, 1, timeout,
677  	)
678  	require.NoError(t, writePool.Start())
679  
680  	readBufferPool := pool.NewReadBuffer(
681  		pool.DefaultReadBufferGCInterval,
682  		pool.DefaultReadBufferExpiryInterval,
683  	)
684  
685  	readPool := pool.NewRead(
686  		readBufferPool, 1, timeout,
687  	)
688  	require.NoError(t, readPool.Start())
689  
690  	mockConn := newMockConn(t, 1)
691  
692  	receivedCustomChan := make(chan *customMsg)
693  
694  	var pubKey [33]byte
695  	copy(pubKey[:], aliceKeyPub.SerializeCompressed())
696  
697  	// We have to have a valid server key for brontide to start up properly.
698  	serverKey, err := btcec.NewPrivateKey()
699  	require.NoError(t, err)
700  
701  	var serverKeyArr [33]byte
702  	copy(serverKeyArr[:], serverKey.PubKey().SerializeCompressed())
703  
704  	estimator := chainfee.NewStaticEstimator(12500, 0)
705  
706  	cfg := &Config{
707  		Addr:              cfgAddr,
708  		PubKeyBytes:       pubKey,
709  		ServerPubKey:      serverKeyArr,
710  		ErrorBuffer:       errBuffer,
711  		ChainIO:           chainIO,
712  		Switch:            mockSwitch,
713  		ChanActiveTimeout: chanActiveTimeout,
714  		InterceptSwitch:   interceptableSwitch,
715  		ChannelDB:         dbAliceChannel.ChannelStateDB(),
716  		FeeEstimator:      estimator,
717  		Wallet:            wallet,
718  		ChainNotifier:     notifier,
719  		ChanStatusMgr:     chanStatusMgr,
720  		Features: lnwire.NewFeatureVector(
721  			nil, lnwire.Features,
722  		),
723  		DisconnectPeer: func(b *btcec.PublicKey) error {
724  			return nil
725  		},
726  		ChannelNotifier:               channelNotifier,
727  		PrunePersistentPeerConnection: func([33]byte) {},
728  		LegacyFeatures:                lnwire.EmptyFeatureVector(),
729  		WritePool:                     writePool,
730  		ReadPool:                      readPool,
731  		Conn:                          mockConn,
732  		HandleCustomMessage: func(
733  			peer [33]byte, msg *lnwire.Custom) error {
734  
735  			receivedCustomChan <- &customMsg{
736  				peer: peer,
737  				msg:  *msg,
738  			}
739  
740  			return nil
741  		},
742  		PongBuf: make([]byte, lnwire.MaxPongBytes),
743  		FetchLastChanUpdate: func(chanID lnwire.ShortChannelID,
744  		) (*lnwire.ChannelUpdate1, error) {
745  
746  			return &lnwire.ChannelUpdate1{}, nil
747  		},
748  	}
749  
750  	alicePeer := NewBrontide(*cfg)
751  
752  	return &peerTestCtx{
753  		publishTx:     publishTx,
754  		mockSwitch:    mockSwitch,
755  		peer:          alicePeer,
756  		notifier:      notifier,
757  		db:            dbAliceChannel,
758  		privKey:       aliceKeyPriv,
759  		mockConn:      mockConn,
760  		customChan:    receivedCustomChan,
761  		chanStatusMgr: chanStatusMgr,
762  	}
763  }
764  
765  // startPeer invokes the `Start` method on the specified peer and handles any
766  // initial startup messages for testing.
767  func startPeer(t *testing.T, mockConn *mockMessageConn,
768  	peer *Brontide) <-chan struct{} {
769  
770  	// Start the peer in a goroutine so that we can handle and test for
771  	// startup messages. Successfully sending and receiving init message,
772  	// indicates a successful startup.
773  	done := make(chan struct{})
774  	go func() {
775  		require.NoError(t, peer.Start())
776  		close(done)
777  	}()
778  
779  	// Receive the init message that should be the first message received on
780  	// startup.
781  	rawMsg, err := fn.RecvOrTimeout[[]byte](
782  		mockConn.writtenMessages, timeout,
783  	)
784  	require.NoError(t, err)
785  
786  	msgReader := bytes.NewReader(rawMsg)
787  	nextMsg, err := lnwire.ReadMessage(msgReader, 0)
788  	require.NoError(t, err)
789  
790  	_, ok := nextMsg.(*lnwire.Init)
791  	require.True(t, ok)
792  
793  	// Write the reply for the init message to complete the startup.
794  	initReplyMsg := lnwire.NewInitMessage(
795  		lnwire.NewRawFeatureVector(
796  			lnwire.DataLossProtectRequired,
797  			lnwire.GossipQueriesOptional,
798  		),
799  		lnwire.NewRawFeatureVector(),
800  	)
801  
802  	var b bytes.Buffer
803  	_, err = lnwire.WriteMessage(&b, initReplyMsg, 0)
804  	require.NoError(t, err)
805  
806  	ok = fn.SendOrQuit[[]byte, struct{}](
807  		mockConn.readMessages, b.Bytes(), make(chan struct{}),
808  	)
809  	require.True(t, ok)
810  
811  	return done
812  }