Skip to content

Commit

Permalink
Various refactoring for trampoline blinded paths (#2952)
Browse files Browse the repository at this point in the history
* Fix offer description documentation

And remove the `currency` fields as we have no short-term plans to
support currency conversion in `eclair`.

* Relax `payment_constraints` requirement in final blinded payload

We don't always need to include a `payment_constraints` field for
ourselves: it's fine to accept payment that don't contain one as
long as we created the `encrypted_recipient_data`, which we can
verify using the `path_id`. We were too restrictive for no good
reason.

* Allow omitting `total_amount` in blinded payments

If the `total_amount` field isn't provided, we can safely default to
using the `amount`, which saves space in the onion. Note that we keep
always encoding it in the outgoing payments we send, we're simply more
permissive when receiving payments.

* Refactor `decryptEncryptedRecipientData`

We extract a helper method for decrypting encrypted recipient data
which will be used when decrypting trampoline blinded paths.

* Use relay methods in `PaymentOnion.IntermediatePayload.NodeRelay`

In order to support blinded trampoline payments, we won't have access to
a direct `amount_to_forward` field, but will use a `payment_relay` TLV
instead, which only allows calculating the outgoing amount from the
incoming amount (same thing for the expiry).

We refactor this to simplify the diff when introducing blinded trampoline
payments.
  • Loading branch information
t-bast authored Nov 29, 2024
1 parent 0d2d380 commit 304290d
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ object Sphinx extends Logging {
* @param sessionKey this node's session key.
* @param publicKeys public keys of each node on the route, starting from the introduction point.
* @param payloads payloads that should be encrypted for each node on the route.
* @return a blinded route and the blinding tweak of the last node.
* @return a blinded route and the path key for the last node.
*/
def create(sessionKey: PrivateKey, publicKeys: Seq[PublicKey], payloads: Seq[ByteVector]): BlindedRouteDetails = {
require(publicKeys.length == payloads.length, "a payload must be provided for each node in the blinded path")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,25 @@ object IncomingPaymentPacket {
if (add.pathKey_opt.isDefined && payload.get[OnionPaymentPayloadTlv.PathKey].isDefined) {
Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket)))
} else {
add.pathKey_opt.orElse(payload.get[OnionPaymentPayloadTlv.PathKey].map(_.publicKey)) match {
case Some(pathKey) => RouteBlindingEncryptedDataCodecs.decode(privateKey, pathKey, encryptedRecipientData) match {
case Left(_) =>
// There are two possibilities in this case:
// - the path key is invalid: the sender or the previous node is buggy or malicious
// - the encrypted data is invalid: the sender, the previous node or the recipient must be buggy or malicious
Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket)))
case Right(decoded) => Right(DecodedEncryptedRecipientData(decoded.tlvs, decoded.nextPathKey))
}
case None =>
// The sender is trying to use route blinding, but we didn't receive the path key used to derive
// the decryption key. The sender or the previous peer is buggy or malicious.
val pathKey_opt = add.pathKey_opt.orElse(payload.get[OnionPaymentPayloadTlv.PathKey].map(_.publicKey))
decryptEncryptedRecipientData(add, privateKey, pathKey_opt, encryptedRecipientData)
}
}

private def decryptEncryptedRecipientData(add: UpdateAddHtlc, privateKey: PrivateKey, pathKey_opt: Option[PublicKey], encryptedRecipientData: ByteVector): Either[FailureMessage, DecodedEncryptedRecipientData] = {
pathKey_opt match {
case Some(pathKey) => RouteBlindingEncryptedDataCodecs.decode(privateKey, pathKey, encryptedRecipientData) match {
case Left(_) =>
// There are two possibilities in this case:
// - the path key is invalid: the sender or the previous node is buggy or malicious
// - the encrypted data is invalid: the sender, the previous node or the recipient must be buggy or malicious
Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket)))
case Right(decoded) => Right(DecodedEncryptedRecipientData(decoded.tlvs, decoded.nextPathKey))
}
case None =>
// The sender is trying to use route blinding, but we didn't receive the path key used to derive
// the decryption key. The sender or the previous peer is buggy or malicious.
Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket)))
}
}

