/ onionmessage / test_utils.go
test_utils.go
1 package onionmessage 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "testing" 8 9 "github.com/btcsuite/btcd/btcec/v2" 10 sphinx "github.com/lightningnetwork/lightning-onion" 11 "github.com/lightningnetwork/lnd/lnwire" 12 "github.com/lightningnetwork/lnd/record" 13 "github.com/lightningnetwork/lnd/routing/route" 14 "github.com/lightningnetwork/lnd/tlv" 15 "github.com/stretchr/testify/require" 16 ) 17 18 // mockNodeIDResolver implements NodeIDResolver for tests. 19 type mockNodeIDResolver struct { 20 peers map[lnwire.ShortChannelID]*btcec.PublicKey 21 } 22 23 // addPeer registers a single SCID to pubkey mapping for tests. 24 func (m *mockNodeIDResolver) addPeer(scid lnwire.ShortChannelID, 25 pubKey *btcec.PublicKey) { 26 27 m.peers[scid] = pubKey 28 } 29 30 // newMockNodeIDResolver creates a new instance of mockNodeIDResolver. 31 func newMockNodeIDResolver() *mockNodeIDResolver { 32 return &mockNodeIDResolver{ 33 peers: make(map[lnwire.ShortChannelID]*btcec.PublicKey), 34 } 35 } 36 37 // RemotePubFromSCID resolves a node public key from a short channel ID. 38 func (m *mockNodeIDResolver) RemotePubFromSCID(_ context.Context, 39 scid lnwire.ShortChannelID) (*btcec.PublicKey, error) { 40 41 if pk, ok := m.peers[scid]; ok { 42 return pk, nil 43 } 44 45 return nil, fmt.Errorf("unknown scid: %v", scid) 46 } 47 48 // EncodeBlindedRouteData encodes BlindedRouteData to bytes for use in test 49 // hop payloads. 50 func EncodeBlindedRouteData(t *testing.T, 51 data *record.BlindedRouteData) []byte { 52 53 t.Helper() 54 55 buf, err := record.EncodeBlindedRouteData(data) 56 require.NoError(t, err) 57 58 return buf 59 } 60 61 // BuildBlindedPath creates a BlindedPathInfo from a list of HopInfo. This is a 62 // test helper that wraps sphinx.BuildBlindedPath with a fresh session key. 63 func BuildBlindedPath(t *testing.T, 64 hops []*sphinx.HopInfo) *sphinx.BlindedPathInfo { 65 66 t.Helper() 67 68 sessionKey, err := btcec.NewPrivateKey() 69 require.NoError(t, err) 70 71 blindedPath, err := sphinx.BuildBlindedPath(sessionKey, hops) 72 require.NoError(t, err) 73 74 return blindedPath 75 } 76 77 // ConcatBlindedPaths concatenates two blinded paths. The sender's path points 78 // TO the introduction node (with NextBlindingOverride), and the receiver's 79 // path starts AT the introduction node. The concatenated path includes all 80 // hops from both paths - the sender's last hop instructs forwarding to the 81 // intro node, and all receiver hops follow. 82 func ConcatBlindedPaths(t *testing.T, senderPath, 83 receiverPath *sphinx.BlindedPathInfo) *sphinx.BlindedPathInfo { 84 85 t.Helper() 86 87 // The resulting path uses the sender's session key and introduction 88 // point but concatenates all blinded hops. 89 concatenated := &sphinx.BlindedPath{ 90 IntroductionPoint: senderPath.Path.IntroductionPoint, 91 BlindingPoint: senderPath.Path.BlindingPoint, 92 BlindedHops: append( 93 senderPath.Path.BlindedHops, 94 receiverPath.Path.BlindedHops..., 95 ), 96 } 97 98 return &sphinx.BlindedPathInfo{ 99 Path: concatenated, 100 SessionKey: senderPath.SessionKey, 101 LastEphemeralKey: receiverPath.LastEphemeralKey, 102 } 103 } 104 105 // BuildOnionMessage builds an onion message from a BlindedPathInfo and returns 106 // the message along with the ciphertexts for each blinded hop (in hop order). 107 // If finalPayloads is nil or empty, no final hop payload data is included. 108 func BuildOnionMessage(t *testing.T, blindedPath *sphinx.BlindedPathInfo, 109 finalHopTLVs []*lnwire.FinalHopTLV) (*lnwire.OnionMessage, 110 [][]byte) { 111 112 t.Helper() 113 114 // Convert the blinded path to a sphinx path and add final payloads. 115 sphinxPath, err := route.OnionMessageBlindedPathToSphinxPath( 116 blindedPath.Path, nil, finalHopTLVs, 117 ) 118 require.NoError(t, err) 119 120 onionSessionKey, err := btcec.NewPrivateKey() 121 require.NoError(t, err) 122 123 // Create an onion packet with no associated data. 124 onionPkt, err := sphinx.NewOnionPacket( 125 sphinxPath, onionSessionKey, nil, 126 sphinx.DeterministicPacketFiller, 127 sphinx.WithMaxPayloadSize(sphinx.MaxRoutingPayloadSize), 128 ) 129 require.NoError(t, err) 130 131 // Encode the onion message packet. 132 var buf bytes.Buffer 133 require.NoError(t, onionPkt.Encode(&buf)) 134 135 onionMsg := &lnwire.OnionMessage{ 136 PathKey: blindedPath.SessionKey.PubKey(), 137 OnionBlob: buf.Bytes(), 138 } 139 140 var ctexts [][]byte 141 for _, bh := range blindedPath.Path.BlindedHops { 142 ctexts = append(ctexts, bh.CipherText) 143 } 144 145 return onionMsg, ctexts 146 } 147 148 // PeeledHop captures decrypted state for a single hop when peeling an onion. 149 type PeeledHop struct { 150 EncryptedData []byte 151 Payload *lnwire.OnionMessagePayload 152 IsFinal bool 153 } 154 155 // PeelOnionLayers sequentially processes an onion message, creating a fresh 156 // router for each hop using the provided private keys (one per hop), returning 157 // the encrypted data and decoded payload for each hop until the final hop. 158 func PeelOnionLayers(t *testing.T, privKeys []*btcec.PrivateKey, 159 msg *lnwire.OnionMessage) []PeeledHop { 160 161 t.Helper() 162 163 var onionPkt sphinx.OnionPacket 164 require.NoError(t, onionPkt.Decode(bytes.NewReader(msg.OnionBlob))) 165 166 currentPathKey := msg.PathKey 167 var hops []PeeledHop 168 169 for i := 0; ; i++ { 170 require.Less(t, i, len(privKeys), "more hops than privKeys") 171 172 router := sphinx.NewRouter( 173 &sphinx.PrivKeyECDH{PrivKey: privKeys[i]}, 174 sphinx.NewNoOpReplayLog(), 175 ) 176 require.NoError(t, router.Start()) 177 178 processedPkt, err := router.ProcessOnionPacket( 179 &onionPkt, nil, 10, 180 sphinx.WithBlindingPoint(currentPathKey), 181 ) 182 require.NoError(t, err) 183 184 payload := lnwire.NewOnionMessagePayload() 185 _, err = payload.Decode( 186 bytes.NewReader(processedPkt.Payload.Payload), 187 ) 188 require.NoError(t, err) 189 190 origPayload := *payload 191 origPayload.EncryptedData = bytes.Clone(payload.EncryptedData) 192 193 isFinal := processedPkt.Action == sphinx.ExitNode 194 hops = append(hops, PeeledHop{ 195 EncryptedData: origPayload.EncryptedData, 196 Payload: &origPayload, 197 IsFinal: isFinal, 198 }) 199 200 if isFinal { 201 router.Stop() 202 break 203 } 204 205 decrypted, err := router.DecryptBlindedHopData( 206 currentPathKey, payload.EncryptedData, 207 ) 208 require.NoError(t, err) 209 210 routeData, err := record.DecodeBlindedRouteData( 211 bytes.NewReader(decrypted), 212 ) 213 require.NoError(t, err) 214 215 nextPathKey := deriveNextPathKeyForTest( 216 router, currentPathKey, routeData.NextBlindingOverride, 217 ) 218 require.NotNil(t, nextPathKey) 219 220 router.Stop() 221 222 onionPkt = *processedPkt.NextPacket 223 currentPathKey = nextPathKey 224 } 225 226 return hops 227 } 228 229 // deriveNextPathKeyForTest derives the next path key using the router and 230 // current path key. If an override is provided, it is used instead. 231 func deriveNextPathKeyForTest(router *sphinx.Router, 232 currentPathKey *btcec.PublicKey, 233 override tlv.OptionalRecordT[tlv.TlvType8, 234 *btcec.PublicKey]) *btcec.PublicKey { 235 236 // If an override is provided, use it. 237 return override.UnwrapOrFunc(func() tlv.RecordT[tlv.TlvType8, 238 *btcec.PublicKey] { 239 240 // Otherwise, derive the next path key using the router. 241 nextKey, err := router.NextEphemeral(currentPathKey) 242 if err != nil { 243 // If the derivation fails, return a zero key. 244 return override.Zero() 245 } 246 247 return tlv.NewPrimitiveRecord[tlv.TlvType8](nextKey) 248 }).Val 249 }