sql_migration_test.go
1 package invoices 2 3 import ( 4 crand "crypto/rand" 5 "database/sql" 6 "math/rand" 7 "sync/atomic" 8 "testing" 9 "time" 10 11 "github.com/lightningnetwork/lnd/clock" 12 "github.com/lightningnetwork/lnd/graph/db/models" 13 "github.com/lightningnetwork/lnd/lntypes" 14 "github.com/lightningnetwork/lnd/lnwire" 15 "github.com/lightningnetwork/lnd/record" 16 "github.com/lightningnetwork/lnd/sqldb" 17 "github.com/stretchr/testify/require" 18 "pgregory.net/rapid" 19 ) 20 21 var ( 22 // testHtlcIDSequence is a global counter for generating unique HTLC 23 // IDs. 24 testHtlcIDSequence uint64 25 ) 26 27 // randomString generates a random string of a given length using rapid. 28 func randomStringRapid(t *rapid.T, length int) string { 29 // Define the character set for the string. 30 const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" //nolint:ll 31 32 // Generate a string by selecting random characters from the charset. 33 runes := make([]rune, length) 34 for i := range runes { 35 // Draw a random index and use it to select a character from the 36 // charset. 37 index := rapid.IntRange(0, len(charset)-1).Draw(t, "charIndex") 38 runes[i] = rune(charset[index]) 39 } 40 41 return string(runes) 42 } 43 44 // randTimeBetween generates a random time between min and max. 45 func randTimeBetween(minTime, maxTime time.Time) time.Time { 46 var timeZones = []*time.Location{ 47 time.UTC, 48 time.FixedZone("EST", -5*3600), 49 time.FixedZone("MST", -7*3600), 50 time.FixedZone("PST", -8*3600), 51 time.FixedZone("CEST", 2*3600), 52 } 53 54 // Ensure max is after min 55 if maxTime.Before(minTime) { 56 minTime, maxTime = maxTime, minTime 57 } 58 59 // Calculate the range in nanoseconds 60 duration := maxTime.Sub(minTime) 61 randDuration := time.Duration(rand.Int63n(duration.Nanoseconds())) 62 63 // Generate the random time 64 randomTime := minTime.Add(randDuration) 65 66 // Assign a random time zone 67 randomTimeZone := timeZones[rand.Intn(len(timeZones))] 68 69 // Return the time in the random time zone 70 return randomTime.In(randomTimeZone) 71 } 72 73 // randTime generates a random time between 2009 and 2140. 74 func randTime() time.Time { 75 minTime := time.Date(2009, 1, 3, 0, 0, 0, 0, time.UTC) 76 maxTime := time.Date(2140, 1, 1, 0, 0, 0, 1000, time.UTC) 77 78 return randTimeBetween(minTime, maxTime) 79 } 80 81 func randInvoiceTime(invoice *Invoice) time.Time { 82 return randTimeBetween( 83 invoice.CreationDate, 84 invoice.CreationDate.Add(invoice.Terms.Expiry), 85 ) 86 } 87 88 // randHTLCRapid generates a random HTLC for an invoice using rapid to randomize 89 // its parameters. 90 func randHTLCRapid(t *rapid.T, invoice *Invoice, amt lnwire.MilliSatoshi) ( 91 models.CircuitKey, *InvoiceHTLC) { 92 93 htlc := &InvoiceHTLC{ 94 Amt: amt, 95 AcceptHeight: rapid.Uint32Range(1, 999).Draw(t, "AcceptHeight"), 96 AcceptTime: randInvoiceTime(invoice), 97 Expiry: rapid.Uint32Range(1, 999).Draw(t, "Expiry"), 98 } 99 100 // Set MPP total amount if MPP feature is enabled in the invoice. 101 if invoice.Terms.Features.HasFeature(lnwire.MPPRequired) { 102 htlc.MppTotalAmt = invoice.Terms.Value 103 } 104 105 // Set the HTLC state and resolve time based on the invoice state. 106 switch invoice.State { 107 case ContractSettled: 108 htlc.State = HtlcStateSettled 109 htlc.ResolveTime = randInvoiceTime(invoice) 110 111 case ContractCanceled: 112 htlc.State = HtlcStateCanceled 113 htlc.ResolveTime = randInvoiceTime(invoice) 114 115 case ContractAccepted: 116 htlc.State = HtlcStateAccepted 117 } 118 119 // Add randomized custom records to the HTLC. 120 htlc.CustomRecords = make(record.CustomSet) 121 numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords") 122 for i := 0; i < numRecords; i++ { 123 key := rapid.Uint64Range( 124 record.CustomTypeStart, 1000+record.CustomTypeStart, 125 ).Draw(t, "customRecordKey") 126 value := []byte(randomStringRapid(t, 10)) 127 htlc.CustomRecords[key] = value 128 } 129 130 // Generate a unique HTLC ID and assign it to a channel ID. 131 htlcID := atomic.AddUint64(&testHtlcIDSequence, 1) 132 randChanID := lnwire.NewShortChanIDFromInt(htlcID % 5) 133 134 circuitKey := models.CircuitKey{ 135 ChanID: randChanID, 136 HtlcID: htlcID, 137 } 138 139 return circuitKey, htlc 140 } 141 142 // generateInvoiceHTLCsRapid generates all HTLCs for an invoice, including AMP 143 // HTLCs if applicable, using rapid for randomization of HTLC count and 144 // distribution. 145 func generateInvoiceHTLCsRapid(t *rapid.T, invoice *Invoice) { 146 mpp := invoice.Terms.Features.HasFeature(lnwire.MPPRequired) 147 148 // Use rapid to determine the number of HTLCs based on invoice state and 149 // MPP feature. 150 numHTLCs := 1 151 if invoice.State == ContractOpen { 152 numHTLCs = 0 153 } else if mpp { 154 numHTLCs = rapid.IntRange(1, 10).Draw(t, "numHTLCs") 155 } 156 157 total := invoice.Terms.Value 158 159 // Distribute the total amount across the HTLCs, adding any remainder to 160 // the last HTLC. 161 if numHTLCs > 0 { 162 amt := total / lnwire.MilliSatoshi(numHTLCs) 163 remainder := total - amt*lnwire.MilliSatoshi(numHTLCs) 164 165 for i := 0; i < numHTLCs; i++ { 166 if i == numHTLCs-1 { 167 // Add remainder to the last HTLC. 168 amt += remainder 169 } 170 171 // Generate an HTLC with a random circuit key and add it 172 // to the invoice. 173 circuitKey, htlc := randHTLCRapid(t, invoice, amt) 174 invoice.Htlcs[circuitKey] = htlc 175 } 176 } 177 } 178 179 // generateAMPHtlcsRapid generates AMP HTLCs for an invoice using rapid to 180 // randomize various parameters of the HTLCs in the AMP set. 181 func generateAMPHtlcsRapid(t *rapid.T, invoice *Invoice) { 182 // Randomly determine the number of AMP sets (1 to 5). 183 numSetIDs := rapid.IntRange(1, 5).Draw(t, "numSetIDs") 184 settledIdx := uint64(1) 185 186 for i := 0; i < numSetIDs; i++ { 187 var setID SetID 188 _, err := crand.Read(setID[:]) 189 require.NoError(t, err) 190 191 // Determine the number of HTLCs in this set (1 to 5). 192 numHTLCs := rapid.IntRange(1, 5).Draw(t, "numHTLCs") 193 total := invoice.Terms.Value 194 invoiceKeys := make(map[CircuitKey]struct{}) 195 196 // Calculate the amount per HTLC and account for remainder in 197 // the final HTLC. 198 amt := total / lnwire.MilliSatoshi(numHTLCs) 199 remainder := total - amt*lnwire.MilliSatoshi(numHTLCs) 200 201 var htlcState HtlcState 202 for j := 0; j < numHTLCs; j++ { 203 if j == numHTLCs-1 { 204 amt += remainder 205 } 206 207 // Generate HTLC with randomized parameters. 208 circuitKey, htlc := randHTLCRapid(t, invoice, amt) 209 htlcState = htlc.State 210 211 var ( 212 rootShare, hash [32]byte 213 preimage lntypes.Preimage 214 ) 215 216 // Randomize AMP data fields. 217 _, err := crand.Read(rootShare[:]) 218 require.NoError(t, err) 219 _, err = crand.Read(hash[:]) 220 require.NoError(t, err) 221 _, err = crand.Read(preimage[:]) 222 require.NoError(t, err) 223 224 record := record.NewAMP(rootShare, setID, uint32(j)) 225 226 htlc.AMP = &InvoiceHtlcAMPData{ 227 Record: *record, 228 Hash: hash, 229 Preimage: &preimage, 230 } 231 232 invoice.Htlcs[circuitKey] = htlc 233 invoiceKeys[circuitKey] = struct{}{} 234 } 235 236 ampState := InvoiceStateAMP{ 237 State: htlcState, 238 InvoiceKeys: invoiceKeys, 239 } 240 if htlcState == HtlcStateSettled { 241 ampState.SettleIndex = settledIdx 242 ampState.SettleDate = randInvoiceTime(invoice) 243 settledIdx++ 244 } 245 246 // Set the total amount paid if the AMP set is not canceled. 247 if htlcState != HtlcStateCanceled { 248 ampState.AmtPaid = invoice.Terms.Value 249 } 250 251 invoice.AMPState[setID] = ampState 252 } 253 } 254 255 // TestMigrateSingleInvoiceRapid tests the migration of single invoices with 256 // random data variations using rapid. This test generates a random invoice 257 // configuration and ensures successful migration. 258 // 259 // NOTE: This test may need to be changed if the Invoice or any of the related 260 // types are modified. 261 func TestMigrateSingleInvoiceRapid(t *testing.T) { 262 // Create a shared Postgres instance for efficient testing. 263 pgFixture := sqldb.NewTestPgFixture( 264 t, sqldb.DefaultPostgresFixtureLifetime, 265 ) 266 t.Cleanup(func() { 267 pgFixture.TearDown(t) 268 }) 269 270 makeSQLDB := func(t *testing.T, sqlite bool) *SQLStore { 271 var db *sqldb.BaseDB 272 if sqlite { 273 db = sqldb.NewTestSqliteDB(t).BaseDB 274 } else { 275 db = sqldb.NewTestPostgresDB(t, pgFixture).BaseDB 276 } 277 278 executor := sqldb.NewTransactionExecutor( 279 db, func(tx *sql.Tx) SQLInvoiceQueries { 280 return db.WithTx(tx) 281 }, 282 ) 283 284 testClock := clock.NewTestClock(time.Unix(1, 0)) 285 286 return NewSQLStore(executor, testClock) 287 } 288 289 // Define property-based test using rapid. 290 rapid.Check(t, func(rt *rapid.T) { 291 // Randomized feature flags for MPP and AMP. 292 mpp := rapid.Bool().Draw(rt, "mpp") 293 amp := rapid.Bool().Draw(rt, "amp") 294 295 for _, sqlite := range []bool{true, false} { 296 store := makeSQLDB(t, sqlite) 297 testMigrateSingleInvoiceRapid(rt, store, mpp, amp) 298 } 299 }) 300 } 301 302 // testMigrateSingleInvoiceRapid is the primary function for the migration of a 303 // single invoice with random data in a rapid-based test setup. 304 func testMigrateSingleInvoiceRapid(t *rapid.T, store *SQLStore, mpp bool, 305 amp bool) { 306 307 ctxb := t.Context() 308 invoices := make(map[lntypes.Hash]*Invoice) 309 310 for i := 0; i < 100; i++ { 311 invoice := generateTestInvoiceRapid(t, mpp, amp) 312 var hash lntypes.Hash 313 _, err := crand.Read(hash[:]) 314 require.NoError(t, err) 315 316 invoices[hash] = invoice 317 } 318 319 ops := sqldb.WriteTxOpt() 320 err := store.db.ExecTx(ctxb, ops, func(tx SQLInvoiceQueries) error { 321 for hash, invoice := range invoices { 322 err := MigrateSingleInvoice(ctxb, tx, invoice, hash) 323 require.NoError(t, err) 324 } 325 326 return nil 327 }, sqldb.NoOpReset) 328 require.NoError(t, err) 329 330 // Fetch and compare each migrated invoice from the store with the 331 // original. 332 for hash, invoice := range invoices { 333 sqlInvoice, err := store.LookupInvoice( 334 ctxb, InvoiceRefByHash(hash), 335 ) 336 require.NoError(t, err) 337 338 invoice.AddIndex = sqlInvoice.AddIndex 339 340 OverrideInvoiceTimeZone(invoice) 341 OverrideInvoiceTimeZone(&sqlInvoice) 342 343 require.Equal(t, *invoice, sqlInvoice) 344 } 345 } 346 347 // generateTestInvoiceRapid generates a random invoice with variations based on 348 // mpp and amp flags. 349 func generateTestInvoiceRapid(t *rapid.T, mpp bool, amp bool) *Invoice { 350 var preimage lntypes.Preimage 351 _, err := crand.Read(preimage[:]) 352 require.NoError(t, err) 353 354 terms := ContractTerm{ 355 FinalCltvDelta: rapid.Int32Range(1, 1000).Draw( 356 t, "FinalCltvDelta", 357 ), 358 Expiry: time.Duration( 359 rapid.IntRange(1, 4444).Draw(t, "Expiry"), 360 ) * time.Minute, 361 PaymentPreimage: &preimage, 362 Value: lnwire.MilliSatoshi( 363 rapid.Int64Range(1, 9999999).Draw(t, "Value"), 364 ), 365 PaymentAddr: [32]byte{}, 366 Features: lnwire.EmptyFeatureVector(), 367 } 368 369 if amp { 370 terms.Features.Set(lnwire.AMPRequired) 371 } else if mpp { 372 terms.Features.Set(lnwire.MPPRequired) 373 } 374 375 created := randTime() 376 377 const maxContractState = 3 378 state := ContractState( 379 rapid.IntRange(0, maxContractState).Draw(t, "ContractState"), 380 ) 381 var ( 382 settled time.Time 383 settleIndex uint64 384 ) 385 if state == ContractSettled { 386 settled = randTimeBetween(created, created.Add(terms.Expiry)) 387 settleIndex = rapid.Uint64Range(1, 999).Draw(t, "SettleIndex") 388 } 389 390 invoice := &Invoice{ 391 Memo: []byte(randomStringRapid(t, 10)), 392 PaymentRequest: []byte( 393 randomStringRapid(t, MaxPaymentRequestSize), 394 ), 395 CreationDate: created, 396 SettleDate: settled, 397 Terms: terms, 398 AddIndex: 0, 399 SettleIndex: settleIndex, 400 State: state, 401 AMPState: make(map[SetID]InvoiceStateAMP), 402 HodlInvoice: rapid.Bool().Draw(t, "HodlInvoice"), 403 } 404 405 invoice.Htlcs = make(map[models.CircuitKey]*InvoiceHTLC) 406 407 if invoice.IsAMP() { 408 generateAMPHtlcsRapid(t, invoice) 409 } else { 410 generateInvoiceHTLCsRapid(t, invoice) 411 } 412 413 for _, htlc := range invoice.Htlcs { 414 if htlc.State == HtlcStateSettled { 415 invoice.AmtPaid += htlc.Amt 416 } 417 } 418 419 return invoice 420 }