/ onionmessage / test_utils.go
test_utils.go
  1  package onionmessage
  2  
  3  import (
  4  	"bytes"
  5  	"context"
  6  	"fmt"
  7  	"testing"
  8  
  9  	"github.com/btcsuite/btcd/btcec/v2"
 10  	sphinx "github.com/lightningnetwork/lightning-onion"
 11  	"github.com/lightningnetwork/lnd/lnwire"
 12  	"github.com/lightningnetwork/lnd/record"
 13  	"github.com/lightningnetwork/lnd/routing/route"
 14  	"github.com/lightningnetwork/lnd/tlv"
 15  	"github.com/stretchr/testify/require"
 16  )
 17  
 18  // mockNodeIDResolver implements NodeIDResolver for tests.
 19  type mockNodeIDResolver struct {
 20  	peers map[lnwire.ShortChannelID]*btcec.PublicKey
 21  }
 22  
 23  // addPeer registers a single SCID to pubkey mapping for tests.
 24  func (m *mockNodeIDResolver) addPeer(scid lnwire.ShortChannelID,
 25  	pubKey *btcec.PublicKey) {
 26  
 27  	m.peers[scid] = pubKey
 28  }
 29  
 30  // newMockNodeIDResolver creates a new instance of mockNodeIDResolver.
 31  func newMockNodeIDResolver() *mockNodeIDResolver {
 32  	return &mockNodeIDResolver{
 33  		peers: make(map[lnwire.ShortChannelID]*btcec.PublicKey),
 34  	}
 35  }
 36  
 37  // RemotePubFromSCID resolves a node public key from a short channel ID.
 38  func (m *mockNodeIDResolver) RemotePubFromSCID(_ context.Context,
 39  	scid lnwire.ShortChannelID) (*btcec.PublicKey, error) {
 40  
 41  	if pk, ok := m.peers[scid]; ok {
 42  		return pk, nil
 43  	}
 44  
 45  	return nil, fmt.Errorf("unknown scid: %v", scid)
 46  }
 47  
 48  // EncodeBlindedRouteData encodes BlindedRouteData to bytes for use in test
 49  // hop payloads.
 50  func EncodeBlindedRouteData(t *testing.T,
 51  	data *record.BlindedRouteData) []byte {
 52  
 53  	t.Helper()
 54  
 55  	buf, err := record.EncodeBlindedRouteData(data)
 56  	require.NoError(t, err)
 57  
 58  	return buf
 59  }
 60  
 61  // BuildBlindedPath creates a BlindedPathInfo from a list of HopInfo. This is a
 62  // test helper that wraps sphinx.BuildBlindedPath with a fresh session key.
 63  func BuildBlindedPath(t *testing.T,
 64  	hops []*sphinx.HopInfo) *sphinx.BlindedPathInfo {
 65  
 66  	t.Helper()
 67  
 68  	sessionKey, err := btcec.NewPrivateKey()
 69  	require.NoError(t, err)
 70  
 71  	blindedPath, err := sphinx.BuildBlindedPath(sessionKey, hops)
 72  	require.NoError(t, err)
 73  
 74  	return blindedPath
 75  }
 76  
 77  // ConcatBlindedPaths concatenates two blinded paths. The sender's path points
 78  // TO the introduction node (with NextBlindingOverride), and the receiver's
 79  // path starts AT the introduction node. The concatenated path includes all
 80  // hops from both paths - the sender's last hop instructs forwarding to the
 81  // intro node, and all receiver hops follow.
 82  func ConcatBlindedPaths(t *testing.T, senderPath,
 83  	receiverPath *sphinx.BlindedPathInfo) *sphinx.BlindedPathInfo {
 84  
 85  	t.Helper()
 86  
 87  	// The resulting path uses the sender's session key and introduction
 88  	// point but concatenates all blinded hops.
 89  	concatenated := &sphinx.BlindedPath{
 90  		IntroductionPoint: senderPath.Path.IntroductionPoint,
 91  		BlindingPoint:     senderPath.Path.BlindingPoint,
 92  		BlindedHops: append(
 93  			senderPath.Path.BlindedHops,
 94  			receiverPath.Path.BlindedHops...,
 95  		),
 96  	}
 97  
 98  	return &sphinx.BlindedPathInfo{
 99  		Path:             concatenated,
100  		SessionKey:       senderPath.SessionKey,
101  		LastEphemeralKey: receiverPath.LastEphemeralKey,
102  	}
103  }
104  
105  // BuildOnionMessage builds an onion message from a BlindedPathInfo and returns
106  // the message along with the ciphertexts for each blinded hop (in hop order).
107  // If finalPayloads is nil or empty, no final hop payload data is included.
108  func BuildOnionMessage(t *testing.T, blindedPath *sphinx.BlindedPathInfo,
109  	finalHopTLVs []*lnwire.FinalHopTLV) (*lnwire.OnionMessage,
110  	[][]byte) {
111  
112  	t.Helper()
113  
114  	// Convert the blinded path to a sphinx path and add final payloads.
115  	sphinxPath, err := route.OnionMessageBlindedPathToSphinxPath(
116  		blindedPath.Path, nil, finalHopTLVs,
117  	)
118  	require.NoError(t, err)
119  
120  	onionSessionKey, err := btcec.NewPrivateKey()
121  	require.NoError(t, err)
122  
123  	// Create an onion packet with no associated data.
124  	onionPkt, err := sphinx.NewOnionPacket(
125  		sphinxPath, onionSessionKey, nil,
126  		sphinx.DeterministicPacketFiller,
127  		sphinx.WithMaxPayloadSize(sphinx.MaxRoutingPayloadSize),
128  	)
129  	require.NoError(t, err)
130  
131  	// Encode the onion message packet.
132  	var buf bytes.Buffer
133  	require.NoError(t, onionPkt.Encode(&buf))
134  
135  	onionMsg := &lnwire.OnionMessage{
136  		PathKey:   blindedPath.SessionKey.PubKey(),
137  		OnionBlob: buf.Bytes(),
138  	}
139  
140  	var ctexts [][]byte
141  	for _, bh := range blindedPath.Path.BlindedHops {
142  		ctexts = append(ctexts, bh.CipherText)
143  	}
144  
145  	return onionMsg, ctexts
146  }
147  
148  // PeeledHop captures decrypted state for a single hop when peeling an onion.
149  type PeeledHop struct {
150  	EncryptedData []byte
151  	Payload       *lnwire.OnionMessagePayload
152  	IsFinal       bool
153  }
154  
155  // PeelOnionLayers sequentially processes an onion message, creating a fresh
156  // router for each hop using the provided private keys (one per hop), returning
157  // the encrypted data and decoded payload for each hop until the final hop.
158  func PeelOnionLayers(t *testing.T, privKeys []*btcec.PrivateKey,
159  	msg *lnwire.OnionMessage) []PeeledHop {
160  
161  	t.Helper()
162  
163  	var onionPkt sphinx.OnionPacket
164  	require.NoError(t, onionPkt.Decode(bytes.NewReader(msg.OnionBlob)))
165  
166  	currentPathKey := msg.PathKey
167  	var hops []PeeledHop
168  
169  	for i := 0; ; i++ {
170  		require.Less(t, i, len(privKeys), "more hops than privKeys")
171  
172  		router := sphinx.NewRouter(
173  			&sphinx.PrivKeyECDH{PrivKey: privKeys[i]},
174  			sphinx.NewNoOpReplayLog(),
175  		)
176  		require.NoError(t, router.Start())
177  
178  		processedPkt, err := router.ProcessOnionPacket(
179  			&onionPkt, nil, 10,
180  			sphinx.WithBlindingPoint(currentPathKey),
181  		)
182  		require.NoError(t, err)
183  
184  		payload := lnwire.NewOnionMessagePayload()
185  		_, err = payload.Decode(
186  			bytes.NewReader(processedPkt.Payload.Payload),
187  		)
188  		require.NoError(t, err)
189  
190  		origPayload := *payload
191  		origPayload.EncryptedData = bytes.Clone(payload.EncryptedData)
192  
193  		isFinal := processedPkt.Action == sphinx.ExitNode
194  		hops = append(hops, PeeledHop{
195  			EncryptedData: origPayload.EncryptedData,
196  			Payload:       &origPayload,
197  			IsFinal:       isFinal,
198  		})
199  
200  		if isFinal {
201  			router.Stop()
202  			break
203  		}
204  
205  		decrypted, err := router.DecryptBlindedHopData(
206  			currentPathKey, payload.EncryptedData,
207  		)
208  		require.NoError(t, err)
209  
210  		routeData, err := record.DecodeBlindedRouteData(
211  			bytes.NewReader(decrypted),
212  		)
213  		require.NoError(t, err)
214  
215  		nextPathKey := deriveNextPathKeyForTest(
216  			router, currentPathKey, routeData.NextBlindingOverride,
217  		)
218  		require.NotNil(t, nextPathKey)
219  
220  		router.Stop()
221  
222  		onionPkt = *processedPkt.NextPacket
223  		currentPathKey = nextPathKey
224  	}
225  
226  	return hops
227  }
228  
229  // deriveNextPathKeyForTest derives the next path key using the router and
230  // current path key. If an override is provided, it is used instead.
231  func deriveNextPathKeyForTest(router *sphinx.Router,
232  	currentPathKey *btcec.PublicKey,
233  	override tlv.OptionalRecordT[tlv.TlvType8,
234  		*btcec.PublicKey]) *btcec.PublicKey {
235  
236  	// If an override is provided, use it.
237  	return override.UnwrapOrFunc(func() tlv.RecordT[tlv.TlvType8,
238  		*btcec.PublicKey] {
239  
240  		// Otherwise, derive the next path key using the router.
241  		nextKey, err := router.NextEphemeral(currentPathKey)
242  		if err != nil {
243  			// If the derivation fails, return a zero key.
244  			return override.Zero()
245  		}
246  
247  		return tlv.NewPrimitiveRecord[tlv.TlvType8](nextKey)
248  	}).Val
249  }