Expand Down Expand Up @@ -213,8 +218,8 @@ object IncomingPaymentPacket {

private def validateBlindedFinalPayload(add: UpdateAddHtlc, payload: TlvStream[OnionPaymentPayloadTlv], blindedPayload: TlvStream[RouteBlindingEncryptedDataTlv]): Either[FailureMessage, FinalPacket] = {
FinalPayload.Blinded.validate(payload, blindedPayload).left.map(_.failureMessage).flatMap {
case payload if add.amountMsat < payload.paymentConstraints.minAmount => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket)))
case payload if add.cltvExpiry > payload.paymentConstraints.maxCltvExpiry => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket)))
case payload if payload.paymentConstraints_opt.exists(c => add.amountMsat < c.minAmount) => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket)))
case payload if payload.paymentConstraints_opt.exists(c => c.maxCltvExpiry < add.cltvExpiry) => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket)))
case payload if !Features.areCompatible(Features.empty, payload.allowedFeatures) => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket)))
case payload if add.amountMsat < payload.amount => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket)))
case payload if add.cltvExpiry < payload.expiry => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,21 @@ object NodeRelay {
}
}

private def outgoingAmount(upstream: Upstream.Hot.Trampoline, payloadOut: IntermediatePayload.NodeRelay): MilliSatoshi = payloadOut.outgoingAmount(upstream.amountIn)

private def outgoingExpiry(upstream: Upstream.Hot.Trampoline, payloadOut: IntermediatePayload.NodeRelay): CltvExpiry = payloadOut.outgoingExpiry(upstream.expiryIn)

