/ htlcswitch / switch_test.go
switch_test.go
1 package htlcswitch 2 3 import ( 4 "crypto/rand" 5 "crypto/sha256" 6 "errors" 7 "fmt" 8 "io" 9 mrand "math/rand" 10 "reflect" 11 "testing" 12 "time" 13 14 "github.com/btcsuite/btcd/btcutil" 15 "github.com/davecgh/go-spew/spew" 16 "github.com/lightningnetwork/lnd/chainntnfs" 17 "github.com/lightningnetwork/lnd/channeldb" 18 "github.com/lightningnetwork/lnd/contractcourt" 19 "github.com/lightningnetwork/lnd/fn/v2" 20 "github.com/lightningnetwork/lnd/graph/db/models" 21 "github.com/lightningnetwork/lnd/htlcswitch/hodl" 22 "github.com/lightningnetwork/lnd/htlcswitch/hop" 23 "github.com/lightningnetwork/lnd/lntest/mock" 24 "github.com/lightningnetwork/lnd/lntest/wait" 25 "github.com/lightningnetwork/lnd/lntypes" 26 "github.com/lightningnetwork/lnd/lnwallet/chainfee" 27 "github.com/lightningnetwork/lnd/lnwire" 28 "github.com/lightningnetwork/lnd/ticker" 29 "github.com/stretchr/testify/require" 30 ) 31 32 var zeroCircuit = models.CircuitKey{} 33 var emptyScid = lnwire.ShortChannelID{} 34 35 func genPreimage() ([32]byte, error) { 36 var preimage [32]byte 37 if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil { 38 return preimage, err 39 } 40 return preimage, nil 41 } 42 43 // TestSwitchAddDuplicateLink tests that the switch will reject duplicate links 44 // for live links. It also tests that we can successfully add a link after 45 // having removed it. 46 func TestSwitchAddDuplicateLink(t *testing.T) { 47 t.Parallel() 48 49 alicePeer, err := newMockServer( 50 t, "alice", testStartingHeight, nil, testDefaultDelta, 51 ) 52 require.NoError(t, err, "unable to create alice server") 53 54 s, err := initSwitchWithTempDB(t, testStartingHeight) 55 require.NoError(t, err, "unable to init switch") 56 if err := s.Start(); err != nil { 57 t.Fatalf("unable to start switch: %v", err) 58 } 59 defer s.Stop() 60 61 chanID1, aliceScid := genID() 62 63 aliceChannelLink := newMockChannelLink( 64 s, chanID1, aliceScid, emptyScid, alicePeer, false, false, 65 false, false, 66 ) 67 if err := s.AddLink(aliceChannelLink); err != nil { 68 t.Fatalf("unable to add alice link: %v", err) 69 } 70 71 // Alice should have a live link, adding again should fail. 72 if err := s.AddLink(aliceChannelLink); err == nil { 73 t.Fatalf("adding duplicate link should have failed") 74 } 75 76 // Remove the live link to ensure the indexes are cleared. 77 s.RemoveLink(chanID1) 78 79 // Alice has no links, adding should succeed. 80 if err := s.AddLink(aliceChannelLink); err != nil { 81 t.Fatalf("unable to add alice link: %v", err) 82 } 83 } 84 85 // TestSwitchHasActiveLink tests the behavior of HasActiveLink, and asserts that 86 // it only returns true if a link's short channel id has confirmed (meaning the 87 // channel is no longer pending) and it's EligibleToForward method returns true, 88 // i.e. it has received ChannelReady from the remote peer. 89 func TestSwitchHasActiveLink(t *testing.T) { 90 t.Parallel() 91 92 alicePeer, err := newMockServer( 93 t, "alice", testStartingHeight, nil, testDefaultDelta, 94 ) 95 require.NoError(t, err, "unable to create alice server") 96 97 s, err := initSwitchWithTempDB(t, testStartingHeight) 98 require.NoError(t, err, "unable to init switch") 99 if err := s.Start(); err != nil { 100 t.Fatalf("unable to start switch: %v", err) 101 } 102 defer s.Stop() 103 104 chanID1, aliceScid := genID() 105 106 aliceChannelLink := newMockChannelLink( 107 s, chanID1, aliceScid, emptyScid, alicePeer, false, false, 108 false, false, 109 ) 110 if err := s.AddLink(aliceChannelLink); err != nil { 111 t.Fatalf("unable to add alice link: %v", err) 112 } 113 114 // The link has been added, but it's still pending. HasActiveLink should 115 // return false since the link has not been added to the linkIndex 116 // containing live links. 117 if s.HasActiveLink(chanID1) { 118 t.Fatalf("link should not be active yet, still pending") 119 } 120 121 // Finally, simulate the link receiving channel_ready by setting its 122 // eligibility to true. 123 aliceChannelLink.eligible = true 124 125 // The link should now be reported as active, since EligibleToForward 126 // returns true and the link is in the linkIndex. 127 if !s.HasActiveLink(chanID1) { 128 t.Fatalf("link should not be active now") 129 } 130 } 131 132 // TestSwitchSendPending checks the inability of htlc switch to forward adds 133 // over pending links. 134 func TestSwitchSendPending(t *testing.T) { 135 t.Parallel() 136 137 alicePeer, err := newMockServer( 138 t, "alice", testStartingHeight, nil, testDefaultDelta, 139 ) 140 require.NoError(t, err, "unable to create alice server") 141 142 bobPeer, err := newMockServer( 143 t, "bob", testStartingHeight, nil, testDefaultDelta, 144 ) 145 require.NoError(t, err, "unable to create bob server") 146 147 s, err := initSwitchWithTempDB(t, testStartingHeight) 148 require.NoError(t, err, "unable to init switch") 149 if err := s.Start(); err != nil { 150 t.Fatalf("unable to start switch: %v", err) 151 } 152 defer s.Stop() 153 154 chanID1, chanID2, aliceChanID, bobChanID := genIDs() 155 156 pendingChanID := lnwire.ShortChannelID{} 157 158 aliceChannelLink := newMockChannelLink( 159 s, chanID1, pendingChanID, emptyScid, alicePeer, false, false, 160 false, false, 161 ) 162 if err := s.AddLink(aliceChannelLink); err != nil { 163 t.Fatalf("unable to add alice link: %v", err) 164 } 165 166 bobChannelLink := newMockChannelLink( 167 s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 168 false, 169 ) 170 if err := s.AddLink(bobChannelLink); err != nil { 171 t.Fatalf("unable to add bob link: %v", err) 172 } 173 174 // Create request which should is being forwarded from Bob channel 175 // link to Alice channel link. 176 preimage, err := genPreimage() 177 require.NoError(t, err, "unable to generate preimage") 178 rhash := sha256.Sum256(preimage[:]) 179 packet := &htlcPacket{ 180 incomingChanID: bobChanID, 181 incomingHTLCID: 0, 182 outgoingChanID: aliceChanID, 183 obfuscator: NewMockObfuscator(), 184 htlc: &lnwire.UpdateAddHTLC{ 185 PaymentHash: rhash, 186 Amount: 1, 187 }, 188 } 189 190 // Send the ADD packet, this should not be forwarded out to the link 191 // since there are no eligible links. 192 if err = s.ForwardPackets(nil, packet); err != nil { 193 t.Fatal(err) 194 } 195 select { 196 case p := <-bobChannelLink.packets: 197 if p.linkFailure != nil { 198 err = p.linkFailure 199 } 200 case <-time.After(time.Second): 201 t.Fatal("no timely reply from switch") 202 } 203 linkErr, ok := err.(*LinkError) 204 if !ok { 205 t.Fatalf("expected link error, got: %T", err) 206 } 207 if linkErr.WireMessage().Code() != lnwire.CodeUnknownNextPeer { 208 t.Fatalf("expected fail unknown next peer, got: %T", 209 linkErr.WireMessage().Code()) 210 } 211 212 // No message should be sent, since the packet was failed. 213 select { 214 case <-aliceChannelLink.packets: 215 t.Fatal("expected not to receive message") 216 case <-time.After(time.Second): 217 } 218 219 // Since the packet should have been failed, there should be no active 220 // circuits. 221 if s.circuits.NumOpen() != 0 { 222 t.Fatal("wrong amount of circuits") 223 } 224 } 225 226 // TestSwitchForwardMapping checks that the Switch properly consults its maps 227 // when forwarding packets. 228 func TestSwitchForwardMapping(t *testing.T) { 229 tests := []struct { 230 name string 231 232 // If this is true, then Alice's channel will be private. 233 alicePrivate bool 234 235 // If this is true, then Alice's channel will be a zero-conf 236 // channel. 237 zeroConf bool 238 239 // If this is true, then Alice's channel will be an 240 // option-scid-alias feature-bit, non-zero-conf channel. 241 optionScid bool 242 243 // If this is true, then an alias will be used for forwarding. 244 useAlias bool 245 246 // This is Alice's channel alias. This may not be set if this 247 // is not an option_scid_alias channel (feature bit). 248 aliceAlias lnwire.ShortChannelID 249 250 // This is Alice's confirmed SCID. This may not be set if this 251 // is a zero-conf channel before confirmation. 252 aliceReal lnwire.ShortChannelID 253 254 // If this is set, we expect Bob forwarding to Alice to fail. 255 expectErr bool 256 }{ 257 { 258 name: "private unconfirmed zero-conf", 259 alicePrivate: true, 260 zeroConf: true, 261 useAlias: true, 262 aliceAlias: lnwire.ShortChannelID{ 263 BlockHeight: 16_000_002, 264 TxIndex: 2, 265 TxPosition: 2, 266 }, 267 aliceReal: lnwire.ShortChannelID{}, 268 expectErr: false, 269 }, 270 { 271 name: "private confirmed zero-conf", 272 alicePrivate: true, 273 zeroConf: true, 274 useAlias: true, 275 aliceAlias: lnwire.ShortChannelID{ 276 BlockHeight: 16_000_003, 277 TxIndex: 3, 278 TxPosition: 3, 279 }, 280 aliceReal: lnwire.ShortChannelID{ 281 BlockHeight: 300000, 282 TxIndex: 3, 283 TxPosition: 3, 284 }, 285 expectErr: false, 286 }, 287 { 288 name: "private confirmed zero-conf failure", 289 alicePrivate: true, 290 zeroConf: true, 291 useAlias: false, 292 aliceAlias: lnwire.ShortChannelID{ 293 BlockHeight: 16_000_004, 294 TxIndex: 4, 295 TxPosition: 4, 296 }, 297 aliceReal: lnwire.ShortChannelID{ 298 BlockHeight: 300002, 299 TxIndex: 4, 300 TxPosition: 4, 301 }, 302 expectErr: true, 303 }, 304 { 305 name: "public unconfirmed zero-conf", 306 alicePrivate: false, 307 zeroConf: true, 308 useAlias: true, 309 aliceAlias: lnwire.ShortChannelID{ 310 BlockHeight: 16_000_005, 311 TxIndex: 5, 312 TxPosition: 5, 313 }, 314 aliceReal: lnwire.ShortChannelID{}, 315 expectErr: false, 316 }, 317 { 318 name: "public confirmed zero-conf w/ alias", 319 alicePrivate: false, 320 zeroConf: true, 321 useAlias: true, 322 aliceAlias: lnwire.ShortChannelID{ 323 BlockHeight: 16_000_006, 324 TxIndex: 6, 325 TxPosition: 6, 326 }, 327 aliceReal: lnwire.ShortChannelID{ 328 BlockHeight: 500000, 329 TxIndex: 6, 330 TxPosition: 6, 331 }, 332 expectErr: false, 333 }, 334 { 335 name: "public confirmed zero-conf w/ real", 336 alicePrivate: false, 337 zeroConf: true, 338 useAlias: false, 339 aliceAlias: lnwire.ShortChannelID{ 340 BlockHeight: 16_000_007, 341 TxIndex: 7, 342 TxPosition: 7, 343 }, 344 aliceReal: lnwire.ShortChannelID{ 345 BlockHeight: 502000, 346 TxIndex: 7, 347 TxPosition: 7, 348 }, 349 expectErr: false, 350 }, 351 { 352 name: "private non-option channel", 353 alicePrivate: true, 354 aliceAlias: lnwire.ShortChannelID{}, 355 aliceReal: lnwire.ShortChannelID{ 356 BlockHeight: 505000, 357 TxIndex: 8, 358 TxPosition: 8, 359 }, 360 }, 361 { 362 name: "private option channel w/ alias", 363 alicePrivate: true, 364 optionScid: true, 365 useAlias: true, 366 aliceAlias: lnwire.ShortChannelID{ 367 BlockHeight: 16_000_015, 368 TxIndex: 9, 369 TxPosition: 9, 370 }, 371 aliceReal: lnwire.ShortChannelID{ 372 BlockHeight: 506000, 373 TxIndex: 10, 374 TxPosition: 10, 375 }, 376 expectErr: false, 377 }, 378 { 379 name: "private option channel failure", 380 alicePrivate: true, 381 optionScid: true, 382 useAlias: false, 383 aliceAlias: lnwire.ShortChannelID{ 384 BlockHeight: 16_000_016, 385 TxIndex: 16, 386 TxPosition: 16, 387 }, 388 aliceReal: lnwire.ShortChannelID{ 389 BlockHeight: 507000, 390 TxIndex: 17, 391 TxPosition: 17, 392 }, 393 expectErr: true, 394 }, 395 { 396 name: "public non-option channel", 397 alicePrivate: false, 398 useAlias: false, 399 aliceAlias: lnwire.ShortChannelID{}, 400 aliceReal: lnwire.ShortChannelID{ 401 BlockHeight: 508000, 402 TxIndex: 17, 403 TxPosition: 17, 404 }, 405 expectErr: false, 406 }, 407 { 408 name: "public option channel w/ alias", 409 alicePrivate: false, 410 optionScid: true, 411 useAlias: true, 412 aliceAlias: lnwire.ShortChannelID{ 413 BlockHeight: 16_000_018, 414 TxIndex: 18, 415 TxPosition: 18, 416 }, 417 aliceReal: lnwire.ShortChannelID{ 418 BlockHeight: 509000, 419 TxIndex: 19, 420 TxPosition: 19, 421 }, 422 expectErr: false, 423 }, 424 { 425 name: "public option channel w/ real", 426 alicePrivate: false, 427 optionScid: true, 428 useAlias: false, 429 aliceAlias: lnwire.ShortChannelID{ 430 BlockHeight: 16_000_019, 431 TxIndex: 19, 432 TxPosition: 19, 433 }, 434 aliceReal: lnwire.ShortChannelID{ 435 BlockHeight: 510000, 436 TxIndex: 20, 437 TxPosition: 20, 438 }, 439 expectErr: false, 440 }, 441 } 442 443 for _, test := range tests { 444 test := test 445 t.Run(test.name, func(t *testing.T) { 446 t.Parallel() 447 testSwitchForwardMapping( 448 t, test.alicePrivate, test.zeroConf, 449 test.useAlias, test.optionScid, 450 test.aliceAlias, test.aliceReal, 451 test.expectErr, 452 ) 453 }) 454 } 455 } 456 457 func testSwitchForwardMapping(t *testing.T, alicePrivate, aliceZeroConf, 458 useAlias, optionScid bool, aliceAlias, aliceReal lnwire.ShortChannelID, 459 expectErr bool) { 460 461 alicePeer, err := newMockServer( 462 t, "alice", testStartingHeight, nil, testDefaultDelta, 463 ) 464 require.NoError(t, err) 465 466 bobPeer, err := newMockServer( 467 t, "bob", testStartingHeight, nil, testDefaultDelta, 468 ) 469 require.NoError(t, err) 470 471 s, err := initSwitchWithTempDB(t, testStartingHeight) 472 require.NoError(t, err) 473 err = s.Start() 474 require.NoError(t, err) 475 defer func() { _ = s.Stop() }() 476 477 // Create the lnwire.ChannelIDs that we'll use. 478 chanID1, chanID2, _, _ := genIDs() 479 480 var aliceChannelLink *mockChannelLink 481 482 if aliceZeroConf { 483 aliceChannelLink = newMockChannelLink( 484 s, chanID1, aliceAlias, aliceReal, alicePeer, true, 485 alicePrivate, true, false, 486 ) 487 } else { 488 aliceChannelLink = newMockChannelLink( 489 s, chanID1, aliceReal, emptyScid, alicePeer, true, 490 alicePrivate, false, optionScid, 491 ) 492 493 if optionScid { 494 aliceChannelLink.addAlias(aliceAlias) 495 } 496 } 497 498 err = s.AddLink(aliceChannelLink) 499 require.NoError(t, err) 500 501 // Bob will just have a non-option_scid_alias channel so no mapping is 502 // necessary. 503 bobScid := lnwire.ShortChannelID{ 504 BlockHeight: 501000, 505 TxIndex: 200, 506 TxPosition: 2, 507 } 508 509 bobChannelLink := newMockChannelLink( 510 s, chanID2, bobScid, emptyScid, bobPeer, true, false, false, 511 false, 512 ) 513 err = s.AddLink(bobChannelLink) 514 require.NoError(t, err) 515 516 // Generate preimage. 517 preimage, err := genPreimage() 518 require.NoError(t, err, "unable to generate preimage") 519 rhash := sha256.Sum256(preimage[:]) 520 521 // Determine the outgoing SCID to use. 522 outgoingSCID := aliceReal 523 if useAlias { 524 outgoingSCID = aliceAlias 525 } 526 527 packet := &htlcPacket{ 528 incomingChanID: bobScid, 529 incomingHTLCID: 0, 530 outgoingChanID: outgoingSCID, 531 obfuscator: NewMockObfuscator(), 532 htlc: &lnwire.UpdateAddHTLC{ 533 PaymentHash: rhash, 534 Amount: 1, 535 }, 536 } 537 err = s.ForwardPackets(nil, packet) 538 require.NoError(t, err) 539 540 // If we expect a forwarding error, then assert that we receive one. 541 // option_scid_alias forwards may fail if forwarding would be a privacy 542 // leak. 543 if expectErr { 544 select { 545 case <-bobChannelLink.packets: 546 case <-time.After(time.Second * 5): 547 t.Fatal("expected a forwarding error") 548 } 549 550 select { 551 case <-aliceChannelLink.packets: 552 t.Fatal("did not expect a packet") 553 case <-time.After(time.Second * 5): 554 } 555 } else { 556 select { 557 case <-bobChannelLink.packets: 558 t.Fatal("did not expect a forwarding error") 559 case <-time.After(time.Second * 5): 560 } 561 562 select { 563 case <-aliceChannelLink.packets: 564 case <-time.After(time.Second * 5): 565 t.Fatal("expected alice to receive packet") 566 } 567 } 568 } 569 570 // TestSwitchSendHTLCMapping tests that SendHTLC will properly route packets to 571 // zero-conf or option-scid-alias (feature-bit) channels if the confirmed SCID 572 // is used. It also tests that nothing breaks with the mapping change. 573 func TestSwitchSendHTLCMapping(t *testing.T) { 574 tests := []struct { 575 name string 576 577 // If this is true, the channel will be zero-conf. 578 zeroConf bool 579 580 // Denotes whether the channel is option-scid-alias, non 581 // zero-conf feature bit. 582 optionFeature bool 583 584 // If this is true, then the alias will be used in the packet. 585 useAlias bool 586 587 // This will be the channel alias if there is a mapping. 588 alias lnwire.ShortChannelID 589 590 // This will be the confirmed SCID if the channel is confirmed. 591 real lnwire.ShortChannelID 592 }{ 593 { 594 name: "non-zero-conf real scid w/ option", 595 zeroConf: false, 596 optionFeature: true, 597 useAlias: false, 598 alias: lnwire.ShortChannelID{ 599 BlockHeight: 10010, 600 TxIndex: 10, 601 TxPosition: 10, 602 }, 603 real: lnwire.ShortChannelID{ 604 BlockHeight: 500000, 605 TxIndex: 50, 606 TxPosition: 50, 607 }, 608 }, 609 { 610 name: "non-zero-conf real scid no option", 611 zeroConf: false, 612 useAlias: false, 613 alias: lnwire.ShortChannelID{}, 614 real: lnwire.ShortChannelID{ 615 BlockHeight: 400000, 616 TxIndex: 50, 617 TxPosition: 50, 618 }, 619 }, 620 { 621 name: "zero-conf alias scid w/ conf", 622 zeroConf: true, 623 useAlias: true, 624 alias: lnwire.ShortChannelID{ 625 BlockHeight: 10020, 626 TxIndex: 20, 627 TxPosition: 20, 628 }, 629 real: lnwire.ShortChannelID{ 630 BlockHeight: 450000, 631 TxIndex: 50, 632 TxPosition: 50, 633 }, 634 }, 635 { 636 name: "zero-conf alias scid no conf", 637 zeroConf: true, 638 useAlias: true, 639 alias: lnwire.ShortChannelID{ 640 BlockHeight: 10015, 641 TxIndex: 25, 642 TxPosition: 35, 643 }, 644 real: lnwire.ShortChannelID{}, 645 }, 646 { 647 name: "zero-conf real scid", 648 zeroConf: true, 649 useAlias: false, 650 alias: lnwire.ShortChannelID{ 651 BlockHeight: 10035, 652 TxIndex: 35, 653 TxPosition: 35, 654 }, 655 real: lnwire.ShortChannelID{ 656 BlockHeight: 470000, 657 TxIndex: 35, 658 TxPosition: 45, 659 }, 660 }, 661 } 662 663 for _, test := range tests { 664 test := test 665 t.Run(test.name, func(t *testing.T) { 666 t.Parallel() 667 testSwitchSendHtlcMapping( 668 t, test.zeroConf, test.useAlias, test.alias, 669 test.real, test.optionFeature, 670 ) 671 }) 672 } 673 } 674 675 func testSwitchSendHtlcMapping(t *testing.T, zeroConf, useAlias bool, alias, 676 realScid lnwire.ShortChannelID, optionFeature bool) { 677 678 peer, err := newMockServer( 679 t, "alice", testStartingHeight, nil, testDefaultDelta, 680 ) 681 require.NoError(t, err) 682 683 s, err := initSwitchWithTempDB(t, testStartingHeight) 684 require.NoError(t, err) 685 err = s.Start() 686 require.NoError(t, err) 687 defer func() { _ = s.Stop() }() 688 689 // Create the lnwire.ChannelID that we'll use. 690 chanID, _ := genID() 691 692 var link *mockChannelLink 693 694 if zeroConf { 695 link = newMockChannelLink( 696 s, chanID, alias, realScid, peer, true, false, true, 697 false, 698 ) 699 } else { 700 link = newMockChannelLink( 701 s, chanID, realScid, emptyScid, peer, true, false, 702 false, true, 703 ) 704 705 if optionFeature { 706 link.addAlias(alias) 707 } 708 } 709 710 err = s.AddLink(link) 711 require.NoError(t, err) 712 713 // Generate preimage. 714 preimage, err := genPreimage() 715 require.NoError(t, err) 716 rhash := sha256.Sum256(preimage[:]) 717 718 // Determine the outgoing SCID to use. 719 outgoingSCID := realScid 720 if useAlias { 721 outgoingSCID = alias 722 } 723 724 // Send the HTLC and assert that we don't get an error. 725 htlc := &lnwire.UpdateAddHTLC{ 726 PaymentHash: rhash, 727 Amount: 1, 728 } 729 730 err = s.SendHTLC(outgoingSCID, 0, htlc) 731 require.NoError(t, err) 732 } 733 734 // TestSwitchUpdateScid verifies that zero-conf and non-zero-conf 735 // option-scid-alias (feature bit) channels will have the expected entries in 736 // the aliasToReal and baseIndex maps. 737 func TestSwitchUpdateScid(t *testing.T) { 738 t.Parallel() 739 740 peer, err := newMockServer( 741 t, "alice", testStartingHeight, nil, testDefaultDelta, 742 ) 743 require.NoError(t, err, "unable to create alice server") 744 745 s, err := initSwitchWithTempDB(t, testStartingHeight) 746 require.NoError(t, err) 747 err = s.Start() 748 require.NoError(t, err) 749 defer func() { _ = s.Stop() }() 750 751 // Create the IDs that we'll use. 752 chanID, chanID2, _, _ := genIDs() 753 754 alias := lnwire.ShortChannelID{ 755 BlockHeight: 16_000_000, 756 TxIndex: 0, 757 TxPosition: 0, 758 } 759 alias2 := alias 760 alias2.TxPosition = 1 761 762 realScid := lnwire.ShortChannelID{ 763 BlockHeight: 500000, 764 TxIndex: 0, 765 TxPosition: 0, 766 } 767 768 link := newMockChannelLink( 769 s, chanID, alias, emptyScid, peer, true, false, true, false, 770 ) 771 link.addAlias(alias2) 772 773 err = s.AddLink(link) 774 require.NoError(t, err) 775 776 // Assert that the zero-conf link does not have entries in the 777 // aliasToReal map. 778 s.indexMtx.RLock() 779 _, ok := s.aliasToReal[alias] 780 require.False(t, ok) 781 _, ok = s.aliasToReal[alias2] 782 require.False(t, ok) 783 784 // Assert that both aliases point to the "base" SCID, which is actually 785 // just the first alias. 786 baseScid, ok := s.baseIndex[alias] 787 require.True(t, ok) 788 require.Equal(t, alias, baseScid) 789 790 baseScid, ok = s.baseIndex[alias2] 791 require.True(t, ok) 792 require.Equal(t, alias, baseScid) 793 794 s.indexMtx.RUnlock() 795 796 // We'll set the mock link's confirmed SCID so that UpdateShortChanID 797 // populates aliasToReal and adds an entry to baseIndex. 798 link.realScid = realScid 799 link.confirmedZC = true 800 801 err = s.UpdateShortChanID(chanID) 802 require.NoError(t, err) 803 804 // Assert that aliasToReal is populated and there is an entry in 805 // baseIndex for realScid. 806 s.indexMtx.RLock() 807 realMapping, ok := s.aliasToReal[alias] 808 require.True(t, ok) 809 require.Equal(t, realScid, realMapping) 810 811 realMapping, ok = s.aliasToReal[alias2] 812 require.True(t, ok) 813 require.Equal(t, realScid, realMapping) 814 815 baseScid, ok = s.baseIndex[realScid] 816 require.True(t, ok) 817 require.Equal(t, alias, baseScid) 818 819 s.indexMtx.RUnlock() 820 821 // Now we'll perform the same checks with a non-zero-conf 822 // option-scid-alias channel (feature-bit). 823 optionReal := lnwire.ShortChannelID{ 824 BlockHeight: 600000, 825 TxIndex: 0, 826 TxPosition: 0, 827 } 828 optionAlias := lnwire.ShortChannelID{ 829 BlockHeight: 12000, 830 TxIndex: 0, 831 TxPosition: 0, 832 } 833 optionAlias2 := optionAlias 834 optionAlias2.TxPosition = 1 835 link2 := newMockChannelLink( 836 s, chanID2, optionReal, emptyScid, peer, true, false, false, 837 true, 838 ) 839 link2.addAlias(optionAlias) 840 link2.addAlias(optionAlias2) 841 842 err = s.AddLink(link2) 843 require.NoError(t, err) 844 845 // Assert that the option-scid-alias link does have entries in the 846 // aliasToReal and baseIndex maps. 847 s.indexMtx.RLock() 848 realMapping, ok = s.aliasToReal[optionAlias] 849 require.True(t, ok) 850 require.Equal(t, optionReal, realMapping) 851 852 realMapping, ok = s.aliasToReal[optionAlias2] 853 require.True(t, ok) 854 require.Equal(t, optionReal, realMapping) 855 856 baseScid, ok = s.baseIndex[optionReal] 857 require.True(t, ok) 858 require.Equal(t, optionReal, baseScid) 859 860 baseScid, ok = s.baseIndex[optionAlias] 861 require.True(t, ok) 862 require.Equal(t, optionReal, baseScid) 863 864 baseScid, ok = s.baseIndex[optionAlias2] 865 require.True(t, ok) 866 require.Equal(t, optionReal, baseScid) 867 868 s.indexMtx.RUnlock() 869 } 870 871 // TestSwitchForward checks the ability of htlc switch to forward add/settle 872 // requests. 873 func TestSwitchForward(t *testing.T) { 874 t.Parallel() 875 876 alicePeer, err := newMockServer( 877 t, "alice", testStartingHeight, nil, testDefaultDelta, 878 ) 879 if err != nil { 880 t.Fatalf("unable to create alice server: %v", err) 881 } 882 bobPeer, err := newMockServer( 883 t, "bob", testStartingHeight, nil, testDefaultDelta, 884 ) 885 if err != nil { 886 t.Fatalf("unable to create bob server: %v", err) 887 } 888 889 s, err := initSwitchWithTempDB(t, testStartingHeight) 890 if err != nil { 891 t.Fatalf("unable to init switch: %v", err) 892 } 893 if err := s.Start(); err != nil { 894 t.Fatalf("unable to start switch: %v", err) 895 } 896 defer s.Stop() 897 898 chanID1, chanID2, aliceChanID, bobChanID := genIDs() 899 900 aliceChannelLink := newMockChannelLink( 901 s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 902 false, false, 903 ) 904 bobChannelLink := newMockChannelLink( 905 s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 906 false, 907 ) 908 if err := s.AddLink(aliceChannelLink); err != nil { 909 t.Fatalf("unable to add alice link: %v", err) 910 } 911 if err := s.AddLink(bobChannelLink); err != nil { 912 t.Fatalf("unable to add bob link: %v", err) 913 } 914 915 // Create request which should be forwarded from Alice channel link to 916 // bob channel link. 917 preimage, err := genPreimage() 918 if err != nil { 919 t.Fatalf("unable to generate preimage: %v", err) 920 } 921 rhash := sha256.Sum256(preimage[:]) 922 packet := &htlcPacket{ 923 incomingChanID: aliceChannelLink.ShortChanID(), 924 incomingHTLCID: 0, 925 outgoingChanID: bobChannelLink.ShortChanID(), 926 obfuscator: NewMockObfuscator(), 927 htlc: &lnwire.UpdateAddHTLC{ 928 PaymentHash: rhash, 929 Amount: 1, 930 }, 931 } 932 933 // Handle the request and checks that bob channel link received it. 934 if err := s.ForwardPackets(nil, packet); err != nil { 935 t.Fatal(err) 936 } 937 938 select { 939 case <-bobChannelLink.packets: 940 if err := bobChannelLink.completeCircuit(packet); err != nil { 941 t.Fatalf("unable to complete payment circuit: %v", err) 942 } 943 case <-time.After(time.Second): 944 t.Fatal("request was not propagated to destination") 945 } 946 947 if s.circuits.NumOpen() != 1 { 948 t.Fatal("wrong amount of circuits") 949 } 950 951 if !s.IsForwardedHTLC(bobChannelLink.ShortChanID(), 0) { 952 t.Fatal("htlc should be identified as forwarded") 953 } 954 955 // Create settle request pretending that bob link handled the add htlc 956 // request and sent the htlc settle request back. This request should 957 // be forwarder back to Alice link. 958 packet = &htlcPacket{ 959 outgoingChanID: bobChannelLink.ShortChanID(), 960 outgoingHTLCID: 0, 961 amount: 1, 962 htlc: &lnwire.UpdateFulfillHTLC{ 963 PaymentPreimage: preimage, 964 }, 965 } 966 967 // Handle the request and checks that payment circuit works properly. 968 if err := s.ForwardPackets(nil, packet); err != nil { 969 t.Fatal(err) 970 } 971 972 select { 973 case pkt := <-aliceChannelLink.packets: 974 if err := aliceChannelLink.deleteCircuit(pkt); err != nil { 975 t.Fatalf("unable to remove circuit: %v", err) 976 } 977 case <-time.After(time.Second): 978 t.Fatal("request was not propagated to channelPoint") 979 } 980 981 if s.circuits.NumOpen() != 0 { 982 t.Fatal("wrong amount of circuits") 983 } 984 } 985 986 func TestSwitchForwardFailAfterFullAdd(t *testing.T) { 987 t.Parallel() 988 989 chanID1, chanID2, aliceChanID, bobChanID := genIDs() 990 991 alicePeer, err := newMockServer( 992 t, "alice", testStartingHeight, nil, testDefaultDelta, 993 ) 994 if err != nil { 995 t.Fatalf("unable to create alice server: %v", err) 996 } 997 bobPeer, err := newMockServer( 998 t, "bob", testStartingHeight, nil, testDefaultDelta, 999 ) 1000 require.NoError(t, err, "unable to create bob server") 1001 1002 tempPath := t.TempDir() 1003 1004 cdb := channeldb.OpenForTesting(t, tempPath) 1005 1006 s, err := initSwitchWithDB(testStartingHeight, cdb) 1007 require.NoError(t, err, "unable to init switch") 1008 if err := s.Start(); err != nil { 1009 t.Fatalf("unable to start switch: %v", err) 1010 } 1011 1012 // Even though we intend to Stop s later in the test, it is safe to 1013 // defer this Stop since its execution it is protected by an atomic 1014 // guard, guaranteeing it executes at most once. 1015 defer s.Stop() 1016 1017 aliceChannelLink := newMockChannelLink( 1018 s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 1019 false, false, 1020 ) 1021 bobChannelLink := newMockChannelLink( 1022 s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 1023 false, 1024 ) 1025 if err := s.AddLink(aliceChannelLink); err != nil { 1026 t.Fatalf("unable to add alice link: %v", err) 1027 } 1028 if err := s.AddLink(bobChannelLink); err != nil { 1029 t.Fatalf("unable to add bob link: %v", err) 1030 } 1031 1032 // Create request which should be forwarded from Alice channel link to 1033 // bob channel link. 1034 preimage := [sha256.Size]byte{1} 1035 rhash := sha256.Sum256(preimage[:]) 1036 ogPacket := &htlcPacket{ 1037 incomingChanID: aliceChannelLink.ShortChanID(), 1038 incomingHTLCID: 0, 1039 outgoingChanID: bobChannelLink.ShortChanID(), 1040 obfuscator: NewMockObfuscator(), 1041 htlc: &lnwire.UpdateAddHTLC{ 1042 PaymentHash: rhash, 1043 Amount: 1, 1044 }, 1045 } 1046 1047 if s.circuits.NumPending() != 0 { 1048 t.Fatalf("wrong amount of half circuits") 1049 } 1050 if s.circuits.NumOpen() != 0 { 1051 t.Fatalf("wrong amount of circuits") 1052 } 1053 1054 // Handle the request and checks that bob channel link received it. 1055 if err := s.ForwardPackets(nil, ogPacket); err != nil { 1056 t.Fatal(err) 1057 } 1058 1059 if s.circuits.NumPending() != 1 { 1060 t.Fatalf("wrong amount of half circuits") 1061 } 1062 if s.circuits.NumOpen() != 0 { 1063 t.Fatalf("wrong amount of circuits") 1064 } 1065 1066 // Pull packet from bob's link, but do not perform a full add. 1067 select { 1068 case packet := <-bobChannelLink.packets: 1069 // Complete the payment circuit and assign the outgoing htlc id 1070 // before restarting. 1071 if err := bobChannelLink.completeCircuit(packet); err != nil { 1072 t.Fatalf("unable to complete payment circuit: %v", err) 1073 } 1074 1075 case <-time.After(time.Second): 1076 t.Fatal("request was not propagated to destination") 1077 } 1078 1079 if s.circuits.NumPending() != 1 { 1080 t.Fatalf("wrong amount of half circuits") 1081 } 1082 if s.circuits.NumOpen() != 1 { 1083 t.Fatalf("wrong amount of circuits") 1084 } 1085 1086 // Now we will restart bob, leaving the forwarding decision for this 1087 // htlc is in the half-added state. 1088 if err := s.Stop(); err != nil { 1089 t.Fatal(err) 1090 } 1091 1092 if err := cdb.Close(); err != nil { 1093 t.Fatal(err) 1094 } 1095 1096 cdb2 := channeldb.OpenForTesting(t, tempPath) 1097 1098 s2, err := initSwitchWithDB(testStartingHeight, cdb2) 1099 require.NoError(t, err, "unable reinit switch") 1100 if err := s2.Start(); err != nil { 1101 t.Fatalf("unable to restart switch: %v", err) 1102 } 1103 1104 // Even though we intend to Stop s2 later in the test, it is safe to 1105 // defer this Stop since its execution it is protected by an atomic 1106 // guard, guaranteeing it executes at most once. 1107 defer s2.Stop() 1108 1109 aliceChannelLink = newMockChannelLink( 1110 s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 1111 false, false, 1112 ) 1113 bobChannelLink = newMockChannelLink( 1114 s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 1115 false, 1116 ) 1117 if err := s2.AddLink(aliceChannelLink); err != nil { 1118 t.Fatalf("unable to add alice link: %v", err) 1119 } 1120 if err := s2.AddLink(bobChannelLink); err != nil { 1121 t.Fatalf("unable to add bob link: %v", err) 1122 } 1123 1124 if s2.circuits.NumPending() != 1 { 1125 t.Fatalf("wrong amount of half circuits") 1126 } 1127 if s2.circuits.NumOpen() != 1 { 1128 t.Fatalf("wrong amount of circuits") 1129 } 1130 1131 // Craft a failure message from the remote peer. 1132 fail := &htlcPacket{ 1133 outgoingChanID: bobChannelLink.ShortChanID(), 1134 outgoingHTLCID: 0, 1135 amount: 1, 1136 htlc: &lnwire.UpdateFailHTLC{}, 1137 } 1138 1139 // Send the fail packet from the remote peer through the switch. 1140 if err := s2.ForwardPackets(nil, fail); err != nil { 1141 t.Fatal(err) 1142 } 1143 1144 // Pull packet from alice's link, as it should have gone through 1145 // successfully. 1146 select { 1147 case pkt := <-aliceChannelLink.packets: 1148 if err := aliceChannelLink.completeCircuit(pkt); err != nil { 1149 t.Fatalf("unable to remove circuit: %v", err) 1150 } 1151 case <-time.After(time.Second): 1152 t.Fatal("request was not propagated to destination") 1153 } 1154 1155 // Circuit map should be empty now. 1156 if s2.circuits.NumPending() != 0 { 1157 t.Fatalf("wrong amount of half circuits") 1158 } 1159 if s2.circuits.NumOpen() != 0 { 1160 t.Fatalf("wrong amount of circuits") 1161 } 1162 1163 // Send the fail packet from the remote peer through the switch. 1164 if err := s.ForwardPackets(nil, fail); err != nil { 1165 t.Fatal(err) 1166 } 1167 select { 1168 case <-aliceChannelLink.packets: 1169 t.Fatalf("expected duplicate fail to not arrive at the destination") 1170 case <-time.After(time.Second): 1171 } 1172 } 1173 1174 func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { 1175 t.Parallel() 1176 1177 chanID1, chanID2, aliceChanID, bobChanID := genIDs() 1178 1179 alicePeer, err := newMockServer( 1180 t, "alice", testStartingHeight, nil, testDefaultDelta, 1181 ) 1182 require.NoError(t, err, "unable to create alice server") 1183 bobPeer, err := newMockServer( 1184 t, "bob", testStartingHeight, nil, testDefaultDelta, 1185 ) 1186 require.NoError(t, err, "unable to create bob server") 1187 1188 tempPath := t.TempDir() 1189 1190 cdb := channeldb.OpenForTesting(t, tempPath) 1191 1192 s, err := initSwitchWithDB(testStartingHeight, cdb) 1193 require.NoError(t, err, "unable to init switch") 1194 if err := s.Start(); err != nil { 1195 t.Fatalf("unable to start switch: %v", err) 1196 } 1197 1198 // Even though we intend to Stop s later in the test, it is safe to 1199 // defer this Stop since its execution it is protected by an atomic 1200 // guard, guaranteeing it executes at most once. 1201 defer s.Stop() 1202 1203 aliceChannelLink := newMockChannelLink( 1204 s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 1205 false, false, 1206 ) 1207 bobChannelLink := newMockChannelLink( 1208 s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 1209 false, 1210 ) 1211 if err := s.AddLink(aliceChannelLink); err != nil { 1212 t.Fatalf("unable to add alice link: %v", err) 1213 } 1214 if err := s.AddLink(bobChannelLink); err != nil { 1215 t.Fatalf("unable to add bob link: %v", err) 1216 } 1217 1218 // Create request which should be forwarded from Alice channel link to 1219 // bob channel link. 1220 preimage := [sha256.Size]byte{1} 1221 rhash := sha256.Sum256(preimage[:]) 1222 ogPacket := &htlcPacket{ 1223 incomingChanID: aliceChannelLink.ShortChanID(), 1224 incomingHTLCID: 0, 1225 outgoingChanID: bobChannelLink.ShortChanID(), 1226 obfuscator: NewMockObfuscator(), 1227 htlc: &lnwire.UpdateAddHTLC{ 1228 PaymentHash: rhash, 1229 Amount: 1, 1230 }, 1231 } 1232 1233 if s.circuits.NumPending() != 0 { 1234 t.Fatalf("wrong amount of half circuits") 1235 } 1236 if s.circuits.NumOpen() != 0 { 1237 t.Fatalf("wrong amount of circuits") 1238 } 1239 1240 // Handle the request and checks that bob channel link received it. 1241 if err := s.ForwardPackets(nil, ogPacket); err != nil { 1242 t.Fatal(err) 1243 } 1244 1245 if s.circuits.NumPending() != 1 { 1246 t.Fatalf("wrong amount of half circuits") 1247 } 1248 if s.circuits.NumOpen() != 0 { 1249 t.Fatalf("wrong amount of circuits") 1250 } 1251 1252 // Pull packet from bob's link, but do not perform a full add. 1253 select { 1254 case packet := <-bobChannelLink.packets: 1255 // Complete the payment circuit and assign the outgoing htlc id 1256 // before restarting. 1257 if err := bobChannelLink.completeCircuit(packet); err != nil { 1258 t.Fatalf("unable to complete payment circuit: %v", err) 1259 } 1260 1261 case <-time.After(time.Second): 1262 t.Fatal("request was not propagated to destination") 1263 } 1264 1265 if s.circuits.NumPending() != 1 { 1266 t.Fatalf("wrong amount of half circuits") 1267 } 1268 if s.circuits.NumOpen() != 1 { 1269 t.Fatalf("wrong amount of circuits") 1270 } 1271 1272 // Now we will restart bob, leaving the forwarding decision for this 1273 // htlc is in the half-added state. 1274 if err := s.Stop(); err != nil { 1275 t.Fatal(err) 1276 } 1277 1278 if err := cdb.Close(); err != nil { 1279 t.Fatal(err) 1280 } 1281 1282 cdb2 := channeldb.OpenForTesting(t, tempPath) 1283 1284 s2, err := initSwitchWithDB(testStartingHeight, cdb2) 1285 require.NoError(t, err, "unable reinit switch") 1286 if err := s2.Start(); err != nil { 1287 t.Fatalf("unable to restart switch: %v", err) 1288 } 1289 1290 // Even though we intend to Stop s2 later in the test, it is safe to 1291 // defer this Stop since its execution it is protected by an atomic 1292 // guard, guaranteeing it executes at most once. 1293 defer s2.Stop() 1294 1295 aliceChannelLink = newMockChannelLink( 1296 s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 1297 false, false, 1298 ) 1299 bobChannelLink = newMockChannelLink( 1300 s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 1301 false, 1302 ) 1303 if err := s2.AddLink(aliceChannelLink); err != nil { 1304 t.Fatalf("unable to add alice link: %v", err) 1305 } 1306 if err := s2.AddLink(bobChannelLink); err != nil { 1307 t.Fatalf("unable to add bob link: %v", err) 1308 } 1309 1310 if s2.circuits.NumPending() != 1 { 1311 t.Fatalf("wrong amount of half circuits") 1312 } 1313 if s2.circuits.NumOpen() != 1 { 1314 t.Fatalf("wrong amount of circuits") 1315 } 1316 1317 // Craft a settle message from the remote peer. 1318 settle := &htlcPacket{ 1319 outgoingChanID: bobChannelLink.ShortChanID(), 1320 outgoingHTLCID: 0, 1321 amount: 1, 1322 htlc: &lnwire.UpdateFulfillHTLC{ 1323 PaymentPreimage: preimage, 1324 }, 1325 } 1326 1327 // Send the settle packet from the remote peer through the switch. 1328 if err := s2.ForwardPackets(nil, settle); err != nil { 1329 t.Fatal(err) 1330 } 1331 1332 // Pull packet from alice's link, as it should have gone through 1333 // successfully. 1334 select { 1335 case packet := <-aliceChannelLink.packets: 1336 if err := aliceChannelLink.completeCircuit(packet); err != nil { 1337 t.Fatalf("unable to complete circuit with in key=%s: %v", 1338 packet.inKey(), err) 1339 } 1340 case <-time.After(time.Second): 1341 t.Fatal("request was not propagated to destination") 1342 } 1343 1344 // Circuit map should be empty now. 1345 if s2.circuits.NumPending() != 0 { 1346 t.Fatalf("wrong amount of half circuits") 1347 } 1348 if s2.circuits.NumOpen() != 0 { 1349 t.Fatalf("wrong amount of circuits") 1350 } 1351 1352 // Send the settle packet again, which not arrive at destination. 1353 if err := s2.ForwardPackets(nil, settle); err != nil { 1354 t.Fatal(err) 1355 } 1356 select { 1357 case <-bobChannelLink.packets: 1358 t.Fatalf("expected duplicate fail to not arrive at the destination") 1359 case <-time.After(time.Second): 1360 } 1361 } 1362 1363 func TestSwitchForwardDropAfterFullAdd(t *testing.T) { 1364 t.Parallel() 1365 1366 chanID1, chanID2, aliceChanID, bobChanID := genIDs() 1367 1368 alicePeer, err := newMockServer( 1369 t, "alice", testStartingHeight, nil, testDefaultDelta, 1370 ) 1371 require.NoError(t, err, "unable to create alice server") 1372 bobPeer, err := newMockServer( 1373 t, "bob", testStartingHeight, nil, testDefaultDelta, 1374 ) 1375 require.NoError(t, err, "unable to create bob server") 1376 1377 tempPath := t.TempDir() 1378 1379 cdb := channeldb.OpenForTesting(t, tempPath) 1380 1381 s, err := initSwitchWithDB(testStartingHeight, cdb) 1382 require.NoError(t, err, "unable to init switch") 1383 if err := s.Start(); err != nil { 1384 t.Fatalf("unable to start switch: %v", err) 1385 } 1386 1387 // Even though we intend to Stop s later in the test, it is safe to 1388 // defer this Stop since its execution it is protected by an atomic 1389 // guard, guaranteeing it executes at most once. 1390 defer s.Stop() 1391 1392 aliceChannelLink := newMockChannelLink( 1393 s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 1394 false, false, 1395 ) 1396 bobChannelLink := newMockChannelLink( 1397 s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 1398 false, 1399 ) 1400 if err := s.AddLink(aliceChannelLink); err != nil { 1401 t.Fatalf("unable to add alice link: %v", err) 1402 } 1403 if err := s.AddLink(bobChannelLink); err != nil { 1404 t.Fatalf("unable to add bob link: %v", err) 1405 } 1406 1407 // Create request which should be forwarded from Alice channel link to 1408 // bob channel link. 1409 preimage := [sha256.Size]byte{1} 1410 rhash := sha256.Sum256(preimage[:]) 1411 ogPacket := &htlcPacket{ 1412 incomingChanID: aliceChannelLink.ShortChanID(), 1413 incomingHTLCID: 0, 1414 outgoingChanID: bobChannelLink.ShortChanID(), 1415 obfuscator: NewMockObfuscator(), 1416 htlc: &lnwire.UpdateAddHTLC{ 1417 PaymentHash: rhash, 1418 Amount: 1, 1419 }, 1420 } 1421 1422 if s.circuits.NumPending() != 0 { 1423 t.Fatalf("wrong amount of half circuits") 1424 } 1425 if s.circuits.NumOpen() != 0 { 1426 t.Fatalf("wrong amount of circuits") 1427 } 1428 1429 // Handle the request and checks that bob channel link received it. 1430 if err := s.ForwardPackets(nil, ogPacket); err != nil { 1431 t.Fatal(err) 1432 } 1433 1434 if s.circuits.NumPending() != 1 { 1435 t.Fatalf("wrong amount of half circuits") 1436 } 1437 if s.circuits.NumOpen() != 0 { 1438 t.Fatalf("wrong amount of half circuits") 1439 } 1440 1441 // Pull packet from bob's link, but do not perform a full add. 1442 select { 1443 case packet := <-bobChannelLink.packets: 1444 // Complete the payment circuit and assign the outgoing htlc id 1445 // before restarting. 1446 if err := bobChannelLink.completeCircuit(packet); err != nil { 1447 t.Fatalf("unable to complete payment circuit: %v", err) 1448 } 1449 case <-time.After(time.Second): 1450 t.Fatal("request was not propagated to destination") 1451 } 1452 1453 // Now we will restart bob, leaving the forwarding decision for this 1454 // htlc is in the half-added state. 1455 if err := s.Stop(); err != nil { 1456 t.Fatal(err) 1457 } 1458 1459 if err := cdb.Close(); err != nil { 1460 t.Fatal(err) 1461 } 1462 1463 cdb2 := channeldb.OpenForTesting(t, tempPath) 1464 1465 s2, err := initSwitchWithDB(testStartingHeight, cdb2) 1466 require.NoError(t, err, "unable reinit switch") 1467 if err := s2.Start(); err != nil { 1468 t.Fatalf("unable to restart switch: %v", err) 1469 } 1470 1471 // Even though we intend to Stop s2 later in the test, it is safe to 1472 // defer this Stop since its execution it is protected by an atomic 1473 // guard, guaranteeing it executes at most once. 1474 defer s2.Stop() 1475 1476 aliceChannelLink = newMockChannelLink( 1477 s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 1478 false, false, 1479 ) 1480 bobChannelLink = newMockChannelLink( 1481 s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 1482 false, 1483 ) 1484 if err := s2.AddLink(aliceChannelLink); err != nil { 1485 t.Fatalf("unable to add alice link: %v", err) 1486 } 1487 if err := s2.AddLink(bobChannelLink); err != nil { 1488 t.Fatalf("unable to add bob link: %v", err) 1489 } 1490 1491 if s2.circuits.NumPending() != 1 { 1492 t.Fatalf("wrong amount of half circuits") 1493 } 1494 if s2.circuits.NumOpen() != 1 { 1495 t.Fatalf("wrong amount of half circuits") 1496 } 1497 1498 // Resend the failed htlc. The packet will be dropped silently since the 1499 // switch will detect that it has been half added previously. 1500 if err := s2.ForwardPackets(nil, ogPacket); err != nil { 1501 t.Fatal(err) 1502 } 1503 1504 // After detecting an incomplete forward, the fail packet should have 1505 // been returned to the sender. 1506 select { 1507 case <-aliceChannelLink.packets: 1508 t.Fatal("request should not have returned to source") 1509 case <-bobChannelLink.packets: 1510 t.Fatal("request should not have forwarded to destination") 1511 case <-time.After(time.Second): 1512 } 1513 } 1514 1515 func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { 1516 t.Parallel() 1517 1518 chanID1, chanID2, aliceChanID, bobChanID := genIDs() 1519 1520 alicePeer, err := newMockServer( 1521 t, "alice", testStartingHeight, nil, testDefaultDelta, 1522 ) 1523 require.NoError(t, err, "unable to create alice server") 1524 bobPeer, err := newMockServer( 1525 t, "bob", testStartingHeight, nil, testDefaultDelta, 1526 ) 1527 require.NoError(t, err, "unable to create bob server") 1528 1529 tempPath := t.TempDir() 1530 1531 cdb := channeldb.OpenForTesting(t, tempPath) 1532 1533 s, err := initSwitchWithDB(testStartingHeight, cdb) 1534 require.NoError(t, err, "unable to init switch") 1535 if err := s.Start(); err != nil { 1536 t.Fatalf("unable to start switch: %v", err) 1537 } 1538 1539 // Even though we intend to Stop s later in the test, it is safe to 1540 // defer this Stop since its execution it is protected by an atomic 1541 // guard, guaranteeing it executes at most once. 1542 defer s.Stop() 1543 1544 aliceChannelLink := newMockChannelLink( 1545 s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 1546 false, false, 1547 ) 1548 bobChannelLink := newMockChannelLink( 1549 s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 1550 false, 1551 ) 1552 if err := s.AddLink(aliceChannelLink); err != nil { 1553 t.Fatalf("unable to add alice link: %v", err) 1554 } 1555 if err := s.AddLink(bobChannelLink); err != nil { 1556 t.Fatalf("unable to add bob link: %v", err) 1557 } 1558 1559 // Create request which should be forwarded from Alice channel link to 1560 // bob channel link. 1561 preimage := [sha256.Size]byte{1} 1562 rhash := sha256.Sum256(preimage[:]) 1563 ogPacket := &htlcPacket{ 1564 incomingChanID: aliceChannelLink.ShortChanID(), 1565 incomingHTLCID: 0, 1566 outgoingChanID: bobChannelLink.ShortChanID(), 1567 obfuscator: NewMockObfuscator(), 1568 htlc: &lnwire.UpdateAddHTLC{ 1569 PaymentHash: rhash, 1570 Amount: 1, 1571 }, 1572 } 1573 1574 if s.circuits.NumPending() != 0 { 1575 t.Fatalf("wrong amount of half circuits") 1576 } 1577 if s.circuits.NumOpen() != 0 { 1578 t.Fatalf("wrong amount of circuits") 1579 } 1580 1581 // Handle the request and checks that bob channel link received it. 1582 if err := s.ForwardPackets(nil, ogPacket); err != nil { 1583 t.Fatal(err) 1584 } 1585 1586 if s.circuits.NumPending() != 1 { 1587 t.Fatalf("wrong amount of half circuits") 1588 } 1589 if s.circuits.NumOpen() != 0 { 1590 t.Fatalf("wrong amount of half circuits") 1591 } 1592 1593 // Pull packet from bob's link, but do not perform a full add. 1594 select { 1595 case <-bobChannelLink.packets: 1596 case <-time.After(time.Second): 1597 t.Fatal("request was not propagated to destination") 1598 } 1599 1600 // Now we will restart bob, leaving the forwarding decision for this 1601 // htlc is in the half-added state. 1602 if err := s.Stop(); err != nil { 1603 t.Fatal(err) 1604 } 1605 1606 if err := cdb.Close(); err != nil { 1607 t.Fatal(err) 1608 } 1609 1610 cdb2 := channeldb.OpenForTesting(t, tempPath) 1611 1612 s2, err := initSwitchWithDB(testStartingHeight, cdb2) 1613 require.NoError(t, err, "unable reinit switch") 1614 if err := s2.Start(); err != nil { 1615 t.Fatalf("unable to restart switch: %v", err) 1616 } 1617 1618 // Even though we intend to Stop s2 later in the test, it is safe to 1619 // defer this Stop since its execution it is protected by an atomic 1620 // guard, guaranteeing it executes at most once. 1621 defer s2.Stop() 1622 1623 aliceChannelLink = newMockChannelLink( 1624 s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 1625 false, false, 1626 ) 1627 bobChannelLink = newMockChannelLink( 1628 s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 1629 false, 1630 ) 1631 if err := s2.AddLink(aliceChannelLink); err != nil { 1632 t.Fatalf("unable to add alice link: %v", err) 1633 } 1634 if err := s2.AddLink(bobChannelLink); err != nil { 1635 t.Fatalf("unable to add bob link: %v", err) 1636 } 1637 1638 if s2.circuits.NumPending() != 1 { 1639 t.Fatalf("wrong amount of half circuits") 1640 } 1641 if s2.circuits.NumOpen() != 0 { 1642 t.Fatalf("wrong amount of half circuits") 1643 } 1644 1645 // Resend the failed htlc, it should be returned to alice since the 1646 // switch will detect that it has been half added previously. 1647 err = s2.ForwardPackets(nil, ogPacket) 1648 if err != nil { 1649 t.Fatal(err) 1650 } 1651 1652 // After detecting an incomplete forward, the fail packet should have 1653 // been returned to the sender. 1654 select { 1655 case pkt := <-aliceChannelLink.packets: 1656 linkErr := pkt.linkFailure 1657 if linkErr.FailureDetail != OutgoingFailureIncompleteForward { 1658 t.Fatalf("expected incomplete forward, got: %v", 1659 linkErr.FailureDetail) 1660 } 1661 case <-time.After(time.Second): 1662 t.Fatal("request was not propagated to destination") 1663 } 1664 } 1665 1666 // TestSwitchForwardCircuitPersistence checks the ability of htlc switch to 1667 // maintain the proper entries in the circuit map in the face of restarts. 1668 func TestSwitchForwardCircuitPersistence(t *testing.T) { 1669 t.Parallel() 1670 1671 chanID1, chanID2, aliceChanID, bobChanID := genIDs() 1672 1673 alicePeer, err := newMockServer( 1674 t, "alice", testStartingHeight, nil, testDefaultDelta, 1675 ) 1676 require.NoError(t, err, "unable to create alice server") 1677 bobPeer, err := newMockServer( 1678 t, "bob", testStartingHeight, nil, testDefaultDelta, 1679 ) 1680 require.NoError(t, err, "unable to create bob server") 1681 1682 tempPath := t.TempDir() 1683 1684 cdb := channeldb.OpenForTesting(t, tempPath) 1685 1686 s, err := initSwitchWithDB(testStartingHeight, cdb) 1687 require.NoError(t, err, "unable to init switch") 1688 if err := s.Start(); err != nil { 1689 t.Fatalf("unable to start switch: %v", err) 1690 } 1691 1692 // Even though we intend to Stop s later in the test, it is safe to 1693 // defer this Stop since its execution it is protected by an atomic 1694 // guard, guaranteeing it executes at most once. 1695 defer s.Stop() 1696 1697 aliceChannelLink := newMockChannelLink( 1698 s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 1699 false, false, 1700 ) 1701 bobChannelLink := newMockChannelLink( 1702 s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 1703 false, 1704 ) 1705 if err := s.AddLink(aliceChannelLink); err != nil { 1706 t.Fatalf("unable to add alice link: %v", err) 1707 } 1708 if err := s.AddLink(bobChannelLink); err != nil { 1709 t.Fatalf("unable to add bob link: %v", err) 1710 } 1711 1712 // Create request which should be forwarded from Alice channel link to 1713 // bob channel link. 1714 preimage := [sha256.Size]byte{1} 1715 rhash := sha256.Sum256(preimage[:]) 1716 ogPacket := &htlcPacket{ 1717 incomingChanID: aliceChannelLink.ShortChanID(), 1718 incomingHTLCID: 0, 1719 outgoingChanID: bobChannelLink.ShortChanID(), 1720 obfuscator: NewMockObfuscator(), 1721 htlc: &lnwire.UpdateAddHTLC{ 1722 PaymentHash: rhash, 1723 Amount: 1, 1724 }, 1725 } 1726 1727 if s.circuits.NumPending() != 0 { 1728 t.Fatalf("wrong amount of half circuits") 1729 } 1730 if s.circuits.NumOpen() != 0 { 1731 t.Fatalf("wrong amount of circuits") 1732 } 1733 1734 // Handle the request and checks that bob channel link received it. 1735 if err := s.ForwardPackets(nil, ogPacket); err != nil { 1736 t.Fatal(err) 1737 } 1738 1739 if s.circuits.NumPending() != 1 { 1740 t.Fatalf("wrong amount of half circuits") 1741 } 1742 if s.circuits.NumOpen() != 0 { 1743 t.Fatalf("wrong amount of circuits") 1744 } 1745 1746 // Retrieve packet from outgoing link and cache until after restart. 1747 var packet *htlcPacket 1748 select { 1749 case packet = <-bobChannelLink.packets: 1750 case <-time.After(time.Second): 1751 t.Fatal("request was not propagated to destination") 1752 } 1753 1754 if err := s.Stop(); err != nil { 1755 t.Fatal(err) 1756 } 1757 1758 if err := cdb.Close(); err != nil { 1759 t.Fatal(err) 1760 } 1761 1762 cdb2 := channeldb.OpenForTesting(t, tempPath) 1763 1764 s2, err := initSwitchWithDB(testStartingHeight, cdb2) 1765 require.NoError(t, err, "unable reinit switch") 1766 if err := s2.Start(); err != nil { 1767 t.Fatalf("unable to restart switch: %v", err) 1768 } 1769 1770 // Even though we intend to Stop s2 later in the test, it is safe to 1771 // defer this Stop since its execution it is protected by an atomic 1772 // guard, guaranteeing it executes at most once. 1773 defer s2.Stop() 1774 1775 aliceChannelLink = newMockChannelLink( 1776 s2, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 1777 false, false, 1778 ) 1779 bobChannelLink = newMockChannelLink( 1780 s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 1781 false, 1782 ) 1783 if err := s2.AddLink(aliceChannelLink); err != nil { 1784 t.Fatalf("unable to add alice link: %v", err) 1785 } 1786 if err := s2.AddLink(bobChannelLink); err != nil { 1787 t.Fatalf("unable to add bob link: %v", err) 1788 } 1789 1790 if s2.circuits.NumPending() != 1 { 1791 t.Fatalf("wrong amount of half circuits") 1792 } 1793 if s2.circuits.NumOpen() != 0 { 1794 t.Fatalf("wrong amount of half circuits") 1795 } 1796 1797 // Now that the switch has restarted, complete the payment circuit. 1798 if err := bobChannelLink.completeCircuit(packet); err != nil { 1799 t.Fatalf("unable to complete payment circuit: %v", err) 1800 } 1801 1802 if s2.circuits.NumPending() != 1 { 1803 t.Fatalf("wrong amount of half circuits") 1804 } 1805 if s2.circuits.NumOpen() != 1 { 1806 t.Fatal("wrong amount of circuits") 1807 } 1808 1809 // Create settle request pretending that bob link handled the add htlc 1810 // request and sent the htlc settle request back. This request should 1811 // be forwarder back to Alice link. 1812 ogPacket = &htlcPacket{ 1813 outgoingChanID: bobChannelLink.ShortChanID(), 1814 outgoingHTLCID: 0, 1815 amount: 1, 1816 htlc: &lnwire.UpdateFulfillHTLC{ 1817 PaymentPreimage: preimage, 1818 }, 1819 } 1820 1821 // Handle the request and checks that payment circuit works properly. 1822 if err := s2.ForwardPackets(nil, ogPacket); err != nil { 1823 t.Fatal(err) 1824 } 1825 1826 select { 1827 case packet = <-aliceChannelLink.packets: 1828 if err := aliceChannelLink.completeCircuit(packet); err != nil { 1829 t.Fatalf("unable to complete circuit with in key=%s: %v", 1830 packet.inKey(), err) 1831 } 1832 case <-time.After(time.Second): 1833 t.Fatal("request was not propagated to channelPoint") 1834 } 1835 1836 if s2.circuits.NumPending() != 0 { 1837 t.Fatalf("wrong amount of half circuits, want 1, got %d", 1838 s2.circuits.NumPending()) 1839 } 1840 if s2.circuits.NumOpen() != 0 { 1841 t.Fatal("wrong amount of circuits") 1842 } 1843 1844 if err := s2.Stop(); err != nil { 1845 t.Fatal(err) 1846 } 1847 1848 if err := cdb2.Close(); err != nil { 1849 t.Fatal(err) 1850 } 1851 1852 cdb3 := channeldb.OpenForTesting(t, tempPath) 1853 1854 s3, err := initSwitchWithDB(testStartingHeight, cdb3) 1855 require.NoError(t, err, "unable reinit switch") 1856 if err := s3.Start(); err != nil { 1857 t.Fatalf("unable to restart switch: %v", err) 1858 } 1859 defer s3.Stop() 1860 1861 aliceChannelLink = newMockChannelLink( 1862 s3, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 1863 false, false, 1864 ) 1865 bobChannelLink = newMockChannelLink( 1866 s3, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 1867 false, 1868 ) 1869 if err := s3.AddLink(aliceChannelLink); err != nil { 1870 t.Fatalf("unable to add alice link: %v", err) 1871 } 1872 if err := s3.AddLink(bobChannelLink); err != nil { 1873 t.Fatalf("unable to add bob link: %v", err) 1874 } 1875 1876 if s3.circuits.NumPending() != 0 { 1877 t.Fatalf("wrong amount of half circuits") 1878 } 1879 if s3.circuits.NumOpen() != 0 { 1880 t.Fatalf("wrong amount of circuits") 1881 } 1882 } 1883 1884 type multiHopFwdTest struct { 1885 name string 1886 eligible1, eligible2 bool 1887 failure1, failure2 *LinkError 1888 expectedReply lnwire.FailCode 1889 } 1890 1891 // TestCircularForwards tests the allowing/disallowing of circular payments 1892 // through the same channel in the case where the switch is configured to allow 1893 // and disallow same channel circular forwards. 1894 func TestCircularForwards(t *testing.T) { 1895 chanID1, aliceChanID := genID() 1896 preimage := [sha256.Size]byte{1} 1897 hash := sha256.Sum256(preimage[:]) 1898 1899 tests := []struct { 1900 name string 1901 allowCircularPayment bool 1902 expectedErr error 1903 }{ 1904 { 1905 name: "circular payment allowed", 1906 allowCircularPayment: true, 1907 expectedErr: nil, 1908 }, 1909 { 1910 name: "circular payment disallowed", 1911 allowCircularPayment: false, 1912 expectedErr: NewDetailedLinkError( 1913 lnwire.NewTemporaryChannelFailure(nil), 1914 OutgoingFailureCircularRoute, 1915 ), 1916 }, 1917 } 1918 1919 for _, test := range tests { 1920 test := test 1921 t.Run(test.name, func(t *testing.T) { 1922 t.Parallel() 1923 1924 alicePeer, err := newMockServer( 1925 t, "alice", testStartingHeight, nil, 1926 testDefaultDelta, 1927 ) 1928 if err != nil { 1929 t.Fatalf("unable to create alice server: %v", 1930 err) 1931 } 1932 1933 s, err := initSwitchWithTempDB(t, testStartingHeight) 1934 if err != nil { 1935 t.Fatalf("unable to init switch: %v", err) 1936 } 1937 if err := s.Start(); err != nil { 1938 t.Fatalf("unable to start switch: %v", err) 1939 } 1940 defer func() { _ = s.Stop() }() 1941 1942 // Set the switch to allow or disallow circular routes 1943 // according to the test's requirements. 1944 s.cfg.AllowCircularRoute = test.allowCircularPayment 1945 1946 aliceChannelLink := newMockChannelLink( 1947 s, chanID1, aliceChanID, emptyScid, alicePeer, 1948 true, false, false, false, 1949 ) 1950 1951 if err := s.AddLink(aliceChannelLink); err != nil { 1952 t.Fatalf("unable to add alice link: %v", err) 1953 } 1954 1955 // Create a new packet that loops through alice's link 1956 // in a circle. 1957 obfuscator := NewMockObfuscator() 1958 packet := &htlcPacket{ 1959 incomingChanID: aliceChannelLink.ShortChanID(), 1960 outgoingChanID: aliceChannelLink.ShortChanID(), 1961 htlc: &lnwire.UpdateAddHTLC{ 1962 PaymentHash: hash, 1963 Amount: 1, 1964 }, 1965 obfuscator: obfuscator, 1966 } 1967 1968 // Attempt to forward the packet and check for the expected 1969 // error. 1970 if err = s.ForwardPackets(nil, packet); err != nil { 1971 t.Fatal(err) 1972 } 1973 select { 1974 case p := <-aliceChannelLink.packets: 1975 if p.linkFailure != nil { 1976 err = p.linkFailure 1977 } 1978 case <-time.After(time.Second): 1979 t.Fatal("no timely reply from switch") 1980 } 1981 if !reflect.DeepEqual(err, test.expectedErr) { 1982 t.Fatalf("expected: %v, got: %v", 1983 test.expectedErr, err) 1984 } 1985 1986 // Ensure that no circuits were opened. 1987 if s.circuits.NumOpen() > 0 { 1988 t.Fatal("do not expect any open circuits") 1989 } 1990 }) 1991 } 1992 } 1993 1994 // TestCheckCircularForward tests the error returned by checkCircularForward 1995 // in cases where we allow and disallow same channel circular forwards. 1996 func TestCheckCircularForward(t *testing.T) { 1997 tests := []struct { 1998 name string 1999 2000 // aliasMapping determines whether the test should add an alias 2001 // mapping to Switch alias maps before checkCircularForward. 2002 aliasMapping bool 2003 2004 // allowCircular determines whether we should allow circular 2005 // forwards. 2006 allowCircular bool 2007 2008 // incomingLink is the link that the htlc arrived on. 2009 incomingLink lnwire.ShortChannelID 2010 2011 // outgoingLink is the link that the htlc forward 2012 // is destined to leave on. 2013 outgoingLink lnwire.ShortChannelID 2014 2015 // expectedErr is the error we expect to be returned. 2016 expectedErr *LinkError 2017 }{ 2018 { 2019 name: "not circular, allowed in config", 2020 aliasMapping: false, 2021 allowCircular: true, 2022 incomingLink: lnwire.NewShortChanIDFromInt(123), 2023 outgoingLink: lnwire.NewShortChanIDFromInt(321), 2024 expectedErr: nil, 2025 }, 2026 { 2027 name: "not circular, not allowed in config", 2028 aliasMapping: false, 2029 allowCircular: false, 2030 incomingLink: lnwire.NewShortChanIDFromInt(123), 2031 outgoingLink: lnwire.NewShortChanIDFromInt(321), 2032 expectedErr: nil, 2033 }, 2034 { 2035 name: "circular, allowed in config", 2036 aliasMapping: false, 2037 allowCircular: true, 2038 incomingLink: lnwire.NewShortChanIDFromInt(123), 2039 outgoingLink: lnwire.NewShortChanIDFromInt(123), 2040 expectedErr: nil, 2041 }, 2042 { 2043 name: "circular, not allowed in config", 2044 aliasMapping: false, 2045 allowCircular: false, 2046 incomingLink: lnwire.NewShortChanIDFromInt(123), 2047 outgoingLink: lnwire.NewShortChanIDFromInt(123), 2048 expectedErr: NewDetailedLinkError( 2049 lnwire.NewTemporaryChannelFailure(nil), 2050 OutgoingFailureCircularRoute, 2051 ), 2052 }, 2053 { 2054 name: "circular with map, not allowed", 2055 aliasMapping: true, 2056 allowCircular: false, 2057 incomingLink: lnwire.NewShortChanIDFromInt(1 << 60), 2058 outgoingLink: lnwire.NewShortChanIDFromInt(1 << 55), 2059 expectedErr: NewDetailedLinkError( 2060 lnwire.NewTemporaryChannelFailure(nil), 2061 OutgoingFailureCircularRoute, 2062 ), 2063 }, 2064 { 2065 name: "circular with map, not allowed 2", 2066 aliasMapping: true, 2067 allowCircular: false, 2068 incomingLink: lnwire.NewShortChanIDFromInt(1 << 55), 2069 outgoingLink: lnwire.NewShortChanIDFromInt(1 << 60), 2070 expectedErr: NewDetailedLinkError( 2071 lnwire.NewTemporaryChannelFailure(nil), 2072 OutgoingFailureCircularRoute, 2073 ), 2074 }, 2075 { 2076 name: "circular with map, allowed", 2077 aliasMapping: true, 2078 allowCircular: true, 2079 incomingLink: lnwire.NewShortChanIDFromInt(1 << 60), 2080 outgoingLink: lnwire.NewShortChanIDFromInt(1 << 55), 2081 expectedErr: nil, 2082 }, 2083 { 2084 name: "circular with map, allowed 2", 2085 aliasMapping: true, 2086 allowCircular: true, 2087 incomingLink: lnwire.NewShortChanIDFromInt(1 << 55), 2088 outgoingLink: lnwire.NewShortChanIDFromInt(1 << 61), 2089 expectedErr: nil, 2090 }, 2091 { 2092 name: "not circular, both confirmed SCID", 2093 aliasMapping: false, 2094 allowCircular: false, 2095 incomingLink: lnwire.NewShortChanIDFromInt(1 << 60), 2096 outgoingLink: lnwire.NewShortChanIDFromInt(1 << 61), 2097 expectedErr: nil, 2098 }, 2099 } 2100 2101 for _, test := range tests { 2102 test := test 2103 2104 t.Run(test.name, func(t *testing.T) { 2105 t.Parallel() 2106 2107 s, err := initSwitchWithTempDB(t, testStartingHeight) 2108 require.NoError(t, err) 2109 err = s.Start() 2110 require.NoError(t, err) 2111 defer func() { _ = s.Stop() }() 2112 2113 if test.aliasMapping { 2114 // Make the incoming and outgoing point to the 2115 // same base SCID. 2116 inScid := test.incomingLink 2117 outScid := test.outgoingLink 2118 s.indexMtx.Lock() 2119 s.baseIndex[inScid] = outScid 2120 s.baseIndex[outScid] = outScid 2121 s.indexMtx.Unlock() 2122 } 2123 2124 // Check for a circular forward, the hash passed can 2125 // be nil because it is only used for logging. 2126 err = s.checkCircularForward( 2127 test.incomingLink, test.outgoingLink, 2128 test.allowCircular, lntypes.Hash{}, 2129 ) 2130 if !reflect.DeepEqual(err, test.expectedErr) { 2131 t.Fatalf("expected: %v, got: %v", 2132 test.expectedErr, err) 2133 } 2134 }) 2135 } 2136 } 2137 2138 // TestSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes 2139 // along, then we won't attempt to forward it down al ink that isn't yet able 2140 // to forward any HTLC's. 2141 func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { 2142 tests := []multiHopFwdTest{ 2143 // None of the channels is eligible. 2144 { 2145 name: "not eligible", 2146 expectedReply: lnwire.CodeUnknownNextPeer, 2147 }, 2148 2149 // Channel one has a policy failure and the other channel isn't 2150 // available. 2151 { 2152 name: "policy fail", 2153 eligible1: true, 2154 failure1: NewLinkError( 2155 lnwire.NewFinalIncorrectCltvExpiry(0), 2156 ), 2157 expectedReply: lnwire.CodeFinalIncorrectCltvExpiry, 2158 }, 2159 2160 // The requested channel is not eligible, but the packet is 2161 // forwarded through the other channel. 2162 { 2163 name: "non-strict success", 2164 eligible2: true, 2165 expectedReply: lnwire.CodeNone, 2166 }, 2167 2168 // The requested channel has insufficient bandwidth and the 2169 // other channel's policy isn't satisfied. 2170 { 2171 name: "non-strict policy fail", 2172 eligible1: true, 2173 failure1: NewDetailedLinkError( 2174 lnwire.NewTemporaryChannelFailure(nil), 2175 OutgoingFailureInsufficientBalance, 2176 ), 2177 eligible2: true, 2178 failure2: NewLinkError( 2179 lnwire.NewFinalIncorrectCltvExpiry(0), 2180 ), 2181 expectedReply: lnwire.CodeTemporaryChannelFailure, 2182 }, 2183 } 2184 2185 for _, test := range tests { 2186 test := test 2187 t.Run(test.name, func(t *testing.T) { 2188 testSkipIneligibleLinksMultiHopForward(t, &test) 2189 }) 2190 } 2191 } 2192 2193 // testSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes 2194 // along, then we won't attempt to forward it down al ink that isn't yet able 2195 // to forward any HTLC's. 2196 func testSkipIneligibleLinksMultiHopForward(t *testing.T, 2197 testCase *multiHopFwdTest) { 2198 2199 t.Parallel() 2200 2201 var packet *htlcPacket 2202 2203 alicePeer, err := newMockServer( 2204 t, "alice", testStartingHeight, nil, testDefaultDelta, 2205 ) 2206 require.NoError(t, err, "unable to create alice server") 2207 bobPeer, err := newMockServer( 2208 t, "bob", testStartingHeight, nil, testDefaultDelta, 2209 ) 2210 require.NoError(t, err, "unable to create bob server") 2211 2212 s, err := initSwitchWithTempDB(t, testStartingHeight) 2213 require.NoError(t, err, "unable to init switch") 2214 if err := s.Start(); err != nil { 2215 t.Fatalf("unable to start switch: %v", err) 2216 } 2217 defer s.Stop() 2218 2219 chanID1, aliceChanID := genID() 2220 aliceChannelLink := newMockChannelLink( 2221 s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 2222 false, false, 2223 ) 2224 2225 // We'll create a link for Bob, but mark the link as unable to forward 2226 // any new outgoing HTLC's. 2227 chanID2, bobChanID2 := genID() 2228 bobChannelLink1 := newMockChannelLink( 2229 s, chanID2, bobChanID2, emptyScid, bobPeer, testCase.eligible1, 2230 false, false, false, 2231 ) 2232 bobChannelLink1.checkHtlcForwardResult = testCase.failure1 2233 2234 chanID3, bobChanID3 := genID() 2235 bobChannelLink2 := newMockChannelLink( 2236 s, chanID3, bobChanID3, emptyScid, bobPeer, testCase.eligible2, 2237 false, false, false, 2238 ) 2239 bobChannelLink2.checkHtlcForwardResult = testCase.failure2 2240 2241 if err := s.AddLink(aliceChannelLink); err != nil { 2242 t.Fatalf("unable to add alice link: %v", err) 2243 } 2244 if err := s.AddLink(bobChannelLink1); err != nil { 2245 t.Fatalf("unable to add bob link: %v", err) 2246 } 2247 if err := s.AddLink(bobChannelLink2); err != nil { 2248 t.Fatalf("unable to add bob link: %v", err) 2249 } 2250 2251 // Create a new packet that's destined for Bob as an incoming HTLC from 2252 // Alice. 2253 preimage := [sha256.Size]byte{1} 2254 rhash := sha256.Sum256(preimage[:]) 2255 obfuscator := NewMockObfuscator() 2256 packet = &htlcPacket{ 2257 incomingChanID: aliceChannelLink.ShortChanID(), 2258 incomingHTLCID: 0, 2259 outgoingChanID: bobChannelLink1.ShortChanID(), 2260 htlc: &lnwire.UpdateAddHTLC{ 2261 PaymentHash: rhash, 2262 Amount: 1, 2263 }, 2264 obfuscator: obfuscator, 2265 } 2266 2267 // The request to forward should fail as 2268 if err := s.ForwardPackets(nil, packet); err != nil { 2269 t.Fatal(err) 2270 } 2271 2272 // We select from all links and extract the error if exists. 2273 // The packet must be selected but we don't always expect a link error. 2274 var linkError *LinkError 2275 select { 2276 case p := <-aliceChannelLink.packets: 2277 linkError = p.linkFailure 2278 case p := <-bobChannelLink1.packets: 2279 linkError = p.linkFailure 2280 case p := <-bobChannelLink2.packets: 2281 linkError = p.linkFailure 2282 case <-time.After(time.Second): 2283 t.Fatal("no timely reply from switch") 2284 } 2285 failure := obfuscator.(*mockObfuscator).failure 2286 if testCase.expectedReply == lnwire.CodeNone { 2287 if linkError != nil { 2288 t.Fatalf("forwarding should have succeeded") 2289 } 2290 if failure != nil { 2291 t.Fatalf("unexpected failure %T", failure) 2292 } 2293 } else { 2294 if linkError == nil { 2295 t.Fatalf("forwarding should have failed due to " + 2296 "inactive link") 2297 } 2298 if failure.Code() != testCase.expectedReply { 2299 t.Fatalf("unexpected failure %T", failure) 2300 } 2301 } 2302 2303 if s.circuits.NumOpen() != 0 { 2304 t.Fatal("wrong amount of circuits") 2305 } 2306 } 2307 2308 // TestSkipIneligibleLinksLocalForward ensures that the switch will not attempt 2309 // to forward any HTLC's down a link that isn't yet eligible for forwarding. 2310 func TestSkipIneligibleLinksLocalForward(t *testing.T) { 2311 t.Parallel() 2312 2313 testSkipLinkLocalForward(t, false, nil) 2314 } 2315 2316 // TestSkipPolicyUnsatisfiedLinkLocalForward ensures that the switch will not 2317 // attempt to send locally initiated HTLCs that would violate the channel policy 2318 // down a link. 2319 func TestSkipPolicyUnsatisfiedLinkLocalForward(t *testing.T) { 2320 t.Parallel() 2321 2322 testSkipLinkLocalForward(t, true, lnwire.NewTemporaryChannelFailure(nil)) 2323 } 2324 2325 func testSkipLinkLocalForward(t *testing.T, eligible bool, 2326 policyResult lnwire.FailureMessage) { 2327 2328 // We'll create a single link for this test, marking it as being unable 2329 // to forward form the get go. 2330 alicePeer, err := newMockServer( 2331 t, "alice", testStartingHeight, nil, testDefaultDelta, 2332 ) 2333 require.NoError(t, err, "unable to create alice server") 2334 2335 s, err := initSwitchWithTempDB(t, testStartingHeight) 2336 require.NoError(t, err, "unable to init switch") 2337 if err := s.Start(); err != nil { 2338 t.Fatalf("unable to start switch: %v", err) 2339 } 2340 defer s.Stop() 2341 2342 chanID1, _, aliceChanID, _ := genIDs() 2343 2344 aliceChannelLink := newMockChannelLink( 2345 s, chanID1, aliceChanID, emptyScid, alicePeer, eligible, false, 2346 false, false, 2347 ) 2348 aliceChannelLink.checkHtlcTransitResult = NewLinkError( 2349 policyResult, 2350 ) 2351 if err := s.AddLink(aliceChannelLink); err != nil { 2352 t.Fatalf("unable to add alice link: %v", err) 2353 } 2354 2355 preimage, err := genPreimage() 2356 require.NoError(t, err, "unable to generate preimage") 2357 rhash := sha256.Sum256(preimage[:]) 2358 addMsg := &lnwire.UpdateAddHTLC{ 2359 PaymentHash: rhash, 2360 Amount: 1, 2361 } 2362 2363 // We'll attempt to send out a new HTLC that has Alice as the first 2364 // outgoing link. This should fail as Alice isn't yet able to forward 2365 // any active HTLC's. 2366 err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg) 2367 if err == nil { 2368 t.Fatalf("local forward should fail due to inactive link") 2369 } 2370 2371 if s.circuits.NumOpen() != 0 { 2372 t.Fatal("wrong amount of circuits") 2373 } 2374 } 2375 2376 // TestSwitchCancel checks that if htlc was rejected we remove unused 2377 // circuits. 2378 func TestSwitchCancel(t *testing.T) { 2379 t.Parallel() 2380 2381 alicePeer, err := newMockServer( 2382 t, "alice", testStartingHeight, nil, testDefaultDelta, 2383 ) 2384 require.NoError(t, err, "unable to create alice server") 2385 bobPeer, err := newMockServer( 2386 t, "bob", testStartingHeight, nil, testDefaultDelta, 2387 ) 2388 require.NoError(t, err, "unable to create bob server") 2389 2390 s, err := initSwitchWithTempDB(t, testStartingHeight) 2391 require.NoError(t, err, "unable to init switch") 2392 if err := s.Start(); err != nil { 2393 t.Fatalf("unable to start switch: %v", err) 2394 } 2395 defer s.Stop() 2396 2397 chanID1, chanID2, aliceChanID, bobChanID := genIDs() 2398 2399 aliceChannelLink := newMockChannelLink( 2400 s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 2401 false, false, 2402 ) 2403 bobChannelLink := newMockChannelLink( 2404 s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 2405 false, 2406 ) 2407 if err := s.AddLink(aliceChannelLink); err != nil { 2408 t.Fatalf("unable to add alice link: %v", err) 2409 } 2410 if err := s.AddLink(bobChannelLink); err != nil { 2411 t.Fatalf("unable to add bob link: %v", err) 2412 } 2413 2414 // Create request which should be forwarder from alice channel link 2415 // to bob channel link. 2416 preimage, err := genPreimage() 2417 require.NoError(t, err, "unable to generate preimage") 2418 rhash := sha256.Sum256(preimage[:]) 2419 request := &htlcPacket{ 2420 incomingChanID: aliceChannelLink.ShortChanID(), 2421 incomingHTLCID: 0, 2422 outgoingChanID: bobChannelLink.ShortChanID(), 2423 obfuscator: NewMockObfuscator(), 2424 htlc: &lnwire.UpdateAddHTLC{ 2425 PaymentHash: rhash, 2426 Amount: 1, 2427 }, 2428 } 2429 2430 // Handle the request and checks that bob channel link received it. 2431 if err := s.ForwardPackets(nil, request); err != nil { 2432 t.Fatal(err) 2433 } 2434 2435 select { 2436 case packet := <-bobChannelLink.packets: 2437 if err := bobChannelLink.completeCircuit(packet); err != nil { 2438 t.Fatalf("unable to complete payment circuit: %v", err) 2439 } 2440 2441 case <-time.After(time.Second): 2442 t.Fatal("request was not propagated to destination") 2443 } 2444 2445 if s.circuits.NumPending() != 1 { 2446 t.Fatalf("wrong amount of half circuits") 2447 } 2448 if s.circuits.NumOpen() != 1 { 2449 t.Fatal("wrong amount of circuits") 2450 } 2451 2452 // Create settle request pretending that bob channel link handled 2453 // the add htlc request and sent the htlc settle request back. This 2454 // request should be forwarder back to alice channel link. 2455 request = &htlcPacket{ 2456 outgoingChanID: bobChannelLink.ShortChanID(), 2457 outgoingHTLCID: 0, 2458 amount: 1, 2459 htlc: &lnwire.UpdateFailHTLC{}, 2460 } 2461 2462 // Handle the request and checks that payment circuit works properly. 2463 if err := s.ForwardPackets(nil, request); err != nil { 2464 t.Fatal(err) 2465 } 2466 2467 select { 2468 case pkt := <-aliceChannelLink.packets: 2469 if err := aliceChannelLink.completeCircuit(pkt); err != nil { 2470 t.Fatalf("unable to remove circuit: %v", err) 2471 } 2472 2473 case <-time.After(time.Second): 2474 t.Fatal("request was not propagated to channelPoint") 2475 } 2476 2477 if s.circuits.NumPending() != 0 { 2478 t.Fatal("wrong amount of circuits") 2479 } 2480 if s.circuits.NumOpen() != 0 { 2481 t.Fatal("wrong amount of circuits") 2482 } 2483 } 2484 2485 // TestSwitchAddSamePayment tests that we send the payment with the same 2486 // payment hash. 2487 func TestSwitchAddSamePayment(t *testing.T) { 2488 t.Parallel() 2489 2490 chanID1, chanID2, aliceChanID, bobChanID := genIDs() 2491 2492 alicePeer, err := newMockServer( 2493 t, "alice", testStartingHeight, nil, testDefaultDelta, 2494 ) 2495 require.NoError(t, err, "unable to create alice server") 2496 bobPeer, err := newMockServer( 2497 t, "bob", testStartingHeight, nil, testDefaultDelta, 2498 ) 2499 require.NoError(t, err, "unable to create bob server") 2500 2501 s, err := initSwitchWithTempDB(t, testStartingHeight) 2502 require.NoError(t, err, "unable to init switch") 2503 if err := s.Start(); err != nil { 2504 t.Fatalf("unable to start switch: %v", err) 2505 } 2506 defer s.Stop() 2507 2508 aliceChannelLink := newMockChannelLink( 2509 s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 2510 false, false, 2511 ) 2512 bobChannelLink := newMockChannelLink( 2513 s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 2514 false, 2515 ) 2516 if err := s.AddLink(aliceChannelLink); err != nil { 2517 t.Fatalf("unable to add alice link: %v", err) 2518 } 2519 if err := s.AddLink(bobChannelLink); err != nil { 2520 t.Fatalf("unable to add bob link: %v", err) 2521 } 2522 2523 // Create request which should be forwarder from alice channel link 2524 // to bob channel link. 2525 preimage, err := genPreimage() 2526 require.NoError(t, err, "unable to generate preimage") 2527 rhash := sha256.Sum256(preimage[:]) 2528 request := &htlcPacket{ 2529 incomingChanID: aliceChannelLink.ShortChanID(), 2530 incomingHTLCID: 0, 2531 outgoingChanID: bobChannelLink.ShortChanID(), 2532 obfuscator: NewMockObfuscator(), 2533 htlc: &lnwire.UpdateAddHTLC{ 2534 PaymentHash: rhash, 2535 Amount: 1, 2536 }, 2537 } 2538 2539 // Handle the request and checks that bob channel link received it. 2540 if err := s.ForwardPackets(nil, request); err != nil { 2541 t.Fatal(err) 2542 } 2543 2544 select { 2545 case packet := <-bobChannelLink.packets: 2546 if err := bobChannelLink.completeCircuit(packet); err != nil { 2547 t.Fatalf("unable to complete payment circuit: %v", err) 2548 } 2549 2550 case <-time.After(time.Second): 2551 t.Fatal("request was not propagated to destination") 2552 } 2553 2554 if s.circuits.NumOpen() != 1 { 2555 t.Fatal("wrong amount of circuits") 2556 } 2557 2558 request = &htlcPacket{ 2559 incomingChanID: aliceChannelLink.ShortChanID(), 2560 incomingHTLCID: 1, 2561 outgoingChanID: bobChannelLink.ShortChanID(), 2562 obfuscator: NewMockObfuscator(), 2563 htlc: &lnwire.UpdateAddHTLC{ 2564 PaymentHash: rhash, 2565 Amount: 1, 2566 }, 2567 } 2568 2569 // Handle the request and checks that bob channel link received it. 2570 if err := s.ForwardPackets(nil, request); err != nil { 2571 t.Fatal(err) 2572 } 2573 2574 select { 2575 case packet := <-bobChannelLink.packets: 2576 if err := bobChannelLink.completeCircuit(packet); err != nil { 2577 t.Fatalf("unable to complete payment circuit: %v", err) 2578 } 2579 2580 case <-time.After(time.Second): 2581 t.Fatal("request was not propagated to destination") 2582 } 2583 2584 if s.circuits.NumOpen() != 2 { 2585 t.Fatal("wrong amount of circuits") 2586 } 2587 2588 // Create settle request pretending that bob channel link handled 2589 // the add htlc request and sent the htlc settle request back. This 2590 // request should be forwarder back to alice channel link. 2591 request = &htlcPacket{ 2592 outgoingChanID: bobChannelLink.ShortChanID(), 2593 outgoingHTLCID: 0, 2594 amount: 1, 2595 htlc: &lnwire.UpdateFailHTLC{}, 2596 } 2597 2598 // Handle the request and checks that payment circuit works properly. 2599 if err := s.ForwardPackets(nil, request); err != nil { 2600 t.Fatal(err) 2601 } 2602 2603 select { 2604 case pkt := <-aliceChannelLink.packets: 2605 if err := aliceChannelLink.completeCircuit(pkt); err != nil { 2606 t.Fatalf("unable to remove circuit: %v", err) 2607 } 2608 2609 case <-time.After(time.Second): 2610 t.Fatal("request was not propagated to channelPoint") 2611 } 2612 2613 if s.circuits.NumOpen() != 1 { 2614 t.Fatal("wrong amount of circuits") 2615 } 2616 2617 request = &htlcPacket{ 2618 outgoingChanID: bobChannelLink.ShortChanID(), 2619 outgoingHTLCID: 1, 2620 amount: 1, 2621 htlc: &lnwire.UpdateFailHTLC{}, 2622 } 2623 2624 // Handle the request and checks that payment circuit works properly. 2625 if err := s.ForwardPackets(nil, request); err != nil { 2626 t.Fatal(err) 2627 } 2628 2629 select { 2630 case pkt := <-aliceChannelLink.packets: 2631 if err := aliceChannelLink.completeCircuit(pkt); err != nil { 2632 t.Fatalf("unable to remove circuit: %v", err) 2633 } 2634 2635 case <-time.After(time.Second): 2636 t.Fatal("request was not propagated to channelPoint") 2637 } 2638 2639 if s.circuits.NumOpen() != 0 { 2640 t.Fatal("wrong amount of circuits") 2641 } 2642 } 2643 2644 // TestSwitchSendPayment tests ability of htlc switch to respond to the 2645 // users when response is came back from channel link. 2646 func TestSwitchSendPayment(t *testing.T) { 2647 t.Parallel() 2648 2649 alicePeer, err := newMockServer( 2650 t, "alice", testStartingHeight, nil, testDefaultDelta, 2651 ) 2652 require.NoError(t, err, "unable to create alice server") 2653 2654 s, err := initSwitchWithTempDB(t, testStartingHeight) 2655 require.NoError(t, err, "unable to init switch") 2656 if err := s.Start(); err != nil { 2657 t.Fatalf("unable to start switch: %v", err) 2658 } 2659 defer s.Stop() 2660 2661 chanID1, _, aliceChanID, _ := genIDs() 2662 2663 aliceChannelLink := newMockChannelLink( 2664 s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 2665 false, false, 2666 ) 2667 if err := s.AddLink(aliceChannelLink); err != nil { 2668 t.Fatalf("unable to add link: %v", err) 2669 } 2670 2671 // Create request which should be forwarder from alice channel link 2672 // to bob channel link. 2673 preimage, err := genPreimage() 2674 require.NoError(t, err, "unable to generate preimage") 2675 rhash := sha256.Sum256(preimage[:]) 2676 update := &lnwire.UpdateAddHTLC{ 2677 PaymentHash: rhash, 2678 Amount: 1, 2679 } 2680 paymentID := uint64(123) 2681 2682 // First check that the switch will correctly respond that this payment 2683 // ID is unknown. 2684 _, err = s.GetAttemptResult( 2685 paymentID, rhash, newMockDeobfuscator(), 2686 ) 2687 if err != ErrPaymentIDNotFound { 2688 t.Fatalf("expected ErrPaymentIDNotFound, got %v", err) 2689 } 2690 2691 // Handle the request and checks that bob channel link received it. 2692 errChan := make(chan error) 2693 go func() { 2694 err := s.SendHTLC( 2695 aliceChannelLink.ShortChanID(), paymentID, update, 2696 ) 2697 if err != nil { 2698 errChan <- err 2699 return 2700 } 2701 2702 resultChan, err := s.GetAttemptResult( 2703 paymentID, rhash, newMockDeobfuscator(), 2704 ) 2705 if err != nil { 2706 errChan <- err 2707 return 2708 } 2709 2710 result, ok := <-resultChan 2711 if !ok { 2712 errChan <- fmt.Errorf("shutting down") 2713 } 2714 2715 if result.Error != nil { 2716 errChan <- result.Error 2717 return 2718 } 2719 2720 errChan <- nil 2721 }() 2722 2723 select { 2724 case packet := <-aliceChannelLink.packets: 2725 if err := aliceChannelLink.completeCircuit(packet); err != nil { 2726 t.Fatalf("unable to complete payment circuit: %v", err) 2727 } 2728 2729 case err := <-errChan: 2730 if err != nil { 2731 t.Fatalf("unable to send payment: %v", err) 2732 } 2733 case <-time.After(time.Second): 2734 t.Fatal("request was not propagated to destination") 2735 } 2736 2737 if s.circuits.NumOpen() != 1 { 2738 t.Fatal("wrong amount of circuits") 2739 } 2740 2741 // Create fail request pretending that bob channel link handled 2742 // the add htlc request with error and sent the htlc fail request 2743 // back. This request should be forwarded back to alice channel link. 2744 obfuscator := NewMockObfuscator() 2745 failure := lnwire.NewFailIncorrectDetails(update.Amount, 100) 2746 reason, err := obfuscator.EncryptFirstHop(failure) 2747 require.NoError(t, err, "unable obfuscate failure") 2748 2749 if s.IsForwardedHTLC(aliceChannelLink.ShortChanID(), update.ID) { 2750 t.Fatal("htlc should be identified as not forwarded") 2751 } 2752 packet := &htlcPacket{ 2753 outgoingChanID: aliceChannelLink.ShortChanID(), 2754 outgoingHTLCID: 0, 2755 amount: 1, 2756 htlc: &lnwire.UpdateFailHTLC{ 2757 Reason: reason, 2758 }, 2759 } 2760 2761 if err := s.ForwardPackets(nil, packet); err != nil { 2762 t.Fatalf("can't forward htlc packet: %v", err) 2763 } 2764 2765 select { 2766 case err := <-errChan: 2767 assertFailureCode( 2768 t, err, lnwire.CodeIncorrectOrUnknownPaymentDetails, 2769 ) 2770 case <-time.After(time.Second): 2771 t.Fatal("err wasn't received") 2772 } 2773 } 2774 2775 // TestLocalPaymentNoForwardingEvents tests that if we send a series of locally 2776 // initiated payments, then they aren't reflected in the forwarding log. 2777 func TestLocalPaymentNoForwardingEvents(t *testing.T) { 2778 t.Parallel() 2779 2780 // First, we'll create our traditional three hop network. We'll only be 2781 // interacting with and asserting the state of the first end point for 2782 // this test. 2783 channels, _, err := createClusterChannels( 2784 t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5, 2785 ) 2786 require.NoError(t, err, "unable to create channel") 2787 2788 n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, 2789 channels.bobToCarol, channels.carolToBob, testStartingHeight) 2790 if err := n.start(); err != nil { 2791 t.Fatalf("unable to start three hop network: %v", err) 2792 } 2793 2794 // We'll now craft and send a payment from Alice to Bob. 2795 amount := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) 2796 htlcAmt, totalTimelock, hops := generateHops( 2797 amount, testStartingHeight, n.firstBobChannelLink, 2798 ) 2799 2800 // With the payment crafted, we'll send it from Alice to Bob. We'll 2801 // wait for Alice to receive the preimage for the payment before 2802 // proceeding. 2803 receiver := n.bobServer 2804 firstHop := n.firstBobChannelLink.ShortChanID() 2805 _, err = makePayment( 2806 n.aliceServer, receiver, firstHop, hops, amount, htlcAmt, 2807 totalTimelock, 2808 ).Wait(30 * time.Second) 2809 require.NoError(t, err, "unable to make the payment") 2810 2811 // At this point, we'll forcibly stop the three hop network. Doing 2812 // this will cause any pending forwarding events to be flushed by the 2813 // various switches in the network. 2814 n.stop() 2815 2816 // With all the switches stopped, we'll fetch Alice's mock forwarding 2817 // event log. 2818 log, ok := n.aliceServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog) 2819 if !ok { 2820 t.Fatalf("mockForwardingLog assertion failed") 2821 } 2822 log.Lock() 2823 defer log.Unlock() 2824 2825 // If we examine the memory of the forwarding log, then it should be 2826 // blank. 2827 if len(log.events) != 0 { 2828 t.Fatalf("log should have no events, instead has: %v", 2829 spew.Sdump(log.events)) 2830 } 2831 } 2832 2833 // TestMultiHopPaymentForwardingEvents tests that if we send a series of 2834 // multi-hop payments via Alice->Bob->Carol. Then Bob properly logs forwarding 2835 // events, while Alice and Carol don't. 2836 func TestMultiHopPaymentForwardingEvents(t *testing.T) { 2837 t.Parallel() 2838 2839 // First, we'll create our traditional three hop network. 2840 channels, _, err := createClusterChannels( 2841 t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5, 2842 ) 2843 require.NoError(t, err, "unable to create channel") 2844 2845 n := newThreeHopNetwork(t, channels.aliceToBob, channels.bobToAlice, 2846 channels.bobToCarol, channels.carolToBob, testStartingHeight) 2847 if err := n.start(); err != nil { 2848 t.Fatalf("unable to start three hop network: %v", err) 2849 } 2850 2851 // We'll make now 10 payments, of 100k satoshis each from Alice to 2852 // Carol via Bob. 2853 const numPayments = 10 2854 finalAmt := lnwire.NewMSatFromSatoshis(100000) 2855 htlcAmt, totalTimelock, hops := generateHops( 2856 finalAmt, testStartingHeight, n.firstBobChannelLink, 2857 n.carolChannelLink, 2858 ) 2859 firstHop := n.firstBobChannelLink.ShortChanID() 2860 for i := 0; i < numPayments/2; i++ { 2861 _, err := makePayment( 2862 n.aliceServer, n.carolServer, firstHop, hops, finalAmt, 2863 htlcAmt, totalTimelock, 2864 ).Wait(30 * time.Second) 2865 if err != nil { 2866 t.Fatalf("unable to send payment: %v", err) 2867 } 2868 } 2869 2870 bobLog, ok := n.bobServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog) 2871 if !ok { 2872 t.Fatalf("mockForwardingLog assertion failed") 2873 } 2874 2875 // After sending 5 of the payments, trigger the forwarding ticker, to 2876 // make sure the events are properly flushed. 2877 bobTicker, ok := n.bobServer.htlcSwitch.cfg.FwdEventTicker.(*ticker.Force) 2878 if !ok { 2879 t.Fatalf("mockTicker assertion failed") 2880 } 2881 2882 // We'll trigger the ticker, and wait for the events to appear in Bob's 2883 // forwarding log. 2884 timeout := time.After(15 * time.Second) 2885 for { 2886 select { 2887 case bobTicker.Force <- time.Now(): 2888 case <-time.After(1 * time.Second): 2889 t.Fatalf("unable to force tick") 2890 } 2891 2892 // If all 5 events is found in Bob's log, we can break out and 2893 // continue the test. 2894 bobLog.Lock() 2895 if len(bobLog.events) == 5 { 2896 bobLog.Unlock() 2897 break 2898 } 2899 bobLog.Unlock() 2900 2901 // Otherwise wait a little bit before checking again. 2902 select { 2903 case <-time.After(50 * time.Millisecond): 2904 case <-timeout: 2905 bobLog.Lock() 2906 defer bobLog.Unlock() 2907 t.Fatalf("expected 5 events in event log, instead "+ 2908 "found: %v", spew.Sdump(bobLog.events)) 2909 } 2910 } 2911 2912 // Send the remaining payments. 2913 for i := numPayments / 2; i < numPayments; i++ { 2914 _, err := makePayment( 2915 n.aliceServer, n.carolServer, firstHop, hops, finalAmt, 2916 htlcAmt, totalTimelock, 2917 ).Wait(30 * time.Second) 2918 if err != nil { 2919 t.Fatalf("unable to send payment: %v", err) 2920 } 2921 } 2922 2923 // With all 10 payments sent. We'll now manually stop each of the 2924 // switches so we can examine their end state. 2925 n.stop() 2926 2927 // Alice and Carol shouldn't have any recorded forwarding events, as 2928 // they were the source and the sink for these payment flows. 2929 aliceLog, ok := n.aliceServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog) 2930 if !ok { 2931 t.Fatalf("mockForwardingLog assertion failed") 2932 } 2933 aliceLog.Lock() 2934 defer aliceLog.Unlock() 2935 if len(aliceLog.events) != 0 { 2936 t.Fatalf("log should have no events, instead has: %v", 2937 spew.Sdump(aliceLog.events)) 2938 } 2939 2940 carolLog, ok := n.carolServer.htlcSwitch.cfg.FwdingLog.(*mockForwardingLog) 2941 if !ok { 2942 t.Fatalf("mockForwardingLog assertion failed") 2943 } 2944 carolLog.Lock() 2945 defer carolLog.Unlock() 2946 if len(carolLog.events) != 0 { 2947 t.Fatalf("log should have no events, instead has: %v", 2948 spew.Sdump(carolLog.events)) 2949 } 2950 2951 // Bob on the other hand, should have 10 events. 2952 bobLog.Lock() 2953 defer bobLog.Unlock() 2954 if len(bobLog.events) != 10 { 2955 t.Fatalf("log should have 10 events, instead has: %v", 2956 spew.Sdump(bobLog.events)) 2957 } 2958 2959 // Each of the 10 events should have had all fields set properly. 2960 for _, event := range bobLog.events { 2961 // The incoming and outgoing channels should properly be set for 2962 // the event. 2963 if event.IncomingChanID != n.aliceChannelLink.ShortChanID() { 2964 t.Fatalf("chan id mismatch: expected %v, got %v", 2965 event.IncomingChanID, 2966 n.aliceChannelLink.ShortChanID()) 2967 } 2968 if event.OutgoingChanID != n.carolChannelLink.ShortChanID() { 2969 t.Fatalf("chan id mismatch: expected %v, got %v", 2970 event.OutgoingChanID, 2971 n.carolChannelLink.ShortChanID()) 2972 } 2973 2974 // Additionally, the incoming and outgoing amounts should also 2975 // be properly set. 2976 if event.AmtIn != htlcAmt { 2977 t.Fatalf("incoming amt mismatch: expected %v, got %v", 2978 event.AmtIn, htlcAmt) 2979 } 2980 if event.AmtOut != finalAmt { 2981 t.Fatalf("outgoing amt mismatch: expected %v, got %v", 2982 event.AmtOut, finalAmt) 2983 } 2984 } 2985 } 2986 2987 // TestUpdateFailMalformedHTLCErrorConversion tests that we're able to properly 2988 // convert malformed HTLC errors that originate at the direct link, as well as 2989 // during multi-hop HTLC forwarding. 2990 func TestUpdateFailMalformedHTLCErrorConversion(t *testing.T) { 2991 t.Parallel() 2992 2993 // First, we'll create our traditional three hop network. 2994 channels, _, err := createClusterChannels( 2995 t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5, 2996 ) 2997 require.NoError(t, err, "unable to create channel") 2998 2999 n := newThreeHopNetwork( 3000 t, channels.aliceToBob, channels.bobToAlice, 3001 channels.bobToCarol, channels.carolToBob, testStartingHeight, 3002 ) 3003 if err := n.start(); err != nil { 3004 t.Fatalf("unable to start three hop network: %v", err) 3005 } 3006 3007 assertPaymentFailure := func(t *testing.T) { 3008 // With the decoder modified, we'll now attempt to send a 3009 // payment from Alice to carol. 3010 finalAmt := lnwire.NewMSatFromSatoshis(100000) 3011 htlcAmt, totalTimelock, hops := generateHops( 3012 finalAmt, testStartingHeight, n.firstBobChannelLink, 3013 n.carolChannelLink, 3014 ) 3015 firstHop := n.firstBobChannelLink.ShortChanID() 3016 _, err = makePayment( 3017 n.aliceServer, n.carolServer, firstHop, hops, finalAmt, 3018 htlcAmt, totalTimelock, 3019 ).Wait(30 * time.Second) 3020 3021 // The payment should fail as Carol is unable to decode the 3022 // onion blob sent to her. 3023 if err == nil { 3024 t.Fatalf("unable to send payment: %v", err) 3025 } 3026 3027 routingErr := err.(ClearTextError) 3028 failureMsg := routingErr.WireMessage() 3029 if _, ok := failureMsg.(*lnwire.FailInvalidOnionKey); !ok { 3030 t.Fatalf("expected onion failure instead got: %v", 3031 routingErr.WireMessage()) 3032 } 3033 } 3034 3035 t.Run("multi-hop error conversion", func(t *testing.T) { 3036 // Now that we have our network up, we'll modify the hop 3037 // iterator for the Bob <-> Carol channel to fail to decode in 3038 // order to simulate either a replay attack or an issue 3039 // decoding the onion. 3040 n.carolOnionDecoder.decodeFail = true 3041 3042 assertPaymentFailure(t) 3043 }) 3044 3045 t.Run("direct channel error conversion", func(t *testing.T) { 3046 // Similar to the above test case, we'll now make the Alice <-> 3047 // Bob link always fail to decode an onion. This differs from 3048 // the above test case in that there's no encryption on the 3049 // error at all since Alice will directly receive a 3050 // UpdateFailMalformedHTLC message. 3051 n.bobOnionDecoder.decodeFail = true 3052 3053 assertPaymentFailure(t) 3054 }) 3055 } 3056 3057 // TestSwitchGetAttemptResult tests that the switch interacts as expected with 3058 // the circuit map and network result store when looking up the result of a 3059 // payment ID. This is important for not to lose results under concurrent 3060 // lookup and receiving results. 3061 func TestSwitchGetAttemptResult(t *testing.T) { 3062 t.Parallel() 3063 3064 const paymentID = 123 3065 var preimg lntypes.Preimage 3066 preimg[0] = 3 3067 3068 s, err := initSwitchWithTempDB(t, testStartingHeight) 3069 require.NoError(t, err, "unable to init switch") 3070 if err := s.Start(); err != nil { 3071 t.Fatalf("unable to start switch: %v", err) 3072 } 3073 defer s.Stop() 3074 3075 lookup := make(chan *PaymentCircuit, 1) 3076 s.circuits = &mockCircuitMap{ 3077 lookup: lookup, 3078 } 3079 3080 // If the payment circuit is not found in the circuit map, the payment 3081 // result must be found in the store if available. Since we haven't 3082 // added anything to the store yet, ErrPaymentIDNotFound should be 3083 // returned. 3084 lookup <- nil 3085 _, err = s.GetAttemptResult( 3086 paymentID, lntypes.Hash{}, newMockDeobfuscator(), 3087 ) 3088 if err != ErrPaymentIDNotFound { 3089 t.Fatalf("expected ErrPaymentIDNotFound, got %v", err) 3090 } 3091 3092 // Next let the lookup find the circuit in the circuit map. It should 3093 // subscribe to payment results, and return the result when available. 3094 lookup <- &PaymentCircuit{} 3095 resultChan, err := s.GetAttemptResult( 3096 paymentID, lntypes.Hash{}, newMockDeobfuscator(), 3097 ) 3098 require.NoError(t, err, "unable to get payment result") 3099 3100 // Add the result to the store. 3101 n := &networkResult{ 3102 msg: &lnwire.UpdateFulfillHTLC{ 3103 PaymentPreimage: preimg, 3104 }, 3105 unencrypted: true, 3106 isResolution: true, 3107 } 3108 3109 err = s.networkResults.storeResult(paymentID, n) 3110 require.NoError(t, err, "unable to store result") 3111 3112 // The result should be available. 3113 select { 3114 case res, ok := <-resultChan: 3115 if !ok { 3116 t.Fatalf("channel was closed") 3117 } 3118 3119 if res.Error != nil { 3120 t.Fatalf("got unexpected error result") 3121 } 3122 3123 if res.Preimage != preimg { 3124 t.Fatalf("expected preimg %v, got %v", 3125 preimg, res.Preimage) 3126 } 3127 3128 case <-time.After(1 * time.Second): 3129 t.Fatalf("result not received") 3130 } 3131 3132 // As a final test, try to get the result again. Now that is no longer 3133 // in the circuit map, it should be immediately available from the 3134 // store. 3135 lookup <- nil 3136 resultChan, err = s.GetAttemptResult( 3137 paymentID, lntypes.Hash{}, newMockDeobfuscator(), 3138 ) 3139 require.NoError(t, err, "unable to get payment result") 3140 3141 select { 3142 case res, ok := <-resultChan: 3143 if !ok { 3144 t.Fatalf("channel was closed") 3145 } 3146 3147 if res.Error != nil { 3148 t.Fatalf("got unexpected error result") 3149 } 3150 3151 if res.Preimage != preimg { 3152 t.Fatalf("expected preimg %v, got %v", 3153 preimg, res.Preimage) 3154 } 3155 3156 case <-time.After(1 * time.Second): 3157 t.Fatalf("result not received") 3158 } 3159 } 3160 3161 // TestInvalidFailure tests that the switch returns an unreadable failure error 3162 // if the failure cannot be decrypted. 3163 func TestInvalidFailure(t *testing.T) { 3164 t.Parallel() 3165 3166 alicePeer, err := newMockServer( 3167 t, "alice", testStartingHeight, nil, testDefaultDelta, 3168 ) 3169 require.NoError(t, err, "unable to create alice server") 3170 3171 s, err := initSwitchWithTempDB(t, testStartingHeight) 3172 require.NoError(t, err, "unable to init switch") 3173 if err := s.Start(); err != nil { 3174 t.Fatalf("unable to start switch: %v", err) 3175 } 3176 defer s.Stop() 3177 3178 chanID1, _, aliceChanID, _ := genIDs() 3179 3180 // Set up a mock channel link. 3181 aliceChannelLink := newMockChannelLink( 3182 s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 3183 false, false, 3184 ) 3185 if err := s.AddLink(aliceChannelLink); err != nil { 3186 t.Fatalf("unable to add link: %v", err) 3187 } 3188 3189 // Create a request which should be forwarded to the mock channel link. 3190 preimage, err := genPreimage() 3191 require.NoError(t, err, "unable to generate preimage") 3192 rhash := sha256.Sum256(preimage[:]) 3193 update := &lnwire.UpdateAddHTLC{ 3194 PaymentHash: rhash, 3195 Amount: 1, 3196 } 3197 3198 paymentID := uint64(123) 3199 3200 // Send the request. 3201 err = s.SendHTLC( 3202 aliceChannelLink.ShortChanID(), paymentID, update, 3203 ) 3204 require.NoError(t, err, "unable to send payment") 3205 3206 // Catch the packet and complete the circuit so that the switch is ready 3207 // for a response. 3208 select { 3209 case packet := <-aliceChannelLink.packets: 3210 if err := aliceChannelLink.completeCircuit(packet); err != nil { 3211 t.Fatalf("unable to complete payment circuit: %v", err) 3212 } 3213 3214 case <-time.After(time.Second): 3215 t.Fatal("request was not propagated to destination") 3216 } 3217 3218 // Send response packet with an unreadable failure message to the 3219 // switch. The reason failed is not relevant, because we mock the 3220 // decryption. 3221 packet := &htlcPacket{ 3222 outgoingChanID: aliceChannelLink.ShortChanID(), 3223 outgoingHTLCID: 0, 3224 amount: 1, 3225 htlc: &lnwire.UpdateFailHTLC{ 3226 Reason: []byte{1, 2, 3}, 3227 }, 3228 } 3229 3230 if err := s.ForwardPackets(nil, packet); err != nil { 3231 t.Fatalf("can't forward htlc packet: %v", err) 3232 } 3233 3234 // Get payment result from switch. We expect an unreadable failure 3235 // message error. 3236 deobfuscator := SphinxErrorDecrypter{ 3237 OnionErrorDecrypter: &mockOnionErrorDecryptor{ 3238 err: ErrUnreadableFailureMessage, 3239 }, 3240 } 3241 3242 resultChan, err := s.GetAttemptResult( 3243 paymentID, rhash, &deobfuscator, 3244 ) 3245 if err != nil { 3246 t.Fatal(err) 3247 } 3248 3249 select { 3250 case result := <-resultChan: 3251 if result.Error != ErrUnreadableFailureMessage { 3252 t.Fatal("expected unreadable failure message") 3253 } 3254 3255 case <-time.After(time.Second): 3256 t.Fatal("err wasn't received") 3257 } 3258 3259 // Modify the decryption to simulate that decryption went alright, but 3260 // the failure cannot be decoded. 3261 deobfuscator = SphinxErrorDecrypter{ 3262 OnionErrorDecrypter: &mockOnionErrorDecryptor{ 3263 sourceIdx: 2, 3264 message: []byte{200}, 3265 }, 3266 } 3267 3268 resultChan, err = s.GetAttemptResult( 3269 paymentID, rhash, &deobfuscator, 3270 ) 3271 if err != nil { 3272 t.Fatal(err) 3273 } 3274 3275 select { 3276 case result := <-resultChan: 3277 rtErr, ok := result.Error.(ClearTextError) 3278 if !ok { 3279 t.Fatal("expected ClearTextError") 3280 } 3281 source, ok := rtErr.(*ForwardingError) 3282 if !ok { 3283 t.Fatalf("expected forwarding error, got: %T", rtErr) 3284 } 3285 if source.FailureSourceIdx != 2 { 3286 t.Fatal("unexpected error source index") 3287 } 3288 if rtErr.WireMessage() != nil { 3289 t.Fatal("expected empty failure message") 3290 } 3291 3292 case <-time.After(time.Second): 3293 t.Fatal("err wasn't received") 3294 } 3295 } 3296 3297 // htlcNotifierEvents is a function that generates a set of expected htlc 3298 // notifier evetns for each node in a three hop network with the dynamic 3299 // values provided. These functions take dynamic values so that changes to 3300 // external systems (such as our default timelock delta) do not break 3301 // these tests. 3302 type htlcNotifierEvents func(channels *clusterChannels, htlcID uint64, 3303 ts time.Time, htlc *lnwire.UpdateAddHTLC, 3304 hops []*hop.Payload, 3305 preimage *lntypes.Preimage) ([]interface{}, []interface{}, []interface{}) 3306 3307 // TestHtlcNotifier tests the notifying of htlc events that are routed over a 3308 // three hop network. It sets up an Alice -> Bob -> Carol network and routes 3309 // payments from Alice -> Carol to test events from the perspective of a 3310 // sending (Alice), forwarding (Bob) and receiving (Carol) node. Test cases 3311 // are present for saduccessful and failed payments. 3312 func TestHtlcNotifier(t *testing.T) { 3313 tests := []struct { 3314 name string 3315 3316 // Options is a set of options to apply to the three hop 3317 // network's servers. 3318 options []serverOption 3319 3320 // expectedEvents is a function which returns an expected set 3321 // of events for the test. 3322 expectedEvents htlcNotifierEvents 3323 3324 // iterations is the number of times we will send a payment, 3325 // this is used to send more than one payment to force non- 3326 // zero htlc indexes to make sure we aren't just checking 3327 // default values. 3328 iterations int 3329 }{ 3330 { 3331 name: "successful three hop payment", 3332 options: nil, 3333 expectedEvents: func(channels *clusterChannels, 3334 htlcID uint64, ts time.Time, 3335 htlc *lnwire.UpdateAddHTLC, 3336 hops []*hop.Payload, 3337 preimage *lntypes.Preimage) ([]interface{}, 3338 []interface{}, []interface{}) { 3339 3340 return getThreeHopEvents( 3341 channels, htlcID, ts, htlc, hops, nil, preimage, 3342 ) 3343 }, 3344 iterations: 2, 3345 }, 3346 { 3347 name: "failed at forwarding link", 3348 // Set a functional option which disables bob as a 3349 // forwarding node to force a payment error. 3350 options: []serverOption{ 3351 serverOptionRejectHtlc(false, true, false), 3352 }, 3353 expectedEvents: func(channels *clusterChannels, 3354 htlcID uint64, ts time.Time, 3355 htlc *lnwire.UpdateAddHTLC, 3356 hops []*hop.Payload, 3357 preimage *lntypes.Preimage) ([]interface{}, 3358 []interface{}, []interface{}) { 3359 3360 return getThreeHopEvents( 3361 channels, htlcID, ts, htlc, hops, 3362 &LinkError{ 3363 msg: &lnwire.FailChannelDisabled{}, 3364 FailureDetail: OutgoingFailureForwardsDisabled, 3365 }, 3366 preimage, 3367 ) 3368 }, 3369 iterations: 1, 3370 }, 3371 } 3372 3373 for _, test := range tests { 3374 test := test 3375 3376 t.Run(test.name, func(t *testing.T) { 3377 testHtcNotifier( 3378 t, test.options, test.iterations, 3379 test.expectedEvents, 3380 ) 3381 }) 3382 } 3383 } 3384 3385 // testHtcNotifier runs a htlc notifier test. 3386 func testHtcNotifier(t *testing.T, testOpts []serverOption, iterations int, 3387 getEvents htlcNotifierEvents) { 3388 3389 t.Parallel() 3390 3391 // First, we'll create our traditional three hop 3392 // network. 3393 channels, _, err := createClusterChannels( 3394 t, btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5, 3395 ) 3396 require.NoError(t, err, "unable to create channel") 3397 3398 // Mock time so that all events are reported with a static timestamp. 3399 now := time.Now() 3400 mockTime := func() time.Time { 3401 return now 3402 } 3403 3404 // Create htlc notifiers for each server in the three hop network and 3405 // start them. 3406 aliceNotifier := NewHtlcNotifier(mockTime) 3407 if err := aliceNotifier.Start(); err != nil { 3408 t.Fatalf("could not start alice notifier") 3409 } 3410 t.Cleanup(func() { 3411 if err := aliceNotifier.Stop(); err != nil { 3412 t.Fatalf("failed to stop alice notifier: %v", err) 3413 } 3414 }) 3415 3416 bobNotifier := NewHtlcNotifier(mockTime) 3417 if err := bobNotifier.Start(); err != nil { 3418 t.Fatalf("could not start bob notifier") 3419 } 3420 t.Cleanup(func() { 3421 if err := bobNotifier.Stop(); err != nil { 3422 t.Fatalf("failed to stop bob notifier: %v", err) 3423 } 3424 }) 3425 3426 carolNotifier := NewHtlcNotifier(mockTime) 3427 if err := carolNotifier.Start(); err != nil { 3428 t.Fatalf("could not start carol notifier") 3429 } 3430 t.Cleanup(func() { 3431 if err := carolNotifier.Stop(); err != nil { 3432 t.Fatalf("failed to stop carol notifier: %v", err) 3433 } 3434 }) 3435 3436 // Create a notifier server option which will set our htlc notifiers 3437 // for the three hop network. 3438 notifierOption := serverOptionWithHtlcNotifier( 3439 aliceNotifier, bobNotifier, carolNotifier, 3440 ) 3441 3442 // Add the htlcNotifier option to any other options 3443 // set in the test. 3444 options := append(testOpts, notifierOption) // nolint:gocritic 3445 3446 n := newThreeHopNetwork( 3447 t, channels.aliceToBob, 3448 channels.bobToAlice, channels.bobToCarol, 3449 channels.carolToBob, testStartingHeight, 3450 options..., 3451 ) 3452 if err := n.start(); err != nil { 3453 t.Fatalf("unable to start three hop "+ 3454 "network: %v", err) 3455 } 3456 t.Cleanup(n.stop) 3457 3458 // Before we forward anything, subscribe to htlc events 3459 // from each notifier. 3460 aliceEvents, err := aliceNotifier.SubscribeHtlcEvents() 3461 if err != nil { 3462 t.Fatalf("could not subscribe to alice's"+ 3463 " events: %v", err) 3464 } 3465 t.Cleanup(aliceEvents.Cancel) 3466 3467 bobEvents, err := bobNotifier.SubscribeHtlcEvents() 3468 if err != nil { 3469 t.Fatalf("could not subscribe to bob's"+ 3470 " events: %v", err) 3471 } 3472 t.Cleanup(bobEvents.Cancel) 3473 3474 carolEvents, err := carolNotifier.SubscribeHtlcEvents() 3475 if err != nil { 3476 t.Fatalf("could not subscribe to carol's"+ 3477 " events: %v", err) 3478 } 3479 t.Cleanup(carolEvents.Cancel) 3480 3481 // Send multiple payments, as specified by the test to test incrementing 3482 // of htlc ids. 3483 for i := 0; i < iterations; i++ { 3484 // We'll start off by making a payment from 3485 // Alice -> Bob -> Carol. The preimage, generated 3486 // by Carol's Invoice is expected in the Settle events 3487 htlc, hops, preimage := n.sendThreeHopPayment(t) 3488 3489 alice, bob, carol := getEvents( 3490 channels, uint64(i), now, htlc, hops, preimage, 3491 ) 3492 3493 checkHtlcEvents(t, aliceEvents.Updates(), alice) 3494 checkHtlcEvents(t, bobEvents.Updates(), bob) 3495 checkHtlcEvents(t, carolEvents.Updates(), carol) 3496 } 3497 } 3498 3499 // checkHtlcEvents checks that a subscription has the set of htlc events 3500 // we expect it to have. 3501 func checkHtlcEvents(t *testing.T, events <-chan interface{}, 3502 expectedEvents []interface{}) { 3503 3504 t.Helper() 3505 3506 for _, expected := range expectedEvents { 3507 select { 3508 case event := <-events: 3509 if !reflect.DeepEqual(event, expected) { 3510 t.Fatalf("expected %v, got: %v", expected, 3511 event) 3512 } 3513 3514 case <-time.After(5 * time.Second): 3515 t.Fatalf("expected event: %v", expected) 3516 } 3517 } 3518 3519 // Check that there are no unexpected events following. 3520 select { 3521 case event := <-events: 3522 t.Fatalf("unexpected event: %v", event) 3523 default: 3524 } 3525 } 3526 3527 // sendThreeHopPayment is a helper function which sends a payment over 3528 // Alice -> Bob -> Carol in a three hop network and returns Alice's first htlc 3529 // and the remainder of the hops. 3530 func (n *threeHopNetwork) sendThreeHopPayment(t *testing.T) (*lnwire.UpdateAddHTLC, 3531 []*hop.Payload, *lntypes.Preimage) { 3532 3533 amount := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) 3534 3535 htlcAmt, totalTimelock, hops := generateHops(amount, testStartingHeight, 3536 n.firstBobChannelLink, n.carolChannelLink) 3537 blob, err := generateRoute(hops...) 3538 if err != nil { 3539 t.Fatal(err) 3540 } 3541 invoice, htlc, pid, err := generatePayment( 3542 amount, htlcAmt, totalTimelock, blob, 3543 ) 3544 if err != nil { 3545 t.Fatal(err) 3546 } 3547 3548 err = n.carolServer.registry.AddInvoice( 3549 t.Context(), *invoice, htlc.PaymentHash, 3550 ) 3551 require.NoError(t, err, "unable to add invoice in carol registry") 3552 3553 if err := n.aliceServer.htlcSwitch.SendHTLC( 3554 n.firstBobChannelLink.ShortChanID(), pid, htlc, 3555 ); err != nil { 3556 t.Fatalf("could not send htlc") 3557 } 3558 3559 return htlc, hops, invoice.Terms.PaymentPreimage 3560 } 3561 3562 // getThreeHopEvents gets the set of htlc events that we expect for a payment 3563 // from Alice -> Bob -> Carol. If a non-nil link error is provided, the set 3564 // of events will fail on Bob's outgoing link. 3565 func getThreeHopEvents(channels *clusterChannels, htlcID uint64, 3566 ts time.Time, htlc *lnwire.UpdateAddHTLC, hops []*hop.Payload, 3567 linkError *LinkError, 3568 preimage *lntypes.Preimage) ([]interface{}, []interface{}, []interface{}) { 3569 3570 aliceKey := HtlcKey{ 3571 IncomingCircuit: zeroCircuit, 3572 OutgoingCircuit: models.CircuitKey{ 3573 ChanID: channels.aliceToBob.ShortChanID(), 3574 HtlcID: htlcID, 3575 }, 3576 } 3577 3578 // Alice always needs a forwarding event because she initiates the 3579 // send. 3580 aliceEvents := []interface{}{ 3581 &ForwardingEvent{ 3582 HtlcKey: aliceKey, 3583 HtlcInfo: HtlcInfo{ 3584 OutgoingTimeLock: htlc.Expiry, 3585 OutgoingAmt: htlc.Amount, 3586 }, 3587 HtlcEventType: HtlcEventTypeSend, 3588 Timestamp: ts, 3589 }, 3590 } 3591 3592 bobKey := HtlcKey{ 3593 IncomingCircuit: models.CircuitKey{ 3594 ChanID: channels.bobToAlice.ShortChanID(), 3595 HtlcID: htlcID, 3596 }, 3597 OutgoingCircuit: models.CircuitKey{ 3598 ChanID: channels.bobToCarol.ShortChanID(), 3599 HtlcID: htlcID, 3600 }, 3601 } 3602 3603 bobInfo := HtlcInfo{ 3604 IncomingTimeLock: htlc.Expiry, 3605 IncomingAmt: htlc.Amount, 3606 OutgoingTimeLock: hops[1].FwdInfo.OutgoingCTLV, 3607 OutgoingAmt: hops[1].FwdInfo.AmountToForward, 3608 } 3609 3610 // If we expect the payment to fail, we add failures for alice and 3611 // bob, and no events for carol because the payment never reaches her. 3612 if linkError != nil { 3613 aliceEvents = append(aliceEvents, 3614 &ForwardingFailEvent{ 3615 HtlcKey: aliceKey, 3616 HtlcEventType: HtlcEventTypeSend, 3617 Timestamp: ts, 3618 }, 3619 ) 3620 3621 bobEvents := []interface{}{ 3622 &LinkFailEvent{ 3623 HtlcKey: bobKey, 3624 HtlcInfo: bobInfo, 3625 HtlcEventType: HtlcEventTypeForward, 3626 LinkError: linkError, 3627 Incoming: false, 3628 Timestamp: ts, 3629 }, 3630 &FinalHtlcEvent{ 3631 CircuitKey: bobKey.IncomingCircuit, 3632 Settled: false, 3633 Offchain: true, 3634 Timestamp: ts, 3635 }, 3636 } 3637 3638 return aliceEvents, bobEvents, nil 3639 } 3640 3641 // If we want to get events for a successful payment, we add a settle 3642 // for alice, a forward and settle for bob and a receive settle for 3643 // carol. 3644 aliceEvents = append( 3645 aliceEvents, 3646 &SettleEvent{ 3647 HtlcKey: aliceKey, 3648 Preimage: *preimage, 3649 HtlcEventType: HtlcEventTypeSend, 3650 Timestamp: ts, 3651 }, 3652 ) 3653 3654 bobEvents := []interface{}{ 3655 &ForwardingEvent{ 3656 HtlcKey: bobKey, 3657 HtlcInfo: bobInfo, 3658 HtlcEventType: HtlcEventTypeForward, 3659 Timestamp: ts, 3660 }, 3661 &SettleEvent{ 3662 HtlcKey: bobKey, 3663 Preimage: *preimage, 3664 HtlcEventType: HtlcEventTypeForward, 3665 Timestamp: ts, 3666 }, 3667 &FinalHtlcEvent{ 3668 CircuitKey: bobKey.IncomingCircuit, 3669 Settled: true, 3670 Offchain: true, 3671 Timestamp: ts, 3672 }, 3673 } 3674 3675 carolEvents := []interface{}{ 3676 &SettleEvent{ 3677 HtlcKey: HtlcKey{ 3678 IncomingCircuit: models.CircuitKey{ 3679 ChanID: channels.carolToBob.ShortChanID(), 3680 HtlcID: htlcID, 3681 }, 3682 OutgoingCircuit: zeroCircuit, 3683 }, 3684 Preimage: *preimage, 3685 HtlcEventType: HtlcEventTypeReceive, 3686 Timestamp: ts, 3687 }, &FinalHtlcEvent{ 3688 CircuitKey: models.CircuitKey{ 3689 ChanID: channels.carolToBob.ShortChanID(), 3690 HtlcID: htlcID, 3691 }, 3692 Settled: true, 3693 Offchain: true, 3694 Timestamp: ts, 3695 }, 3696 } 3697 3698 return aliceEvents, bobEvents, carolEvents 3699 } 3700 3701 type mockForwardInterceptor struct { 3702 t *testing.T 3703 3704 interceptedChan chan InterceptedPacket 3705 } 3706 3707 func (m *mockForwardInterceptor) InterceptForwardHtlc( 3708 intercepted InterceptedPacket) error { 3709 3710 m.interceptedChan <- intercepted 3711 3712 return nil 3713 } 3714 3715 func (m *mockForwardInterceptor) getIntercepted() InterceptedPacket { 3716 m.t.Helper() 3717 3718 select { 3719 case p := <-m.interceptedChan: 3720 return p 3721 3722 case <-time.After(time.Second): 3723 require.Fail(m.t, "timeout") 3724 3725 return InterceptedPacket{} 3726 } 3727 } 3728 3729 func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) { 3730 if s.circuits.NumPending() != pending { 3731 t.Fatalf("wrong amount of half circuits, expected %v but "+ 3732 "got %v", pending, s.circuits.NumPending()) 3733 } 3734 if s.circuits.NumOpen() != opened { 3735 t.Fatalf("wrong amount of circuits, expected %v but got %v", 3736 opened, s.circuits.NumOpen()) 3737 } 3738 } 3739 3740 func assertOutgoingLinkReceive(t *testing.T, targetLink *mockChannelLink, 3741 expectReceive bool) *htlcPacket { 3742 3743 // Pull packet from targetLink link. 3744 select { 3745 case packet := <-targetLink.packets: 3746 if !expectReceive { 3747 t.Fatal("forward was intercepted, shouldn't land at bob link") 3748 } else if err := targetLink.completeCircuit(packet); err != nil { 3749 t.Fatalf("unable to complete payment circuit: %v", err) 3750 } 3751 3752 return packet 3753 3754 case <-time.After(time.Second): 3755 if expectReceive { 3756 t.Fatal("request was not propagated to destination") 3757 } 3758 } 3759 3760 return nil 3761 } 3762 3763 func assertOutgoingLinkReceiveIntercepted(t *testing.T, 3764 targetLink *mockChannelLink) { 3765 3766 t.Helper() 3767 3768 select { 3769 case <-targetLink.packets: 3770 case <-time.After(time.Second): 3771 t.Fatal("request was not propagated to destination") 3772 } 3773 } 3774 3775 type interceptableSwitchTestContext struct { 3776 t *testing.T 3777 3778 preimage [sha256.Size]byte 3779 rhash [32]byte 3780 onionBlob [1366]byte 3781 incomingHtlcID uint64 3782 cltvRejectDelta uint32 3783 cltvInterceptDelta uint32 3784 3785 forwardInterceptor *mockForwardInterceptor 3786 aliceChannelLink *mockChannelLink 3787 bobChannelLink *mockChannelLink 3788 s *Switch 3789 } 3790 3791 func newInterceptableSwitchTestContext( 3792 t *testing.T) *interceptableSwitchTestContext { 3793 3794 chanID1, chanID2, aliceChanID, bobChanID := genIDs() 3795 3796 alicePeer, err := newMockServer( 3797 t, "alice", testStartingHeight, nil, testDefaultDelta, 3798 ) 3799 require.NoError(t, err, "unable to create alice server") 3800 bobPeer, err := newMockServer( 3801 t, "bob", testStartingHeight, nil, testDefaultDelta, 3802 ) 3803 require.NoError(t, err, "unable to create bob server") 3804 3805 tempPath := t.TempDir() 3806 3807 cdb := channeldb.OpenForTesting(t, tempPath) 3808 3809 s, err := initSwitchWithDB(testStartingHeight, cdb) 3810 require.NoError(t, err, "unable to init switch") 3811 if err := s.Start(); err != nil { 3812 t.Fatalf("unable to start switch: %v", err) 3813 } 3814 3815 aliceChannelLink := newMockChannelLink( 3816 s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 3817 false, false, 3818 ) 3819 bobChannelLink := newMockChannelLink( 3820 s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 3821 false, 3822 ) 3823 if err := s.AddLink(aliceChannelLink); err != nil { 3824 t.Fatalf("unable to add alice link: %v", err) 3825 } 3826 if err := s.AddLink(bobChannelLink); err != nil { 3827 t.Fatalf("unable to add bob link: %v", err) 3828 } 3829 3830 preimage := [sha256.Size]byte{1} 3831 3832 ctx := &interceptableSwitchTestContext{ 3833 t: t, 3834 preimage: preimage, 3835 rhash: sha256.Sum256(preimage[:]), 3836 onionBlob: [1366]byte{4, 5, 6}, 3837 incomingHtlcID: uint64(0), 3838 cltvRejectDelta: 10, 3839 cltvInterceptDelta: 13, 3840 forwardInterceptor: &mockForwardInterceptor{ 3841 t: t, 3842 interceptedChan: make(chan InterceptedPacket), 3843 }, 3844 aliceChannelLink: aliceChannelLink, 3845 bobChannelLink: bobChannelLink, 3846 s: s, 3847 } 3848 3849 return ctx 3850 } 3851 3852 func (c *interceptableSwitchTestContext) createTestPacket() *htlcPacket { 3853 c.incomingHtlcID++ 3854 3855 return &htlcPacket{ 3856 incomingChanID: c.aliceChannelLink.ShortChanID(), 3857 incomingHTLCID: c.incomingHtlcID, 3858 incomingTimeout: testStartingHeight + c.cltvInterceptDelta + 1, 3859 outgoingChanID: c.bobChannelLink.ShortChanID(), 3860 obfuscator: NewMockObfuscator(), 3861 htlc: &lnwire.UpdateAddHTLC{ 3862 PaymentHash: c.rhash, 3863 Amount: 1, 3864 OnionBlob: c.onionBlob, 3865 }, 3866 } 3867 } 3868 3869 func (c *interceptableSwitchTestContext) finish() { 3870 if err := c.s.Stop(); err != nil { 3871 c.t.Fatal(err) 3872 } 3873 } 3874 3875 func (c *interceptableSwitchTestContext) createSettlePacket( 3876 outgoingHTLCID uint64) *htlcPacket { 3877 3878 return &htlcPacket{ 3879 outgoingChanID: c.bobChannelLink.ShortChanID(), 3880 outgoingHTLCID: outgoingHTLCID, 3881 amount: 1, 3882 htlc: &lnwire.UpdateFulfillHTLC{ 3883 PaymentPreimage: c.preimage, 3884 }, 3885 } 3886 } 3887 3888 func TestSwitchHoldForward(t *testing.T) { 3889 t.Parallel() 3890 3891 c := newInterceptableSwitchTestContext(t) 3892 defer c.finish() 3893 3894 notifier := &mock.ChainNotifier{ 3895 EpochChan: make(chan *chainntnfs.BlockEpoch, 1), 3896 } 3897 notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight} 3898 3899 switchForwardInterceptor, err := NewInterceptableSwitch( 3900 &InterceptableSwitchConfig{ 3901 Switch: c.s, 3902 CltvRejectDelta: c.cltvRejectDelta, 3903 CltvInterceptDelta: c.cltvInterceptDelta, 3904 Notifier: notifier, 3905 }, 3906 ) 3907 require.NoError(t, err) 3908 require.NoError(t, switchForwardInterceptor.Start()) 3909 3910 switchForwardInterceptor.SetInterceptor(c.forwardInterceptor.InterceptForwardHtlc) 3911 linkQuit := make(chan struct{}) 3912 3913 // Test a forward that expires too soon. 3914 packet := c.createTestPacket() 3915 packet.incomingTimeout = testStartingHeight + c.cltvRejectDelta - 1 3916 3917 err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet) 3918 require.NoError(t, err, "can't forward htlc packet") 3919 assertOutgoingLinkReceive(t, c.bobChannelLink, false) 3920 assertOutgoingLinkReceiveIntercepted(t, c.aliceChannelLink) 3921 assertNumCircuits(t, c.s, 0, 0) 3922 3923 // Test a forward that expires too soon and can't be failed. 3924 packet = c.createTestPacket() 3925 packet.incomingTimeout = testStartingHeight + c.cltvRejectDelta - 1 3926 3927 // Simulate an error during the composition of the failure message. 3928 currentCallback := c.s.cfg.FetchLastChannelUpdate 3929 c.s.cfg.FetchLastChannelUpdate = func( 3930 lnwire.ShortChannelID) (*lnwire.ChannelUpdate1, error) { 3931 3932 return nil, errors.New("cannot fetch update") 3933 } 3934 3935 err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet) 3936 require.NoError(t, err, "can't forward htlc packet") 3937 receivedPkt := assertOutgoingLinkReceive(t, c.bobChannelLink, true) 3938 assertNumCircuits(t, c.s, 1, 1) 3939 3940 require.NoError(t, switchForwardInterceptor.ForwardPackets( 3941 linkQuit, false, 3942 c.createSettlePacket(receivedPkt.outgoingHTLCID), 3943 )) 3944 3945 assertOutgoingLinkReceive(t, c.aliceChannelLink, true) 3946 assertNumCircuits(t, c.s, 0, 0) 3947 3948 c.s.cfg.FetchLastChannelUpdate = currentCallback 3949 3950 // Test resume a hold forward. 3951 assertNumCircuits(t, c.s, 0, 0) 3952 err = switchForwardInterceptor.ForwardPackets( 3953 linkQuit, false, c.createTestPacket(), 3954 ) 3955 require.NoError(t, err) 3956 3957 assertNumCircuits(t, c.s, 0, 0) 3958 assertOutgoingLinkReceive(t, c.bobChannelLink, false) 3959 3960 require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{ 3961 Action: FwdActionResume, 3962 Key: c.forwardInterceptor.getIntercepted().IncomingCircuit, 3963 })) 3964 receivedPkt = assertOutgoingLinkReceive(t, c.bobChannelLink, true) 3965 assertNumCircuits(t, c.s, 1, 1) 3966 3967 // settling the htlc to close the circuit. 3968 err = switchForwardInterceptor.ForwardPackets( 3969 linkQuit, false, 3970 c.createSettlePacket(receivedPkt.outgoingHTLCID), 3971 ) 3972 require.NoError(t, err) 3973 3974 assertOutgoingLinkReceive(t, c.aliceChannelLink, true) 3975 assertNumCircuits(t, c.s, 0, 0) 3976 3977 // Test resume a hold forward after disconnection. 3978 require.NoError(t, switchForwardInterceptor.ForwardPackets( 3979 linkQuit, false, c.createTestPacket(), 3980 )) 3981 3982 // Wait until the packet is offered to the interceptor. 3983 _ = c.forwardInterceptor.getIntercepted() 3984 3985 // No forward expected yet. 3986 assertNumCircuits(t, c.s, 0, 0) 3987 assertOutgoingLinkReceive(t, c.bobChannelLink, false) 3988 3989 // Disconnect should resume the forwarding. 3990 switchForwardInterceptor.SetInterceptor(nil) 3991 3992 receivedPkt = assertOutgoingLinkReceive(t, c.bobChannelLink, true) 3993 assertNumCircuits(t, c.s, 1, 1) 3994 3995 // Settle the htlc to close the circuit. 3996 require.NoError(t, switchForwardInterceptor.ForwardPackets( 3997 linkQuit, false, 3998 c.createSettlePacket(receivedPkt.outgoingHTLCID), 3999 )) 4000 4001 assertOutgoingLinkReceive(t, c.aliceChannelLink, true) 4002 assertNumCircuits(t, c.s, 0, 0) 4003 4004 // Test failing a hold forward 4005 switchForwardInterceptor.SetInterceptor( 4006 c.forwardInterceptor.InterceptForwardHtlc, 4007 ) 4008 4009 require.NoError(t, switchForwardInterceptor.ForwardPackets( 4010 linkQuit, false, c.createTestPacket(), 4011 )) 4012 assertNumCircuits(t, c.s, 0, 0) 4013 assertOutgoingLinkReceive(t, c.bobChannelLink, false) 4014 4015 require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{ 4016 Action: FwdActionFail, 4017 Key: c.forwardInterceptor.getIntercepted().IncomingCircuit, 4018 FailureCode: lnwire.CodeTemporaryChannelFailure, 4019 })) 4020 assertOutgoingLinkReceive(t, c.bobChannelLink, false) 4021 assertOutgoingLinkReceive(t, c.aliceChannelLink, true) 4022 assertNumCircuits(t, c.s, 0, 0) 4023 4024 // Test failing a hold forward with a failure message. 4025 require.NoError(t, 4026 switchForwardInterceptor.ForwardPackets( 4027 linkQuit, false, c.createTestPacket(), 4028 ), 4029 ) 4030 assertNumCircuits(t, c.s, 0, 0) 4031 assertOutgoingLinkReceive(t, c.bobChannelLink, false) 4032 4033 reason := lnwire.OpaqueReason([]byte{1, 2, 3}) 4034 require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{ 4035 Action: FwdActionFail, 4036 Key: c.forwardInterceptor.getIntercepted().IncomingCircuit, 4037 FailureMessage: reason, 4038 })) 4039 4040 assertOutgoingLinkReceive(t, c.bobChannelLink, false) 4041 packet = assertOutgoingLinkReceive(t, c.aliceChannelLink, true) 4042 4043 require.Equal(t, reason, packet.htlc.(*lnwire.UpdateFailHTLC).Reason) 4044 4045 assertNumCircuits(t, c.s, 0, 0) 4046 4047 // Test failing a hold forward with a malformed htlc failure. 4048 err = switchForwardInterceptor.ForwardPackets( 4049 linkQuit, false, c.createTestPacket(), 4050 ) 4051 require.NoError(t, err) 4052 4053 assertNumCircuits(t, c.s, 0, 0) 4054 assertOutgoingLinkReceive(t, c.bobChannelLink, false) 4055 4056 code := lnwire.CodeInvalidOnionKey 4057 require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{ 4058 Action: FwdActionFail, 4059 Key: c.forwardInterceptor.getIntercepted().IncomingCircuit, 4060 FailureCode: code, 4061 })) 4062 4063 assertOutgoingLinkReceive(t, c.bobChannelLink, false) 4064 packet = assertOutgoingLinkReceive(t, c.aliceChannelLink, true) 4065 failPacket := packet.htlc.(*lnwire.UpdateFailHTLC) 4066 4067 shaOnionBlob := sha256.Sum256(c.onionBlob[:]) 4068 expectedFailure := &lnwire.FailInvalidOnionKey{ 4069 OnionSHA256: shaOnionBlob, 4070 } 4071 4072 fwdErr, err := newMockDeobfuscator().DecryptError(failPacket.Reason) 4073 require.NoError(t, err) 4074 require.Equal(t, expectedFailure, fwdErr.WireMessage()) 4075 4076 assertNumCircuits(t, c.s, 0, 0) 4077 4078 // Test settling a hold forward 4079 require.NoError(t, switchForwardInterceptor.ForwardPackets( 4080 linkQuit, false, c.createTestPacket(), 4081 )) 4082 assertNumCircuits(t, c.s, 0, 0) 4083 assertOutgoingLinkReceive(t, c.bobChannelLink, false) 4084 4085 require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{ 4086 Key: c.forwardInterceptor.getIntercepted().IncomingCircuit, 4087 Action: FwdActionSettle, 4088 Preimage: c.preimage, 4089 })) 4090 assertOutgoingLinkReceive(t, c.bobChannelLink, false) 4091 assertOutgoingLinkReceive(t, c.aliceChannelLink, true) 4092 assertNumCircuits(t, c.s, 0, 0) 4093 4094 require.NoError(t, switchForwardInterceptor.Stop()) 4095 4096 // Test always-on interception. 4097 notifier = &mock.ChainNotifier{ 4098 EpochChan: make(chan *chainntnfs.BlockEpoch, 1), 4099 } 4100 notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight} 4101 4102 switchForwardInterceptor, err = NewInterceptableSwitch( 4103 &InterceptableSwitchConfig{ 4104 Switch: c.s, 4105 CltvRejectDelta: c.cltvRejectDelta, 4106 CltvInterceptDelta: c.cltvInterceptDelta, 4107 RequireInterceptor: true, 4108 Notifier: notifier, 4109 }, 4110 ) 4111 require.NoError(t, err) 4112 require.NoError(t, switchForwardInterceptor.Start()) 4113 4114 // Forward a fresh packet. It is expected to be failed immediately, 4115 // because there is no interceptor registered. 4116 require.NoError(t, switchForwardInterceptor.ForwardPackets( 4117 linkQuit, false, c.createTestPacket(), 4118 )) 4119 4120 assertOutgoingLinkReceive(t, c.bobChannelLink, false) 4121 assertOutgoingLinkReceive(t, c.aliceChannelLink, true) 4122 assertNumCircuits(t, c.s, 0, 0) 4123 4124 // Forward a replayed packet. It is expected to be held until the 4125 // interceptor connects. To continue the test, it needs to be ran in a 4126 // goroutine. 4127 errChan := make(chan error) 4128 go func() { 4129 errChan <- switchForwardInterceptor.ForwardPackets( 4130 linkQuit, true, c.createTestPacket(), 4131 ) 4132 }() 4133 4134 // Assert that nothing is forward to the switch. 4135 assertOutgoingLinkReceive(t, c.bobChannelLink, false) 4136 assertNumCircuits(t, c.s, 0, 0) 4137 4138 // Register an interceptor. 4139 switchForwardInterceptor.SetInterceptor( 4140 c.forwardInterceptor.InterceptForwardHtlc, 4141 ) 4142 4143 // Expect the ForwardPackets call to unblock. 4144 require.NoError(t, <-errChan) 4145 4146 // Now expect the queued packet to come through. 4147 c.forwardInterceptor.getIntercepted() 4148 4149 // Disconnect and reconnect interceptor. 4150 switchForwardInterceptor.SetInterceptor(nil) 4151 switchForwardInterceptor.SetInterceptor( 4152 c.forwardInterceptor.InterceptForwardHtlc, 4153 ) 4154 4155 // A replay of the held packet is expected. 4156 intercepted := c.forwardInterceptor.getIntercepted() 4157 4158 // Settle the packet. 4159 require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{ 4160 Key: intercepted.IncomingCircuit, 4161 Action: FwdActionSettle, 4162 Preimage: c.preimage, 4163 })) 4164 assertOutgoingLinkReceive(t, c.bobChannelLink, false) 4165 assertOutgoingLinkReceive(t, c.aliceChannelLink, true) 4166 assertNumCircuits(t, c.s, 0, 0) 4167 4168 require.NoError(t, switchForwardInterceptor.Stop()) 4169 4170 select { 4171 case <-c.forwardInterceptor.interceptedChan: 4172 require.Fail(t, "unexpected interception") 4173 4174 default: 4175 } 4176 } 4177 4178 func TestInterceptableSwitchWatchDog(t *testing.T) { 4179 t.Parallel() 4180 4181 c := newInterceptableSwitchTestContext(t) 4182 defer c.finish() 4183 4184 // Start interceptable switch. 4185 notifier := &mock.ChainNotifier{ 4186 EpochChan: make(chan *chainntnfs.BlockEpoch, 1), 4187 } 4188 notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight} 4189 4190 switchForwardInterceptor, err := NewInterceptableSwitch( 4191 &InterceptableSwitchConfig{ 4192 Switch: c.s, 4193 CltvRejectDelta: c.cltvRejectDelta, 4194 CltvInterceptDelta: c.cltvInterceptDelta, 4195 Notifier: notifier, 4196 }, 4197 ) 4198 require.NoError(t, err) 4199 require.NoError(t, switchForwardInterceptor.Start()) 4200 4201 // Set interceptor. 4202 switchForwardInterceptor.SetInterceptor( 4203 c.forwardInterceptor.InterceptForwardHtlc, 4204 ) 4205 4206 // Receive a packet. 4207 linkQuit := make(chan struct{}) 4208 4209 packet := c.createTestPacket() 4210 4211 err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet) 4212 require.NoError(t, err, "can't forward htlc packet") 4213 4214 // Intercept the packet. 4215 intercepted := c.forwardInterceptor.getIntercepted() 4216 4217 require.Equal(t, 4218 int32(packet.incomingTimeout-c.cltvRejectDelta), 4219 intercepted.AutoFailHeight, 4220 ) 4221 4222 // Htlc expires before a resolution from the interceptor. 4223 notifier.EpochChan <- &chainntnfs.BlockEpoch{ 4224 Height: int32(packet.incomingTimeout) - 4225 int32(c.cltvRejectDelta), 4226 } 4227 4228 // Expect the htlc to be failed back. 4229 assertOutgoingLinkReceive(t, c.aliceChannelLink, true) 4230 4231 // It is too late now to resolve. Expect an error. 4232 require.Error(t, switchForwardInterceptor.Resolve(&FwdResolution{ 4233 Action: FwdActionSettle, 4234 Key: intercepted.IncomingCircuit, 4235 Preimage: c.preimage, 4236 })) 4237 } 4238 4239 // TestSwitchDustForwarding tests that the switch properly fails HTLC's which 4240 // have incoming or outgoing links that breach their fee thresholds. 4241 func TestSwitchDustForwarding(t *testing.T) { 4242 t.Parallel() 4243 4244 // We'll create a three-hop network: 4245 // - Alice has a dust limit of 200sats with Bob 4246 // - Bob has a dust limit of 800sats with Alice 4247 // - Bob has a dust limit of 200sats with Carol 4248 // - Carol has a dust limit of 800sats with Bob 4249 channels, _, err := createClusterChannels( 4250 t, btcutil.SatoshiPerBitcoin, btcutil.SatoshiPerBitcoin, 4251 ) 4252 require.NoError(t, err) 4253 4254 n := newThreeHopNetwork( 4255 t, channels.aliceToBob, channels.bobToAlice, 4256 channels.bobToCarol, channels.carolToBob, testStartingHeight, 4257 ) 4258 err = n.start() 4259 require.NoError(t, err) 4260 4261 // We'll also put Alice and Bob into hodl.ExitSettle mode, such that 4262 // they won't settle incoming exit-hop HTLC's automatically. 4263 n.aliceChannelLink.cfg.HodlMask = hodl.ExitSettle.Mask() 4264 n.firstBobChannelLink.cfg.HodlMask = hodl.ExitSettle.Mask() 4265 4266 // We'll test that once the default threshold is exceeded on the 4267 // Alice -> Bob channel, either side's calls to SendHTLC will fail. 4268 numHTLCs := maxInflightHtlcs 4269 aliceAttemptID, bobAttemptID := numHTLCs, numHTLCs 4270 amt := lnwire.NewMSatFromSatoshis(700) 4271 aliceBobFirstHop := n.aliceChannelLink.ShortChanID() 4272 4273 // We decreased the max number of inflight HTLCs therefore we also need 4274 // do decrease the max fee exposure. 4275 maxFeeExposure := lnwire.NewMSatFromSatoshis(74500) 4276 n.aliceChannelLink.cfg.MaxFeeExposure = maxFeeExposure 4277 n.firstBobChannelLink.cfg.MaxFeeExposure = maxFeeExposure 4278 4279 // Alice will send 50 HTLC's of 700sats. Bob will also send 50 HTLC's 4280 // of 700sats. 4281 sendDustHtlcs(t, n, true, amt, aliceBobFirstHop, numHTLCs) 4282 sendDustHtlcs(t, n, false, amt, aliceBobFirstHop, numHTLCs) 4283 4284 // Generate the parameters needed for Bob to send another dust HTLC. 4285 _, timelock, hops := generateHops( 4286 amt, testStartingHeight, n.aliceChannelLink, 4287 ) 4288 4289 blob, err := generateRoute(hops...) 4290 require.NoError(t, err) 4291 4292 // Assert that if Bob sends a dust HTLC it will fail. 4293 failingPreimage := lntypes.Preimage{0, 0, 3} 4294 failingHash := failingPreimage.Hash() 4295 failingHtlc := &lnwire.UpdateAddHTLC{ 4296 PaymentHash: failingHash, 4297 Amount: amt, 4298 Expiry: timelock, 4299 OnionBlob: blob, 4300 } 4301 4302 // This is the expected dust without taking the commitfee into account. 4303 expectedDust := maxInflightHtlcs * 2 * amt 4304 4305 assertAlmostDust := func(link *channelLink, mbox MailBox, 4306 whoseCommit lntypes.ChannelParty) { 4307 4308 err := wait.NoError(func() error { 4309 linkDust := link.getDustSum( 4310 whoseCommit, fn.None[chainfee.SatPerKWeight](), 4311 ) 4312 localMailDust, remoteMailDust := mbox.DustPackets() 4313 4314 totalDust := linkDust 4315 if whoseCommit.IsRemote() { 4316 totalDust += remoteMailDust 4317 } else { 4318 totalDust += localMailDust 4319 } 4320 4321 if totalDust == expectedDust { 4322 return nil 4323 } 4324 4325 return fmt.Errorf("got totalDust=%v, expectedDust=%v", 4326 totalDust, expectedDust) 4327 }, 15*time.Second) 4328 require.NoError(t, err, "timeout checking dust") 4329 } 4330 4331 // Wait until Bob is almost at the fee threshold. 4332 bobMbox := n.bobServer.htlcSwitch.mailOrchestrator.GetOrCreateMailBox( 4333 n.firstBobChannelLink.ChanID(), 4334 n.firstBobChannelLink.ShortChanID(), 4335 ) 4336 assertAlmostDust(n.firstBobChannelLink, bobMbox, lntypes.Local) 4337 4338 // Sending one more HTLC should fail. SendHTLC won't error, but the 4339 // HTLC should be failed backwards. When sending we only check for the 4340 // dust amount without the commitment fee. When the HTLC is added to the 4341 // commitment state (link) we also take into account the commitment fee 4342 // and with a fee of 6000 sat/kw and a commitment size of 724 (non 4343 // anchor channel) we are overexposed in fees (maxFeeExposure) that's 4344 // why the HTLC is failed back. 4345 err = n.bobServer.htlcSwitch.SendHTLC( 4346 aliceBobFirstHop, uint64(bobAttemptID), failingHtlc, 4347 ) 4348 require.Nil(t, err) 4349 4350 // Use the network result store to ensure the HTLC was failed 4351 // backwards. 4352 bobResultChan, err := n.bobServer.htlcSwitch.GetAttemptResult( 4353 uint64(bobAttemptID), failingHash, newMockDeobfuscator(), 4354 ) 4355 require.NoError(t, err) 4356 4357 result, ok := <-bobResultChan 4358 require.True(t, ok) 4359 assertFailureCode( 4360 t, result.Error, lnwire.CodeTemporaryChannelFailure, 4361 ) 4362 4363 bobAttemptID++ 4364 4365 // Generate the parameters needed for bob to send a non-dust HTLC. 4366 nondustAmt := lnwire.NewMSatFromSatoshis(10_000) 4367 _, _, hops = generateHops( 4368 nondustAmt, testStartingHeight, n.aliceChannelLink, 4369 ) 4370 4371 blob, err = generateRoute(hops...) 4372 require.NoError(t, err) 4373 4374 // Now attempt to send an HTLC above Bob's dust limit. Even though this 4375 // is not a dust HTLC, it should fail because the increase in weight 4376 // pushes us over the threshold. 4377 nondustPreimage := lntypes.Preimage{0, 0, 4} 4378 nondustHash := nondustPreimage.Hash() 4379 nondustHtlc := &lnwire.UpdateAddHTLC{ 4380 PaymentHash: nondustHash, 4381 Amount: nondustAmt, 4382 Expiry: timelock, 4383 OnionBlob: blob, 4384 } 4385 4386 err = n.bobServer.htlcSwitch.SendHTLC( 4387 aliceBobFirstHop, uint64(bobAttemptID), nondustHtlc, 4388 ) 4389 require.NoError(t, err) 4390 assertAlmostDust(n.firstBobChannelLink, bobMbox, lntypes.Local) 4391 4392 // Check that the HTLC failed. 4393 bobResultChan, err = n.bobServer.htlcSwitch.GetAttemptResult( 4394 uint64(bobAttemptID), nondustHash, newMockDeobfuscator(), 4395 ) 4396 require.NoError(t, err) 4397 4398 result, ok = <-bobResultChan 4399 require.True(t, ok) 4400 assertFailureCode( 4401 t, result.Error, lnwire.CodeTemporaryChannelFailure, 4402 ) 4403 4404 // Introduce Carol into the mix and assert that sending a multi-hop 4405 // dust HTLC to Alice will fail. Bob should fail back the HTLC with a 4406 // temporary channel failure. 4407 carolAmt, carolTimelock, carolHops := generateHops( 4408 amt, testStartingHeight, n.secondBobChannelLink, 4409 n.aliceChannelLink, 4410 ) 4411 4412 carolBlob, err := generateRoute(carolHops...) 4413 require.NoError(t, err) 4414 4415 carolPreimage := lntypes.Preimage{0, 0, 5} 4416 carolHash := carolPreimage.Hash() 4417 carolHtlc := &lnwire.UpdateAddHTLC{ 4418 PaymentHash: carolHash, 4419 Amount: carolAmt, 4420 Expiry: carolTimelock, 4421 OnionBlob: carolBlob, 4422 } 4423 4424 // Initialize Carol's attempt ID. 4425 carolAttemptID := 0 4426 4427 err = n.carolServer.htlcSwitch.SendHTLC( 4428 n.carolChannelLink.ShortChanID(), uint64(carolAttemptID), 4429 carolHtlc, 4430 ) 4431 require.NoError(t, err) 4432 4433 carolResultChan, err := n.carolServer.htlcSwitch.GetAttemptResult( 4434 uint64(carolAttemptID), carolHash, newMockDeobfuscator(), 4435 ) 4436 require.NoError(t, err) 4437 4438 result, ok = <-carolResultChan 4439 require.True(t, ok) 4440 assertFailureCode( 4441 t, result.Error, lnwire.CodeTemporaryChannelFailure, 4442 ) 4443 4444 // Send an HTLC from Alice to Carol and assert that it gets failed. 4445 htlcAmt, totalTimelock, aliceHops := generateHops( 4446 amt, testStartingHeight, n.firstBobChannelLink, 4447 n.carolChannelLink, 4448 ) 4449 4450 blob, err = generateRoute(aliceHops...) 4451 require.NoError(t, err) 4452 4453 aliceMultihopPreimage := lntypes.Preimage{0, 0, 6} 4454 aliceMultihopHash := aliceMultihopPreimage.Hash() 4455 aliceMultihopHtlc := &lnwire.UpdateAddHTLC{ 4456 PaymentHash: aliceMultihopHash, 4457 Amount: htlcAmt, 4458 Expiry: totalTimelock, 4459 OnionBlob: blob, 4460 } 4461 4462 // Wait until Alice's expected dust for the remote commitment is just 4463 // under the fee threshold. 4464 aliceOrch := n.aliceServer.htlcSwitch.mailOrchestrator 4465 aliceMbox := aliceOrch.GetOrCreateMailBox( 4466 n.aliceChannelLink.ChanID(), n.aliceChannelLink.ShortChanID(), 4467 ) 4468 assertAlmostDust(n.aliceChannelLink, aliceMbox, lntypes.Remote) 4469 4470 err = n.aliceServer.htlcSwitch.SendHTLC( 4471 n.aliceChannelLink.ShortChanID(), uint64(aliceAttemptID), 4472 aliceMultihopHtlc, 4473 ) 4474 require.Nil(t, err) 4475 4476 aliceResultChan, err := n.aliceServer.htlcSwitch.GetAttemptResult( 4477 uint64(aliceAttemptID), aliceMultihopHash, 4478 newMockDeobfuscator(), 4479 ) 4480 require.NoError(t, err) 4481 4482 result, ok = <-aliceResultChan 4483 require.True(t, ok) 4484 assertFailureCode( 4485 t, result.Error, lnwire.CodeTemporaryChannelFailure, 4486 ) 4487 4488 // Check that there are numHTLCs circuits open for both Alice and Bob. 4489 require.Equal(t, numHTLCs, n.aliceServer.htlcSwitch.circuits.NumOpen()) 4490 require.Equal(t, numHTLCs, n.bobServer.htlcSwitch.circuits.NumOpen()) 4491 } 4492 4493 // sendDustHtlcs is a helper function used to send many dust HTLC's to test the 4494 // Switch's channel-max-fee-exposure logic. It takes a boolean denoting whether 4495 // or not Alice is the sender. 4496 func sendDustHtlcs(t *testing.T, n *threeHopNetwork, alice bool, 4497 amt lnwire.MilliSatoshi, sid lnwire.ShortChannelID, numHTLCs int) { 4498 4499 t.Helper() 4500 4501 // Extract the destination into a variable. If alice is the sender, the 4502 // destination is Bob. 4503 destLink := n.aliceChannelLink 4504 if alice { 4505 destLink = n.firstBobChannelLink 4506 } 4507 4508 // Create hops that will be used in the onion payload. 4509 htlcAmt, totalTimelock, hops := generateHops( 4510 amt, testStartingHeight, destLink, 4511 ) 4512 4513 // Convert the hops to a blob that will be put in the Add message. 4514 blob, err := generateRoute(hops...) 4515 require.NoError(t, err) 4516 4517 // Create a slice to store the preimages. 4518 preimages := make([]lntypes.Preimage, numHTLCs) 4519 4520 // Initialize the attempt ID used in SendHTLC calls. 4521 attemptID := uint64(0) 4522 4523 // Deterministically generate preimages. Avoid the all-zeroes preimage 4524 // because that will be rejected by the database. We'll use a different 4525 // third byte for Alice and Bob. 4526 endByte := byte(2) 4527 if alice { 4528 endByte = byte(3) 4529 } 4530 4531 for i := 0; i < numHTLCs; i++ { 4532 preimages[i] = lntypes.Preimage{byte(i >> 8), byte(i), endByte} 4533 } 4534 4535 sendingSwitch := n.bobServer.htlcSwitch 4536 if alice { 4537 sendingSwitch = n.aliceServer.htlcSwitch 4538 } 4539 4540 // Call SendHTLC in a loop for numHTLCs. 4541 for i := 0; i < numHTLCs; i++ { 4542 // Construct the htlc packet. 4543 hash := preimages[i].Hash() 4544 4545 htlc := &lnwire.UpdateAddHTLC{ 4546 PaymentHash: hash, 4547 Amount: htlcAmt, 4548 Expiry: totalTimelock, 4549 OnionBlob: blob, 4550 } 4551 4552 for { 4553 // It may be the case that the fee threshold is hit 4554 // before all numHTLCs*2 HTLC's are sent due to double 4555 // counting. Get around this by continuing to send 4556 // until successful. 4557 err = sendingSwitch.SendHTLC(sid, attemptID, htlc) 4558 if err == nil { 4559 break 4560 } 4561 } 4562 4563 attemptID++ 4564 } 4565 } 4566 4567 // TestSwitchMailboxDust tests that the switch takes into account the mailbox 4568 // dust when evaluating the fee threshold. The mockChannelLink does not have 4569 // channel state, so this only tests the switch-mailbox interaction. 4570 func TestSwitchMailboxDust(t *testing.T) { 4571 t.Parallel() 4572 4573 alicePeer, err := newMockServer( 4574 t, "alice", testStartingHeight, nil, testDefaultDelta, 4575 ) 4576 require.NoError(t, err) 4577 4578 bobPeer, err := newMockServer( 4579 t, "bob", testStartingHeight, nil, testDefaultDelta, 4580 ) 4581 require.NoError(t, err) 4582 4583 carolPeer, err := newMockServer( 4584 t, "carol", testStartingHeight, nil, testDefaultDelta, 4585 ) 4586 require.NoError(t, err) 4587 4588 s, err := initSwitchWithTempDB(t, testStartingHeight) 4589 require.NoError(t, err) 4590 err = s.Start() 4591 require.NoError(t, err) 4592 defer func() { 4593 _ = s.Stop() 4594 }() 4595 4596 chanID1, chanID2, aliceChanID, bobChanID := genIDs() 4597 4598 chanID3, carolChanID := genID() 4599 4600 aliceLink := newMockChannelLink( 4601 s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 4602 false, false, 4603 ) 4604 err = s.AddLink(aliceLink) 4605 require.NoError(t, err) 4606 4607 bobLink := newMockChannelLink( 4608 s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 4609 false, 4610 ) 4611 err = s.AddLink(bobLink) 4612 require.NoError(t, err) 4613 4614 carolLink := newMockChannelLink( 4615 s, chanID3, carolChanID, emptyScid, carolPeer, true, false, 4616 false, false, 4617 ) 4618 err = s.AddLink(carolLink) 4619 require.NoError(t, err) 4620 4621 // mockChannelLink sets the local and remote dust limits of the mailbox 4622 // to 400 satoshis and the feerate to 0. We'll fill the mailbox up with 4623 // dust packets and assert that calls to SendHTLC will fail. 4624 preimage, err := genPreimage() 4625 require.NoError(t, err) 4626 rhash := sha256.Sum256(preimage[:]) 4627 amt := lnwire.NewMSatFromSatoshis(350) 4628 addMsg := &lnwire.UpdateAddHTLC{ 4629 PaymentHash: rhash, 4630 Amount: amt, 4631 ChanID: chanID1, 4632 } 4633 4634 // Initialize the carolHTLCID. 4635 var carolHTLCID uint64 4636 4637 // It will take aliceCount HTLC's of 350sats to fill up Alice's mailbox 4638 // to the point where another would put Alice over the fee threshold. 4639 aliceCount := 1428 4640 4641 mailbox := s.mailOrchestrator.GetOrCreateMailBox(chanID1, aliceChanID) 4642 4643 for i := 0; i < aliceCount; i++ { 4644 alicePkt := &htlcPacket{ 4645 incomingChanID: carolChanID, 4646 incomingHTLCID: carolHTLCID, 4647 outgoingChanID: aliceChanID, 4648 obfuscator: NewMockObfuscator(), 4649 incomingAmount: amt, 4650 amount: amt, 4651 htlc: addMsg, 4652 } 4653 4654 err = mailbox.AddPacket(alicePkt) 4655 require.NoError(t, err) 4656 4657 carolHTLCID++ 4658 } 4659 4660 // Sending one more HTLC to Alice should result in the fee threshold 4661 // being breached. 4662 err = s.SendHTLC(aliceChanID, 0, addMsg) 4663 require.ErrorIs(t, err, errFeeExposureExceeded) 4664 4665 // We'll now call ForwardPackets from Bob to ensure that the mailbox 4666 // sum is also accounted for in the forwarding case. 4667 packet := &htlcPacket{ 4668 incomingChanID: bobChanID, 4669 incomingHTLCID: 0, 4670 outgoingChanID: aliceChanID, 4671 obfuscator: NewMockObfuscator(), 4672 incomingAmount: amt, 4673 amount: amt, 4674 htlc: &lnwire.UpdateAddHTLC{ 4675 PaymentHash: rhash, 4676 Amount: amt, 4677 ChanID: chanID1, 4678 }, 4679 } 4680 4681 err = s.ForwardPackets(nil, packet) 4682 require.NoError(t, err) 4683 4684 // Bob should receive a failure from the switch. 4685 select { 4686 case p := <-bobLink.packets: 4687 require.NotEmpty(t, p.linkFailure) 4688 assertFailureCode( 4689 t, p.linkFailure, lnwire.CodeTemporaryChannelFailure, 4690 ) 4691 4692 case <-time.After(5 * time.Second): 4693 t.Fatal("no timely reply from switch") 4694 } 4695 } 4696 4697 // TestSwitchResolution checks the ability of the switch to persist and handle 4698 // resolution messages. 4699 func TestSwitchResolution(t *testing.T) { 4700 t.Parallel() 4701 4702 alicePeer, err := newMockServer( 4703 t, "alice", testStartingHeight, nil, testDefaultDelta, 4704 ) 4705 require.NoError(t, err) 4706 4707 bobPeer, err := newMockServer( 4708 t, "bob", testStartingHeight, nil, testDefaultDelta, 4709 ) 4710 require.NoError(t, err) 4711 4712 s, err := initSwitchWithTempDB(t, testStartingHeight) 4713 require.NoError(t, err) 4714 4715 // Even though we intend to Stop s later in the test, it is safe to 4716 // defer this Stop since its execution it is protected by an atomic 4717 // guard, guaranteeing it executes at most once. 4718 t.Cleanup(func() { var _ = s.Stop() }) 4719 4720 err = s.Start() 4721 require.NoError(t, err) 4722 4723 chanID1, chanID2, aliceChanID, bobChanID := genIDs() 4724 4725 aliceChannelLink := newMockChannelLink( 4726 s, chanID1, aliceChanID, emptyScid, alicePeer, true, false, 4727 false, false, 4728 ) 4729 bobChannelLink := newMockChannelLink( 4730 s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 4731 false, 4732 ) 4733 err = s.AddLink(aliceChannelLink) 4734 require.NoError(t, err) 4735 err = s.AddLink(bobChannelLink) 4736 require.NoError(t, err) 4737 4738 // Create an add htlcPacket that Alice will send to Bob. 4739 preimage, err := genPreimage() 4740 require.NoError(t, err) 4741 4742 rhash := sha256.Sum256(preimage[:]) 4743 packet := &htlcPacket{ 4744 incomingChanID: aliceChannelLink.ShortChanID(), 4745 incomingHTLCID: 0, 4746 outgoingChanID: bobChannelLink.ShortChanID(), 4747 obfuscator: NewMockObfuscator(), 4748 htlc: &lnwire.UpdateAddHTLC{ 4749 PaymentHash: rhash, 4750 Amount: 1, 4751 }, 4752 } 4753 4754 err = s.ForwardPackets(nil, packet) 4755 require.NoError(t, err) 4756 4757 // Bob will receive the packet and open the circuit. 4758 select { 4759 case <-bobChannelLink.packets: 4760 err = bobChannelLink.completeCircuit(packet) 4761 require.NoError(t, err) 4762 case <-time.After(time.Second): 4763 t.Fatal("request was not propagated to destination") 4764 } 4765 4766 // Check that only one circuit is open. 4767 require.Equal(t, 1, s.circuits.NumOpen()) 4768 4769 // We'll send a settle resolution to Switch that should go to Alice. 4770 settleResMsg := contractcourt.ResolutionMsg{ 4771 SourceChan: bobChanID, 4772 HtlcIndex: 0, 4773 PreImage: &preimage, 4774 } 4775 4776 // Before the resolution is sent, remove alice's link so we can assert 4777 // that the resolution is actually stored. Otherwise, it would be 4778 // deleted shortly after being sent. 4779 s.RemoveLink(chanID1) 4780 4781 // Send the resolution message. 4782 err = s.ProcessContractResolution(settleResMsg) 4783 require.NoError(t, err) 4784 4785 // Assert that the resolution store contains the settle reoslution. 4786 resMsgs, err := s.resMsgStore.fetchAllResolutionMsg() 4787 require.NoError(t, err) 4788 4789 require.Equal(t, 1, len(resMsgs)) 4790 require.Equal(t, settleResMsg.SourceChan, resMsgs[0].SourceChan) 4791 require.Equal(t, settleResMsg.HtlcIndex, resMsgs[0].HtlcIndex) 4792 require.Nil(t, resMsgs[0].Failure) 4793 require.Equal(t, preimage, *resMsgs[0].PreImage) 4794 4795 // Now we'll restart Alice's link and delete the circuit. 4796 err = s.AddLink(aliceChannelLink) 4797 require.NoError(t, err) 4798 4799 // Alice will receive the packet and open the circuit. 4800 select { 4801 case alicePkt := <-aliceChannelLink.packets: 4802 err = aliceChannelLink.completeCircuit(alicePkt) 4803 require.NoError(t, err) 4804 case <-time.After(time.Second): 4805 t.Fatal("request was not propagated to destination") 4806 } 4807 4808 // Assert that there are no more circuits. 4809 require.Equal(t, 0, s.circuits.NumOpen()) 4810 4811 // We'll restart the Switch and assert that Alice does not receive 4812 // another packet. 4813 switchDB := s.cfg.DB.(*channeldb.DB) 4814 err = s.Stop() 4815 require.NoError(t, err) 4816 4817 s, err = initSwitchWithDB(testStartingHeight, switchDB) 4818 require.NoError(t, err) 4819 4820 err = s.Start() 4821 require.NoError(t, err) 4822 defer func() { 4823 _ = s.Stop() 4824 }() 4825 4826 err = s.AddLink(aliceChannelLink) 4827 require.NoError(t, err) 4828 err = s.AddLink(bobChannelLink) 4829 require.NoError(t, err) 4830 4831 // Alice should not receive a packet since the Switch should have 4832 // deleted the resolution message since the circuit was closed. 4833 select { 4834 case alicePkt := <-aliceChannelLink.packets: 4835 t.Fatalf("received erroneous packet: %v", alicePkt) 4836 case <-time.After(time.Second * 5): 4837 } 4838 4839 // Check that the resolution message no longer exists in the store. 4840 resMsgs, err = s.resMsgStore.fetchAllResolutionMsg() 4841 require.NoError(t, err) 4842 require.Equal(t, 0, len(resMsgs)) 4843 } 4844 4845 // TestSwitchForwardFailAlias tests that if ForwardPackets returns a failure 4846 // before actually forwarding, the ChannelUpdate uses the SCID from the 4847 // incoming channel and does not leak private information like the UTXO. 4848 func TestSwitchForwardFailAlias(t *testing.T) { 4849 tests := []struct { 4850 name string 4851 4852 // Whether or not Alice will be a zero-conf channel or an 4853 // option-scid-alias channel (feature-bit). 4854 zeroConf bool 4855 }{ 4856 { 4857 name: "option-scid-alias forwarding failure", 4858 zeroConf: false, 4859 }, 4860 { 4861 name: "zero-conf forwarding failure", 4862 zeroConf: true, 4863 }, 4864 } 4865 4866 for _, test := range tests { 4867 test := test 4868 4869 t.Run(test.name, func(t *testing.T) { 4870 testSwitchForwardFailAlias(t, test.zeroConf) 4871 }) 4872 } 4873 } 4874 4875 func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) { 4876 t.Parallel() 4877 4878 chanID1, chanID2, aliceChanID, bobChanID := genIDs() 4879 4880 alicePeer, err := newMockServer( 4881 t, "alice", testStartingHeight, nil, testDefaultDelta, 4882 ) 4883 require.NoError(t, err) 4884 4885 bobPeer, err := newMockServer( 4886 t, "bob", testStartingHeight, nil, testDefaultDelta, 4887 ) 4888 require.NoError(t, err) 4889 4890 tempPath := t.TempDir() 4891 4892 cdb := channeldb.OpenForTesting(t, tempPath) 4893 4894 s, err := initSwitchWithDB(testStartingHeight, cdb) 4895 require.NoError(t, err) 4896 4897 err = s.Start() 4898 require.NoError(t, err) 4899 4900 // Make Alice's channel zero-conf or option-scid-alias (feature bit). 4901 aliceAlias := lnwire.ShortChannelID{ 4902 BlockHeight: 16_000_000, 4903 TxIndex: 5, 4904 TxPosition: 5, 4905 } 4906 4907 var aliceLink *mockChannelLink 4908 if zeroConf { 4909 aliceLink = newMockChannelLink( 4910 s, chanID1, aliceAlias, aliceChanID, alicePeer, true, 4911 true, true, false, 4912 ) 4913 } else { 4914 aliceLink = newMockChannelLink( 4915 s, chanID1, aliceChanID, emptyScid, alicePeer, true, 4916 true, false, true, 4917 ) 4918 aliceLink.addAlias(aliceAlias) 4919 } 4920 err = s.AddLink(aliceLink) 4921 require.NoError(t, err) 4922 4923 bobLink := newMockChannelLink( 4924 s, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 4925 false, 4926 ) 4927 err = s.AddLink(bobLink) 4928 require.NoError(t, err) 4929 4930 // Create a packet that will be sent from Alice to Bob via the switch. 4931 preimage := [sha256.Size]byte{1} 4932 rhash := sha256.Sum256(preimage[:]) 4933 ogPacket := &htlcPacket{ 4934 incomingChanID: aliceLink.ShortChanID(), 4935 incomingHTLCID: 0, 4936 outgoingChanID: bobLink.ShortChanID(), 4937 obfuscator: NewMockObfuscator(), 4938 htlc: &lnwire.UpdateAddHTLC{ 4939 PaymentHash: rhash, 4940 Amount: 1, 4941 }, 4942 } 4943 4944 // Forward the packet and check that Bob's channel link received it. 4945 err = s.ForwardPackets(nil, ogPacket) 4946 require.NoError(t, err) 4947 4948 // Assert that the circuits are in the expected state. 4949 require.Equal(t, 1, s.circuits.NumPending()) 4950 require.Equal(t, 0, s.circuits.NumOpen()) 4951 4952 // Pull packet from Bob's link, and do nothing with it. 4953 select { 4954 case <-bobLink.packets: 4955 case <-s.quit: 4956 t.Fatal("switch shutting down, failed to forward packet") 4957 } 4958 4959 // Now we will restart the Switch to trigger the LoadedFromDisk logic. 4960 err = s.Stop() 4961 require.NoError(t, err) 4962 4963 err = cdb.Close() 4964 require.NoError(t, err) 4965 4966 cdb2 := channeldb.OpenForTesting(t, tempPath) 4967 4968 s2, err := initSwitchWithDB(testStartingHeight, cdb2) 4969 require.NoError(t, err) 4970 4971 err = s2.Start() 4972 require.NoError(t, err) 4973 4974 defer func() { 4975 _ = s2.Stop() 4976 }() 4977 4978 var aliceLink2 *mockChannelLink 4979 if zeroConf { 4980 aliceLink2 = newMockChannelLink( 4981 s2, chanID1, aliceAlias, aliceChanID, alicePeer, true, 4982 true, true, false, 4983 ) 4984 } else { 4985 aliceLink2 = newMockChannelLink( 4986 s2, chanID1, aliceChanID, emptyScid, alicePeer, true, 4987 true, false, true, 4988 ) 4989 aliceLink2.addAlias(aliceAlias) 4990 } 4991 err = s2.AddLink(aliceLink2) 4992 require.NoError(t, err) 4993 4994 bobLink2 := newMockChannelLink( 4995 s2, chanID2, bobChanID, emptyScid, bobPeer, true, false, false, 4996 false, 4997 ) 4998 err = s2.AddLink(bobLink2) 4999 require.NoError(t, err) 5000 5001 // Reforward the ogPacket and wait for Alice to receive a failure 5002 // packet. 5003 err = s2.ForwardPackets(nil, ogPacket) 5004 require.NoError(t, err) 5005 5006 select { 5007 case failPacket := <-aliceLink2.packets: 5008 // Assert that the failPacket does not leak UTXO information. 5009 // This means checking that aliceChanID was not returned. 5010 msg := failPacket.linkFailure.msg 5011 failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) 5012 require.True(t, ok) 5013 require.Equal(t, aliceAlias, failMsg.Update.ShortChannelID) 5014 case <-s2.quit: 5015 t.Fatal("switch shutting down, failed to forward packet") 5016 } 5017 } 5018 5019 // TestSwitchAliasFailAdd tests that the mailbox does not leak UTXO information 5020 // when failing back an HTLC due to the 5-second timeout. This is tested in the 5021 // switch rather than the mailbox because the mailbox tests do not have the 5022 // proper context (e.g. the Switch's failAliasUpdate function). The caveat here 5023 // is that if the private UTXO is already known, it is fine to send a failure 5024 // back. This tests option-scid-alias (feature-bit) and zero-conf channels. 5025 func TestSwitchAliasFailAdd(t *testing.T) { 5026 tests := []struct { 5027 name string 5028 5029 // Denotes whether the opened channel will be zero-conf. 5030 zeroConf bool 5031 5032 // Denotes whether the opened channel will be private. 5033 private bool 5034 5035 // Denotes whether an alias was used during forwarding. 5036 useAlias bool 5037 }{ 5038 { 5039 name: "public zero-conf using alias", 5040 zeroConf: true, 5041 private: false, 5042 useAlias: true, 5043 }, 5044 { 5045 name: "public zero-conf using real", 5046 zeroConf: true, 5047 private: false, 5048 useAlias: true, 5049 }, 5050 { 5051 name: "private zero-conf using alias", 5052 zeroConf: true, 5053 private: true, 5054 useAlias: true, 5055 }, 5056 { 5057 name: "public option-scid-alias using alias", 5058 zeroConf: false, 5059 private: false, 5060 useAlias: true, 5061 }, 5062 { 5063 name: "public option-scid-alias using real", 5064 zeroConf: false, 5065 private: false, 5066 useAlias: false, 5067 }, 5068 { 5069 name: "private option-scid-alias using alias", 5070 zeroConf: false, 5071 private: true, 5072 useAlias: true, 5073 }, 5074 } 5075 5076 for _, test := range tests { 5077 test := test 5078 5079 t.Run(test.name, func(t *testing.T) { 5080 testSwitchAliasFailAdd( 5081 t, test.zeroConf, test.private, test.useAlias, 5082 ) 5083 }) 5084 } 5085 } 5086 5087 func testSwitchAliasFailAdd(t *testing.T, zeroConf, private, useAlias bool) { 5088 t.Parallel() 5089 5090 chanID1, chanID2, aliceChanID, bobChanID := genIDs() 5091 5092 alicePeer, err := newMockServer( 5093 t, "alice", testStartingHeight, nil, testDefaultDelta, 5094 ) 5095 require.NoError(t, err) 5096 5097 bobPeer, err := newMockServer( 5098 t, "bob", testStartingHeight, nil, testDefaultDelta, 5099 ) 5100 require.NoError(t, err) 5101 5102 tempPath := t.TempDir() 5103 5104 cdb := channeldb.OpenForTesting(t, tempPath) 5105 5106 s, err := initSwitchWithDB(testStartingHeight, cdb) 5107 require.NoError(t, err) 5108 5109 // Change the mailOrchestrator's expiry to a second. 5110 s.mailOrchestrator.cfg.expiry = time.Second 5111 5112 err = s.Start() 5113 require.NoError(t, err) 5114 5115 defer func() { 5116 _ = s.Stop() 5117 }() 5118 5119 // Make Alice's channel zero-conf or option-scid-alias (feature bit). 5120 aliceAlias := lnwire.ShortChannelID{ 5121 BlockHeight: 16_000_000, 5122 TxIndex: 5, 5123 TxPosition: 5, 5124 } 5125 aliceAlias2 := aliceAlias 5126 aliceAlias2.TxPosition = 6 5127 5128 var aliceLink *mockChannelLink 5129 if zeroConf { 5130 aliceLink = newMockChannelLink( 5131 s, chanID1, aliceAlias, aliceChanID, alicePeer, true, 5132 private, true, false, 5133 ) 5134 aliceLink.addAlias(aliceAlias2) 5135 } else { 5136 aliceLink = newMockChannelLink( 5137 s, chanID1, aliceChanID, emptyScid, alicePeer, true, 5138 private, false, true, 5139 ) 5140 aliceLink.addAlias(aliceAlias) 5141 aliceLink.addAlias(aliceAlias2) 5142 } 5143 err = s.AddLink(aliceLink) 5144 require.NoError(t, err) 5145 5146 bobLink := newMockChannelLink( 5147 s, chanID2, bobChanID, emptyScid, bobPeer, true, true, false, 5148 false, 5149 ) 5150 err = s.AddLink(bobLink) 5151 require.NoError(t, err) 5152 5153 // Create a packet that Bob will send to Alice via ForwardPackets. 5154 preimage := [sha256.Size]byte{1} 5155 rhash := sha256.Sum256(preimage[:]) 5156 ogPacket := &htlcPacket{ 5157 incomingChanID: bobLink.ShortChanID(), 5158 incomingHTLCID: 0, 5159 obfuscator: NewMockObfuscator(), 5160 htlc: &lnwire.UpdateAddHTLC{ 5161 PaymentHash: rhash, 5162 Amount: 1, 5163 }, 5164 } 5165 5166 // Determine which outgoingChanID to set based on the useAlias boolean. 5167 outgoingChanID := aliceChanID 5168 if useAlias { 5169 // Choose randomly from the 2 possible aliases. 5170 aliases := aliceLink.getAliases() 5171 idx := mrand.Intn(len(aliases)) 5172 5173 outgoingChanID = aliases[idx] 5174 } 5175 5176 ogPacket.outgoingChanID = outgoingChanID 5177 5178 // Forward the packet so Alice's mailbox fails it backwards. 5179 err = s.ForwardPackets(nil, ogPacket) 5180 require.NoError(t, err) 5181 5182 // Assert that the circuits are in the expected state. 5183 require.Equal(t, 1, s.circuits.NumPending()) 5184 require.Equal(t, 0, s.circuits.NumOpen()) 5185 5186 // Wait to receive the packet from Bob's mailbox. 5187 select { 5188 case failPacket := <-bobLink.packets: 5189 // Assert that failPacket returns the expected SCID in the 5190 // ChannelUpdate. 5191 msg := failPacket.linkFailure.msg 5192 failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) 5193 require.True(t, ok) 5194 require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) 5195 case <-s.quit: 5196 t.Fatal("switch shutting down, failed to receive fail packet") 5197 } 5198 } 5199 5200 // TestSwitchHandlePacketForwardAlias checks that handlePacketForward (which 5201 // calls CheckHtlcForward) does not leak the UTXO in a failure message for 5202 // alias channels. This test requires us to have a REAL link, which we also 5203 // must modify in order to test it properly (e.g. making it a private channel). 5204 // This doesn't lead to good code, but short of refactoring the link-generation 5205 // code there is not a good alternative. 5206 func TestSwitchHandlePacketForward(t *testing.T) { 5207 tests := []struct { 5208 name string 5209 5210 // Denotes whether or not the channel will be zero-conf. 5211 zeroConf bool 5212 5213 // Denotes whether or not the channel will have negotiated the 5214 // option-scid-alias feature-bit and is not zero-conf. 5215 optionFeature bool 5216 5217 // Denotes whether or not the channel will be private. 5218 private bool 5219 5220 // Denotes whether or not the alias will be used for 5221 // forwarding. 5222 useAlias bool 5223 }{ 5224 { 5225 name: "public zero-conf using alias", 5226 zeroConf: true, 5227 private: false, 5228 useAlias: true, 5229 }, 5230 { 5231 name: "public zero-conf using real", 5232 zeroConf: true, 5233 private: false, 5234 useAlias: false, 5235 }, 5236 { 5237 name: "private zero-conf using alias", 5238 zeroConf: true, 5239 private: true, 5240 useAlias: true, 5241 }, 5242 { 5243 name: "public option-scid-alias using alias", 5244 zeroConf: false, 5245 optionFeature: true, 5246 private: false, 5247 useAlias: true, 5248 }, 5249 { 5250 name: "public option-scid-alias using real", 5251 zeroConf: false, 5252 optionFeature: true, 5253 private: false, 5254 useAlias: false, 5255 }, 5256 { 5257 name: "private option-scid-alias using alias", 5258 zeroConf: false, 5259 optionFeature: true, 5260 private: true, 5261 useAlias: true, 5262 }, 5263 } 5264 5265 for _, test := range tests { 5266 test := test 5267 5268 t.Run(test.name, func(t *testing.T) { 5269 testSwitchHandlePacketForward( 5270 t, test.zeroConf, test.private, test.useAlias, 5271 test.optionFeature, 5272 ) 5273 }) 5274 } 5275 } 5276 5277 func testSwitchHandlePacketForward(t *testing.T, zeroConf, private, 5278 useAlias, optionFeature bool) { 5279 5280 t.Parallel() 5281 5282 // Create a link for Alice that we'll add to the switch. 5283 harness, err := 5284 newSingleLinkTestHarness(t, btcutil.SatoshiPerBitcoin, 0) 5285 require.NoError(t, err) 5286 5287 aliceLink := harness.aliceLink 5288 5289 s, err := initSwitchWithTempDB(t, testStartingHeight) 5290 if err != nil { 5291 t.Fatalf("unable to init switch: %v", err) 5292 } 5293 if err := s.Start(); err != nil { 5294 t.Fatalf("unable to start switch: %v", err) 5295 } 5296 defer func() { 5297 _ = s.Stop() 5298 }() 5299 5300 // Change Alice's ShortChanID and OtherShortChanID here. 5301 aliceAlias := lnwire.ShortChannelID{ 5302 BlockHeight: 16_000_000, 5303 TxIndex: 5, 5304 TxPosition: 5, 5305 } 5306 aliceAlias2 := aliceAlias 5307 aliceAlias2.TxPosition = 6 5308 5309 aliceChannelLink := aliceLink.(*channelLink) 5310 aliceChannelState := aliceChannelLink.channel.State() 5311 5312 // Set the link's GetAliases function. 5313 aliceChannelLink.cfg.GetAliases = func( 5314 base lnwire.ShortChannelID) []lnwire.ShortChannelID { 5315 5316 return []lnwire.ShortChannelID{aliceAlias, aliceAlias2} 5317 } 5318 5319 if !private { 5320 // Change the channel to public depending on the test. 5321 aliceChannelState.ChannelFlags = lnwire.FFAnnounceChannel 5322 } 5323 5324 // If this is an option-scid-alias feature-bit non-zero-conf channel, 5325 // we'll mark the channel as such. 5326 if optionFeature { 5327 aliceChannelState.ChanType |= channeldb.ScidAliasFeatureBit 5328 } 5329 5330 // This is the ShortChannelID field in the OpenChannel struct. 5331 aliceScid := aliceLink.ShortChanID() 5332 if zeroConf { 5333 // Store the alias in the shortChanID field and mark the real 5334 // scid in the database. 5335 err = aliceChannelState.MarkRealScid(aliceScid) 5336 require.NoError(t, err) 5337 5338 aliceChannelState.ChanType |= channeldb.ZeroConfBit 5339 } 5340 5341 err = s.AddLink(aliceLink) 5342 require.NoError(t, err) 5343 5344 // Add a mockChannelLink for Bob. 5345 bobChanID, bobScid := genID() 5346 bobPeer, err := newMockServer( 5347 t, "bob", testStartingHeight, nil, testDefaultDelta, 5348 ) 5349 require.NoError(t, err) 5350 5351 bobLink := newMockChannelLink( 5352 s, bobChanID, bobScid, emptyScid, bobPeer, true, false, false, 5353 false, 5354 ) 5355 err = s.AddLink(bobLink) 5356 require.NoError(t, err) 5357 5358 preimage := [sha256.Size]byte{1} 5359 rhash := sha256.Sum256(preimage[:]) 5360 ogPacket := &htlcPacket{ 5361 incomingChanID: bobLink.ShortChanID(), 5362 incomingHTLCID: 0, 5363 incomingAmount: 1000, 5364 obfuscator: NewMockObfuscator(), 5365 htlc: &lnwire.UpdateAddHTLC{ 5366 PaymentHash: rhash, 5367 Amount: 1, 5368 }, 5369 } 5370 5371 // Determine which outgoingChanID to set based on the useAlias bool. 5372 outgoingChanID := aliceScid 5373 if useAlias { 5374 // Choose from the possible aliases. 5375 aliases := aliceLink.getAliases() 5376 idx := mrand.Intn(len(aliases)) 5377 5378 outgoingChanID = aliases[idx] 5379 } 5380 5381 ogPacket.outgoingChanID = outgoingChanID 5382 5383 // Forward the packet to Alice and she should fail it back with an 5384 // AmountBelowMinimum FailureMessage. 5385 err = s.ForwardPackets(nil, ogPacket) 5386 require.NoError(t, err) 5387 5388 select { 5389 case failPacket := <-bobLink.packets: 5390 // Assert that failPacket returns the expected ChannelUpdate. 5391 msg := failPacket.linkFailure.msg 5392 failMsg, ok := msg.(*lnwire.FailAmountBelowMinimum) 5393 require.True(t, ok) 5394 require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) 5395 case <-s.quit: 5396 t.Fatal("switch shutting down, failed to receive failure") 5397 } 5398 } 5399 5400 // TestSwitchAliasInterceptFail tests that when the InterceptableSwitch fails 5401 // an incoming HTLC, it does not leak the on-chain UTXO for option-scid-alias 5402 // (feature bit) or zero-conf channels. 5403 func TestSwitchAliasInterceptFail(t *testing.T) { 5404 tests := []struct { 5405 name string 5406 5407 // Denotes whether or not the incoming channel is a zero-conf 5408 // channel or an option-scid-alias channel instead (feature 5409 // bit). 5410 zeroConf bool 5411 }{ 5412 { 5413 name: "option-scid-alias", 5414 zeroConf: false, 5415 }, 5416 { 5417 name: "zero-conf", 5418 zeroConf: true, 5419 }, 5420 } 5421 5422 for _, test := range tests { 5423 test := test 5424 5425 t.Run(test.name, func(t *testing.T) { 5426 testSwitchAliasInterceptFail(t, test.zeroConf) 5427 }) 5428 } 5429 } 5430 5431 func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) { 5432 t.Parallel() 5433 5434 chanID, aliceScid := genID() 5435 5436 alicePeer, err := newMockServer( 5437 t, "alice", testStartingHeight, nil, testDefaultDelta, 5438 ) 5439 require.NoError(t, err) 5440 5441 tempPath := t.TempDir() 5442 5443 cdb := channeldb.OpenForTesting(t, tempPath) 5444 5445 s, err := initSwitchWithDB(testStartingHeight, cdb) 5446 require.NoError(t, err) 5447 5448 err = s.Start() 5449 require.NoError(t, err) 5450 5451 defer func() { 5452 _ = s.Stop() 5453 }() 5454 5455 // Make Alice's alias here. 5456 aliceAlias := lnwire.ShortChannelID{ 5457 BlockHeight: 16_000_000, 5458 TxIndex: 5, 5459 TxPosition: 5, 5460 } 5461 aliceAlias2 := aliceAlias 5462 aliceAlias2.TxPosition = 6 5463 5464 var aliceLink *mockChannelLink 5465 if zeroConf { 5466 aliceLink = newMockChannelLink( 5467 s, chanID, aliceAlias, aliceScid, alicePeer, true, 5468 true, true, false, 5469 ) 5470 aliceLink.addAlias(aliceAlias2) 5471 } else { 5472 aliceLink = newMockChannelLink( 5473 s, chanID, aliceScid, emptyScid, alicePeer, true, 5474 true, false, true, 5475 ) 5476 aliceLink.addAlias(aliceAlias) 5477 aliceLink.addAlias(aliceAlias2) 5478 } 5479 err = s.AddLink(aliceLink) 5480 require.NoError(t, err) 5481 5482 // Now we'll create the packet that will be sent from the Alice link. 5483 preimage := [sha256.Size]byte{1} 5484 rhash := sha256.Sum256(preimage[:]) 5485 ogPacket := &htlcPacket{ 5486 incomingChanID: aliceLink.ShortChanID(), 5487 incomingTimeout: 1000, 5488 incomingHTLCID: 0, 5489 outgoingChanID: lnwire.ShortChannelID{}, 5490 obfuscator: NewMockObfuscator(), 5491 htlc: &lnwire.UpdateAddHTLC{ 5492 PaymentHash: rhash, 5493 Amount: 1, 5494 }, 5495 } 5496 5497 // Now setup the interceptable switch so that we can reject this 5498 // packet. 5499 forwardInterceptor := &mockForwardInterceptor{ 5500 t: t, 5501 interceptedChan: make(chan InterceptedPacket), 5502 } 5503 5504 notifier := &mock.ChainNotifier{ 5505 EpochChan: make(chan *chainntnfs.BlockEpoch, 1), 5506 } 5507 notifier.EpochChan <- &chainntnfs.BlockEpoch{Height: testStartingHeight} 5508 5509 interceptSwitch, err := NewInterceptableSwitch( 5510 &InterceptableSwitchConfig{ 5511 Switch: s, 5512 Notifier: notifier, 5513 CltvRejectDelta: 10, 5514 CltvInterceptDelta: 13, 5515 }, 5516 ) 5517 require.NoError(t, err) 5518 require.NoError(t, interceptSwitch.Start()) 5519 interceptSwitch.SetInterceptor(forwardInterceptor.InterceptForwardHtlc) 5520 5521 err = interceptSwitch.ForwardPackets(nil, false, ogPacket) 5522 require.NoError(t, err) 5523 5524 inCircuit := forwardInterceptor.getIntercepted().IncomingCircuit 5525 require.NoError(t, interceptSwitch.resolve(&FwdResolution{ 5526 Action: FwdActionFail, 5527 Key: inCircuit, 5528 FailureCode: lnwire.CodeTemporaryChannelFailure, 5529 })) 5530 5531 select { 5532 case failPacket := <-aliceLink.packets: 5533 // Assert that failPacket returns the expected ChannelUpdate. 5534 failHtlc, ok := failPacket.htlc.(*lnwire.UpdateFailHTLC) 5535 require.True(t, ok) 5536 5537 fwdErr, err := newMockDeobfuscator().DecryptError( 5538 failHtlc.Reason, 5539 ) 5540 require.NoError(t, err) 5541 5542 failure := fwdErr.WireMessage() 5543 5544 failureMsg, ok := failure.(*lnwire.FailTemporaryChannelFailure) 5545 require.True(t, ok) 5546 5547 failScid := failureMsg.Update.ShortChannelID 5548 isAlias := failScid == aliceAlias || failScid == aliceAlias2 5549 require.True(t, isAlias) 5550 5551 case <-s.quit: 5552 t.Fatalf("switch shutting down, failed to receive failure") 5553 } 5554 5555 require.NoError(t, interceptSwitch.Stop()) 5556 }