/ htlcswitch / mock.go
mock.go
   1  package htlcswitch
   2  
   3  import (
   4  	"bytes"
   5  	"context"
   6  	"crypto/sha256"
   7  	"encoding/binary"
   8  	"errors"
   9  	"fmt"
  10  	"io"
  11  	"net"
  12  	"path/filepath"
  13  	"sync"
  14  	"sync/atomic"
  15  	"testing"
  16  	"time"
  17  
  18  	"github.com/btcsuite/btcd/btcec/v2"
  19  	"github.com/btcsuite/btcd/btcec/v2/ecdsa"
  20  	"github.com/btcsuite/btcd/btcutil"
  21  	"github.com/btcsuite/btcd/wire"
  22  	sphinx "github.com/lightningnetwork/lightning-onion"
  23  	"github.com/lightningnetwork/lnd/chainntnfs"
  24  	"github.com/lightningnetwork/lnd/channeldb"
  25  	"github.com/lightningnetwork/lnd/clock"
  26  	"github.com/lightningnetwork/lnd/contractcourt"
  27  	"github.com/lightningnetwork/lnd/fn/v2"
  28  	"github.com/lightningnetwork/lnd/graph/db/models"
  29  	"github.com/lightningnetwork/lnd/htlcswitch/hop"
  30  	"github.com/lightningnetwork/lnd/invoices"
  31  	"github.com/lightningnetwork/lnd/lnpeer"
  32  	"github.com/lightningnetwork/lnd/lntest/mock"
  33  	"github.com/lightningnetwork/lnd/lntypes"
  34  	"github.com/lightningnetwork/lnd/lnwallet/chainfee"
  35  	"github.com/lightningnetwork/lnd/lnwire"
  36  	"github.com/lightningnetwork/lnd/ticker"
  37  	"github.com/lightningnetwork/lnd/tlv"
  38  )
  39  
  40  func isAlias(scid lnwire.ShortChannelID) bool {
  41  	return scid.BlockHeight >= 16_000_000 && scid.BlockHeight < 16_250_000
  42  }
  43  
  44  type mockPreimageCache struct {
  45  	sync.Mutex
  46  	preimageMap map[lntypes.Hash]lntypes.Preimage
  47  }
  48  
  49  func newMockPreimageCache() *mockPreimageCache {
  50  	return &mockPreimageCache{
  51  		preimageMap: make(map[lntypes.Hash]lntypes.Preimage),
  52  	}
  53  }
  54  
  55  func (m *mockPreimageCache) LookupPreimage(
  56  	hash lntypes.Hash) (lntypes.Preimage, bool) {
  57  
  58  	m.Lock()
  59  	defer m.Unlock()
  60  
  61  	p, ok := m.preimageMap[hash]
  62  	return p, ok
  63  }
  64  
  65  func (m *mockPreimageCache) AddPreimages(preimages ...lntypes.Preimage) error {
  66  	m.Lock()
  67  	defer m.Unlock()
  68  
  69  	for _, preimage := range preimages {
  70  		m.preimageMap[preimage.Hash()] = preimage
  71  	}
  72  
  73  	return nil
  74  }
  75  
  76  func (m *mockPreimageCache) SubscribeUpdates(
  77  	chanID lnwire.ShortChannelID, htlc *channeldb.HTLC,
  78  	payload *hop.Payload,
  79  	nextHopOnionBlob []byte) (*contractcourt.WitnessSubscription, error) {
  80  
  81  	return nil, nil
  82  }
  83  
  84  // TODO(yy): replace it with chainfee.MockEstimator.
  85  type mockFeeEstimator struct {
  86  	byteFeeIn chan chainfee.SatPerKWeight
  87  	relayFee  chan chainfee.SatPerKWeight
  88  
  89  	quit chan struct{}
  90  }
  91  
  92  func newMockFeeEstimator() *mockFeeEstimator {
  93  	return &mockFeeEstimator{
  94  		byteFeeIn: make(chan chainfee.SatPerKWeight),
  95  		relayFee:  make(chan chainfee.SatPerKWeight),
  96  		quit:      make(chan struct{}),
  97  	}
  98  }
  99  
 100  func (m *mockFeeEstimator) EstimateFeePerKW(
 101  	numBlocks uint32) (chainfee.SatPerKWeight, error) {
 102  
 103  	select {
 104  	case feeRate := <-m.byteFeeIn:
 105  		return feeRate, nil
 106  	case <-m.quit:
 107  		return 0, fmt.Errorf("exiting")
 108  	}
 109  }
 110  
 111  func (m *mockFeeEstimator) RelayFeePerKW() chainfee.SatPerKWeight {
 112  	select {
 113  	case feeRate := <-m.relayFee:
 114  		return feeRate
 115  	case <-m.quit:
 116  		return 0
 117  	}
 118  }
 119  
 120  func (m *mockFeeEstimator) Start() error {
 121  	return nil
 122  }
 123  func (m *mockFeeEstimator) Stop() error {
 124  	close(m.quit)
 125  	return nil
 126  }
 127  
 128  var _ chainfee.Estimator = (*mockFeeEstimator)(nil)
 129  
 130  type mockForwardingLog struct {
 131  	sync.Mutex
 132  
 133  	events map[time.Time]channeldb.ForwardingEvent
 134  }
 135  
 136  func (m *mockForwardingLog) AddForwardingEvents(events []channeldb.ForwardingEvent) error {
 137  	m.Lock()
 138  	defer m.Unlock()
 139  
 140  	for _, event := range events {
 141  		m.events[event.Timestamp] = event
 142  	}
 143  
 144  	return nil
 145  }
 146  
 147  type mockServer struct {
 148  	started  int32 // To be used atomically.
 149  	shutdown int32 // To be used atomically.
 150  	wg       sync.WaitGroup
 151  	quit     chan struct{}
 152  
 153  	t testing.TB
 154  
 155  	name             string
 156  	messages         chan lnwire.Message
 157  	protocolTraceMtx sync.Mutex
 158  	protocolTrace    []lnwire.Message
 159  
 160  	id         [33]byte
 161  	htlcSwitch *Switch
 162  
 163  	registry         *mockInvoiceRegistry
 164  	pCache           *mockPreimageCache
 165  	interceptorFuncs []messageInterceptor
 166  }
 167  
 168  var _ lnpeer.Peer = (*mockServer)(nil)
 169  
 170  func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) {
 171  	signAliasUpdate := func(u *lnwire.ChannelUpdate1) (*ecdsa.Signature,
 172  		error) {
 173  
 174  		return testSig, nil
 175  	}
 176  
 177  	cfg := Config{
 178  		DB:                   db,
 179  		FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels,
 180  		FetchAllChannels:     db.ChannelStateDB().FetchAllChannels,
 181  		FetchClosedChannels:  db.ChannelStateDB().FetchClosedChannels,
 182  		SwitchPackager:       channeldb.NewSwitchPackager(),
 183  		FwdingLog: &mockForwardingLog{
 184  			events: make(map[time.Time]channeldb.ForwardingEvent),
 185  		},
 186  		FetchLastChannelUpdate: func(scid lnwire.ShortChannelID) (
 187  			*lnwire.ChannelUpdate1, error) {
 188  
 189  			return &lnwire.ChannelUpdate1{
 190  				ShortChannelID: scid,
 191  			}, nil
 192  		},
 193  		Notifier: &mock.ChainNotifier{
 194  			SpendChan: make(chan *chainntnfs.SpendDetail),
 195  			EpochChan: make(chan *chainntnfs.BlockEpoch),
 196  			ConfChan:  make(chan *chainntnfs.TxConfirmation),
 197  		},
 198  		FwdEventTicker: ticker.NewForce(
 199  			DefaultFwdEventInterval,
 200  		),
 201  		LogEventTicker:         ticker.NewForce(DefaultLogInterval),
 202  		AckEventTicker:         ticker.NewForce(DefaultAckInterval),
 203  		HtlcNotifier:           &mockHTLCNotifier{},
 204  		Clock:                  clock.NewDefaultClock(),
 205  		MailboxDeliveryTimeout: time.Hour,
 206  		MaxFeeExposure:         DefaultMaxFeeExposure,
 207  		SignAliasUpdate:        signAliasUpdate,
 208  		IsAlias:                isAlias,
 209  	}
 210  
 211  	return New(cfg, startingHeight)
 212  }
 213  
 214  func initSwitchWithTempDB(t testing.TB, startingHeight uint32) (*Switch,
 215  	error) {
 216  
 217  	tempPath := filepath.Join(t.TempDir(), "switchdb")
 218  	db := channeldb.OpenForTesting(t, tempPath)
 219  
 220  	s, err := initSwitchWithDB(startingHeight, db)
 221  	if err != nil {
 222  		return nil, err
 223  	}
 224  
 225  	return s, nil
 226  }
 227  
 228  func newMockServer(t testing.TB, name string, startingHeight uint32,
 229  	db *channeldb.DB, defaultDelta uint32) (*mockServer, error) {
 230  
 231  	var id [33]byte
 232  	h := sha256.Sum256([]byte(name))
 233  	copy(id[:], h[:])
 234  
 235  	pCache := newMockPreimageCache()
 236  
 237  	var (
 238  		htlcSwitch *Switch
 239  		err        error
 240  	)
 241  	if db == nil {
 242  		htlcSwitch, err = initSwitchWithTempDB(t, startingHeight)
 243  	} else {
 244  		htlcSwitch, err = initSwitchWithDB(startingHeight, db)
 245  	}
 246  	if err != nil {
 247  		return nil, err
 248  	}
 249  
 250  	t.Cleanup(func() { _ = htlcSwitch.Stop() })
 251  
 252  	registry := newMockRegistry(t)
 253  
 254  	return &mockServer{
 255  		t:                t,
 256  		id:               id,
 257  		name:             name,
 258  		messages:         make(chan lnwire.Message, 3000),
 259  		quit:             make(chan struct{}),
 260  		registry:         registry,
 261  		htlcSwitch:       htlcSwitch,
 262  		pCache:           pCache,
 263  		interceptorFuncs: make([]messageInterceptor, 0),
 264  	}, nil
 265  }
 266  
 267  func (s *mockServer) Start() error {
 268  	if !atomic.CompareAndSwapInt32(&s.started, 0, 1) {
 269  		return errors.New("mock server already started")
 270  	}
 271  
 272  	if err := s.htlcSwitch.Start(); err != nil {
 273  		return err
 274  	}
 275  
 276  	s.wg.Add(1)
 277  	go func() {
 278  		defer s.wg.Done()
 279  
 280  		defer func() {
 281  			s.htlcSwitch.Stop()
 282  		}()
 283  
 284  		for {
 285  			select {
 286  			case msg := <-s.messages:
 287  				s.protocolTraceMtx.Lock()
 288  				s.protocolTrace = append(s.protocolTrace, msg)
 289  				s.protocolTraceMtx.Unlock()
 290  
 291  				var shouldSkip bool
 292  
 293  				for _, interceptor := range s.interceptorFuncs {
 294  					skip, err := interceptor(msg)
 295  					if err != nil {
 296  						s.t.Fatalf("%v: error in the "+
 297  							"interceptor: %v", s.name, err)
 298  						return
 299  					}
 300  					shouldSkip = shouldSkip || skip
 301  				}
 302  
 303  				if shouldSkip {
 304  					continue
 305  				}
 306  
 307  				if err := s.readHandler(msg); err != nil {
 308  					s.t.Fatal(err)
 309  					return
 310  				}
 311  			case <-s.quit:
 312  				return
 313  			}
 314  		}
 315  	}()
 316  
 317  	return nil
 318  }
 319  
 320  func (s *mockServer) QuitSignal() <-chan struct{} {
 321  	return s.quit
 322  }
 323  
 324  // mockHopIterator represents the test version of hop iterator which instead
 325  // of encrypting the path in onion blob just stores the path as a list of hops.
 326  type mockHopIterator struct {
 327  	hops []*hop.Payload
 328  }
 329  
 330  func newMockHopIterator(hops ...*hop.Payload) hop.Iterator {
 331  	return &mockHopIterator{hops: hops}
 332  }
 333  
 334  func (r *mockHopIterator) HopPayload() (*hop.Payload, hop.RouteRole, error) {
 335  	h := r.hops[0]
 336  	r.hops = r.hops[1:]
 337  	return h, hop.RouteRoleCleartext, nil
 338  }
 339  
 340  func (r *mockHopIterator) ExtraOnionBlob() []byte {
 341  	return nil
 342  }
 343  
 344  func (r *mockHopIterator) ExtractErrorEncrypter(
 345  	extracter hop.ErrorEncrypterExtracter, _ bool) (hop.ErrorEncrypter,
 346  	lnwire.FailCode) {
 347  
 348  	return extracter(nil)
 349  }
 350  
 351  func (r *mockHopIterator) EncodeNextHop(w io.Writer) error {
 352  	var hopLength [4]byte
 353  	binary.BigEndian.PutUint32(hopLength[:], uint32(len(r.hops)))
 354  
 355  	if _, err := w.Write(hopLength[:]); err != nil {
 356  		return err
 357  	}
 358  
 359  	for _, hop := range r.hops {
 360  		fwdInfo := hop.ForwardingInfo()
 361  		if err := encodeFwdInfo(w, &fwdInfo); err != nil {
 362  			return err
 363  		}
 364  	}
 365  
 366  	return nil
 367  }
 368  
 369  func encodeFwdInfo(w io.Writer, f *hop.ForwardingInfo) error {
 370  	if err := binary.Write(w, binary.BigEndian, f.NextHop); err != nil {
 371  		return err
 372  	}
 373  
 374  	if err := binary.Write(w, binary.BigEndian, f.AmountToForward); err != nil {
 375  		return err
 376  	}
 377  
 378  	if err := binary.Write(w, binary.BigEndian, f.OutgoingCTLV); err != nil {
 379  		return err
 380  	}
 381  
 382  	return nil
 383  }
 384  
 385  var _ hop.Iterator = (*mockHopIterator)(nil)
 386  
 387  // mockObfuscator mock implementation of the failure obfuscator which only
 388  // encodes the failure and do not makes any onion obfuscation.
 389  type mockObfuscator struct {
 390  	ogPacket *sphinx.OnionPacket
 391  	failure  lnwire.FailureMessage
 392  }
 393  
 394  // NewMockObfuscator initializes a dummy mockObfuscator used for testing.
 395  func NewMockObfuscator() hop.ErrorEncrypter {
 396  	return &mockObfuscator{}
 397  }
 398  
 399  func (o *mockObfuscator) OnionPacket() *sphinx.OnionPacket {
 400  	return o.ogPacket
 401  }
 402  
 403  func (o *mockObfuscator) Type() hop.EncrypterType {
 404  	return hop.EncrypterTypeMock
 405  }
 406  
 407  func (o *mockObfuscator) Encode(w io.Writer) error {
 408  	return nil
 409  }
 410  
 411  func (o *mockObfuscator) Decode(r io.Reader) error {
 412  	return nil
 413  }
 414  
 415  func (o *mockObfuscator) Reextract(
 416  	extracter hop.ErrorEncrypterExtracter) error {
 417  
 418  	return nil
 419  }
 420  
 421  var fakeHmac = []byte("hmachmachmachmachmachmachmachmac")
 422  
 423  func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) (
 424  	lnwire.OpaqueReason, error) {
 425  
 426  	o.failure = failure
 427  
 428  	var b bytes.Buffer
 429  	b.Write(fakeHmac)
 430  
 431  	if err := lnwire.EncodeFailure(&b, failure, 0); err != nil {
 432  		return nil, err
 433  	}
 434  	return b.Bytes(), nil
 435  }
 436  
 437  func (o *mockObfuscator) IntermediateEncrypt(reason lnwire.OpaqueReason) lnwire.OpaqueReason {
 438  	return reason
 439  }
 440  
 441  func (o *mockObfuscator) EncryptMalformedError(reason lnwire.OpaqueReason) lnwire.OpaqueReason {
 442  	var b bytes.Buffer
 443  	b.Write(fakeHmac)
 444  
 445  	b.Write(reason)
 446  
 447  	return b.Bytes()
 448  }
 449  
 450  // mockDeobfuscator mock implementation of the failure deobfuscator which
 451  // only decodes the failure do not makes any onion obfuscation.
 452  type mockDeobfuscator struct{}
 453  
 454  func newMockDeobfuscator() ErrorDecrypter {
 455  	return &mockDeobfuscator{}
 456  }
 457  
 458  func (o *mockDeobfuscator) DecryptError(reason lnwire.OpaqueReason) (
 459  	*ForwardingError, error) {
 460  
 461  	if !bytes.Equal(reason[:32], fakeHmac) {
 462  		return nil, errors.New("fake decryption error")
 463  	}
 464  	reason = reason[32:]
 465  
 466  	r := bytes.NewReader(reason)
 467  	failure, err := lnwire.DecodeFailure(r, 0)
 468  	if err != nil {
 469  		return nil, err
 470  	}
 471  
 472  	return NewForwardingError(failure, 1), nil
 473  }
 474  
 475  var _ ErrorDecrypter = (*mockDeobfuscator)(nil)
 476  
 477  // mockIteratorDecoder test version of hop iterator decoder which decodes the
 478  // encoded array of hops.
 479  type mockIteratorDecoder struct {
 480  	mu sync.RWMutex
 481  
 482  	responses map[[32]byte][]hop.DecodeHopIteratorResponse
 483  
 484  	decodeFail bool
 485  }
 486  
 487  func newMockIteratorDecoder() *mockIteratorDecoder {
 488  	return &mockIteratorDecoder{
 489  		responses: make(map[[32]byte][]hop.DecodeHopIteratorResponse),
 490  	}
 491  }
 492  
 493  func (p *mockIteratorDecoder) DecodeHopIterator(r io.Reader, rHash []byte,
 494  	cltv uint32) (hop.Iterator, lnwire.FailCode) {
 495  
 496  	var b [4]byte
 497  	_, err := r.Read(b[:])
 498  	if err != nil {
 499  		return nil, lnwire.CodeTemporaryChannelFailure
 500  	}
 501  	hopLength := binary.BigEndian.Uint32(b[:])
 502  
 503  	hops := make([]*hop.Payload, hopLength)
 504  	for i := uint32(0); i < hopLength; i++ {
 505  		var f hop.ForwardingInfo
 506  		if err := decodeFwdInfo(r, &f); err != nil {
 507  			return nil, lnwire.CodeTemporaryChannelFailure
 508  		}
 509  
 510  		var nextHopBytes [8]byte
 511  		binary.BigEndian.PutUint64(nextHopBytes[:], f.NextHop.ToUint64())
 512  
 513  		hops[i] = hop.NewLegacyPayload(&sphinx.HopData{
 514  			Realm:         [1]byte{}, // hop.BitcoinNetwork
 515  			NextAddress:   nextHopBytes,
 516  			ForwardAmount: uint64(f.AmountToForward),
 517  			OutgoingCltv:  f.OutgoingCTLV,
 518  		})
 519  	}
 520  
 521  	return newMockHopIterator(hops...), lnwire.CodeNone
 522  }
 523  
 524  func (p *mockIteratorDecoder) DecodeHopIterators(id []byte,
 525  	reqs []hop.DecodeHopIteratorRequest, _ bool) (
 526  	[]hop.DecodeHopIteratorResponse, error) {
 527  
 528  	idHash := sha256.Sum256(id)
 529  
 530  	p.mu.RLock()
 531  	if resps, ok := p.responses[idHash]; ok {
 532  		p.mu.RUnlock()
 533  		return resps, nil
 534  	}
 535  	p.mu.RUnlock()
 536  
 537  	batchSize := len(reqs)
 538  
 539  	resps := make([]hop.DecodeHopIteratorResponse, 0, batchSize)
 540  	for _, req := range reqs {
 541  		iterator, failcode := p.DecodeHopIterator(
 542  			req.OnionReader, req.RHash, req.IncomingCltv,
 543  		)
 544  
 545  		if p.decodeFail {
 546  			failcode = lnwire.CodeTemporaryChannelFailure
 547  		}
 548  
 549  		resp := hop.DecodeHopIteratorResponse{
 550  			HopIterator: iterator,
 551  			FailCode:    failcode,
 552  		}
 553  		resps = append(resps, resp)
 554  	}
 555  
 556  	p.mu.Lock()
 557  	p.responses[idHash] = resps
 558  	p.mu.Unlock()
 559  
 560  	return resps, nil
 561  }
 562  
 563  func decodeFwdInfo(r io.Reader, f *hop.ForwardingInfo) error {
 564  	if err := binary.Read(r, binary.BigEndian, &f.NextHop); err != nil {
 565  		return err
 566  	}
 567  
 568  	if err := binary.Read(r, binary.BigEndian, &f.AmountToForward); err != nil {
 569  		return err
 570  	}
 571  
 572  	if err := binary.Read(r, binary.BigEndian, &f.OutgoingCTLV); err != nil {
 573  		return err
 574  	}
 575  
 576  	return nil
 577  }
 578  
 579  // messageInterceptor is function that handles the incoming peer messages and
 580  // may decide should the peer skip the message or not.
 581  type messageInterceptor func(m lnwire.Message) (bool, error)
 582  
 583  // Record is used to set the function which will be triggered when new
 584  // lnwire message was received.
 585  func (s *mockServer) intersect(f messageInterceptor) {
 586  	s.interceptorFuncs = append(s.interceptorFuncs, f)
 587  }
 588  
 589  func (s *mockServer) SendMessage(sync bool, msgs ...lnwire.Message) error {
 590  
 591  	for _, msg := range msgs {
 592  		select {
 593  		case s.messages <- msg:
 594  		case <-s.quit:
 595  			return errors.New("server is stopped")
 596  		}
 597  	}
 598  
 599  	return nil
 600  }
 601  
 602  func (s *mockServer) SendMessageLazy(sync bool, msgs ...lnwire.Message) error {
 603  	panic("not implemented")
 604  }
 605  
 606  func (s *mockServer) readHandler(message lnwire.Message) error {
 607  	var targetChan lnwire.ChannelID
 608  
 609  	switch msg := message.(type) {
 610  	case *lnwire.UpdateAddHTLC:
 611  		targetChan = msg.ChanID
 612  	case *lnwire.UpdateFulfillHTLC:
 613  		targetChan = msg.ChanID
 614  	case *lnwire.UpdateFailHTLC:
 615  		targetChan = msg.ChanID
 616  	case *lnwire.UpdateFailMalformedHTLC:
 617  		targetChan = msg.ChanID
 618  	case *lnwire.RevokeAndAck:
 619  		targetChan = msg.ChanID
 620  	case *lnwire.CommitSig:
 621  		targetChan = msg.ChanID
 622  	case *lnwire.ChannelReady:
 623  		// Ignore
 624  		return nil
 625  	case *lnwire.ChannelReestablish:
 626  		targetChan = msg.ChanID
 627  	case *lnwire.UpdateFee:
 628  		targetChan = msg.ChanID
 629  	case *lnwire.Stfu:
 630  		targetChan = msg.ChanID
 631  	default:
 632  		return fmt.Errorf("unknown message type: %T", msg)
 633  	}
 634  
 635  	// Dispatch the commitment update message to the proper channel link
 636  	// dedicated to this channel. If the link is not found, we will discard
 637  	// the message.
 638  	link, err := s.htlcSwitch.GetLink(targetChan)
 639  	if err != nil {
 640  		return nil
 641  	}
 642  
 643  	// Create goroutine for this, in order to be able to properly stop
 644  	// the server when handler stacked (server unavailable)
 645  	link.HandleChannelUpdate(message)
 646  
 647  	return nil
 648  }
 649  
 650  func (s *mockServer) PubKey() [33]byte {
 651  	return s.id
 652  }
 653  
 654  func (s *mockServer) IdentityKey() *btcec.PublicKey {
 655  	pubkey, _ := btcec.ParsePubKey(s.id[:])
 656  	return pubkey
 657  }
 658  
 659  func (s *mockServer) Address() net.Addr {
 660  	return nil
 661  }
 662  
 663  func (s *mockServer) AddNewChannel(channel *lnpeer.NewChannel,
 664  	cancel <-chan struct{}) error {
 665  
 666  	return nil
 667  }
 668  
 669  func (s *mockServer) AddPendingChannel(_ lnwire.ChannelID,
 670  	cancel <-chan struct{}) error {
 671  
 672  	return nil
 673  }
 674  
 675  func (s *mockServer) RemovePendingChannel(_ lnwire.ChannelID) error {
 676  	return nil
 677  }
 678  
 679  func (s *mockServer) WipeChannel(*wire.OutPoint) {}
 680  
 681  func (s *mockServer) LocalFeatures() *lnwire.FeatureVector {
 682  	return nil
 683  }
 684  
 685  func (s *mockServer) RemoteFeatures() *lnwire.FeatureVector {
 686  	return nil
 687  }
 688  
 689  func (s *mockServer) Disconnect(err error) {}
 690  
 691  func (s *mockServer) Stop() error {
 692  	if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) {
 693  		return nil
 694  	}
 695  
 696  	close(s.quit)
 697  	s.wg.Wait()
 698  
 699  	return nil
 700  }
 701  
 702  func (s *mockServer) String() string {
 703  	return s.name
 704  }
 705  
 706  type mockChannelLink struct {
 707  	htlcSwitch *Switch
 708  
 709  	shortChanID lnwire.ShortChannelID
 710  
 711  	// Only used for zero-conf channels.
 712  	realScid lnwire.ShortChannelID
 713  
 714  	aliases []lnwire.ShortChannelID
 715  
 716  	chanID lnwire.ChannelID
 717  
 718  	peer lnpeer.Peer
 719  
 720  	mailBox MailBox
 721  
 722  	packets chan *htlcPacket
 723  
 724  	eligible bool
 725  
 726  	unadvertised bool
 727  
 728  	zeroConf bool
 729  
 730  	optionFeature bool
 731  
 732  	htlcID uint64
 733  
 734  	checkHtlcTransitResult *LinkError
 735  
 736  	checkHtlcForwardResult *LinkError
 737  
 738  	failAliasUpdate func(sid lnwire.ShortChannelID,
 739  		incoming bool) *lnwire.ChannelUpdate1
 740  
 741  	confirmedZC bool
 742  }
 743  
 744  // completeCircuit is a helper method for adding the finalized payment circuit
 745  // to the switch's circuit map. In testing, this should be executed after
 746  // receiving an htlc from the downstream packets channel.
 747  func (f *mockChannelLink) completeCircuit(pkt *htlcPacket) error {
 748  	switch htlc := pkt.htlc.(type) {
 749  	case *lnwire.UpdateAddHTLC:
 750  		pkt.outgoingChanID = f.shortChanID
 751  		pkt.outgoingHTLCID = f.htlcID
 752  		htlc.ID = f.htlcID
 753  
 754  		keystone := Keystone{pkt.inKey(), pkt.outKey()}
 755  		err := f.htlcSwitch.circuits.OpenCircuits(keystone)
 756  		if err != nil {
 757  			return err
 758  		}
 759  
 760  		f.htlcID++
 761  
 762  	case *lnwire.UpdateFulfillHTLC, *lnwire.UpdateFailHTLC:
 763  		if pkt.circuit != nil {
 764  			err := f.htlcSwitch.teardownCircuit(pkt)
 765  			if err != nil {
 766  				return err
 767  			}
 768  		}
 769  	}
 770  
 771  	f.mailBox.AckPacket(pkt.inKey())
 772  
 773  	return nil
 774  }
 775  
 776  func (f *mockChannelLink) deleteCircuit(pkt *htlcPacket) error {
 777  	return f.htlcSwitch.circuits.DeleteCircuits(pkt.inKey())
 778  }
 779  
 780  func newMockChannelLink(htlcSwitch *Switch, chanID lnwire.ChannelID,
 781  	shortChanID, realScid lnwire.ShortChannelID, peer lnpeer.Peer,
 782  	eligible, unadvertised, zeroConf, optionFeature bool,
 783  ) *mockChannelLink {
 784  
 785  	aliases := make([]lnwire.ShortChannelID, 0)
 786  	var realConfirmed bool
 787  
 788  	if zeroConf {
 789  		aliases = append(aliases, shortChanID)
 790  	}
 791  
 792  	if realScid != hop.Source {
 793  		realConfirmed = true
 794  	}
 795  
 796  	return &mockChannelLink{
 797  		htlcSwitch:    htlcSwitch,
 798  		chanID:        chanID,
 799  		shortChanID:   shortChanID,
 800  		realScid:      realScid,
 801  		peer:          peer,
 802  		eligible:      eligible,
 803  		unadvertised:  unadvertised,
 804  		zeroConf:      zeroConf,
 805  		optionFeature: optionFeature,
 806  		aliases:       aliases,
 807  		confirmedZC:   realConfirmed,
 808  	}
 809  }
 810  
 811  // addAlias is not part of any interface method.
 812  func (f *mockChannelLink) addAlias(alias lnwire.ShortChannelID) {
 813  	f.aliases = append(f.aliases, alias)
 814  }
 815  
 816  func (f *mockChannelLink) handleSwitchPacket(pkt *htlcPacket) error {
 817  	f.mailBox.AddPacket(pkt)
 818  	return nil
 819  }
 820  
 821  func (f *mockChannelLink) getDustSum(whoseCommit lntypes.ChannelParty,
 822  	dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi {
 823  
 824  	return 0
 825  }
 826  
 827  func (f *mockChannelLink) getFeeRate() chainfee.SatPerKWeight {
 828  	return 0
 829  }
 830  
 831  func (f *mockChannelLink) getDustClosure() dustClosure {
 832  	dustLimit := btcutil.Amount(400)
 833  	return dustHelper(
 834  		channeldb.SingleFunderTweaklessBit, dustLimit, dustLimit,
 835  	)
 836  }
 837  
 838  func (f *mockChannelLink) getCommitFee(remote bool) btcutil.Amount {
 839  	return 0
 840  }
 841  
 842  func (f *mockChannelLink) HandleChannelUpdate(lnwire.Message) {
 843  }
 844  
 845  func (f *mockChannelLink) UpdateForwardingPolicy(_ models.ForwardingPolicy) {
 846  }
 847  func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi,
 848  	lnwire.MilliSatoshi, uint32, uint32, models.InboundFee, uint32,
 849  	lnwire.ShortChannelID, lnwire.CustomRecords) *LinkError {
 850  
 851  	return f.checkHtlcForwardResult
 852  }
 853  
 854  func (f *mockChannelLink) CheckHtlcTransit(payHash [32]byte,
 855  	amt lnwire.MilliSatoshi, timeout uint32,
 856  	heightNow uint32, _ lnwire.CustomRecords) *LinkError {
 857  
 858  	return f.checkHtlcTransitResult
 859  }
 860  
 861  func (f *mockChannelLink) Stats() (
 862  	uint64, lnwire.MilliSatoshi, lnwire.MilliSatoshi) {
 863  
 864  	return 0, 0, 0
 865  }
 866  
 867  func (f *mockChannelLink) AttachMailBox(mailBox MailBox) {
 868  	f.mailBox = mailBox
 869  	f.packets = mailBox.PacketOutBox()
 870  	mailBox.SetDustClosure(f.getDustClosure())
 871  }
 872  
 873  func (f *mockChannelLink) attachFailAliasUpdate(closure func(
 874  	sid lnwire.ShortChannelID, incoming bool) *lnwire.ChannelUpdate1) {
 875  
 876  	f.failAliasUpdate = closure
 877  }
 878  
 879  func (f *mockChannelLink) getAliases() []lnwire.ShortChannelID {
 880  	return f.aliases
 881  }
 882  
 883  func (f *mockChannelLink) isZeroConf() bool {
 884  	return f.zeroConf
 885  }
 886  
 887  func (f *mockChannelLink) negotiatedAliasFeature() bool {
 888  	return f.optionFeature
 889  }
 890  
 891  func (f *mockChannelLink) confirmedScid() lnwire.ShortChannelID {
 892  	return f.realScid
 893  }
 894  
 895  func (f *mockChannelLink) zeroConfConfirmed() bool {
 896  	return f.confirmedZC
 897  }
 898  
 899  func (f *mockChannelLink) Start() error {
 900  	f.mailBox.ResetMessages()
 901  	f.mailBox.ResetPackets()
 902  	return nil
 903  }
 904  
 905  func (f *mockChannelLink) ChanID() lnwire.ChannelID {
 906  	return f.chanID
 907  }
 908  
 909  func (f *mockChannelLink) ShortChanID() lnwire.ShortChannelID {
 910  	return f.shortChanID
 911  }
 912  
 913  func (f *mockChannelLink) Bandwidth() lnwire.MilliSatoshi {
 914  	return 99999999
 915  }
 916  
 917  func (f *mockChannelLink) PeerPubKey() [33]byte {
 918  	return f.peer.PubKey()
 919  }
 920  
 921  func (f *mockChannelLink) ChannelPoint() wire.OutPoint {
 922  	return wire.OutPoint{}
 923  }
 924  
 925  func (f *mockChannelLink) Stop()                                        {}
 926  func (f *mockChannelLink) EligibleToForward() bool                      { return f.eligible }
 927  func (f *mockChannelLink) MayAddOutgoingHtlc(lnwire.MilliSatoshi) error { return nil }
 928  func (f *mockChannelLink) setLiveShortChanID(sid lnwire.ShortChannelID) { f.shortChanID = sid }
 929  func (f *mockChannelLink) IsUnadvertised() bool                         { return f.unadvertised }
 930  func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) {
 931  	f.eligible = true
 932  	return f.shortChanID, nil
 933  }
 934  
 935  func (f *mockChannelLink) EnableAdds(linkDirection LinkDirection) bool {
 936  	// TODO(proofofkeags): Implement
 937  	return true
 938  }
 939  
 940  func (f *mockChannelLink) DisableAdds(linkDirection LinkDirection) bool {
 941  	// TODO(proofofkeags): Implement
 942  	return true
 943  }
 944  func (f *mockChannelLink) IsFlushing(linkDirection LinkDirection) bool {
 945  	// TODO(proofofkeags): Implement
 946  	return false
 947  }
 948  func (f *mockChannelLink) OnFlushedOnce(func()) {
 949  	// TODO(proofofkeags): Implement
 950  }
 951  func (f *mockChannelLink) OnCommitOnce(LinkDirection, func()) {
 952  	// TODO(proofofkeags): Implement
 953  }
 954  func (f *mockChannelLink) InitStfu() <-chan fn.Result[lntypes.ChannelParty] {
 955  	// TODO(proofofkeags): Implement
 956  	c := make(chan fn.Result[lntypes.ChannelParty], 1)
 957  
 958  	c <- fn.Errf[lntypes.ChannelParty]("InitStfu not implemented")
 959  
 960  	return c
 961  }
 962  
 963  func (f *mockChannelLink) FundingCustomBlob() fn.Option[tlv.Blob] {
 964  	return fn.None[tlv.Blob]()
 965  }
 966  
 967  func (f *mockChannelLink) CommitmentCustomBlob() fn.Option[tlv.Blob] {
 968  	return fn.None[tlv.Blob]()
 969  }
 970  
 971  // AuxBandwidth returns the bandwidth that can be used for a channel,
 972  // expressed in milli-satoshi. This might be different from the regular
 973  // BTC bandwidth for custom channels. This will always return fn.None()
 974  // for a regular (non-custom) channel.
 975  func (f *mockChannelLink) AuxBandwidth(lnwire.MilliSatoshi,
 976  	lnwire.ShortChannelID,
 977  	fn.Option[tlv.Blob], AuxTrafficShaper) fn.Result[OptionalBandwidth] {
 978  
 979  	return fn.Ok(OptionalBandwidth{})
 980  }
 981  
 982  var _ ChannelLink = (*mockChannelLink)(nil)
 983  
 984  const testInvoiceCltvExpiry = 6
 985  
 986  type mockInvoiceRegistry struct {
 987  	settleChan chan lntypes.Hash
 988  
 989  	registry *invoices.InvoiceRegistry
 990  }
 991  
 992  type mockChainNotifier struct {
 993  	chainntnfs.ChainNotifier
 994  }
 995  
 996  // RegisterBlockEpochNtfn mocks a successful call to register block
 997  // notifications.
 998  func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) (
 999  	*chainntnfs.BlockEpochEvent, error) {
1000  
1001  	return &chainntnfs.BlockEpochEvent{
1002  		Cancel: func() {},
1003  	}, nil
1004  }
1005  
1006  func newMockRegistry(t testing.TB) *mockInvoiceRegistry {
1007  	cdb := channeldb.OpenForTesting(t, t.TempDir())
1008  
1009  	modifierMock := &invoices.MockHtlcModifier{}
1010  	registry := invoices.NewRegistry(
1011  		cdb,
1012  		invoices.NewInvoiceExpiryWatcher(
1013  			clock.NewDefaultClock(), 0, 0, nil,
1014  			&mockChainNotifier{},
1015  		),
1016  		&invoices.RegistryConfig{
1017  			FinalCltvRejectDelta: 5,
1018  			HtlcInterceptor:      modifierMock,
1019  		},
1020  	)
1021  	registry.Start()
1022  
1023  	return &mockInvoiceRegistry{
1024  		registry: registry,
1025  	}
1026  }
1027  
1028  func (i *mockInvoiceRegistry) LookupInvoice(ctx context.Context,
1029  	rHash lntypes.Hash) (invoices.Invoice, error) {
1030  
1031  	return i.registry.LookupInvoice(ctx, rHash)
1032  }
1033  
1034  func (i *mockInvoiceRegistry) SettleHodlInvoice(
1035  	ctx context.Context, preimage lntypes.Preimage) error {
1036  
1037  	return i.registry.SettleHodlInvoice(ctx, preimage)
1038  }
1039  
1040  func (i *mockInvoiceRegistry) NotifyExitHopHtlc(rhash lntypes.Hash,
1041  	amt lnwire.MilliSatoshi, expiry uint32, currentHeight int32,
1042  	circuitKey models.CircuitKey, hodlChan chan<- interface{},
1043  	wireCustomRecords lnwire.CustomRecords,
1044  	payload invoices.Payload) (invoices.HtlcResolution, error) {
1045  
1046  	event, err := i.registry.NotifyExitHopHtlc(
1047  		rhash, amt, expiry, currentHeight, circuitKey,
1048  		hodlChan, wireCustomRecords, payload,
1049  	)
1050  	if err != nil {
1051  		return nil, err
1052  	}
1053  	if i.settleChan != nil {
1054  		i.settleChan <- rhash
1055  	}
1056  
1057  	return event, nil
1058  }
1059  
1060  func (i *mockInvoiceRegistry) CancelInvoice(ctx context.Context,
1061  	payHash lntypes.Hash) error {
1062  
1063  	return i.registry.CancelInvoice(ctx, payHash)
1064  }
1065  
1066  func (i *mockInvoiceRegistry) AddInvoice(ctx context.Context,
1067  	invoice invoices.Invoice, paymentHash lntypes.Hash) error {
1068  
1069  	_, err := i.registry.AddInvoice(ctx, &invoice, paymentHash)
1070  	return err
1071  }
1072  
1073  func (i *mockInvoiceRegistry) HodlUnsubscribeAll(
1074  	subscriber chan<- interface{}) {
1075  
1076  	i.registry.HodlUnsubscribeAll(subscriber)
1077  }
1078  
1079  var _ InvoiceDatabase = (*mockInvoiceRegistry)(nil)
1080  
1081  type mockCircuitMap struct {
1082  	lookup chan *PaymentCircuit
1083  }
1084  
1085  var _ CircuitMap = (*mockCircuitMap)(nil)
1086  
1087  func (m *mockCircuitMap) OpenCircuits(...Keystone) error {
1088  	return nil
1089  }
1090  
1091  func (m *mockCircuitMap) TrimOpenCircuits(chanID lnwire.ShortChannelID,
1092  	start uint64) error {
1093  	return nil
1094  }
1095  
1096  func (m *mockCircuitMap) DeleteCircuits(inKeys ...CircuitKey) error {
1097  	return nil
1098  }
1099  
1100  func (m *mockCircuitMap) CommitCircuits(
1101  	circuit ...*PaymentCircuit) (*CircuitFwdActions, error) {
1102  
1103  	return nil, nil
1104  }
1105  
1106  func (m *mockCircuitMap) CloseCircuit(outKey CircuitKey) (*PaymentCircuit,
1107  	error) {
1108  	return nil, nil
1109  }
1110  
1111  func (m *mockCircuitMap) FailCircuit(inKey CircuitKey) (*PaymentCircuit,
1112  	error) {
1113  	return nil, nil
1114  }
1115  
1116  func (m *mockCircuitMap) LookupCircuit(inKey CircuitKey) *PaymentCircuit {
1117  	return <-m.lookup
1118  }
1119  
1120  func (m *mockCircuitMap) LookupOpenCircuit(outKey CircuitKey) *PaymentCircuit {
1121  	return nil
1122  }
1123  
1124  func (m *mockCircuitMap) LookupByPaymentHash(hash [32]byte) []*PaymentCircuit {
1125  	return nil
1126  }
1127  
1128  func (m *mockCircuitMap) NumPending() int {
1129  	return 0
1130  }
1131  
1132  func (m *mockCircuitMap) NumOpen() int {
1133  	return 0
1134  }
1135  
1136  type mockOnionErrorDecryptor struct {
1137  	sourceIdx int
1138  	message   []byte
1139  	err       error
1140  }
1141  
1142  func (m *mockOnionErrorDecryptor) DecryptError(encryptedData []byte) (
1143  	*sphinx.DecryptedError, error) {
1144  
1145  	return &sphinx.DecryptedError{
1146  		SenderIdx: m.sourceIdx,
1147  		Message:   m.message,
1148  	}, m.err
1149  }
1150  
1151  var _ htlcNotifier = (*mockHTLCNotifier)(nil)
1152  
1153  type mockHTLCNotifier struct {
1154  	htlcNotifier //nolint:unused
1155  }
1156  
1157  func (h *mockHTLCNotifier) NotifyForwardingEvent(key HtlcKey, info HtlcInfo,
1158  	eventType HtlcEventType) {
1159  
1160  }
1161  
1162  func (h *mockHTLCNotifier) NotifyLinkFailEvent(key HtlcKey, info HtlcInfo,
1163  	eventType HtlcEventType, linkErr *LinkError,
1164  	incoming bool) {
1165  
1166  }
1167  
1168  func (h *mockHTLCNotifier) NotifyForwardingFailEvent(key HtlcKey,
1169  	eventType HtlcEventType) {
1170  
1171  }
1172  
1173  func (h *mockHTLCNotifier) NotifySettleEvent(key HtlcKey,
1174  	preimage lntypes.Preimage, eventType HtlcEventType) {
1175  
1176  }
1177  
1178  func (h *mockHTLCNotifier) NotifyFinalHtlcEvent(key models.CircuitKey,
1179  	info channeldb.FinalHtlcInfo) {
1180  
1181  }