@@ -280,9 +280,10 @@ object Sphinx extends Logging {
280
280
/**
281
281
* The downstream failure could not be decrypted.
282
282
*
283
- * @param unwrapped encrypted failure packet after unwrapping using our shared secrets.
283
+ * @param unwrapped encrypted failure packet after unwrapping using our shared secrets.
284
+ * @param attribution_opt attribution data after unwrapping using our shared secrets
284
285
*/
285
- case class CannotDecryptFailurePacket (unwrapped : ByteVector )
286
+ case class CannotDecryptFailurePacket (unwrapped : ByteVector , attribution_opt : Option [ ByteVector ] )
286
287
287
288
case class HoldTime (duration : FiniteDuration , remoteNodeId : PublicKey )
288
289
@@ -334,16 +335,16 @@ object Sphinx extends Logging {
334
335
* @return failure message if the origin of the packet could be identified and the packet decrypted, the unwrapped
335
336
* failure packet otherwise.
336
337
*/
337
- def decrypt (packet : ByteVector , attribution_opt : Option [ByteVector ], sharedSecrets : Seq [SharedSecret ], hopIndex : Int = 0 ): HtlcFailure = {
338
+ def decrypt (packet : ByteVector , attribution_opt : Option [ByteVector ], sharedSecrets : Seq [SharedSecret ]): HtlcFailure = {
338
339
sharedSecrets match {
339
- case Nil => HtlcFailure (Nil , Left (CannotDecryptFailurePacket (packet)))
340
+ case Nil => HtlcFailure (Nil , Left (CannotDecryptFailurePacket (packet, attribution_opt )))
340
341
case ss :: tail =>
341
342
val packet1 = wrap(packet, ss.secret)
342
- val attribution1_opt = attribution_opt.flatMap(Attribution .unwrap(_, packet1, ss.secret, hopIndex ))
343
+ val attribution1_opt = attribution_opt.flatMap(Attribution .unwrap(_, packet1, ss.secret, sharedSecrets.length ))
343
344
val um = generateKey(" um" , ss.secret)
344
345
val HtlcFailure (downstreamHoldTimes, failure) = FailureMessageCodecs .failureOnionCodec(Hmac256 (um)).decode(packet1.toBitVector) match {
345
346
case Attempt .Successful (value) => HtlcFailure (Nil , Right (DecryptedFailurePacket (ss.remoteNodeId, value.value)))
346
- case _ => decrypt(packet1, attribution1_opt.map(_._2), tail, hopIndex + 1 )
347
+ case _ => decrypt(packet1, attribution1_opt.map(_._2), tail)
347
348
}
348
349
HtlcFailure (attribution1_opt.map(n => HoldTime (n._1, ss.remoteNodeId) +: downstreamHoldTimes).getOrElse(Nil ), failure)
349
350
}
@@ -389,11 +390,11 @@ object Sphinx extends Logging {
389
390
}))
390
391
391
392
/**
392
- * Computes the HMACs for the node that is `minNumHop ` hops away from us. Hence we only compute `maxNumHops - minNumHop ` HMACs.
393
+ * Computes the HMACs for the node that is `maxNumHops - remainingHops ` hops away from us. Hence we only compute `remainingHops ` HMACs.
393
394
* HMACs are truncated to 4 bytes to save space. An attacker has only one try to guess the HMAC so 4 bytes should be enough.
394
395
*/
395
- private def computeHmacs (mac : Mac32 , failurePacket : ByteVector , holdTimes : ByteVector , hmacs : Seq [Seq [ByteVector ]], minNumHop : Int ): Seq [ByteVector ] = {
396
- (minNumHop until maxNumHops).map(i => {
396
+ private def computeHmacs (mac : Mac32 , failurePacket : ByteVector , holdTimes : ByteVector , hmacs : Seq [Seq [ByteVector ]], remainingHops : Int ): Seq [ByteVector ] = {
397
+ ((maxNumHops - remainingHops) until maxNumHops).map(i => {
397
398
val y = maxNumHops - i
398
399
mac.mac(failurePacket ++
399
400
holdTimes.take(y * holdTimeLength) ++
@@ -411,38 +412,41 @@ object Sphinx extends Logging {
411
412
val previousHmacs = getHmacs(previousAttribution).dropRight(1 ).map(_.drop(1 ))
412
413
val mac = Hmac256 (generateKey(" um" , sharedSecret))
413
414
val holdTimes = uint32.encode(holdTime.toMillis).require.bytes ++ previousAttribution.take((maxNumHops - 1 ) * holdTimeLength)
414
- val hmacs = computeHmacs(mac, failurePacket_opt.getOrElse(ByteVector .empty), holdTimes, previousHmacs, 0 ) +: previousHmacs
415
+ val hmacs = computeHmacs(mac, failurePacket_opt.getOrElse(ByteVector .empty), holdTimes, previousHmacs, maxNumHops ) +: previousHmacs
415
416
cipher(holdTimes ++ ByteVector .concat(hmacs.map(ByteVector .concat(_))), sharedSecret)
416
417
}
417
418
418
419
/**
419
420
* Unwrap one hop of attribution data
420
421
* @return a pair with the hold time for this hop and the attribution data for the next hop, or None if the attribution data was invalid
421
422
*/
422
- def unwrap (encrypted : ByteVector , failurePacket : ByteVector , sharedSecret : ByteVector32 , minNumHop : Int ): Option [(FiniteDuration , ByteVector )] = {
423
+ def unwrap (encrypted : ByteVector , failurePacket : ByteVector , sharedSecret : ByteVector32 , remainingHops : Int ): Option [(FiniteDuration , ByteVector )] = {
423
424
val bytes = cipher(encrypted, sharedSecret)
424
425
val holdTime = uint32.decode(bytes.take(holdTimeLength).bits).require.value.milliseconds
425
426
val hmacs = getHmacs(bytes)
426
427
val mac = Hmac256 (generateKey(" um" , sharedSecret))
427
- if (computeHmacs(mac, failurePacket, bytes.take(maxNumHops * holdTimeLength), hmacs.drop(1 ), minNumHop ) == hmacs.head.drop(minNumHop )) {
428
+ if (computeHmacs(mac, failurePacket, bytes.take(maxNumHops * holdTimeLength), hmacs.drop(1 ), remainingHops ) == hmacs.head.drop(maxNumHops - remainingHops )) {
428
429
val unwrapped = bytes.slice(holdTimeLength, maxNumHops * holdTimeLength) ++ ByteVector .low(holdTimeLength) ++ ByteVector .concat((hmacs.drop(1 ) :+ Seq ()).map(s => ByteVector .low(hmacLength) ++ ByteVector .concat(s)))
429
430
Some (holdTime, unwrapped)
430
431
} else {
431
432
None
432
433
}
433
434
}
434
435
436
+ case class UnwrappedAttribution (holdTimes : List [HoldTime ], remaining_opt : Option [ByteVector ])
437
+
435
438
/**
436
439
* Decrypt the hold times from the attribution data of a fulfilled HTLC
437
440
*/
438
- def fulfillHoldTimes (attribution : ByteVector , sharedSecrets : Seq [SharedSecret ], hopIndex : Int = 0 ): List [ HoldTime ] = {
441
+ def fulfillHoldTimes (attribution : ByteVector , sharedSecrets : Seq [SharedSecret ]): UnwrappedAttribution = {
439
442
sharedSecrets match {
440
- case Nil => Nil
443
+ case Nil => UnwrappedAttribution ( Nil , Some (attribution))
441
444
case ss :: tail =>
442
- unwrap(attribution, ByteVector .empty, ss.secret, hopIndex ) match {
445
+ unwrap(attribution, ByteVector .empty, ss.secret, sharedSecrets.length ) match {
443
446
case Some ((holdTime, nextAttribution)) =>
444
- HoldTime (holdTime, ss.remoteNodeId) :: fulfillHoldTimes(nextAttribution, tail, hopIndex + 1 )
445
- case None => Nil
447
+ val UnwrappedAttribution (holdTimes, remaining_opt) = fulfillHoldTimes(nextAttribution, tail)
448
+ UnwrappedAttribution (HoldTime (holdTime, ss.remoteNodeId) :: holdTimes, remaining_opt)
449
+ case None => UnwrappedAttribution (Nil , None )
446
450
}
447
451
}
448
452
}
0 commit comments