/ invoices / sql_migration_test.go
sql_migration_test.go
  1  package invoices
  2  
  3  import (
  4  	crand "crypto/rand"
  5  	"database/sql"
  6  	"math/rand"
  7  	"sync/atomic"
  8  	"testing"
  9  	"time"
 10  
 11  	"github.com/lightningnetwork/lnd/clock"
 12  	"github.com/lightningnetwork/lnd/graph/db/models"
 13  	"github.com/lightningnetwork/lnd/lntypes"
 14  	"github.com/lightningnetwork/lnd/lnwire"
 15  	"github.com/lightningnetwork/lnd/record"
 16  	"github.com/lightningnetwork/lnd/sqldb"
 17  	"github.com/stretchr/testify/require"
 18  	"pgregory.net/rapid"
 19  )
 20  
 21  var (
 22  	// testHtlcIDSequence is a global counter for generating unique HTLC
 23  	// IDs.
 24  	testHtlcIDSequence uint64
 25  )
 26  
 27  // randomString generates a random string of a given length using rapid.
 28  func randomStringRapid(t *rapid.T, length int) string {
 29  	// Define the character set for the string.
 30  	const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" //nolint:ll
 31  
 32  	// Generate a string by selecting random characters from the charset.
 33  	runes := make([]rune, length)
 34  	for i := range runes {
 35  		// Draw a random index and use it to select a character from the
 36  		// charset.
 37  		index := rapid.IntRange(0, len(charset)-1).Draw(t, "charIndex")
 38  		runes[i] = rune(charset[index])
 39  	}
 40  
 41  	return string(runes)
 42  }
 43  
 44  // randTimeBetween generates a random time between min and max.
 45  func randTimeBetween(minTime, maxTime time.Time) time.Time {
 46  	var timeZones = []*time.Location{
 47  		time.UTC,
 48  		time.FixedZone("EST", -5*3600),
 49  		time.FixedZone("MST", -7*3600),
 50  		time.FixedZone("PST", -8*3600),
 51  		time.FixedZone("CEST", 2*3600),
 52  	}
 53  
 54  	// Ensure max is after min
 55  	if maxTime.Before(minTime) {
 56  		minTime, maxTime = maxTime, minTime
 57  	}
 58  
 59  	// Calculate the range in nanoseconds
 60  	duration := maxTime.Sub(minTime)
 61  	randDuration := time.Duration(rand.Int63n(duration.Nanoseconds()))
 62  
 63  	// Generate the random time
 64  	randomTime := minTime.Add(randDuration)
 65  
 66  	// Assign a random time zone
 67  	randomTimeZone := timeZones[rand.Intn(len(timeZones))]
 68  
 69  	// Return the time in the random time zone
 70  	return randomTime.In(randomTimeZone)
 71  }
 72  
 73  // randTime generates a random time between 2009 and 2140.
 74  func randTime() time.Time {
 75  	minTime := time.Date(2009, 1, 3, 0, 0, 0, 0, time.UTC)
 76  	maxTime := time.Date(2140, 1, 1, 0, 0, 0, 1000, time.UTC)
 77  
 78  	return randTimeBetween(minTime, maxTime)
 79  }
 80  
 81  func randInvoiceTime(invoice *Invoice) time.Time {
 82  	return randTimeBetween(
 83  		invoice.CreationDate,
 84  		invoice.CreationDate.Add(invoice.Terms.Expiry),
 85  	)
 86  }
 87  
 88  // randHTLCRapid generates a random HTLC for an invoice using rapid to randomize
 89  // its parameters.
 90  func randHTLCRapid(t *rapid.T, invoice *Invoice, amt lnwire.MilliSatoshi) (
 91  	models.CircuitKey, *InvoiceHTLC) {
 92  
 93  	htlc := &InvoiceHTLC{
 94  		Amt:          amt,
 95  		AcceptHeight: rapid.Uint32Range(1, 999).Draw(t, "AcceptHeight"),
 96  		AcceptTime:   randInvoiceTime(invoice),
 97  		Expiry:       rapid.Uint32Range(1, 999).Draw(t, "Expiry"),
 98  	}
 99  
100  	// Set MPP total amount if MPP feature is enabled in the invoice.
101  	if invoice.Terms.Features.HasFeature(lnwire.MPPRequired) {
102  		htlc.MppTotalAmt = invoice.Terms.Value
103  	}
104  
105  	// Set the HTLC state and resolve time based on the invoice state.
106  	switch invoice.State {
107  	case ContractSettled:
108  		htlc.State = HtlcStateSettled
109  		htlc.ResolveTime = randInvoiceTime(invoice)
110  
111  	case ContractCanceled:
112  		htlc.State = HtlcStateCanceled
113  		htlc.ResolveTime = randInvoiceTime(invoice)
114  
115  	case ContractAccepted:
116  		htlc.State = HtlcStateAccepted
117  	}
118  
119  	// Add randomized custom records to the HTLC.
120  	htlc.CustomRecords = make(record.CustomSet)
121  	numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords")
122  	for i := 0; i < numRecords; i++ {
123  		key := rapid.Uint64Range(
124  			record.CustomTypeStart, 1000+record.CustomTypeStart,
125  		).Draw(t, "customRecordKey")
126  		value := []byte(randomStringRapid(t, 10))
127  		htlc.CustomRecords[key] = value
128  	}
129  
130  	// Generate a unique HTLC ID and assign it to a channel ID.
131  	htlcID := atomic.AddUint64(&testHtlcIDSequence, 1)
132  	randChanID := lnwire.NewShortChanIDFromInt(htlcID % 5)
133  
134  	circuitKey := models.CircuitKey{
135  		ChanID: randChanID,
136  		HtlcID: htlcID,
137  	}
138  
139  	return circuitKey, htlc
140  }
141  
142  // generateInvoiceHTLCsRapid generates all HTLCs for an invoice, including AMP
143  // HTLCs if applicable, using rapid for randomization of HTLC count and
144  // distribution.
145  func generateInvoiceHTLCsRapid(t *rapid.T, invoice *Invoice) {
146  	mpp := invoice.Terms.Features.HasFeature(lnwire.MPPRequired)
147  
148  	// Use rapid to determine the number of HTLCs based on invoice state and
149  	// MPP feature.
150  	numHTLCs := 1
151  	if invoice.State == ContractOpen {
152  		numHTLCs = 0
153  	} else if mpp {
154  		numHTLCs = rapid.IntRange(1, 10).Draw(t, "numHTLCs")
155  	}
156  
157  	total := invoice.Terms.Value
158  
159  	// Distribute the total amount across the HTLCs, adding any remainder to
160  	// the last HTLC.
161  	if numHTLCs > 0 {
162  		amt := total / lnwire.MilliSatoshi(numHTLCs)
163  		remainder := total - amt*lnwire.MilliSatoshi(numHTLCs)
164  
165  		for i := 0; i < numHTLCs; i++ {
166  			if i == numHTLCs-1 {
167  				// Add remainder to the last HTLC.
168  				amt += remainder
169  			}
170  
171  			// Generate an HTLC with a random circuit key and add it
172  			// to the invoice.
173  			circuitKey, htlc := randHTLCRapid(t, invoice, amt)
174  			invoice.Htlcs[circuitKey] = htlc
175  		}
176  	}
177  }
178  
179  // generateAMPHtlcsRapid generates AMP HTLCs for an invoice using rapid to
180  // randomize various parameters of the HTLCs in the AMP set.
181  func generateAMPHtlcsRapid(t *rapid.T, invoice *Invoice) {
182  	// Randomly determine the number of AMP sets (1 to 5).
183  	numSetIDs := rapid.IntRange(1, 5).Draw(t, "numSetIDs")
184  	settledIdx := uint64(1)
185  
186  	for i := 0; i < numSetIDs; i++ {
187  		var setID SetID
188  		_, err := crand.Read(setID[:])
189  		require.NoError(t, err)
190  
191  		// Determine the number of HTLCs in this set (1 to 5).
192  		numHTLCs := rapid.IntRange(1, 5).Draw(t, "numHTLCs")
193  		total := invoice.Terms.Value
194  		invoiceKeys := make(map[CircuitKey]struct{})
195  
196  		// Calculate the amount per HTLC and account for remainder in
197  		// the final HTLC.
198  		amt := total / lnwire.MilliSatoshi(numHTLCs)
199  		remainder := total - amt*lnwire.MilliSatoshi(numHTLCs)
200  
201  		var htlcState HtlcState
202  		for j := 0; j < numHTLCs; j++ {
203  			if j == numHTLCs-1 {
204  				amt += remainder
205  			}
206  
207  			// Generate HTLC with randomized parameters.
208  			circuitKey, htlc := randHTLCRapid(t, invoice, amt)
209  			htlcState = htlc.State
210  
211  			var (
212  				rootShare, hash [32]byte
213  				preimage        lntypes.Preimage
214  			)
215  
216  			// Randomize AMP data fields.
217  			_, err := crand.Read(rootShare[:])
218  			require.NoError(t, err)
219  			_, err = crand.Read(hash[:])
220  			require.NoError(t, err)
221  			_, err = crand.Read(preimage[:])
222  			require.NoError(t, err)
223  
224  			record := record.NewAMP(rootShare, setID, uint32(j))
225  
226  			htlc.AMP = &InvoiceHtlcAMPData{
227  				Record:   *record,
228  				Hash:     hash,
229  				Preimage: &preimage,
230  			}
231  
232  			invoice.Htlcs[circuitKey] = htlc
233  			invoiceKeys[circuitKey] = struct{}{}
234  		}
235  
236  		ampState := InvoiceStateAMP{
237  			State:       htlcState,
238  			InvoiceKeys: invoiceKeys,
239  		}
240  		if htlcState == HtlcStateSettled {
241  			ampState.SettleIndex = settledIdx
242  			ampState.SettleDate = randInvoiceTime(invoice)
243  			settledIdx++
244  		}
245  
246  		// Set the total amount paid if the AMP set is not canceled.
247  		if htlcState != HtlcStateCanceled {
248  			ampState.AmtPaid = invoice.Terms.Value
249  		}
250  
251  		invoice.AMPState[setID] = ampState
252  	}
253  }
254  
255  // TestMigrateSingleInvoiceRapid tests the migration of single invoices with
256  // random data variations using rapid. This test generates a random invoice
257  // configuration and ensures successful migration.
258  //
259  // NOTE: This test may need to be changed if the Invoice or any of the related
260  // types are modified.
261  func TestMigrateSingleInvoiceRapid(t *testing.T) {
262  	// Create a shared Postgres instance for efficient testing.
263  	pgFixture := sqldb.NewTestPgFixture(
264  		t, sqldb.DefaultPostgresFixtureLifetime,
265  	)
266  	t.Cleanup(func() {
267  		pgFixture.TearDown(t)
268  	})
269  
270  	makeSQLDB := func(t *testing.T, sqlite bool) *SQLStore {
271  		var db *sqldb.BaseDB
272  		if sqlite {
273  			db = sqldb.NewTestSqliteDB(t).BaseDB
274  		} else {
275  			db = sqldb.NewTestPostgresDB(t, pgFixture).BaseDB
276  		}
277  
278  		executor := sqldb.NewTransactionExecutor(
279  			db, func(tx *sql.Tx) SQLInvoiceQueries {
280  				return db.WithTx(tx)
281  			},
282  		)
283  
284  		testClock := clock.NewTestClock(time.Unix(1, 0))
285  
286  		return NewSQLStore(executor, testClock)
287  	}
288  
289  	// Define property-based test using rapid.
290  	rapid.Check(t, func(rt *rapid.T) {
291  		// Randomized feature flags for MPP and AMP.
292  		mpp := rapid.Bool().Draw(rt, "mpp")
293  		amp := rapid.Bool().Draw(rt, "amp")
294  
295  		for _, sqlite := range []bool{true, false} {
296  			store := makeSQLDB(t, sqlite)
297  			testMigrateSingleInvoiceRapid(rt, store, mpp, amp)
298  		}
299  	})
300  }
301  
302  // testMigrateSingleInvoiceRapid is the primary function for the migration of a
303  // single invoice with random data in a rapid-based test setup.
304  func testMigrateSingleInvoiceRapid(t *rapid.T, store *SQLStore, mpp bool,
305  	amp bool) {
306  
307  	ctxb := t.Context()
308  	invoices := make(map[lntypes.Hash]*Invoice)
309  
310  	for i := 0; i < 100; i++ {
311  		invoice := generateTestInvoiceRapid(t, mpp, amp)
312  		var hash lntypes.Hash
313  		_, err := crand.Read(hash[:])
314  		require.NoError(t, err)
315  
316  		invoices[hash] = invoice
317  	}
318  
319  	ops := sqldb.WriteTxOpt()
320  	err := store.db.ExecTx(ctxb, ops, func(tx SQLInvoiceQueries) error {
321  		for hash, invoice := range invoices {
322  			err := MigrateSingleInvoice(ctxb, tx, invoice, hash)
323  			require.NoError(t, err)
324  		}
325  
326  		return nil
327  	}, sqldb.NoOpReset)
328  	require.NoError(t, err)
329  
330  	// Fetch and compare each migrated invoice from the store with the
331  	// original.
332  	for hash, invoice := range invoices {
333  		sqlInvoice, err := store.LookupInvoice(
334  			ctxb, InvoiceRefByHash(hash),
335  		)
336  		require.NoError(t, err)
337  
338  		invoice.AddIndex = sqlInvoice.AddIndex
339  
340  		OverrideInvoiceTimeZone(invoice)
341  		OverrideInvoiceTimeZone(&sqlInvoice)
342  
343  		require.Equal(t, *invoice, sqlInvoice)
344  	}
345  }
346  
347  // generateTestInvoiceRapid generates a random invoice with variations based on
348  // mpp and amp flags.
349  func generateTestInvoiceRapid(t *rapid.T, mpp bool, amp bool) *Invoice {
350  	var preimage lntypes.Preimage
351  	_, err := crand.Read(preimage[:])
352  	require.NoError(t, err)
353  
354  	terms := ContractTerm{
355  		FinalCltvDelta: rapid.Int32Range(1, 1000).Draw(
356  			t, "FinalCltvDelta",
357  		),
358  		Expiry: time.Duration(
359  			rapid.IntRange(1, 4444).Draw(t, "Expiry"),
360  		) * time.Minute,
361  		PaymentPreimage: &preimage,
362  		Value: lnwire.MilliSatoshi(
363  			rapid.Int64Range(1, 9999999).Draw(t, "Value"),
364  		),
365  		PaymentAddr: [32]byte{},
366  		Features:    lnwire.EmptyFeatureVector(),
367  	}
368  
369  	if amp {
370  		terms.Features.Set(lnwire.AMPRequired)
371  	} else if mpp {
372  		terms.Features.Set(lnwire.MPPRequired)
373  	}
374  
375  	created := randTime()
376  
377  	const maxContractState = 3
378  	state := ContractState(
379  		rapid.IntRange(0, maxContractState).Draw(t, "ContractState"),
380  	)
381  	var (
382  		settled     time.Time
383  		settleIndex uint64
384  	)
385  	if state == ContractSettled {
386  		settled = randTimeBetween(created, created.Add(terms.Expiry))
387  		settleIndex = rapid.Uint64Range(1, 999).Draw(t, "SettleIndex")
388  	}
389  
390  	invoice := &Invoice{
391  		Memo: []byte(randomStringRapid(t, 10)),
392  		PaymentRequest: []byte(
393  			randomStringRapid(t, MaxPaymentRequestSize),
394  		),
395  		CreationDate: created,
396  		SettleDate:   settled,
397  		Terms:        terms,
398  		AddIndex:     0,
399  		SettleIndex:  settleIndex,
400  		State:        state,
401  		AMPState:     make(map[SetID]InvoiceStateAMP),
402  		HodlInvoice:  rapid.Bool().Draw(t, "HodlInvoice"),
403  	}
404  
405  	invoice.Htlcs = make(map[models.CircuitKey]*InvoiceHTLC)
406  
407  	if invoice.IsAMP() {
408  		generateAMPHtlcsRapid(t, invoice)
409  	} else {
410  		generateInvoiceHTLCsRapid(t, invoice)
411  	}
412  
413  	for _, htlc := range invoice.Htlcs {
414  		if htlc.State == HtlcStateSettled {
415  			invoice.AmtPaid += htlc.Amt
416  		}
417  	}
418  
419  	return invoice
420  }