Skip to content

Commit 519807c

Browse files
authored
Merge pull request #10665 from calvinrzachman/decode-hop-iterator-blinding
htlcswitch: re-add single decode hop iterator
2 parents 778f4f9 + 24f2b24 commit 519807c

4 files changed

Lines changed: 246 additions & 2 deletions

File tree

htlcswitch/hop/iterator.go

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ func (p *OnionProcessor) ReconstructHopIterator(r io.Reader, rHash []byte,
690690

691691
// Attempt to process the Sphinx packet. We include the payment hash of
692692
// the HTLC as it's authenticated within the Sphinx packet itself as
693-
// associated data in order to thwart attempts a replay attacks. In the
693+
// associated data in order to thwart replay attacks. In the
694694
// case of a replay, an attacker is *forced* to use the same payment
695695
// hash twice, thereby losing their money entirely.
696696
sphinxPacket, err := p.router.ReconstructOnionPacket(
@@ -737,6 +737,69 @@ func (r *DecodeHopIteratorResponse) Result() (Iterator, lnwire.FailCode) {
737737
return r.HopIterator, r.FailCode
738738
}
739739

740+
// DecodeHopIterator attempts to decode a valid sphinx packet from the passed
741+
// io.Reader instance using the rHash as the associated data when checking the
742+
// relevant MACs during the decoding process.
743+
func (p *OnionProcessor) DecodeHopIterator(r io.Reader, rHash []byte,
744+
incomingCltv uint32, incomingAmount lnwire.MilliSatoshi,
745+
blindingPoint lnwire.BlindingPointRecord) (Iterator, lnwire.FailCode) {
746+
747+
onionPkt := &sphinx.OnionPacket{}
748+
if err := onionPkt.Decode(r); err != nil {
749+
switch {
750+
case errors.Is(err, sphinx.ErrInvalidOnionVersion):
751+
return nil, lnwire.CodeInvalidOnionVersion
752+
case errors.Is(err, sphinx.ErrInvalidOnionKey):
753+
return nil, lnwire.CodeInvalidOnionKey
754+
default:
755+
log.Errorf("unable to decode onion packet: %v", err)
756+
return nil, lnwire.CodeInvalidOnionKey
757+
}
758+
}
759+
760+
// If a blinding point was provided in the update_add_htlc message,
761+
// pass it through so the sphinx router can derive the correct shared
762+
// secret for blinded hops.
763+
var opts []sphinx.ProcessOnionOpt
764+
blindingPoint.WhenSome(func(
765+
b tlv.RecordT[lnwire.BlindingPointTlvType,
766+
*btcec.PublicKey]) {
767+
768+
opts = append(opts, sphinx.WithBlindingPoint(b.Val))
769+
})
770+
771+
// Attempt to process the Sphinx packet. We include the payment hash of
772+
// the HTLC as it's authenticated within the Sphinx packet itself as
773+
// associated data in order to thwart replay attacks. In the
774+
// case of a replay, an attacker is *forced* to use the same payment
775+
// hash twice, thereby losing their money entirely.
776+
sphinxPacket, err := p.router.ProcessOnionPacket(
777+
onionPkt, rHash, incomingCltv, opts...,
778+
)
779+
if err != nil {
780+
switch {
781+
case errors.Is(err, sphinx.ErrInvalidOnionVersion):
782+
return nil, lnwire.CodeInvalidOnionVersion
783+
case errors.Is(err, sphinx.ErrInvalidOnionHMAC):
784+
return nil, lnwire.CodeInvalidOnionHmac
785+
case errors.Is(err, sphinx.ErrInvalidOnionKey):
786+
return nil, lnwire.CodeInvalidOnionKey
787+
default:
788+
log.Errorf("unable to process onion packet: %v", err)
789+
return nil, lnwire.CodeInvalidOnionKey
790+
}
791+
}
792+
793+
return makeSphinxHopIterator(p.router, onionPkt, sphinxPacket,
794+
BlindingKit{
795+
Processor: p.router,
796+
UpdateAddBlinding: blindingPoint,
797+
IncomingAmount: incomingAmount,
798+
IncomingCltv: incomingCltv,
799+
}, rHash,
800+
), lnwire.CodeNone
801+
}
802+
740803
// DecodeHopIterators performs batched decoding and validation of incoming
741804
// sphinx packets. For the same `id`, this method will return the same iterators
742805
// and failcodes upon subsequent invocations.

htlcswitch/hop/iterator_test.go

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,170 @@ func TestSphinxHopIteratorForwardingInstructions(t *testing.T) {
103103
}
104104
}
105105

106+
// TestDecodeHopIterator tests that DecodeHopIterator can successfully process
107+
// a real onion packet constructed by the sphinx library and return a valid hop
108+
// iterator with the correct forwarding information. It also tests various error
109+
// cases such as truncated packets and corrupted HMACs.
110+
func TestDecodeHopIterator(t *testing.T) {
111+
t.Parallel()
112+
113+
// Generate a fresh private key for our onion processor (the
114+
// "receiving" node).
115+
receiverPrivKey, err := btcec.NewPrivateKey()
116+
require.NoError(t, err)
117+
118+
sphinxRouter := sphinx.NewRouter(
119+
&sphinx.PrivKeyECDH{PrivKey: receiverPrivKey},
120+
sphinx.NewNoOpReplayLog(),
121+
)
122+
require.NoError(t, sphinxRouter.Start())
123+
defer sphinxRouter.Stop()
124+
125+
processor := NewOnionProcessor(sphinxRouter)
126+
127+
// Session key used by the "sender" to construct the onion.
128+
sessionKey, err := btcec.NewPrivateKey()
129+
require.NoError(t, err)
130+
131+
// Build a TLV payload for the final hop with amount and CLTV.
132+
var (
133+
fwdAmt uint64 = 500_000
134+
outgoingCltv uint32 = 144
135+
incomingCltv uint32 = 200
136+
incomingAmt lnwire.MilliSatoshi = 600_000
137+
noBlinding lnwire.BlindingPointRecord
138+
)
139+
var payloadBuf bytes.Buffer
140+
tlvRecords := []tlv.Record{
141+
record.NewAmtToFwdRecord(&fwdAmt),
142+
record.NewLockTimeRecord(&outgoingCltv),
143+
}
144+
tlvStream, err := tlv.NewStream(tlvRecords...)
145+
require.NoError(t, err)
146+
require.NoError(t, tlvStream.Encode(&payloadBuf))
147+
148+
// Build a one-hop payment path to the receiver.
149+
var path sphinx.PaymentPath
150+
path[0] = sphinx.OnionHop{
151+
NodePub: *receiverPrivKey.PubKey(),
152+
HopPayload: sphinx.HopPayload{
153+
Type: sphinx.PayloadTLV,
154+
Payload: payloadBuf.Bytes(),
155+
},
156+
}
157+
158+
// Create the onion packet.
159+
rHash := [32]byte{0xaa, 0xbb, 0xcc}
160+
onionPkt, err := sphinx.NewOnionPacket(
161+
&path, sessionKey, rHash[:],
162+
sphinx.DeterministicPacketFiller,
163+
)
164+
require.NoError(t, err)
165+
166+
// serializeOnion is a helper that encodes an onion packet to bytes.
167+
serializeOnion := func(pkt *sphinx.OnionPacket) []byte {
168+
var buf bytes.Buffer
169+
require.NoError(t, pkt.Encode(&buf))
170+
return buf.Bytes()
171+
}
172+
173+
validOnionBytes := serializeOnion(onionPkt)
174+
175+
tests := []struct {
176+
name string
177+
onionBytes []byte
178+
rHash []byte
179+
expectedFail lnwire.FailCode
180+
checkPayload bool
181+
}{
182+
{
183+
name: "valid onion",
184+
onionBytes: validOnionBytes,
185+
rHash: rHash[:],
186+
expectedFail: lnwire.CodeNone,
187+
checkPayload: true,
188+
},
189+
{
190+
name: "truncated packet",
191+
onionBytes: validOnionBytes[:10],
192+
rHash: rHash[:],
193+
expectedFail: lnwire.CodeInvalidOnionKey,
194+
},
195+
{
196+
name: "empty reader",
197+
onionBytes: []byte{},
198+
rHash: rHash[:],
199+
expectedFail: lnwire.CodeInvalidOnionKey,
200+
},
201+
{
202+
name: "corrupted HMAC",
203+
onionBytes: func() []byte {
204+
corrupted := make([]byte, len(validOnionBytes))
205+
copy(corrupted, validOnionBytes)
206+
// Flip a byte in the HMAC (last 32 bytes of
207+
// the packet).
208+
corrupted[len(corrupted)-1] ^= 0xff
209+
210+
return corrupted
211+
}(),
212+
rHash: rHash[:],
213+
expectedFail: lnwire.CodeInvalidOnionHmac,
214+
},
215+
{
216+
name: "wrong payment hash",
217+
onionBytes: validOnionBytes,
218+
rHash: bytes.Repeat([]byte{0xff}, 32),
219+
expectedFail: lnwire.CodeInvalidOnionHmac,
220+
},
221+
{
222+
name: "invalid version byte",
223+
onionBytes: func() []byte {
224+
corrupted := make([]byte, len(validOnionBytes))
225+
copy(corrupted, validOnionBytes)
226+
// Set an invalid version (first byte).
227+
corrupted[0] = 0xff
228+
229+
return corrupted
230+
}(),
231+
rHash: rHash[:],
232+
expectedFail: lnwire.CodeInvalidOnionVersion,
233+
},
234+
}
235+
236+
for _, tc := range tests {
237+
t.Run(tc.name, func(t *testing.T) {
238+
t.Parallel()
239+
240+
reader := bytes.NewReader(tc.onionBytes)
241+
iterator, failCode := processor.DecodeHopIterator(
242+
reader, tc.rHash, incomingCltv, incomingAmt,
243+
noBlinding,
244+
)
245+
246+
require.Equal(t, tc.expectedFail, failCode)
247+
248+
if !tc.checkPayload {
249+
return
250+
}
251+
252+
require.NotNil(t, iterator)
253+
254+
payload, role, err := iterator.HopPayload()
255+
require.NoError(t, err)
256+
require.Equal(t, RouteRoleCleartext, role)
257+
258+
fwdInfo := payload.ForwardingInfo()
259+
require.Equal(
260+
t, lnwire.MilliSatoshi(fwdAmt),
261+
fwdInfo.AmountToForward,
262+
)
263+
require.Equal(
264+
t, outgoingCltv, fwdInfo.OutgoingCTLV,
265+
)
266+
})
267+
}
268+
}
269+
106270
// TestForwardingAmountCalc tests calculation of forwarding amounts from the
107271
// hop's forwarding parameters.
108272
func TestForwardingAmountCalc(t *testing.T) {

htlcswitch/hop/log.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,28 @@ package hop
22

33
import (
44
"github.com/btcsuite/btclog/v2"
5+
"github.com/lightningnetwork/lnd/build"
56
)
67

8+
// Subsystem defines the logging sub system name of this package.
9+
const Subsystem = "HOPS"
10+
711
// log is a logger that is initialized with no output filters. This
812
// means the package will not perform any logging by default until the caller
913
// requests it.
1014
var log btclog.Logger
1115

16+
// The default amount of logging is none.
17+
func init() {
18+
UseLogger(build.NewSubLogger(Subsystem, nil))
19+
}
20+
21+
// DisableLog disables all library log output. Logging output is disabled
22+
// by default until UseLogger is called.
23+
func DisableLog() {
24+
UseLogger(btclog.Disabled)
25+
}
26+
1227
// UseLogger uses a specified Logger to output package logging info. This
1328
// function is called from the parent package htlcswitch logger initialization.
1429
func UseLogger(logger btclog.Logger) {

htlcswitch/mock.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,8 @@ func newMockIteratorDecoder() *mockIteratorDecoder {
491491
}
492492

493493
func (p *mockIteratorDecoder) DecodeHopIterator(r io.Reader, rHash []byte,
494-
cltv uint32) (hop.Iterator, lnwire.FailCode) {
494+
cltv uint32, _ lnwire.MilliSatoshi,
495+
_ lnwire.BlindingPointRecord) (hop.Iterator, lnwire.FailCode) {
495496

496497
var b [4]byte
497498
_, err := r.Read(b[:])
@@ -540,6 +541,7 @@ func (p *mockIteratorDecoder) DecodeHopIterators(id []byte,
540541
for _, req := range reqs {
541542
iterator, failcode := p.DecodeHopIterator(
542543
req.OnionReader, req.RHash, req.IncomingCltv,
544+
req.IncomingAmount, req.BlindingPoint,
543545
)
544546

545547
if p.decodeFail {

0 commit comments

Comments
 (0)