@@ -348,86 +348,102 @@ object Sphinx extends Logging {
348
348
HtlcFailure (attribution1_opt.map(n => HoldTime (n._1, ss.remoteNodeId) +: downstreamHoldTimes).getOrElse(Nil ), failure)
349
349
}
350
350
}
351
+ }
352
+
353
+ /**
354
+ * Attribution data is added to the failure packet and prevents a node from evading responsibility for its failures.
355
+ * Nodes that relay attribution data can prove that they are not the erring node and in case the erring node tries
356
+ * to hide, there will only be at most two nodes that can be the erring node (the last one to send attribution data
357
+ * and the one after it). It also adds timing data for each node on the path.
358
+ * Attribution data can also be added to fulfilled HTLCs to provide timing data and allow choosing fast nodes for
359
+ * future payments.
360
+ * https://github.com/lightning/bolts/pull/1044
361
+ */
362
+ object Attribution {
363
+ val maxNumHops = 20
364
+ val holdTimeLength = 4
365
+ val hmacLength = 4 // HMACs are truncated to 4 bytes to save space
366
+ val totalLength = maxNumHops * holdTimeLength + maxNumHops * (maxNumHops + 1 ) / 2 * hmacLength // = 920
367
+
368
+ private def cipher (bytes : ByteVector , sharedSecret : ByteVector32 ): ByteVector = {
369
+ val key = generateKey(" ammagext" , sharedSecret)
370
+ val stream = generateStream(key, totalLength)
371
+ bytes xor stream
372
+ }
351
373
352
374
/**
353
- * Attribution data is added to the failure packet and prevents a node from evading responsibility for its failures.
354
- * Nodes that relay attribution data can prove that they are not the erring node and in case the erring node tries
355
- * to hide, there will only be at most two nodes that can be the erring node (the last one to send attribution data
356
- * and the one after it).
357
- * It also adds timing data for each node on the path.
358
- * https://github.com/lightning/bolts/pull/1044
375
+ * Get the HMACs from the attribution data.
376
+ * The layout of the attribution data is as follows (using maxNumHops = 3 for conciseness):
377
+ * holdTime(0) ++ holdTime(1) ++ holdTime(2) ++
378
+ * hmacs(0)(0) ++ hmacs(0)(1) ++ hmacs(0)(2) ++
379
+ * hmacs(1)(0) ++ hmacs(1)(1) ++
380
+ * hmacs(2)(0)
381
+ *
382
+ * Where `hmac(i)(j)` is the hmac added by node `i` (counted from the node that built the attribution data),
383
+ * assuming it is `maxNumHops - 1 - i - j` hops away from the erring node.
359
384
*/
360
- object Attribution {
361
- val maxNumHops = 20
362
- val holdTimeLength = 4
363
- val hmacLength = 4 // HMACs are truncated to 4 bytes to save space
364
- val totalLength = maxNumHops * holdTimeLength + maxNumHops * (maxNumHops + 1 ) / 2 * hmacLength // = 920
365
-
366
- private def cipher (bytes : ByteVector , sharedSecret : ByteVector32 ): ByteVector = {
367
- val key = generateKey(" ammagext" , sharedSecret)
368
- val stream = generateStream(key, totalLength)
369
- bytes xor stream
370
- }
385
+ private def getHmacs (bytes : ByteVector ): Seq [Seq [ByteVector ]] =
386
+ (0 until maxNumHops).map(i => (0 until (maxNumHops - i)).map(j => {
387
+ val start = maxNumHops * holdTimeLength + (maxNumHops * i - (i * (i - 1 )) / 2 + j) * hmacLength
388
+ bytes.slice(start, start + hmacLength)
389
+ }))
371
390
372
- /**
373
- * Get the HMACs from the attribution data.
374
- * The layout of the attribution data is as follows (using maxNumHops = 3 for conciseness):
375
- * holdTime(0) ++ holdTime(1) ++ holdTime(2) ++
376
- * hmacs(0)(0) ++ hmacs(0)(1) ++ hmacs(0)(2) ++
377
- * hmacs(1)(0) ++ hmacs(1)(1) ++
378
- * hmacs(2)(0)
379
- *
380
- * Where `hmac(i)(j)` is the hmac added by node `i` (counted from the node that built the attribution data),
381
- * assuming it is `maxNumHops - 1 - i - j` hops away from the erring node.
382
- */
383
- private def getHmacs (bytes : ByteVector ): Seq [Seq [ByteVector ]] =
384
- (0 until maxNumHops).map(i => (0 until (maxNumHops - i)).map(j => {
385
- val start = maxNumHops * holdTimeLength + (maxNumHops * i - (i * (i - 1 )) / 2 + j) * hmacLength
386
- bytes.slice(start, start + hmacLength)
387
- }))
388
-
389
- /**
390
- * Computes the HMACs for the node that is `minNumHop` hops away from us. Hence we only compute `maxNumHops - minNumHop` HMACs.
391
- * HMACs are truncated to 4 bytes to save space. An attacker has only one try to guess the HMAC so 4 bytes should be enough.
392
- */
393
- private def computeHmacs (mac : Mac32 , failurePacket : ByteVector , holdTimes : ByteVector , hmacs : Seq [Seq [ByteVector ]], minNumHop : Int ): Seq [ByteVector ] = {
394
- (minNumHop until maxNumHops).map(i => {
395
- val y = maxNumHops - i
396
- mac.mac(failurePacket ++
397
- holdTimes.take(y * holdTimeLength) ++
398
- ByteVector .concat((0 until y - 1 ).map(j => hmacs(j)(i)))).bytes.take(hmacLength)
399
- })
400
- }
391
+ /**
392
+ * Computes the HMACs for the node that is `minNumHop` hops away from us. Hence we only compute `maxNumHops - minNumHop` HMACs.
393
+ * HMACs are truncated to 4 bytes to save space. An attacker has only one try to guess the HMAC so 4 bytes should be enough.
394
+ */
395
+ private def computeHmacs (mac : Mac32 , failurePacket : ByteVector , holdTimes : ByteVector , hmacs : Seq [Seq [ByteVector ]], minNumHop : Int ): Seq [ByteVector ] = {
396
+ (minNumHop until maxNumHops).map(i => {
397
+ val y = maxNumHops - i
398
+ mac.mac(failurePacket ++
399
+ holdTimes.take(y * holdTimeLength) ++
400
+ ByteVector .concat((0 until y - 1 ).map(j => hmacs(j)(i)))).bytes.take(hmacLength)
401
+ })
402
+ }
403
+
404
+ /**
405
+ * Create attribution data to send with the failure packet or with a fulfilled HTLC
406
+ *
407
+ * @param failurePacket_opt the failure packet before being wrapped or `None` for fulfilled HTLCs
408
+ */
409
+ def create (previousAttribution_opt : Option [ByteVector ], failurePacket_opt : Option [ByteVector ], holdTime : FiniteDuration , sharedSecret : ByteVector32 ): ByteVector = {
410
+ val previousAttribution = previousAttribution_opt.getOrElse(ByteVector .low(totalLength))
411
+ val previousHmacs = getHmacs(previousAttribution).dropRight(1 ).map(_.drop(1 ))
412
+ val mac = Hmac256 (generateKey(" um" , sharedSecret))
413
+ val holdTimes = uint32.encode(holdTime.toMillis).require.bytes ++ previousAttribution.take((maxNumHops - 1 ) * holdTimeLength)
414
+ val hmacs = computeHmacs(mac, failurePacket_opt.getOrElse(ByteVector .empty), holdTimes, previousHmacs, 0 ) +: previousHmacs
415
+ cipher(holdTimes ++ ByteVector .concat(hmacs.map(ByteVector .concat(_))), sharedSecret)
416
+ }
401
417
402
- /**
403
- * Create attribution data to send with the failure packet
404
- *
405
- * @param failurePacket the failure packet before being wrapped
406
- */
407
- def create (previousAttribution_opt : Option [ByteVector ], failurePacket : ByteVector , holdTime : FiniteDuration , sharedSecret : ByteVector32 ): ByteVector = {
408
- val previousAttribution = previousAttribution_opt.getOrElse(ByteVector .low(totalLength))
409
- val previousHmacs = getHmacs(previousAttribution).dropRight(1 ).map(_.drop(1 ))
410
- val mac = Hmac256 (generateKey(" um" , sharedSecret))
411
- val holdTimes = uint32.encode(holdTime.toMillis).require.bytes ++ previousAttribution.take((maxNumHops - 1 ) * holdTimeLength)
412
- val hmacs = computeHmacs(mac, failurePacket, holdTimes, previousHmacs, 0 ) +: previousHmacs
413
- cipher(holdTimes ++ ByteVector .concat(hmacs.map(ByteVector .concat(_))), sharedSecret)
418
+ /**
419
+ * Unwrap one hop of attribution data
420
+ * @return a pair with the hold time for this hop and the attribution data for the next hop, or None if the attribution data was invalid
421
+ */
422
+ def unwrap (encrypted : ByteVector , failurePacket : ByteVector , sharedSecret : ByteVector32 , minNumHop : Int ): Option [(FiniteDuration , ByteVector )] = {
423
+ val bytes = cipher(encrypted, sharedSecret)
424
+ val holdTime = uint32.decode(bytes.take(holdTimeLength).bits).require.value.milliseconds
425
+ val hmacs = getHmacs(bytes)
426
+ val mac = Hmac256 (generateKey(" um" , sharedSecret))
427
+ if (computeHmacs(mac, failurePacket, bytes.take(maxNumHops * holdTimeLength), hmacs.drop(1 ), minNumHop) == hmacs.head.drop(minNumHop)) {
428
+ val unwrapped = bytes.slice(holdTimeLength, maxNumHops * holdTimeLength) ++ ByteVector .low(holdTimeLength) ++ ByteVector .concat((hmacs.drop(1 ) :+ Seq ()).map(s => ByteVector .low(hmacLength) ++ ByteVector .concat(s)))
429
+ Some (holdTime, unwrapped)
430
+ } else {
431
+ None
414
432
}
433
+ }
415
434
416
- /**
417
- * Unwrap one hop of attribution data
418
- * @return a pair with the hold time for this hop and the attribution data for the next hop, or None if the attribution data was invalid
419
- */
420
- def unwrap (encrypted : ByteVector , failurePacket : ByteVector , sharedSecret : ByteVector32 , minNumHop : Int ): Option [(FiniteDuration , ByteVector )] = {
421
- val bytes = cipher(encrypted, sharedSecret)
422
- val holdTime = uint32.decode(bytes.take(holdTimeLength).bits).require.value.milliseconds
423
- val hmacs = getHmacs(bytes)
424
- val mac = Hmac256 (generateKey(" um" , sharedSecret))
425
- if (computeHmacs(mac, failurePacket, bytes.take(maxNumHops * holdTimeLength), hmacs.drop(1 ), minNumHop) == hmacs.head.drop(minNumHop)) {
426
- val unwrapped = bytes.slice(holdTimeLength, maxNumHops * holdTimeLength) ++ ByteVector .low(holdTimeLength) ++ ByteVector .concat((hmacs.drop(1 ) :+ Seq ()).map(s => ByteVector .low(hmacLength) ++ ByteVector .concat(s)))
427
- Some (holdTime, unwrapped)
428
- } else {
429
- None
430
- }
435
+ /**
436
+ * Decrypt the hold times from the attribution data of a fulfilled HTLC
437
+ */
438
+ def fulfillHoldTimes (attribution : ByteVector , sharedSecrets : Seq [SharedSecret ], hopIndex : Int = 0 ): List [HoldTime ] = {
439
+ sharedSecrets match {
440
+ case Nil => Nil
441
+ case ss :: tail =>
442
+ unwrap(attribution, ByteVector .empty, ss.secret, hopIndex) match {
443
+ case Some ((holdTime, nextAttribution)) =>
444
+ HoldTime (holdTime, ss.remoteNodeId) :: fulfillHoldTimes(nextAttribution, tail, hopIndex + 1 )
445
+ case None => Nil
446
+ }
431
447
}
432
448
}
433
449
}
0 commit comments