private def validateRelay(nodeParams: NodeParams, upstream: Upstream.Hot.Trampoline, payloadOut: IntermediatePayload.NodeRelay): Option[FailureMessage] = {
val fee = nodeFee(nodeParams.relayParams.minTrampolineFees, payloadOut.amountToForward)
if (upstream.amountIn - payloadOut.amountToForward < fee) {
val amountOut = outgoingAmount(upstream, payloadOut)
val expiryOut = outgoingExpiry(upstream, payloadOut)
val fee = nodeFee(nodeParams.relayParams.minTrampolineFees, amountOut)
if (upstream.amountIn - amountOut < fee) {
Some(TrampolineFeeInsufficient())
} else if (upstream.expiryIn - payloadOut.outgoingCltv < nodeParams.channelConf.expiryDelta) {
} else if (upstream.expiryIn - expiryOut < nodeParams.channelConf.expiryDelta) {
Some(TrampolineExpiryTooSoon())
} else if (payloadOut.outgoingCltv <= CltvExpiry(nodeParams.currentBlockHeight)) {
} else if (expiryOut <= CltvExpiry(nodeParams.currentBlockHeight)) {
Some(TrampolineExpiryTooSoon())
} else if (payloadOut.amountToForward <= MilliSatoshi(0)) {
} else if (amountOut <= MilliSatoshi(0)) {
Some(InvalidOnionPayload(UInt64(2), 0))
} else {
None
Expand Down Expand Up @@ -174,8 +180,9 @@ object NodeRelay {
* should return upstream.
*/
private def translateError(nodeParams: NodeParams, failures: Seq[PaymentFailure], upstream: Upstream.Hot.Trampoline, nextPayload: IntermediatePayload.NodeRelay): Option[FailureMessage] = {
val amountOut = outgoingAmount(upstream, nextPayload)
val routeNotFound = failures.collectFirst { case f@LocalFailure(_, _, RouteNotFound) => f }.nonEmpty
val routingFeeHigh = upstream.amountIn - nextPayload.amountToForward >= nodeFee(nodeParams.relayParams.minTrampolineFees, nextPayload.amountToForward) * 5
val routingFeeHigh = upstream.amountIn - amountOut >= nodeFee(nodeParams.relayParams.minTrampolineFees, amountOut) * 5
failures match {
case Nil => None
case LocalFailure(_, _, BalanceTooLow) :: Nil if routingFeeHigh =>
Expand Down Expand Up @@ -320,12 +327,14 @@ class NodeRelay private(nodeParams: NodeParams,

/** Relay the payment to the next identified node: this is similar to sending an outgoing payment. */
private def relay(upstream: Upstream.Hot.Trampoline, recipient: Recipient, walletNodeId_opt: Option[PublicKey], recipientFeatures_opt: Option[Features[InitFeature]], payloadOut: IntermediatePayload.NodeRelay, packetOut_opt: Option[OnionRoutingPacket]): Behavior[Command] = {
context.log.debug("relaying trampoline payment (amountIn={} expiryIn={} amountOut={} expiryOut={} isWallet={})", upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv, walletNodeId_opt.isDefined)
val amountOut = outgoingAmount(upstream, payloadOut)
val expiryOut = outgoingExpiry(upstream, payloadOut)
context.log.debug("relaying trampoline payment (amountIn={} expiryIn={} amountOut={} expiryOut={} isWallet={})", upstream.amountIn, upstream.expiryIn, amountOut, expiryOut, walletNodeId_opt.isDefined)
val confidence = (upstream.received.map(_.add.endorsement).min + 0.5) / 8
// We only make one try when it's a direct payment to a wallet.
val maxPaymentAttempts = if (walletNodeId_opt.isDefined) 1 else nodeParams.maxPaymentAttempts
val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, recipient.nodeId, upstream, None, None, storeInDb = false, publishEvent = false, recordPathFindingMetrics = true, confidence)
val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv)
val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, amountOut, expiryOut)
// If the next node is using trampoline, we assume that they support MPP.
val useMultiPart = recipient.features.hasFeature(Features.BasicMultiPartPayment) || packetOut_opt.nonEmpty
val payFsmAdapters = {
Expand Down Expand Up @@ -393,6 +402,8 @@ class NodeRelay private(nodeParams: NodeParams,

/** We couldn't forward the payment, but the next node may accept on-the-fly funding. */
private def attemptOnTheFlyFunding(upstream: Upstream.Hot.Trampoline, walletNodeId: PublicKey, recipient: Recipient, nextPayload: IntermediatePayload.NodeRelay, failures: Seq[PaymentFailure], startedAt: TimestampMilli): Behavior[Command] = {
val amountOut = outgoingAmount(upstream, nextPayload)
val expiryOut = outgoingExpiry(upstream, nextPayload)
// We create a payment onion, using a dummy channel hop between our node and the wallet node.
val dummyEdge = Invoice.ExtraEdge(nodeParams.nodeId, walletNodeId, Alias(0), 0 msat, 0, CltvExpiryDelta(0), 1 msat, None)
val dummyHop = ChannelHop(Alias(0), nodeParams.nodeId, walletNodeId, HopRelayParams.FromHint(dummyEdge))
Expand All @@ -401,7 +412,7 @@ class NodeRelay private(nodeParams: NodeParams,
case _: SpontaneousRecipient => None
case r: BlindedRecipient => r.blindedHops.headOption
}
val dummyRoute = Route(nextPayload.amountToForward, Seq(dummyHop), finalHop_opt)
val dummyRoute = Route(amountOut, Seq(dummyHop), finalHop_opt)
OutgoingPaymentPacket.buildOutgoingPayment(Origin.Hot(ActorRef.noSender, upstream), paymentHash, dummyRoute, recipient, 1.0) match {
case Left(f) =>
context.log.warn("could not create payment onion for on-the-fly funding: {}", f.getMessage)
Expand All @@ -411,7 +422,7 @@ class NodeRelay private(nodeParams: NodeParams,
case Right(nextPacket) =>
val forwardNodeIdFailureAdapter = context.messageAdapter[Register.ForwardNodeIdFailure[Peer.ProposeOnTheFlyFunding]](_ => WrappedOnTheFlyFundingResponse(Peer.ProposeOnTheFlyFundingResponse.NotAvailable("peer not found")))
val onTheFlyFundingResponseAdapter = context.messageAdapter[Peer.ProposeOnTheFlyFundingResponse](WrappedOnTheFlyFundingResponse)
val cmd = Peer.ProposeOnTheFlyFunding(onTheFlyFundingResponseAdapter, nextPayload.amountToForward, paymentHash, nextPayload.outgoingCltv, nextPacket.cmd.onion, nextPacket.cmd.nextPathKey_opt, upstream)
val cmd = Peer.ProposeOnTheFlyFunding(onTheFlyFundingResponseAdapter, amountOut, paymentHash, expiryOut, nextPacket.cmd.onion, nextPacket.cmd.nextPathKey_opt, upstream)
register ! Register.ForwardNodeId(forwardNodeIdFailureAdapter, walletNodeId, cmd)
Behaviors.receiveMessagePartial {
rejectExtraHtlcPartialFunction orElse {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ object OfferTypes {
// Invoice request TLVs are in the range [0, 159] or [1000000000, 2999999999].
tlv.tag <= UInt64(159) || (tlv.tag >= UInt64(1000000000) && tlv.tag <= UInt64(2999999999L))

def filterOfferFields(tlvs: TlvStream[InvoiceRequestTlv]): TlvStream[OfferTlv] =
private def filterOfferFields(tlvs: TlvStream[InvoiceRequestTlv]): TlvStream[OfferTlv] =
TlvStream[OfferTlv](tlvs.records.collect { case tlv: OfferTlv => tlv }, tlvs.unknown.filter(isOfferTlv))

def filterInvoiceRequestFields(tlvs: TlvStream[InvoiceTlv]): TlvStream[InvoiceRequestTlv] =
Expand All @@ -238,11 +238,7 @@ object OfferTypes {
case class Offer(records: TlvStream[OfferTlv]) {
val chains: Seq[BlockHash] = records.get[OfferChains].map(_.chains).getOrElse(Seq(Block.LivenetGenesisBlock.hash))
val metadata: Option[ByteVector] = records.get[OfferMetadata].map(_.data)
val currency: Option[String] = records.get[OfferCurrency].map(_.iso4217)
val amount: Option[MilliSatoshi] = currency match {
case Some(_) => None // TODO: add exchange rates
case None => records.get[OfferAmount].map(_.amount)
}
val amount: Option[MilliSatoshi] = records.get[OfferAmount].map(_.amount)
val description: Option[String] = records.get[OfferDescription].map(_.description)
val features: Features[Bolt12Feature] = records.get[OfferFeatures].map(_.features.bolt12Features()).getOrElse(Features.empty)
val expiry: Option[TimestampSecond] = records.get[OfferAbsoluteExpiry].map(_.absoluteExpiry)
Expand All @@ -267,11 +263,11 @@ object OfferTypes {
val hrp = "lno"

/**
* @param amount_opt amount if it can be determined at offer creation time.
* @param description description of the offer.
* @param nodeId the nodeId to use for this offer, which should be different from our public nodeId if we're hiding behind a blinded route.
* @param features invoice features.
* @param chain chain on which the offer is valid.
* @param amount_opt amount if it can be determined at offer creation time.
* @param description_opt description of the offer (optional if the offer doesn't include an amount).
* @param nodeId the nodeId to use for this offer, which should be different from our public nodeId if we're hiding behind a blinded route.
* @param features invoice features.
* @param chain chain on which the offer is valid.
*/
def apply(amount_opt: Option[MilliSatoshi],
description_opt: Option[String],
Expand Down Expand Up @@ -312,6 +308,8 @@ object OfferTypes {
def validate(records: TlvStream[OfferTlv]): Either[InvalidTlvPayload, Offer] = {
if (records.get[OfferDescription].isEmpty && records.get[OfferAmount].nonEmpty) return Left(MissingRequiredTlv(UInt64(10)))
if (records.get[OfferNodeId].isEmpty && records.get[OfferPaths].forall(_.paths.isEmpty)) return Left(MissingRequiredTlv(UInt64(22)))
// Currency conversion isn't supported yet.
if (records.get[OfferCurrency].nonEmpty) return Left(ForbiddenTlv(UInt64(6)))
if (records.unknown.exists(!isOfferTlv(_))) return Left(ForbiddenTlv(records.unknown.find(!isOfferTlv(_)).get.tag))
Right(Offer(records))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,14 +289,23 @@ object PaymentOnion {
}

sealed trait NodeRelay extends IntermediatePayload {
val amountToForward = records.get[AmountToForward].get.amount
val outgoingCltv = records.get[OutgoingCltv].get.cltv
// @formatter:off
def outgoingAmount(incomingAmount: MilliSatoshi): MilliSatoshi
def outgoingExpiry(incomingCltv: CltvExpiry): CltvExpiry
// @formatter:on
}

object NodeRelay {
case class Standard(records: TlvStream[OnionPaymentPayloadTlv]) extends NodeRelay {
val amountToForward = records.get[AmountToForward].get.amount
val outgoingCltv = records.get[OutgoingCltv].get.cltv
val outgoingNodeId = records.get[OutgoingNodeId].get.nodeId
val isAsyncPayment: Boolean = records.get[AsyncPayment].isDefined

// @formatter:off
override def outgoingAmount(incomingAmount: MilliSatoshi): MilliSatoshi = amountToForward
override def outgoingExpiry(incomingCltv: CltvExpiry): CltvExpiry = outgoingCltv
// @formatter:on
}

object Standard {
Expand All @@ -321,6 +330,8 @@ object PaymentOnion {

/** We relay to a payment recipient that doesn't support trampoline, which exposes its identity. */
case class ToNonTrampoline(records: TlvStream[OnionPaymentPayloadTlv]) extends NodeRelay {
val amountToForward = records.get[AmountToForward].get.amount
val outgoingCltv = records.get[OutgoingCltv].get.cltv
val outgoingNodeId = records.get[OutgoingNodeId].get.nodeId
val totalAmount = records.get[PaymentData].map(_.totalAmount match {
case MilliSatoshi(0) => amountToForward
Expand All @@ -330,6 +341,11 @@ object PaymentOnion {
val paymentMetadata = records.get[PaymentMetadata].map(_.data)
val invoiceFeatures = records.get[InvoiceFeatures].map(_.features).getOrElse(ByteVector.empty)
val invoiceRoutingInfo = records.get[InvoiceRoutingInfo].map(_.extraHops).get

// @formatter:off
override def outgoingAmount(incomingAmount: MilliSatoshi): MilliSatoshi = amountToForward
override def outgoingExpiry(incomingCltv: CltvExpiry): CltvExpiry = outgoingCltv
// @formatter:on
}

object ToNonTrampoline {
Expand Down Expand Up @@ -360,8 +376,15 @@ object PaymentOnion {

/** We relay to a payment recipient that doesn't support trampoline, but hides its identity using blinded paths. */
case class ToBlindedPaths(records: TlvStream[OnionPaymentPayloadTlv]) extends NodeRelay {
val amountToForward = records.get[AmountToForward].get.amount
val outgoingCltv = records.get[OutgoingCltv].get.cltv
val outgoingBlindedPaths = records.get[OutgoingBlindedPaths].get.paths
val invoiceFeatures = records.get[InvoiceFeatures].get.features

// @formatter:off
override def outgoingAmount(incomingAmount: MilliSatoshi): MilliSatoshi = amountToForward
override def outgoingExpiry(incomingCltv: CltvExpiry): CltvExpiry = outgoingCltv
// @formatter:on
}

object ToBlindedPaths {
Expand Down Expand Up @@ -449,11 +472,11 @@ object PaymentOnion {
*/
case class Blinded(records: TlvStream[OnionPaymentPayloadTlv], blindedRecords: TlvStream[RouteBlindingEncryptedDataTlv]) extends FinalPayload {
override val amount = records.get[AmountToForward].get.amount
override val totalAmount = records.get[TotalAmount].get.totalAmount
override val totalAmount = records.get[TotalAmount].map(_.totalAmount).getOrElse(amount)
override val expiry = records.get[OutgoingCltv].get.cltv
val pathKey_opt: Option[PublicKey] = records.get[PathKey].map(_.publicKey)
val pathId = blindedRecords.get[RouteBlindingEncryptedDataTlv.PathId].get.data
val paymentConstraints = blindedRecords.get[RouteBlindingEncryptedDataTlv.PaymentConstraints].get
val paymentConstraints_opt = blindedRecords.get[RouteBlindingEncryptedDataTlv.PaymentConstraints]
val allowedFeatures = blindedRecords.get[RouteBlindingEncryptedDataTlv.AllowedFeatures].map(_.features).getOrElse(Features.empty)
}

Expand All @@ -462,7 +485,6 @@ object PaymentOnion {
if (records.get[AmountToForward].isEmpty) return Left(MissingRequiredTlv(UInt64(2)))
if (records.get[OutgoingCltv].isEmpty) return Left(MissingRequiredTlv(UInt64(4)))
if (records.get[EncryptedRecipientData].isEmpty) return Left(MissingRequiredTlv(UInt64(10)))
if (records.get[TotalAmount].isEmpty) return Left(MissingRequiredTlv(UInt64(18)))
// Bolt 4: MUST return an error if the payload contains other tlv fields than `encrypted_recipient_data`, `current_path_key`, `amt_to_forward`, `outgoing_cltv_value` and `total_amount_msat`.
if (records.unknown.nonEmpty) return Left(ForbiddenTlv(records.unknown.head.tag))
records.records.find {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ object BlindedRouteData {

def validPaymentRecipientData(records: TlvStream[RouteBlindingEncryptedDataTlv]): Either[InvalidTlvPayload, TlvStream[RouteBlindingEncryptedDataTlv]] = {
if (records.get[PathId].isEmpty) return Left(MissingRequiredTlv(UInt64(6)))
if (records.get[PaymentConstraints].isEmpty) return Left(MissingRequiredTlv(UInt64(12)))
Right(records)
}

Expand Down
Loading

0 comments on commit 304290d

Please sign in to comment.