@@ -335,16 +335,16 @@ object Sphinx extends Logging {
335
335
* @return failure message if the origin of the packet could be identified and the packet decrypted, the unwrapped
336
336
* failure packet otherwise.
337
337
*/
338
- def decrypt (packet : ByteVector , attribution_opt : Option [ByteVector ], sharedSecrets : Seq [SharedSecret ], hopIndex : Int = 0 ): HtlcFailure = {
338
+ def decrypt (packet : ByteVector , attribution_opt : Option [ByteVector ], sharedSecrets : Seq [SharedSecret ]): HtlcFailure = {
339
339
sharedSecrets match {
340
340
case Nil => HtlcFailure (Nil , Left (CannotDecryptFailurePacket (packet, attribution_opt)))
341
341
case ss :: tail =>
342
342
val packet1 = wrap(packet, ss.secret)
343
- val attribution1_opt = attribution_opt.flatMap(Attribution .unwrap(_, packet1, ss.secret, hopIndex ))
343
+ val attribution1_opt = attribution_opt.flatMap(Attribution .unwrap(_, packet1, ss.secret, sharedSecrets.length ))
344
344
val um = generateKey(" um" , ss.secret)
345
345
val HtlcFailure (downstreamHoldTimes, failure) = FailureMessageCodecs .failureOnionCodec(Hmac256 (um)).decode(packet1.toBitVector) match {
346
346
case Attempt .Successful (value) => HtlcFailure (Nil , Right (DecryptedFailurePacket (ss.remoteNodeId, value.value)))
347
- case _ => decrypt(packet1, attribution1_opt.map(_._2), tail, hopIndex + 1 )
347
+ case _ => decrypt(packet1, attribution1_opt.map(_._2), tail)
348
348
}
349
349
HtlcFailure (attribution1_opt.map(n => HoldTime (n._1, ss.remoteNodeId) +: downstreamHoldTimes).getOrElse(Nil ), failure)
350
350
}
@@ -390,11 +390,11 @@ object Sphinx extends Logging {
390
390
}))
391
391
392
392
/**
393
- * Computes the HMACs for the node that is `minNumHop ` hops away from us. Hence we only compute `maxNumHops - minNumHop ` HMACs.
393
+ * Computes the HMACs for the node that is `maxNumHops - remainingHops ` hops away from us. Hence we only compute `remainingHops ` HMACs.
394
394
* HMACs are truncated to 4 bytes to save space. An attacker has only one try to guess the HMAC so 4 bytes should be enough.
395
395
*/
396
- private def computeHmacs (mac : Mac32 , failurePacket : ByteVector , holdTimes : ByteVector , hmacs : Seq [Seq [ByteVector ]], minNumHop : Int ): Seq [ByteVector ] = {
397
- (minNumHop until maxNumHops).map(i => {
396
+ private def computeHmacs (mac : Mac32 , failurePacket : ByteVector , holdTimes : ByteVector , hmacs : Seq [Seq [ByteVector ]], remainingHops : Int ): Seq [ByteVector ] = {
397
+ ((maxNumHops - remainingHops) until maxNumHops).map(i => {
398
398
val y = maxNumHops - i
399
399
mac.mac(failurePacket ++
400
400
holdTimes.take(y * holdTimeLength) ++
@@ -403,29 +403,30 @@ object Sphinx extends Logging {
403
403
}
404
404
405
405
/**
406
- * Create attribution data to send with the failure packet or with a fulfilled HTLC
406
+ * Create attribution data to send when settling an HTLC (in both failure and success cases).
407
407
*
408
- * @param failurePacket_opt the failure packet before being wrapped or `None` for fulfilled HTLCs
408
+ * @param failurePacket_opt the failure packet before being wrapped or `None` for fulfilled HTLCs.
409
409
*/
410
410
def create (previousAttribution_opt : Option [ByteVector ], failurePacket_opt : Option [ByteVector ], holdTime : FiniteDuration , sharedSecret : ByteVector32 ): ByteVector = {
411
411
val previousAttribution = previousAttribution_opt.getOrElse(ByteVector .low(totalLength))
412
412
val previousHmacs = getHmacs(previousAttribution).dropRight(1 ).map(_.drop(1 ))
413
413
val mac = Hmac256 (generateKey(" um" , sharedSecret))
414
414
val holdTimes = uint32.encode(holdTime.toMillis / 100 ).require.bytes ++ previousAttribution.take((maxNumHops - 1 ) * holdTimeLength)
415
- val hmacs = computeHmacs(mac, failurePacket_opt.getOrElse(ByteVector .empty), holdTimes, previousHmacs, 0 ) +: previousHmacs
415
+ val hmacs = computeHmacs(mac, failurePacket_opt.getOrElse(ByteVector .empty), holdTimes, previousHmacs, maxNumHops ) +: previousHmacs
416
416
cipher(holdTimes ++ ByteVector .concat(hmacs.map(ByteVector .concat(_))), sharedSecret)
417
417
}
418
418
419
419
/**
420
- * Unwrap one hop of attribution data
421
- * @return a pair with the hold time for this hop and the attribution data for the next hop, or None if the attribution data was invalid
420
+ * Unwrap one hop of attribution data.
421
+ *
422
+ * @return a pair with the hold time for this hop and the attribution data for the next hop, or None if the attribution data was invalid.
422
423
*/
423
- def unwrap (encrypted : ByteVector , failurePacket : ByteVector , sharedSecret : ByteVector32 , minNumHop : Int ): Option [(FiniteDuration , ByteVector )] = {
424
+ def unwrap (encrypted : ByteVector , failurePacket : ByteVector , sharedSecret : ByteVector32 , remainingHops : Int ): Option [(FiniteDuration , ByteVector )] = {
424
425
val bytes = cipher(encrypted, sharedSecret)
425
426
val holdTime = (uint32.decode(bytes.take(holdTimeLength).bits).require.value * 100 ).milliseconds
426
427
val hmacs = getHmacs(bytes)
427
428
val mac = Hmac256 (generateKey(" um" , sharedSecret))
428
- if (computeHmacs(mac, failurePacket, bytes.take(maxNumHops * holdTimeLength), hmacs.drop(1 ), minNumHop ) == hmacs.head.drop(minNumHop )) {
429
+ if (computeHmacs(mac, failurePacket, bytes.take(maxNumHops * holdTimeLength), hmacs.drop(1 ), remainingHops ) == hmacs.head.drop(maxNumHops - remainingHops )) {
429
430
val unwrapped = bytes.slice(holdTimeLength, maxNumHops * holdTimeLength) ++ ByteVector .low(holdTimeLength) ++ ByteVector .concat((hmacs.drop(1 ) :+ Seq ()).map(s => ByteVector .low(hmacLength) ++ ByteVector .concat(s)))
430
431
Some (holdTime, unwrapped)
431
432
} else {
@@ -436,15 +437,15 @@ object Sphinx extends Logging {
436
437
case class UnwrappedAttribution (holdTimes : List [HoldTime ], remaining_opt : Option [ByteVector ])
437
438
438
439
/**
439
- * Decrypt the hold times from the attribution data of a fulfilled HTLC
440
+ * Unwrap many hops of attribution data (e.g. used for fulfilled HTLCs).
440
441
*/
441
- def fulfillHoldTimes (attribution : ByteVector , sharedSecrets : Seq [SharedSecret ], hopIndex : Int = 0 ): UnwrappedAttribution = {
442
+ def unwrap (attribution : ByteVector , sharedSecrets : Seq [SharedSecret ]): UnwrappedAttribution = {
442
443
sharedSecrets match {
443
444
case Nil => UnwrappedAttribution (Nil , Some (attribution))
444
445
case ss :: tail =>
445
- unwrap(attribution, ByteVector .empty, ss.secret, hopIndex ) match {
446
+ unwrap(attribution, ByteVector .empty, ss.secret, sharedSecrets.length ) match {
446
447
case Some ((holdTime, nextAttribution)) =>
447
- val UnwrappedAttribution (holdTimes, remaining_opt) = fulfillHoldTimes (nextAttribution, tail, hopIndex + 1 )
448
+ val UnwrappedAttribution (holdTimes, remaining_opt) = unwrap (nextAttribution, tail)
448
449
UnwrappedAttribution (HoldTime (holdTime, ss.remoteNodeId) :: holdTimes, remaining_opt)
449
450
case None => UnwrappedAttribution (Nil , None )
450
451
}
0 commit comments