Skip to content

Commit d7c020d

Browse files
authored
Refactor attribution helpers and commands (#3125)
* Refactor attribution decryption to prepare for trampoline We update `Sphinx.scala` to shift hmacs differently to make it easier to implement trampoline attribution in follow-up PRs. * Refactor attribution data in `CMD_FAIL_HTLC` and `CMD_FULFILL_HTLC` We group attribution data in a dedicated (optional) class for each command instead of always adding more fields, which requires updating a lot of test code whenever we need to change the contents. We take this opportunity to add the trampoline received at field, and update the pending commands DB.
1 parent 3a9b791 commit d7c020d

31 files changed

+280
-211
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelData.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,12 @@ final case class CMD_ADD_HTLC(replyTo: ActorRef,
221221
origin: Origin.Hot,
222222
commit: Boolean = false) extends HasReplyToCommand with ForbiddenCommandDuringQuiescenceNegotiation with ForbiddenCommandWhenQuiescent
223223

224+
case class FailureAttributionData(htlcReceivedAt: TimestampMilli, trampolineReceivedAt_opt: Option[TimestampMilli])
225+
case class FulfillAttributionData(htlcReceivedAt: TimestampMilli, trampolineReceivedAt_opt: Option[TimestampMilli], downstreamAttribution_opt: Option[ByteVector])
226+
224227
sealed trait HtlcSettlementCommand extends HasOptionalReplyToCommand with ForbiddenCommandDuringQuiescenceNegotiation with ForbiddenCommandWhenQuiescent { def id: Long }
225-
final case class CMD_FULFILL_HTLC(id: Long, r: ByteVector32, downstreamAttribution_opt: Option[ByteVector], htlcReceivedAt_opt: Option[TimestampMilli], commit: Boolean = false, replyTo_opt: Option[ActorRef] = None) extends HtlcSettlementCommand
226-
final case class CMD_FAIL_HTLC(id: Long, reason: FailureReason, htlcReceivedAt_opt: Option[TimestampMilli], delay_opt: Option[FiniteDuration] = None, commit: Boolean = false, replyTo_opt: Option[ActorRef] = None) extends HtlcSettlementCommand
228+
final case class CMD_FULFILL_HTLC(id: Long, r: ByteVector32, attribution_opt: Option[FulfillAttributionData], commit: Boolean = false, replyTo_opt: Option[ActorRef] = None) extends HtlcSettlementCommand
229+
final case class CMD_FAIL_HTLC(id: Long, reason: FailureReason, attribution_opt: Option[FailureAttributionData], delay_opt: Option[FiniteDuration] = None, commit: Boolean = false, replyTo_opt: Option[ActorRef] = None) extends HtlcSettlementCommand
227230
final case class CMD_FAIL_MALFORMED_HTLC(id: Long, onionHash: ByteVector32, failureCode: Int, commit: Boolean = false, replyTo_opt: Option[ActorRef] = None) extends HtlcSettlementCommand
228231
final case class CMD_UPDATE_FEE(feeratePerKw: FeeratePerKw, commit: Boolean = false, replyTo_opt: Option[ActorRef] = None) extends HasOptionalReplyToCommand with ForbiddenCommandDuringQuiescenceNegotiation with ForbiddenCommandWhenQuiescent
229232
final case class CMD_SIGN(replyTo_opt: Option[ActorRef] = None) extends HasOptionalReplyToCommand with ForbiddenCommandWhenQuiescent

eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,8 @@ class Channel(val nodeParams: NodeParams, val channelKeys: ChannelKeys, val wall
715715
case PostRevocationAction.RejectHtlc(add) =>
716716
log.debug("rejecting incoming htlc {}", add)
717717
// NB: we don't set commit = true, we will sign all updates at once afterwards.
718-
self ! CMD_FAIL_HTLC(add.id, FailureReason.LocalFailure(TemporaryChannelFailure(Some(d.channelUpdate))), Some(TimestampMilli.now()), commit = true)
718+
val attribution = FailureAttributionData(htlcReceivedAt = TimestampMilli.now(), trampolineReceivedAt_opt = None)
719+
self ! CMD_FAIL_HTLC(add.id, FailureReason.LocalFailure(TemporaryChannelFailure(Some(d.channelUpdate))), Some(attribution), commit = true)
719720
case PostRevocationAction.RelayFailure(result) =>
720721
log.debug("forwarding {} to relayer", result)
721722
relayer ! result
@@ -1660,11 +1661,13 @@ class Channel(val nodeParams: NodeParams, val channelKeys: ChannelKeys, val wall
16601661
case PostRevocationAction.RelayHtlc(add) =>
16611662
// BOLT 2: A sending node SHOULD fail to route any HTLC added after it sent shutdown.
16621663
log.debug("closing in progress: failing {}", add)
1663-
self ! CMD_FAIL_HTLC(add.id, FailureReason.LocalFailure(PermanentChannelFailure()), Some(TimestampMilli.now()), commit = true)
1664+
val attribution = FailureAttributionData(htlcReceivedAt = TimestampMilli.now(), trampolineReceivedAt_opt = None)
1665+
self ! CMD_FAIL_HTLC(add.id, FailureReason.LocalFailure(PermanentChannelFailure()), Some(attribution), commit = true)
16641666
case PostRevocationAction.RejectHtlc(add) =>
16651667
// BOLT 2: A sending node SHOULD fail to route any HTLC added after it sent shutdown.
16661668
log.debug("closing in progress: rejecting {}", add)
1667-
self ! CMD_FAIL_HTLC(add.id, FailureReason.LocalFailure(PermanentChannelFailure()), Some(TimestampMilli.now()), commit = true)
1669+
val attribution = FailureAttributionData(htlcReceivedAt = TimestampMilli.now(), trampolineReceivedAt_opt = None)
1670+
self ! CMD_FAIL_HTLC(add.id, FailureReason.LocalFailure(PermanentChannelFailure()), Some(attribution), commit = true)
16681671
case PostRevocationAction.RelayFailure(result) =>
16691672
log.debug("forwarding {} to relayer", result)
16701673
relayer ! result

eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -335,16 +335,16 @@ object Sphinx extends Logging {
335335
* @return failure message if the origin of the packet could be identified and the packet decrypted, the unwrapped
336336
* failure packet otherwise.
337337
*/
338-
def decrypt(packet: ByteVector, attribution_opt: Option[ByteVector], sharedSecrets: Seq[SharedSecret], hopIndex: Int = 0): HtlcFailure = {
338+
def decrypt(packet: ByteVector, attribution_opt: Option[ByteVector], sharedSecrets: Seq[SharedSecret]): HtlcFailure = {
339339
sharedSecrets match {
340340
case Nil => HtlcFailure(Nil, Left(CannotDecryptFailurePacket(packet, attribution_opt)))
341341
case ss :: tail =>
342342
val packet1 = wrap(packet, ss.secret)
343-
val attribution1_opt = attribution_opt.flatMap(Attribution.unwrap(_, packet1, ss.secret, hopIndex))
343+
val attribution1_opt = attribution_opt.flatMap(Attribution.unwrap(_, packet1, ss.secret, sharedSecrets.length))
344344
val um = generateKey("um", ss.secret)
345345
val HtlcFailure(downstreamHoldTimes, failure) = FailureMessageCodecs.failureOnionCodec(Hmac256(um)).decode(packet1.toBitVector) match {
346346
case Attempt.Successful(value) => HtlcFailure(Nil, Right(DecryptedFailurePacket(ss.remoteNodeId, value.value)))
347-
case _ => decrypt(packet1, attribution1_opt.map(_._2), tail, hopIndex + 1)
347+
case _ => decrypt(packet1, attribution1_opt.map(_._2), tail)
348348
}
349349
HtlcFailure(attribution1_opt.map(n => HoldTime(n._1, ss.remoteNodeId) +: downstreamHoldTimes).getOrElse(Nil), failure)
350350
}
@@ -390,11 +390,11 @@ object Sphinx extends Logging {
390390
}))
391391

392392
/**
393-
* Computes the HMACs for the node that is `minNumHop` hops away from us. Hence we only compute `maxNumHops - minNumHop` HMACs.
393+
* Computes the HMACs for the node that is `maxNumHops - remainingHops` hops away from us. Hence we only compute `remainingHops` HMACs.
394394
* HMACs are truncated to 4 bytes to save space. An attacker has only one try to guess the HMAC so 4 bytes should be enough.
395395
*/
396-
private def computeHmacs(mac: Mac32, failurePacket: ByteVector, holdTimes: ByteVector, hmacs: Seq[Seq[ByteVector]], minNumHop: Int): Seq[ByteVector] = {
397-
(minNumHop until maxNumHops).map(i => {
396+
private def computeHmacs(mac: Mac32, failurePacket: ByteVector, holdTimes: ByteVector, hmacs: Seq[Seq[ByteVector]], remainingHops: Int): Seq[ByteVector] = {
397+
((maxNumHops - remainingHops) until maxNumHops).map(i => {
398398
val y = maxNumHops - i
399399
mac.mac(failurePacket ++
400400
holdTimes.take(y * holdTimeLength) ++
@@ -403,29 +403,30 @@ object Sphinx extends Logging {
403403
}
404404

405405
/**
406-
* Create attribution data to send with the failure packet or with a fulfilled HTLC
406+
* Create attribution data to send when settling an HTLC (in both failure and success cases).
407407
*
408-
* @param failurePacket_opt the failure packet before being wrapped or `None` for fulfilled HTLCs
408+
* @param failurePacket_opt the failure packet before being wrapped or `None` for fulfilled HTLCs.
409409
*/
410410
def create(previousAttribution_opt: Option[ByteVector], failurePacket_opt: Option[ByteVector], holdTime: FiniteDuration, sharedSecret: ByteVector32): ByteVector = {
411411
val previousAttribution = previousAttribution_opt.getOrElse(ByteVector.low(totalLength))
412412
val previousHmacs = getHmacs(previousAttribution).dropRight(1).map(_.drop(1))
413413
val mac = Hmac256(generateKey("um", sharedSecret))
414414
val holdTimes = uint32.encode(holdTime.toMillis / 100).require.bytes ++ previousAttribution.take((maxNumHops - 1) * holdTimeLength)
415-
val hmacs = computeHmacs(mac, failurePacket_opt.getOrElse(ByteVector.empty), holdTimes, previousHmacs, 0) +: previousHmacs
415+
val hmacs = computeHmacs(mac, failurePacket_opt.getOrElse(ByteVector.empty), holdTimes, previousHmacs, maxNumHops) +: previousHmacs
416416
cipher(holdTimes ++ ByteVector.concat(hmacs.map(ByteVector.concat(_))), sharedSecret)
417417
}
418418

419419
/**
420-
* Unwrap one hop of attribution data
421-
* @return a pair with the hold time for this hop and the attribution data for the next hop, or None if the attribution data was invalid
420+
* Unwrap one hop of attribution data.
421+
*
422+
* @return a pair with the hold time for this hop and the attribution data for the next hop, or None if the attribution data was invalid.
422423
*/
423-
def unwrap(encrypted: ByteVector, failurePacket: ByteVector, sharedSecret: ByteVector32, minNumHop: Int): Option[(FiniteDuration, ByteVector)] = {
424+
def unwrap(encrypted: ByteVector, failurePacket: ByteVector, sharedSecret: ByteVector32, remainingHops: Int): Option[(FiniteDuration, ByteVector)] = {
424425
val bytes = cipher(encrypted, sharedSecret)
425426
val holdTime = (uint32.decode(bytes.take(holdTimeLength).bits).require.value * 100).milliseconds
426427
val hmacs = getHmacs(bytes)
427428
val mac = Hmac256(generateKey("um", sharedSecret))
428-
if (computeHmacs(mac, failurePacket, bytes.take(maxNumHops * holdTimeLength), hmacs.drop(1), minNumHop) == hmacs.head.drop(minNumHop)) {
429+
if (computeHmacs(mac, failurePacket, bytes.take(maxNumHops * holdTimeLength), hmacs.drop(1), remainingHops) == hmacs.head.drop(maxNumHops - remainingHops)) {
429430
val unwrapped = bytes.slice(holdTimeLength, maxNumHops * holdTimeLength) ++ ByteVector.low(holdTimeLength) ++ ByteVector.concat((hmacs.drop(1) :+ Seq()).map(s => ByteVector.low(hmacLength) ++ ByteVector.concat(s)))
430431
Some(holdTime, unwrapped)
431432
} else {
@@ -436,15 +437,15 @@ object Sphinx extends Logging {
436437
case class UnwrappedAttribution(holdTimes: List[HoldTime], remaining_opt: Option[ByteVector])
437438

438439
/**
439-
* Decrypt the hold times from the attribution data of a fulfilled HTLC
440+
* Unwrap many hops of attribution data (e.g. used for fulfilled HTLCs).
440441
*/
441-
def fulfillHoldTimes(attribution: ByteVector, sharedSecrets: Seq[SharedSecret], hopIndex: Int = 0): UnwrappedAttribution = {
442+
def unwrap(attribution: ByteVector, sharedSecrets: Seq[SharedSecret]): UnwrappedAttribution = {
442443
sharedSecrets match {
443444
case Nil => UnwrappedAttribution(Nil, Some(attribution))
444445
case ss :: tail =>
445-
unwrap(attribution, ByteVector.empty, ss.secret, hopIndex) match {
446+
unwrap(attribution, ByteVector.empty, ss.secret, sharedSecrets.length) match {
446447
case Some((holdTime, nextAttribution)) =>
447-
val UnwrappedAttribution(holdTimes, remaining_opt) = fulfillHoldTimes(nextAttribution, tail, hopIndex + 1)
448+
val UnwrappedAttribution(holdTimes, remaining_opt) = unwrap(nextAttribution, tail)
448449
UnwrappedAttribution(HoldTime(holdTime, ss.remoteNodeId) :: holdTimes, remaining_opt)
449450
case None => UnwrappedAttribution(Nil, None)
450451
}

eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package fr.acinq.eclair.payment
1818

1919
import fr.acinq.bitcoin.scalacompat.ByteVector32
2020
import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
21-
import fr.acinq.eclair.channel.{CMD_ADD_HTLC, CMD_FAIL_HTLC, CMD_FULFILL_HTLC, CannotExtractSharedSecret, Origin}
21+
import fr.acinq.eclair.channel._
2222
import fr.acinq.eclair.crypto.Sphinx
2323
import fr.acinq.eclair.payment.send.Recipient
2424
import fr.acinq.eclair.router.Router.Route
@@ -384,7 +384,7 @@ object OutgoingPaymentPacket {
384384
Right(UpdateFailMalformedHtlc(add.channelId, add.id, failure.onionHash, failure.code))
385385
case None =>
386386
// If the htlcReceivedAt was lost (because the node restarted), we use a hold time of 0 which should be ignored by the payer.
387-
val holdTime = cmd.htlcReceivedAt_opt.map(now - _).getOrElse(0 millisecond)
387+
val holdTime = cmd.attribution_opt.map(now - _.htlcReceivedAt).getOrElse(0 millisecond)
388388
buildHtlcFailure(nodeSecret, useAttributableFailures, cmd.reason, add, holdTime).map {
389389
case (encryptedReason, tlvs) => UpdateFailHtlc(add.channelId, cmd.id, encryptedReason, tlvs)
390390
}
@@ -397,8 +397,8 @@ object OutgoingPaymentPacket {
397397
extractSharedSecret(nodeSecret, add) match {
398398
case Left(_) => TlvStream.empty
399399
case Right(sharedSecret) =>
400-
val holdTime = cmd.htlcReceivedAt_opt.map(now - _).getOrElse(0 millisecond)
401-
TlvStream(UpdateFulfillHtlcTlv.AttributionData(Sphinx.Attribution.create(cmd.downstreamAttribution_opt, None, holdTime, sharedSecret)))
400+
val holdTime = cmd.attribution_opt.map(now - _.htlcReceivedAt).getOrElse(0 millisecond)
401+
TlvStream(UpdateFulfillHtlcTlv.AttributionData(Sphinx.Attribution.create(cmd.attribution_opt.flatMap(_.downstreamAttribution_opt), None, holdTime, sharedSecret)))
402402
}
403403
} else {
404404
TlvStream.empty

0 commit comments

Comments
 (0)