/ htlcswitch / failure_test.go
failure_test.go
 1  package htlcswitch
 2  
 3  import (
 4  	"encoding/hex"
 5  	"encoding/json"
 6  	"os"
 7  	"testing"
 8  
 9  	"github.com/btcsuite/btcd/btcec/v2"
10  	sphinx "github.com/lightningnetwork/lightning-onion"
11  	"github.com/lightningnetwork/lnd/lnwire"
12  	"github.com/lightningnetwork/lnd/tlv"
13  	"github.com/stretchr/testify/require"
14  )
15  
16  // TestLongFailureMessage tests that longer failure messages can be interpreted
17  // correctly.
18  func TestLongFailureMessage(t *testing.T) {
19  	t.Parallel()
20  
21  	var testData struct {
22  		SessionKey string   `json:"session_key"`
23  		Path       []string `json:"path"`
24  		Reason     string   `json:"reason"`
25  	}
26  
27  	// Use long 1024-byte test vector from BOLT 04.
28  	testDataBytes, err := os.ReadFile("testdata/long_failure_msg.json")
29  	require.NoError(t, err)
30  	require.NoError(t, json.Unmarshal(testDataBytes, &testData))
31  
32  	sessionKeyBytes, _ := hex.DecodeString(testData.SessionKey)
33  
34  	reason, _ := hex.DecodeString(testData.Reason)
35  
36  	sphinxPath := make([]*btcec.PublicKey, len(testData.Path))
37  	for i, sKey := range testData.Path {
38  		bKey, err := hex.DecodeString(sKey)
39  		require.NoError(t, err)
40  
41  		key, err := btcec.ParsePubKey(bKey)
42  		require.NoError(t, err)
43  
44  		sphinxPath[i] = key
45  	}
46  
47  	sessionKey, _ := btcec.PrivKeyFromBytes(sessionKeyBytes)
48  
49  	circuit := &sphinx.Circuit{
50  		SessionKey:  sessionKey,
51  		PaymentPath: sphinxPath,
52  	}
53  
54  	errorDecryptor := &SphinxErrorDecrypter{
55  		OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit),
56  	}
57  
58  	// Assert that the failure message can still be extracted.
59  	failure, err := errorDecryptor.DecryptError(reason)
60  	require.NoError(t, err)
61  
62  	incorrectDetails, ok := failure.msg.(*lnwire.FailIncorrectDetails)
63  	require.True(t, ok)
64  
65  	var value varBytesRecordProducer
66  
67  	extraData := incorrectDetails.ExtraOpaqueData()
68  	typeMap, err := extraData.ExtractRecords(&value)
69  	require.NoError(t, err)
70  	require.Len(t, typeMap, 1)
71  
72  	expectedValue := make([]byte, 300)
73  	for i := range expectedValue {
74  		expectedValue[i] = 128
75  	}
76  
77  	require.Equal(t, expectedValue, value.data)
78  }
79  
80  type varBytesRecordProducer struct {
81  	data []byte
82  }
83  
84  func (v *varBytesRecordProducer) Record() tlv.Record {
85  	return tlv.MakePrimitiveRecord(34001, &v.data)
86  }