/ htlcswitch / switch_test.go
switch_test.go
   1  package htlcswitch
   2  
   3  import (
   4  	"crypto/rand"
   5  	"crypto/sha256"
   6  	"errors"
   7  	"fmt"
   8  	"io"
   9  	mrand "math/rand"
  10  	"reflect"
  11  	"testing"
  12  	"time"
  13  
  14  	"github.com/btcsuite/btcd/btcutil"
  15  	"github.com/davecgh/go-spew/spew"
  16  	"github.com/lightningnetwork/lnd/chainntnfs"
  17  	"github.com/lightningnetwork/lnd/channeldb"
  18  	"github.com/lightningnetwork/lnd/contractcourt"
  19  	"github.com/lightningnetwork/lnd/fn/v2"
  20  	"github.com/lightningnetwork/lnd/graph/db/models"
  21  	"github.com/lightningnetwork/lnd/htlcswitch/hodl"
  22  	"github.com/lightningnetwork/lnd/htlcswitch/hop"
  23  	"github.com/lightningnetwork/lnd/lntest/mock"
  24  	"github.com/lightningnetwork/lnd/lntest/wait"
  25  	"github.com/lightningnetwork/lnd/lntypes"
  26  	"github.com/lightningnetwork/lnd/lnwallet/chainfee"
  27  	"github.com/lightningnetwork/lnd/lnwire"
  28  	"github.com/lightningnetwork/lnd/ticker"
  29  	"github.com/stretchr/testify/require"
  30  )
  31  
  32  var zeroCircuit = models.CircuitKey{}
  33  var emptyScid = lnwire.ShortChannelID{}
  34  
  35  func genPreimage() ([32]byte, error) {
  36  	var preimage [32]byte
  37  	if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil {
  38  		return preimage, err
  39  	}
  40  	return preimage, nil
  41  }
  42  
  43  // TestSwitchAddDuplicateLink tests that the switch will reject duplicate links
  44  // for live links. It also tests that we can successfully add a link after
  45  // having removed it.
  46  func TestSwitchAddDuplicateLink(t *testing.T) {
  47  	t.Parallel()
  48  
  49  	alicePeer, err := newMockServer(
  50  		t, "alice", testStartingHeight, nil, testDefaultDelta,
  51  	)
  52  	require.NoError(t, err, "unable to create alice server")
  53  
  54  	s, err := initSwitchWithTempDB(t, testStartingHeight)
  55  	require.NoError(t, err, "unable to init switch")
  56  	if err := s.Start(); err != nil {
  57  		t.Fatalf("unable to start switch: %v", err)
  58  	}
  59  	defer s.Stop()
  60  
  61  	chanID1, aliceScid := genID()
  62  
  63  	aliceChannelLink := newMockChannelLink(
  64  		s, chanID1, aliceScid, emptyScid, alicePeer, false, false,
  65  		false, false,
  66  	)
  67  	if err := s.AddLink(aliceChannelLink); err != nil {
  68  		t.Fatalf("unable to add alice link: %v", err)
  69  	}
  70  
  71  	// Alice should have a live link, adding again should fail.
  72  	if err := s.AddLink(aliceChannelLink); err == nil {
  73  		t.Fatalf("adding duplicate link should have failed")
  74  	}
  75  
  76  	// Remove the live link to ensure the indexes are cleared.
  77  	s.RemoveLink(chanID1)
  78  
  79  	// Alice has no links, adding should succeed.
  80  	if err := s.AddLink(aliceChannelLink); err != nil {
  81  		t.Fatalf("unable to add alice link: %v", err)
  82  	}
  83  }
  84  
  85  // TestSwitchHasActiveLink tests the behavior of HasActiveLink, and asserts that
  86  // it only returns true if a link's short channel id has confirmed (meaning the
  87  // channel is no longer pending) and it's EligibleToForward method returns true,
  88  // i.e. it has received ChannelReady from the remote peer.
  89  func TestSwitchHasActiveLink(t *testing.T) {
  90  	t.Parallel()
  91  
  92  	alicePeer, err := newMockServer(
  93  		t, "alice", testStartingHeight, nil, testDefaultDelta,
  94  	)
  95  	require.NoError(t, err, "unable to create alice server")
  96  
  97  	s, err := initSwitchWithTempDB(t, testStartingHeight)
  98  	require.NoError(t, err, "unable to init switch")
  99  	if err := s.Start(); err != nil {
 100  		t.Fatalf("unable to start switch: %v", err)
 101  	}
 102  	defer s.Stop()
 103  
 104  	chanID1, aliceScid := genID()
 105  
 106  	aliceChannelLink := newMockChannelLink(
 107  		s, chanID1, aliceScid, emptyScid, alicePeer, false, false,
 108  		false, false,
 109  	)
 110  	if err := s.AddLink(aliceChannelLink); err != nil {
 111  		t.Fatalf("unable to add alice link: %v", err)
 112  	}
 113  
 114  	// The link has been added, but it's still pending. HasActiveLink should
 115  	// return false since the link has not been added to the linkIndex
 116  	// containing live links.
 117  	if s.HasActiveLink(chanID1) {
 118  		t.Fatalf("link should not be active yet, still pending")
 119  	}
 120  
 121  	// Finally, simulate the link receiving channel_ready by setting its
 122  	// eligibility to true.
 123  	aliceChannelLink.eligible = true
 124  
 125  	// The link should now be reported as active, since EligibleToForward
 126  	// returns true and the link is in the linkIndex.
 127  	if !s.HasActiveLink(chanID1) {
 128  		t.Fatalf("link should not be active now")
 129  	}
 130  }
 131  
 132  // TestSwitchSendPending checks the inability of htlc switch to forward adds
 133  // over pending links.
 134  func TestSwitchSendPending(t *testing.T) {
 135  	t.Parallel()
 136  
 137  	alicePeer, err := newMockServer(
 138  		t, "alice", testStartingHeight, nil, testDefaultDelta,
 139  	)
 140  	require.NoError(t, err, "unable to create alice server")
 141  
 142  	bobPeer, err := newMockServer(
 143  		t, "bob", testStartingHeight, nil, testDefaultDelta,
 144  	)
 145  	require.NoError(t, err, "unable to create bob server")
 146  
 147  	s, err := initSwitchWithTempDB(t, testStartingHeight)
 148  	require.NoError(t, err, "unable to init switch")
 149  	if err := s.Start(); err != nil {
 150  		t.Fatalf("unable to start switch: %v", err)
 151  	}
 152  	defer s.Stop()
 153  
 154  	chanID1, chanID2, aliceChanID, bobChanID := genIDs()
 155  
 156  	pendingChanID := lnwire.ShortChannelID{}
 157  
 158  	aliceChannelLink := newMockChannelLink(
 159  		s, chanID1, pendingChanID, emptyScid, alicePeer, false, false,
 160  		false, false,
 161  	)
 162  	if err := s.AddLink(aliceChannelLink); err != nil {
 163  		t.Fatalf("unable to add alice link: %v", err)
 164  	}
 165  
 166  	bobChannelLink := newMockChannelLink(
 167  		s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
 168  		false,
 169  	)
 170  	if err := s.AddLink(bobChannelLink); err != nil {
 171  		t.Fatalf("unable to add bob link: %v", err)
 172  	}
 173  
 174  	// Create request which should is being forwarded from Bob channel
 175  	// link to Alice channel link.
 176  	preimage, err := genPreimage()
 177  	require.NoError(t, err, "unable to generate preimage")
 178  	rhash := sha256.Sum256(preimage[:])
 179  	packet := &htlcPacket{
 180  		incomingChanID: bobChanID,
 181  		incomingHTLCID: 0,
 182  		outgoingChanID: aliceChanID,
 183  		obfuscator:     NewMockObfuscator(),
 184  		htlc: &lnwire.UpdateAddHTLC{
 185  			PaymentHash: rhash,
 186  			Amount:      1,
 187  		},
 188  	}
 189  
 190  	// Send the ADD packet, this should not be forwarded out to the link
 191  	// since there are no eligible links.
 192  	if err = s.ForwardPackets(nil, packet); err != nil {
 193  		t.Fatal(err)
 194  	}
 195  	select {
 196  	case p := <-bobChannelLink.packets:
 197  		if p.linkFailure != nil {
 198  			err = p.linkFailure
 199  		}
 200  	case <-time.After(time.Second):
 201  		t.Fatal("no timely reply from switch")
 202  	}
 203  	linkErr, ok := err.(*LinkError)
 204  	if !ok {
 205  		t.Fatalf("expected link error, got: %T", err)
 206  	}
 207  	if linkErr.WireMessage().Code() != lnwire.CodeUnknownNextPeer {
 208  		t.Fatalf("expected fail unknown next peer, got: %T",
 209  			linkErr.WireMessage().Code())
 210  	}
 211  
 212  	// No message should be sent, since the packet was failed.
 213  	select {
 214  	case <-aliceChannelLink.packets:
 215  		t.Fatal("expected not to receive message")
 216  	case <-time.After(time.Second):
 217  	}
 218  
 219  	// Since the packet should have been failed, there should be no active
 220  	// circuits.
 221  	if s.circuits.NumOpen() != 0 {
 222  		t.Fatal("wrong amount of circuits")
 223  	}
 224  }
 225  
 226  // TestSwitchForwardMapping checks that the Switch properly consults its maps
 227  // when forwarding packets.
 228  func TestSwitchForwardMapping(t *testing.T) {
 229  	tests := []struct {
 230  		name string
 231  
 232  		// If this is true, then Alice's channel will be private.
 233  		alicePrivate bool
 234  
 235  		// If this is true, then Alice's channel will be a zero-conf
 236  		// channel.
 237  		zeroConf bool
 238  
 239  		// If this is true, then Alice's channel will be an
 240  		// option-scid-alias feature-bit, non-zero-conf channel.
 241  		optionScid bool
 242  
 243  		// If this is true, then an alias will be used for forwarding.
 244  		useAlias bool
 245  
 246  		// This is Alice's channel alias. This may not be set if this
 247  		// is not an option_scid_alias channel (feature bit).
 248  		aliceAlias lnwire.ShortChannelID
 249  
 250  		// This is Alice's confirmed SCID. This may not be set if this
 251  		// is a zero-conf channel before confirmation.
 252  		aliceReal lnwire.ShortChannelID
 253  
 254  		// If this is set, we expect Bob forwarding to Alice to fail.
 255  		expectErr bool
 256  	}{
 257  		{
 258  			name:         "private unconfirmed zero-conf",
 259  			alicePrivate: true,
 260  			zeroConf:     true,
 261  			useAlias:     true,
 262  			aliceAlias: lnwire.ShortChannelID{
 263  				BlockHeight: 16_000_002,
 264  				TxIndex:     2,
 265  				TxPosition:  2,
 266  			},
 267  			aliceReal: lnwire.ShortChannelID{},
 268  			expectErr: false,
 269  		},
 270  		{
 271  			name:         "private confirmed zero-conf",
 272  			alicePrivate: true,
 273  			zeroConf:     true,
 274  			useAlias:     true,
 275  			aliceAlias: lnwire.ShortChannelID{
 276  				BlockHeight: 16_000_003,
 277  				TxIndex:     3,
 278  				TxPosition:  3,
 279  			},
 280  			aliceReal: lnwire.ShortChannelID{
 281  				BlockHeight: 300000,
 282  				TxIndex:     3,
 283  				TxPosition:  3,
 284  			},
 285  			expectErr: false,
 286  		},
 287  		{
 288  			name:         "private confirmed zero-conf failure",
 289  			alicePrivate: true,
 290  			zeroConf:     true,
 291  			useAlias:     false,
 292  			aliceAlias: lnwire.ShortChannelID{
 293  				BlockHeight: 16_000_004,
 294  				TxIndex:     4,
 295  				TxPosition:  4,
 296  			},
 297  			aliceReal: lnwire.ShortChannelID{
 298  				BlockHeight: 300002,
 299  				TxIndex:     4,
 300  				TxPosition:  4,
 301  			},
 302  			expectErr: true,
 303  		},
 304  		{
 305  			name:         "public unconfirmed zero-conf",
 306  			alicePrivate: false,
 307  			zeroConf:     true,
 308  			useAlias:     true,
 309  			aliceAlias: lnwire.ShortChannelID{
 310  				BlockHeight: 16_000_005,
 311  				TxIndex:     5,
 312  				TxPosition:  5,
 313  			},
 314  			aliceReal: lnwire.ShortChannelID{},
 315  			expectErr: false,
 316  		},
 317  		{
 318  			name:         "public confirmed zero-conf w/ alias",
 319  			alicePrivate: false,
 320  			zeroConf:     true,
 321  			useAlias:     true,
 322  			aliceAlias: lnwire.ShortChannelID{
 323  				BlockHeight: 16_000_006,
 324  				TxIndex:     6,
 325  				TxPosition:  6,
 326  			},
 327  			aliceReal: lnwire.ShortChannelID{
 328  				BlockHeight: 500000,
 329  				TxIndex:     6,
 330  				TxPosition:  6,
 331  			},
 332  			expectErr: false,
 333  		},
 334  		{
 335  			name:         "public confirmed zero-conf w/ real",
 336  			alicePrivate: false,
 337  			zeroConf:     true,
 338  			useAlias:     false,
 339  			aliceAlias: lnwire.ShortChannelID{
 340  				BlockHeight: 16_000_007,
 341  				TxIndex:     7,
 342  				TxPosition:  7,
 343  			},
 344  			aliceReal: lnwire.ShortChannelID{
 345  				BlockHeight: 502000,
 346  				TxIndex:     7,
 347  				TxPosition:  7,
 348  			},
 349  			expectErr: false,
 350  		},
 351  		{
 352  			name:         "private non-option channel",
 353  			alicePrivate: true,
 354  			aliceAlias:   lnwire.ShortChannelID{},
 355  			aliceReal: lnwire.ShortChannelID{
 356  				BlockHeight: 505000,
 357  				TxIndex:     8,
 358  				TxPosition:  8,
 359  			},
 360  		},
 361  		{
 362  			name:         "private option channel w/ alias",
 363  			alicePrivate: true,
 364  			optionScid:   true,
 365  			useAlias:     true,
 366  			aliceAlias: lnwire.ShortChannelID{
 367  				BlockHeight: 16_000_015,
 368  				TxIndex:     9,
 369  				TxPosition:  9,
 370  			},
 371  			aliceReal: lnwire.ShortChannelID{
 372  				BlockHeight: 506000,
 373  				TxIndex:     10,
 374  				TxPosition:  10,
 375  			},
 376  			expectErr: false,
 377  		},
 378  		{
 379  			name:         "private option channel failure",
 380  			alicePrivate: true,
 381  			optionScid:   true,
 382  			useAlias:     false,
 383  			aliceAlias: lnwire.ShortChannelID{
 384  				BlockHeight: 16_000_016,
 385  				TxIndex:     16,
 386  				TxPosition:  16,
 387  			},
 388  			aliceReal: lnwire.ShortChannelID{
 389  				BlockHeight: 507000,
 390  				TxIndex:     17,
 391  				TxPosition:  17,
 392  			},
 393  			expectErr: true,
 394  		},
 395  		{
 396  			name:         "public non-option channel",
 397  			alicePrivate: false,
 398  			useAlias:     false,
 399  			aliceAlias:   lnwire.ShortChannelID{},
 400  			aliceReal: lnwire.ShortChannelID{
 401  				BlockHeight: 508000,
 402  				TxIndex:     17,
 403  				TxPosition:  17,
 404  			},
 405  			expectErr: false,
 406  		},
 407  		{
 408  			name:         "public option channel w/ alias",
 409  			alicePrivate: false,
 410  			optionScid:   true,
 411  			useAlias:     true,
 412  			aliceAlias: lnwire.ShortChannelID{
 413  				BlockHeight: 16_000_018,
 414  				TxIndex:     18,
 415  				TxPosition:  18,
 416  			},
 417  			aliceReal: lnwire.ShortChannelID{
 418  				BlockHeight: 509000,
 419  				TxIndex:     19,
 420  				TxPosition:  19,
 421  			},
 422  			expectErr: false,
 423  		},
 424  		{
 425  			name:         "public option channel w/ real",
 426  			alicePrivate: false,
 427  			optionScid:   true,
 428  			useAlias:     false,
 429  			aliceAlias: lnwire.ShortChannelID{
 430  				BlockHeight: 16_000_019,
 431  				TxIndex:     19,
 432  				TxPosition:  19,
 433  			},
 434  			aliceReal: lnwire.ShortChannelID{
 435  				BlockHeight: 510000,
 436  				TxIndex:     20,
 437  				TxPosition:  20,
 438  			},
 439  			expectErr: false,
 440  		},
 441  	}
 442  
 443  	for _, test := range tests {
 444  		test := test
 445  		t.Run(test.name, func(t *testing.T) {
 446  			t.Parallel()
 447  			testSwitchForwardMapping(
 448  				t, test.alicePrivate, test.zeroConf,
 449  				test.useAlias, test.optionScid,
 450  				test.aliceAlias, test.aliceReal,
 451  				test.expectErr,
 452  			)
 453  		})
 454  	}
 455  }
 456  
 457  func testSwitchForwardMapping(t *testing.T, alicePrivate, aliceZeroConf,
 458  	useAlias, optionScid bool, aliceAlias, aliceReal lnwire.ShortChannelID,
 459  	expectErr bool) {
 460  
 461  	alicePeer, err := newMockServer(
 462  		t, "alice", testStartingHeight, nil, testDefaultDelta,
 463  	)
 464  	require.NoError(t, err)
 465  
 466  	bobPeer, err := newMockServer(
 467  		t, "bob", testStartingHeight, nil, testDefaultDelta,
 468  	)
 469  	require.NoError(t, err)
 470  
 471  	s, err := initSwitchWithTempDB(t, testStartingHeight)
 472  	require.NoError(t, err)
 473  	err = s.Start()
 474  	require.NoError(t, err)
 475  	defer func() { _ = s.Stop() }()
 476  
 477  	// Create the lnwire.ChannelIDs that we'll use.
 478  	chanID1, chanID2, _, _ := genIDs()
 479  
 480  	var aliceChannelLink *mockChannelLink
 481  
 482  	if aliceZeroConf {
 483  		aliceChannelLink = newMockChannelLink(
 484  			s, chanID1, aliceAlias, aliceReal, alicePeer, true,
 485  			alicePrivate, true, false,
 486  		)
 487  	} else {
 488  		aliceChannelLink = newMockChannelLink(
 489  			s, chanID1, aliceReal, emptyScid, alicePeer, true,
 490  			alicePrivate, false, optionScid,
 491  		)
 492  
 493  		if optionScid {
 494  			aliceChannelLink.addAlias(aliceAlias)
 495  		}
 496  	}
 497  
 498  	err = s.AddLink(aliceChannelLink)
 499  	require.NoError(t, err)
 500  
 501  	// Bob will just have a non-option_scid_alias channel so no mapping is
 502  	// necessary.
 503  	bobScid := lnwire.ShortChannelID{
 504  		BlockHeight: 501000,
 505  		TxIndex:     200,
 506  		TxPosition:  2,
 507  	}
 508  
 509  	bobChannelLink := newMockChannelLink(
 510  		s, chanID2, bobScid, emptyScid, bobPeer, true, false, false,
 511  		false,
 512  	)
 513  	err = s.AddLink(bobChannelLink)
 514  	require.NoError(t, err)
 515  
 516  	// Generate preimage.
 517  	preimage, err := genPreimage()
 518  	require.NoError(t, err, "unable to generate preimage")
 519  	rhash := sha256.Sum256(preimage[:])
 520  
 521  	// Determine the outgoing SCID to use.
 522  	outgoingSCID := aliceReal
 523  	if useAlias {
 524  		outgoingSCID = aliceAlias
 525  	}
 526  
 527  	packet := &htlcPacket{
 528  		incomingChanID: bobScid,
 529  		incomingHTLCID: 0,
 530  		outgoingChanID: outgoingSCID,
 531  		obfuscator:     NewMockObfuscator(),
 532  		htlc: &lnwire.UpdateAddHTLC{
 533  			PaymentHash: rhash,
 534  			Amount:      1,
 535  		},
 536  	}
 537  	err = s.ForwardPackets(nil, packet)
 538  	require.NoError(t, err)
 539  
 540  	// If we expect a forwarding error, then assert that we receive one.
 541  	// option_scid_alias forwards may fail if forwarding would be a privacy
 542  	// leak.
 543  	if expectErr {
 544  		select {
 545  		case <-bobChannelLink.packets:
 546  		case <-time.After(time.Second * 5):
 547  			t.Fatal("expected a forwarding error")
 548  		}
 549  
 550  		select {
 551  		case <-aliceChannelLink.packets:
 552  			t.Fatal("did not expect a packet")
 553  		case <-time.After(time.Second * 5):
 554  		}
 555  	} else {
 556  		select {
 557  		case <-bobChannelLink.packets:
 558  			t.Fatal("did not expect a forwarding error")
 559  		case <-time.After(time.Second * 5):
 560  		}
 561  
 562  		select {
 563  		case <-aliceChannelLink.packets:
 564  		case <-time.After(time.Second * 5):
 565  			t.Fatal("expected alice to receive packet")
 566  		}
 567  	}
 568  }
 569  
 570  // TestSwitchSendHTLCMapping tests that SendHTLC will properly route packets to
 571  // zero-conf or option-scid-alias (feature-bit) channels if the confirmed SCID
 572  // is used. It also tests that nothing breaks with the mapping change.
 573  func TestSwitchSendHTLCMapping(t *testing.T) {
 574  	tests := []struct {
 575  		name string
 576  
 577  		// If this is true, the channel will be zero-conf.
 578  		zeroConf bool
 579  
 580  		// Denotes whether the channel is option-scid-alias, non
 581  		// zero-conf feature bit.
 582  		optionFeature bool
 583  
 584  		// If this is true, then the alias will be used in the packet.
 585  		useAlias bool
 586  
 587  		// This will be the channel alias if there is a mapping.
 588  		alias lnwire.ShortChannelID
 589  
 590  		// This will be the confirmed SCID if the channel is confirmed.
 591  		real lnwire.ShortChannelID
 592  	}{
 593  		{
 594  			name:          "non-zero-conf real scid w/ option",
 595  			zeroConf:      false,
 596  			optionFeature: true,
 597  			useAlias:      false,
 598  			alias: lnwire.ShortChannelID{
 599  				BlockHeight: 10010,
 600  				TxIndex:     10,
 601  				TxPosition:  10,
 602  			},
 603  			real: lnwire.ShortChannelID{
 604  				BlockHeight: 500000,
 605  				TxIndex:     50,
 606  				TxPosition:  50,
 607  			},
 608  		},
 609  		{
 610  			name:     "non-zero-conf real scid no option",
 611  			zeroConf: false,
 612  			useAlias: false,
 613  			alias:    lnwire.ShortChannelID{},
 614  			real: lnwire.ShortChannelID{
 615  				BlockHeight: 400000,
 616  				TxIndex:     50,
 617  				TxPosition:  50,
 618  			},
 619  		},
 620  		{
 621  			name:     "zero-conf alias scid w/ conf",
 622  			zeroConf: true,
 623  			useAlias: true,
 624  			alias: lnwire.ShortChannelID{
 625  				BlockHeight: 10020,
 626  				TxIndex:     20,
 627  				TxPosition:  20,
 628  			},
 629  			real: lnwire.ShortChannelID{
 630  				BlockHeight: 450000,
 631  				TxIndex:     50,
 632  				TxPosition:  50,
 633  			},
 634  		},
 635  		{
 636  			name:     "zero-conf alias scid no conf",
 637  			zeroConf: true,
 638  			useAlias: true,
 639  			alias: lnwire.ShortChannelID{
 640  				BlockHeight: 10015,
 641  				TxIndex:     25,
 642  				TxPosition:  35,
 643  			},
 644  			real: lnwire.ShortChannelID{},
 645  		},
 646  		{
 647  			name:     "zero-conf real scid",
 648  			zeroConf: true,
 649  			useAlias: false,
 650  			alias: lnwire.ShortChannelID{
 651  				BlockHeight: 10035,
 652  				TxIndex:     35,
 653  				TxPosition:  35,
 654  			},
 655  			real: lnwire.ShortChannelID{
 656  				BlockHeight: 470000,
 657  				TxIndex:     35,
 658  				TxPosition:  45,
 659  			},
 660  		},
 661  	}
 662  
 663  	for _, test := range tests {
 664  		test := test
 665  		t.Run(test.name, func(t *testing.T) {
 666  			t.Parallel()
 667  			testSwitchSendHtlcMapping(
 668  				t, test.zeroConf, test.useAlias, test.alias,
 669  				test.real, test.optionFeature,
 670  			)
 671  		})
 672  	}
 673  }
 674  
 675  func testSwitchSendHtlcMapping(t *testing.T, zeroConf, useAlias bool, alias,
 676  	realScid lnwire.ShortChannelID, optionFeature bool) {
 677  
 678  	peer, err := newMockServer(
 679  		t, "alice", testStartingHeight, nil, testDefaultDelta,
 680  	)
 681  	require.NoError(t, err)
 682  
 683  	s, err := initSwitchWithTempDB(t, testStartingHeight)
 684  	require.NoError(t, err)
 685  	err = s.Start()
 686  	require.NoError(t, err)
 687  	defer func() { _ = s.Stop() }()
 688  
 689  	// Create the lnwire.ChannelID that we'll use.
 690  	chanID, _ := genID()
 691  
 692  	var link *mockChannelLink
 693  
 694  	if zeroConf {
 695  		link = newMockChannelLink(
 696  			s, chanID, alias, realScid, peer, true, false, true,
 697  			false,
 698  		)
 699  	} else {
 700  		link = newMockChannelLink(
 701  			s, chanID, realScid, emptyScid, peer, true, false,
 702  			false, true,
 703  		)
 704  
 705  		if optionFeature {
 706  			link.addAlias(alias)
 707  		}
 708  	}
 709  
 710  	err = s.AddLink(link)
 711  	require.NoError(t, err)
 712  
 713  	// Generate preimage.
 714  	preimage, err := genPreimage()
 715  	require.NoError(t, err)
 716  	rhash := sha256.Sum256(preimage[:])
 717  
 718  	// Determine the outgoing SCID to use.
 719  	outgoingSCID := realScid
 720  	if useAlias {
 721  		outgoingSCID = alias
 722  	}
 723  
 724  	// Send the HTLC and assert that we don't get an error.
 725  	htlc := &lnwire.UpdateAddHTLC{
 726  		PaymentHash: rhash,
 727  		Amount:      1,
 728  	}
 729  
 730  	err = s.SendHTLC(outgoingSCID, 0, htlc)
 731  	require.NoError(t, err)
 732  }
 733  
 734  // TestSwitchUpdateScid verifies that zero-conf and non-zero-conf
 735  // option-scid-alias (feature bit) channels will have the expected entries in
 736  // the aliasToReal and baseIndex maps.
 737  func TestSwitchUpdateScid(t *testing.T) {
 738  	t.Parallel()
 739  
 740  	peer, err := newMockServer(
 741  		t, "alice", testStartingHeight, nil, testDefaultDelta,
 742  	)
 743  	require.NoError(t, err, "unable to create alice server")
 744  
 745  	s, err := initSwitchWithTempDB(t, testStartingHeight)
 746  	require.NoError(t, err)
 747  	err = s.Start()
 748  	require.NoError(t, err)
 749  	defer func() { _ = s.Stop() }()
 750  
 751  	// Create the IDs that we'll use.
 752  	chanID, chanID2, _, _ := genIDs()
 753  
 754  	alias := lnwire.ShortChannelID{
 755  		BlockHeight: 16_000_000,
 756  		TxIndex:     0,
 757  		TxPosition:  0,
 758  	}
 759  	alias2 := alias
 760  	alias2.TxPosition = 1
 761  
 762  	realScid := lnwire.ShortChannelID{
 763  		BlockHeight: 500000,
 764  		TxIndex:     0,
 765  		TxPosition:  0,
 766  	}
 767  
 768  	link := newMockChannelLink(
 769  		s, chanID, alias, emptyScid, peer, true, false, true, false,
 770  	)
 771  	link.addAlias(alias2)
 772  
 773  	err = s.AddLink(link)
 774  	require.NoError(t, err)
 775  
 776  	// Assert that the zero-conf link does not have entries in the
 777  	// aliasToReal map.
 778  	s.indexMtx.RLock()
 779  	_, ok := s.aliasToReal[alias]
 780  	require.False(t, ok)
 781  	_, ok = s.aliasToReal[alias2]
 782  	require.False(t, ok)
 783  
 784  	// Assert that both aliases point to the "base" SCID, which is actually
 785  	// just the first alias.
 786  	baseScid, ok := s.baseIndex[alias]
 787  	require.True(t, ok)
 788  	require.Equal(t, alias, baseScid)
 789  
 790  	baseScid, ok = s.baseIndex[alias2]
 791  	require.True(t, ok)
 792  	require.Equal(t, alias, baseScid)
 793  
 794  	s.indexMtx.RUnlock()
 795  
 796  	// We'll set the mock link's confirmed SCID so that UpdateShortChanID
 797  	// populates aliasToReal and adds an entry to baseIndex.
 798  	link.realScid = realScid
 799  	link.confirmedZC = true
 800  
 801  	err = s.UpdateShortChanID(chanID)
 802  	require.NoError(t, err)
 803  
 804  	// Assert that aliasToReal is populated and there is an entry in
 805  	// baseIndex for realScid.
 806  	s.indexMtx.RLock()
 807  	realMapping, ok := s.aliasToReal[alias]
 808  	require.True(t, ok)
 809  	require.Equal(t, realScid, realMapping)
 810  
 811  	realMapping, ok = s.aliasToReal[alias2]
 812  	require.True(t, ok)
 813  	require.Equal(t, realScid, realMapping)
 814  
 815  	baseScid, ok = s.baseIndex[realScid]
 816  	require.True(t, ok)
 817  	require.Equal(t, alias, baseScid)
 818  
 819  	s.indexMtx.RUnlock()
 820  
 821  	// Now we'll perform the same checks with a non-zero-conf
 822  	// option-scid-alias channel (feature-bit).
 823  	optionReal := lnwire.ShortChannelID{
 824  		BlockHeight: 600000,
 825  		TxIndex:     0,
 826  		TxPosition:  0,
 827  	}
 828  	optionAlias := lnwire.ShortChannelID{
 829  		BlockHeight: 12000,
 830  		TxIndex:     0,
 831  		TxPosition:  0,
 832  	}
 833  	optionAlias2 := optionAlias
 834  	optionAlias2.TxPosition = 1
 835  	link2 := newMockChannelLink(
 836  		s, chanID2, optionReal, emptyScid, peer, true, false, false,
 837  		true,
 838  	)
 839  	link2.addAlias(optionAlias)
 840  	link2.addAlias(optionAlias2)
 841  
 842  	err = s.AddLink(link2)
 843  	require.NoError(t, err)
 844  
 845  	// Assert that the option-scid-alias link does have entries in the
 846  	// aliasToReal and baseIndex maps.
 847  	s.indexMtx.RLock()
 848  	realMapping, ok = s.aliasToReal[optionAlias]
 849  	require.True(t, ok)
 850  	require.Equal(t, optionReal, realMapping)
 851  
 852  	realMapping, ok = s.aliasToReal[optionAlias2]
 853  	require.True(t, ok)
 854  	require.Equal(t, optionReal, realMapping)
 855  
 856  	baseScid, ok = s.baseIndex[optionReal]
 857  	require.True(t, ok)
 858  	require.Equal(t, optionReal, baseScid)
 859  
 860  	baseScid, ok = s.baseIndex[optionAlias]
 861  	require.True(t, ok)
 862  	require.Equal(t, optionReal, baseScid)
 863  
 864  	baseScid, ok = s.baseIndex[optionAlias2]
 865  	require.True(t, ok)
 866  	require.Equal(t, optionReal, baseScid)
 867  
 868  	s.indexMtx.RUnlock()
 869  }
 870  
 871  // TestSwitchForward checks the ability of htlc switch to forward add/settle
 872  // requests.
 873  func TestSwitchForward(t *testing.T) {
 874  	t.Parallel()
 875  
 876  	alicePeer, err := newMockServer(
 877  		t, "alice", testStartingHeight, nil, testDefaultDelta,
 878  	)
 879  	if err != nil {
 880  		t.Fatalf("unable to create alice server: %v", err)
 881  	}
 882  	bobPeer, err := newMockServer(
 883  		t, "bob", testStartingHeight, nil, testDefaultDelta,
 884  	)
 885  	if err != nil {
 886  		t.Fatalf("unable to create bob server: %v", err)
 887  	}
 888  
 889  	s, err := initSwitchWithTempDB(t, testStartingHeight)
 890  	if err != nil {
 891  		t.Fatalf("unable to init switch: %v", err)
 892  	}
 893  	if err := s.Start(); err != nil {
 894  		t.Fatalf("unable to start switch: %v", err)
 895  	}
 896  	defer s.Stop()
 897  
 898  	chanID1, chanID2, aliceChanID, bobChanID := genIDs()
 899  
 900  	aliceChannelLink := newMockChannelLink(
 901  		s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
 902  		false, false,
 903  	)
 904  	bobChannelLink := newMockChannelLink(
 905  		s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
 906  		false,
 907  	)
 908  	if err := s.AddLink(aliceChannelLink); err != nil {
 909  		t.Fatalf("unable to add alice link: %v", err)
 910  	}
 911  	if err := s.AddLink(bobChannelLink); err != nil {
 912  		t.Fatalf("unable to add bob link: %v", err)
 913  	}
 914  
 915  	// Create request which should be forwarded from Alice channel link to
 916  	// bob channel link.
 917  	preimage, err := genPreimage()
 918  	if err != nil {
 919  		t.Fatalf("unable to generate preimage: %v", err)
 920  	}
 921  	rhash := sha256.Sum256(preimage[:])
 922  	packet := &htlcPacket{
 923  		incomingChanID: aliceChannelLink.ShortChanID(),
 924  		incomingHTLCID: 0,
 925  		outgoingChanID: bobChannelLink.ShortChanID(),
 926  		obfuscator:     NewMockObfuscator(),
 927  		htlc: &lnwire.UpdateAddHTLC{
 928  			PaymentHash: rhash,
 929  			Amount:      1,
 930  		},
 931  	}
 932  
 933  	// Handle the request and checks that bob channel link received it.
 934  	if err := s.ForwardPackets(nil, packet); err != nil {
 935  		t.Fatal(err)
 936  	}
 937  
 938  	select {
 939  	case <-bobChannelLink.packets:
 940  		if err := bobChannelLink.completeCircuit(packet); err != nil {
 941  			t.Fatalf("unable to complete payment circuit: %v", err)
 942  		}
 943  	case <-time.After(time.Second):
 944  		t.Fatal("request was not propagated to destination")
 945  	}
 946  
 947  	if s.circuits.NumOpen() != 1 {
 948  		t.Fatal("wrong amount of circuits")
 949  	}
 950  
 951  	if !s.IsForwardedHTLC(bobChannelLink.ShortChanID(), 0) {
 952  		t.Fatal("htlc should be identified as forwarded")
 953  	}
 954  
 955  	// Create settle request pretending that bob link handled the add htlc
 956  	// request and sent the htlc settle request back. This request should
 957  	// be forwarder back to Alice link.
 958  	packet = &htlcPacket{
 959  		outgoingChanID: bobChannelLink.ShortChanID(),
 960  		outgoingHTLCID: 0,
 961  		amount:         1,
 962  		htlc: &lnwire.UpdateFulfillHTLC{
 963  			PaymentPreimage: preimage,
 964  		},
 965  	}
 966  
 967  	// Handle the request and checks that payment circuit works properly.
 968  	if err := s.ForwardPackets(nil, packet); err != nil {
 969  		t.Fatal(err)
 970  	}
 971  
 972  	select {
 973  	case pkt := <-aliceChannelLink.packets:
 974  		if err := aliceChannelLink.deleteCircuit(pkt); err != nil {
 975  			t.Fatalf("unable to remove circuit: %v", err)
 976  		}
 977  	case <-time.After(time.Second):
 978  		t.Fatal("request was not propagated to channelPoint")
 979  	}
 980  
 981  	if s.circuits.NumOpen() != 0 {
 982  		t.Fatal("wrong amount of circuits")
 983  	}
 984  }
 985  
 986  func TestSwitchForwardFailAfterFullAdd(t *testing.T) {
 987  	t.Parallel()
 988  
 989  	chanID1, chanID2, aliceChanID, bobChanID := genIDs()
 990  
 991  	alicePeer, err := newMockServer(
 992  		t, "alice", testStartingHeight, nil, testDefaultDelta,
 993  	)
 994  	if err != nil {
 995  		t.Fatalf("unable to create alice server: %v", err)
 996  	}
 997  	bobPeer, err := newMockServer(
 998  		t, "bob", testStartingHeight, nil, testDefaultDelta,
 999  	)
1000  	require.NoError(t, err, "unable to create bob server")
1001  
1002  	tempPath := t.TempDir()
1003  
1004  	cdb := channeldb.OpenForTesting(t, tempPath)
1005  
1006  	s, err := initSwitchWithDB(testStartingHeight, cdb)
1007  	require.NoError(t, err, "unable to init switch")
1008  	if err := s.Start(); err != nil {
1009  		t.Fatalf("unable to start switch: %v", err)
1010  	}
1011  
1012  	// Even though we intend to Stop s later in the test, it is safe to
1013  	// defer this Stop since its execution it is protected by an atomic
1014  	// guard, guaranteeing it executes at most once.
1015  	defer s.Stop()
1016  
1017  	aliceChannelLink := newMockChannelLink(
1018  		s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
1019  		false, false,
1020  	)
1021  	bobChannelLink := newMockChannelLink(
1022  		s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
1023  		false,
1024  	)
1025  	if err := s.AddLink(aliceChannelLink); err != nil {
1026  		t.Fatalf("unable to add alice link: %v", err)
1027  	}
1028  	if err := s.AddLink(bobChannelLink); err != nil {
1029  		t.Fatalf("unable to add bob link: %v", err)
1030  	}
1031  
1032  	// Create request which should be forwarded from Alice channel link to
1033  	// bob channel link.
1034  	preimage := [sha256.Size]byte{1}
1035  	rhash := sha256.Sum256(preimage[:])
1036  	ogPacket := &htlcPacket{
1037  		incomingChanID: aliceChannelLink.ShortChanID(),
1038  		incomingHTLCID: 0,
1039  		outgoingChanID: bobChannelLink.ShortChanID(),
1040  		obfuscator:     NewMockObfuscator(),
1041  		htlc: &lnwire.UpdateAddHTLC{
1042  			PaymentHash: rhash,
1043  			Amount:      1,
1044  		},
1045  	}
1046  
1047  	if s.circuits.NumPending() != 0 {
1048  		t.Fatalf("wrong amount of half circuits")
1049  	}
1050  	if s.circuits.NumOpen() != 0 {
1051  		t.Fatalf("wrong amount of circuits")
1052  	}
1053  
1054  	// Handle the request and checks that bob channel link received it.
1055  	if err := s.ForwardPackets(nil, ogPacket); err != nil {
1056  		t.Fatal(err)
1057  	}
1058  
1059  	if s.circuits.NumPending() != 1 {
1060  		t.Fatalf("wrong amount of half circuits")
1061  	}
1062  	if s.circuits.NumOpen() != 0 {
1063  		t.Fatalf("wrong amount of circuits")
1064  	}
1065  
1066  	// Pull packet from bob's link, but do not perform a full add.
1067  	select {
1068  	case packet := <-bobChannelLink.packets:
1069  		// Complete the payment circuit and assign the outgoing htlc id
1070  		// before restarting.
1071  		if err := bobChannelLink.completeCircuit(packet); err != nil {
1072  			t.Fatalf("unable to complete payment circuit: %v", err)
1073  		}
1074  
1075  	case <-time.After(time.Second):
1076  		t.Fatal("request was not propagated to destination")
1077  	}
1078  
1079  	if s.circuits.NumPending() != 1 {
1080  		t.Fatalf("wrong amount of half circuits")
1081  	}
1082  	if s.circuits.NumOpen() != 1 {
1083  		t.Fatalf("wrong amount of circuits")
1084  	}
1085  
1086  	// Now we will restart bob, leaving the forwarding decision for this
1087  	// htlc is in the half-added state.
1088  	if err := s.Stop(); err != nil {
1089  		t.Fatal(err)
1090  	}
1091  
1092  	if err := cdb.Close(); err != nil {
1093  		t.Fatal(err)
1094  	}
1095  
1096  	cdb2 := channeldb.OpenForTesting(t, tempPath)
1097  
1098  	s2, err := initSwitchWithDB(testStartingHeight, cdb2)
1099  	require.NoError(t, err, "unable reinit switch")
1100  	if err := s2.Start(); err != nil {
1101  		t.Fatalf("unable to restart switch: %v", err)
1102  	}
1103  
1104  	// Even though we intend to Stop s2 later in the test, it is safe to
1105  	// defer this Stop since its execution it is protected by an atomic
1106  	// guard, guaranteeing it executes at most once.
1107  	defer s2.Stop()
1108  
1109  	aliceChannelLink = newMockChannelLink(
1110  		s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
1111  		false, false,
1112  	)
1113  	bobChannelLink = newMockChannelLink(
1114  		s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
1115  		false,
1116  	)
1117  	if err := s2.AddLink(aliceChannelLink); err != nil {
1118  		t.Fatalf("unable to add alice link: %v", err)
1119  	}
1120  	if err := s2.AddLink(bobChannelLink); err != nil {
1121  		t.Fatalf("unable to add bob link: %v", err)
1122  	}
1123  
1124  	if s2.circuits.NumPending() != 1 {
1125  		t.Fatalf("wrong amount of half circuits")
1126  	}
1127  	if s2.circuits.NumOpen() != 1 {
1128  		t.Fatalf("wrong amount of circuits")
1129  	}
1130  
1131  	// Craft a failure message from the remote peer.
1132  	fail := &htlcPacket{
1133  		outgoingChanID: bobChannelLink.ShortChanID(),
1134  		outgoingHTLCID: 0,
1135  		amount:         1,
1136  		htlc:           &lnwire.UpdateFailHTLC{},
1137  	}
1138  
1139  	// Send the fail packet from the remote peer through the switch.
1140  	if err := s2.ForwardPackets(nil, fail); err != nil {
1141  		t.Fatal(err)
1142  	}
1143  
1144  	// Pull packet from alice's link, as it should have gone through
1145  	// successfully.
1146  	select {
1147  	case pkt := <-aliceChannelLink.packets:
1148  		if err := aliceChannelLink.completeCircuit(pkt); err != nil {
1149  			t.Fatalf("unable to remove circuit: %v", err)
1150  		}
1151  	case <-time.After(time.Second):
1152  		t.Fatal("request was not propagated to destination")
1153  	}
1154  
1155  	// Circuit map should be empty now.
1156  	if s2.circuits.NumPending() != 0 {
1157  		t.Fatalf("wrong amount of half circuits")
1158  	}
1159  	if s2.circuits.NumOpen() != 0 {
1160  		t.Fatalf("wrong amount of circuits")
1161  	}
1162  
1163  	// Send the fail packet from the remote peer through the switch.
1164  	if err := s.ForwardPackets(nil, fail); err != nil {
1165  		t.Fatal(err)
1166  	}
1167  	select {
1168  	case <-aliceChannelLink.packets:
1169  		t.Fatalf("expected duplicate fail to not arrive at the destination")
1170  	case <-time.After(time.Second):
1171  	}
1172  }
1173  
1174  func TestSwitchForwardSettleAfterFullAdd(t *testing.T) {
1175  	t.Parallel()
1176  
1177  	chanID1, chanID2, aliceChanID, bobChanID := genIDs()
1178  
1179  	alicePeer, err := newMockServer(
1180  		t, "alice", testStartingHeight, nil, testDefaultDelta,
1181  	)
1182  	require.NoError(t, err, "unable to create alice server")
1183  	bobPeer, err := newMockServer(
1184  		t, "bob", testStartingHeight, nil, testDefaultDelta,
1185  	)
1186  	require.NoError(t, err, "unable to create bob server")
1187  
1188  	tempPath := t.TempDir()
1189  
1190  	cdb := channeldb.OpenForTesting(t, tempPath)
1191  
1192  	s, err := initSwitchWithDB(testStartingHeight, cdb)
1193  	require.NoError(t, err, "unable to init switch")
1194  	if err := s.Start(); err != nil {
1195  		t.Fatalf("unable to start switch: %v", err)
1196  	}
1197  
1198  	// Even though we intend to Stop s later in the test, it is safe to
1199  	// defer this Stop since its execution it is protected by an atomic
1200  	// guard, guaranteeing it executes at most once.
1201  	defer s.Stop()
1202  
1203  	aliceChannelLink := newMockChannelLink(
1204  		s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
1205  		false, false,
1206  	)
1207  	bobChannelLink := newMockChannelLink(
1208  		s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
1209  		false,
1210  	)
1211  	if err := s.AddLink(aliceChannelLink); err != nil {
1212  		t.Fatalf("unable to add alice link: %v", err)
1213  	}
1214  	if err := s.AddLink(bobChannelLink); err != nil {
1215  		t.Fatalf("unable to add bob link: %v", err)
1216  	}
1217  
1218  	// Create request which should be forwarded from Alice channel link to
1219  	// bob channel link.
1220  	preimage := [sha256.Size]byte{1}
1221  	rhash := sha256.Sum256(preimage[:])
1222  	ogPacket := &htlcPacket{
1223  		incomingChanID: aliceChannelLink.ShortChanID(),
1224  		incomingHTLCID: 0,
1225  		outgoingChanID: bobChannelLink.ShortChanID(),
1226  		obfuscator:     NewMockObfuscator(),
1227  		htlc: &lnwire.UpdateAddHTLC{
1228  			PaymentHash: rhash,
1229  			Amount:      1,
1230  		},
1231  	}
1232  
1233  	if s.circuits.NumPending() != 0 {
1234  		t.Fatalf("wrong amount of half circuits")
1235  	}
1236  	if s.circuits.NumOpen() != 0 {
1237  		t.Fatalf("wrong amount of circuits")
1238  	}
1239  
1240  	// Handle the request and checks that bob channel link received it.
1241  	if err := s.ForwardPackets(nil, ogPacket); err != nil {
1242  		t.Fatal(err)
1243  	}
1244  
1245  	if s.circuits.NumPending() != 1 {
1246  		t.Fatalf("wrong amount of half circuits")
1247  	}
1248  	if s.circuits.NumOpen() != 0 {
1249  		t.Fatalf("wrong amount of circuits")
1250  	}
1251  
1252  	// Pull packet from bob's link, but do not perform a full add.
1253  	select {
1254  	case packet := <-bobChannelLink.packets:
1255  		// Complete the payment circuit and assign the outgoing htlc id
1256  		// before restarting.
1257  		if err := bobChannelLink.completeCircuit(packet); err != nil {
1258  			t.Fatalf("unable to complete payment circuit: %v", err)
1259  		}
1260  
1261  	case <-time.After(time.Second):
1262  		t.Fatal("request was not propagated to destination")
1263  	}
1264  
1265  	if s.circuits.NumPending() != 1 {
1266  		t.Fatalf("wrong amount of half circuits")
1267  	}
1268  	if s.circuits.NumOpen() != 1 {
1269  		t.Fatalf("wrong amount of circuits")
1270  	}
1271  
1272  	// Now we will restart bob, leaving the forwarding decision for this
1273  	// htlc is in the half-added state.
1274  	if err := s.Stop(); err != nil {
1275  		t.Fatal(err)
1276  	}
1277  
1278  	if err := cdb.Close(); err != nil {
1279  		t.Fatal(err)
1280  	}
1281  
1282  	cdb2 := channeldb.OpenForTesting(t, tempPath)
1283  
1284  	s2, err := initSwitchWithDB(testStartingHeight, cdb2)
1285  	require.NoError(t, err, "unable reinit switch")
1286  	if err := s2.Start(); err != nil {
1287  		t.Fatalf("unable to restart switch: %v", err)
1288  	}
1289  
1290  	// Even though we intend to Stop s2 later in the test, it is safe to
1291  	// defer this Stop since its execution it is protected by an atomic
1292  	// guard, guaranteeing it executes at most once.
1293  	defer s2.Stop()
1294  
1295  	aliceChannelLink = newMockChannelLink(
1296  		s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
1297  		false, false,
1298  	)
1299  	bobChannelLink = newMockChannelLink(
1300  		s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
1301  		false,
1302  	)
1303  	if err := s2.AddLink(aliceChannelLink); err != nil {
1304  		t.Fatalf("unable to add alice link: %v", err)
1305  	}
1306  	if err := s2.AddLink(bobChannelLink); err != nil {
1307  		t.Fatalf("unable to add bob link: %v", err)
1308  	}
1309  
1310  	if s2.circuits.NumPending() != 1 {
1311  		t.Fatalf("wrong amount of half circuits")
1312  	}
1313  	if s2.circuits.NumOpen() != 1 {
1314  		t.Fatalf("wrong amount of circuits")
1315  	}
1316  
1317  	// Craft a settle message from the remote peer.
1318  	settle := &htlcPacket{
1319  		outgoingChanID: bobChannelLink.ShortChanID(),
1320  		outgoingHTLCID: 0,
1321  		amount:         1,
1322  		htlc: &lnwire.UpdateFulfillHTLC{
1323  			PaymentPreimage: preimage,
1324  		},
1325  	}
1326  
1327  	// Send the settle packet from the remote peer through the switch.
1328  	if err := s2.ForwardPackets(nil, settle); err != nil {
1329  		t.Fatal(err)
1330  	}
1331  
1332  	// Pull packet from alice's link, as it should have gone through
1333  	// successfully.
1334  	select {
1335  	case packet := <-aliceChannelLink.packets:
1336  		if err := aliceChannelLink.completeCircuit(packet); err != nil {
1337  			t.Fatalf("unable to complete circuit with in key=%s: %v",
1338  				packet.inKey(), err)
1339  		}
1340  	case <-time.After(time.Second):
1341  		t.Fatal("request was not propagated to destination")
1342  	}
1343  
1344  	// Circuit map should be empty now.
1345  	if s2.circuits.NumPending() != 0 {
1346  		t.Fatalf("wrong amount of half circuits")
1347  	}
1348  	if s2.circuits.NumOpen() != 0 {
1349  		t.Fatalf("wrong amount of circuits")
1350  	}
1351  
1352  	// Send the settle packet again, which not arrive at destination.
1353  	if err := s2.ForwardPackets(nil, settle); err != nil {
1354  		t.Fatal(err)
1355  	}
1356  	select {
1357  	case <-bobChannelLink.packets:
1358  		t.Fatalf("expected duplicate fail to not arrive at the destination")
1359  	case <-time.After(time.Second):
1360  	}
1361  }
1362  
1363  func TestSwitchForwardDropAfterFullAdd(t *testing.T) {
1364  	t.Parallel()
1365  
1366  	chanID1, chanID2, aliceChanID, bobChanID := genIDs()
1367  
1368  	alicePeer, err := newMockServer(
1369  		t, "alice", testStartingHeight, nil, testDefaultDelta,
1370  	)
1371  	require.NoError(t, err, "unable to create alice server")
1372  	bobPeer, err := newMockServer(
1373  		t, "bob", testStartingHeight, nil, testDefaultDelta,
1374  	)
1375  	require.NoError(t, err, "unable to create bob server")
1376  
1377  	tempPath := t.TempDir()
1378  
1379  	cdb := channeldb.OpenForTesting(t, tempPath)
1380  
1381  	s, err := initSwitchWithDB(testStartingHeight, cdb)
1382  	require.NoError(t, err, "unable to init switch")
1383  	if err := s.Start(); err != nil {
1384  		t.Fatalf("unable to start switch: %v", err)
1385  	}
1386  
1387  	// Even though we intend to Stop s later in the test, it is safe to
1388  	// defer this Stop since its execution it is protected by an atomic
1389  	// guard, guaranteeing it executes at most once.
1390  	defer s.Stop()
1391  
1392  	aliceChannelLink := newMockChannelLink(
1393  		s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
1394  		false, false,
1395  	)
1396  	bobChannelLink := newMockChannelLink(
1397  		s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
1398  		false,
1399  	)
1400  	if err := s.AddLink(aliceChannelLink); err != nil {
1401  		t.Fatalf("unable to add alice link: %v", err)
1402  	}
1403  	if err := s.AddLink(bobChannelLink); err != nil {
1404  		t.Fatalf("unable to add bob link: %v", err)
1405  	}
1406  
1407  	// Create request which should be forwarded from Alice channel link to
1408  	// bob channel link.
1409  	preimage := [sha256.Size]byte{1}
1410  	rhash := sha256.Sum256(preimage[:])
1411  	ogPacket := &htlcPacket{
1412  		incomingChanID: aliceChannelLink.ShortChanID(),
1413  		incomingHTLCID: 0,
1414  		outgoingChanID: bobChannelLink.ShortChanID(),
1415  		obfuscator:     NewMockObfuscator(),
1416  		htlc: &lnwire.UpdateAddHTLC{
1417  			PaymentHash: rhash,
1418  			Amount:      1,
1419  		},
1420  	}
1421  
1422  	if s.circuits.NumPending() != 0 {
1423  		t.Fatalf("wrong amount of half circuits")
1424  	}
1425  	if s.circuits.NumOpen() != 0 {
1426  		t.Fatalf("wrong amount of circuits")
1427  	}
1428  
1429  	// Handle the request and checks that bob channel link received it.
1430  	if err := s.ForwardPackets(nil, ogPacket); err != nil {
1431  		t.Fatal(err)
1432  	}
1433  
1434  	if s.circuits.NumPending() != 1 {
1435  		t.Fatalf("wrong amount of half circuits")
1436  	}
1437  	if s.circuits.NumOpen() != 0 {
1438  		t.Fatalf("wrong amount of half circuits")
1439  	}
1440  
1441  	// Pull packet from bob's link, but do not perform a full add.
1442  	select {
1443  	case packet := <-bobChannelLink.packets:
1444  		// Complete the payment circuit and assign the outgoing htlc id
1445  		// before restarting.
1446  		if err := bobChannelLink.completeCircuit(packet); err != nil {
1447  			t.Fatalf("unable to complete payment circuit: %v", err)
1448  		}
1449  	case <-time.After(time.Second):
1450  		t.Fatal("request was not propagated to destination")
1451  	}
1452  
1453  	// Now we will restart bob, leaving the forwarding decision for this
1454  	// htlc is in the half-added state.
1455  	if err := s.Stop(); err != nil {
1456  		t.Fatal(err)
1457  	}
1458  
1459  	if err := cdb.Close(); err != nil {
1460  		t.Fatal(err)
1461  	}
1462  
1463  	cdb2 := channeldb.OpenForTesting(t, tempPath)
1464  
1465  	s2, err := initSwitchWithDB(testStartingHeight, cdb2)
1466  	require.NoError(t, err, "unable reinit switch")
1467  	if err := s2.Start(); err != nil {
1468  		t.Fatalf("unable to restart switch: %v", err)
1469  	}
1470  
1471  	// Even though we intend to Stop s2 later in the test, it is safe to
1472  	// defer this Stop since its execution it is protected by an atomic
1473  	// guard, guaranteeing it executes at most once.
1474  	defer s2.Stop()
1475  
1476  	aliceChannelLink = newMockChannelLink(
1477  		s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
1478  		false, false,
1479  	)
1480  	bobChannelLink = newMockChannelLink(
1481  		s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
1482  		false,
1483  	)
1484  	if err := s2.AddLink(aliceChannelLink); err != nil {
1485  		t.Fatalf("unable to add alice link: %v", err)
1486  	}
1487  	if err := s2.AddLink(bobChannelLink); err != nil {
1488  		t.Fatalf("unable to add bob link: %v", err)
1489  	}
1490  
1491  	if s2.circuits.NumPending() != 1 {
1492  		t.Fatalf("wrong amount of half circuits")
1493  	}
1494  	if s2.circuits.NumOpen() != 1 {
1495  		t.Fatalf("wrong amount of half circuits")
1496  	}
1497  
1498  	// Resend the failed htlc. The packet will be dropped silently since the
1499  	// switch will detect that it has been half added previously.
1500  	if err := s2.ForwardPackets(nil, ogPacket); err != nil {
1501  		t.Fatal(err)
1502  	}
1503  
1504  	// After detecting an incomplete forward, the fail packet should have
1505  	// been returned to the sender.
1506  	select {
1507  	case <-aliceChannelLink.packets:
1508  		t.Fatal("request should not have returned to source")
1509  	case <-bobChannelLink.packets:
1510  		t.Fatal("request should not have forwarded to destination")
1511  	case <-time.After(time.Second):
1512  	}
1513  }
1514  
1515  func TestSwitchForwardFailAfterHalfAdd(t *testing.T) {
1516  	t.Parallel()
1517  
1518  	chanID1, chanID2, aliceChanID, bobChanID := genIDs()
1519  
1520  	alicePeer, err := newMockServer(
1521  		t, "alice", testStartingHeight, nil, testDefaultDelta,
1522  	)
1523  	require.NoError(t, err, "unable to create alice server")
1524  	bobPeer, err := newMockServer(
1525  		t, "bob", testStartingHeight, nil, testDefaultDelta,
1526  	)
1527  	require.NoError(t, err, "unable to create bob server")
1528  
1529  	tempPath := t.TempDir()
1530  
1531  	cdb := channeldb.OpenForTesting(t, tempPath)
1532  
1533  	s, err := initSwitchWithDB(testStartingHeight, cdb)
1534  	require.NoError(t, err, "unable to init switch")
1535  	if err := s.Start(); err != nil {
1536  		t.Fatalf("unable to start switch: %v", err)
1537  	}
1538  
1539  	// Even though we intend to Stop s later in the test, it is safe to
1540  	// defer this Stop since its execution it is protected by an atomic
1541  	// guard, guaranteeing it executes at most once.
1542  	defer s.Stop()
1543  
1544  	aliceChannelLink := newMockChannelLink(
1545  		s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
1546  		false, false,
1547  	)
1548  	bobChannelLink := newMockChannelLink(
1549  		s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
1550  		false,
1551  	)
1552  	if err := s.AddLink(aliceChannelLink); err != nil {
1553  		t.Fatalf("unable to add alice link: %v", err)
1554  	}
1555  	if err := s.AddLink(bobChannelLink); err != nil {
1556  		t.Fatalf("unable to add bob link: %v", err)
1557  	}
1558  
1559  	// Create request which should be forwarded from Alice channel link to
1560  	// bob channel link.
1561  	preimage := [sha256.Size]byte{1}
1562  	rhash := sha256.Sum256(preimage[:])
1563  	ogPacket := &htlcPacket{
1564  		incomingChanID: aliceChannelLink.ShortChanID(),
1565  		incomingHTLCID: 0,
1566  		outgoingChanID: bobChannelLink.ShortChanID(),
1567  		obfuscator:     NewMockObfuscator(),
1568  		htlc: &lnwire.UpdateAddHTLC{
1569  			PaymentHash: rhash,
1570  			Amount:      1,
1571  		},
1572  	}
1573  
1574  	if s.circuits.NumPending() != 0 {
1575  		t.Fatalf("wrong amount of half circuits")
1576  	}
1577  	if s.circuits.NumOpen() != 0 {
1578  		t.Fatalf("wrong amount of circuits")
1579  	}
1580  
1581  	// Handle the request and checks that bob channel link received it.
1582  	if err := s.ForwardPackets(nil, ogPacket); err != nil {
1583  		t.Fatal(err)
1584  	}
1585  
1586  	if s.circuits.NumPending() != 1 {
1587  		t.Fatalf("wrong amount of half circuits")
1588  	}
1589  	if s.circuits.NumOpen() != 0 {
1590  		t.Fatalf("wrong amount of half circuits")
1591  	}
1592  
1593  	// Pull packet from bob's link, but do not perform a full add.
1594  	select {
1595  	case <-bobChannelLink.packets:
1596  	case <-time.After(time.Second):
1597  		t.Fatal("request was not propagated to destination")
1598  	}
1599  
1600  	// Now we will restart bob, leaving the forwarding decision for this
1601  	// htlc is in the half-added state.
1602  	if err := s.Stop(); err != nil {
1603  		t.Fatal(err)
1604  	}
1605  
1606  	if err := cdb.Close(); err != nil {
1607  		t.Fatal(err)
1608  	}
1609  
1610  	cdb2 := channeldb.OpenForTesting(t, tempPath)
1611  
1612  	s2, err := initSwitchWithDB(testStartingHeight, cdb2)
1613  	require.NoError(t, err, "unable reinit switch")
1614  	if err := s2.Start(); err != nil {
1615  		t.Fatalf("unable to restart switch: %v", err)
1616  	}
1617  
1618  	// Even though we intend to Stop s2 later in the test, it is safe to
1619  	// defer this Stop since its execution it is protected by an atomic
1620  	// guard, guaranteeing it executes at most once.
1621  	defer s2.Stop()
1622  
1623  	aliceChannelLink = newMockChannelLink(
1624  		s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
1625  		false, false,
1626  	)
1627  	bobChannelLink = newMockChannelLink(
1628  		s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
1629  		false,
1630  	)
1631  	if err := s2.AddLink(aliceChannelLink); err != nil {
1632  		t.Fatalf("unable to add alice link: %v", err)
1633  	}
1634  	if err := s2.AddLink(bobChannelLink); err != nil {
1635  		t.Fatalf("unable to add bob link: %v", err)
1636  	}
1637  
1638  	if s2.circuits.NumPending() != 1 {
1639  		t.Fatalf("wrong amount of half circuits")
1640  	}
1641  	if s2.circuits.NumOpen() != 0 {
1642  		t.Fatalf("wrong amount of half circuits")
1643  	}
1644  
1645  	// Resend the failed htlc, it should be returned to alice since the
1646  	// switch will detect that it has been half added previously.
1647  	err = s2.ForwardPackets(nil, ogPacket)
1648  	if err != nil {
1649  		t.Fatal(err)
1650  	}
1651  
1652  	// After detecting an incomplete forward, the fail packet should have
1653  	// been returned to the sender.
1654  	select {
1655  	case pkt := <-aliceChannelLink.packets:
1656  		linkErr := pkt.linkFailure
1657  		if linkErr.FailureDetail != OutgoingFailureIncompleteForward {
1658  			t.Fatalf("expected incomplete forward, got: %v",
1659  				linkErr.FailureDetail)
1660  		}
1661  	case <-time.After(time.Second):
1662  		t.Fatal("request was not propagated to destination")
1663  	}
1664  }
1665  
1666  // TestSwitchForwardCircuitPersistence checks the ability of htlc switch to
1667  // maintain the proper entries in the circuit map in the face of restarts.
1668  func TestSwitchForwardCircuitPersistence(t *testing.T) {
1669  	t.Parallel()
1670  
1671  	chanID1, chanID2, aliceChanID, bobChanID := genIDs()
1672  
1673  	alicePeer, err := newMockServer(
1674  		t, "alice", testStartingHeight, nil, testDefaultDelta,
1675  	)
1676  	require.NoError(t, err, "unable to create alice server")
1677  	bobPeer, err := newMockServer(
1678  		t, "bob", testStartingHeight, nil, testDefaultDelta,
1679  	)
1680  	require.NoError(t, err, "unable to create bob server")
1681  
1682  	tempPath := t.TempDir()
1683  
1684  	cdb := channeldb.OpenForTesting(t, tempPath)
1685  
1686  	s, err := initSwitchWithDB(testStartingHeight, cdb)
1687  	require.NoError(t, err, "unable to init switch")
1688  	if err := s.Start(); err != nil {
1689  		t.Fatalf("unable to start switch: %v", err)
1690  	}
1691  
1692  	// Even though we intend to Stop s later in the test, it is safe to
1693  	// defer this Stop since its execution it is protected by an atomic
1694  	// guard, guaranteeing it executes at most once.
1695  	defer s.Stop()
1696  
1697  	aliceChannelLink := newMockChannelLink(
1698  		s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
1699  		false, false,
1700  	)
1701  	bobChannelLink := newMockChannelLink(
1702  		s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
1703  		false,
1704  	)
1705  	if err := s.AddLink(aliceChannelLink); err != nil {
1706  		t.Fatalf("unable to add alice link: %v", err)
1707  	}
1708  	if err := s.AddLink(bobChannelLink); err != nil {
1709  		t.Fatalf("unable to add bob link: %v", err)
1710  	}
1711  
1712  	// Create request which should be forwarded from Alice channel link to
1713  	// bob channel link.
1714  	preimage := [sha256.Size]byte{1}
1715  	rhash := sha256.Sum256(preimage[:])
1716  	ogPacket := &htlcPacket{
1717  		incomingChanID: aliceChannelLink.ShortChanID(),
1718  		incomingHTLCID: 0,
1719  		outgoingChanID: bobChannelLink.ShortChanID(),
1720  		obfuscator:     NewMockObfuscator(),
1721  		htlc: &lnwire.UpdateAddHTLC{
1722  			PaymentHash: rhash,
1723  			Amount:      1,
1724  		},
1725  	}
1726  
1727  	if s.circuits.NumPending() != 0 {
1728  		t.Fatalf("wrong amount of half circuits")
1729  	}
1730  	if s.circuits.NumOpen() != 0 {
1731  		t.Fatalf("wrong amount of circuits")
1732  	}
1733  
1734  	// Handle the request and checks that bob channel link received it.
1735  	if err := s.ForwardPackets(nil, ogPacket); err != nil {
1736  		t.Fatal(err)
1737  	}
1738  
1739  	if s.circuits.NumPending() != 1 {
1740  		t.Fatalf("wrong amount of half circuits")
1741  	}
1742  	if s.circuits.NumOpen() != 0 {
1743  		t.Fatalf("wrong amount of circuits")
1744  	}
1745  
1746  	// Retrieve packet from outgoing link and cache until after restart.
1747  	var packet *htlcPacket
1748  	select {
1749  	case packet = <-bobChannelLink.packets:
1750  	case <-time.After(time.Second):
1751  		t.Fatal("request was not propagated to destination")
1752  	}
1753  
1754  	if err := s.Stop(); err != nil {
1755  		t.Fatal(err)
1756  	}
1757  
1758  	if err := cdb.Close(); err != nil {
1759  		t.Fatal(err)
1760  	}
1761  
1762  	cdb2 := channeldb.OpenForTesting(t, tempPath)
1763  
1764  	s2, err := initSwitchWithDB(testStartingHeight, cdb2)
1765  	require.NoError(t, err, "unable reinit switch")
1766  	if err := s2.Start(); err != nil {
1767  		t.Fatalf("unable to restart switch: %v", err)
1768  	}
1769  
1770  	// Even though we intend to Stop s2 later in the test, it is safe to
1771  	// defer this Stop since its execution it is protected by an atomic
1772  	// guard, guaranteeing it executes at most once.
1773  	defer s2.Stop()
1774  
1775  	aliceChannelLink = newMockChannelLink(
1776  		s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
1777  		false, false,
1778  	)
1779  	bobChannelLink = newMockChannelLink(
1780  		s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
1781  		false,
1782  	)
1783  	if err := s2.AddLink(aliceChannelLink); err != nil {
1784  		t.Fatalf("unable to add alice link: %v", err)
1785  	}
1786  	if err := s2.AddLink(bobChannelLink); err != nil {
1787  		t.Fatalf("unable to add bob link: %v", err)
1788  	}
1789  
1790  	if s2.circuits.NumPending() != 1 {
1791  		t.Fatalf("wrong amount of half circuits")
1792  	}
1793  	if s2.circuits.NumOpen() != 0 {
1794  		t.Fatalf("wrong amount of half circuits")
1795  	}
1796  
1797  	// Now that the switch has restarted, complete the payment circuit.
1798  	if err := bobChannelLink.completeCircuit(packet); err != nil {
1799  		t.Fatalf("unable to complete payment circuit: %v", err)
1800  	}
1801  
1802  	if s2.circuits.NumPending() != 1 {
1803  		t.Fatalf("wrong amount of half circuits")
1804  	}
1805  	if s2.circuits.NumOpen() != 1 {
1806  		t.Fatal("wrong amount of circuits")
1807  	}
1808  
1809  	// Create settle request pretending that bob link handled the add htlc
1810  	// request and sent the htlc settle request back. This request should
1811  	// be forwarder back to Alice link.
1812  	ogPacket = &htlcPacket{
1813  		outgoingChanID: bobChannelLink.ShortChanID(),
1814  		outgoingHTLCID: 0,
1815  		amount:         1,
1816  		htlc: &lnwire.UpdateFulfillHTLC{
1817  			PaymentPreimage: preimage,
1818  		},
1819  	}
1820  
1821  	// Handle the request and checks that payment circuit works properly.
1822  	if err := s2.ForwardPackets(nil, ogPacket); err != nil {
1823  		t.Fatal(err)
1824  	}
1825  
1826  	select {
1827  	case packet = <-aliceChannelLink.packets:
1828  		if err := aliceChannelLink.completeCircuit(packet); err != nil {
1829  			t.Fatalf("unable to complete circuit with in key=%s: %v",
1830  				packet.inKey(), err)
1831  		}
1832  	case <-time.After(time.Second):
1833  		t.Fatal("request was not propagated to channelPoint")
1834  	}
1835  
1836  	if s2.circuits.NumPending() != 0 {
1837  		t.Fatalf("wrong amount of half circuits, want 1, got %d",
1838  			s2.circuits.NumPending())
1839  	}
1840  	if s2.circuits.NumOpen() != 0 {
1841  		t.Fatal("wrong amount of circuits")
1842  	}
1843  
1844  	if err := s2.Stop(); err != nil {
1845  		t.Fatal(err)
1846  	}
1847  
1848  	if err := cdb2.Close(); err != nil {
1849  		t.Fatal(err)
1850  	}
1851  
1852  	cdb3 := channeldb.OpenForTesting(t, tempPath)
1853  
1854  	s3, err := initSwitchWithDB(testStartingHeight, cdb3)
1855  	require.NoError(t, err, "unable reinit switch")
1856  	if err := s3.Start(); err != nil {
1857  		t.Fatalf("unable to restart switch: %v", err)
1858  	}
1859  	defer s3.Stop()
1860  
1861  	aliceChannelLink = newMockChannelLink(
1862  		s3, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
1863  		false, false,
1864  	)
1865  	bobChannelLink = newMockChannelLink(
1866  		s3, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
1867  		false,
1868  	)
1869  	if err := s3.AddLink(aliceChannelLink); err != nil {
1870  		t.Fatalf("unable to add alice link: %v", err)
1871  	}
1872  	if err := s3.AddLink(bobChannelLink); err != nil {
1873  		t.Fatalf("unable to add bob link: %v", err)
1874  	}
1875  
1876  	if s3.circuits.NumPending() != 0 {
1877  		t.Fatalf("wrong amount of half circuits")
1878  	}
1879  	if s3.circuits.NumOpen() != 0 {
1880  		t.Fatalf("wrong amount of circuits")
1881  	}
1882  }
1883  
1884  type multiHopFwdTest struct {
1885  	name                 string
1886  	eligible1, eligible2 bool
1887  	failure1, failure2   *LinkError
1888  	expectedReply        lnwire.FailCode
1889  }
1890  
1891  // TestCircularForwards tests the allowing/disallowing of circular payments
1892  // through the same channel in the case where the switch is configured to allow
1893  // and disallow same channel circular forwards.
1894  func TestCircularForwards(t *testing.T) {
1895  	chanID1, aliceChanID := genID()
1896  	preimage := [sha256.Size]byte{1}
1897  	hash := sha256.Sum256(preimage[:])
1898  
1899  	tests := []struct {
1900  		name                 string
1901  		allowCircularPayment bool
1902  		expectedErr          error
1903  	}{
1904  		{
1905  			name:                 "circular payment allowed",
1906  			allowCircularPayment: true,
1907  			expectedErr:          nil,
1908  		},
1909  		{
1910  			name:                 "circular payment disallowed",
1911  			allowCircularPayment: false,
1912  			expectedErr: NewDetailedLinkError(
1913  				lnwire.NewTemporaryChannelFailure(nil),
1914  				OutgoingFailureCircularRoute,
1915  			),
1916  		},
1917  	}
1918  
1919  	for _, test := range tests {
1920  		test := test
1921  		t.Run(test.name, func(t *testing.T) {
1922  			t.Parallel()
1923  
1924  			alicePeer, err := newMockServer(
1925  				t, "alice", testStartingHeight, nil,
1926  				testDefaultDelta,
1927  			)
1928  			if err != nil {
1929  				t.Fatalf("unable to create alice server: %v",
1930  					err)
1931  			}
1932  
1933  			s, err := initSwitchWithTempDB(t, testStartingHeight)
1934  			if err != nil {
1935  				t.Fatalf("unable to init switch: %v", err)
1936  			}
1937  			if err := s.Start(); err != nil {
1938  				t.Fatalf("unable to start switch: %v", err)
1939  			}
1940  			defer func() { _ = s.Stop() }()
1941  
1942  			// Set the switch to allow or disallow circular routes
1943  			// according to the test's requirements.
1944  			s.cfg.AllowCircularRoute = test.allowCircularPayment
1945  
1946  			aliceChannelLink := newMockChannelLink(
1947  				s, chanID1, aliceChanID, emptyScid, alicePeer,
1948  				true, false, false, false,
1949  			)
1950  
1951  			if err := s.AddLink(aliceChannelLink); err != nil {
1952  				t.Fatalf("unable to add alice link: %v", err)
1953  			}
1954  
1955  			// Create a new packet that loops through alice's link
1956  			// in a circle.
1957  			obfuscator := NewMockObfuscator()
1958  			packet := &htlcPacket{
1959  				incomingChanID: aliceChannelLink.ShortChanID(),
1960  				outgoingChanID: aliceChannelLink.ShortChanID(),
1961  				htlc: &lnwire.UpdateAddHTLC{
1962  					PaymentHash: hash,
1963  					Amount:      1,
1964  				},
1965  				obfuscator: obfuscator,
1966  			}
1967  
1968  			// Attempt to forward the packet and check for the expected
1969  			// error.
1970  			if err = s.ForwardPackets(nil, packet); err != nil {
1971  				t.Fatal(err)
1972  			}
1973  			select {
1974  			case p := <-aliceChannelLink.packets:
1975  				if p.linkFailure != nil {
1976  					err = p.linkFailure
1977  				}
1978  			case <-time.After(time.Second):
1979  				t.Fatal("no timely reply from switch")
1980  			}
1981  			if !reflect.DeepEqual(err, test.expectedErr) {
1982  				t.Fatalf("expected: %v, got: %v",
1983  					test.expectedErr, err)
1984  			}
1985  
1986  			// Ensure that no circuits were opened.
1987  			if s.circuits.NumOpen() > 0 {
1988  				t.Fatal("do not expect any open circuits")
1989  			}
1990  		})
1991  	}
1992  }
1993  
1994  // TestCheckCircularForward tests the error returned by checkCircularForward
1995  // in cases where we allow and disallow same channel circular forwards.
1996  func TestCheckCircularForward(t *testing.T) {
1997  	tests := []struct {
1998  		name string
1999  
2000  		// aliasMapping determines whether the test should add an alias
2001  		// mapping to Switch alias maps before checkCircularForward.
2002  		aliasMapping bool
2003  
2004  		// allowCircular determines whether we should allow circular
2005  		// forwards.
2006  		allowCircular bool
2007  
2008  		// incomingLink is the link that the htlc arrived on.
2009  		incomingLink lnwire.ShortChannelID
2010  
2011  		// outgoingLink is the link that the htlc forward
2012  		// is destined to leave on.
2013  		outgoingLink lnwire.ShortChannelID
2014  
2015  		// expectedErr is the error we expect to be returned.
2016  		expectedErr *LinkError
2017  	}{
2018  		{
2019  			name:          "not circular, allowed in config",
2020  			aliasMapping:  false,
2021  			allowCircular: true,
2022  			incomingLink:  lnwire.NewShortChanIDFromInt(123),
2023  			outgoingLink:  lnwire.NewShortChanIDFromInt(321),
2024  			expectedErr:   nil,
2025  		},
2026  		{
2027  			name:          "not circular, not allowed in config",
2028  			aliasMapping:  false,
2029  			allowCircular: false,
2030  			incomingLink:  lnwire.NewShortChanIDFromInt(123),
2031  			outgoingLink:  lnwire.NewShortChanIDFromInt(321),
2032  			expectedErr:   nil,
2033  		},
2034  		{
2035  			name:          "circular, allowed in config",
2036  			aliasMapping:  false,
2037  			allowCircular: true,
2038  			incomingLink:  lnwire.NewShortChanIDFromInt(123),
2039  			outgoingLink:  lnwire.NewShortChanIDFromInt(123),
2040  			expectedErr:   nil,
2041  		},
2042  		{
2043  			name:          "circular, not allowed in config",
2044  			aliasMapping:  false,
2045  			allowCircular: false,
2046  			incomingLink:  lnwire.NewShortChanIDFromInt(123),
2047  			outgoingLink:  lnwire.NewShortChanIDFromInt(123),
2048  			expectedErr: NewDetailedLinkError(
2049  				lnwire.NewTemporaryChannelFailure(nil),
2050  				OutgoingFailureCircularRoute,
2051  			),
2052  		},
2053  		{
2054  			name:          "circular with map, not allowed",
2055  			aliasMapping:  true,
2056  			allowCircular: false,
2057  			incomingLink:  lnwire.NewShortChanIDFromInt(1 << 60),
2058  			outgoingLink:  lnwire.NewShortChanIDFromInt(1 << 55),
2059  			expectedErr: NewDetailedLinkError(
2060  				lnwire.NewTemporaryChannelFailure(nil),
2061  				OutgoingFailureCircularRoute,
2062  			),
2063  		},
2064  		{
2065  			name:          "circular with map, not allowed 2",
2066  			aliasMapping:  true,
2067  			allowCircular: false,
2068  			incomingLink:  lnwire.NewShortChanIDFromInt(1 << 55),
2069  			outgoingLink:  lnwire.NewShortChanIDFromInt(1 << 60),
2070  			expectedErr: NewDetailedLinkError(
2071  				lnwire.NewTemporaryChannelFailure(nil),
2072  				OutgoingFailureCircularRoute,
2073  			),
2074  		},
2075  		{
2076  			name:          "circular with map, allowed",
2077  			aliasMapping:  true,
2078  			allowCircular: true,
2079  			incomingLink:  lnwire.NewShortChanIDFromInt(1 << 60),
2080  			outgoingLink:  lnwire.NewShortChanIDFromInt(1 << 55),
2081  			expectedErr:   nil,
2082  		},
2083  		{
2084  			name:          "circular with map, allowed 2",
2085  			aliasMapping:  true,
2086  			allowCircular: true,
2087  			incomingLink:  lnwire.NewShortChanIDFromInt(1 << 55),
2088  			outgoingLink:  lnwire.NewShortChanIDFromInt(1 << 61),
2089  			expectedErr:   nil,
2090  		},
2091  		{
2092  			name:          "not circular, both confirmed SCID",
2093  			aliasMapping:  false,
2094  			allowCircular: false,
2095  			incomingLink:  lnwire.NewShortChanIDFromInt(1 << 60),
2096  			outgoingLink:  lnwire.NewShortChanIDFromInt(1 << 61),
2097  			expectedErr:   nil,
2098  		},
2099  	}
2100  
2101  	for _, test := range tests {
2102  		test := test
2103  
2104  		t.Run(test.name, func(t *testing.T) {
2105  			t.Parallel()
2106  
2107  			s, err := initSwitchWithTempDB(t, testStartingHeight)
2108  			require.NoError(t, err)
2109  			err = s.Start()
2110  			require.NoError(t, err)
2111  			defer func() { _ = s.Stop() }()
2112  
2113  			if test.aliasMapping {
2114  				// Make the incoming and outgoing point to the
2115  				// same base SCID.
2116  				inScid := test.incomingLink
2117  				outScid := test.outgoingLink
2118  				s.indexMtx.Lock()
2119  				s.baseIndex[inScid] = outScid
2120  				s.baseIndex[outScid] = outScid
2121  				s.indexMtx.Unlock()
2122  			}
2123  
2124  			// Check for a circular forward, the hash passed can
2125  			// be nil because it is only used for logging.
2126  			err = s.checkCircularForward(
2127  				test.incomingLink, test.outgoingLink,
2128  				test.allowCircular, lntypes.Hash{},
2129  			)
2130  			if !reflect.DeepEqual(err, test.expectedErr) {
2131  				t.Fatalf("expected: %v, got: %v",
2132  					test.expectedErr, err)
2133  			}
2134  		})
2135  	}
2136  }
2137  
2138  // TestSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes
2139  // along, then we won't attempt to forward it down al ink that isn't yet able
2140  // to forward any HTLC's.
2141  func TestSkipIneligibleLinksMultiHopForward(t *testing.T) {
2142  	tests := []multiHopFwdTest{
2143  		// None of the channels is eligible.
2144  		{
2145  			name:          "not eligible",
2146  			expectedReply: lnwire.CodeUnknownNextPeer,
2147  		},
2148  
2149  		// Channel one has a policy failure and the other channel isn't
2150  		// available.
2151  		{
2152  			name:      "policy fail",
2153  			eligible1: true,
2154  			failure1: NewLinkError(
2155  				lnwire.NewFinalIncorrectCltvExpiry(0),
2156  			),
2157  			expectedReply: lnwire.CodeFinalIncorrectCltvExpiry,
2158  		},
2159  
2160  		// The requested channel is not eligible, but the packet is
2161  		// forwarded through the other channel.
2162  		{
2163  			name:          "non-strict success",
2164  			eligible2:     true,
2165  			expectedReply: lnwire.CodeNone,
2166  		},
2167  
2168  		// The requested channel has insufficient bandwidth and the
2169  		// other channel's policy isn't satisfied.
2170  		{
2171  			name:      "non-strict policy fail",
2172  			eligible1: true,
2173  			failure1: NewDetailedLinkError(
2174  				lnwire.NewTemporaryChannelFailure(nil),
2175  				OutgoingFailureInsufficientBalance,
2176  			),
2177  			eligible2: true,
2178  			failure2: NewLinkError(
2179  				lnwire.NewFinalIncorrectCltvExpiry(0),
2180  			),
2181  			expectedReply: lnwire.CodeTemporaryChannelFailure,
2182  		},
2183  	}
2184  
2185  	for _, test := range tests {
2186  		test := test
2187  		t.Run(test.name, func(t *testing.T) {
2188  			testSkipIneligibleLinksMultiHopForward(t, &test)
2189  		})
2190  	}
2191  }
2192  
2193  // testSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes
2194  // along, then we won't attempt to forward it down al ink that isn't yet able
2195  // to forward any HTLC's.
2196  func testSkipIneligibleLinksMultiHopForward(t *testing.T,
2197  	testCase *multiHopFwdTest) {
2198  
2199  	t.Parallel()
2200  
2201  	var packet *htlcPacket
2202  
2203  	alicePeer, err := newMockServer(
2204  		t, "alice", testStartingHeight, nil, testDefaultDelta,
2205  	)
2206  	require.NoError(t, err, "unable to create alice server")
2207  	bobPeer, err := newMockServer(
2208  		t, "bob", testStartingHeight, nil, testDefaultDelta,
2209  	)
2210  	require.NoError(t, err, "unable to create bob server")
2211  
2212  	s, err := initSwitchWithTempDB(t, testStartingHeight)
2213  	require.NoError(t, err, "unable to init switch")
2214  	if err := s.Start(); err != nil {
2215  		t.Fatalf("unable to start switch: %v", err)
2216  	}
2217  	defer s.Stop()
2218  
2219  	chanID1, aliceChanID := genID()
2220  	aliceChannelLink := newMockChannelLink(
2221  		s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
2222  		false, false,
2223  	)
2224  
2225  	// We'll create a link for Bob, but mark the link as unable to forward
2226  	// any new outgoing HTLC's.
2227  	chanID2, bobChanID2 := genID()
2228  	bobChannelLink1 := newMockChannelLink(
2229  		s, chanID2, bobChanID2, emptyScid, bobPeer, testCase.eligible1,
2230  		false, false, false,
2231  	)
2232  	bobChannelLink1.checkHtlcForwardResult = testCase.failure1
2233  
2234  	chanID3, bobChanID3 := genID()
2235  	bobChannelLink2 := newMockChannelLink(
2236  		s, chanID3, bobChanID3, emptyScid, bobPeer, testCase.eligible2,
2237  		false, false, false,
2238  	)
2239  	bobChannelLink2.checkHtlcForwardResult = testCase.failure2
2240  
2241  	if err := s.AddLink(aliceChannelLink); err != nil {
2242  		t.Fatalf("unable to add alice link: %v", err)
2243  	}
2244  	if err := s.AddLink(bobChannelLink1); err != nil {
2245  		t.Fatalf("unable to add bob link: %v", err)
2246  	}
2247  	if err := s.AddLink(bobChannelLink2); err != nil {
2248  		t.Fatalf("unable to add bob link: %v", err)
2249  	}
2250  
2251  	// Create a new packet that's destined for Bob as an incoming HTLC from
2252  	// Alice.
2253  	preimage := [sha256.Size]byte{1}
2254  	rhash := sha256.Sum256(preimage[:])
2255  	obfuscator := NewMockObfuscator()
2256  	packet = &htlcPacket{
2257  		incomingChanID: aliceChannelLink.ShortChanID(),
2258  		incomingHTLCID: 0,
2259  		outgoingChanID: bobChannelLink1.ShortChanID(),
2260  		htlc: &lnwire.UpdateAddHTLC{
2261  			PaymentHash: rhash,
2262  			Amount:      1,
2263  		},
2264  		obfuscator: obfuscator,
2265  	}
2266  
2267  	// The request to forward should fail as
2268  	if err := s.ForwardPackets(nil, packet); err != nil {
2269  		t.Fatal(err)
2270  	}
2271  
2272  	// We select from all links and extract the error if exists.
2273  	// The packet must be selected but we don't always expect a link error.
2274  	var linkError *LinkError
2275  	select {
2276  	case p := <-aliceChannelLink.packets:
2277  		linkError = p.linkFailure
2278  	case p := <-bobChannelLink1.packets:
2279  		linkError = p.linkFailure
2280  	case p := <-bobChannelLink2.packets:
2281  		linkError = p.linkFailure
2282  	case <-time.After(time.Second):
2283  		t.Fatal("no timely reply from switch")
2284  	}
2285  	failure := obfuscator.(*mockObfuscator).failure
2286  	if testCase.expectedReply == lnwire.CodeNone {
2287  		if linkError != nil {
2288  			t.Fatalf("forwarding should have succeeded")
2289  		}
2290  		if failure != nil {
2291  			t.Fatalf("unexpected failure %T", failure)
2292  		}
2293  	} else {
2294  		if linkError == nil {
2295  			t.Fatalf("forwarding should have failed due to " +
2296  				"inactive link")
2297  		}
2298  		if failure.Code() != testCase.expectedReply {
2299  			t.Fatalf("unexpected failure %T", failure)
2300  		}
2301  	}
2302  
2303  	if s.circuits.NumOpen() != 0 {
2304  		t.Fatal("wrong amount of circuits")
2305  	}
2306  }
2307  
2308  // TestSkipIneligibleLinksLocalForward ensures that the switch will not attempt
2309  // to forward any HTLC's down a link that isn't yet eligible for forwarding.
2310  func TestSkipIneligibleLinksLocalForward(t *testing.T) {
2311  	t.Parallel()
2312  
2313  	testSkipLinkLocalForward(t, false, nil)
2314  }
2315  
2316  // TestSkipPolicyUnsatisfiedLinkLocalForward ensures that the switch will not
2317  // attempt to send locally initiated HTLCs that would violate the channel policy
2318  // down a link.
2319  func TestSkipPolicyUnsatisfiedLinkLocalForward(t *testing.T) {
2320  	t.Parallel()
2321  
2322  	testSkipLinkLocalForward(t, true, lnwire.NewTemporaryChannelFailure(nil))
2323  }
2324  
2325  func testSkipLinkLocalForward(t *testing.T, eligible bool,
2326  	policyResult lnwire.FailureMessage) {
2327  
2328  	// We'll create a single link for this test, marking it as being unable
2329  	// to forward form the get go.
2330  	alicePeer, err := newMockServer(
2331  		t, "alice", testStartingHeight, nil, testDefaultDelta,
2332  	)
2333  	require.NoError(t, err, "unable to create alice server")
2334  
2335  	s, err := initSwitchWithTempDB(t, testStartingHeight)
2336  	require.NoError(t, err, "unable to init switch")
2337  	if err := s.Start(); err != nil {
2338  		t.Fatalf("unable to start switch: %v", err)
2339  	}
2340  	defer s.Stop()
2341  
2342  	chanID1, _, aliceChanID, _ := genIDs()
2343  
2344  	aliceChannelLink := newMockChannelLink(
2345  		s, chanID1, aliceChanID, emptyScid, alicePeer, eligible, false,
2346  		false, false,
2347  	)
2348  	aliceChannelLink.checkHtlcTransitResult = NewLinkError(
2349  		policyResult,
2350  	)
2351  	if err := s.AddLink(aliceChannelLink); err != nil {
2352  		t.Fatalf("unable to add alice link: %v", err)
2353  	}
2354  
2355  	preimage, err := genPreimage()
2356  	require.NoError(t, err, "unable to generate preimage")
2357  	rhash := sha256.Sum256(preimage[:])
2358  	addMsg := &lnwire.UpdateAddHTLC{
2359  		PaymentHash: rhash,
2360  		Amount:      1,
2361  	}
2362  
2363  	// We'll attempt to send out a new HTLC that has Alice as the first
2364  	// outgoing link. This should fail as Alice isn't yet able to forward
2365  	// any active HTLC's.
2366  	err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg)
2367  	if err == nil {
2368  		t.Fatalf("local forward should fail due to inactive link")
2369  	}
2370  
2371  	if s.circuits.NumOpen() != 0 {
2372  		t.Fatal("wrong amount of circuits")
2373  	}
2374  }
2375  
2376  // TestSwitchCancel checks that if htlc was rejected we remove unused
2377  // circuits.
2378  func TestSwitchCancel(t *testing.T) {
2379  	t.Parallel()
2380  
2381  	alicePeer, err := newMockServer(
2382  		t, "alice", testStartingHeight, nil, testDefaultDelta,
2383  	)
2384  	require.NoError(t, err, "unable to create alice server")
2385  	bobPeer, err := newMockServer(
2386  		t, "bob", testStartingHeight, nil, testDefaultDelta,
2387  	)
2388  	require.NoError(t, err, "unable to create bob server")
2389  
2390  	s, err := initSwitchWithTempDB(t, testStartingHeight)
2391  	require.NoError(t, err, "unable to init switch")
2392  	if err := s.Start(); err != nil {
2393  		t.Fatalf("unable to start switch: %v", err)
2394  	}
2395  	defer s.Stop()
2396  
2397  	chanID1, chanID2, aliceChanID, bobChanID := genIDs()
2398  
2399  	aliceChannelLink := newMockChannelLink(
2400  		s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
2401  		false, false,
2402  	)
2403  	bobChannelLink := newMockChannelLink(
2404  		s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
2405  		false,
2406  	)
2407  	if err := s.AddLink(aliceChannelLink); err != nil {
2408  		t.Fatalf("unable to add alice link: %v", err)
2409  	}
2410  	if err := s.AddLink(bobChannelLink); err != nil {
2411  		t.Fatalf("unable to add bob link: %v", err)
2412  	}
2413  
2414  	// Create request which should be forwarder from alice channel link
2415  	// to bob channel link.
2416  	preimage, err := genPreimage()
2417  	require.NoError(t, err, "unable to generate preimage")
2418  	rhash := sha256.Sum256(preimage[:])
2419  	request := &htlcPacket{
2420  		incomingChanID: aliceChannelLink.ShortChanID(),
2421  		incomingHTLCID: 0,
2422  		outgoingChanID: bobChannelLink.ShortChanID(),
2423  		obfuscator:     NewMockObfuscator(),
2424  		htlc: &lnwire.UpdateAddHTLC{
2425  			PaymentHash: rhash,
2426  			Amount:      1,
2427  		},
2428  	}
2429  
2430  	// Handle the request and checks that bob channel link received it.
2431  	if err := s.ForwardPackets(nil, request); err != nil {
2432  		t.Fatal(err)
2433  	}
2434  
2435  	select {
2436  	case packet := <-bobChannelLink.packets:
2437  		if err := bobChannelLink.completeCircuit(packet); err != nil {
2438  			t.Fatalf("unable to complete payment circuit: %v", err)
2439  		}
2440  
2441  	case <-time.After(time.Second):
2442  		t.Fatal("request was not propagated to destination")
2443  	}
2444  
2445  	if s.circuits.NumPending() != 1 {
2446  		t.Fatalf("wrong amount of half circuits")
2447  	}
2448  	if s.circuits.NumOpen() != 1 {
2449  		t.Fatal("wrong amount of circuits")
2450  	}
2451  
2452  	// Create settle request pretending that bob channel link handled
2453  	// the add htlc request and sent the htlc settle request back. This
2454  	// request should be forwarder back to alice channel link.
2455  	request = &htlcPacket{
2456  		outgoingChanID: bobChannelLink.ShortChanID(),
2457  		outgoingHTLCID: 0,
2458  		amount:         1,
2459  		htlc:           &lnwire.UpdateFailHTLC{},
2460  	}
2461  
2462  	// Handle the request and checks that payment circuit works properly.
2463  	if err := s.ForwardPackets(nil, request); err != nil {
2464  		t.Fatal(err)
2465  	}
2466  
2467  	select {
2468  	case pkt := <-aliceChannelLink.packets:
2469  		if err := aliceChannelLink.completeCircuit(pkt); err != nil {
2470  			t.Fatalf("unable to remove circuit: %v", err)
2471  		}
2472  
2473  	case <-time.After(time.Second):
2474  		t.Fatal("request was not propagated to channelPoint")
2475  	}
2476  
2477  	if s.circuits.NumPending() != 0 {
2478  		t.Fatal("wrong amount of circuits")
2479  	}
2480  	if s.circuits.NumOpen() != 0 {
2481  		t.Fatal("wrong amount of circuits")
2482  	}
2483  }
2484  
2485  // TestSwitchAddSamePayment tests that we send the payment with the same
2486  // payment hash.
2487  func TestSwitchAddSamePayment(t *testing.T) {
2488  	t.Parallel()
2489  
2490  	chanID1, chanID2, aliceChanID, bobChanID := genIDs()
2491  
2492  	alicePeer, err := newMockServer(
2493  		t, "alice", testStartingHeight, nil, testDefaultDelta,
2494  	)
2495  	require.NoError(t, err, "unable to create alice server")
2496  	bobPeer, err := newMockServer(
2497  		t, "bob", testStartingHeight, nil, testDefaultDelta,
2498  	)
2499  	require.NoError(t, err, "unable to create bob server")
2500  
2501  	s, err := initSwitchWithTempDB(t, testStartingHeight)
2502  	require.NoError(t, err, "unable to init switch")
2503  	if err := s.Start(); err != nil {
2504  		t.Fatalf("unable to start switch: %v", err)
2505  	}
2506  	defer s.Stop()
2507  
2508  	aliceChannelLink := newMockChannelLink(
2509  		s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
2510  		false, false,
2511  	)
2512  	bobChannelLink := newMockChannelLink(
2513  		s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
2514  		false,
2515  	)
2516  	if err := s.AddLink(aliceChannelLink); err != nil {
2517  		t.Fatalf("unable to add alice link: %v", err)
2518  	}
2519  	if err := s.AddLink(bobChannelLink); err != nil {
2520  		t.Fatalf("unable to add bob link: %v", err)
2521  	}
2522  
2523  	// Create request which should be forwarder from alice channel link
2524  	// to bob channel link.
2525  	preimage, err := genPreimage()
2526  	require.NoError(t, err, "unable to generate preimage")
2527  	rhash := sha256.Sum256(preimage[:])
2528  	request := &htlcPacket{
2529  		incomingChanID: aliceChannelLink.ShortChanID(),
2530  		incomingHTLCID: 0,
2531  		outgoingChanID: bobChannelLink.ShortChanID(),
2532  		obfuscator:     NewMockObfuscator(),
2533  		htlc: &lnwire.UpdateAddHTLC{
2534  			PaymentHash: rhash,
2535  			Amount:      1,
2536  		},
2537  	}
2538  
2539  	// Handle the request and checks that bob channel link received it.
2540  	if err := s.ForwardPackets(nil, request); err != nil {
2541  		t.Fatal(err)
2542  	}
2543  
2544  	select {
2545  	case packet := <-bobChannelLink.packets:
2546  		if err := bobChannelLink.completeCircuit(packet); err != nil {
2547  			t.Fatalf("unable to complete payment circuit: %v", err)
2548  		}
2549  
2550  	case <-time.After(time.Second):
2551  		t.Fatal("request was not propagated to destination")
2552  	}
2553  
2554  	if s.circuits.NumOpen() != 1 {
2555  		t.Fatal("wrong amount of circuits")
2556  	}
2557  
2558  	request = &htlcPacket{
2559  		incomingChanID: aliceChannelLink.ShortChanID(),
2560  		incomingHTLCID: 1,
2561  		outgoingChanID: bobChannelLink.ShortChanID(),
2562  		obfuscator:     NewMockObfuscator(),
2563  		htlc: &lnwire.UpdateAddHTLC{
2564  			PaymentHash: rhash,
2565  			Amount:      1,
2566  		},
2567  	}
2568  
2569  	// Handle the request and checks that bob channel link received it.
2570  	if err := s.ForwardPackets(nil, request); err != nil {
2571  		t.Fatal(err)
2572  	}
2573  
2574  	select {
2575  	case packet := <-bobChannelLink.packets:
2576  		if err := bobChannelLink.completeCircuit(packet); err != nil {
2577  			t.Fatalf("unable to complete payment circuit: %v", err)
2578  		}
2579  
2580  	case <-time.After(time.Second):
2581  		t.Fatal("request was not propagated to destination")
2582  	}
2583  
2584  	if s.circuits.NumOpen() != 2 {
2585  		t.Fatal("wrong amount of circuits")
2586  	}
2587  
2588  	// Create settle request pretending that bob channel link handled
2589  	// the add htlc request and sent the htlc settle request back. This
2590  	// request should be forwarder back to alice channel link.
2591  	request = &htlcPacket{
2592  		outgoingChanID: bobChannelLink.ShortChanID(),
2593  		outgoingHTLCID: 0,
2594  		amount:         1,
2595  		htlc:           &lnwire.UpdateFailHTLC{},
2596  	}
2597  
2598  	// Handle the request and checks that payment circuit works properly.
2599  	if err := s.ForwardPackets(nil, request); err != nil {
2600  		t.Fatal(err)
2601  	}
2602  
2603  	select {
2604  	case pkt := <-aliceChannelLink.packets:
2605  		if err := aliceChannelLink.completeCircuit(pkt); err != nil {
2606  			t.Fatalf("unable to remove circuit: %v", err)
2607  		}
2608  
2609  	case <-time.After(time.Second):
2610  		t.Fatal("request was not propagated to channelPoint")
2611  	}
2612  
2613  	if s.circuits.NumOpen() != 1 {
2614  		t.Fatal("wrong amount of circuits")
2615  	}
2616  
2617  	request = &htlcPacket{
2618  		outgoingChanID: bobChannelLink.ShortChanID(),
2619  		outgoingHTLCID: 1,
2620  		amount:         1,
2621  		htlc:           &lnwire.UpdateFailHTLC{},
2622  	}
2623  
2624  	// Handle the request and checks that payment circuit works properly.
2625  	if err := s.ForwardPackets(nil, request); err != nil {
2626  		t.Fatal(err)
2627  	}
2628  
2629  	select {
2630  	case pkt := <-aliceChannelLink.packets:
2631  		if err := aliceChannelLink.completeCircuit(pkt); err != nil {
2632  			t.Fatalf("unable to remove circuit: %v", err)
2633  		}
2634  
2635  	case <-time.After(time.Second):
2636  		t.Fatal("request was not propagated to channelPoint")
2637  	}
2638  
2639  	if s.circuits.NumOpen() != 0 {
2640  		t.Fatal("wrong amount of circuits")
2641  	}
2642  }
2643  
2644  // TestSwitchSendPayment tests ability of htlc switch to respond to the
2645  // users when response is came back from channel link.
2646  func TestSwitchSendPayment(t *testing.T) {
2647  	t.Parallel()
2648  
2649  	alicePeer, err := newMockServer(
2650  		t, "alice", testStartingHeight, nil, testDefaultDelta,
2651  	)
2652  	require.NoError(t, err, "unable to create alice server")
2653  
2654  	s, err := initSwitchWithTempDB(t, testStartingHeight)
2655  	require.NoError(t, err, "unable to init switch")
2656  	if err := s.Start(); err != nil {
2657  		t.Fatalf("unable to start switch: %v", err)
2658  	}
2659  	defer s.Stop()
2660  
2661  	chanID1, _, aliceChanID, _ := genIDs()
2662  
2663  	aliceChannelLink := newMockChannelLink(
2664  		s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
2665  		false, false,
2666  	)
2667  	if err := s.AddLink(aliceChannelLink); err != nil {
2668  		t.Fatalf("unable to add link: %v", err)
2669  	}
2670  
2671  	// Create request which should be forwarder from alice channel link
2672  	// to bob channel link.
2673  	preimage, err := genPreimage()
2674  	require.NoError(t, err, "unable to generate preimage")
2675  	rhash := sha256.Sum256(preimage[:])
2676  	update := &lnwire.UpdateAddHTLC{
2677  		PaymentHash: rhash,
2678  		Amount:      1,
2679  	}
2680  	paymentID := uint64(123)
2681  
2682  	// First check that the switch will correctly respond that this payment
2683  	// ID is unknown.
2684  	_, err = s.GetAttemptResult(
2685  		paymentID, rhash, newMockDeobfuscator(),
2686  	)
2687  	if err != ErrPaymentIDNotFound {
2688  		t.Fatalf("expected ErrPaymentIDNotFound, got %v", err)
2689  	}
2690  
2691  	// Handle the request and checks that bob channel link received it.
2692  	errChan := make(chan error)
2693  	go func() {
2694  		err := s.SendHTLC(
2695  			aliceChannelLink.ShortChanID(), paymentID, update,
2696  		)
2697  		if err != nil {
2698  			errChan <- err
2699  			return
2700  		}
2701  
2702  		resultChan, err := s.GetAttemptResult(
2703  			paymentID, rhash, newMockDeobfuscator(),
2704  		)
2705  		if err != nil {
2706  			errChan <- err
2707  			return
2708  		}
2709  
2710  		result, ok := <-resultChan
2711  		if !ok {
2712  			errChan <- fmt.Errorf("shutting down")
2713  		}
2714  
2715  		if result.Error != nil {
2716  			errChan <- result.Error
2717  			return
2718  		}
2719  
2720  		errChan <- nil
2721  	}()
2722  
2723  	select {
2724  	case packet := <-aliceChannelLink.packets:
2725  		if err := aliceChannelLink.completeCircuit(packet); err != nil {
2726  			t.Fatalf("unable to complete payment circuit: %v", err)
2727  		}
2728  
2729  	case err := <-errChan:
2730  		if err != nil {
2731  			t.Fatalf("unable to send payment: %v", err)
2732  		}
2733  	case <-time.After(time.Second):
2734  		t.Fatal("request was not propagated to destination")
2735  	}
2736  
2737  	if s.circuits.NumOpen() != 1 {
2738  		t.Fatal("wrong amount of circuits")
2739  	}
2740  
2741  	// Create fail request pretending that bob channel link handled
2742  	// the add htlc request with error and sent the htlc fail request
2743  	// back. This request should be forwarded back to alice channel link.
2744  	obfuscator := NewMockObfuscator()
2745  	failure := lnwire.NewFailIncorrectDetails(update.Amount, 100)
2746  	reason, err := obfuscator.EncryptFirstHop(failure)
2747  	require.NoError(t, err, "unable obfuscate failure")
2748  
2749  	if s.IsForwardedHTLC(aliceChannelLink.ShortChanID(), update.ID) {
2750  		t.Fatal("htlc should be identified as not forwarded")
2751  	}
2752  	packet := &htlcPacket{
2753  		outgoingChanID: aliceChannelLink.ShortChanID(),
2754  		outgoingHTLCID: 0,
2755  		amount:         1,
2756  		htlc: &lnwire.UpdateFailHTLC{
2757  			Reason: reason,
2758  		},
2759  	}
2760  
2761  	if err := s.ForwardPackets(nil, packet); err != nil {
2762  		t.Fatalf("can't forward htlc packet: %v", err)
2763  	}
2764  
2765  	select {
2766  	case err := <-errChan:
2767  		assertFailureCode(
2768  			t, err, lnwire.CodeIncorrectOrUnknownPaymentDetails,
2769  		)
2770  	case <-time.After(time.Second):
2771  		t.Fatal("err wasn't received")
2772  	}
2773  }
2774  
2775  // TestLocalPaymentNoForwardingEvents tests that if we send a series of locally
2776  // initiated payments, then they aren't reflected in the forwarding log.
2777  func TestLocalPaymentNoForwardingEvents(t *testing.T) {
2778  	t.Parallel()
2779  
2780  	// First, we'll create our traditional three hop network. We'll only be
2781  	// interacting with and asserting the state of the first end point for
2782  	// this test.
2783  	channels, _, err := createClusterChannels(
2784  		t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5,
2785  	)
2786  	require.NoError(t, err, "unable to create channel")
2787  
2788  	n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice,
2789  		channels.bobToCarol, channels.carolToBob, testStartingHeight)
2790  	if err := n.start(); err != nil {
2791  		t.Fatalf("unable to start three hop network: %v", err)
2792  	}
2793  
2794  	// We'll now craft and send a payment from Alice to Bob.
2795  	amount := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin)
2796  	htlcAmt, totalTimelock, hops := generateHops(
2797  		amount, testStartingHeight, n.firstBobChannelLink,
2798  	)
2799  
2800  	// With the payment crafted, we'll send it from Alice to Bob. We'll
2801  	// wait for Alice to receive the preimage for the payment before
2802  	// proceeding.
2803  	receiver := n.bobServer
2804  	firstHop := n.firstBobChannelLink.ShortChanID()
2805  	_, err = makePayment(
2806  		n.aliceServer, receiver, firstHop, hops, amount, htlcAmt,
2807  		totalTimelock,
2808  	).Wait(30 * time.Second)
2809  	require.NoError(t, err, "unable to make the payment")
2810  
2811  	// At this point, we'll forcibly stop the three hop network. Doing
2812  	// this will cause any pending forwarding events to be flushed by the
2813  	// various switches in the network.
2814  	n.stop()
2815  
2816  	// With all the switches stopped, we'll fetch Alice's mock forwarding
2817  	// event log.
2818  	log, ok := n.aliceServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
2819  	if !ok {
2820  		t.Fatalf("mockForwardingLog assertion failed")
2821  	}
2822  	log.Lock()
2823  	defer log.Unlock()
2824  
2825  	// If we examine the memory of the forwarding log, then it should be
2826  	// blank.
2827  	if len(log.events) != 0 {
2828  		t.Fatalf("log should have no events, instead has: %v",
2829  			spew.Sdump(log.events))
2830  	}
2831  }
2832  
2833  // TestMultiHopPaymentForwardingEvents tests that if we send a series of
2834  // multi-hop payments via Alice->Bob->Carol. Then Bob properly logs forwarding
2835  // events, while Alice and Carol don't.
2836  func TestMultiHopPaymentForwardingEvents(t *testing.T) {
2837  	t.Parallel()
2838  
2839  	// First, we'll create our traditional three hop network.
2840  	channels, _, err := createClusterChannels(
2841  		t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5,
2842  	)
2843  	require.NoError(t, err, "unable to create channel")
2844  
2845  	n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice,
2846  		channels.bobToCarol, channels.carolToBob, testStartingHeight)
2847  	if err := n.start(); err != nil {
2848  		t.Fatalf("unable to start three hop network: %v", err)
2849  	}
2850  
2851  	// We'll make now 10 payments, of 100k satoshis each from Alice to
2852  	// Carol via Bob.
2853  	const numPayments = 10
2854  	finalAmt := lnwire.NewMSatFromSatoshis(100000)
2855  	htlcAmt, totalTimelock, hops := generateHops(
2856  		finalAmt, testStartingHeight, n.firstBobChannelLink,
2857  		n.carolChannelLink,
2858  	)
2859  	firstHop := n.firstBobChannelLink.ShortChanID()
2860  	for i := 0; i < numPayments/2; i++ {
2861  		_, err := makePayment(
2862  			n.aliceServer, n.carolServer, firstHop, hops, finalAmt,
2863  			htlcAmt, totalTimelock,
2864  		).Wait(30 * time.Second)
2865  		if err != nil {
2866  			t.Fatalf("unable to send payment: %v", err)
2867  		}
2868  	}
2869  
2870  	bobLog, ok := n.bobServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
2871  	if !ok {
2872  		t.Fatalf("mockForwardingLog assertion failed")
2873  	}
2874  
2875  	// After sending 5 of the payments, trigger the forwarding ticker, to
2876  	// make sure the events are properly flushed.
2877  	bobTicker, ok := n.bobServer.htlcSwitch.cfg.FwdEventTicker.(*ticker.Force)
2878  	if !ok {
2879  		t.Fatalf("mockTicker assertion failed")
2880  	}
2881  
2882  	// We'll trigger the ticker, and wait for the events to appear in Bob's
2883  	// forwarding log.
2884  	timeout := time.After(15 * time.Second)
2885  	for {
2886  		select {
2887  		case bobTicker.Force <- time.Now():
2888  		case <-time.After(1 * time.Second):
2889  			t.Fatalf("unable to force tick")
2890  		}
2891  
2892  		// If all 5 events is found in Bob's log, we can break out and
2893  		// continue the test.
2894  		bobLog.Lock()
2895  		if len(bobLog.events) == 5 {
2896  			bobLog.Unlock()
2897  			break
2898  		}
2899  		bobLog.Unlock()
2900  
2901  		// Otherwise wait a little bit before checking again.
2902  		select {
2903  		case <-time.After(50 * time.Millisecond):
2904  		case <-timeout:
2905  			bobLog.Lock()
2906  			defer bobLog.Unlock()
2907  			t.Fatalf("expected 5 events in event log, instead "+
2908  				"found: %v", spew.Sdump(bobLog.events))
2909  		}
2910  	}
2911  
2912  	// Send the remaining payments.
2913  	for i := numPayments / 2; i < numPayments; i++ {
2914  		_, err := makePayment(
2915  			n.aliceServer, n.carolServer, firstHop, hops, finalAmt,
2916  			htlcAmt, totalTimelock,
2917  		).Wait(30 * time.Second)
2918  		if err != nil {
2919  			t.Fatalf("unable to send payment: %v", err)
2920  		}
2921  	}
2922  
2923  	// With all 10 payments sent. We'll now manually stop each of the
2924  	// switches so we can examine their end state.
2925  	n.stop()
2926  
2927  	// Alice and Carol shouldn't have any recorded forwarding events, as
2928  	// they were the source and the sink for these payment flows.
2929  	aliceLog, ok := n.aliceServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
2930  	if !ok {
2931  		t.Fatalf("mockForwardingLog assertion failed")
2932  	}
2933  	aliceLog.Lock()
2934  	defer aliceLog.Unlock()
2935  	if len(aliceLog.events) != 0 {
2936  		t.Fatalf("log should have no events, instead has: %v",
2937  			spew.Sdump(aliceLog.events))
2938  	}
2939  
2940  	carolLog, ok := n.carolServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog)
2941  	if !ok {
2942  		t.Fatalf("mockForwardingLog assertion failed")
2943  	}
2944  	carolLog.Lock()
2945  	defer carolLog.Unlock()
2946  	if len(carolLog.events) != 0 {
2947  		t.Fatalf("log should have no events, instead has: %v",
2948  			spew.Sdump(carolLog.events))
2949  	}
2950  
2951  	// Bob on the other hand, should have 10 events.
2952  	bobLog.Lock()
2953  	defer bobLog.Unlock()
2954  	if len(bobLog.events) != 10 {
2955  		t.Fatalf("log should have 10 events, instead has: %v",
2956  			spew.Sdump(bobLog.events))
2957  	}
2958  
2959  	// Each of the 10 events should have had all fields set properly.
2960  	for _, event := range bobLog.events {
2961  		// The incoming and outgoing channels should properly be set for
2962  		// the event.
2963  		if event.IncomingChanID != n.aliceChannelLink.ShortChanID() {
2964  			t.Fatalf("chan id mismatch: expected %v, got %v",
2965  				event.IncomingChanID,
2966  				n.aliceChannelLink.ShortChanID())
2967  		}
2968  		if event.OutgoingChanID != n.carolChannelLink.ShortChanID() {
2969  			t.Fatalf("chan id mismatch: expected %v, got %v",
2970  				event.OutgoingChanID,
2971  				n.carolChannelLink.ShortChanID())
2972  		}
2973  
2974  		// Additionally, the incoming and outgoing amounts should also
2975  		// be properly set.
2976  		if event.AmtIn != htlcAmt {
2977  			t.Fatalf("incoming amt mismatch: expected %v, got %v",
2978  				event.AmtIn, htlcAmt)
2979  		}
2980  		if event.AmtOut != finalAmt {
2981  			t.Fatalf("outgoing amt mismatch: expected %v, got %v",
2982  				event.AmtOut, finalAmt)
2983  		}
2984  	}
2985  }
2986  
2987  // TestUpdateFailMalformedHTLCErrorConversion tests that we're able to properly
2988  // convert malformed HTLC errors that originate at the direct link, as well as
2989  // during multi-hop HTLC forwarding.
2990  func TestUpdateFailMalformedHTLCErrorConversion(t *testing.T) {
2991  	t.Parallel()
2992  
2993  	// First, we'll create our traditional three hop network.
2994  	channels, _, err := createClusterChannels(
2995  		t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5,
2996  	)
2997  	require.NoError(t, err, "unable to create channel")
2998  
2999  	n := newThreeHopNetwork(
3000  		t, channels.aliceToBob, channels.bobToAlice,
3001  		channels.bobToCarol, channels.carolToBob, testStartingHeight,
3002  	)
3003  	if err := n.start(); err != nil {
3004  		t.Fatalf("unable to start three hop network: %v", err)
3005  	}
3006  
3007  	assertPaymentFailure := func(t *testing.T) {
3008  		// With the decoder modified, we'll now attempt to send a
3009  		// payment from Alice to carol.
3010  		finalAmt := lnwire.NewMSatFromSatoshis(100000)
3011  		htlcAmt, totalTimelock, hops := generateHops(
3012  			finalAmt, testStartingHeight, n.firstBobChannelLink,
3013  			n.carolChannelLink,
3014  		)
3015  		firstHop := n.firstBobChannelLink.ShortChanID()
3016  		_, err = makePayment(
3017  			n.aliceServer, n.carolServer, firstHop, hops, finalAmt,
3018  			htlcAmt, totalTimelock,
3019  		).Wait(30 * time.Second)
3020  
3021  		// The payment should fail as Carol is unable to decode the
3022  		// onion blob sent to her.
3023  		if err == nil {
3024  			t.Fatalf("unable to send payment: %v", err)
3025  		}
3026  
3027  		routingErr := err.(ClearTextError)
3028  		failureMsg := routingErr.WireMessage()
3029  		if _, ok := failureMsg.(*lnwire.FailInvalidOnionKey); !ok {
3030  			t.Fatalf("expected onion failure instead got: %v",
3031  				routingErr.WireMessage())
3032  		}
3033  	}
3034  
3035  	t.Run("multi-hop error conversion", func(t *testing.T) {
3036  		// Now that we have our network up, we'll modify the hop
3037  		// iterator for the Bob <-> Carol channel to fail to decode in
3038  		// order to simulate either a replay attack or an issue
3039  		// decoding the onion.
3040  		n.carolOnionDecoder.decodeFail = true
3041  
3042  		assertPaymentFailure(t)
3043  	})
3044  
3045  	t.Run("direct channel error conversion", func(t *testing.T) {
3046  		// Similar to the above test case, we'll now make the Alice <->
3047  		// Bob link always fail to decode an onion. This differs from
3048  		// the above test case in that there's no encryption on the
3049  		// error at all since Alice will directly receive a
3050  		// UpdateFailMalformedHTLC message.
3051  		n.bobOnionDecoder.decodeFail = true
3052  
3053  		assertPaymentFailure(t)
3054  	})
3055  }
3056  
3057  // TestSwitchGetAttemptResult tests that the switch interacts as expected with
3058  // the circuit map and network result store when looking up the result of a
3059  // payment ID. This is important for not to lose results under concurrent
3060  // lookup and receiving results.
3061  func TestSwitchGetAttemptResult(t *testing.T) {
3062  	t.Parallel()
3063  
3064  	const paymentID = 123
3065  	var preimg lntypes.Preimage
3066  	preimg[0] = 3
3067  
3068  	s, err := initSwitchWithTempDB(t, testStartingHeight)
3069  	require.NoError(t, err, "unable to init switch")
3070  	if err := s.Start(); err != nil {
3071  		t.Fatalf("unable to start switch: %v", err)
3072  	}
3073  	defer s.Stop()
3074  
3075  	lookup := make(chan *PaymentCircuit, 1)
3076  	s.circuits = &mockCircuitMap{
3077  		lookup: lookup,
3078  	}
3079  
3080  	// If the payment circuit is not found in the circuit map, the payment
3081  	// result must be found in the store if available. Since we haven't
3082  	// added anything to the store yet, ErrPaymentIDNotFound should be
3083  	// returned.
3084  	lookup <- nil
3085  	_, err = s.GetAttemptResult(
3086  		paymentID, lntypes.Hash{}, newMockDeobfuscator(),
3087  	)
3088  	if err != ErrPaymentIDNotFound {
3089  		t.Fatalf("expected ErrPaymentIDNotFound, got %v", err)
3090  	}
3091  
3092  	// Next let the lookup find the circuit in the circuit map. It should
3093  	// subscribe to payment results, and return the result when available.
3094  	lookup <- &PaymentCircuit{}
3095  	resultChan, err := s.GetAttemptResult(
3096  		paymentID, lntypes.Hash{}, newMockDeobfuscator(),
3097  	)
3098  	require.NoError(t, err, "unable to get payment result")
3099  
3100  	// Add the result to the store.
3101  	n := &networkResult{
3102  		msg: &lnwire.UpdateFulfillHTLC{
3103  			PaymentPreimage: preimg,
3104  		},
3105  		unencrypted:  true,
3106  		isResolution: true,
3107  	}
3108  
3109  	err = s.networkResults.storeResult(paymentID, n)
3110  	require.NoError(t, err, "unable to store result")
3111  
3112  	// The result should be available.
3113  	select {
3114  	case res, ok := <-resultChan:
3115  		if !ok {
3116  			t.Fatalf("channel was closed")
3117  		}
3118  
3119  		if res.Error != nil {
3120  			t.Fatalf("got unexpected error result")
3121  		}
3122  
3123  		if res.Preimage != preimg {
3124  			t.Fatalf("expected preimg %v, got %v",
3125  				preimg, res.Preimage)
3126  		}
3127  
3128  	case <-time.After(1 * time.Second):
3129  		t.Fatalf("result not received")
3130  	}
3131  
3132  	// As a final test, try to get the result again. Now that is no longer
3133  	// in the circuit map, it should be immediately available from the
3134  	// store.
3135  	lookup <- nil
3136  	resultChan, err = s.GetAttemptResult(
3137  		paymentID, lntypes.Hash{}, newMockDeobfuscator(),
3138  	)
3139  	require.NoError(t, err, "unable to get payment result")
3140  
3141  	select {
3142  	case res, ok := <-resultChan:
3143  		if !ok {
3144  			t.Fatalf("channel was closed")
3145  		}
3146  
3147  		if res.Error != nil {
3148  			t.Fatalf("got unexpected error result")
3149  		}
3150  
3151  		if res.Preimage != preimg {
3152  			t.Fatalf("expected preimg %v, got %v",
3153  				preimg, res.Preimage)
3154  		}
3155  
3156  	case <-time.After(1 * time.Second):
3157  		t.Fatalf("result not received")
3158  	}
3159  }
3160  
3161  // TestInvalidFailure tests that the switch returns an unreadable failure error
3162  // if the failure cannot be decrypted.
3163  func TestInvalidFailure(t *testing.T) {
3164  	t.Parallel()
3165  
3166  	alicePeer, err := newMockServer(
3167  		t, "alice", testStartingHeight, nil, testDefaultDelta,
3168  	)
3169  	require.NoError(t, err, "unable to create alice server")
3170  
3171  	s, err := initSwitchWithTempDB(t, testStartingHeight)
3172  	require.NoError(t, err, "unable to init switch")
3173  	if err := s.Start(); err != nil {
3174  		t.Fatalf("unable to start switch: %v", err)
3175  	}
3176  	defer s.Stop()
3177  
3178  	chanID1, _, aliceChanID, _ := genIDs()
3179  
3180  	// Set up a mock channel link.
3181  	aliceChannelLink := newMockChannelLink(
3182  		s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
3183  		false, false,
3184  	)
3185  	if err := s.AddLink(aliceChannelLink); err != nil {
3186  		t.Fatalf("unable to add link: %v", err)
3187  	}
3188  
3189  	// Create a request which should be forwarded to the mock channel link.
3190  	preimage, err := genPreimage()
3191  	require.NoError(t, err, "unable to generate preimage")
3192  	rhash := sha256.Sum256(preimage[:])
3193  	update := &lnwire.UpdateAddHTLC{
3194  		PaymentHash: rhash,
3195  		Amount:      1,
3196  	}
3197  
3198  	paymentID := uint64(123)
3199  
3200  	// Send the request.
3201  	err = s.SendHTLC(
3202  		aliceChannelLink.ShortChanID(), paymentID, update,
3203  	)
3204  	require.NoError(t, err, "unable to send payment")
3205  
3206  	// Catch the packet and complete the circuit so that the switch is ready
3207  	// for a response.
3208  	select {
3209  	case packet := <-aliceChannelLink.packets:
3210  		if err := aliceChannelLink.completeCircuit(packet); err != nil {
3211  			t.Fatalf("unable to complete payment circuit: %v", err)
3212  		}
3213  
3214  	case <-time.After(time.Second):
3215  		t.Fatal("request was not propagated to destination")
3216  	}
3217  
3218  	// Send response packet with an unreadable failure message to the
3219  	// switch. The reason failed is not relevant, because we mock the
3220  	// decryption.
3221  	packet := &htlcPacket{
3222  		outgoingChanID: aliceChannelLink.ShortChanID(),
3223  		outgoingHTLCID: 0,
3224  		amount:         1,
3225  		htlc: &lnwire.UpdateFailHTLC{
3226  			Reason: []byte{1, 2, 3},
3227  		},
3228  	}
3229  
3230  	if err := s.ForwardPackets(nil, packet); err != nil {
3231  		t.Fatalf("can't forward htlc packet: %v", err)
3232  	}
3233  
3234  	// Get payment result from switch. We expect an unreadable failure
3235  	// message error.
3236  	deobfuscator := SphinxErrorDecrypter{
3237  		OnionErrorDecrypter: &mockOnionErrorDecryptor{
3238  			err: ErrUnreadableFailureMessage,
3239  		},
3240  	}
3241  
3242  	resultChan, err := s.GetAttemptResult(
3243  		paymentID, rhash, &deobfuscator,
3244  	)
3245  	if err != nil {
3246  		t.Fatal(err)
3247  	}
3248  
3249  	select {
3250  	case result := <-resultChan:
3251  		if result.Error != ErrUnreadableFailureMessage {
3252  			t.Fatal("expected unreadable failure message")
3253  		}
3254  
3255  	case <-time.After(time.Second):
3256  		t.Fatal("err wasn't received")
3257  	}
3258  
3259  	// Modify the decryption to simulate that decryption went alright, but
3260  	// the failure cannot be decoded.
3261  	deobfuscator = SphinxErrorDecrypter{
3262  		OnionErrorDecrypter: &mockOnionErrorDecryptor{
3263  			sourceIdx: 2,
3264  			message:   []byte{200},
3265  		},
3266  	}
3267  
3268  	resultChan, err = s.GetAttemptResult(
3269  		paymentID, rhash, &deobfuscator,
3270  	)
3271  	if err != nil {
3272  		t.Fatal(err)
3273  	}
3274  
3275  	select {
3276  	case result := <-resultChan:
3277  		rtErr, ok := result.Error.(ClearTextError)
3278  		if !ok {
3279  			t.Fatal("expected ClearTextError")
3280  		}
3281  		source, ok := rtErr.(*ForwardingError)
3282  		if !ok {
3283  			t.Fatalf("expected forwarding error, got: %T", rtErr)
3284  		}
3285  		if source.FailureSourceIdx != 2 {
3286  			t.Fatal("unexpected error source index")
3287  		}
3288  		if rtErr.WireMessage() != nil {
3289  			t.Fatal("expected empty failure message")
3290  		}
3291  
3292  	case <-time.After(time.Second):
3293  		t.Fatal("err wasn't received")
3294  	}
3295  }
3296  
3297  // htlcNotifierEvents is a function that generates a set of expected htlc
3298  // notifier evetns for each node in a three hop network with the dynamic
3299  // values provided. These functions take dynamic values so that changes to
3300  // external systems (such as our default timelock delta) do not break
3301  // these tests.
3302  type htlcNotifierEvents func(channels *clusterChannels, htlcID uint64,
3303  	ts time.Time, htlc *lnwire.UpdateAddHTLC,
3304  	hops []*hop.Payload,
3305  	preimage *lntypes.Preimage) ([]interface{}, []interface{}, []interface{})
3306  
3307  // TestHtlcNotifier tests the notifying of htlc events that are routed over a
3308  // three hop network. It sets up an Alice -> Bob -> Carol network and routes
3309  // payments from Alice -> Carol to test events from the perspective of a
3310  // sending (Alice), forwarding (Bob) and receiving (Carol) node. Test cases
3311  // are present for saduccessful and failed payments.
3312  func TestHtlcNotifier(t *testing.T) {
3313  	tests := []struct {
3314  		name string
3315  
3316  		// Options is a set of options to apply to the three hop
3317  		// network's servers.
3318  		options []serverOption
3319  
3320  		// expectedEvents is a function which returns an expected set
3321  		// of events for the test.
3322  		expectedEvents htlcNotifierEvents
3323  
3324  		// iterations is the number of times we will send a payment,
3325  		// this is used to send more than one payment to force non-
3326  		// zero htlc indexes to make sure we aren't just checking
3327  		// default values.
3328  		iterations int
3329  	}{
3330  		{
3331  			name:    "successful three hop payment",
3332  			options: nil,
3333  			expectedEvents: func(channels *clusterChannels,
3334  				htlcID uint64, ts time.Time,
3335  				htlc *lnwire.UpdateAddHTLC,
3336  				hops []*hop.Payload,
3337  				preimage *lntypes.Preimage) ([]interface{},
3338  				[]interface{}, []interface{}) {
3339  
3340  				return getThreeHopEvents(
3341  					channels, htlcID, ts, htlc, hops, nil, preimage,
3342  				)
3343  			},
3344  			iterations: 2,
3345  		},
3346  		{
3347  			name: "failed at forwarding link",
3348  			// Set a functional option which disables bob as a
3349  			// forwarding node to force a payment error.
3350  			options: []serverOption{
3351  				serverOptionRejectHtlc(false, true, false),
3352  			},
3353  			expectedEvents: func(channels *clusterChannels,
3354  				htlcID uint64, ts time.Time,
3355  				htlc *lnwire.UpdateAddHTLC,
3356  				hops []*hop.Payload,
3357  				preimage *lntypes.Preimage) ([]interface{},
3358  				[]interface{}, []interface{}) {
3359  
3360  				return getThreeHopEvents(
3361  					channels, htlcID, ts, htlc, hops,
3362  					&LinkError{
3363  						msg:           &lnwire.FailChannelDisabled{},
3364  						FailureDetail: OutgoingFailureForwardsDisabled,
3365  					},
3366  					preimage,
3367  				)
3368  			},
3369  			iterations: 1,
3370  		},
3371  	}
3372  
3373  	for _, test := range tests {
3374  		test := test
3375  
3376  		t.Run(test.name, func(t *testing.T) {
3377  			testHtcNotifier(
3378  				t, test.options, test.iterations,
3379  				test.expectedEvents,
3380  			)
3381  		})
3382  	}
3383  }
3384  
3385  // testHtcNotifier runs a htlc notifier test.
3386  func testHtcNotifier(t *testing.T, testOpts []serverOption, iterations int,
3387  	getEvents htlcNotifierEvents) {
3388  
3389  	t.Parallel()
3390  
3391  	// First, we'll create our traditional three hop
3392  	// network.
3393  	channels, _, err := createClusterChannels(
3394  		t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5,
3395  	)
3396  	require.NoError(t, err, "unable to create channel")
3397  
3398  	// Mock time so that all events are reported with a static timestamp.
3399  	now := time.Now()
3400  	mockTime := func() time.Time {
3401  		return now
3402  	}
3403  
3404  	// Create htlc notifiers for each server in the three hop network and
3405  	// start them.
3406  	aliceNotifier := NewHtlcNotifier(mockTime)
3407  	if err := aliceNotifier.Start(); err != nil {
3408  		t.Fatalf("could not start alice notifier")
3409  	}
3410  	t.Cleanup(func() {
3411  		if err := aliceNotifier.Stop(); err != nil {
3412  			t.Fatalf("failed to stop alice notifier: %v", err)
3413  		}
3414  	})
3415  
3416  	bobNotifier := NewHtlcNotifier(mockTime)
3417  	if err := bobNotifier.Start(); err != nil {
3418  		t.Fatalf("could not start bob notifier")
3419  	}
3420  	t.Cleanup(func() {
3421  		if err := bobNotifier.Stop(); err != nil {
3422  			t.Fatalf("failed to stop bob notifier: %v", err)
3423  		}
3424  	})
3425  
3426  	carolNotifier := NewHtlcNotifier(mockTime)
3427  	if err := carolNotifier.Start(); err != nil {
3428  		t.Fatalf("could not start carol notifier")
3429  	}
3430  	t.Cleanup(func() {
3431  		if err := carolNotifier.Stop(); err != nil {
3432  			t.Fatalf("failed to stop carol notifier: %v", err)
3433  		}
3434  	})
3435  
3436  	// Create a notifier server option which will set our htlc notifiers
3437  	// for the three hop network.
3438  	notifierOption := serverOptionWithHtlcNotifier(
3439  		aliceNotifier, bobNotifier, carolNotifier,
3440  	)
3441  
3442  	// Add the htlcNotifier option to any other options
3443  	// set in the test.
3444  	options := append(testOpts, notifierOption) // nolint:gocritic
3445  
3446  	n := newThreeHopNetwork(
3447  		t, channels.aliceToBob,
3448  		channels.bobToAlice, channels.bobToCarol,
3449  		channels.carolToBob, testStartingHeight,
3450  		options...,
3451  	)
3452  	if err := n.start(); err != nil {
3453  		t.Fatalf("unable to start three hop "+
3454  			"network: %v", err)
3455  	}
3456  	t.Cleanup(n.stop)
3457  
3458  	// Before we forward anything, subscribe to htlc events
3459  	// from each notifier.
3460  	aliceEvents, err := aliceNotifier.SubscribeHtlcEvents()
3461  	if err != nil {
3462  		t.Fatalf("could not subscribe to alice's"+
3463  			" events: %v", err)
3464  	}
3465  	t.Cleanup(aliceEvents.Cancel)
3466  
3467  	bobEvents, err := bobNotifier.SubscribeHtlcEvents()
3468  	if err != nil {
3469  		t.Fatalf("could not subscribe to bob's"+
3470  			" events: %v", err)
3471  	}
3472  	t.Cleanup(bobEvents.Cancel)
3473  
3474  	carolEvents, err := carolNotifier.SubscribeHtlcEvents()
3475  	if err != nil {
3476  		t.Fatalf("could not subscribe to carol's"+
3477  			" events: %v", err)
3478  	}
3479  	t.Cleanup(carolEvents.Cancel)
3480  
3481  	// Send multiple payments, as specified by the test to test incrementing
3482  	// of htlc ids.
3483  	for i := 0; i < iterations; i++ {
3484  		// We'll start off by making a payment from
3485  		// Alice -> Bob -> Carol. The preimage, generated
3486  		// by Carol's Invoice is expected in the Settle events
3487  		htlc, hops, preimage := n.sendThreeHopPayment(t)
3488  
3489  		alice, bob, carol := getEvents(
3490  			channels, uint64(i), now, htlc, hops, preimage,
3491  		)
3492  
3493  		checkHtlcEvents(t, aliceEvents.Updates(), alice)
3494  		checkHtlcEvents(t, bobEvents.Updates(), bob)
3495  		checkHtlcEvents(t, carolEvents.Updates(), carol)
3496  	}
3497  }
3498  
3499  // checkHtlcEvents checks that a subscription has the set of htlc events
3500  // we expect it to have.
3501  func checkHtlcEvents(t *testing.T, events <-chan interface{},
3502  	expectedEvents []interface{}) {
3503  
3504  	t.Helper()
3505  
3506  	for _, expected := range expectedEvents {
3507  		select {
3508  		case event := <-events:
3509  			if !reflect.DeepEqual(event, expected) {
3510  				t.Fatalf("expected %v, got: %v", expected,
3511  					event)
3512  			}
3513  
3514  		case <-time.After(5 * time.Second):
3515  			t.Fatalf("expected event: %v", expected)
3516  		}
3517  	}
3518  
3519  	// Check that there are no unexpected events following.
3520  	select {
3521  	case event := <-events:
3522  		t.Fatalf("unexpected event: %v", event)
3523  	default:
3524  	}
3525  }
3526  
3527  // sendThreeHopPayment is a helper function which sends a payment over
3528  // Alice -> Bob -> Carol in a three hop network and returns Alice's first htlc
3529  // and the remainder of the hops.
3530  func (n *threeHopNetwork) sendThreeHopPayment(t *testing.T) (*lnwire.UpdateAddHTLC,
3531  	[]*hop.Payload, *lntypes.Preimage) {
3532  
3533  	amount := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin)
3534  
3535  	htlcAmt, totalTimelock, hops := generateHops(amount, testStartingHeight,
3536  		n.firstBobChannelLink, n.carolChannelLink)
3537  	blob, err := generateRoute(hops...)
3538  	if err != nil {
3539  		t.Fatal(err)
3540  	}
3541  	invoice, htlc, pid, err := generatePayment(
3542  		amount, htlcAmt, totalTimelock, blob,
3543  	)
3544  	if err != nil {
3545  		t.Fatal(err)
3546  	}
3547  
3548  	err = n.carolServer.registry.AddInvoice(
3549  		t.Context(), *invoice, htlc.PaymentHash,
3550  	)
3551  	require.NoError(t, err, "unable to add invoice in carol registry")
3552  
3553  	if err := n.aliceServer.htlcSwitch.SendHTLC(
3554  		n.firstBobChannelLink.ShortChanID(), pid, htlc,
3555  	); err != nil {
3556  		t.Fatalf("could not send htlc")
3557  	}
3558  
3559  	return htlc, hops, invoice.Terms.PaymentPreimage
3560  }
3561  
3562  // getThreeHopEvents gets the set of htlc events that we expect for a payment
3563  // from Alice -> Bob -> Carol. If a non-nil link error is provided, the set
3564  // of events will fail on Bob's outgoing link.
3565  func getThreeHopEvents(channels *clusterChannels, htlcID uint64,
3566  	ts time.Time, htlc *lnwire.UpdateAddHTLC, hops []*hop.Payload,
3567  	linkError *LinkError,
3568  	preimage *lntypes.Preimage) ([]interface{}, []interface{}, []interface{}) {
3569  
3570  	aliceKey := HtlcKey{
3571  		IncomingCircuit: zeroCircuit,
3572  		OutgoingCircuit: models.CircuitKey{
3573  			ChanID: channels.aliceToBob.ShortChanID(),
3574  			HtlcID: htlcID,
3575  		},
3576  	}
3577  
3578  	// Alice always needs a forwarding event because she initiates the
3579  	// send.
3580  	aliceEvents := []interface{}{
3581  		&ForwardingEvent{
3582  			HtlcKey: aliceKey,
3583  			HtlcInfo: HtlcInfo{
3584  				OutgoingTimeLock: htlc.Expiry,
3585  				OutgoingAmt:      htlc.Amount,
3586  			},
3587  			HtlcEventType: HtlcEventTypeSend,
3588  			Timestamp:     ts,
3589  		},
3590  	}
3591  
3592  	bobKey := HtlcKey{
3593  		IncomingCircuit: models.CircuitKey{
3594  			ChanID: channels.bobToAlice.ShortChanID(),
3595  			HtlcID: htlcID,
3596  		},
3597  		OutgoingCircuit: models.CircuitKey{
3598  			ChanID: channels.bobToCarol.ShortChanID(),
3599  			HtlcID: htlcID,
3600  		},
3601  	}
3602  
3603  	bobInfo := HtlcInfo{
3604  		IncomingTimeLock: htlc.Expiry,
3605  		IncomingAmt:      htlc.Amount,
3606  		OutgoingTimeLock: hops[1].FwdInfo.OutgoingCTLV,
3607  		OutgoingAmt:      hops[1].FwdInfo.AmountToForward,
3608  	}
3609  
3610  	// If we expect the payment to fail, we add failures for alice and
3611  	// bob, and no events for carol because the payment never reaches her.
3612  	if linkError != nil {
3613  		aliceEvents = append(aliceEvents,
3614  			&ForwardingFailEvent{
3615  				HtlcKey:       aliceKey,
3616  				HtlcEventType: HtlcEventTypeSend,
3617  				Timestamp:     ts,
3618  			},
3619  		)
3620  
3621  		bobEvents := []interface{}{
3622  			&LinkFailEvent{
3623  				HtlcKey:       bobKey,
3624  				HtlcInfo:      bobInfo,
3625  				HtlcEventType: HtlcEventTypeForward,
3626  				LinkError:     linkError,
3627  				Incoming:      false,
3628  				Timestamp:     ts,
3629  			},
3630  			&FinalHtlcEvent{
3631  				CircuitKey: bobKey.IncomingCircuit,
3632  				Settled:    false,
3633  				Offchain:   true,
3634  				Timestamp:  ts,
3635  			},
3636  		}
3637  
3638  		return aliceEvents, bobEvents, nil
3639  	}
3640  
3641  	// If we want to get events for a successful payment, we add a settle
3642  	// for alice, a forward and settle for bob and a receive settle for
3643  	// carol.
3644  	aliceEvents = append(
3645  		aliceEvents,
3646  		&SettleEvent{
3647  			HtlcKey:       aliceKey,
3648  			Preimage:      *preimage,
3649  			HtlcEventType: HtlcEventTypeSend,
3650  			Timestamp:     ts,
3651  		},
3652  	)
3653  
3654  	bobEvents := []interface{}{
3655  		&ForwardingEvent{
3656  			HtlcKey:       bobKey,
3657  			HtlcInfo:      bobInfo,
3658  			HtlcEventType: HtlcEventTypeForward,
3659  			Timestamp:     ts,
3660  		},
3661  		&SettleEvent{
3662  			HtlcKey:       bobKey,
3663  			Preimage:      *preimage,
3664  			HtlcEventType: HtlcEventTypeForward,
3665  			Timestamp:     ts,
3666  		},
3667  		&FinalHtlcEvent{
3668  			CircuitKey: bobKey.IncomingCircuit,
3669  			Settled:    true,
3670  			Offchain:   true,
3671  			Timestamp:  ts,
3672  		},
3673  	}
3674  
3675  	carolEvents := []interface{}{
3676  		&SettleEvent{
3677  			HtlcKey: HtlcKey{
3678  				IncomingCircuit: models.CircuitKey{
3679  					ChanID: channels.carolToBob.ShortChanID(),
3680  					HtlcID: htlcID,
3681  				},
3682  				OutgoingCircuit: zeroCircuit,
3683  			},
3684  			Preimage:      *preimage,
3685  			HtlcEventType: HtlcEventTypeReceive,
3686  			Timestamp:     ts,
3687  		}, &FinalHtlcEvent{
3688  			CircuitKey: models.CircuitKey{
3689  				ChanID: channels.carolToBob.ShortChanID(),
3690  				HtlcID: htlcID,
3691  			},
3692  			Settled:   true,
3693  			Offchain:  true,
3694  			Timestamp: ts,
3695  		},
3696  	}
3697  
3698  	return aliceEvents, bobEvents, carolEvents
3699  }
3700  
3701  type mockForwardInterceptor struct {
3702  	t *testing.T
3703  
3704  	interceptedChan chan InterceptedPacket
3705  }
3706  
3707  func (m *mockForwardInterceptor) InterceptForwardHtlc(
3708  	intercepted InterceptedPacket) error {
3709  
3710  	m.interceptedChan <- intercepted
3711  
3712  	return nil
3713  }
3714  
3715  func (m *mockForwardInterceptor) getIntercepted() InterceptedPacket {
3716  	m.t.Helper()
3717  
3718  	select {
3719  	case p := <-m.interceptedChan:
3720  		return p
3721  
3722  	case <-time.After(time.Second):
3723  		require.Fail(m.t, "timeout")
3724  
3725  		return InterceptedPacket{}
3726  	}
3727  }
3728  
3729  func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) {
3730  	if s.circuits.NumPending() != pending {
3731  		t.Fatalf("wrong amount of half circuits, expected %v but "+
3732  			"got %v", pending, s.circuits.NumPending())
3733  	}
3734  	if s.circuits.NumOpen() != opened {
3735  		t.Fatalf("wrong amount of circuits, expected %v but got %v",
3736  			opened, s.circuits.NumOpen())
3737  	}
3738  }
3739  
3740  func assertOutgoingLinkReceive(t *testing.T, targetLink *mockChannelLink,
3741  	expectReceive bool) *htlcPacket {
3742  
3743  	// Pull packet from targetLink link.
3744  	select {
3745  	case packet := <-targetLink.packets:
3746  		if !expectReceive {
3747  			t.Fatal("forward was intercepted, shouldn't land at bob link")
3748  		} else if err := targetLink.completeCircuit(packet); err != nil {
3749  			t.Fatalf("unable to complete payment circuit: %v", err)
3750  		}
3751  
3752  		return packet
3753  
3754  	case <-time.After(time.Second):
3755  		if expectReceive {
3756  			t.Fatal("request was not propagated to destination")
3757  		}
3758  	}
3759  
3760  	return nil
3761  }
3762  
3763  func assertOutgoingLinkReceiveIntercepted(t *testing.T,
3764  	targetLink *mockChannelLink) {
3765  
3766  	t.Helper()
3767  
3768  	select {
3769  	case <-targetLink.packets:
3770  	case <-time.After(time.Second):
3771  		t.Fatal("request was not propagated to destination")
3772  	}
3773  }
3774  
3775  type interceptableSwitchTestContext struct {
3776  	t *testing.T
3777  
3778  	preimage           [sha256.Size]byte
3779  	rhash              [32]byte
3780  	onionBlob          [1366]byte
3781  	incomingHtlcID     uint64
3782  	cltvRejectDelta    uint32
3783  	cltvInterceptDelta uint32
3784  
3785  	forwardInterceptor *mockForwardInterceptor
3786  	aliceChannelLink   *mockChannelLink
3787  	bobChannelLink     *mockChannelLink
3788  	s                  *Switch
3789  }
3790  
3791  func newInterceptableSwitchTestContext(
3792  	t *testing.T) *interceptableSwitchTestContext {
3793  
3794  	chanID1, chanID2, aliceChanID, bobChanID := genIDs()
3795  
3796  	alicePeer, err := newMockServer(
3797  		t, "alice", testStartingHeight, nil, testDefaultDelta,
3798  	)
3799  	require.NoError(t, err, "unable to create alice server")
3800  	bobPeer, err := newMockServer(
3801  		t, "bob", testStartingHeight, nil, testDefaultDelta,
3802  	)
3803  	require.NoError(t, err, "unable to create bob server")
3804  
3805  	tempPath := t.TempDir()
3806  
3807  	cdb := channeldb.OpenForTesting(t, tempPath)
3808  
3809  	s, err := initSwitchWithDB(testStartingHeight, cdb)
3810  	require.NoError(t, err, "unable to init switch")
3811  	if err := s.Start(); err != nil {
3812  		t.Fatalf("unable to start switch: %v", err)
3813  	}
3814  
3815  	aliceChannelLink := newMockChannelLink(
3816  		s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
3817  		false, false,
3818  	)
3819  	bobChannelLink := newMockChannelLink(
3820  		s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
3821  		false,
3822  	)
3823  	if err := s.AddLink(aliceChannelLink); err != nil {
3824  		t.Fatalf("unable to add alice link: %v", err)
3825  	}
3826  	if err := s.AddLink(bobChannelLink); err != nil {
3827  		t.Fatalf("unable to add bob link: %v", err)
3828  	}
3829  
3830  	preimage := [sha256.Size]byte{1}
3831  
3832  	ctx := &interceptableSwitchTestContext{
3833  		t:                  t,
3834  		preimage:           preimage,
3835  		rhash:              sha256.Sum256(preimage[:]),
3836  		onionBlob:          [1366]byte{4, 5, 6},
3837  		incomingHtlcID:     uint64(0),
3838  		cltvRejectDelta:    10,
3839  		cltvInterceptDelta: 13,
3840  		forwardInterceptor: &mockForwardInterceptor{
3841  			t:               t,
3842  			interceptedChan: make(chan InterceptedPacket),
3843  		},
3844  		aliceChannelLink: aliceChannelLink,
3845  		bobChannelLink:   bobChannelLink,
3846  		s:                s,
3847  	}
3848  
3849  	return ctx
3850  }
3851  
3852  func (c *interceptableSwitchTestContext) createTestPacket() *htlcPacket {
3853  	c.incomingHtlcID++
3854  
3855  	return &htlcPacket{
3856  		incomingChanID:  c.aliceChannelLink.ShortChanID(),
3857  		incomingHTLCID:  c.incomingHtlcID,
3858  		incomingTimeout: testStartingHeight + c.cltvInterceptDelta + 1,
3859  		outgoingChanID:  c.bobChannelLink.ShortChanID(),
3860  		obfuscator:      NewMockObfuscator(),
3861  		htlc: &lnwire.UpdateAddHTLC{
3862  			PaymentHash: c.rhash,
3863  			Amount:      1,
3864  			OnionBlob:   c.onionBlob,
3865  		},
3866  	}
3867  }
3868  
3869  func (c *interceptableSwitchTestContext) finish() {
3870  	if err := c.s.Stop(); err != nil {
3871  		c.t.Fatal(err)
3872  	}
3873  }
3874  
3875  func (c *interceptableSwitchTestContext) createSettlePacket(
3876  	outgoingHTLCID uint64) *htlcPacket {
3877  
3878  	return &htlcPacket{
3879  		outgoingChanID: c.bobChannelLink.ShortChanID(),
3880  		outgoingHTLCID: outgoingHTLCID,
3881  		amount:         1,
3882  		htlc: &lnwire.UpdateFulfillHTLC{
3883  			PaymentPreimage: c.preimage,
3884  		},
3885  	}
3886  }
3887  
3888  func TestSwitchHoldForward(t *testing.T) {
3889  	t.Parallel()
3890  
3891  	c := newInterceptableSwitchTestContext(t)
3892  	defer c.finish()
3893  
3894  	notifier := &mock.ChainNotifier{
3895  		EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
3896  	}
3897  	notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight}
3898  
3899  	switchForwardInterceptor, err := NewInterceptableSwitch(
3900  		&InterceptableSwitchConfig{
3901  			Switch:             c.s,
3902  			CltvRejectDelta:    c.cltvRejectDelta,
3903  			CltvInterceptDelta: c.cltvInterceptDelta,
3904  			Notifier:           notifier,
3905  		},
3906  	)
3907  	require.NoError(t, err)
3908  	require.NoError(t, switchForwardInterceptor.Start())
3909  
3910  	switchForwardInterceptor.SetInterceptor(c.forwardInterceptor.InterceptForwardHtlc)
3911  	linkQuit := make(chan struct{})
3912  
3913  	// Test a forward that expires too soon.
3914  	packet := c.createTestPacket()
3915  	packet.incomingTimeout = testStartingHeight + c.cltvRejectDelta - 1
3916  
3917  	err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
3918  	require.NoError(t, err, "can't forward htlc packet")
3919  	assertOutgoingLinkReceive(t, c.bobChannelLink, false)
3920  	assertOutgoingLinkReceiveIntercepted(t, c.aliceChannelLink)
3921  	assertNumCircuits(t, c.s, 0, 0)
3922  
3923  	// Test a forward that expires too soon and can't be failed.
3924  	packet = c.createTestPacket()
3925  	packet.incomingTimeout = testStartingHeight + c.cltvRejectDelta - 1
3926  
3927  	// Simulate an error during the composition of the failure message.
3928  	currentCallback := c.s.cfg.FetchLastChannelUpdate
3929  	c.s.cfg.FetchLastChannelUpdate = func(
3930  		lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, error) {
3931  
3932  		return nil, errors.New("cannot fetch update")
3933  	}
3934  
3935  	err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
3936  	require.NoError(t, err, "can't forward htlc packet")
3937  	receivedPkt := assertOutgoingLinkReceive(t, c.bobChannelLink, true)
3938  	assertNumCircuits(t, c.s, 1, 1)
3939  
3940  	require.NoError(t, switchForwardInterceptor.ForwardPackets(
3941  		linkQuit, false,
3942  		c.createSettlePacket(receivedPkt.outgoingHTLCID),
3943  	))
3944  
3945  	assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
3946  	assertNumCircuits(t, c.s, 0, 0)
3947  
3948  	c.s.cfg.FetchLastChannelUpdate = currentCallback
3949  
3950  	// Test resume a hold forward.
3951  	assertNumCircuits(t, c.s, 0, 0)
3952  	err = switchForwardInterceptor.ForwardPackets(
3953  		linkQuit, false, c.createTestPacket(),
3954  	)
3955  	require.NoError(t, err)
3956  
3957  	assertNumCircuits(t, c.s, 0, 0)
3958  	assertOutgoingLinkReceive(t, c.bobChannelLink, false)
3959  
3960  	require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
3961  		Action: FwdActionResume,
3962  		Key:    c.forwardInterceptor.getIntercepted().IncomingCircuit,
3963  	}))
3964  	receivedPkt = assertOutgoingLinkReceive(t, c.bobChannelLink, true)
3965  	assertNumCircuits(t, c.s, 1, 1)
3966  
3967  	// settling the htlc to close the circuit.
3968  	err = switchForwardInterceptor.ForwardPackets(
3969  		linkQuit, false,
3970  		c.createSettlePacket(receivedPkt.outgoingHTLCID),
3971  	)
3972  	require.NoError(t, err)
3973  
3974  	assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
3975  	assertNumCircuits(t, c.s, 0, 0)
3976  
3977  	// Test resume a hold forward after disconnection.
3978  	require.NoError(t, switchForwardInterceptor.ForwardPackets(
3979  		linkQuit, false, c.createTestPacket(),
3980  	))
3981  
3982  	// Wait until the packet is offered to the interceptor.
3983  	_ = c.forwardInterceptor.getIntercepted()
3984  
3985  	// No forward expected yet.
3986  	assertNumCircuits(t, c.s, 0, 0)
3987  	assertOutgoingLinkReceive(t, c.bobChannelLink, false)
3988  
3989  	// Disconnect should resume the forwarding.
3990  	switchForwardInterceptor.SetInterceptor(nil)
3991  
3992  	receivedPkt = assertOutgoingLinkReceive(t, c.bobChannelLink, true)
3993  	assertNumCircuits(t, c.s, 1, 1)
3994  
3995  	// Settle the htlc to close the circuit.
3996  	require.NoError(t, switchForwardInterceptor.ForwardPackets(
3997  		linkQuit, false,
3998  		c.createSettlePacket(receivedPkt.outgoingHTLCID),
3999  	))
4000  
4001  	assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
4002  	assertNumCircuits(t, c.s, 0, 0)
4003  
4004  	// Test failing a hold forward
4005  	switchForwardInterceptor.SetInterceptor(
4006  		c.forwardInterceptor.InterceptForwardHtlc,
4007  	)
4008  
4009  	require.NoError(t, switchForwardInterceptor.ForwardPackets(
4010  		linkQuit, false, c.createTestPacket(),
4011  	))
4012  	assertNumCircuits(t, c.s, 0, 0)
4013  	assertOutgoingLinkReceive(t, c.bobChannelLink, false)
4014  
4015  	require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
4016  		Action:      FwdActionFail,
4017  		Key:         c.forwardInterceptor.getIntercepted().IncomingCircuit,
4018  		FailureCode: lnwire.CodeTemporaryChannelFailure,
4019  	}))
4020  	assertOutgoingLinkReceive(t, c.bobChannelLink, false)
4021  	assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
4022  	assertNumCircuits(t, c.s, 0, 0)
4023  
4024  	// Test failing a hold forward with a failure message.
4025  	require.NoError(t,
4026  		switchForwardInterceptor.ForwardPackets(
4027  			linkQuit, false, c.createTestPacket(),
4028  		),
4029  	)
4030  	assertNumCircuits(t, c.s, 0, 0)
4031  	assertOutgoingLinkReceive(t, c.bobChannelLink, false)
4032  
4033  	reason := lnwire.OpaqueReason([]byte{1, 2, 3})
4034  	require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
4035  		Action:         FwdActionFail,
4036  		Key:            c.forwardInterceptor.getIntercepted().IncomingCircuit,
4037  		FailureMessage: reason,
4038  	}))
4039  
4040  	assertOutgoingLinkReceive(t, c.bobChannelLink, false)
4041  	packet = assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
4042  
4043  	require.Equal(t, reason, packet.htlc.(*lnwire.UpdateFailHTLC).Reason)
4044  
4045  	assertNumCircuits(t, c.s, 0, 0)
4046  
4047  	// Test failing a hold forward with a malformed htlc failure.
4048  	err = switchForwardInterceptor.ForwardPackets(
4049  		linkQuit, false, c.createTestPacket(),
4050  	)
4051  	require.NoError(t, err)
4052  
4053  	assertNumCircuits(t, c.s, 0, 0)
4054  	assertOutgoingLinkReceive(t, c.bobChannelLink, false)
4055  
4056  	code := lnwire.CodeInvalidOnionKey
4057  	require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
4058  		Action:      FwdActionFail,
4059  		Key:         c.forwardInterceptor.getIntercepted().IncomingCircuit,
4060  		FailureCode: code,
4061  	}))
4062  
4063  	assertOutgoingLinkReceive(t, c.bobChannelLink, false)
4064  	packet = assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
4065  	failPacket := packet.htlc.(*lnwire.UpdateFailHTLC)
4066  
4067  	shaOnionBlob := sha256.Sum256(c.onionBlob[:])
4068  	expectedFailure := &lnwire.FailInvalidOnionKey{
4069  		OnionSHA256: shaOnionBlob,
4070  	}
4071  
4072  	fwdErr, err := newMockDeobfuscator().DecryptError(failPacket.Reason)
4073  	require.NoError(t, err)
4074  	require.Equal(t, expectedFailure, fwdErr.WireMessage())
4075  
4076  	assertNumCircuits(t, c.s, 0, 0)
4077  
4078  	// Test settling a hold forward
4079  	require.NoError(t, switchForwardInterceptor.ForwardPackets(
4080  		linkQuit, false, c.createTestPacket(),
4081  	))
4082  	assertNumCircuits(t, c.s, 0, 0)
4083  	assertOutgoingLinkReceive(t, c.bobChannelLink, false)
4084  
4085  	require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
4086  		Key:      c.forwardInterceptor.getIntercepted().IncomingCircuit,
4087  		Action:   FwdActionSettle,
4088  		Preimage: c.preimage,
4089  	}))
4090  	assertOutgoingLinkReceive(t, c.bobChannelLink, false)
4091  	assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
4092  	assertNumCircuits(t, c.s, 0, 0)
4093  
4094  	require.NoError(t, switchForwardInterceptor.Stop())
4095  
4096  	// Test always-on interception.
4097  	notifier = &mock.ChainNotifier{
4098  		EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
4099  	}
4100  	notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight}
4101  
4102  	switchForwardInterceptor, err = NewInterceptableSwitch(
4103  		&InterceptableSwitchConfig{
4104  			Switch:             c.s,
4105  			CltvRejectDelta:    c.cltvRejectDelta,
4106  			CltvInterceptDelta: c.cltvInterceptDelta,
4107  			RequireInterceptor: true,
4108  			Notifier:           notifier,
4109  		},
4110  	)
4111  	require.NoError(t, err)
4112  	require.NoError(t, switchForwardInterceptor.Start())
4113  
4114  	// Forward a fresh packet. It is expected to be failed immediately,
4115  	// because there is no interceptor registered.
4116  	require.NoError(t, switchForwardInterceptor.ForwardPackets(
4117  		linkQuit, false, c.createTestPacket(),
4118  	))
4119  
4120  	assertOutgoingLinkReceive(t, c.bobChannelLink, false)
4121  	assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
4122  	assertNumCircuits(t, c.s, 0, 0)
4123  
4124  	// Forward a replayed packet. It is expected to be held until the
4125  	// interceptor connects. To continue the test, it needs to be ran in a
4126  	// goroutine.
4127  	errChan := make(chan error)
4128  	go func() {
4129  		errChan <- switchForwardInterceptor.ForwardPackets(
4130  			linkQuit, true, c.createTestPacket(),
4131  		)
4132  	}()
4133  
4134  	// Assert that nothing is forward to the switch.
4135  	assertOutgoingLinkReceive(t, c.bobChannelLink, false)
4136  	assertNumCircuits(t, c.s, 0, 0)
4137  
4138  	// Register an interceptor.
4139  	switchForwardInterceptor.SetInterceptor(
4140  		c.forwardInterceptor.InterceptForwardHtlc,
4141  	)
4142  
4143  	// Expect the ForwardPackets call to unblock.
4144  	require.NoError(t, <-errChan)
4145  
4146  	// Now expect the queued packet to come through.
4147  	c.forwardInterceptor.getIntercepted()
4148  
4149  	// Disconnect and reconnect interceptor.
4150  	switchForwardInterceptor.SetInterceptor(nil)
4151  	switchForwardInterceptor.SetInterceptor(
4152  		c.forwardInterceptor.InterceptForwardHtlc,
4153  	)
4154  
4155  	// A replay of the held packet is expected.
4156  	intercepted := c.forwardInterceptor.getIntercepted()
4157  
4158  	// Settle the packet.
4159  	require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
4160  		Key:      intercepted.IncomingCircuit,
4161  		Action:   FwdActionSettle,
4162  		Preimage: c.preimage,
4163  	}))
4164  	assertOutgoingLinkReceive(t, c.bobChannelLink, false)
4165  	assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
4166  	assertNumCircuits(t, c.s, 0, 0)
4167  
4168  	require.NoError(t, switchForwardInterceptor.Stop())
4169  
4170  	select {
4171  	case <-c.forwardInterceptor.interceptedChan:
4172  		require.Fail(t, "unexpected interception")
4173  
4174  	default:
4175  	}
4176  }
4177  
4178  func TestInterceptableSwitchWatchDog(t *testing.T) {
4179  	t.Parallel()
4180  
4181  	c := newInterceptableSwitchTestContext(t)
4182  	defer c.finish()
4183  
4184  	// Start interceptable switch.
4185  	notifier := &mock.ChainNotifier{
4186  		EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
4187  	}
4188  	notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight}
4189  
4190  	switchForwardInterceptor, err := NewInterceptableSwitch(
4191  		&InterceptableSwitchConfig{
4192  			Switch:             c.s,
4193  			CltvRejectDelta:    c.cltvRejectDelta,
4194  			CltvInterceptDelta: c.cltvInterceptDelta,
4195  			Notifier:           notifier,
4196  		},
4197  	)
4198  	require.NoError(t, err)
4199  	require.NoError(t, switchForwardInterceptor.Start())
4200  
4201  	// Set interceptor.
4202  	switchForwardInterceptor.SetInterceptor(
4203  		c.forwardInterceptor.InterceptForwardHtlc,
4204  	)
4205  
4206  	// Receive a packet.
4207  	linkQuit := make(chan struct{})
4208  
4209  	packet := c.createTestPacket()
4210  
4211  	err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
4212  	require.NoError(t, err, "can't forward htlc packet")
4213  
4214  	// Intercept the packet.
4215  	intercepted := c.forwardInterceptor.getIntercepted()
4216  
4217  	require.Equal(t,
4218  		int32(packet.incomingTimeout-c.cltvRejectDelta),
4219  		intercepted.AutoFailHeight,
4220  	)
4221  
4222  	// Htlc expires before a resolution from the interceptor.
4223  	notifier.EpochChan <- &chainntnfs.BlockEpoch{
4224  		Height: int32(packet.incomingTimeout) -
4225  			int32(c.cltvRejectDelta),
4226  	}
4227  
4228  	// Expect the htlc to be failed back.
4229  	assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
4230  
4231  	// It is too late now to resolve. Expect an error.
4232  	require.Error(t, switchForwardInterceptor.Resolve(&FwdResolution{
4233  		Action:   FwdActionSettle,
4234  		Key:      intercepted.IncomingCircuit,
4235  		Preimage: c.preimage,
4236  	}))
4237  }
4238  
4239  // TestSwitchDustForwarding tests that the switch properly fails HTLC's which
4240  // have incoming or outgoing links that breach their fee thresholds.
4241  func TestSwitchDustForwarding(t *testing.T) {
4242  	t.Parallel()
4243  
4244  	// We'll create a three-hop network:
4245  	// - Alice has a dust limit of 200sats with Bob
4246  	// - Bob has a dust limit of 800sats with Alice
4247  	// - Bob has a dust limit of 200sats with Carol
4248  	// - Carol has a dust limit of 800sats with Bob
4249  	channels, _, err := createClusterChannels(
4250  		t, btcutil.SatoshiPerBitcoin, btcutil.SatoshiPerBitcoin,
4251  	)
4252  	require.NoError(t, err)
4253  
4254  	n := newThreeHopNetwork(
4255  		t, channels.aliceToBob, channels.bobToAlice,
4256  		channels.bobToCarol, channels.carolToBob, testStartingHeight,
4257  	)
4258  	err = n.start()
4259  	require.NoError(t, err)
4260  
4261  	// We'll also put Alice and Bob into hodl.ExitSettle mode, such that
4262  	// they won't settle incoming exit-hop HTLC's automatically.
4263  	n.aliceChannelLink.cfg.HodlMask = hodl.ExitSettle.Mask()
4264  	n.firstBobChannelLink.cfg.HodlMask = hodl.ExitSettle.Mask()
4265  
4266  	// We'll test that once the default threshold is exceeded on the
4267  	// Alice -> Bob channel, either side's calls to SendHTLC will fail.
4268  	numHTLCs := maxInflightHtlcs
4269  	aliceAttemptID, bobAttemptID := numHTLCs, numHTLCs
4270  	amt := lnwire.NewMSatFromSatoshis(700)
4271  	aliceBobFirstHop := n.aliceChannelLink.ShortChanID()
4272  
4273  	// We decreased the max number of inflight HTLCs therefore we also need
4274  	// do decrease the max fee exposure.
4275  	maxFeeExposure := lnwire.NewMSatFromSatoshis(74500)
4276  	n.aliceChannelLink.cfg.MaxFeeExposure = maxFeeExposure
4277  	n.firstBobChannelLink.cfg.MaxFeeExposure = maxFeeExposure
4278  
4279  	// Alice will send 50 HTLC's of 700sats. Bob will also send 50 HTLC's
4280  	// of 700sats.
4281  	sendDustHtlcs(t, n, true, amt, aliceBobFirstHop, numHTLCs)
4282  	sendDustHtlcs(t, n, false, amt, aliceBobFirstHop, numHTLCs)
4283  
4284  	// Generate the parameters needed for Bob to send another dust HTLC.
4285  	_, timelock, hops := generateHops(
4286  		amt, testStartingHeight, n.aliceChannelLink,
4287  	)
4288  
4289  	blob, err := generateRoute(hops...)
4290  	require.NoError(t, err)
4291  
4292  	// Assert that if Bob sends a dust HTLC it will fail.
4293  	failingPreimage := lntypes.Preimage{0, 0, 3}
4294  	failingHash := failingPreimage.Hash()
4295  	failingHtlc := &lnwire.UpdateAddHTLC{
4296  		PaymentHash: failingHash,
4297  		Amount:      amt,
4298  		Expiry:      timelock,
4299  		OnionBlob:   blob,
4300  	}
4301  
4302  	// This is the expected dust without taking the commitfee into account.
4303  	expectedDust := maxInflightHtlcs * 2 * amt
4304  
4305  	assertAlmostDust := func(link *channelLink, mbox MailBox,
4306  		whoseCommit lntypes.ChannelParty) {
4307  
4308  		err := wait.NoError(func() error {
4309  			linkDust := link.getDustSum(
4310  				whoseCommit, fn.None[chainfee.SatPerKWeight](),
4311  			)
4312  			localMailDust, remoteMailDust := mbox.DustPackets()
4313  
4314  			totalDust := linkDust
4315  			if whoseCommit.IsRemote() {
4316  				totalDust += remoteMailDust
4317  			} else {
4318  				totalDust += localMailDust
4319  			}
4320  
4321  			if totalDust == expectedDust {
4322  				return nil
4323  			}
4324  
4325  			return fmt.Errorf("got totalDust=%v, expectedDust=%v",
4326  				totalDust, expectedDust)
4327  		}, 15*time.Second)
4328  		require.NoError(t, err, "timeout checking dust")
4329  	}
4330  
4331  	// Wait until Bob is almost at the fee threshold.
4332  	bobMbox := n.bobServer.htlcSwitch.mailOrchestrator.GetOrCreateMailBox(
4333  		n.firstBobChannelLink.ChanID(),
4334  		n.firstBobChannelLink.ShortChanID(),
4335  	)
4336  	assertAlmostDust(n.firstBobChannelLink, bobMbox, lntypes.Local)
4337  
4338  	// Sending one more HTLC should fail. SendHTLC won't error, but the
4339  	// HTLC should be failed backwards. When sending we only check for the
4340  	// dust amount without the commitment fee. When the HTLC is added to the
4341  	// commitment state (link) we also take into account the commitment fee
4342  	// and with a fee of 6000 sat/kw and a commitment size of 724 (non
4343  	// anchor channel) we are overexposed in fees (maxFeeExposure) that's
4344  	// why the HTLC is failed back.
4345  	err = n.bobServer.htlcSwitch.SendHTLC(
4346  		aliceBobFirstHop, uint64(bobAttemptID), failingHtlc,
4347  	)
4348  	require.Nil(t, err)
4349  
4350  	// Use the network result store to ensure the HTLC was failed
4351  	// backwards.
4352  	bobResultChan, err := n.bobServer.htlcSwitch.GetAttemptResult(
4353  		uint64(bobAttemptID), failingHash, newMockDeobfuscator(),
4354  	)
4355  	require.NoError(t, err)
4356  
4357  	result, ok := <-bobResultChan
4358  	require.True(t, ok)
4359  	assertFailureCode(
4360  		t, result.Error, lnwire.CodeTemporaryChannelFailure,
4361  	)
4362  
4363  	bobAttemptID++
4364  
4365  	// Generate the parameters needed for bob to send a non-dust HTLC.
4366  	nondustAmt := lnwire.NewMSatFromSatoshis(10_000)
4367  	_, _, hops = generateHops(
4368  		nondustAmt, testStartingHeight, n.aliceChannelLink,
4369  	)
4370  
4371  	blob, err = generateRoute(hops...)
4372  	require.NoError(t, err)
4373  
4374  	// Now attempt to send an HTLC above Bob's dust limit. Even though this
4375  	// is not a dust HTLC, it should fail because the increase in weight
4376  	// pushes us over the threshold.
4377  	nondustPreimage := lntypes.Preimage{0, 0, 4}
4378  	nondustHash := nondustPreimage.Hash()
4379  	nondustHtlc := &lnwire.UpdateAddHTLC{
4380  		PaymentHash: nondustHash,
4381  		Amount:      nondustAmt,
4382  		Expiry:      timelock,
4383  		OnionBlob:   blob,
4384  	}
4385  
4386  	err = n.bobServer.htlcSwitch.SendHTLC(
4387  		aliceBobFirstHop, uint64(bobAttemptID), nondustHtlc,
4388  	)
4389  	require.NoError(t, err)
4390  	assertAlmostDust(n.firstBobChannelLink, bobMbox, lntypes.Local)
4391  
4392  	// Check that the HTLC failed.
4393  	bobResultChan, err = n.bobServer.htlcSwitch.GetAttemptResult(
4394  		uint64(bobAttemptID), nondustHash, newMockDeobfuscator(),
4395  	)
4396  	require.NoError(t, err)
4397  
4398  	result, ok = <-bobResultChan
4399  	require.True(t, ok)
4400  	assertFailureCode(
4401  		t, result.Error, lnwire.CodeTemporaryChannelFailure,
4402  	)
4403  
4404  	// Introduce Carol into the mix and assert that sending a multi-hop
4405  	// dust HTLC to Alice will fail. Bob should fail back the HTLC with a
4406  	// temporary channel failure.
4407  	carolAmt, carolTimelock, carolHops := generateHops(
4408  		amt, testStartingHeight, n.secondBobChannelLink,
4409  		n.aliceChannelLink,
4410  	)
4411  
4412  	carolBlob, err := generateRoute(carolHops...)
4413  	require.NoError(t, err)
4414  
4415  	carolPreimage := lntypes.Preimage{0, 0, 5}
4416  	carolHash := carolPreimage.Hash()
4417  	carolHtlc := &lnwire.UpdateAddHTLC{
4418  		PaymentHash: carolHash,
4419  		Amount:      carolAmt,
4420  		Expiry:      carolTimelock,
4421  		OnionBlob:   carolBlob,
4422  	}
4423  
4424  	// Initialize Carol's attempt ID.
4425  	carolAttemptID := 0
4426  
4427  	err = n.carolServer.htlcSwitch.SendHTLC(
4428  		n.carolChannelLink.ShortChanID(), uint64(carolAttemptID),
4429  		carolHtlc,
4430  	)
4431  	require.NoError(t, err)
4432  
4433  	carolResultChan, err := n.carolServer.htlcSwitch.GetAttemptResult(
4434  		uint64(carolAttemptID), carolHash, newMockDeobfuscator(),
4435  	)
4436  	require.NoError(t, err)
4437  
4438  	result, ok = <-carolResultChan
4439  	require.True(t, ok)
4440  	assertFailureCode(
4441  		t, result.Error, lnwire.CodeTemporaryChannelFailure,
4442  	)
4443  
4444  	// Send an HTLC from Alice to Carol and assert that it gets failed.
4445  	htlcAmt, totalTimelock, aliceHops := generateHops(
4446  		amt, testStartingHeight, n.firstBobChannelLink,
4447  		n.carolChannelLink,
4448  	)
4449  
4450  	blob, err = generateRoute(aliceHops...)
4451  	require.NoError(t, err)
4452  
4453  	aliceMultihopPreimage := lntypes.Preimage{0, 0, 6}
4454  	aliceMultihopHash := aliceMultihopPreimage.Hash()
4455  	aliceMultihopHtlc := &lnwire.UpdateAddHTLC{
4456  		PaymentHash: aliceMultihopHash,
4457  		Amount:      htlcAmt,
4458  		Expiry:      totalTimelock,
4459  		OnionBlob:   blob,
4460  	}
4461  
4462  	// Wait until Alice's expected dust for the remote commitment is just
4463  	// under the fee threshold.
4464  	aliceOrch := n.aliceServer.htlcSwitch.mailOrchestrator
4465  	aliceMbox := aliceOrch.GetOrCreateMailBox(
4466  		n.aliceChannelLink.ChanID(), n.aliceChannelLink.ShortChanID(),
4467  	)
4468  	assertAlmostDust(n.aliceChannelLink, aliceMbox, lntypes.Remote)
4469  
4470  	err = n.aliceServer.htlcSwitch.SendHTLC(
4471  		n.aliceChannelLink.ShortChanID(), uint64(aliceAttemptID),
4472  		aliceMultihopHtlc,
4473  	)
4474  	require.Nil(t, err)
4475  
4476  	aliceResultChan, err := n.aliceServer.htlcSwitch.GetAttemptResult(
4477  		uint64(aliceAttemptID), aliceMultihopHash,
4478  		newMockDeobfuscator(),
4479  	)
4480  	require.NoError(t, err)
4481  
4482  	result, ok = <-aliceResultChan
4483  	require.True(t, ok)
4484  	assertFailureCode(
4485  		t, result.Error, lnwire.CodeTemporaryChannelFailure,
4486  	)
4487  
4488  	// Check that there are numHTLCs circuits open for both Alice and Bob.
4489  	require.Equal(t, numHTLCs, n.aliceServer.htlcSwitch.circuits.NumOpen())
4490  	require.Equal(t, numHTLCs, n.bobServer.htlcSwitch.circuits.NumOpen())
4491  }
4492  
4493  // sendDustHtlcs is a helper function used to send many dust HTLC's to test the
4494  // Switch's channel-max-fee-exposure logic. It takes a boolean denoting whether
4495  // or not Alice is the sender.
4496  func sendDustHtlcs(t *testing.T, n *threeHopNetwork, alice bool,
4497  	amt lnwire.MilliSatoshi, sid lnwire.ShortChannelID, numHTLCs int) {
4498  
4499  	t.Helper()
4500  
4501  	// Extract the destination into a variable. If alice is the sender, the
4502  	// destination is Bob.
4503  	destLink := n.aliceChannelLink
4504  	if alice {
4505  		destLink = n.firstBobChannelLink
4506  	}
4507  
4508  	// Create hops that will be used in the onion payload.
4509  	htlcAmt, totalTimelock, hops := generateHops(
4510  		amt, testStartingHeight, destLink,
4511  	)
4512  
4513  	// Convert the hops to a blob that will be put in the Add message.
4514  	blob, err := generateRoute(hops...)
4515  	require.NoError(t, err)
4516  
4517  	// Create a slice to store the preimages.
4518  	preimages := make([]lntypes.Preimage, numHTLCs)
4519  
4520  	// Initialize the attempt ID used in SendHTLC calls.
4521  	attemptID := uint64(0)
4522  
4523  	// Deterministically generate preimages. Avoid the all-zeroes preimage
4524  	// because that will be rejected by the database. We'll use a different
4525  	// third byte for Alice and Bob.
4526  	endByte := byte(2)
4527  	if alice {
4528  		endByte = byte(3)
4529  	}
4530  
4531  	for i := 0; i < numHTLCs; i++ {
4532  		preimages[i] = lntypes.Preimage{byte(i >> 8), byte(i), endByte}
4533  	}
4534  
4535  	sendingSwitch := n.bobServer.htlcSwitch
4536  	if alice {
4537  		sendingSwitch = n.aliceServer.htlcSwitch
4538  	}
4539  
4540  	// Call SendHTLC in a loop for numHTLCs.
4541  	for i := 0; i < numHTLCs; i++ {
4542  		// Construct the htlc packet.
4543  		hash := preimages[i].Hash()
4544  
4545  		htlc := &lnwire.UpdateAddHTLC{
4546  			PaymentHash: hash,
4547  			Amount:      htlcAmt,
4548  			Expiry:      totalTimelock,
4549  			OnionBlob:   blob,
4550  		}
4551  
4552  		for {
4553  			// It may be the case that the fee threshold is hit
4554  			// before all numHTLCs*2 HTLC's are sent due to double
4555  			// counting. Get around this by continuing to send
4556  			// until successful.
4557  			err = sendingSwitch.SendHTLC(sid, attemptID, htlc)
4558  			if err == nil {
4559  				break
4560  			}
4561  		}
4562  
4563  		attemptID++
4564  	}
4565  }
4566  
4567  // TestSwitchMailboxDust tests that the switch takes into account the mailbox
4568  // dust when evaluating the fee threshold. The mockChannelLink does not have
4569  // channel state, so this only tests the switch-mailbox interaction.
4570  func TestSwitchMailboxDust(t *testing.T) {
4571  	t.Parallel()
4572  
4573  	alicePeer, err := newMockServer(
4574  		t, "alice", testStartingHeight, nil, testDefaultDelta,
4575  	)
4576  	require.NoError(t, err)
4577  
4578  	bobPeer, err := newMockServer(
4579  		t, "bob", testStartingHeight, nil, testDefaultDelta,
4580  	)
4581  	require.NoError(t, err)
4582  
4583  	carolPeer, err := newMockServer(
4584  		t, "carol", testStartingHeight, nil, testDefaultDelta,
4585  	)
4586  	require.NoError(t, err)
4587  
4588  	s, err := initSwitchWithTempDB(t, testStartingHeight)
4589  	require.NoError(t, err)
4590  	err = s.Start()
4591  	require.NoError(t, err)
4592  	defer func() {
4593  		_ = s.Stop()
4594  	}()
4595  
4596  	chanID1, chanID2, aliceChanID, bobChanID := genIDs()
4597  
4598  	chanID3, carolChanID := genID()
4599  
4600  	aliceLink := newMockChannelLink(
4601  		s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
4602  		false, false,
4603  	)
4604  	err = s.AddLink(aliceLink)
4605  	require.NoError(t, err)
4606  
4607  	bobLink := newMockChannelLink(
4608  		s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
4609  		false,
4610  	)
4611  	err = s.AddLink(bobLink)
4612  	require.NoError(t, err)
4613  
4614  	carolLink := newMockChannelLink(
4615  		s, chanID3, carolChanID, emptyScid, carolPeer, true, false,
4616  		false, false,
4617  	)
4618  	err = s.AddLink(carolLink)
4619  	require.NoError(t, err)
4620  
4621  	// mockChannelLink sets the local and remote dust limits of the mailbox
4622  	// to 400 satoshis and the feerate to 0. We'll fill the mailbox up with
4623  	// dust packets and assert that calls to SendHTLC will fail.
4624  	preimage, err := genPreimage()
4625  	require.NoError(t, err)
4626  	rhash := sha256.Sum256(preimage[:])
4627  	amt := lnwire.NewMSatFromSatoshis(350)
4628  	addMsg := &lnwire.UpdateAddHTLC{
4629  		PaymentHash: rhash,
4630  		Amount:      amt,
4631  		ChanID:      chanID1,
4632  	}
4633  
4634  	// Initialize the carolHTLCID.
4635  	var carolHTLCID uint64
4636  
4637  	// It will take aliceCount HTLC's of 350sats to fill up Alice's mailbox
4638  	// to the point where another would put Alice over the fee threshold.
4639  	aliceCount := 1428
4640  
4641  	mailbox := s.mailOrchestrator.GetOrCreateMailBox(chanID1, aliceChanID)
4642  
4643  	for i := 0; i < aliceCount; i++ {
4644  		alicePkt := &htlcPacket{
4645  			incomingChanID: carolChanID,
4646  			incomingHTLCID: carolHTLCID,
4647  			outgoingChanID: aliceChanID,
4648  			obfuscator:     NewMockObfuscator(),
4649  			incomingAmount: amt,
4650  			amount:         amt,
4651  			htlc:           addMsg,
4652  		}
4653  
4654  		err = mailbox.AddPacket(alicePkt)
4655  		require.NoError(t, err)
4656  
4657  		carolHTLCID++
4658  	}
4659  
4660  	// Sending one more HTLC to Alice should result in the fee threshold
4661  	// being breached.
4662  	err = s.SendHTLC(aliceChanID, 0, addMsg)
4663  	require.ErrorIs(t, err, errFeeExposureExceeded)
4664  
4665  	// We'll now call ForwardPackets from Bob to ensure that the mailbox
4666  	// sum is also accounted for in the forwarding case.
4667  	packet := &htlcPacket{
4668  		incomingChanID: bobChanID,
4669  		incomingHTLCID: 0,
4670  		outgoingChanID: aliceChanID,
4671  		obfuscator:     NewMockObfuscator(),
4672  		incomingAmount: amt,
4673  		amount:         amt,
4674  		htlc: &lnwire.UpdateAddHTLC{
4675  			PaymentHash: rhash,
4676  			Amount:      amt,
4677  			ChanID:      chanID1,
4678  		},
4679  	}
4680  
4681  	err = s.ForwardPackets(nil, packet)
4682  	require.NoError(t, err)
4683  
4684  	// Bob should receive a failure from the switch.
4685  	select {
4686  	case p := <-bobLink.packets:
4687  		require.NotEmpty(t, p.linkFailure)
4688  		assertFailureCode(
4689  			t, p.linkFailure, lnwire.CodeTemporaryChannelFailure,
4690  		)
4691  
4692  	case <-time.After(5 * time.Second):
4693  		t.Fatal("no timely reply from switch")
4694  	}
4695  }
4696  
4697  // TestSwitchResolution checks the ability of the switch to persist and handle
4698  // resolution messages.
4699  func TestSwitchResolution(t *testing.T) {
4700  	t.Parallel()
4701  
4702  	alicePeer, err := newMockServer(
4703  		t, "alice", testStartingHeight, nil, testDefaultDelta,
4704  	)
4705  	require.NoError(t, err)
4706  
4707  	bobPeer, err := newMockServer(
4708  		t, "bob", testStartingHeight, nil, testDefaultDelta,
4709  	)
4710  	require.NoError(t, err)
4711  
4712  	s, err := initSwitchWithTempDB(t, testStartingHeight)
4713  	require.NoError(t, err)
4714  
4715  	// Even though we intend to Stop s later in the test, it is safe to
4716  	// defer this Stop since its execution it is protected by an atomic
4717  	// guard, guaranteeing it executes at most once.
4718  	t.Cleanup(func() { var _ = s.Stop() })
4719  
4720  	err = s.Start()
4721  	require.NoError(t, err)
4722  
4723  	chanID1, chanID2, aliceChanID, bobChanID := genIDs()
4724  
4725  	aliceChannelLink := newMockChannelLink(
4726  		s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
4727  		false, false,
4728  	)
4729  	bobChannelLink := newMockChannelLink(
4730  		s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
4731  		false,
4732  	)
4733  	err = s.AddLink(aliceChannelLink)
4734  	require.NoError(t, err)
4735  	err = s.AddLink(bobChannelLink)
4736  	require.NoError(t, err)
4737  
4738  	// Create an add htlcPacket that Alice will send to Bob.
4739  	preimage, err := genPreimage()
4740  	require.NoError(t, err)
4741  
4742  	rhash := sha256.Sum256(preimage[:])
4743  	packet := &htlcPacket{
4744  		incomingChanID: aliceChannelLink.ShortChanID(),
4745  		incomingHTLCID: 0,
4746  		outgoingChanID: bobChannelLink.ShortChanID(),
4747  		obfuscator:     NewMockObfuscator(),
4748  		htlc: &lnwire.UpdateAddHTLC{
4749  			PaymentHash: rhash,
4750  			Amount:      1,
4751  		},
4752  	}
4753  
4754  	err = s.ForwardPackets(nil, packet)
4755  	require.NoError(t, err)
4756  
4757  	// Bob will receive the packet and open the circuit.
4758  	select {
4759  	case <-bobChannelLink.packets:
4760  		err = bobChannelLink.completeCircuit(packet)
4761  		require.NoError(t, err)
4762  	case <-time.After(time.Second):
4763  		t.Fatal("request was not propagated to destination")
4764  	}
4765  
4766  	// Check that only one circuit is open.
4767  	require.Equal(t, 1, s.circuits.NumOpen())
4768  
4769  	// We'll send a settle resolution to Switch that should go to Alice.
4770  	settleResMsg := contractcourt.ResolutionMsg{
4771  		SourceChan: bobChanID,
4772  		HtlcIndex:  0,
4773  		PreImage:   &preimage,
4774  	}
4775  
4776  	// Before the resolution is sent, remove alice's link so we can assert
4777  	// that the resolution is actually stored. Otherwise, it would be
4778  	// deleted shortly after being sent.
4779  	s.RemoveLink(chanID1)
4780  
4781  	// Send the resolution message.
4782  	err = s.ProcessContractResolution(settleResMsg)
4783  	require.NoError(t, err)
4784  
4785  	// Assert that the resolution store contains the settle reoslution.
4786  	resMsgs, err := s.resMsgStore.fetchAllResolutionMsg()
4787  	require.NoError(t, err)
4788  
4789  	require.Equal(t, 1, len(resMsgs))
4790  	require.Equal(t, settleResMsg.SourceChan, resMsgs[0].SourceChan)
4791  	require.Equal(t, settleResMsg.HtlcIndex, resMsgs[0].HtlcIndex)
4792  	require.Nil(t, resMsgs[0].Failure)
4793  	require.Equal(t, preimage, *resMsgs[0].PreImage)
4794  
4795  	// Now we'll restart Alice's link and delete the circuit.
4796  	err = s.AddLink(aliceChannelLink)
4797  	require.NoError(t, err)
4798  
4799  	// Alice will receive the packet and open the circuit.
4800  	select {
4801  	case alicePkt := <-aliceChannelLink.packets:
4802  		err = aliceChannelLink.completeCircuit(alicePkt)
4803  		require.NoError(t, err)
4804  	case <-time.After(time.Second):
4805  		t.Fatal("request was not propagated to destination")
4806  	}
4807  
4808  	// Assert that there are no more circuits.
4809  	require.Equal(t, 0, s.circuits.NumOpen())
4810  
4811  	// We'll restart the Switch and assert that Alice does not receive
4812  	// another packet.
4813  	switchDB := s.cfg.DB.(*channeldb.DB)
4814  	err = s.Stop()
4815  	require.NoError(t, err)
4816  
4817  	s, err = initSwitchWithDB(testStartingHeight, switchDB)
4818  	require.NoError(t, err)
4819  
4820  	err = s.Start()
4821  	require.NoError(t, err)
4822  	defer func() {
4823  		_ = s.Stop()
4824  	}()
4825  
4826  	err = s.AddLink(aliceChannelLink)
4827  	require.NoError(t, err)
4828  	err = s.AddLink(bobChannelLink)
4829  	require.NoError(t, err)
4830  
4831  	// Alice should not receive a packet since the Switch should have
4832  	// deleted the resolution message since the circuit was closed.
4833  	select {
4834  	case alicePkt := <-aliceChannelLink.packets:
4835  		t.Fatalf("received erroneous packet: %v", alicePkt)
4836  	case <-time.After(time.Second * 5):
4837  	}
4838  
4839  	// Check that the resolution message no longer exists in the store.
4840  	resMsgs, err = s.resMsgStore.fetchAllResolutionMsg()
4841  	require.NoError(t, err)
4842  	require.Equal(t, 0, len(resMsgs))
4843  }
4844  
4845  // TestSwitchForwardFailAlias tests that if ForwardPackets returns a failure
4846  // before actually forwarding, the ChannelUpdate uses the SCID from the
4847  // incoming channel and does not leak private information like the UTXO.
4848  func TestSwitchForwardFailAlias(t *testing.T) {
4849  	tests := []struct {
4850  		name string
4851  
4852  		// Whether or not Alice will be a zero-conf channel or an
4853  		// option-scid-alias channel (feature-bit).
4854  		zeroConf bool
4855  	}{
4856  		{
4857  			name:     "option-scid-alias forwarding failure",
4858  			zeroConf: false,
4859  		},
4860  		{
4861  			name:     "zero-conf forwarding failure",
4862  			zeroConf: true,
4863  		},
4864  	}
4865  
4866  	for _, test := range tests {
4867  		test := test
4868  
4869  		t.Run(test.name, func(t *testing.T) {
4870  			testSwitchForwardFailAlias(t, test.zeroConf)
4871  		})
4872  	}
4873  }
4874  
4875  func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) {
4876  	t.Parallel()
4877  
4878  	chanID1, chanID2, aliceChanID, bobChanID := genIDs()
4879  
4880  	alicePeer, err := newMockServer(
4881  		t, "alice", testStartingHeight, nil, testDefaultDelta,
4882  	)
4883  	require.NoError(t, err)
4884  
4885  	bobPeer, err := newMockServer(
4886  		t, "bob", testStartingHeight, nil, testDefaultDelta,
4887  	)
4888  	require.NoError(t, err)
4889  
4890  	tempPath := t.TempDir()
4891  
4892  	cdb := channeldb.OpenForTesting(t, tempPath)
4893  
4894  	s, err := initSwitchWithDB(testStartingHeight, cdb)
4895  	require.NoError(t, err)
4896  
4897  	err = s.Start()
4898  	require.NoError(t, err)
4899  
4900  	// Make Alice's channel zero-conf or option-scid-alias (feature bit).
4901  	aliceAlias := lnwire.ShortChannelID{
4902  		BlockHeight: 16_000_000,
4903  		TxIndex:     5,
4904  		TxPosition:  5,
4905  	}
4906  
4907  	var aliceLink *mockChannelLink
4908  	if zeroConf {
4909  		aliceLink = newMockChannelLink(
4910  			s, chanID1, aliceAlias, aliceChanID, alicePeer, true,
4911  			true, true, false,
4912  		)
4913  	} else {
4914  		aliceLink = newMockChannelLink(
4915  			s, chanID1, aliceChanID, emptyScid, alicePeer, true,
4916  			true, false, true,
4917  		)
4918  		aliceLink.addAlias(aliceAlias)
4919  	}
4920  	err = s.AddLink(aliceLink)
4921  	require.NoError(t, err)
4922  
4923  	bobLink := newMockChannelLink(
4924  		s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
4925  		false,
4926  	)
4927  	err = s.AddLink(bobLink)
4928  	require.NoError(t, err)
4929  
4930  	// Create a packet that will be sent from Alice to Bob via the switch.
4931  	preimage := [sha256.Size]byte{1}
4932  	rhash := sha256.Sum256(preimage[:])
4933  	ogPacket := &htlcPacket{
4934  		incomingChanID: aliceLink.ShortChanID(),
4935  		incomingHTLCID: 0,
4936  		outgoingChanID: bobLink.ShortChanID(),
4937  		obfuscator:     NewMockObfuscator(),
4938  		htlc: &lnwire.UpdateAddHTLC{
4939  			PaymentHash: rhash,
4940  			Amount:      1,
4941  		},
4942  	}
4943  
4944  	// Forward the packet and check that Bob's channel link received it.
4945  	err = s.ForwardPackets(nil, ogPacket)
4946  	require.NoError(t, err)
4947  
4948  	// Assert that the circuits are in the expected state.
4949  	require.Equal(t, 1, s.circuits.NumPending())
4950  	require.Equal(t, 0, s.circuits.NumOpen())
4951  
4952  	// Pull packet from Bob's link, and do nothing with it.
4953  	select {
4954  	case <-bobLink.packets:
4955  	case <-s.quit:
4956  		t.Fatal("switch shutting down, failed to forward packet")
4957  	}
4958  
4959  	// Now we will restart the Switch to trigger the LoadedFromDisk logic.
4960  	err = s.Stop()
4961  	require.NoError(t, err)
4962  
4963  	err = cdb.Close()
4964  	require.NoError(t, err)
4965  
4966  	cdb2 := channeldb.OpenForTesting(t, tempPath)
4967  
4968  	s2, err := initSwitchWithDB(testStartingHeight, cdb2)
4969  	require.NoError(t, err)
4970  
4971  	err = s2.Start()
4972  	require.NoError(t, err)
4973  
4974  	defer func() {
4975  		_ = s2.Stop()
4976  	}()
4977  
4978  	var aliceLink2 *mockChannelLink
4979  	if zeroConf {
4980  		aliceLink2 = newMockChannelLink(
4981  			s2, chanID1, aliceAlias, aliceChanID, alicePeer, true,
4982  			true, true, false,
4983  		)
4984  	} else {
4985  		aliceLink2 = newMockChannelLink(
4986  			s2, chanID1, aliceChanID, emptyScid, alicePeer, true,
4987  			true, false, true,
4988  		)
4989  		aliceLink2.addAlias(aliceAlias)
4990  	}
4991  	err = s2.AddLink(aliceLink2)
4992  	require.NoError(t, err)
4993  
4994  	bobLink2 := newMockChannelLink(
4995  		s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false,
4996  		false,
4997  	)
4998  	err = s2.AddLink(bobLink2)
4999  	require.NoError(t, err)
5000  
5001  	// Reforward the ogPacket and wait for Alice to receive a failure
5002  	// packet.
5003  	err = s2.ForwardPackets(nil, ogPacket)
5004  	require.NoError(t, err)
5005  
5006  	select {
5007  	case failPacket := <-aliceLink2.packets:
5008  		// Assert that the failPacket does not leak UTXO information.
5009  		// This means checking that aliceChanID was not returned.
5010  		msg := failPacket.linkFailure.msg
5011  		failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure)
5012  		require.True(t, ok)
5013  		require.Equal(t, aliceAlias, failMsg.Update.ShortChannelID)
5014  	case <-s2.quit:
5015  		t.Fatal("switch shutting down, failed to forward packet")
5016  	}
5017  }
5018  
5019  // TestSwitchAliasFailAdd tests that the mailbox does not leak UTXO information
5020  // when failing back an HTLC due to the 5-second timeout. This is tested in the
5021  // switch rather than the mailbox because the mailbox tests do not have the
5022  // proper context (e.g. the Switch's failAliasUpdate function). The caveat here
5023  // is that if the private UTXO is already known, it is fine to send a failure
5024  // back. This tests option-scid-alias (feature-bit) and zero-conf channels.
5025  func TestSwitchAliasFailAdd(t *testing.T) {
5026  	tests := []struct {
5027  		name string
5028  
5029  		// Denotes whether the opened channel will be zero-conf.
5030  		zeroConf bool
5031  
5032  		// Denotes whether the opened channel will be private.
5033  		private bool
5034  
5035  		// Denotes whether an alias was used during forwarding.
5036  		useAlias bool
5037  	}{
5038  		{
5039  			name:     "public zero-conf using alias",
5040  			zeroConf: true,
5041  			private:  false,
5042  			useAlias: true,
5043  		},
5044  		{
5045  			name:     "public zero-conf using real",
5046  			zeroConf: true,
5047  			private:  false,
5048  			useAlias: true,
5049  		},
5050  		{
5051  			name:     "private zero-conf using alias",
5052  			zeroConf: true,
5053  			private:  true,
5054  			useAlias: true,
5055  		},
5056  		{
5057  			name:     "public option-scid-alias using alias",
5058  			zeroConf: false,
5059  			private:  false,
5060  			useAlias: true,
5061  		},
5062  		{
5063  			name:     "public option-scid-alias using real",
5064  			zeroConf: false,
5065  			private:  false,
5066  			useAlias: false,
5067  		},
5068  		{
5069  			name:     "private option-scid-alias using alias",
5070  			zeroConf: false,
5071  			private:  true,
5072  			useAlias: true,
5073  		},
5074  	}
5075  
5076  	for _, test := range tests {
5077  		test := test
5078  
5079  		t.Run(test.name, func(t *testing.T) {
5080  			testSwitchAliasFailAdd(
5081  				t, test.zeroConf, test.private, test.useAlias,
5082  			)
5083  		})
5084  	}
5085  }
5086  
5087  func testSwitchAliasFailAdd(t *testing.T, zeroConf, private, useAlias bool) {
5088  	t.Parallel()
5089  
5090  	chanID1, chanID2, aliceChanID, bobChanID := genIDs()
5091  
5092  	alicePeer, err := newMockServer(
5093  		t, "alice", testStartingHeight, nil, testDefaultDelta,
5094  	)
5095  	require.NoError(t, err)
5096  
5097  	bobPeer, err := newMockServer(
5098  		t, "bob", testStartingHeight, nil, testDefaultDelta,
5099  	)
5100  	require.NoError(t, err)
5101  
5102  	tempPath := t.TempDir()
5103  
5104  	cdb := channeldb.OpenForTesting(t, tempPath)
5105  
5106  	s, err := initSwitchWithDB(testStartingHeight, cdb)
5107  	require.NoError(t, err)
5108  
5109  	// Change the mailOrchestrator's expiry to a second.
5110  	s.mailOrchestrator.cfg.expiry = time.Second
5111  
5112  	err = s.Start()
5113  	require.NoError(t, err)
5114  
5115  	defer func() {
5116  		_ = s.Stop()
5117  	}()
5118  
5119  	// Make Alice's channel zero-conf or option-scid-alias (feature bit).
5120  	aliceAlias := lnwire.ShortChannelID{
5121  		BlockHeight: 16_000_000,
5122  		TxIndex:     5,
5123  		TxPosition:  5,
5124  	}
5125  	aliceAlias2 := aliceAlias
5126  	aliceAlias2.TxPosition = 6
5127  
5128  	var aliceLink *mockChannelLink
5129  	if zeroConf {
5130  		aliceLink = newMockChannelLink(
5131  			s, chanID1, aliceAlias, aliceChanID, alicePeer, true,
5132  			private, true, false,
5133  		)
5134  		aliceLink.addAlias(aliceAlias2)
5135  	} else {
5136  		aliceLink = newMockChannelLink(
5137  			s, chanID1, aliceChanID, emptyScid, alicePeer, true,
5138  			private, false, true,
5139  		)
5140  		aliceLink.addAlias(aliceAlias)
5141  		aliceLink.addAlias(aliceAlias2)
5142  	}
5143  	err = s.AddLink(aliceLink)
5144  	require.NoError(t, err)
5145  
5146  	bobLink := newMockChannelLink(
5147  		s, chanID2, bobChanID, emptyScid, bobPeer, true, true, false,
5148  		false,
5149  	)
5150  	err = s.AddLink(bobLink)
5151  	require.NoError(t, err)
5152  
5153  	// Create a packet that Bob will send to Alice via ForwardPackets.
5154  	preimage := [sha256.Size]byte{1}
5155  	rhash := sha256.Sum256(preimage[:])
5156  	ogPacket := &htlcPacket{
5157  		incomingChanID: bobLink.ShortChanID(),
5158  		incomingHTLCID: 0,
5159  		obfuscator:     NewMockObfuscator(),
5160  		htlc: &lnwire.UpdateAddHTLC{
5161  			PaymentHash: rhash,
5162  			Amount:      1,
5163  		},
5164  	}
5165  
5166  	// Determine which outgoingChanID to set based on the useAlias boolean.
5167  	outgoingChanID := aliceChanID
5168  	if useAlias {
5169  		// Choose randomly from the 2 possible aliases.
5170  		aliases := aliceLink.getAliases()
5171  		idx := mrand.Intn(len(aliases))
5172  
5173  		outgoingChanID = aliases[idx]
5174  	}
5175  
5176  	ogPacket.outgoingChanID = outgoingChanID
5177  
5178  	// Forward the packet so Alice's mailbox fails it backwards.
5179  	err = s.ForwardPackets(nil, ogPacket)
5180  	require.NoError(t, err)
5181  
5182  	// Assert that the circuits are in the expected state.
5183  	require.Equal(t, 1, s.circuits.NumPending())
5184  	require.Equal(t, 0, s.circuits.NumOpen())
5185  
5186  	// Wait to receive the packet from Bob's mailbox.
5187  	select {
5188  	case failPacket := <-bobLink.packets:
5189  		// Assert that failPacket returns the expected SCID in the
5190  		// ChannelUpdate.
5191  		msg := failPacket.linkFailure.msg
5192  		failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure)
5193  		require.True(t, ok)
5194  		require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID)
5195  	case <-s.quit:
5196  		t.Fatal("switch shutting down, failed to receive fail packet")
5197  	}
5198  }
5199  
5200  // TestSwitchHandlePacketForwardAlias checks that handlePacketForward (which
5201  // calls CheckHtlcForward) does not leak the UTXO in a failure message for
5202  // alias channels. This test requires us to have a REAL link, which we also
5203  // must modify in order to test it properly (e.g. making it a private channel).
5204  // This doesn't lead to good code, but short of refactoring the link-generation
5205  // code there is not a good alternative.
5206  func TestSwitchHandlePacketForward(t *testing.T) {
5207  	tests := []struct {
5208  		name string
5209  
5210  		// Denotes whether or not the channel will be zero-conf.
5211  		zeroConf bool
5212  
5213  		// Denotes whether or not the channel will have negotiated the
5214  		// option-scid-alias feature-bit and is not zero-conf.
5215  		optionFeature bool
5216  
5217  		// Denotes whether or not the channel will be private.
5218  		private bool
5219  
5220  		// Denotes whether or not the alias will be used for
5221  		// forwarding.
5222  		useAlias bool
5223  	}{
5224  		{
5225  			name:     "public zero-conf using alias",
5226  			zeroConf: true,
5227  			private:  false,
5228  			useAlias: true,
5229  		},
5230  		{
5231  			name:     "public zero-conf using real",
5232  			zeroConf: true,
5233  			private:  false,
5234  			useAlias: false,
5235  		},
5236  		{
5237  			name:     "private zero-conf using alias",
5238  			zeroConf: true,
5239  			private:  true,
5240  			useAlias: true,
5241  		},
5242  		{
5243  			name:          "public option-scid-alias using alias",
5244  			zeroConf:      false,
5245  			optionFeature: true,
5246  			private:       false,
5247  			useAlias:      true,
5248  		},
5249  		{
5250  			name:          "public option-scid-alias using real",
5251  			zeroConf:      false,
5252  			optionFeature: true,
5253  			private:       false,
5254  			useAlias:      false,
5255  		},
5256  		{
5257  			name:          "private option-scid-alias using alias",
5258  			zeroConf:      false,
5259  			optionFeature: true,
5260  			private:       true,
5261  			useAlias:      true,
5262  		},
5263  	}
5264  
5265  	for _, test := range tests {
5266  		test := test
5267  
5268  		t.Run(test.name, func(t *testing.T) {
5269  			testSwitchHandlePacketForward(
5270  				t, test.zeroConf, test.private, test.useAlias,
5271  				test.optionFeature,
5272  			)
5273  		})
5274  	}
5275  }
5276  
5277  func testSwitchHandlePacketForward(t *testing.T, zeroConf, private,
5278  	useAlias, optionFeature bool) {
5279  
5280  	t.Parallel()
5281  
5282  	// Create a link for Alice that we'll add to the switch.
5283  	harness, err :=
5284  		newSingleLinkTestHarness(t, btcutil.SatoshiPerBitcoin, 0)
5285  	require.NoError(t, err)
5286  
5287  	aliceLink := harness.aliceLink
5288  
5289  	s, err := initSwitchWithTempDB(t, testStartingHeight)
5290  	if err != nil {
5291  		t.Fatalf("unable to init switch: %v", err)
5292  	}
5293  	if err := s.Start(); err != nil {
5294  		t.Fatalf("unable to start switch: %v", err)
5295  	}
5296  	defer func() {
5297  		_ = s.Stop()
5298  	}()
5299  
5300  	// Change Alice's ShortChanID and OtherShortChanID here.
5301  	aliceAlias := lnwire.ShortChannelID{
5302  		BlockHeight: 16_000_000,
5303  		TxIndex:     5,
5304  		TxPosition:  5,
5305  	}
5306  	aliceAlias2 := aliceAlias
5307  	aliceAlias2.TxPosition = 6
5308  
5309  	aliceChannelLink := aliceLink.(*channelLink)
5310  	aliceChannelState := aliceChannelLink.channel.State()
5311  
5312  	// Set the link's GetAliases function.
5313  	aliceChannelLink.cfg.GetAliases = func(
5314  		base lnwire.ShortChannelID) []lnwire.ShortChannelID {
5315  
5316  		return []lnwire.ShortChannelID{aliceAlias, aliceAlias2}
5317  	}
5318  
5319  	if !private {
5320  		// Change the channel to public depending on the test.
5321  		aliceChannelState.ChannelFlags = lnwire.FFAnnounceChannel
5322  	}
5323  
5324  	// If this is an option-scid-alias feature-bit non-zero-conf channel,
5325  	// we'll mark the channel as such.
5326  	if optionFeature {
5327  		aliceChannelState.ChanType |= channeldb.ScidAliasFeatureBit
5328  	}
5329  
5330  	// This is the ShortChannelID field in the OpenChannel struct.
5331  	aliceScid := aliceLink.ShortChanID()
5332  	if zeroConf {
5333  		// Store the alias in the shortChanID field and mark the real
5334  		// scid in the database.
5335  		err = aliceChannelState.MarkRealScid(aliceScid)
5336  		require.NoError(t, err)
5337  
5338  		aliceChannelState.ChanType |= channeldb.ZeroConfBit
5339  	}
5340  
5341  	err = s.AddLink(aliceLink)
5342  	require.NoError(t, err)
5343  
5344  	// Add a mockChannelLink for Bob.
5345  	bobChanID, bobScid := genID()
5346  	bobPeer, err := newMockServer(
5347  		t, "bob", testStartingHeight, nil, testDefaultDelta,
5348  	)
5349  	require.NoError(t, err)
5350  
5351  	bobLink := newMockChannelLink(
5352  		s, bobChanID, bobScid, emptyScid, bobPeer, true, false, false,
5353  		false,
5354  	)
5355  	err = s.AddLink(bobLink)
5356  	require.NoError(t, err)
5357  
5358  	preimage := [sha256.Size]byte{1}
5359  	rhash := sha256.Sum256(preimage[:])
5360  	ogPacket := &htlcPacket{
5361  		incomingChanID: bobLink.ShortChanID(),
5362  		incomingHTLCID: 0,
5363  		incomingAmount: 1000,
5364  		obfuscator:     NewMockObfuscator(),
5365  		htlc: &lnwire.UpdateAddHTLC{
5366  			PaymentHash: rhash,
5367  			Amount:      1,
5368  		},
5369  	}
5370  
5371  	// Determine which outgoingChanID to set based on the useAlias bool.
5372  	outgoingChanID := aliceScid
5373  	if useAlias {
5374  		// Choose from the possible aliases.
5375  		aliases := aliceLink.getAliases()
5376  		idx := mrand.Intn(len(aliases))
5377  
5378  		outgoingChanID = aliases[idx]
5379  	}
5380  
5381  	ogPacket.outgoingChanID = outgoingChanID
5382  
5383  	// Forward the packet to Alice and she should fail it back with an
5384  	// AmountBelowMinimum FailureMessage.
5385  	err = s.ForwardPackets(nil, ogPacket)
5386  	require.NoError(t, err)
5387  
5388  	select {
5389  	case failPacket := <-bobLink.packets:
5390  		// Assert that failPacket returns the expected ChannelUpdate.
5391  		msg := failPacket.linkFailure.msg
5392  		failMsg, ok := msg.(*lnwire.FailAmountBelowMinimum)
5393  		require.True(t, ok)
5394  		require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID)
5395  	case <-s.quit:
5396  		t.Fatal("switch shutting down, failed to receive failure")
5397  	}
5398  }
5399  
5400  // TestSwitchAliasInterceptFail tests that when the InterceptableSwitch fails
5401  // an incoming HTLC, it does not leak the on-chain UTXO for option-scid-alias
5402  // (feature bit) or zero-conf channels.
5403  func TestSwitchAliasInterceptFail(t *testing.T) {
5404  	tests := []struct {
5405  		name string
5406  
5407  		// Denotes whether or not the incoming channel is a zero-conf
5408  		// channel or an option-scid-alias channel instead (feature
5409  		// bit).
5410  		zeroConf bool
5411  	}{
5412  		{
5413  			name:     "option-scid-alias",
5414  			zeroConf: false,
5415  		},
5416  		{
5417  			name:     "zero-conf",
5418  			zeroConf: true,
5419  		},
5420  	}
5421  
5422  	for _, test := range tests {
5423  		test := test
5424  
5425  		t.Run(test.name, func(t *testing.T) {
5426  			testSwitchAliasInterceptFail(t, test.zeroConf)
5427  		})
5428  	}
5429  }
5430  
5431  func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) {
5432  	t.Parallel()
5433  
5434  	chanID, aliceScid := genID()
5435  
5436  	alicePeer, err := newMockServer(
5437  		t, "alice", testStartingHeight, nil, testDefaultDelta,
5438  	)
5439  	require.NoError(t, err)
5440  
5441  	tempPath := t.TempDir()
5442  
5443  	cdb := channeldb.OpenForTesting(t, tempPath)
5444  
5445  	s, err := initSwitchWithDB(testStartingHeight, cdb)
5446  	require.NoError(t, err)
5447  
5448  	err = s.Start()
5449  	require.NoError(t, err)
5450  
5451  	defer func() {
5452  		_ = s.Stop()
5453  	}()
5454  
5455  	// Make Alice's alias here.
5456  	aliceAlias := lnwire.ShortChannelID{
5457  		BlockHeight: 16_000_000,
5458  		TxIndex:     5,
5459  		TxPosition:  5,
5460  	}
5461  	aliceAlias2 := aliceAlias
5462  	aliceAlias2.TxPosition = 6
5463  
5464  	var aliceLink *mockChannelLink
5465  	if zeroConf {
5466  		aliceLink = newMockChannelLink(
5467  			s, chanID, aliceAlias, aliceScid, alicePeer, true,
5468  			true, true, false,
5469  		)
5470  		aliceLink.addAlias(aliceAlias2)
5471  	} else {
5472  		aliceLink = newMockChannelLink(
5473  			s, chanID, aliceScid, emptyScid, alicePeer, true,
5474  			true, false, true,
5475  		)
5476  		aliceLink.addAlias(aliceAlias)
5477  		aliceLink.addAlias(aliceAlias2)
5478  	}
5479  	err = s.AddLink(aliceLink)
5480  	require.NoError(t, err)
5481  
5482  	// Now we'll create the packet that will be sent from the Alice link.
5483  	preimage := [sha256.Size]byte{1}
5484  	rhash := sha256.Sum256(preimage[:])
5485  	ogPacket := &htlcPacket{
5486  		incomingChanID:  aliceLink.ShortChanID(),
5487  		incomingTimeout: 1000,
5488  		incomingHTLCID:  0,
5489  		outgoingChanID:  lnwire.ShortChannelID{},
5490  		obfuscator:      NewMockObfuscator(),
5491  		htlc: &lnwire.UpdateAddHTLC{
5492  			PaymentHash: rhash,
5493  			Amount:      1,
5494  		},
5495  	}
5496  
5497  	// Now setup the interceptable switch so that we can reject this
5498  	// packet.
5499  	forwardInterceptor := &mockForwardInterceptor{
5500  		t:               t,
5501  		interceptedChan: make(chan InterceptedPacket),
5502  	}
5503  
5504  	notifier := &mock.ChainNotifier{
5505  		EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
5506  	}
5507  	notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight}
5508  
5509  	interceptSwitch, err := NewInterceptableSwitch(
5510  		&InterceptableSwitchConfig{
5511  			Switch:             s,
5512  			Notifier:           notifier,
5513  			CltvRejectDelta:    10,
5514  			CltvInterceptDelta: 13,
5515  		},
5516  	)
5517  	require.NoError(t, err)
5518  	require.NoError(t, interceptSwitch.Start())
5519  	interceptSwitch.SetInterceptor(forwardInterceptor.InterceptForwardHtlc)
5520  
5521  	err = interceptSwitch.ForwardPackets(nil, false, ogPacket)
5522  	require.NoError(t, err)
5523  
5524  	inCircuit := forwardInterceptor.getIntercepted().IncomingCircuit
5525  	require.NoError(t, interceptSwitch.resolve(&FwdResolution{
5526  		Action:      FwdActionFail,
5527  		Key:         inCircuit,
5528  		FailureCode: lnwire.CodeTemporaryChannelFailure,
5529  	}))
5530  
5531  	select {
5532  	case failPacket := <-aliceLink.packets:
5533  		// Assert that failPacket returns the expected ChannelUpdate.
5534  		failHtlc, ok := failPacket.htlc.(*lnwire.UpdateFailHTLC)
5535  		require.True(t, ok)
5536  
5537  		fwdErr, err := newMockDeobfuscator().DecryptError(
5538  			failHtlc.Reason,
5539  		)
5540  		require.NoError(t, err)
5541  
5542  		failure := fwdErr.WireMessage()
5543  
5544  		failureMsg, ok := failure.(*lnwire.FailTemporaryChannelFailure)
5545  		require.True(t, ok)
5546  
5547  		failScid := failureMsg.Update.ShortChannelID
5548  		isAlias := failScid == aliceAlias || failScid == aliceAlias2
5549  		require.True(t, isAlias)
5550  
5551  	case <-s.quit:
5552  		t.Fatalf("switch shutting down, failed to receive failure")
5553  	}
5554  
5555  	require.NoError(t, interceptSwitch.Stop())
5556  }