/ peer / brontide_test.go
brontide_test.go
   1  package peer
   2  
   3  import (
   4  	"bytes"
   5  	"fmt"
   6  	"testing"
   7  	"time"
   8  
   9  	"github.com/btcsuite/btcd/btcec/v2"
  10  	"github.com/btcsuite/btcd/btcutil"
  11  	"github.com/btcsuite/btcd/chaincfg"
  12  	"github.com/btcsuite/btcd/txscript"
  13  	"github.com/btcsuite/btcd/wire"
  14  	"github.com/lightningnetwork/lnd/chainntnfs"
  15  	"github.com/lightningnetwork/lnd/channeldb"
  16  	"github.com/lightningnetwork/lnd/contractcourt"
  17  	"github.com/lightningnetwork/lnd/fn/v2"
  18  	"github.com/lightningnetwork/lnd/htlcswitch"
  19  	"github.com/lightningnetwork/lnd/lntest/wait"
  20  	"github.com/lightningnetwork/lnd/lnwallet"
  21  	"github.com/lightningnetwork/lnd/lnwallet/chancloser"
  22  	"github.com/lightningnetwork/lnd/lnwire"
  23  	"github.com/lightningnetwork/lnd/routing/route"
  24  	"github.com/lightningnetwork/lnd/tlv"
  25  	"github.com/stretchr/testify/mock"
  26  	"github.com/stretchr/testify/require"
  27  )
  28  
  29  var (
  30  	// p2SHAddress is a valid pay to script hash address.
  31  	p2SHAddress = "2NBFNJTktNa7GZusGbDbGKRZTxdK9VVez3n"
  32  
  33  	// p2wshAddress is a valid pay to witness script hash address.
  34  	p2wshAddress = "bc1qrp33g0q5c5txsp9arysrx4k6zdkfs4nce4xj0gdcccefvpysxf3qccfmv3"
  35  )
  36  
  37  // TestPeerChannelClosureShutdownResponseLinkRemoved tests the shutdown
  38  // response we get if the link for the channel can't be found in the
  39  // switch. This test was added due to a regression.
  40  func TestPeerChannelClosureShutdownResponseLinkRemoved(t *testing.T) {
  41  	t.Parallel()
  42  
  43  	harness, err := createTestPeerWithChannel(t, noUpdate)
  44  	require.NoError(t, err, "unable to create test channels")
  45  
  46  	var (
  47  		alicePeer = harness.peer
  48  		bobChan   = harness.channel
  49  	)
  50  
  51  	chanPoint := bobChan.ChannelPoint()
  52  	chanID := lnwire.NewChanIDFromOutPoint(chanPoint)
  53  
  54  	dummyDeliveryScript := genScript(t, p2wshAddress)
  55  
  56  	// We send a shutdown request to Alice. She will now be the responding
  57  	// node in this shutdown procedure. We first expect Alice to answer
  58  	// this shutdown request with a Shutdown message.
  59  	alicePeer.chanCloseMsgs <- &closeMsg{
  60  		cid: chanID,
  61  		msg: lnwire.NewShutdown(chanID, dummyDeliveryScript),
  62  	}
  63  
  64  	var msg lnwire.Message
  65  	select {
  66  	case outMsg := <-alicePeer.outgoingQueue:
  67  		msg = outMsg.msg
  68  	case <-time.After(timeout):
  69  		t.Fatalf("did not receive shutdown message")
  70  	}
  71  
  72  	shutdownMsg, ok := msg.(*lnwire.Shutdown)
  73  	if !ok {
  74  		t.Fatalf("expected Shutdown message, got %T", msg)
  75  	}
  76  
  77  	require.NotEqualValues(t, shutdownMsg.Address, dummyDeliveryScript)
  78  }
  79  
  80  // TestPeerChannelClosureAcceptFeeResponder tests the shutdown responder's
  81  // behavior if we can agree on the fee immediately.
  82  func TestPeerChannelClosureAcceptFeeResponder(t *testing.T) {
  83  	t.Parallel()
  84  
  85  	harness, err := createTestPeerWithChannel(t, noUpdate)
  86  	require.NoError(t, err, "unable to create test channels")
  87  
  88  	var (
  89  		alicePeer       = harness.peer
  90  		bobChan         = harness.channel
  91  		mockSwitch      = harness.mockSwitch
  92  		broadcastTxChan = harness.publishTx
  93  		notifier        = harness.notifier
  94  	)
  95  
  96  	chanPoint := bobChan.ChannelPoint()
  97  	chanID := lnwire.NewChanIDFromOutPoint(chanPoint)
  98  
  99  	mockLink := newMockUpdateHandler(chanID)
 100  	mockSwitch.links = append(mockSwitch.links, mockLink)
 101  
 102  	dummyDeliveryScript := genScript(t, p2wshAddress)
 103  
 104  	// We send a shutdown request to Alice. She will now be the responding
 105  	// node in this shutdown procedure. We first expect Alice to answer
 106  	// this shutdown request with a Shutdown message.
 107  	alicePeer.chanCloseMsgs <- &closeMsg{
 108  		cid: chanID,
 109  		msg: lnwire.NewShutdown(chanID, dummyDeliveryScript),
 110  	}
 111  
 112  	var msg lnwire.Message
 113  	select {
 114  	case outMsg := <-alicePeer.outgoingQueue:
 115  		msg = outMsg.msg
 116  	case <-time.After(timeout):
 117  		t.Fatalf("did not receive shutdown message")
 118  	}
 119  
 120  	shutdownMsg, ok := msg.(*lnwire.Shutdown)
 121  	if !ok {
 122  		t.Fatalf("expected Shutdown message, got %T", msg)
 123  	}
 124  
 125  	respDeliveryScript := shutdownMsg.Address
 126  	require.NotEqualValues(t, respDeliveryScript, dummyDeliveryScript)
 127  
 128  	// Alice will then send a ClosingSigned message, indicating her proposed
 129  	// closing transaction fee. Alice sends the ClosingSigned message as she is
 130  	// the initiator of the channel.
 131  	select {
 132  	case outMsg := <-alicePeer.outgoingQueue:
 133  		msg = outMsg.msg
 134  	case <-time.After(timeout):
 135  		t.Fatalf("did not receive ClosingSigned message")
 136  	}
 137  
 138  	respClosingSigned, ok := msg.(*lnwire.ClosingSigned)
 139  	if !ok {
 140  		t.Fatalf("expected ClosingSigned message, got %T", msg)
 141  	}
 142  
 143  	// We accept the fee, and send a ClosingSigned with the same fee back,
 144  	// so she knows we agreed.
 145  	aliceFee := respClosingSigned.FeeSatoshis
 146  	bobSig, _, _, err := bobChan.CreateCloseProposal(
 147  		aliceFee, dummyDeliveryScript, respDeliveryScript,
 148  	)
 149  	require.NoError(t, err, "error creating close proposal")
 150  
 151  	parsedSig, err := lnwire.NewSigFromSignature(bobSig)
 152  	require.NoError(t, err, "error parsing signature")
 153  	closingSigned := lnwire.NewClosingSigned(chanID, aliceFee, parsedSig)
 154  	alicePeer.chanCloseMsgs <- &closeMsg{
 155  		cid: chanID,
 156  		msg: closingSigned,
 157  	}
 158  
 159  	// Alice should now see that we agreed on the fee, and should broadcast the
 160  	// closing transaction.
 161  	select {
 162  	case <-broadcastTxChan:
 163  	case <-time.After(timeout):
 164  		t.Fatalf("closing tx not broadcast")
 165  	}
 166  
 167  	// Need to pull the remaining message off of Alice's outgoing queue.
 168  	select {
 169  	case outMsg := <-alicePeer.outgoingQueue:
 170  		msg = outMsg.msg
 171  	case <-time.After(timeout):
 172  		t.Fatalf("did not receive ClosingSigned message")
 173  	}
 174  	if _, ok := msg.(*lnwire.ClosingSigned); !ok {
 175  		t.Fatalf("expected ClosingSigned message, got %T", msg)
 176  	}
 177  
 178  	// Alice should be waiting in a goroutine for a confirmation.
 179  	notifier.ConfChan <- &chainntnfs.TxConfirmation{}
 180  }
 181  
 182  // TestPeerChannelClosureAcceptFeeInitiator tests the shutdown initiator's
 183  // behavior if we can agree on the fee immediately.
 184  func TestPeerChannelClosureAcceptFeeInitiator(t *testing.T) {
 185  	t.Parallel()
 186  
 187  	harness, err := createTestPeerWithChannel(t, noUpdate)
 188  	require.NoError(t, err, "unable to create test channels")
 189  
 190  	var (
 191  		bobChan         = harness.channel
 192  		alicePeer       = harness.peer
 193  		mockSwitch      = harness.mockSwitch
 194  		broadcastTxChan = harness.publishTx
 195  		notifier        = harness.notifier
 196  	)
 197  
 198  	chanPoint := bobChan.ChannelPoint()
 199  	chanID := lnwire.NewChanIDFromOutPoint(chanPoint)
 200  	mockLink := newMockUpdateHandler(chanID)
 201  	mockSwitch.links = append(mockSwitch.links, mockLink)
 202  
 203  	dummyDeliveryScript := genScript(t, p2wshAddress)
 204  
 205  	// We make Alice send a shutdown request.
 206  	updateChan := make(chan interface{}, 1)
 207  	errChan := make(chan error, 1)
 208  	closeCommand := &htlcswitch.ChanClose{
 209  		CloseType:      contractcourt.CloseRegular,
 210  		ChanPoint:      &chanPoint,
 211  		Updates:        updateChan,
 212  		TargetFeePerKw: 12500,
 213  		Err:            errChan,
 214  	}
 215  	alicePeer.localCloseChanReqs <- closeCommand
 216  
 217  	// We can now pull a Shutdown message off of Alice's outgoingQueue.
 218  	var msg lnwire.Message
 219  	select {
 220  	case outMsg := <-alicePeer.outgoingQueue:
 221  		msg = outMsg.msg
 222  	case <-time.After(timeout):
 223  		t.Fatalf("did not receive shutdown request")
 224  	}
 225  
 226  	shutdownMsg, ok := msg.(*lnwire.Shutdown)
 227  	if !ok {
 228  		t.Fatalf("expected Shutdown message, got %T", msg)
 229  	}
 230  
 231  	aliceDeliveryScript := shutdownMsg.Address
 232  	require.NotEqualValues(t, aliceDeliveryScript, dummyDeliveryScript)
 233  
 234  	// Bob will respond with his own Shutdown message.
 235  	alicePeer.chanCloseMsgs <- &closeMsg{
 236  		cid: chanID,
 237  		msg: lnwire.NewShutdown(chanID,
 238  			dummyDeliveryScript),
 239  	}
 240  
 241  	// Alice will reply with a ClosingSigned here.
 242  	select {
 243  	case outMsg := <-alicePeer.outgoingQueue:
 244  		msg = outMsg.msg
 245  	case <-time.After(timeout):
 246  		t.Fatalf("did not receive closing signed message")
 247  	}
 248  	closingSignedMsg, ok := msg.(*lnwire.ClosingSigned)
 249  	if !ok {
 250  		t.Fatalf("expected to receive closing signed message, got %T", msg)
 251  	}
 252  
 253  	// Bob should reply with the exact same fee in his next ClosingSigned
 254  	// message.
 255  	bobFee := closingSignedMsg.FeeSatoshis
 256  	bobSig, _, _, err := bobChan.CreateCloseProposal(
 257  		bobFee, dummyDeliveryScript, aliceDeliveryScript,
 258  	)
 259  	require.NoError(t, err, "unable to create close proposal")
 260  	parsedSig, err := lnwire.NewSigFromSignature(bobSig)
 261  	require.NoError(t, err, "unable to parse signature")
 262  
 263  	closingSigned := lnwire.NewClosingSigned(shutdownMsg.ChannelID,
 264  		bobFee, parsedSig)
 265  	alicePeer.chanCloseMsgs <- &closeMsg{
 266  		cid: chanID,
 267  		msg: closingSigned,
 268  	}
 269  
 270  	// Alice should accept Bob's fee, broadcast the cooperative close tx, and
 271  	// send a ClosingSigned message back to Bob.
 272  
 273  	// Alice should now broadcast the closing transaction.
 274  	select {
 275  	case <-broadcastTxChan:
 276  	case <-time.After(timeout):
 277  		t.Fatalf("closing tx not broadcast")
 278  	}
 279  
 280  	// Alice should respond with the ClosingSigned they both agreed upon.
 281  	select {
 282  	case outMsg := <-alicePeer.outgoingQueue:
 283  		msg = outMsg.msg
 284  	case <-time.After(timeout):
 285  		t.Fatalf("did not receive closing signed message")
 286  	}
 287  
 288  	closingSignedMsg, ok = msg.(*lnwire.ClosingSigned)
 289  	if !ok {
 290  		t.Fatalf("expected ClosingSigned message, got %T", msg)
 291  	}
 292  
 293  	if closingSignedMsg.FeeSatoshis != bobFee {
 294  		t.Fatalf("expected ClosingSigned fee to be %v, instead got %v",
 295  			bobFee, closingSignedMsg.FeeSatoshis)
 296  	}
 297  
 298  	// Alice should be waiting on a single confirmation for the coop close tx.
 299  	notifier.ConfChan <- &chainntnfs.TxConfirmation{}
 300  }
 301  
 302  // TestPeerChannelClosureFeeNegotiationsResponder tests the shutdown
 303  // responder's behavior in the case where we must do several rounds of fee
 304  // negotiation before we agree on a fee.
 305  func TestPeerChannelClosureFeeNegotiationsResponder(t *testing.T) {
 306  	t.Parallel()
 307  
 308  	harness, err := createTestPeerWithChannel(t, noUpdate)
 309  	require.NoError(t, err, "unable to create test channels")
 310  
 311  	var (
 312  		bobChan         = harness.channel
 313  		alicePeer       = harness.peer
 314  		mockSwitch      = harness.mockSwitch
 315  		broadcastTxChan = harness.publishTx
 316  		notifier        = harness.notifier
 317  	)
 318  
 319  	chanPoint := bobChan.ChannelPoint()
 320  	chanID := lnwire.NewChanIDFromOutPoint(chanPoint)
 321  
 322  	mockLink := newMockUpdateHandler(chanID)
 323  	mockSwitch.links = append(mockSwitch.links, mockLink)
 324  
 325  	// Bob sends a shutdown request to Alice. She will now be the responding
 326  	// node in this shutdown procedure. We first expect Alice to answer this
 327  	// Shutdown request with a Shutdown message.
 328  	dummyDeliveryScript := genScript(t, p2wshAddress)
 329  	alicePeer.chanCloseMsgs <- &closeMsg{
 330  		cid: chanID,
 331  		msg: lnwire.NewShutdown(chanID,
 332  			dummyDeliveryScript),
 333  	}
 334  
 335  	var msg lnwire.Message
 336  	select {
 337  	case outMsg := <-alicePeer.outgoingQueue:
 338  		msg = outMsg.msg
 339  	case <-time.After(timeout):
 340  		t.Fatalf("did not receive shutdown message")
 341  	}
 342  
 343  	shutdownMsg, ok := msg.(*lnwire.Shutdown)
 344  	if !ok {
 345  		t.Fatalf("expected Shutdown message, got %T", msg)
 346  	}
 347  
 348  	aliceDeliveryScript := shutdownMsg.Address
 349  	require.NotEqualValues(t, aliceDeliveryScript, dummyDeliveryScript)
 350  
 351  	// As Alice is the channel initiator, she will send her ClosingSigned
 352  	// message.
 353  	select {
 354  	case outMsg := <-alicePeer.outgoingQueue:
 355  		msg = outMsg.msg
 356  	case <-time.After(timeout):
 357  		t.Fatalf("did not receive closing signed message")
 358  	}
 359  
 360  	aliceClosingSigned, ok := msg.(*lnwire.ClosingSigned)
 361  	if !ok {
 362  		t.Fatalf("expected ClosingSigned message, got %T", msg)
 363  	}
 364  
 365  	// Bob doesn't agree with the fee and will send one back that's 2.5x.
 366  	preferredRespFee := aliceClosingSigned.FeeSatoshis
 367  	increasedFee := btcutil.Amount(float64(preferredRespFee) * 2.5)
 368  	bobSig, _, _, err := bobChan.CreateCloseProposal(
 369  		increasedFee, dummyDeliveryScript, aliceDeliveryScript,
 370  	)
 371  	require.NoError(t, err, "error creating close proposal")
 372  
 373  	parsedSig, err := lnwire.NewSigFromSignature(bobSig)
 374  	require.NoError(t, err, "error parsing signature")
 375  	closingSigned := lnwire.NewClosingSigned(chanID, increasedFee, parsedSig)
 376  	alicePeer.chanCloseMsgs <- &closeMsg{
 377  		cid: chanID,
 378  		msg: closingSigned,
 379  	}
 380  
 381  	// Alice will now see the new fee we propose, but with current settings it
 382  	// won't accept it immediately as it differs too much by its ideal fee. We
 383  	// should get a new proposal back, which should have the average fee rate
 384  	// proposed.
 385  	select {
 386  	case outMsg := <-alicePeer.outgoingQueue:
 387  		msg = outMsg.msg
 388  	case <-time.After(timeout):
 389  		t.Fatalf("did not receive closing signed message")
 390  	}
 391  
 392  	aliceClosingSigned, ok = msg.(*lnwire.ClosingSigned)
 393  	if !ok {
 394  		t.Fatalf("expected ClosingSigned message, got %T", msg)
 395  	}
 396  
 397  	// The fee sent by Alice should be less than the fee Bob just sent as Alice
 398  	// should attempt to compromise.
 399  	aliceFee := aliceClosingSigned.FeeSatoshis
 400  	if aliceFee > increasedFee {
 401  		t.Fatalf("new fee should be less than our fee: new=%v, "+
 402  			"prior=%v", aliceFee, increasedFee)
 403  	}
 404  	lastFeeResponder := aliceFee
 405  
 406  	// We try negotiating a 2.1x fee, which should also be rejected.
 407  	increasedFee = btcutil.Amount(float64(preferredRespFee) * 2.1)
 408  	bobSig, _, _, err = bobChan.CreateCloseProposal(
 409  		increasedFee, dummyDeliveryScript, aliceDeliveryScript,
 410  	)
 411  	require.NoError(t, err, "error creating close proposal")
 412  
 413  	parsedSig, err = lnwire.NewSigFromSignature(bobSig)
 414  	require.NoError(t, err, "error parsing signature")
 415  	closingSigned = lnwire.NewClosingSigned(chanID, increasedFee, parsedSig)
 416  	alicePeer.chanCloseMsgs <- &closeMsg{
 417  		cid: chanID,
 418  		msg: closingSigned,
 419  	}
 420  
 421  	// Bob's latest proposal still won't be accepted and Alice should send over
 422  	// a new ClosingSigned message. It should be the average of what Bob and
 423  	// Alice each proposed last time.
 424  	select {
 425  	case outMsg := <-alicePeer.outgoingQueue:
 426  		msg = outMsg.msg
 427  	case <-time.After(timeout):
 428  		t.Fatalf("did not receive closing signed message")
 429  	}
 430  
 431  	aliceClosingSigned, ok = msg.(*lnwire.ClosingSigned)
 432  	if !ok {
 433  		t.Fatalf("expected ClosingSigned message, got %T", msg)
 434  	}
 435  
 436  	// Alice should inch towards Bob's fee, in order to compromise.
 437  	// Additionally, this fee should be less than the fee Bob sent before.
 438  	aliceFee = aliceClosingSigned.FeeSatoshis
 439  	if aliceFee < lastFeeResponder {
 440  		t.Fatalf("new fee should be greater than prior: new=%v, "+
 441  			"prior=%v", aliceFee, lastFeeResponder)
 442  	}
 443  	if aliceFee > increasedFee {
 444  		t.Fatalf("new fee should be less than Bob's fee: new=%v, "+
 445  			"prior=%v", aliceFee, increasedFee)
 446  	}
 447  
 448  	// Finally, Bob will accept the fee by echoing back the same fee that Alice
 449  	// just sent over.
 450  	bobSig, _, _, err = bobChan.CreateCloseProposal(
 451  		aliceFee, dummyDeliveryScript, aliceDeliveryScript,
 452  	)
 453  	require.NoError(t, err, "error creating close proposal")
 454  
 455  	parsedSig, err = lnwire.NewSigFromSignature(bobSig)
 456  	require.NoError(t, err, "error parsing signature")
 457  	closingSigned = lnwire.NewClosingSigned(chanID, aliceFee, parsedSig)
 458  	alicePeer.chanCloseMsgs <- &closeMsg{
 459  		cid: chanID,
 460  		msg: closingSigned,
 461  	}
 462  
 463  	// Alice will now see that Bob agreed on the fee, and broadcast the coop
 464  	// close transaction.
 465  	select {
 466  	case <-broadcastTxChan:
 467  	case <-time.After(timeout):
 468  		t.Fatalf("closing tx not broadcast")
 469  	}
 470  
 471  	// Alice should respond with the ClosingSigned they both agreed upon.
 472  	select {
 473  	case outMsg := <-alicePeer.outgoingQueue:
 474  		msg = outMsg.msg
 475  	case <-time.After(timeout):
 476  		t.Fatalf("did not receive closing signed message")
 477  	}
 478  	if _, ok := msg.(*lnwire.ClosingSigned); !ok {
 479  		t.Fatalf("expected to receive closing signed message, got %T", msg)
 480  	}
 481  
 482  	// Alice should be waiting on a single confirmation for the coop close tx.
 483  	notifier.ConfChan <- &chainntnfs.TxConfirmation{}
 484  }
 485  
 486  // TestPeerChannelClosureFeeNegotiationsInitiator tests the shutdown
 487  // initiator's behavior in the case where we must do several rounds of fee
 488  // negotiation before we agree on a fee.
 489  func TestPeerChannelClosureFeeNegotiationsInitiator(t *testing.T) {
 490  	t.Parallel()
 491  
 492  	harness, err := createTestPeerWithChannel(t, noUpdate)
 493  	require.NoError(t, err, "unable to create test channels")
 494  
 495  	var (
 496  		alicePeer       = harness.peer
 497  		bobChan         = harness.channel
 498  		mockSwitch      = harness.mockSwitch
 499  		broadcastTxChan = harness.publishTx
 500  		notifier        = harness.notifier
 501  	)
 502  
 503  	chanPoint := bobChan.ChannelPoint()
 504  	chanID := lnwire.NewChanIDFromOutPoint(chanPoint)
 505  	mockLink := newMockUpdateHandler(chanID)
 506  	mockSwitch.links = append(mockSwitch.links, mockLink)
 507  
 508  	// We make the initiator send a shutdown request.
 509  	updateChan := make(chan interface{}, 1)
 510  	errChan := make(chan error, 1)
 511  	closeCommand := &htlcswitch.ChanClose{
 512  		CloseType:      contractcourt.CloseRegular,
 513  		ChanPoint:      &chanPoint,
 514  		Updates:        updateChan,
 515  		TargetFeePerKw: 12500,
 516  		Err:            errChan,
 517  	}
 518  
 519  	alicePeer.localCloseChanReqs <- closeCommand
 520  
 521  	// Alice should now send a Shutdown request to Bob.
 522  	var msg lnwire.Message
 523  	select {
 524  	case outMsg := <-alicePeer.outgoingQueue:
 525  		msg = outMsg.msg
 526  	case <-time.After(timeout):
 527  		t.Fatalf("did not receive shutdown request")
 528  	}
 529  
 530  	shutdownMsg, ok := msg.(*lnwire.Shutdown)
 531  	if !ok {
 532  		t.Fatalf("expected Shutdown message, got %T", msg)
 533  	}
 534  
 535  	aliceDeliveryScript := shutdownMsg.Address
 536  
 537  	// Bob will answer the Shutdown message with his own Shutdown.
 538  	dummyDeliveryScript := genScript(t, p2wshAddress)
 539  	respShutdown := lnwire.NewShutdown(chanID, dummyDeliveryScript)
 540  	alicePeer.chanCloseMsgs <- &closeMsg{
 541  		cid: chanID,
 542  		msg: respShutdown,
 543  	}
 544  
 545  	// Alice should now respond with a ClosingSigned message with her ideal
 546  	// fee rate.
 547  	select {
 548  	case outMsg := <-alicePeer.outgoingQueue:
 549  		msg = outMsg.msg
 550  	case <-time.After(timeout):
 551  		t.Fatalf("did not receive closing signed")
 552  	}
 553  	closingSignedMsg, ok := msg.(*lnwire.ClosingSigned)
 554  	if !ok {
 555  		t.Fatalf("expected ClosingSigned message, got %T", msg)
 556  	}
 557  
 558  	idealFeeRate := closingSignedMsg.FeeSatoshis
 559  	lastReceivedFee := idealFeeRate
 560  
 561  	increasedFee := btcutil.Amount(float64(idealFeeRate) * 2.1)
 562  	lastSentFee := increasedFee
 563  
 564  	bobSig, _, _, err := bobChan.CreateCloseProposal(
 565  		increasedFee, dummyDeliveryScript, aliceDeliveryScript,
 566  	)
 567  	require.NoError(t, err, "error creating close proposal")
 568  
 569  	parsedSig, err := lnwire.NewSigFromSignature(bobSig)
 570  	require.NoError(t, err, "unable to parse signature")
 571  
 572  	closingSigned := lnwire.NewClosingSigned(chanID, increasedFee, parsedSig)
 573  	alicePeer.chanCloseMsgs <- &closeMsg{
 574  		cid: chanID,
 575  		msg: closingSigned,
 576  	}
 577  
 578  	// It still won't be accepted, and we should get a new proposal, the
 579  	// average of what we proposed, and what they proposed last time.
 580  	select {
 581  	case outMsg := <-alicePeer.outgoingQueue:
 582  		msg = outMsg.msg
 583  	case <-time.After(timeout):
 584  		t.Fatalf("did not receive closing signed")
 585  	}
 586  	closingSignedMsg, ok = msg.(*lnwire.ClosingSigned)
 587  	if !ok {
 588  		t.Fatalf("expected ClosingSigned message, got %T", msg)
 589  	}
 590  
 591  	aliceFee := closingSignedMsg.FeeSatoshis
 592  	if aliceFee < lastReceivedFee {
 593  		t.Fatalf("new fee should be greater than prior: new=%v, old=%v",
 594  			aliceFee, lastReceivedFee)
 595  	}
 596  	if aliceFee > lastSentFee {
 597  		t.Fatalf("new fee should be less than our fee: new=%v, old=%v",
 598  			aliceFee, lastSentFee)
 599  	}
 600  
 601  	lastReceivedFee = aliceFee
 602  
 603  	// We'll try negotiating a 1.5x fee, which should also be rejected.
 604  	increasedFee = btcutil.Amount(float64(idealFeeRate) * 1.5)
 605  	lastSentFee = increasedFee
 606  
 607  	bobSig, _, _, err = bobChan.CreateCloseProposal(
 608  		increasedFee, dummyDeliveryScript, aliceDeliveryScript,
 609  	)
 610  	require.NoError(t, err, "error creating close proposal")
 611  
 612  	parsedSig, err = lnwire.NewSigFromSignature(bobSig)
 613  	require.NoError(t, err, "error parsing signature")
 614  
 615  	closingSigned = lnwire.NewClosingSigned(chanID, increasedFee, parsedSig)
 616  	alicePeer.chanCloseMsgs <- &closeMsg{
 617  		cid: chanID,
 618  		msg: closingSigned,
 619  	}
 620  
 621  	// Alice won't accept Bob's new proposal, and Bob should receive a new
 622  	// proposal which is the average of what Bob proposed and Alice proposed
 623  	// last time.
 624  	select {
 625  	case outMsg := <-alicePeer.outgoingQueue:
 626  		msg = outMsg.msg
 627  	case <-time.After(timeout):
 628  		t.Fatalf("did not receive closing signed")
 629  	}
 630  	closingSignedMsg, ok = msg.(*lnwire.ClosingSigned)
 631  	if !ok {
 632  		t.Fatalf("expected ClosingSigned message, got %T", msg)
 633  	}
 634  
 635  	aliceFee = closingSignedMsg.FeeSatoshis
 636  	if aliceFee < lastReceivedFee {
 637  		t.Fatalf("new fee should be greater than prior: new=%v, old=%v",
 638  			aliceFee, lastReceivedFee)
 639  	}
 640  	if aliceFee > lastSentFee {
 641  		t.Fatalf("new fee should be less than Bob's fee: new=%v, old=%v",
 642  			aliceFee, lastSentFee)
 643  	}
 644  
 645  	// Bob will now accept their fee by sending back a ClosingSigned message
 646  	// with an identical fee.
 647  	bobSig, _, _, err = bobChan.CreateCloseProposal(
 648  		aliceFee, dummyDeliveryScript, aliceDeliveryScript,
 649  	)
 650  	require.NoError(t, err, "error creating close proposal")
 651  
 652  	parsedSig, err = lnwire.NewSigFromSignature(bobSig)
 653  	require.NoError(t, err, "error parsing signature")
 654  	closingSigned = lnwire.NewClosingSigned(chanID, aliceFee, parsedSig)
 655  	alicePeer.chanCloseMsgs <- &closeMsg{
 656  		cid: chanID,
 657  		msg: closingSigned,
 658  	}
 659  
 660  	// Wait for closing tx to be broadcasted.
 661  	select {
 662  	case <-broadcastTxChan:
 663  	case <-time.After(timeout):
 664  		t.Fatalf("closing tx not broadcast")
 665  	}
 666  
 667  	// Alice should respond with the ClosingSigned they both agreed upon.
 668  	select {
 669  	case outMsg := <-alicePeer.outgoingQueue:
 670  		msg = outMsg.msg
 671  	case <-time.After(timeout):
 672  		t.Fatalf("did not receive closing signed message")
 673  	}
 674  	if _, ok := msg.(*lnwire.ClosingSigned); !ok {
 675  		t.Fatalf("expected to receive closing signed message, got %T", msg)
 676  	}
 677  
 678  	// Alice should be waiting on a single confirmation for the coop close tx.
 679  	notifier.ConfChan <- &chainntnfs.TxConfirmation{}
 680  }
 681  
 682  // TestChooseDeliveryScript tests that chooseDeliveryScript correctly errors
 683  // when upfront and user set scripts that do not match are provided, allows
 684  // matching values and returns appropriate values in the case where one or none
 685  // are set.
 686  func TestChooseDeliveryScript(t *testing.T) {
 687  	// generate non-zero scripts for testing.
 688  	script1 := genScript(t, p2SHAddress)
 689  	script2 := genScript(t, p2wshAddress)
 690  
 691  	tests := []struct {
 692  		name           string
 693  		userScript     lnwire.DeliveryAddress
 694  		shutdownScript lnwire.DeliveryAddress
 695  		expectedScript lnwire.DeliveryAddress
 696  		newAddr        func() ([]byte, error)
 697  		expectedError  error
 698  	}{
 699  		{
 700  			name:           "Both set and equal",
 701  			userScript:     script1,
 702  			shutdownScript: script1,
 703  			expectedScript: script1,
 704  			expectedError:  nil,
 705  		},
 706  		{
 707  			name:           "Both set and not equal",
 708  			userScript:     script1,
 709  			shutdownScript: script2,
 710  			expectedScript: nil,
 711  			expectedError:  chancloser.ErrUpfrontShutdownScriptMismatch,
 712  		},
 713  		{
 714  			name:           "Only upfront script",
 715  			userScript:     nil,
 716  			shutdownScript: script1,
 717  			expectedScript: script1,
 718  			expectedError:  nil,
 719  		},
 720  		{
 721  			name:           "Only user script",
 722  			userScript:     script2,
 723  			shutdownScript: nil,
 724  			expectedScript: script2,
 725  			expectedError:  nil,
 726  		},
 727  		{
 728  			name:           "no script generate new one",
 729  			userScript:     nil,
 730  			shutdownScript: nil,
 731  			expectedScript: script2,
 732  			newAddr: func() ([]byte, error) {
 733  				return script2, nil
 734  			},
 735  			expectedError: nil,
 736  		},
 737  	}
 738  
 739  	for _, test := range tests {
 740  		test := test
 741  
 742  		t.Run(test.name, func(t *testing.T) {
 743  			script, err := chooseDeliveryScript(
 744  				test.shutdownScript, test.userScript,
 745  				test.newAddr,
 746  			)
 747  			if err != test.expectedError {
 748  				t.Fatalf("Expected: %v, got: %v",
 749  					test.expectedError, err)
 750  			}
 751  
 752  			if !bytes.Equal(script, test.expectedScript) {
 753  				t.Fatalf("Expected: %x, got: %x",
 754  					test.expectedScript, script)
 755  			}
 756  		})
 757  	}
 758  }
 759  
 760  // TestCustomShutdownScript tests that the delivery script of a shutdown
 761  // message can be set to a specified address. It checks that setting a close
 762  // script fails for channels which have an upfront shutdown script already set.
 763  func TestCustomShutdownScript(t *testing.T) {
 764  	script := genScript(t, p2SHAddress)
 765  
 766  	// setShutdown is a function which sets the upfront shutdown address for
 767  	// the local channel.
 768  	setShutdown := func(a, b *channeldb.OpenChannel) {
 769  		a.LocalShutdownScript = script
 770  		b.RemoteShutdownScript = script
 771  	}
 772  
 773  	tests := []struct {
 774  		name string
 775  
 776  		// update is a function used to set values on the channel set up for the
 777  		// test. It is used to set values for upfront shutdown addresses.
 778  		update func(a, b *channeldb.OpenChannel)
 779  
 780  		// userCloseScript is the address specified by the user.
 781  		userCloseScript lnwire.DeliveryAddress
 782  
 783  		// expectedScript is the address we expect to be set on the shutdown
 784  		// message.
 785  		expectedScript lnwire.DeliveryAddress
 786  
 787  		// expectedError is the error we expect, if any.
 788  		expectedError error
 789  	}{
 790  		{
 791  			name:            "User set script",
 792  			update:          noUpdate,
 793  			userCloseScript: script,
 794  			expectedScript:  script,
 795  		},
 796  		{
 797  			name:   "No user set script",
 798  			update: noUpdate,
 799  		},
 800  		{
 801  			name:           "Shutdown set, no user script",
 802  			update:         setShutdown,
 803  			expectedScript: script,
 804  		},
 805  		{
 806  			name:            "Shutdown set, user script matches",
 807  			update:          setShutdown,
 808  			userCloseScript: script,
 809  			expectedScript:  script,
 810  		},
 811  		{
 812  			name:            "Shutdown set, user script different",
 813  			update:          setShutdown,
 814  			userCloseScript: []byte("different addr"),
 815  			expectedError:   chancloser.ErrUpfrontShutdownScriptMismatch,
 816  		},
 817  	}
 818  
 819  	for _, test := range tests {
 820  		test := test
 821  
 822  		t.Run(test.name, func(t *testing.T) {
 823  			// Open a channel.
 824  			harness, err := createTestPeerWithChannel(
 825  				t, test.update,
 826  			)
 827  			if err != nil {
 828  				t.Fatalf("unable to create test channels: %v", err)
 829  			}
 830  
 831  			var (
 832  				alicePeer  = harness.peer
 833  				bobChan    = harness.channel
 834  				mockSwitch = harness.mockSwitch
 835  			)
 836  
 837  			chanPoint := bobChan.ChannelPoint()
 838  			chanID := lnwire.NewChanIDFromOutPoint(chanPoint)
 839  			mockLink := newMockUpdateHandler(chanID)
 840  			mockSwitch.links = append(mockSwitch.links, mockLink)
 841  
 842  			// Request initiator to cooperatively close the channel,
 843  			// with a specified delivery address.
 844  			updateChan := make(chan interface{}, 1)
 845  			errChan := make(chan error, 1)
 846  			closeCommand := htlcswitch.ChanClose{
 847  				CloseType:      contractcourt.CloseRegular,
 848  				ChanPoint:      &chanPoint,
 849  				Updates:        updateChan,
 850  				TargetFeePerKw: 12500,
 851  				DeliveryScript: test.userCloseScript,
 852  				Err:            errChan,
 853  			}
 854  
 855  			// Send the close command for the correct channel and check that a
 856  			// shutdown message is sent.
 857  			alicePeer.localCloseChanReqs <- &closeCommand
 858  
 859  			var msg lnwire.Message
 860  			select {
 861  			case outMsg := <-alicePeer.outgoingQueue:
 862  				msg = outMsg.msg
 863  			case <-time.After(timeout):
 864  				t.Fatalf("did not receive shutdown message")
 865  			case err := <-errChan:
 866  				// Fail if we do not expect an error.
 867  				if test.expectedError != nil {
 868  					require.ErrorIs(
 869  						t, err, test.expectedError,
 870  					)
 871  				}
 872  
 873  				// Terminate the test early if have received an error, no
 874  				// further action is expected.
 875  				return
 876  			}
 877  
 878  			// Check that we have received a shutdown message.
 879  			shutdownMsg, ok := msg.(*lnwire.Shutdown)
 880  			if !ok {
 881  				t.Fatalf("expected shutdown message, got %T", msg)
 882  			}
 883  
 884  			// If the test has not specified an expected address, do not check
 885  			// whether the shutdown address matches. This covers the case where
 886  			// we expect shutdown to a random address and cannot match it.
 887  			if len(test.expectedScript) == 0 {
 888  				return
 889  			}
 890  
 891  			// Check that the Shutdown message includes the expected delivery
 892  			// script.
 893  			if !bytes.Equal(test.expectedScript, shutdownMsg.Address) {
 894  				t.Fatalf("expected delivery script: %x, got: %x",
 895  					test.expectedScript, shutdownMsg.Address)
 896  			}
 897  		})
 898  	}
 899  }
 900  
 901  // TestStaticRemoteDowngrade tests that we downgrade our static remote feature
 902  // bit to optional if we have legacy channels with a peer. This ensures that
 903  // we can stay connected to peers that don't support the feature bit that we
 904  // have channels with.
 905  func TestStaticRemoteDowngrade(t *testing.T) {
 906  	t.Parallel()
 907  
 908  	var (
 909  		// We set the same legacy feature bits for all tests, since
 910  		// these are not relevant to our test scenario
 911  		rawLegacy = lnwire.NewRawFeatureVector(
 912  			lnwire.UpfrontShutdownScriptOptional,
 913  		)
 914  		legacy = lnwire.NewFeatureVector(rawLegacy, nil)
 915  
 916  		legacyCombinedOptional = lnwire.NewRawFeatureVector(
 917  			lnwire.UpfrontShutdownScriptOptional,
 918  			lnwire.StaticRemoteKeyOptional,
 919  		)
 920  
 921  		rawFeatureOptional = lnwire.NewRawFeatureVector(
 922  			lnwire.StaticRemoteKeyOptional,
 923  		)
 924  
 925  		featureOptional = lnwire.NewFeatureVector(
 926  			rawFeatureOptional, nil,
 927  		)
 928  
 929  		rawFeatureRequired = lnwire.NewRawFeatureVector(
 930  			lnwire.StaticRemoteKeyRequired,
 931  		)
 932  
 933  		featureRequired = lnwire.NewFeatureVector(
 934  			rawFeatureRequired, nil,
 935  		)
 936  	)
 937  
 938  	tests := []struct {
 939  		name         string
 940  		legacy       bool
 941  		features     *lnwire.FeatureVector
 942  		expectedInit *lnwire.Init
 943  	}{
 944  		{
 945  			name:     "no legacy channel, static optional",
 946  			legacy:   false,
 947  			features: featureOptional,
 948  			expectedInit: &lnwire.Init{
 949  				GlobalFeatures: rawLegacy,
 950  				Features:       rawFeatureOptional,
 951  			},
 952  		},
 953  		{
 954  			name:     "legacy channel, static optional",
 955  			legacy:   true,
 956  			features: featureOptional,
 957  			expectedInit: &lnwire.Init{
 958  				GlobalFeatures: rawLegacy,
 959  				Features:       rawFeatureOptional,
 960  			},
 961  		},
 962  		{
 963  			name:     "no legacy channel, static required",
 964  			legacy:   false,
 965  			features: featureRequired,
 966  			expectedInit: &lnwire.Init{
 967  				GlobalFeatures: rawLegacy,
 968  				Features:       rawFeatureRequired,
 969  			},
 970  		},
 971  
 972  		// In this case we need to flip our required bit to optional,
 973  		// this should also propagate to the legacy set of feature bits
 974  		// so we have proper consistency: a bit isn't set to optional
 975  		// in one field and required in the other.
 976  		{
 977  			name:     "legacy channel, static required",
 978  			legacy:   true,
 979  			features: featureRequired,
 980  			expectedInit: &lnwire.Init{
 981  				GlobalFeatures: legacyCombinedOptional,
 982  				Features:       rawFeatureOptional,
 983  			},
 984  		},
 985  	}
 986  
 987  	for _, test := range tests {
 988  		test := test
 989  
 990  		t.Run(test.name, func(t *testing.T) {
 991  			params := createTestPeer(t)
 992  
 993  			var (
 994  				p         = params.peer
 995  				mockConn  = params.mockConn
 996  				writePool = p.cfg.WritePool
 997  			)
 998  			// Set feature bits.
 999  			p.cfg.LegacyFeatures = legacy
1000  			p.cfg.Features = test.features
1001  
1002  			var b bytes.Buffer
1003  			_, err := lnwire.WriteMessage(&b, test.expectedInit, 0)
1004  			require.NoError(t, err)
1005  
1006  			// Send our init message, assert that we write our
1007  			// expected message and shutdown our write pool.
1008  			require.NoError(t, p.sendInitMsg(test.legacy))
1009  			mockConn.assertWrite(b.Bytes())
1010  			require.NoError(t, writePool.Stop())
1011  		})
1012  	}
1013  }
1014  
1015  // genScript creates a script paying out to the address provided, which must
1016  // be a valid address.
1017  func genScript(t *testing.T, address string) lnwire.DeliveryAddress {
1018  	// Generate an address which can be used for testing.
1019  	deliveryAddr, err := btcutil.DecodeAddress(
1020  		address,
1021  		&chaincfg.TestNet3Params,
1022  	)
1023  	require.NoError(t, err, "invalid delivery address")
1024  
1025  	script, err := txscript.PayToAddrScript(deliveryAddr)
1026  	require.NoError(t, err, "cannot create script")
1027  
1028  	return script
1029  }
1030  
1031  // TestPeerCustomMessage tests custom message exchange between peers.
1032  func TestPeerCustomMessage(t *testing.T) {
1033  	t.Parallel()
1034  
1035  	params := createTestPeer(t)
1036  
1037  	var (
1038  		mockConn           = params.mockConn
1039  		alicePeer          = params.peer
1040  		receivedCustomChan = params.customChan
1041  		remoteKey          = alicePeer.PubKey()
1042  	)
1043  
1044  	// Start peer.
1045  	startPeerDone := startPeer(t, mockConn, alicePeer)
1046  	_, err := fn.RecvOrTimeout(startPeerDone, 2*timeout)
1047  	require.NoError(t, err)
1048  
1049  	// Send a custom message.
1050  	customMsg, err := lnwire.NewCustom(
1051  		lnwire.MessageType(40000), []byte{1, 2, 3},
1052  	)
1053  	require.NoError(t, err)
1054  
1055  	require.NoError(t, alicePeer.SendMessageLazy(false, customMsg))
1056  
1057  	// Verify that it is passed down to the noise layer correctly.
1058  	writtenMsg := <-mockConn.writtenMessages
1059  	require.Equal(t, []byte{0x9c, 0x40, 0x1, 0x2, 0x3}, writtenMsg)
1060  
1061  	// Receive a custom message.
1062  	receivedCustomMsg, err := lnwire.NewCustom(
1063  		lnwire.MessageType(40001), []byte{4, 5, 6},
1064  	)
1065  	require.NoError(t, err)
1066  
1067  	receivedData := []byte{0x9c, 0x41, 0x4, 0x5, 0x6}
1068  	mockConn.readMessages <- receivedData
1069  
1070  	// Verify that it is propagated up to the custom message handler.
1071  	receivedCustom := <-receivedCustomChan
1072  	require.Equal(t, remoteKey, receivedCustom.peer)
1073  	require.Equal(t, receivedCustomMsg, &receivedCustom.msg)
1074  }
1075  
1076  // TestUpdateNextRevocation checks that the method `updateNextRevocation` is
1077  // behave as expected.
1078  func TestUpdateNextRevocation(t *testing.T) {
1079  	t.Parallel()
1080  
1081  	require := require.New(t)
1082  
1083  	harness, err := createTestPeerWithChannel(t, noUpdate)
1084  	require.NoError(err, "unable to create test channels")
1085  
1086  	bobChan := harness.channel
1087  	alicePeer := harness.peer
1088  
1089  	// testChannel is used to test the updateNextRevocation function.
1090  	testChannel := bobChan.State()
1091  
1092  	// Update the next revocation for a known channel should give us no
1093  	// error.
1094  	err = alicePeer.updateNextRevocation(testChannel)
1095  	require.NoError(err, "expected no error")
1096  
1097  	// Test an error is returned when the chanID cannot be found in
1098  	// `activeChannels` map.
1099  	testChannel.FundingOutpoint = wire.OutPoint{Index: 0}
1100  	err = alicePeer.updateNextRevocation(testChannel)
1101  	require.Error(err, "expected an error")
1102  
1103  	// Test an error is returned when the chanID's corresponding channel is
1104  	// nil.
1105  	testChannel.FundingOutpoint = wire.OutPoint{Index: 1}
1106  	chanID := lnwire.NewChanIDFromOutPoint(testChannel.FundingOutpoint)
1107  	alicePeer.activeChannels.Store(chanID, nil)
1108  
1109  	err = alicePeer.updateNextRevocation(testChannel)
1110  	require.Error(err, "expected an error")
1111  
1112  	// TODO(yy): should also test `InitNextRevocation` is called on
1113  	// `lnwallet.LightningWallet` once it's interfaced.
1114  }
1115  
1116  func assertMsgSent(t *testing.T, conn *mockMessageConn,
1117  	msgType lnwire.MessageType) {
1118  
1119  	t.Helper()
1120  
1121  	require := require.New(t)
1122  
1123  	rawMsg, err := fn.RecvOrTimeout(conn.writtenMessages, timeout)
1124  	require.NoError(err)
1125  
1126  	msgReader := bytes.NewReader(rawMsg)
1127  	msg, err := lnwire.ReadMessage(msgReader, 0)
1128  	require.NoError(err)
1129  
1130  	require.Equal(msgType, msg.MsgType())
1131  }
1132  
1133  // TestAlwaysSendChannelUpdate tests that each time we connect to the peer if
1134  // an active channel, we always send the latest channel update.
1135  func TestAlwaysSendChannelUpdate(t *testing.T) {
1136  	require := require.New(t)
1137  
1138  	var channel *channeldb.OpenChannel
1139  	channelIntercept := func(a, b *channeldb.OpenChannel) {
1140  		channel = a
1141  	}
1142  
1143  	harness, err := createTestPeerWithChannel(t, channelIntercept)
1144  	require.NoError(err, "unable to create test channels")
1145  
1146  	// Avoid the need to mock the channel graph by marking the channel
1147  	// borked. Borked channels still get a reestablish message sent on
1148  	// reconnect, while skipping channel graph checks and link creation.
1149  	require.NoError(channel.MarkBorked())
1150  
1151  	// Start the peer, which'll trigger the normal init and start up logic.
1152  	startPeerDone := startPeer(t, harness.mockConn, harness.peer)
1153  	_, err = fn.RecvOrTimeout(startPeerDone, 2*timeout)
1154  	require.NoError(err)
1155  
1156  	// Assert that we eventually send a channel update.
1157  	assertMsgSent(t, harness.mockConn, lnwire.MsgChannelReestablish)
1158  	assertMsgSent(t, harness.mockConn, lnwire.MsgChannelUpdate)
1159  }
1160  
1161  // TODO(yy): add test for `addActiveChannel` and `handleNewActiveChannel` once
1162  // we have interfaced `lnwallet.LightningChannel` and
1163  // `*contractcourt.ChainArbitrator`.
1164  
1165  // TestHandleNewPendingChannel checks the method `handleNewPendingChannel`
1166  // behaves as expected.
1167  func TestHandleNewPendingChannel(t *testing.T) {
1168  	t.Parallel()
1169  
1170  	// Create three channel IDs for testing.
1171  	chanIDActive := lnwire.ChannelID{0}
1172  	chanIDNotExist := lnwire.ChannelID{1}
1173  	chanIDPending := lnwire.ChannelID{2}
1174  
1175  	testCases := []struct {
1176  		name   string
1177  		chanID lnwire.ChannelID
1178  
1179  		// expectChanAdded specifies whether this chanID will be added
1180  		// to the peer's state.
1181  		expectChanAdded bool
1182  	}{
1183  		{
1184  			name:            "noop on active channel",
1185  			chanID:          chanIDActive,
1186  			expectChanAdded: false,
1187  		},
1188  		{
1189  			name:            "noop on pending channel",
1190  			chanID:          chanIDPending,
1191  			expectChanAdded: false,
1192  		},
1193  		{
1194  			name:            "new channel should be added",
1195  			chanID:          chanIDNotExist,
1196  			expectChanAdded: true,
1197  		},
1198  	}
1199  
1200  	for _, tc := range testCases {
1201  		tc := tc
1202  
1203  		// Create a request for testing.
1204  		errChan := make(chan error, 1)
1205  		req := &newChannelMsg{
1206  			channelID: tc.chanID,
1207  			err:       errChan,
1208  		}
1209  
1210  		t.Run(tc.name, func(t *testing.T) {
1211  			t.Parallel()
1212  			require := require.New(t)
1213  
1214  			// Create a test brontide.
1215  			dummyConfig := Config{}
1216  			peer := NewBrontide(dummyConfig)
1217  
1218  			// Create the test state.
1219  			peer.activeChannels.Store(
1220  				chanIDActive, &lnwallet.LightningChannel{},
1221  			)
1222  			peer.activeChannels.Store(chanIDPending, nil)
1223  
1224  			// Assert test state, we should have two channels
1225  			// store, one active and one pending.
1226  			numChans := 2
1227  			require.EqualValues(
1228  				numChans, peer.activeChannels.Len(),
1229  			)
1230  
1231  			// Call the method.
1232  			peer.handleNewPendingChannel(req)
1233  
1234  			// Add one if we expect this channel to be added.
1235  			if tc.expectChanAdded {
1236  				numChans++
1237  			}
1238  
1239  			// Assert the number of channels is correct.
1240  			require.Equal(numChans, peer.activeChannels.Len())
1241  
1242  			// Assert the request's error chan is closed.
1243  			err, ok := <-req.err
1244  			require.False(ok, "expect err chan to be closed")
1245  			require.NoError(err, "expect no error")
1246  		})
1247  	}
1248  }
1249  
1250  // TestHandleRemovePendingChannel checks the method
1251  // `handleRemovePendingChannel` behaves as expected.
1252  func TestHandleRemovePendingChannel(t *testing.T) {
1253  	t.Parallel()
1254  
1255  	// Create three channel IDs for testing.
1256  	chanIDActive := lnwire.ChannelID{0}
1257  	chanIDNotExist := lnwire.ChannelID{1}
1258  	chanIDPending := lnwire.ChannelID{2}
1259  
1260  	testCases := []struct {
1261  		name   string
1262  		chanID lnwire.ChannelID
1263  
1264  		// expectDeleted specifies whether this chanID will be removed
1265  		// from the peer's state.
1266  		expectDeleted bool
1267  	}{
1268  		{
1269  			name:          "noop on active channel",
1270  			chanID:        chanIDActive,
1271  			expectDeleted: false,
1272  		},
1273  		{
1274  			name:          "pending channel should be removed",
1275  			chanID:        chanIDPending,
1276  			expectDeleted: true,
1277  		},
1278  		{
1279  			name:          "noop on non-exist channel",
1280  			chanID:        chanIDNotExist,
1281  			expectDeleted: false,
1282  		},
1283  	}
1284  
1285  	for _, tc := range testCases {
1286  		tc := tc
1287  
1288  		// Create a request for testing.
1289  		errChan := make(chan error, 1)
1290  		req := &newChannelMsg{
1291  			channelID: tc.chanID,
1292  			err:       errChan,
1293  		}
1294  
1295  		// Create a test brontide.
1296  		dummyConfig := Config{}
1297  		peer := NewBrontide(dummyConfig)
1298  
1299  		// Create the test state.
1300  		peer.activeChannels.Store(
1301  			chanIDActive, &lnwallet.LightningChannel{},
1302  		)
1303  		peer.activeChannels.Store(chanIDPending, nil)
1304  
1305  		// Assert test state, we should have two channels store, one
1306  		// active and one pending.
1307  		require.Equal(t, 2, peer.activeChannels.Len())
1308  
1309  		t.Run(tc.name, func(t *testing.T) {
1310  			t.Parallel()
1311  
1312  			require := require.New(t)
1313  
1314  			// Get the number of channels before mutating the
1315  			// state.
1316  			numChans := peer.activeChannels.Len()
1317  
1318  			// Call the method.
1319  			peer.handleRemovePendingChannel(req)
1320  
1321  			// Minus one if we expect this channel to be removed.
1322  			if tc.expectDeleted {
1323  				numChans--
1324  			}
1325  
1326  			// Assert the number of channels is correct.
1327  			require.Equal(numChans, peer.activeChannels.Len())
1328  
1329  			// Assert the request's error chan is closed.
1330  			err, ok := <-req.err
1331  			require.False(ok, "expect err chan to be closed")
1332  			require.NoError(err, "expect no error")
1333  		})
1334  	}
1335  }
1336  
1337  // TestStartupWriteMessageRace checks that no data race occurs when starting up
1338  // a peer with an existing channel, while an outgoing message is queuing. Such
1339  // a race occurred in https://github.com/lightningnetwork/lnd/issues/8184, where
1340  // a channel reestablish message raced with another outgoing message.
1341  //
1342  // Note that races will only be detected with the Go race detector enabled.
1343  func TestStartupWriteMessageRace(t *testing.T) {
1344  	t.Parallel()
1345  
1346  	// Use a callback to extract the channel created by
1347  	// createTestPeerWithChannel, so we can mark it borked below.
1348  	// We can't mark it borked within the callback, since the channel hasn't
1349  	// been saved to the DB yet when the callback executes.
1350  	var channel *channeldb.OpenChannel
1351  	getChannels := func(a, b *channeldb.OpenChannel) {
1352  		channel = a
1353  	}
1354  
1355  	// createTestPeerWithChannel creates a peer and a channel with that
1356  	// peer.
1357  	harness, err := createTestPeerWithChannel(t, getChannels)
1358  	require.NoError(t, err, "unable to create test channel")
1359  
1360  	peer := harness.peer
1361  
1362  	// Avoid the need to mock the channel graph by marking the channel
1363  	// borked. Borked channels still get a reestablish message sent on
1364  	// reconnect, while skipping channel graph checks and link creation.
1365  	require.NoError(t, channel.MarkBorked())
1366  
1367  	// Use a mock conn to detect read/write races on the conn.
1368  	mockConn := newMockConn(t, 2)
1369  	peer.cfg.Conn = mockConn
1370  
1371  	// Send a message while starting the peer. As the peer starts up, it
1372  	// should not trigger a data race between the sending of this message
1373  	// and the sending of the channel reestablish message.
1374  	var sendPingDone = make(chan struct{})
1375  	go func() {
1376  		require.NoError(t, peer.SendMessage(true, lnwire.NewPing(0)))
1377  		close(sendPingDone)
1378  	}()
1379  
1380  	// Start the peer. No data race should occur.
1381  	startPeerDone := startPeer(t, mockConn, peer)
1382  
1383  	// Ensure startup is complete.
1384  	_, err = fn.RecvOrTimeout(startPeerDone, 2*timeout)
1385  	require.NoError(t, err)
1386  
1387  	// Ensure messages were sent during startup.
1388  	<-sendPingDone
1389  	for i := 0; i < 2; i++ {
1390  		select {
1391  		case <-mockConn.writtenMessages:
1392  		default:
1393  			t.Fatalf("Failed to send all messages during startup")
1394  		}
1395  	}
1396  }
1397  
1398  // TestRemovePendingChannel checks that we are able to remove a pending channel
1399  // successfully from the peers channel map. This also makes sure the
1400  // removePendingChannel is initialized so we don't send to a nil channel and
1401  // get stuck.
1402  func TestRemovePendingChannel(t *testing.T) {
1403  	t.Parallel()
1404  
1405  	// createTestPeerWithChannel creates a peer and a channel.
1406  	harness, err := createTestPeerWithChannel(t, noUpdate)
1407  	require.NoError(t, err, "unable to create test channel")
1408  
1409  	peer := harness.peer
1410  
1411  	// Add a pending channel to the peer Alice.
1412  	errChan := make(chan error, 1)
1413  	pendingChanID := lnwire.ChannelID{1}
1414  	req := &newChannelMsg{
1415  		channelID: pendingChanID,
1416  		err:       errChan,
1417  	}
1418  
1419  	select {
1420  	case peer.newPendingChannel <- req:
1421  		// Operation completed successfully
1422  	case <-time.After(timeout):
1423  		t.Fatalf("not able to remove pending channel")
1424  	}
1425  
1426  	// Make sure the channel was added as a pending channel.
1427  	// The peer was already created with one active channel therefore the
1428  	// `activeChannels` had already one channel prior to adding the new one.
1429  	// The `addedChannels` map only tracks new channels in the current life
1430  	// cycle therefore the initial channel is not part of it.
1431  	err = wait.NoError(func() error {
1432  		if peer.activeChannels.Len() == 2 &&
1433  			peer.addedChannels.Len() == 1 {
1434  
1435  			return nil
1436  		}
1437  
1438  		return fmt.Errorf("pending channel not successfully added")
1439  	}, wait.DefaultTimeout)
1440  
1441  	require.NoError(t, err)
1442  
1443  	// Now try to remove it, the errChan needs to be reopened because it was
1444  	// closed during the pending channel registration above.
1445  	errChan = make(chan error, 1)
1446  	req = &newChannelMsg{
1447  		channelID: pendingChanID,
1448  		err:       errChan,
1449  	}
1450  
1451  	select {
1452  	case peer.removePendingChannel <- req:
1453  		// Operation completed successfully
1454  	case <-time.After(timeout):
1455  		t.Fatalf("not able to remove pending channel")
1456  	}
1457  
1458  	// Make sure the pending channel is successfully removed from both
1459  	// channel maps.
1460  	// The initial channel between the peer is still active at this point.
1461  	err = wait.NoError(func() error {
1462  		if peer.activeChannels.Len() == 1 &&
1463  			peer.addedChannels.Len() == 0 {
1464  
1465  			return nil
1466  		}
1467  
1468  		return fmt.Errorf("pending channel not successfully removed")
1469  	}, wait.DefaultTimeout)
1470  
1471  	require.NoError(t, err)
1472  }
1473  
1474  // mockAuxTrafficShaper is a mock implementation of htlcswitch.AuxTrafficShaper
1475  // for testing the createHtlcValidator function.
1476  type mockAuxTrafficShaper struct {
1477  	mock.Mock
1478  }
1479  
1480  // ShouldHandleTraffic returns the configured mock values.
1481  func (m *mockAuxTrafficShaper) ShouldHandleTraffic(
1482  	cid lnwire.ShortChannelID,
1483  	fundingBlob, htlcBlob fn.Option[tlv.Blob]) (bool, error) {
1484  
1485  	args := m.Called(cid, fundingBlob, htlcBlob)
1486  	return args.Bool(0), args.Error(1)
1487  }
1488  
1489  // PaymentBandwidth returns the configured mock values.
1490  func (m *mockAuxTrafficShaper) PaymentBandwidth(fundingBlob, htlcBlob,
1491  	commitmentBlob fn.Option[tlv.Blob], linkBandwidth,
1492  	htlcAmt lnwire.MilliSatoshi, htlcView lnwallet.AuxHtlcView,
1493  	peer route.Vertex) (lnwire.MilliSatoshi, error) {
1494  
1495  	args := m.Called(
1496  		fundingBlob, htlcBlob, commitmentBlob, linkBandwidth,
1497  		htlcAmt, htlcView, peer,
1498  	)
1499  
1500  	bw, _ := args.Get(0).(lnwire.MilliSatoshi)
1501  
1502  	return bw, args.Error(1)
1503  }
1504  
1505  // ProduceHtlcExtraData is part of the AuxTrafficShaper interface.
1506  func (m *mockAuxTrafficShaper) ProduceHtlcExtraData(
1507  	totalAmount lnwire.MilliSatoshi,
1508  	htlcCustomRecords lnwire.CustomRecords,
1509  	peer route.Vertex) (lnwire.MilliSatoshi, lnwire.CustomRecords,
1510  	error) {
1511  
1512  	args := m.Called(totalAmount, htlcCustomRecords, peer)
1513  
1514  	amt, _ := args.Get(0).(lnwire.MilliSatoshi)
1515  	records, _ := args.Get(1).(lnwire.CustomRecords)
1516  
1517  	return amt, records, args.Error(2)
1518  }
1519  
1520  // IsCustomHTLC is part of the AuxTrafficShaper interface.
1521  func (m *mockAuxTrafficShaper) IsCustomHTLC(
1522  	htlcRecords lnwire.CustomRecords) bool {
1523  
1524  	args := m.Called(htlcRecords)
1525  	return args.Bool(0)
1526  }
1527  
1528  // Compile-time check that mockAuxTrafficShaper implements AuxTrafficShaper.
1529  var _ htlcswitch.AuxTrafficShaper = (*mockAuxTrafficShaper)(nil)
1530  
1531  // TestCreateHtlcValidator tests that the HTLC validator created by
1532  // createHtlcValidator respects the ShouldHandleTraffic check. When
1533  // ShouldHandleTraffic returns false, the validator should return nil without
1534  // calling PaymentBandwidth.
1535  func TestCreateHtlcValidator(t *testing.T) {
1536  	t.Parallel()
1537  
1538  	// Create a minimal Brontide with just the identity key set.
1539  	privKey, err := btcec.NewPrivateKey()
1540  	require.NoError(t, err)
1541  
1542  	peer := &Brontide{
1543  		cfg: Config{
1544  			Addr: &lnwire.NetAddress{
1545  				IdentityKey: privKey.PubKey(),
1546  			},
1547  		},
1548  	}
1549  
1550  	// Create a mock channel with minimal required fields.
1551  	dbChan := &channeldb.OpenChannel{
1552  		ShortChannelID: lnwire.NewShortChanIDFromInt(123),
1553  	}
1554  
1555  	anyArg := mock.Anything
1556  
1557  	testCases := []struct {
1558  		name        string
1559  		setupMock   func(*mockAuxTrafficShaper)
1560  		htlcAmount  lnwire.MilliSatoshi
1561  		linkBw      lnwire.MilliSatoshi
1562  		expectError bool
1563  	}{
1564  		{
1565  			name: "non-custom channel skips check",
1566  			setupMock: func(m *mockAuxTrafficShaper) {
1567  				m.On(
1568  					"ShouldHandleTraffic",
1569  					anyArg, anyArg, anyArg,
1570  				).Return(false, nil)
1571  			},
1572  			htlcAmount:  1000,
1573  			linkBw:      5000,
1574  			expectError: false,
1575  		},
1576  		{
1577  			name: "sufficient bandwidth",
1578  			setupMock: func(m *mockAuxTrafficShaper) {
1579  				m.On(
1580  					"ShouldHandleTraffic",
1581  					anyArg, anyArg, anyArg,
1582  				).Return(true, nil)
1583  				m.On(
1584  					"PaymentBandwidth",
1585  					anyArg, anyArg, anyArg,
1586  					anyArg, anyArg, anyArg,
1587  					anyArg,
1588  				).Return(
1589  					lnwire.MilliSatoshi(10000),
1590  					nil,
1591  				)
1592  			},
1593  			htlcAmount:  1000,
1594  			linkBw:      5000,
1595  			expectError: false,
1596  		},
1597  		{
1598  			name: "insufficient bandwidth",
1599  			setupMock: func(m *mockAuxTrafficShaper) {
1600  				m.On(
1601  					"ShouldHandleTraffic",
1602  					anyArg, anyArg, anyArg,
1603  				).Return(true, nil)
1604  				m.On(
1605  					"PaymentBandwidth",
1606  					anyArg, anyArg, anyArg,
1607  					anyArg, anyArg, anyArg,
1608  					anyArg,
1609  				).Return(
1610  					lnwire.MilliSatoshi(500),
1611  					nil,
1612  				)
1613  			},
1614  			htlcAmount:  1000,
1615  			linkBw:      5000,
1616  			expectError: true,
1617  		},
1618  		{
1619  			name: "ShouldHandleTraffic error",
1620  			setupMock: func(m *mockAuxTrafficShaper) {
1621  				m.On(
1622  					"ShouldHandleTraffic",
1623  					anyArg, anyArg, anyArg,
1624  				).Return(
1625  					false,
1626  					fmt.Errorf("shaper error"),
1627  				)
1628  			},
1629  			htlcAmount:  1000,
1630  			linkBw:      5000,
1631  			expectError: true,
1632  		},
1633  		{
1634  			name: "PaymentBandwidth error",
1635  			setupMock: func(m *mockAuxTrafficShaper) {
1636  				m.On(
1637  					"ShouldHandleTraffic",
1638  					anyArg, anyArg, anyArg,
1639  				).Return(true, nil)
1640  				m.On(
1641  					"PaymentBandwidth",
1642  					anyArg, anyArg, anyArg,
1643  					anyArg, anyArg, anyArg,
1644  					anyArg,
1645  				).Return(
1646  					lnwire.MilliSatoshi(0),
1647  					fmt.Errorf("bandwidth error"),
1648  				)
1649  			},
1650  			htlcAmount:  1000,
1651  			linkBw:      5000,
1652  			expectError: true,
1653  		},
1654  	}
1655  
1656  	for _, tc := range testCases {
1657  		t.Run(tc.name, func(t *testing.T) {
1658  			m := &mockAuxTrafficShaper{}
1659  			tc.setupMock(m)
1660  
1661  			validator := peer.createHtlcValidator(
1662  				dbChan, m,
1663  			)
1664  
1665  			err := validator.ValidateHtlc(
1666  				tc.htlcAmount, tc.linkBw,
1667  				nil, lnwallet.AuxHtlcView{},
1668  			)
1669  
1670  			if tc.expectError {
1671  				require.Error(t, err)
1672  			} else {
1673  				require.NoError(t, err)
1674  			}
1675  
1676  			m.AssertExpectations(t)
1677  		})
1678  	}
1679  }