Skip to content

Commit 0e15be3

Browse files
committed
Handle attribution data for trampoline payments
The trampoline node must unwrap the attribution data with its shared secrets and use what's remaining as the attribution data from the next trampoline node.
1 parent d405c1a commit 0e15be3

23 files changed

+179
-147
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,10 @@ object Sphinx extends Logging {
280280
/**
281281
* The downstream failure could not be decrypted.
282282
*
283-
* @param unwrapped encrypted failure packet after unwrapping using our shared secrets.
283+
* @param unwrapped encrypted failure packet after unwrapping using our shared secrets.
284+
* @param attribution_opt attribution data after unwrapping using our shared secrets
284285
*/
285-
case class CannotDecryptFailurePacket(unwrapped: ByteVector)
286+
case class CannotDecryptFailurePacket(unwrapped: ByteVector, attribution_opt: Option[ByteVector])
286287

287288
case class HoldTime(duration: FiniteDuration, remoteNodeId: PublicKey)
288289

@@ -334,16 +335,16 @@ object Sphinx extends Logging {
334335
* @return failure message if the origin of the packet could be identified and the packet decrypted, the unwrapped
335336
* failure packet otherwise.
336337
*/
337-
def decrypt(packet: ByteVector, attribution_opt: Option[ByteVector], sharedSecrets: Seq[SharedSecret], hopIndex: Int = 0): HtlcFailure = {
338+
def decrypt(packet: ByteVector, attribution_opt: Option[ByteVector], sharedSecrets: Seq[SharedSecret]): HtlcFailure = {
338339
sharedSecrets match {
339-
case Nil => HtlcFailure(Nil, Left(CannotDecryptFailurePacket(packet)))
340+
case Nil => HtlcFailure(Nil, Left(CannotDecryptFailurePacket(packet, attribution_opt)))
340341
case ss :: tail =>
341342
val packet1 = wrap(packet, ss.secret)
342-
val attribution1_opt = attribution_opt.flatMap(Attribution.unwrap(_, packet1, ss.secret, hopIndex))
343+
val attribution1_opt = attribution_opt.flatMap(Attribution.unwrap(_, packet1, ss.secret, sharedSecrets.length))
343344
val um = generateKey("um", ss.secret)
344345
val HtlcFailure(downstreamHoldTimes, failure) = FailureMessageCodecs.failureOnionCodec(Hmac256(um)).decode(packet1.toBitVector) match {
345346
case Attempt.Successful(value) => HtlcFailure(Nil, Right(DecryptedFailurePacket(ss.remoteNodeId, value.value)))
346-
case _ => decrypt(packet1, attribution1_opt.map(_._2), tail, hopIndex + 1)
347+
case _ => decrypt(packet1, attribution1_opt.map(_._2), tail)
347348
}
348349
HtlcFailure(attribution1_opt.map(n => HoldTime(n._1, ss.remoteNodeId) +: downstreamHoldTimes).getOrElse(Nil), failure)
349350
}
@@ -389,11 +390,11 @@ object Sphinx extends Logging {
389390
}))
390391

391392
/**
392-
* Computes the HMACs for the node that is `minNumHop` hops away from us. Hence we only compute `maxNumHops - minNumHop` HMACs.
393+
* Computes the HMACs for the node that is `maxNumHops - remainingHops` hops away from us. Hence we only compute `remainingHops` HMACs.
393394
* HMACs are truncated to 4 bytes to save space. An attacker has only one try to guess the HMAC so 4 bytes should be enough.
394395
*/
395-
private def computeHmacs(mac: Mac32, failurePacket: ByteVector, holdTimes: ByteVector, hmacs: Seq[Seq[ByteVector]], minNumHop: Int): Seq[ByteVector] = {
396-
(minNumHop until maxNumHops).map(i => {
396+
private def computeHmacs(mac: Mac32, failurePacket: ByteVector, holdTimes: ByteVector, hmacs: Seq[Seq[ByteVector]], remainingHops: Int): Seq[ByteVector] = {
397+
((maxNumHops - remainingHops) until maxNumHops).map(i => {
397398
val y = maxNumHops - i
398399
mac.mac(failurePacket ++
399400
holdTimes.take(y * holdTimeLength) ++
@@ -411,38 +412,41 @@ object Sphinx extends Logging {
411412
val previousHmacs = getHmacs(previousAttribution).dropRight(1).map(_.drop(1))
412413
val mac = Hmac256(generateKey("um", sharedSecret))
413414
val holdTimes = uint32.encode(holdTime.toMillis).require.bytes ++ previousAttribution.take((maxNumHops - 1) * holdTimeLength)
414-
val hmacs = computeHmacs(mac, failurePacket_opt.getOrElse(ByteVector.empty), holdTimes, previousHmacs, 0) +: previousHmacs
415+
val hmacs = computeHmacs(mac, failurePacket_opt.getOrElse(ByteVector.empty), holdTimes, previousHmacs, maxNumHops) +: previousHmacs
415416
cipher(holdTimes ++ ByteVector.concat(hmacs.map(ByteVector.concat(_))), sharedSecret)
416417
}
417418

418419
/**
419420
* Unwrap one hop of attribution data
420421
* @return a pair with the hold time for this hop and the attribution data for the next hop, or None if the attribution data was invalid
421422
*/
422-
def unwrap(encrypted: ByteVector, failurePacket: ByteVector, sharedSecret: ByteVector32, minNumHop: Int): Option[(FiniteDuration, ByteVector)] = {
423+
def unwrap(encrypted: ByteVector, failurePacket: ByteVector, sharedSecret: ByteVector32, remainingHops: Int): Option[(FiniteDuration, ByteVector)] = {
423424
val bytes = cipher(encrypted, sharedSecret)
424425
val holdTime = uint32.decode(bytes.take(holdTimeLength).bits).require.value.milliseconds
425426
val hmacs = getHmacs(bytes)
426427
val mac = Hmac256(generateKey("um", sharedSecret))
427-
if (computeHmacs(mac, failurePacket, bytes.take(maxNumHops * holdTimeLength), hmacs.drop(1), minNumHop) == hmacs.head.drop(minNumHop)) {
428+
if (computeHmacs(mac, failurePacket, bytes.take(maxNumHops * holdTimeLength), hmacs.drop(1), remainingHops) == hmacs.head.drop(maxNumHops - remainingHops)) {
428429
val unwrapped = bytes.slice(holdTimeLength, maxNumHops * holdTimeLength) ++ ByteVector.low(holdTimeLength) ++ ByteVector.concat((hmacs.drop(1) :+ Seq()).map(s => ByteVector.low(hmacLength) ++ ByteVector.concat(s)))
429430
Some(holdTime, unwrapped)
430431
} else {
431432
None
432433
}
433434
}
434435

436+
case class UnwrappedAttribution(holdTimes: List[HoldTime], remaining_opt: Option[ByteVector])
437+
435438
/**
436439
* Decrypt the hold times from the attribution data of a fulfilled HTLC
437440
*/
438-
def fulfillHoldTimes(attribution: ByteVector, sharedSecrets: Seq[SharedSecret], hopIndex: Int = 0): List[HoldTime] = {
441+
def fulfillHoldTimes(attribution: ByteVector, sharedSecrets: Seq[SharedSecret]): UnwrappedAttribution = {
439442
sharedSecrets match {
440-
case Nil => Nil
443+
case Nil => UnwrappedAttribution(Nil, Some(attribution))
441444
case ss :: tail =>
442-
unwrap(attribution, ByteVector.empty, ss.secret, hopIndex) match {
445+
unwrap(attribution, ByteVector.empty, ss.secret, sharedSecrets.length) match {
443446
case Some((holdTime, nextAttribution)) =>
444-
HoldTime(holdTime, ss.remoteNodeId) :: fulfillHoldTimes(nextAttribution, tail, hopIndex + 1)
445-
case None => Nil
447+
val UnwrappedAttribution(holdTimes, remaining_opt) = fulfillHoldTimes(nextAttribution, tail)
448+
UnwrappedAttribution(HoldTime(holdTime, ss.remoteNodeId) :: holdTimes, remaining_opt)
449+
case None => UnwrappedAttribution(Nil, None)
446450
}
447451
}
448452
}

eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ object FailureSummary {
250250
def apply(f: PaymentFailure): FailureSummary = f match {
251251
case LocalFailure(_, route, t) => FailureSummary(FailureType.LOCAL, t.getMessage, route.map(h => HopSummary(h)).toList, route.headOption.map(_.nodeId))
252252
case RemoteFailure(_, route, e) => FailureSummary(FailureType.REMOTE, e.failureMessage.message, route.map(h => HopSummary(h)).toList, Some(e.originNode))
253-
case UnreadableRemoteFailure(_, route, _, _) => FailureSummary(FailureType.UNREADABLE_REMOTE, "could not decrypt failure onion", route.map(h => HopSummary(h)).toList, None)
253+
case UnreadableRemoteFailure(_, route, _, _, _) => FailureSummary(FailureType.UNREADABLE_REMOTE, "could not decrypt failure onion", route.map(h => HopSummary(h)).toList, None)
254254
}
255255
}
256256

eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,8 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
391391
rs.getByteVector32FromHex("payment_preimage"),
392392
MilliSatoshi(rs.getLong("recipient_amount_msat")),
393393
PublicKey(rs.getByteVectorFromHex("recipient_node_id")),
394-
Seq(part))
394+
Seq(part),
395+
None)
395396
}
396397
sentByParentId + (parentId -> sent)
397398
}.values.toSeq.sortBy(_.timestamp)

eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,8 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging {
363363
rs.getByteVector32("payment_preimage"),
364364
MilliSatoshi(rs.getLong("recipient_amount_msat")),
365365
PublicKey(rs.getByteVector("recipient_node_id")),
366-
Seq(part))
366+
Seq(part),
367+
None)
367368
}
368369
sentByParentId + (parentId -> sent)
369370
}.values.toSeq.sortBy(_.timestamp)

eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,16 @@ sealed trait PaymentEvent {
4646
/**
4747
* A payment was successfully sent and fulfilled.
4848
*
49-
* @param id id of the whole payment attempt (if using multi-part, there will be multiple parts, each with
50-
* a different id).
51-
* @param paymentHash payment hash.
52-
* @param paymentPreimage payment preimage (proof of payment).
53-
* @param recipientAmount amount that has been received by the final recipient.
54-
* @param recipientNodeId id of the final recipient.
55-
* @param parts child payments (actual outgoing HTLCs).
49+
* @param id id of the whole payment attempt (if using multi-part, there will be multiple parts,
50+
* each with a different id).
51+
* @param paymentHash payment hash.
52+
* @param paymentPreimage payment preimage (proof of payment).
53+
* @param recipientAmount amount that has been received by the final recipient.
54+
* @param recipientNodeId id of the final recipient.
55+
* @param parts child payments (actual outgoing HTLCs).
56+
* @param remainingAttribution_opt for relayed trampoline payments, the attribution data that needs to be sent upstream
5657
*/
57-
case class PaymentSent(id: UUID, paymentHash: ByteVector32, paymentPreimage: ByteVector32, recipientAmount: MilliSatoshi, recipientNodeId: PublicKey, parts: Seq[PaymentSent.PartialPayment]) extends PaymentEvent {
58+
case class PaymentSent(id: UUID, paymentHash: ByteVector32, paymentPreimage: ByteVector32, recipientAmount: MilliSatoshi, recipientNodeId: PublicKey, parts: Seq[PaymentSent.PartialPayment], remainingAttribution_opt: Option[ByteVector]) extends PaymentEvent {
5859
require(parts.nonEmpty, "must have at least one payment part")
5960
val amountWithFees: MilliSatoshi = parts.map(_.amountWithFees).sum
6061
val feesPaid: MilliSatoshi = amountWithFees - recipientAmount // overall fees for this payment
@@ -151,7 +152,7 @@ case class LocalFailure(amount: MilliSatoshi, route: Seq[Hop], t: Throwable) ext
151152
case class RemoteFailure(amount: MilliSatoshi, route: Seq[Hop], e: Sphinx.DecryptedFailurePacket) extends PaymentFailure
152153

153154
/** A remote node failed the payment but we couldn't decrypt the failure (e.g. a malicious node tampered with the message). */
154-
case class UnreadableRemoteFailure(amount: MilliSatoshi, route: Seq[Hop], failurePacket: ByteVector, holdTimes: Seq[HoldTime]) extends PaymentFailure
155+
case class UnreadableRemoteFailure(amount: MilliSatoshi, route: Seq[Hop], failurePacket: ByteVector, attribution_opt: Option[ByteVector], holdTimes: Seq[HoldTime]) extends PaymentFailure
155156

156157
object PaymentFailure {
157158

@@ -236,7 +237,7 @@ object PaymentFailure {
236237
}
237238
case RemoteFailure(_, hops, Sphinx.DecryptedFailurePacket(nodeId, _)) =>
238239
ignoreNodeOutgoingEdge(nodeId, hops, ignore)
239-
case UnreadableRemoteFailure(_, hops, _, holdTimes) =>
240+
case UnreadableRemoteFailure(_, hops, _, _, holdTimes) =>
240241
// TODO: Once everyone supports attributable errors, we should only exclude two nodes: the last for which we have attribution data and the next one.
241242
// We don't know which node is sending garbage, let's blacklist all nodes except:
242243
// - the nodes that returned attribution data (except the last one)

eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,9 @@ object OutgoingPaymentPacket {
389389
// We drop downstream attribution data and report our own attribution data to the previous trampoline node.
390390
// Note that we could use the downstream attribution data to score downstream nodes.
391391
val trampolinePacket = Sphinx.FailurePacket.wrap(packet, trampolineOnionSecret)
392-
val attribution = Sphinx.Attribution.create(previousAttribution_opt = None, Some(trampolinePacket), holdTime, ss.outerOnionSecret)
393-
(Sphinx.FailurePacket.wrap(trampolinePacket, ss.outerOnionSecret), attribution)
392+
val attributionInner = Sphinx.Attribution.create(previousAttribution_opt, Some(packet), holdTime, trampolineOnionSecret)
393+
val attributionOuter = Sphinx.Attribution.create(Some(attributionInner), Some(trampolinePacket), holdTime, ss.outerOnionSecret)
394+
(Sphinx.FailurePacket.wrap(trampolinePacket, ss.outerOnionSecret), attributionOuter)
394395
case None =>
395396
val attribution = Sphinx.Attribution.create(previousAttribution_opt, Some(packet), holdTime, ss.outerOnionSecret)
396397
(Sphinx.FailurePacket.wrap(packet, ss.outerOnionSecret), attribution)
@@ -402,12 +403,19 @@ object OutgoingPaymentPacket {
402403
(Sphinx.FailurePacket.wrap(packet, ss.outerOnionSecret), attribution)
403404
case FailureReason.LocalTrampolineFailure(failure) =>
404405
// This is a trampoline failure: we try to encrypt it to the node who created the trampoline onion.
405-
val packet = ss.trampolineOnionSecret_opt match {
406-
case Some(trampolineOnionSecret) => Sphinx.FailurePacket.wrap(Sphinx.FailurePacket.create(trampolineOnionSecret, failure), trampolineOnionSecret)
407-
case None => ByteVector.empty // this shouldn't happen, we only generate trampoline failures when there was a trampoline onion
406+
ss.trampolineOnionSecret_opt match {
407+
case Some(trampolineOnionSecret) =>
408+
val packet = Sphinx.FailurePacket.create(trampolineOnionSecret, failure)
409+
val trampolinePacket = Sphinx.FailurePacket.wrap(packet, trampolineOnionSecret)
410+
val attributionInner = Sphinx.Attribution.create(previousAttribution_opt = None, Some(packet), holdTime, trampolineOnionSecret)
411+
val attributionOuter = Sphinx.Attribution.create(Some(attributionInner), Some(trampolinePacket), holdTime, ss.outerOnionSecret)
412+
(Sphinx.FailurePacket.wrap(trampolinePacket, ss.outerOnionSecret), attributionOuter)
413+
414+
case None => // this shouldn't happen, we only generate trampoline failures when there was a trampoline onion
415+
val packet = Sphinx.FailurePacket.create(ss.outerOnionSecret, failure)
416+
val attribution = Sphinx.Attribution.create(previousAttribution_opt = None, Some(packet), holdTime, ss.outerOnionSecret)
417+
(Sphinx.FailurePacket.wrap(packet, ss.outerOnionSecret), attribution)
408418
}
409-
val attribution = Sphinx.Attribution.create(previousAttribution_opt = None, Some(packet), holdTime, ss.outerOnionSecret)
410-
(Sphinx.FailurePacket.wrap(packet, ss.outerOnionSecret), attribution)
411419
}
412420
})
413421
}

eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound}
4242
import fr.acinq.eclair.wire.protocol.PaymentOnion.IntermediatePayload
4343
import fr.acinq.eclair.wire.protocol._
4444
import fr.acinq.eclair.{Alias, CltvExpiry, CltvExpiryDelta, EncodedNodeId, FeatureSupport, Features, InitFeature, InvoiceFeature, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, TimestampMilli, UInt64, UnknownFeature, nodeFee, randomBytes32}
45+
import scodec.bits.ByteVector
4546

4647
import java.util.UUID
4748
import java.util.concurrent.TimeUnit
@@ -195,7 +196,7 @@ object NodeRelay {
195196
// If we received a failure from the next trampoline node, we won't be able to decrypt it: we should encrypt
196197
// it with our trampoline shared secret and relay it upstream, because only the sender can decrypt it.
197198
// Note that we currently don't process the downstream attribution data, but we could!
198-
failures.collectFirst { case UnreadableRemoteFailure(_, _, packet, _) => FailureReason.EncryptedDownstreamFailure(packet, attribution_opt = None) }
199+
failures.collectFirst { case UnreadableRemoteFailure(_, _, packet, attribution_opt, _) => FailureReason.EncryptedDownstreamFailure(packet, attribution_opt) }
199200
.getOrElse(FailureReason.LocalTrampolineFailure(TemporaryTrampolineFailure()))
200201
case nextPayload: IntermediatePayload.NodeRelay.ToNonTrampoline =>
201202
// The recipient doesn't support trampoline: if we received a failure from them, we forward it upstream.
@@ -416,11 +417,11 @@ class NodeRelay private(nodeParams: NodeParams,
416417
Behaviors.receiveMessagePartial {
417418
rejectExtraHtlcPartialFunction orElse {
418419
// this is the fulfill that arrives from downstream channels
419-
case WrappedPreimageReceived(PreimageReceived(_, paymentPreimage)) =>
420+
case WrappedPreimageReceived(PreimageReceived(_, paymentPreimage, attribution_opt)) =>
420421
if (!fulfilledUpstream) {
421422
// We want to fulfill upstream as soon as we receive the preimage (even if not all HTLCs have fulfilled downstream).
422423
context.log.debug("got preimage from downstream")
423-
fulfillPayment(upstream, paymentPreimage)
424+
fulfillPayment(upstream, paymentPreimage, attribution_opt)
424425
sending(upstream, recipient, walletNodeId_opt, recipientFeatures_opt, nextPayload, startedAt, fulfilledUpstream = true)
425426
} else {
426427
// we don't want to fulfill multiple times
@@ -548,16 +549,15 @@ class NodeRelay private(nodeParams: NodeParams,
548549
upstream.received.foreach(r => rejectHtlc(r.add.id, r.add.channelId, upstream.amountIn, r.receivedAt, Some(failure1)))
549550
}
550551

551-
private def fulfillPayment(upstream: Upstream.Hot.Trampoline, paymentPreimage: ByteVector32): Unit = upstream.received.foreach(r => {
552-
// Note that we currently ignore downstream attribution data, but we could process it here to score downstream nodes.
553-
val cmd = CMD_FULFILL_HTLC(r.add.id, paymentPreimage, downstreamAttribution_opt = None, Some(r.receivedAt), commit = true)
552+
private def fulfillPayment(upstream: Upstream.Hot.Trampoline, paymentPreimage: ByteVector32, downstreamAttribution_opt: Option[ByteVector]): Unit = upstream.received.foreach(r => {
553+
val cmd = CMD_FULFILL_HTLC(r.add.id, paymentPreimage, downstreamAttribution_opt, Some(r.receivedAt), commit = true)
554554
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, r.add.channelId, cmd)
555555
})
556556

557557
private def success(upstream: Upstream.Hot.Trampoline, fulfilledUpstream: Boolean, paymentSent: PaymentSent): Unit = {
558558
// We may have already fulfilled upstream, but we can now emit an accurate relayed event and clean-up resources.
559559
if (!fulfilledUpstream) {
560-
fulfillPayment(upstream, paymentSent.paymentPreimage)
560+
fulfillPayment(upstream, paymentSent.paymentPreimage, paymentSent.remainingAttribution_opt)
561561
}
562562
val incoming = upstream.received.map(r => PaymentRelayed.IncomingPart(r.add.amountMsat, r.add.channelId, r.receivedAt))
563563
val outgoing = paymentSent.parts.map(part => PaymentRelayed.OutgoingPart(part.amountWithFees, part.toChannelId, part.timestamp))

eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/OnTheFlyFunding.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ object OnTheFlyFunding {
106106
case Some(f) => f match {
107107
case f: FailureReason.EncryptedDownstreamFailure =>
108108
Sphinx.FailurePacket.decrypt(f.packet, f.attribution_opt, onionSharedSecrets).failure match {
109-
case Left(Sphinx.CannotDecryptFailurePacket(unwrapped)) =>
109+
case Left(Sphinx.CannotDecryptFailurePacket(unwrapped, attribution_opt)) =>
110110
log.info("received encrypted on-the-fly funding failure")
111111
// If we cannot decrypt the error, it is encrypted for the payer using the trampoline onion secrets.
112112
// We unwrap the outer onion encryption and will relay the error upstream.
113-
FailureReason.EncryptedDownstreamFailure(unwrapped, None)
113+
FailureReason.EncryptedDownstreamFailure(unwrapped, attribution_opt)
114114
case Right(f) =>
115115
log.warning("downstream on-the-fly funding failure: {}", f.failureMessage.message)
116116
// Otherwise, there was an issue with the way we forwarded the payment to the recipient.

0 commit comments

Comments
 (0)