/ htlcswitch / mock.go
mock.go
1 package htlcswitch 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/sha256" 7 "encoding/binary" 8 "errors" 9 "fmt" 10 "io" 11 "net" 12 "path/filepath" 13 "sync" 14 "sync/atomic" 15 "testing" 16 "time" 17 18 "github.com/btcsuite/btcd/btcec/v2" 19 "github.com/btcsuite/btcd/btcec/v2/ecdsa" 20 "github.com/btcsuite/btcd/btcutil" 21 "github.com/btcsuite/btcd/wire" 22 sphinx "github.com/lightningnetwork/lightning-onion" 23 "github.com/lightningnetwork/lnd/chainntnfs" 24 "github.com/lightningnetwork/lnd/channeldb" 25 "github.com/lightningnetwork/lnd/clock" 26 "github.com/lightningnetwork/lnd/contractcourt" 27 "github.com/lightningnetwork/lnd/fn/v2" 28 "github.com/lightningnetwork/lnd/graph/db/models" 29 "github.com/lightningnetwork/lnd/htlcswitch/hop" 30 "github.com/lightningnetwork/lnd/invoices" 31 "github.com/lightningnetwork/lnd/lnpeer" 32 "github.com/lightningnetwork/lnd/lntest/mock" 33 "github.com/lightningnetwork/lnd/lntypes" 34 "github.com/lightningnetwork/lnd/lnwallet/chainfee" 35 "github.com/lightningnetwork/lnd/lnwire" 36 "github.com/lightningnetwork/lnd/ticker" 37 "github.com/lightningnetwork/lnd/tlv" 38 ) 39 40 func isAlias(scid lnwire.ShortChannelID) bool { 41 return scid.BlockHeight >= 16_000_000 && scid.BlockHeight < 16_250_000 42 } 43 44 type mockPreimageCache struct { 45 sync.Mutex 46 preimageMap map[lntypes.Hash]lntypes.Preimage 47 } 48 49 func newMockPreimageCache() *mockPreimageCache { 50 return &mockPreimageCache{ 51 preimageMap: make(map[lntypes.Hash]lntypes.Preimage), 52 } 53 } 54 55 func (m *mockPreimageCache) LookupPreimage( 56 hash lntypes.Hash) (lntypes.Preimage, bool) { 57 58 m.Lock() 59 defer m.Unlock() 60 61 p, ok := m.preimageMap[hash] 62 return p, ok 63 } 64 65 func (m *mockPreimageCache) AddPreimages(preimages ...lntypes.Preimage) error { 66 m.Lock() 67 defer m.Unlock() 68 69 for _, preimage := range preimages { 70 m.preimageMap[preimage.Hash()] = preimage 71 } 72 73 return nil 74 } 75 76 func (m *mockPreimageCache) SubscribeUpdates( 77 chanID lnwire.ShortChannelID, htlc *channeldb.HTLC, 78 payload *hop.Payload, 79 nextHopOnionBlob []byte) (*contractcourt.WitnessSubscription, error) { 80 81 return nil, nil 82 } 83 84 // TODO(yy): replace it with chainfee.MockEstimator. 85 type mockFeeEstimator struct { 86 byteFeeIn chan chainfee.SatPerKWeight 87 relayFee chan chainfee.SatPerKWeight 88 89 quit chan struct{} 90 } 91 92 func newMockFeeEstimator() *mockFeeEstimator { 93 return &mockFeeEstimator{ 94 byteFeeIn: make(chan chainfee.SatPerKWeight), 95 relayFee: make(chan chainfee.SatPerKWeight), 96 quit: make(chan struct{}), 97 } 98 } 99 100 func (m *mockFeeEstimator) EstimateFeePerKW( 101 numBlocks uint32) (chainfee.SatPerKWeight, error) { 102 103 select { 104 case feeRate := <-m.byteFeeIn: 105 return feeRate, nil 106 case <-m.quit: 107 return 0, fmt.Errorf("exiting") 108 } 109 } 110 111 func (m *mockFeeEstimator) RelayFeePerKW() chainfee.SatPerKWeight { 112 select { 113 case feeRate := <-m.relayFee: 114 return feeRate 115 case <-m.quit: 116 return 0 117 } 118 } 119 120 func (m *mockFeeEstimator) Start() error { 121 return nil 122 } 123 func (m *mockFeeEstimator) Stop() error { 124 close(m.quit) 125 return nil 126 } 127 128 var _ chainfee.Estimator = (*mockFeeEstimator)(nil) 129 130 type mockForwardingLog struct { 131 sync.Mutex 132 133 events map[time.Time]channeldb.ForwardingEvent 134 } 135 136 func (m *mockForwardingLog) AddForwardingEvents(events []channeldb.ForwardingEvent) error { 137 m.Lock() 138 defer m.Unlock() 139 140 for _, event := range events { 141 m.events[event.Timestamp] = event 142 } 143 144 return nil 145 } 146 147 type mockServer struct { 148 started int32 // To be used atomically. 149 shutdown int32 // To be used atomically. 150 wg sync.WaitGroup 151 quit chan struct{} 152 153 t testing.TB 154 155 name string 156 messages chan lnwire.Message 157 protocolTraceMtx sync.Mutex 158 protocolTrace []lnwire.Message 159 160 id [33]byte 161 htlcSwitch *Switch 162 163 registry *mockInvoiceRegistry 164 pCache *mockPreimageCache 165 interceptorFuncs []messageInterceptor 166 } 167 168 var _ lnpeer.Peer = (*mockServer)(nil) 169 170 func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) { 171 signAliasUpdate := func(u *lnwire.ChannelUpdate1) (*ecdsa.Signature, 172 error) { 173 174 return testSig, nil 175 } 176 177 cfg := Config{ 178 DB: db, 179 FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, 180 FetchAllChannels: db.ChannelStateDB().FetchAllChannels, 181 FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, 182 SwitchPackager: channeldb.NewSwitchPackager(), 183 FwdingLog: &mockForwardingLog{ 184 events: make(map[time.Time]channeldb.ForwardingEvent), 185 }, 186 FetchLastChannelUpdate: func(scid lnwire.ShortChannelID) ( 187 *lnwire.ChannelUpdate1, error) { 188 189 return &lnwire.ChannelUpdate1{ 190 ShortChannelID: scid, 191 }, nil 192 }, 193 Notifier: &mock.ChainNotifier{ 194 SpendChan: make(chan *chainntnfs.SpendDetail), 195 EpochChan: make(chan *chainntnfs.BlockEpoch), 196 ConfChan: make(chan *chainntnfs.TxConfirmation), 197 }, 198 FwdEventTicker: ticker.NewForce( 199 DefaultFwdEventInterval, 200 ), 201 LogEventTicker: ticker.NewForce(DefaultLogInterval), 202 AckEventTicker: ticker.NewForce(DefaultAckInterval), 203 HtlcNotifier: &mockHTLCNotifier{}, 204 Clock: clock.NewDefaultClock(), 205 MailboxDeliveryTimeout: time.Hour, 206 MaxFeeExposure: DefaultMaxFeeExposure, 207 SignAliasUpdate: signAliasUpdate, 208 IsAlias: isAlias, 209 } 210 211 return New(cfg, startingHeight) 212 } 213 214 func initSwitchWithTempDB(t testing.TB, startingHeight uint32) (*Switch, 215 error) { 216 217 tempPath := filepath.Join(t.TempDir(), "switchdb") 218 db := channeldb.OpenForTesting(t, tempPath) 219 220 s, err := initSwitchWithDB(startingHeight, db) 221 if err != nil { 222 return nil, err 223 } 224 225 return s, nil 226 } 227 228 func newMockServer(t testing.TB, name string, startingHeight uint32, 229 db *channeldb.DB, defaultDelta uint32) (*mockServer, error) { 230 231 var id [33]byte 232 h := sha256.Sum256([]byte(name)) 233 copy(id[:], h[:]) 234 235 pCache := newMockPreimageCache() 236 237 var ( 238 htlcSwitch *Switch 239 err error 240 ) 241 if db == nil { 242 htlcSwitch, err = initSwitchWithTempDB(t, startingHeight) 243 } else { 244 htlcSwitch, err = initSwitchWithDB(startingHeight, db) 245 } 246 if err != nil { 247 return nil, err 248 } 249 250 t.Cleanup(func() { _ = htlcSwitch.Stop() }) 251 252 registry := newMockRegistry(t) 253 254 return &mockServer{ 255 t: t, 256 id: id, 257 name: name, 258 messages: make(chan lnwire.Message, 3000), 259 quit: make(chan struct{}), 260 registry: registry, 261 htlcSwitch: htlcSwitch, 262 pCache: pCache, 263 interceptorFuncs: make([]messageInterceptor, 0), 264 }, nil 265 } 266 267 func (s *mockServer) Start() error { 268 if !atomic.CompareAndSwapInt32(&s.started, 0, 1) { 269 return errors.New("mock server already started") 270 } 271 272 if err := s.htlcSwitch.Start(); err != nil { 273 return err 274 } 275 276 s.wg.Add(1) 277 go func() { 278 defer s.wg.Done() 279 280 defer func() { 281 s.htlcSwitch.Stop() 282 }() 283 284 for { 285 select { 286 case msg := <-s.messages: 287 s.protocolTraceMtx.Lock() 288 s.protocolTrace = append(s.protocolTrace, msg) 289 s.protocolTraceMtx.Unlock() 290 291 var shouldSkip bool 292 293 for _, interceptor := range s.interceptorFuncs { 294 skip, err := interceptor(msg) 295 if err != nil { 296 s.t.Fatalf("%v: error in the "+ 297 "interceptor: %v", s.name, err) 298 return 299 } 300 shouldSkip = shouldSkip || skip 301 } 302 303 if shouldSkip { 304 continue 305 } 306 307 if err := s.readHandler(msg); err != nil { 308 s.t.Fatal(err) 309 return 310 } 311 case <-s.quit: 312 return 313 } 314 } 315 }() 316 317 return nil 318 } 319 320 func (s *mockServer) QuitSignal() <-chan struct{} { 321 return s.quit 322 } 323 324 // mockHopIterator represents the test version of hop iterator which instead 325 // of encrypting the path in onion blob just stores the path as a list of hops. 326 type mockHopIterator struct { 327 hops []*hop.Payload 328 } 329 330 func newMockHopIterator(hops ...*hop.Payload) hop.Iterator { 331 return &mockHopIterator{hops: hops} 332 } 333 334 func (r *mockHopIterator) HopPayload() (*hop.Payload, hop.RouteRole, error) { 335 h := r.hops[0] 336 r.hops = r.hops[1:] 337 return h, hop.RouteRoleCleartext, nil 338 } 339 340 func (r *mockHopIterator) ExtraOnionBlob() []byte { 341 return nil 342 } 343 344 func (r *mockHopIterator) ExtractErrorEncrypter( 345 extracter hop.ErrorEncrypterExtracter, _ bool) (hop.ErrorEncrypter, 346 lnwire.FailCode) { 347 348 return extracter(nil) 349 } 350 351 func (r *mockHopIterator) EncodeNextHop(w io.Writer) error { 352 var hopLength [4]byte 353 binary.BigEndian.PutUint32(hopLength[:], uint32(len(r.hops))) 354 355 if _, err := w.Write(hopLength[:]); err != nil { 356 return err 357 } 358 359 for _, hop := range r.hops { 360 fwdInfo := hop.ForwardingInfo() 361 if err := encodeFwdInfo(w, &fwdInfo); err != nil { 362 return err 363 } 364 } 365 366 return nil 367 } 368 369 func encodeFwdInfo(w io.Writer, f *hop.ForwardingInfo) error { 370 if err := binary.Write(w, binary.BigEndian, f.NextHop); err != nil { 371 return err 372 } 373 374 if err := binary.Write(w, binary.BigEndian, f.AmountToForward); err != nil { 375 return err 376 } 377 378 if err := binary.Write(w, binary.BigEndian, f.OutgoingCTLV); err != nil { 379 return err 380 } 381 382 return nil 383 } 384 385 var _ hop.Iterator = (*mockHopIterator)(nil) 386 387 // mockObfuscator mock implementation of the failure obfuscator which only 388 // encodes the failure and do not makes any onion obfuscation. 389 type mockObfuscator struct { 390 ogPacket *sphinx.OnionPacket 391 failure lnwire.FailureMessage 392 } 393 394 // NewMockObfuscator initializes a dummy mockObfuscator used for testing. 395 func NewMockObfuscator() hop.ErrorEncrypter { 396 return &mockObfuscator{} 397 } 398 399 func (o *mockObfuscator) OnionPacket() *sphinx.OnionPacket { 400 return o.ogPacket 401 } 402 403 func (o *mockObfuscator) Type() hop.EncrypterType { 404 return hop.EncrypterTypeMock 405 } 406 407 func (o *mockObfuscator) Encode(w io.Writer) error { 408 return nil 409 } 410 411 func (o *mockObfuscator) Decode(r io.Reader) error { 412 return nil 413 } 414 415 func (o *mockObfuscator) Reextract( 416 extracter hop.ErrorEncrypterExtracter) error { 417 418 return nil 419 } 420 421 var fakeHmac = []byte("hmachmachmachmachmachmachmachmac") 422 423 func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) ( 424 lnwire.OpaqueReason, error) { 425 426 o.failure = failure 427 428 var b bytes.Buffer 429 b.Write(fakeHmac) 430 431 if err := lnwire.EncodeFailure(&b, failure, 0); err != nil { 432 return nil, err 433 } 434 return b.Bytes(), nil 435 } 436 437 func (o *mockObfuscator) IntermediateEncrypt(reason lnwire.OpaqueReason) lnwire.OpaqueReason { 438 return reason 439 } 440 441 func (o *mockObfuscator) EncryptMalformedError(reason lnwire.OpaqueReason) lnwire.OpaqueReason { 442 var b bytes.Buffer 443 b.Write(fakeHmac) 444 445 b.Write(reason) 446 447 return b.Bytes() 448 } 449 450 // mockDeobfuscator mock implementation of the failure deobfuscator which 451 // only decodes the failure do not makes any onion obfuscation. 452 type mockDeobfuscator struct{} 453 454 func newMockDeobfuscator() ErrorDecrypter { 455 return &mockDeobfuscator{} 456 } 457 458 func (o *mockDeobfuscator) DecryptError(reason lnwire.OpaqueReason) ( 459 *ForwardingError, error) { 460 461 if !bytes.Equal(reason[:32], fakeHmac) { 462 return nil, errors.New("fake decryption error") 463 } 464 reason = reason[32:] 465 466 r := bytes.NewReader(reason) 467 failure, err := lnwire.DecodeFailure(r, 0) 468 if err != nil { 469 return nil, err 470 } 471 472 return NewForwardingError(failure, 1), nil 473 } 474 475 var _ ErrorDecrypter = (*mockDeobfuscator)(nil) 476 477 // mockIteratorDecoder test version of hop iterator decoder which decodes the 478 // encoded array of hops. 479 type mockIteratorDecoder struct { 480 mu sync.RWMutex 481 482 responses map[[32]byte][]hop.DecodeHopIteratorResponse 483 484 decodeFail bool 485 } 486 487 func newMockIteratorDecoder() *mockIteratorDecoder { 488 return &mockIteratorDecoder{ 489 responses: make(map[[32]byte][]hop.DecodeHopIteratorResponse), 490 } 491 } 492 493 func (p *mockIteratorDecoder) DecodeHopIterator(r io.Reader, rHash []byte, 494 cltv uint32) (hop.Iterator, lnwire.FailCode) { 495 496 var b [4]byte 497 _, err := r.Read(b[:]) 498 if err != nil { 499 return nil, lnwire.CodeTemporaryChannelFailure 500 } 501 hopLength := binary.BigEndian.Uint32(b[:]) 502 503 hops := make([]*hop.Payload, hopLength) 504 for i := uint32(0); i < hopLength; i++ { 505 var f hop.ForwardingInfo 506 if err := decodeFwdInfo(r, &f); err != nil { 507 return nil, lnwire.CodeTemporaryChannelFailure 508 } 509 510 var nextHopBytes [8]byte 511 binary.BigEndian.PutUint64(nextHopBytes[:], f.NextHop.ToUint64()) 512 513 hops[i] = hop.NewLegacyPayload(&sphinx.HopData{ 514 Realm: [1]byte{}, // hop.BitcoinNetwork 515 NextAddress: nextHopBytes, 516 ForwardAmount: uint64(f.AmountToForward), 517 OutgoingCltv: f.OutgoingCTLV, 518 }) 519 } 520 521 return newMockHopIterator(hops...), lnwire.CodeNone 522 } 523 524 func (p *mockIteratorDecoder) DecodeHopIterators(id []byte, 525 reqs []hop.DecodeHopIteratorRequest, _ bool) ( 526 []hop.DecodeHopIteratorResponse, error) { 527 528 idHash := sha256.Sum256(id) 529 530 p.mu.RLock() 531 if resps, ok := p.responses[idHash]; ok { 532 p.mu.RUnlock() 533 return resps, nil 534 } 535 p.mu.RUnlock() 536 537 batchSize := len(reqs) 538 539 resps := make([]hop.DecodeHopIteratorResponse, 0, batchSize) 540 for _, req := range reqs { 541 iterator, failcode := p.DecodeHopIterator( 542 req.OnionReader, req.RHash, req.IncomingCltv, 543 ) 544 545 if p.decodeFail { 546 failcode = lnwire.CodeTemporaryChannelFailure 547 } 548 549 resp := hop.DecodeHopIteratorResponse{ 550 HopIterator: iterator, 551 FailCode: failcode, 552 } 553 resps = append(resps, resp) 554 } 555 556 p.mu.Lock() 557 p.responses[idHash] = resps 558 p.mu.Unlock() 559 560 return resps, nil 561 } 562 563 func decodeFwdInfo(r io.Reader, f *hop.ForwardingInfo) error { 564 if err := binary.Read(r, binary.BigEndian, &f.NextHop); err != nil { 565 return err 566 } 567 568 if err := binary.Read(r, binary.BigEndian, &f.AmountToForward); err != nil { 569 return err 570 } 571 572 if err := binary.Read(r, binary.BigEndian, &f.OutgoingCTLV); err != nil { 573 return err 574 } 575 576 return nil 577 } 578 579 // messageInterceptor is function that handles the incoming peer messages and 580 // may decide should the peer skip the message or not. 581 type messageInterceptor func(m lnwire.Message) (bool, error) 582 583 // Record is used to set the function which will be triggered when new 584 // lnwire message was received. 585 func (s *mockServer) intersect(f messageInterceptor) { 586 s.interceptorFuncs = append(s.interceptorFuncs, f) 587 } 588 589 func (s *mockServer) SendMessage(sync bool, msgs ...lnwire.Message) error { 590 591 for _, msg := range msgs { 592 select { 593 case s.messages <- msg: 594 case <-s.quit: 595 return errors.New("server is stopped") 596 } 597 } 598 599 return nil 600 } 601 602 func (s *mockServer) SendMessageLazy(sync bool, msgs ...lnwire.Message) error { 603 panic("not implemented") 604 } 605 606 func (s *mockServer) readHandler(message lnwire.Message) error { 607 var targetChan lnwire.ChannelID 608 609 switch msg := message.(type) { 610 case *lnwire.UpdateAddHTLC: 611 targetChan = msg.ChanID 612 case *lnwire.UpdateFulfillHTLC: 613 targetChan = msg.ChanID 614 case *lnwire.UpdateFailHTLC: 615 targetChan = msg.ChanID 616 case *lnwire.UpdateFailMalformedHTLC: 617 targetChan = msg.ChanID 618 case *lnwire.RevokeAndAck: 619 targetChan = msg.ChanID 620 case *lnwire.CommitSig: 621 targetChan = msg.ChanID 622 case *lnwire.ChannelReady: 623 // Ignore 624 return nil 625 case *lnwire.ChannelReestablish: 626 targetChan = msg.ChanID 627 case *lnwire.UpdateFee: 628 targetChan = msg.ChanID 629 case *lnwire.Stfu: 630 targetChan = msg.ChanID 631 default: 632 return fmt.Errorf("unknown message type: %T", msg) 633 } 634 635 // Dispatch the commitment update message to the proper channel link 636 // dedicated to this channel. If the link is not found, we will discard 637 // the message. 638 link, err := s.htlcSwitch.GetLink(targetChan) 639 if err != nil { 640 return nil 641 } 642 643 // Create goroutine for this, in order to be able to properly stop 644 // the server when handler stacked (server unavailable) 645 link.HandleChannelUpdate(message) 646 647 return nil 648 } 649 650 func (s *mockServer) PubKey() [33]byte { 651 return s.id 652 } 653 654 func (s *mockServer) IdentityKey() *btcec.PublicKey { 655 pubkey, _ := btcec.ParsePubKey(s.id[:]) 656 return pubkey 657 } 658 659 func (s *mockServer) Address() net.Addr { 660 return nil 661 } 662 663 func (s *mockServer) AddNewChannel(channel *lnpeer.NewChannel, 664 cancel <-chan struct{}) error { 665 666 return nil 667 } 668 669 func (s *mockServer) AddPendingChannel(_ lnwire.ChannelID, 670 cancel <-chan struct{}) error { 671 672 return nil 673 } 674 675 func (s *mockServer) RemovePendingChannel(_ lnwire.ChannelID) error { 676 return nil 677 } 678 679 func (s *mockServer) WipeChannel(*wire.OutPoint) {} 680 681 func (s *mockServer) LocalFeatures() *lnwire.FeatureVector { 682 return nil 683 } 684 685 func (s *mockServer) RemoteFeatures() *lnwire.FeatureVector { 686 return nil 687 } 688 689 func (s *mockServer) Disconnect(err error) {} 690 691 func (s *mockServer) Stop() error { 692 if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) { 693 return nil 694 } 695 696 close(s.quit) 697 s.wg.Wait() 698 699 return nil 700 } 701 702 func (s *mockServer) String() string { 703 return s.name 704 } 705 706 type mockChannelLink struct { 707 htlcSwitch *Switch 708 709 shortChanID lnwire.ShortChannelID 710 711 // Only used for zero-conf channels. 712 realScid lnwire.ShortChannelID 713 714 aliases []lnwire.ShortChannelID 715 716 chanID lnwire.ChannelID 717 718 peer lnpeer.Peer 719 720 mailBox MailBox 721 722 packets chan *htlcPacket 723 724 eligible bool 725 726 unadvertised bool 727 728 zeroConf bool 729 730 optionFeature bool 731 732 htlcID uint64 733 734 checkHtlcTransitResult *LinkError 735 736 checkHtlcForwardResult *LinkError 737 738 failAliasUpdate func(sid lnwire.ShortChannelID, 739 incoming bool) *lnwire.ChannelUpdate1 740 741 confirmedZC bool 742 } 743 744 // completeCircuit is a helper method for adding the finalized payment circuit 745 // to the switch's circuit map. In testing, this should be executed after 746 // receiving an htlc from the downstream packets channel. 747 func (f *mockChannelLink) completeCircuit(pkt *htlcPacket) error { 748 switch htlc := pkt.htlc.(type) { 749 case *lnwire.UpdateAddHTLC: 750 pkt.outgoingChanID = f.shortChanID 751 pkt.outgoingHTLCID = f.htlcID 752 htlc.ID = f.htlcID 753 754 keystone := Keystone{pkt.inKey(), pkt.outKey()} 755 err := f.htlcSwitch.circuits.OpenCircuits(keystone) 756 if err != nil { 757 return err 758 } 759 760 f.htlcID++ 761 762 case *lnwire.UpdateFulfillHTLC, *lnwire.UpdateFailHTLC: 763 if pkt.circuit != nil { 764 err := f.htlcSwitch.teardownCircuit(pkt) 765 if err != nil { 766 return err 767 } 768 } 769 } 770 771 f.mailBox.AckPacket(pkt.inKey()) 772 773 return nil 774 } 775 776 func (f *mockChannelLink) deleteCircuit(pkt *htlcPacket) error { 777 return f.htlcSwitch.circuits.DeleteCircuits(pkt.inKey()) 778 } 779 780 func newMockChannelLink(htlcSwitch *Switch, chanID lnwire.ChannelID, 781 shortChanID, realScid lnwire.ShortChannelID, peer lnpeer.Peer, 782 eligible, unadvertised, zeroConf, optionFeature bool, 783 ) *mockChannelLink { 784 785 aliases := make([]lnwire.ShortChannelID, 0) 786 var realConfirmed bool 787 788 if zeroConf { 789 aliases = append(aliases, shortChanID) 790 } 791 792 if realScid != hop.Source { 793 realConfirmed = true 794 } 795 796 return &mockChannelLink{ 797 htlcSwitch: htlcSwitch, 798 chanID: chanID, 799 shortChanID: shortChanID, 800 realScid: realScid, 801 peer: peer, 802 eligible: eligible, 803 unadvertised: unadvertised, 804 zeroConf: zeroConf, 805 optionFeature: optionFeature, 806 aliases: aliases, 807 confirmedZC: realConfirmed, 808 } 809 } 810 811 // addAlias is not part of any interface method. 812 func (f *mockChannelLink) addAlias(alias lnwire.ShortChannelID) { 813 f.aliases = append(f.aliases, alias) 814 } 815 816 func (f *mockChannelLink) handleSwitchPacket(pkt *htlcPacket) error { 817 f.mailBox.AddPacket(pkt) 818 return nil 819 } 820 821 func (f *mockChannelLink) getDustSum(whoseCommit lntypes.ChannelParty, 822 dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi { 823 824 return 0 825 } 826 827 func (f *mockChannelLink) getFeeRate() chainfee.SatPerKWeight { 828 return 0 829 } 830 831 func (f *mockChannelLink) getDustClosure() dustClosure { 832 dustLimit := btcutil.Amount(400) 833 return dustHelper( 834 channeldb.SingleFunderTweaklessBit, dustLimit, dustLimit, 835 ) 836 } 837 838 func (f *mockChannelLink) getCommitFee(remote bool) btcutil.Amount { 839 return 0 840 } 841 842 func (f *mockChannelLink) HandleChannelUpdate(lnwire.Message) { 843 } 844 845 func (f *mockChannelLink) UpdateForwardingPolicy(_ models.ForwardingPolicy) { 846 } 847 func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi, 848 lnwire.MilliSatoshi, uint32, uint32, models.InboundFee, uint32, 849 lnwire.ShortChannelID, lnwire.CustomRecords) *LinkError { 850 851 return f.checkHtlcForwardResult 852 } 853 854 func (f *mockChannelLink) CheckHtlcTransit(payHash [32]byte, 855 amt lnwire.MilliSatoshi, timeout uint32, 856 heightNow uint32, _ lnwire.CustomRecords) *LinkError { 857 858 return f.checkHtlcTransitResult 859 } 860 861 func (f *mockChannelLink) Stats() ( 862 uint64, lnwire.MilliSatoshi, lnwire.MilliSatoshi) { 863 864 return 0, 0, 0 865 } 866 867 func (f *mockChannelLink) AttachMailBox(mailBox MailBox) { 868 f.mailBox = mailBox 869 f.packets = mailBox.PacketOutBox() 870 mailBox.SetDustClosure(f.getDustClosure()) 871 } 872 873 func (f *mockChannelLink) attachFailAliasUpdate(closure func( 874 sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate1) { 875 876 f.failAliasUpdate = closure 877 } 878 879 func (f *mockChannelLink) getAliases() []lnwire.ShortChannelID { 880 return f.aliases 881 } 882 883 func (f *mockChannelLink) isZeroConf() bool { 884 return f.zeroConf 885 } 886 887 func (f *mockChannelLink) negotiatedAliasFeature() bool { 888 return f.optionFeature 889 } 890 891 func (f *mockChannelLink) confirmedScid() lnwire.ShortChannelID { 892 return f.realScid 893 } 894 895 func (f *mockChannelLink) zeroConfConfirmed() bool { 896 return f.confirmedZC 897 } 898 899 func (f *mockChannelLink) Start() error { 900 f.mailBox.ResetMessages() 901 f.mailBox.ResetPackets() 902 return nil 903 } 904 905 func (f *mockChannelLink) ChanID() lnwire.ChannelID { 906 return f.chanID 907 } 908 909 func (f *mockChannelLink) ShortChanID() lnwire.ShortChannelID { 910 return f.shortChanID 911 } 912 913 func (f *mockChannelLink) Bandwidth() lnwire.MilliSatoshi { 914 return 99999999 915 } 916 917 func (f *mockChannelLink) PeerPubKey() [33]byte { 918 return f.peer.PubKey() 919 } 920 921 func (f *mockChannelLink) ChannelPoint() wire.OutPoint { 922 return wire.OutPoint{} 923 } 924 925 func (f *mockChannelLink) Stop() {} 926 func (f *mockChannelLink) EligibleToForward() bool { return f.eligible } 927 func (f *mockChannelLink) MayAddOutgoingHtlc(lnwire.MilliSatoshi) error { return nil } 928 func (f *mockChannelLink) setLiveShortChanID(sid lnwire.ShortChannelID) { f.shortChanID = sid } 929 func (f *mockChannelLink) IsUnadvertised() bool { return f.unadvertised } 930 func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) { 931 f.eligible = true 932 return f.shortChanID, nil 933 } 934 935 func (f *mockChannelLink) EnableAdds(linkDirection LinkDirection) bool { 936 // TODO(proofofkeags): Implement 937 return true 938 } 939 940 func (f *mockChannelLink) DisableAdds(linkDirection LinkDirection) bool { 941 // TODO(proofofkeags): Implement 942 return true 943 } 944 func (f *mockChannelLink) IsFlushing(linkDirection LinkDirection) bool { 945 // TODO(proofofkeags): Implement 946 return false 947 } 948 func (f *mockChannelLink) OnFlushedOnce(func()) { 949 // TODO(proofofkeags): Implement 950 } 951 func (f *mockChannelLink) OnCommitOnce(LinkDirection, func()) { 952 // TODO(proofofkeags): Implement 953 } 954 func (f *mockChannelLink) InitStfu() <-chan fn.Result[lntypes.ChannelParty] { 955 // TODO(proofofkeags): Implement 956 c := make(chan fn.Result[lntypes.ChannelParty], 1) 957 958 c <- fn.Errf[lntypes.ChannelParty]("InitStfu not implemented") 959 960 return c 961 } 962 963 func (f *mockChannelLink) FundingCustomBlob() fn.Option[tlv.Blob] { 964 return fn.None[tlv.Blob]() 965 } 966 967 func (f *mockChannelLink) CommitmentCustomBlob() fn.Option[tlv.Blob] { 968 return fn.None[tlv.Blob]() 969 } 970 971 // AuxBandwidth returns the bandwidth that can be used for a channel, 972 // expressed in milli-satoshi. This might be different from the regular 973 // BTC bandwidth for custom channels. This will always return fn.None() 974 // for a regular (non-custom) channel. 975 func (f *mockChannelLink) AuxBandwidth(lnwire.MilliSatoshi, 976 lnwire.ShortChannelID, 977 fn.Option[tlv.Blob], AuxTrafficShaper) fn.Result[OptionalBandwidth] { 978 979 return fn.Ok(OptionalBandwidth{}) 980 } 981 982 var _ ChannelLink = (*mockChannelLink)(nil) 983 984 const testInvoiceCltvExpiry = 6 985 986 type mockInvoiceRegistry struct { 987 settleChan chan lntypes.Hash 988 989 registry *invoices.InvoiceRegistry 990 } 991 992 type mockChainNotifier struct { 993 chainntnfs.ChainNotifier 994 } 995 996 // RegisterBlockEpochNtfn mocks a successful call to register block 997 // notifications. 998 func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) ( 999 *chainntnfs.BlockEpochEvent, error) { 1000 1001 return &chainntnfs.BlockEpochEvent{ 1002 Cancel: func() {}, 1003 }, nil 1004 } 1005 1006 func newMockRegistry(t testing.TB) *mockInvoiceRegistry { 1007 cdb := channeldb.OpenForTesting(t, t.TempDir()) 1008 1009 modifierMock := &invoices.MockHtlcModifier{} 1010 registry := invoices.NewRegistry( 1011 cdb, 1012 invoices.NewInvoiceExpiryWatcher( 1013 clock.NewDefaultClock(), 0, 0, nil, 1014 &mockChainNotifier{}, 1015 ), 1016 &invoices.RegistryConfig{ 1017 FinalCltvRejectDelta: 5, 1018 HtlcInterceptor: modifierMock, 1019 }, 1020 ) 1021 registry.Start() 1022 1023 return &mockInvoiceRegistry{ 1024 registry: registry, 1025 } 1026 } 1027 1028 func (i *mockInvoiceRegistry) LookupInvoice(ctx context.Context, 1029 rHash lntypes.Hash) (invoices.Invoice, error) { 1030 1031 return i.registry.LookupInvoice(ctx, rHash) 1032 } 1033 1034 func (i *mockInvoiceRegistry) SettleHodlInvoice( 1035 ctx context.Context, preimage lntypes.Preimage) error { 1036 1037 return i.registry.SettleHodlInvoice(ctx, preimage) 1038 } 1039 1040 func (i *mockInvoiceRegistry) NotifyExitHopHtlc(rhash lntypes.Hash, 1041 amt lnwire.MilliSatoshi, expiry uint32, currentHeight int32, 1042 circuitKey models.CircuitKey, hodlChan chan<- interface{}, 1043 wireCustomRecords lnwire.CustomRecords, 1044 payload invoices.Payload) (invoices.HtlcResolution, error) { 1045 1046 event, err := i.registry.NotifyExitHopHtlc( 1047 rhash, amt, expiry, currentHeight, circuitKey, 1048 hodlChan, wireCustomRecords, payload, 1049 ) 1050 if err != nil { 1051 return nil, err 1052 } 1053 if i.settleChan != nil { 1054 i.settleChan <- rhash 1055 } 1056 1057 return event, nil 1058 } 1059 1060 func (i *mockInvoiceRegistry) CancelInvoice(ctx context.Context, 1061 payHash lntypes.Hash) error { 1062 1063 return i.registry.CancelInvoice(ctx, payHash) 1064 } 1065 1066 func (i *mockInvoiceRegistry) AddInvoice(ctx context.Context, 1067 invoice invoices.Invoice, paymentHash lntypes.Hash) error { 1068 1069 _, err := i.registry.AddInvoice(ctx, &invoice, paymentHash) 1070 return err 1071 } 1072 1073 func (i *mockInvoiceRegistry) HodlUnsubscribeAll( 1074 subscriber chan<- interface{}) { 1075 1076 i.registry.HodlUnsubscribeAll(subscriber) 1077 } 1078 1079 var _ InvoiceDatabase = (*mockInvoiceRegistry)(nil) 1080 1081 type mockCircuitMap struct { 1082 lookup chan *PaymentCircuit 1083 } 1084 1085 var _ CircuitMap = (*mockCircuitMap)(nil) 1086 1087 func (m *mockCircuitMap) OpenCircuits(...Keystone) error { 1088 return nil 1089 } 1090 1091 func (m *mockCircuitMap) TrimOpenCircuits(chanID lnwire.ShortChannelID, 1092 start uint64) error { 1093 return nil 1094 } 1095 1096 func (m *mockCircuitMap) DeleteCircuits(inKeys ...CircuitKey) error { 1097 return nil 1098 } 1099 1100 func (m *mockCircuitMap) CommitCircuits( 1101 circuit ...*PaymentCircuit) (*CircuitFwdActions, error) { 1102 1103 return nil, nil 1104 } 1105 1106 func (m *mockCircuitMap) CloseCircuit(outKey CircuitKey) (*PaymentCircuit, 1107 error) { 1108 return nil, nil 1109 } 1110 1111 func (m *mockCircuitMap) FailCircuit(inKey CircuitKey) (*PaymentCircuit, 1112 error) { 1113 return nil, nil 1114 } 1115 1116 func (m *mockCircuitMap) LookupCircuit(inKey CircuitKey) *PaymentCircuit { 1117 return <-m.lookup 1118 } 1119 1120 func (m *mockCircuitMap) LookupOpenCircuit(outKey CircuitKey) *PaymentCircuit { 1121 return nil 1122 } 1123 1124 func (m *mockCircuitMap) LookupByPaymentHash(hash [32]byte) []*PaymentCircuit { 1125 return nil 1126 } 1127 1128 func (m *mockCircuitMap) NumPending() int { 1129 return 0 1130 } 1131 1132 func (m *mockCircuitMap) NumOpen() int { 1133 return 0 1134 } 1135 1136 type mockOnionErrorDecryptor struct { 1137 sourceIdx int 1138 message []byte 1139 err error 1140 } 1141 1142 func (m *mockOnionErrorDecryptor) DecryptError(encryptedData []byte) ( 1143 *sphinx.DecryptedError, error) { 1144 1145 return &sphinx.DecryptedError{ 1146 SenderIdx: m.sourceIdx, 1147 Message: m.message, 1148 }, m.err 1149 } 1150 1151 var _ htlcNotifier = (*mockHTLCNotifier)(nil) 1152 1153 type mockHTLCNotifier struct { 1154 htlcNotifier //nolint:unused 1155 } 1156 1157 func (h *mockHTLCNotifier) NotifyForwardingEvent(key HtlcKey, info HtlcInfo, 1158 eventType HtlcEventType) { 1159 1160 } 1161 1162 func (h *mockHTLCNotifier) NotifyLinkFailEvent(key HtlcKey, info HtlcInfo, 1163 eventType HtlcEventType, linkErr *LinkError, 1164 incoming bool) { 1165 1166 } 1167 1168 func (h *mockHTLCNotifier) NotifyForwardingFailEvent(key HtlcKey, 1169 eventType HtlcEventType) { 1170 1171 } 1172 1173 func (h *mockHTLCNotifier) NotifySettleEvent(key HtlcKey, 1174 preimage lntypes.Preimage, eventType HtlcEventType) { 1175 1176 } 1177 1178 func (h *mockHTLCNotifier) NotifyFinalHtlcEvent(key models.CircuitKey, 1179 info channeldb.FinalHtlcInfo) { 1180 1181 }