/ invoices / sql_migration.go
sql_migration.go
  1  package invoices
  2  
  3  import (
  4  	"bytes"
  5  	"context"
  6  	"encoding/binary"
  7  	"errors"
  8  	"fmt"
  9  	"strconv"
 10  	"time"
 11  
 12  	"github.com/lightningnetwork/lnd/graph/db/models"
 13  	"github.com/lightningnetwork/lnd/kvdb"
 14  	"github.com/lightningnetwork/lnd/lntypes"
 15  	"github.com/lightningnetwork/lnd/sqldb"
 16  	"github.com/lightningnetwork/lnd/sqldb/sqlc"
 17  	"golang.org/x/time/rate"
 18  )
 19  
 20  var (
 21  	// invoiceBucket is the name of the bucket within the database that
 22  	// stores all data related to invoices no matter their final state.
 23  	// Within the invoice bucket, each invoice is keyed by its invoice ID
 24  	// which is a monotonically increasing uint32.
 25  	invoiceBucket = []byte("invoices")
 26  
 27  	// invoiceIndexBucket  is the name of the sub-bucket within the
 28  	// invoiceBucket which indexes all invoices by their payment hash. The
 29  	// payment hash is the sha256 of the invoice's payment preimage. This
 30  	// index is used to detect duplicates, and also to provide a fast path
 31  	// for looking up incoming HTLCs to determine if we're able to settle
 32  	// them fully.
 33  	//
 34  	// maps: payHash => invoiceKey
 35  	invoiceIndexBucket = []byte("paymenthashes")
 36  
 37  	// numInvoicesKey is the name of key which houses the auto-incrementing
 38  	// invoice ID which is essentially used as a primary key. With each
 39  	// invoice inserted, the primary key is incremented by one. This key is
 40  	// stored within the invoiceIndexBucket. Within the invoiceBucket
 41  	// invoices are uniquely identified by the invoice ID.
 42  	numInvoicesKey = []byte("nik")
 43  
 44  	// addIndexBucket is an index bucket that we'll use to create a
 45  	// monotonically increasing set of add indexes. Each time we add a new
 46  	// invoice, this sequence number will be incremented and then populated
 47  	// within the new invoice.
 48  	//
 49  	// In addition to this sequence number, we map:
 50  	//
 51  	//   addIndexNo => invoiceKey
 52  	addIndexBucket = []byte("invoice-add-index")
 53  )
 54  
 55  // createInvoiceHashIndex generates a hash index that contains payment hashes
 56  // for each invoice in the database. Retrieving the payment hash for certain
 57  // invoices, such as those created for spontaneous AMP payments, can be
 58  // challenging because the hash is not directly derivable from the invoice's
 59  // parameters and is stored separately in the `paymenthashes` bucket. This
 60  // bucket maps payment hashes to invoice keys, but for migration purposes, we
 61  // need the ability to query in the reverse direction. This function establishes
 62  // a new index in the SQL database that maps each invoice key to its
 63  // corresponding payment hash.
 64  func createInvoiceHashIndex(ctx context.Context, db kvdb.Backend,
 65  	tx *sqlc.Queries) error {
 66  
 67  	return db.View(func(kvTx kvdb.RTx) error {
 68  		invoices := kvTx.ReadBucket(invoiceBucket)
 69  		if invoices == nil {
 70  			return ErrNoInvoicesCreated
 71  		}
 72  
 73  		invoiceIndex := invoices.NestedReadBucket(
 74  			invoiceIndexBucket,
 75  		)
 76  		if invoiceIndex == nil {
 77  			return ErrNoInvoicesCreated
 78  		}
 79  
 80  		addIndex := invoices.NestedReadBucket(addIndexBucket)
 81  		if addIndex == nil {
 82  			return ErrNoInvoicesCreated
 83  		}
 84  
 85  		// First, iterate over all elements in the add index bucket and
 86  		// insert the add index value for the corresponding invoice key
 87  		// in the payment_hashes table.
 88  		err := addIndex.ForEach(func(k, v []byte) error {
 89  			// The key is the add index, and the value is
 90  			// the invoice key.
 91  			addIndexNo := binary.BigEndian.Uint64(k)
 92  			invoiceKey := binary.BigEndian.Uint32(v)
 93  
 94  			return tx.InsertKVInvoiceKeyAndAddIndex(ctx,
 95  				sqlc.InsertKVInvoiceKeyAndAddIndexParams{
 96  					ID:       int64(invoiceKey),
 97  					AddIndex: int64(addIndexNo),
 98  				},
 99  			)
100  		})
101  		if err != nil {
102  			return err
103  		}
104  
105  		// Next, iterate over all hashes in the invoice index bucket and
106  		// set the hash to the corresponding the invoice key in the
107  		// payment_hashes table.
108  		return invoiceIndex.ForEach(func(k, v []byte) error {
109  			// Skip the special numInvoicesKey as that does
110  			// not point to a valid invoice.
111  			if bytes.Equal(k, numInvoicesKey) {
112  				return nil
113  			}
114  
115  			// The key is the payment hash, and the value
116  			// is the invoice key.
117  			if len(k) != lntypes.HashSize {
118  				return fmt.Errorf("invalid payment "+
119  					"hash length: expected %v, "+
120  					"got %v", lntypes.HashSize,
121  					len(k))
122  			}
123  
124  			invoiceKey := binary.BigEndian.Uint32(v)
125  
126  			return tx.SetKVInvoicePaymentHash(ctx,
127  				sqlc.SetKVInvoicePaymentHashParams{
128  					ID:   int64(invoiceKey),
129  					Hash: k,
130  				},
131  			)
132  		})
133  	}, func() {})
134  }
135  
136  // toInsertMigratedInvoiceParams creates the parameters for inserting a migrated
137  // invoice into the SQL database. The parameters are derived from the original
138  // invoice insert parameters.
139  func toInsertMigratedInvoiceParams(
140  	params sqlc.InsertInvoiceParams) sqlc.InsertMigratedInvoiceParams {
141  
142  	return sqlc.InsertMigratedInvoiceParams{
143  		Hash:               params.Hash,
144  		Preimage:           params.Preimage,
145  		Memo:               params.Memo,
146  		AmountMsat:         params.AmountMsat,
147  		CltvDelta:          params.CltvDelta,
148  		Expiry:             params.Expiry,
149  		PaymentAddr:        params.PaymentAddr,
150  		PaymentRequest:     params.PaymentRequest,
151  		PaymentRequestHash: params.PaymentRequestHash,
152  		State:              params.State,
153  		AmountPaidMsat:     params.AmountPaidMsat,
154  		IsAmp:              params.IsAmp,
155  		IsHodl:             params.IsHodl,
156  		IsKeysend:          params.IsKeysend,
157  		CreatedAt:          params.CreatedAt,
158  	}
159  }
160  
161  // MigrateSingleInvoice migrates a single invoice to the new SQL schema. Note
162  // that perfect equality between the old and new schemas is not achievable, as
163  // the invoice's add index cannot be mapped directly to its ID due to SQL’s
164  // auto-incrementing primary key. The ID returned from the insert will instead
165  // serve as the add index in the new schema.
166  func MigrateSingleInvoice(ctx context.Context, tx SQLInvoiceQueries,
167  	invoice *Invoice, paymentHash lntypes.Hash) error {
168  
169  	insertInvoiceParams, err := makeInsertInvoiceParams(
170  		invoice, paymentHash,
171  	)
172  	if err != nil {
173  		return err
174  	}
175  
176  	// Convert the insert invoice parameters to the migrated invoice insert
177  	// parameters.
178  	insertMigratedInvoiceParams := toInsertMigratedInvoiceParams(
179  		insertInvoiceParams,
180  	)
181  
182  	// If the invoice is settled, we'll also set the timestamp and the index
183  	// at which it was settled.
184  	if invoice.State == ContractSettled {
185  		if invoice.SettleIndex == 0 {
186  			return fmt.Errorf("settled invoice %s missing settle "+
187  				"index", paymentHash)
188  		}
189  
190  		if invoice.SettleDate.IsZero() {
191  			return fmt.Errorf("settled invoice %s missing settle "+
192  				"date", paymentHash)
193  		}
194  
195  		insertMigratedInvoiceParams.SettleIndex = sqldb.SQLInt64(
196  			invoice.SettleIndex,
197  		)
198  		insertMigratedInvoiceParams.SettledAt = sqldb.SQLTime(
199  			invoice.SettleDate.UTC(),
200  		)
201  	}
202  
203  	// First we need to insert the invoice itself so we can use the "add
204  	// index" which in this case is the auto incrementing primary key that
205  	// is returned from the insert.
206  	invoiceID, err := tx.InsertMigratedInvoice(
207  		ctx, insertMigratedInvoiceParams,
208  	)
209  	if err != nil {
210  		return fmt.Errorf("unable to insert invoice: %w", err)
211  	}
212  
213  	// Insert the invoice's features.
214  	for feature := range invoice.Terms.Features.Features() {
215  		params := sqlc.InsertInvoiceFeatureParams{
216  			InvoiceID: invoiceID,
217  			Feature:   int32(feature),
218  		}
219  
220  		err := tx.InsertInvoiceFeature(ctx, params)
221  		if err != nil {
222  			return fmt.Errorf("unable to insert invoice "+
223  				"feature(%v): %w", feature, err)
224  		}
225  	}
226  
227  	sqlHtlcIDs := make(map[models.CircuitKey]int64)
228  
229  	// Now insert the HTLCs of the invoice. We'll also keep track of the SQL
230  	// ID of each HTLC so we can use it when inserting the AMP sub invoices.
231  	for circuitKey, htlc := range invoice.Htlcs {
232  		htlcParams := sqlc.InsertInvoiceHTLCParams{
233  			HtlcID: int64(circuitKey.HtlcID),
234  			ChanID: strconv.FormatUint(
235  				circuitKey.ChanID.ToUint64(), 10,
236  			),
237  			AmountMsat:   int64(htlc.Amt),
238  			AcceptHeight: int32(htlc.AcceptHeight),
239  			AcceptTime:   htlc.AcceptTime.UTC(),
240  			ExpiryHeight: int32(htlc.Expiry),
241  			State:        int16(htlc.State),
242  			InvoiceID:    invoiceID,
243  		}
244  
245  		// Leave the MPP amount as NULL if the MPP total amount is zero.
246  		if htlc.MppTotalAmt != 0 {
247  			htlcParams.TotalMppMsat = sqldb.SQLInt64(
248  				int64(htlc.MppTotalAmt),
249  			)
250  		}
251  
252  		// Leave the resolve time as NULL if the HTLC is not resolved.
253  		if !htlc.ResolveTime.IsZero() {
254  			htlcParams.ResolveTime = sqldb.SQLTime(
255  				htlc.ResolveTime.UTC(),
256  			)
257  		}
258  
259  		sqlID, err := tx.InsertInvoiceHTLC(ctx, htlcParams)
260  		if err != nil {
261  			return fmt.Errorf("unable to insert invoice htlc: %w",
262  				err)
263  		}
264  
265  		sqlHtlcIDs[circuitKey] = sqlID
266  
267  		// Store custom records.
268  		for key, value := range htlc.CustomRecords {
269  			err = tx.InsertInvoiceHTLCCustomRecord(
270  				ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
271  					Key:    int64(key),
272  					Value:  value,
273  					HtlcID: sqlID,
274  				},
275  			)
276  			if err != nil {
277  				return err
278  			}
279  		}
280  	}
281  
282  	if !invoice.IsAMP() {
283  		return nil
284  	}
285  
286  	for setID, ampState := range invoice.AMPState {
287  		// Find the earliest HTLC of the AMP invoice, which will
288  		// be used as the creation date of this sub invoice.
289  		var createdAt time.Time
290  		for circuitKey := range ampState.InvoiceKeys {
291  			htlc := invoice.Htlcs[circuitKey]
292  			if createdAt.IsZero() {
293  				createdAt = htlc.AcceptTime.UTC()
294  				continue
295  			}
296  
297  			if createdAt.After(htlc.AcceptTime) {
298  				createdAt = htlc.AcceptTime.UTC()
299  			}
300  		}
301  
302  		params := sqlc.InsertAMPSubInvoiceParams{
303  			SetID:     setID[:],
304  			State:     int16(ampState.State),
305  			CreatedAt: createdAt,
306  			InvoiceID: invoiceID,
307  		}
308  
309  		if ampState.SettleIndex != 0 {
310  			if ampState.SettleDate.IsZero() {
311  				return fmt.Errorf("settled AMP sub invoice %x "+
312  					"missing settle date", setID)
313  			}
314  
315  			params.SettledAt = sqldb.SQLTime(
316  				ampState.SettleDate.UTC(),
317  			)
318  
319  			params.SettleIndex = sqldb.SQLInt64(
320  				ampState.SettleIndex,
321  			)
322  		}
323  
324  		err := tx.InsertAMPSubInvoice(ctx, params)
325  		if err != nil {
326  			return fmt.Errorf("unable to insert AMP sub invoice: "+
327  				"%w", err)
328  		}
329  
330  		// Now we can add the AMP HTLCs to the database.
331  		for circuitKey := range ampState.InvoiceKeys {
332  			htlc := invoice.Htlcs[circuitKey]
333  			rootShare := htlc.AMP.Record.RootShare()
334  
335  			sqlHtlcID, ok := sqlHtlcIDs[circuitKey]
336  			if !ok {
337  				return fmt.Errorf("missing htlc for AMP htlc: "+
338  					"%v", circuitKey)
339  			}
340  
341  			params := sqlc.InsertAMPSubInvoiceHTLCParams{
342  				InvoiceID:  invoiceID,
343  				SetID:      setID[:],
344  				HtlcID:     sqlHtlcID,
345  				RootShare:  rootShare[:],
346  				ChildIndex: int64(htlc.AMP.Record.ChildIndex()),
347  				Hash:       htlc.AMP.Hash[:],
348  			}
349  
350  			if htlc.AMP.Preimage != nil {
351  				params.Preimage = htlc.AMP.Preimage[:]
352  			}
353  
354  			err = tx.InsertAMPSubInvoiceHTLC(ctx, params)
355  			if err != nil {
356  				return fmt.Errorf("unable to insert AMP sub "+
357  					"invoice: %w", err)
358  			}
359  		}
360  	}
361  
362  	return nil
363  }
364  
365  // OverrideInvoiceTimeZone overrides the time zone of the invoice to the local
366  // time zone and chops off the nanosecond part for comparison. This is needed
367  // because KV database stores times as-is which as an unwanted side effect would
368  // fail migration due to time comparison expecting both the original and
369  // migrated invoices to be in the same local time zone and in microsecond
370  // precision. Note that PostgreSQL stores times in microsecond precision while
371  // SQLite can store times in nanosecond precision if using TEXT storage class.
372  func OverrideInvoiceTimeZone(invoice *Invoice) {
373  	fixTime := func(t time.Time) time.Time {
374  		return t.In(time.Local).Truncate(time.Microsecond)
375  	}
376  
377  	invoice.CreationDate = fixTime(invoice.CreationDate)
378  
379  	if !invoice.SettleDate.IsZero() {
380  		invoice.SettleDate = fixTime(invoice.SettleDate)
381  	}
382  
383  	if invoice.IsAMP() {
384  		for setID, ampState := range invoice.AMPState {
385  			if ampState.SettleDate.IsZero() {
386  				continue
387  			}
388  
389  			ampState.SettleDate = fixTime(ampState.SettleDate)
390  			invoice.AMPState[setID] = ampState
391  		}
392  	}
393  
394  	for _, htlc := range invoice.Htlcs {
395  		if !htlc.AcceptTime.IsZero() {
396  			htlc.AcceptTime = fixTime(htlc.AcceptTime)
397  		}
398  
399  		if !htlc.ResolveTime.IsZero() {
400  			htlc.ResolveTime = fixTime(htlc.ResolveTime)
401  		}
402  	}
403  }
404  
405  // MigrateInvoicesToSQL runs the migration of all invoices from the KV database
406  // to the SQL database. The migration is done in a single transaction to ensure
407  // that all invoices are migrated or none at all. This function can be run
408  // multiple times without causing any issues as it will check if the migration
409  // has already been performed.
410  func MigrateInvoicesToSQL(ctx context.Context, db kvdb.Backend,
411  	kvStore InvoiceDB, tx *sqlc.Queries, batchSize int) error {
412  
413  	log.Infof("Starting migration of invoices from KV to SQL")
414  
415  	offset := uint64(0)
416  	t0 := time.Now()
417  
418  	// Create the hash index which we will use to look up invoice
419  	// payment hashes by their add index during migration.
420  	err := createInvoiceHashIndex(ctx, db, tx)
421  	if err != nil && !errors.Is(err, ErrNoInvoicesCreated) {
422  		log.Errorf("Unable to create invoice hash index: %v",
423  			err)
424  
425  		return err
426  	}
427  	log.Debugf("Created SQL invoice hash index in %v", time.Since(t0))
428  
429  	s := rate.Sometimes{
430  		Interval: 30 * time.Second,
431  	}
432  
433  	t0 = time.Now()
434  	chunk := 0
435  	total := 0
436  
437  	// Now we can start migrating the invoices. We'll do this in
438  	// batches to reduce memory usage.
439  	for {
440  		query := InvoiceQuery{
441  			IndexOffset:    offset,
442  			NumMaxInvoices: uint64(batchSize),
443  		}
444  
445  		queryResult, err := kvStore.QueryInvoices(ctx, query)
446  		if err != nil && !errors.Is(err, ErrNoInvoicesCreated) {
447  			return fmt.Errorf("unable to query invoices: %w", err)
448  		}
449  
450  		if len(queryResult.Invoices) == 0 {
451  			log.Infof("All invoices migrated. Total: %d", total)
452  			break
453  		}
454  
455  		err = migrateInvoices(ctx, tx, queryResult.Invoices)
456  		if err != nil {
457  			return err
458  		}
459  
460  		offset = queryResult.LastIndexOffset
461  		resultCnt := len(queryResult.Invoices)
462  		total += resultCnt
463  		chunk += resultCnt
464  
465  		s.Do(func() {
466  			elapsed := time.Since(t0).Seconds()
467  			ratePerSec := float64(chunk) / elapsed
468  			log.Debugf("Migrated %d invoices (%.2f invoices/sec)",
469  				total, ratePerSec)
470  
471  			t0 = time.Now()
472  			chunk = 0
473  		})
474  	}
475  
476  	// Clean up the hash index as it's no longer needed.
477  	err = tx.ClearKVInvoiceHashIndex(ctx)
478  	if err != nil {
479  		return fmt.Errorf("unable to clear invoice hash "+
480  			"index: %w", err)
481  	}
482  
483  	log.Infof("Migration of %d invoices from KV to SQL completed", total)
484  
485  	return nil
486  }
487  
488  func migrateInvoices(ctx context.Context, tx *sqlc.Queries,
489  	invoices []Invoice) error {
490  
491  	for i, invoice := range invoices {
492  		var paymentHash lntypes.Hash
493  		if invoice.Terms.PaymentPreimage != nil {
494  			paymentHash = invoice.Terms.PaymentPreimage.Hash()
495  		} else {
496  			paymentHashBytes, err :=
497  				tx.GetKVInvoicePaymentHashByAddIndex(
498  					ctx, int64(invoice.AddIndex),
499  				)
500  			if err != nil {
501  				// This would be an unexpected inconsistency
502  				// in the kv database. We can't do much here
503  				// so we'll notify the user and continue.
504  				log.Warnf("Cannot migrate invoice, unable to "+
505  					"fetch payment hash (add_index=%v): %v",
506  					invoice.AddIndex, err)
507  
508  				continue
509  			}
510  
511  			copy(paymentHash[:], paymentHashBytes)
512  		}
513  
514  		err := MigrateSingleInvoice(ctx, tx, &invoices[i], paymentHash)
515  		if err != nil {
516  			return fmt.Errorf("unable to migrate invoice(%v): %w",
517  				paymentHash, err)
518  		}
519  
520  		migratedInvoice, err := fetchInvoice(
521  			ctx, tx, InvoiceRefByHash(paymentHash),
522  		)
523  		if err != nil {
524  			return fmt.Errorf("unable to fetch migrated "+
525  				"invoice(%v): %w", paymentHash, err)
526  		}
527  
528  		// Override the time zone for comparison. Note that we need to
529  		// override both invoices as the original invoice is coming from
530  		// KV database, it was stored as a binary serialized Go
531  		// time.Time value which has nanosecond precision but might have
532  		// been created in a different time zone. The migrated invoice
533  		// is stored in SQL in UTC and selected in the local time zone,
534  		// however in PostgreSQL it has microsecond precision while in
535  		// SQLite it has nanosecond precision if using TEXT storage
536  		// class.
537  		OverrideInvoiceTimeZone(&invoice)
538  		OverrideInvoiceTimeZone(migratedInvoice)
539  
540  		// Override the add index before checking for equality.
541  		migratedInvoice.AddIndex = invoice.AddIndex
542  
543  		err = sqldb.CompareRecords(invoice, *migratedInvoice, "invoice")
544  		if err != nil {
545  			return err
546  		}
547  	}
548  
549  	return nil
550  }