Skip to content

Commit

Permalink
Payment lifecycle refactor (ACINQ#1414)
Browse files Browse the repository at this point in the history
* Extract faulty channels selection from PaymentLifecycle

Move the logic of figuring out which channels/nodes should be ignored
when retrying after a payment failure out of the PaymentLifecycle.

We can figure this out looking only at the `PaymentFailure` generated,
and the multi-part logic could leverage these helpers.

* Refactor RouteResponse

It was useless to return `ignoreNodes` and `ignoreChannels`, it's rather
the responsibility of the caller (PaymentLifecycle) to store and update
these sets.

Preparing for the MPP move inside the router, we introduce a Route class
and let RouteResponse return a collection of Routes.

This creates some ugliness in PaymentLifecycle because of the `routePrefix`,
but this is just temporary: the `routePrefix` "hack" will be removed soon.
  • Loading branch information
t-bast authored May 12, 2020
1 parent ba4cca2 commit c4d0604
Show file tree
Hide file tree
Showing 19 changed files with 294 additions and 210 deletions.
2 changes: 1 addition & 1 deletion eclair-core/src/main/scala/fr/acinq/eclair/PimpKamon.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object KamonExt {
*/
def failSpan(span: Span, failure: PaymentFailure) = {
failure match {
case LocalFailure(t) => span.fail("local failure", t)
case LocalFailure(_, t) => span.fail("local failure", t)
case RemoteFailure(_, e) => span.fail(s"remote failure: origin=${e.originNode} error=${e.failureMessage}")
case UnreadableRemoteFailure(_) => span.fail("unreadable remote failure")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ import fr.acinq.eclair.payment._
import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop}
import fr.acinq.eclair.{MilliSatoshi, ShortChannelId}

import scala.compat.Platform

trait PaymentsDb extends IncomingPaymentsDb with OutgoingPaymentsDb with PaymentsOverviewDb with Closeable

trait IncomingPaymentsDb {
Expand Down Expand Up @@ -209,7 +207,7 @@ object FailureType extends Enumeration {

object FailureSummary {
def apply(f: PaymentFailure): FailureSummary = f match {
case LocalFailure(t) => FailureSummary(FailureType.LOCAL, t.getMessage, Nil)
case LocalFailure(route, t) => FailureSummary(FailureType.LOCAL, t.getMessage, route.map(h => HopSummary(h)).toList)
case RemoteFailure(route, e) => FailureSummary(FailureType.REMOTE, e.failureMessage.message, route.map(h => HopSummary(h)).toList)
case UnreadableRemoteFailure(route) => FailureSummary(FailureType.UNREADABLE_REMOTE, "could not decrypt failure onion", route.map(h => HopSummary(h)).toList)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ object Monitoring {
}

def apply(pf: PaymentFailure): String = pf match {
case LocalFailure(t) => t.getClass.getSimpleName
case LocalFailure(_, t) => t.getClass.getSimpleName
case RemoteFailure(_, e) => e.failureMessage.getClass.getSimpleName
case UnreadableRemoteFailure(_) => "UnreadableRemoteFailure"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import fr.acinq.bitcoin.ByteVector32
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.MilliSatoshi
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.router.Router.Hop

import scala.compat.Platform
import fr.acinq.eclair.router.Announcements
import fr.acinq.eclair.router.Router.{ChannelDesc, ChannelHop, Hop}
import fr.acinq.eclair.wire.Node

/**
* Created by PM on 01/02/2017.
Expand Down Expand Up @@ -112,10 +112,12 @@ object PaymentReceived {

case class PaymentSettlingOnChain(id: UUID, amount: MilliSatoshi, paymentHash: ByteVector32, timestamp: Long = System.currentTimeMillis) extends PaymentEvent

sealed trait PaymentFailure
sealed trait PaymentFailure {
def route: Seq[Hop]
}

/** A failure happened locally, preventing the payment from being sent (e.g. no route found). */
case class LocalFailure(t: Throwable) extends PaymentFailure
case class LocalFailure(route: Seq[Hop], t: Throwable) extends PaymentFailure

/** A remote node failed the payment and we were able to decrypt the onion failure packet. */
case class RemoteFailure(route: Seq[Hop], e: Sphinx.DecryptedFailurePacket) extends PaymentFailure
Expand All @@ -142,10 +144,10 @@ object PaymentFailure {
*/
def transformForUser(failures: Seq[PaymentFailure]): Seq[PaymentFailure] = {
failures.map {
case LocalFailure(AddHtlcFailed(_, _, t, _, _, _)) => LocalFailure(t) // we're interested in the error which caused the add-htlc to fail
case LocalFailure(hops, AddHtlcFailed(_, _, t, _, _, _)) => LocalFailure(hops, t) // we're interested in the error which caused the add-htlc to fail
case other => other
} match {
case previousFailures :+ LocalFailure(RouteNotFound) if previousFailures.nonEmpty => previousFailures
case previousFailures :+ LocalFailure(_, RouteNotFound) if previousFailures.nonEmpty => previousFailures
case other => other
}
}
Expand All @@ -159,4 +161,44 @@ object PaymentFailure {
.collectFirst { case RemoteFailure(_, Sphinx.DecryptedFailurePacket(origin, u: Update)) if origin == nodeId => u.update }
.isDefined

/** Update the set of nodes and channels to ignore in retries depending on the failure we received. */
def updateIgnored(failure: PaymentFailure, ignoreNodes: Set[PublicKey], ignoreChannels: Set[ChannelDesc]): (Set[PublicKey], Set[ChannelDesc]) = failure match {
case RemoteFailure(hops, Sphinx.DecryptedFailurePacket(nodeId, _)) if nodeId == hops.last.nextNodeId =>
// The failure came from the final recipient: the payment should be aborted without penalizing anyone in the route.
(ignoreNodes, ignoreChannels)
case RemoteFailure(_, Sphinx.DecryptedFailurePacket(nodeId, _: Node)) =>
(ignoreNodes + nodeId, ignoreChannels)
case RemoteFailure(_, Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Update)) =>
if (Announcements.checkSig(failureMessage.update, nodeId)) {
// We were using an outdated channel update, we should retry with the new one and nobody should be penalized.
(ignoreNodes, ignoreChannels)
} else {
// This node is fishy, it gave us a bad signature, so let's filter it out.
(ignoreNodes + nodeId, ignoreChannels)
}
case RemoteFailure(hops, Sphinx.DecryptedFailurePacket(nodeId, _)) =>
// Let's ignore the channel outgoing from nodeId.
hops.collectFirst {
case hop: ChannelHop if hop.nodeId == nodeId => ChannelDesc(hop.lastUpdate.shortChannelId, hop.nodeId, hop.nextNodeId)
} match {
case Some(faultyChannel) => (ignoreNodes, ignoreChannels + faultyChannel)
case None => (ignoreNodes, ignoreChannels)
}
case UnreadableRemoteFailure(hops) =>
// We don't know which node is sending garbage, let's blacklist all nodes except the one we are directly connected to and the final recipient.
val blacklist = hops.map(_.nextNodeId).drop(1).dropRight(1)
(ignoreNodes ++ blacklist, ignoreChannels)
case LocalFailure(hops, _) => hops.headOption match {
case Some(hop: ChannelHop) =>
val faultyChannel = ChannelDesc(hop.lastUpdate.shortChannelId, hop.nodeId, hop.nextNodeId)
(ignoreNodes, ignoreChannels + faultyChannel)
case _ => (ignoreNodes, ignoreChannels)
}
}

/** Update the set of nodes and channels to ignore in retries depending on the failures we received. */
def updateIgnored(failures: Seq[PaymentFailure], ignoreNodes: Set[PublicKey], ignoreChannels: Set[ChannelDesc]): (Set[PublicKey], Set[ChannelDesc]) = {
failures.foldLeft((ignoreNodes, ignoreChannels)) { case ((nodes, channels), failure) => updateIgnored(failure, nodes, channels) }
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class NodeRelayer(nodeParams: NodeParams, relayer: ActorRef, router: ActorRef, c

object NodeRelayer {

def props(nodeParams: NodeParams, relayer: ActorRef, router: ActorRef, commandBuffer: ActorRef, register: ActorRef) = Props(classOf[NodeRelayer], nodeParams, relayer, router, commandBuffer, register)
def props(nodeParams: NodeParams, relayer: ActorRef, router: ActorRef, commandBuffer: ActorRef, register: ActorRef) = Props(new NodeRelayer(nodeParams, relayer, router, commandBuffer, register))

/**
* We start by aggregating an incoming HTLC set. Once we received the whole set, we will compute a route to the next
Expand Down Expand Up @@ -260,13 +260,13 @@ object NodeRelayer {
*/
private def translateError(failures: Seq[PaymentFailure], outgoingNodeId: PublicKey): Option[FailureMessage] = {
def tooManyRouteNotFound(failures: Seq[PaymentFailure]): Boolean = {
val routeNotFoundCount = failures.count(_ == LocalFailure(RouteNotFound))
val routeNotFoundCount = failures.collect { case f@LocalFailure(_, RouteNotFound) => f }.length
routeNotFoundCount > failures.length / 2
}

failures match {
case Nil => None
case LocalFailure(PaymentError.BalanceTooLow) :: Nil => Some(TemporaryNodeFailure) // we don't have enough outgoing liquidity at the moment
case LocalFailure(_, PaymentError.BalanceTooLow) :: Nil => Some(TemporaryNodeFailure) // we don't have enough outgoing liquidity at the moment
case _ if tooManyRouteNotFound(failures) => Some(TrampolineFeeInsufficient) // if we couldn't find routes, it's likely that the fee/cltv was insufficient
case _ =>
// Otherwise, we try to find a downstream error that we could decrypt.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import fr.acinq.eclair.payment._
import fr.acinq.eclair.payment.relay.Relayer.{GetOutgoingChannels, OutgoingChannel, OutgoingChannels}
import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig
import fr.acinq.eclair.payment.send.PaymentLifecycle.SendPayment
import fr.acinq.eclair.router.Router.{ChannelHop, GetNetworkStats, GetNetworkStatsResponse, RouteParams, TickComputeNetworkStats}
import fr.acinq.eclair.router.Router._
import fr.acinq.eclair.router._
import fr.acinq.eclair.wire._
import fr.acinq.eclair.{CltvExpiry, FSMDiagnosticActorLogging, Logs, LongToBtcAmount, MilliSatoshi, NodeParams, ShortChannelId, ToMilliSatoshiConversion}
Expand All @@ -41,7 +41,6 @@ import kamon.context.Context
import scodec.bits.ByteVector

import scala.annotation.tailrec
import scala.compat.Platform
import scala.util.Random

/**
Expand Down Expand Up @@ -96,8 +95,8 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,
val (remaining, payments) = splitPayment(nodeParams, d.request.totalAmount, channels, d.networkStats, d.request, randomize = false)
if (remaining > 0.msat) {
log.warning(s"cannot send ${d.request.totalAmount} with our current balance")
Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(PaymentError.BalanceTooLow)))
goto(PAYMENT_ABORTED) using PaymentAborted(d.sender, d.request, LocalFailure(PaymentError.BalanceTooLow) :: Nil, Set.empty)
Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(Nil, PaymentError.BalanceTooLow)))
goto(PAYMENT_ABORTED) using PaymentAborted(d.sender, d.request, LocalFailure(Nil, PaymentError.BalanceTooLow) :: Nil, Set.empty)
} else {
val pending = setFees(d.request.routeParams, payments, payments.size)
Kamon.runWithContextEntry(parentPaymentIdKey, cfg.parentId) {
Expand Down Expand Up @@ -155,8 +154,8 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,
val (remaining, payments) = splitPayment(nodeParams, d.toSend, filteredChannels, d.networkStats, d.request, randomize = true) // we randomize channel selection when we retry
if (remaining > 0.msat) {
log.warning(s"cannot send ${d.toSend} with our current balance")
Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(PaymentError.BalanceTooLow)))
goto(PAYMENT_ABORTED) using PaymentAborted(d.sender, d.request, d.failures :+ LocalFailure(PaymentError.BalanceTooLow), d.pending.keySet)
Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(LocalFailure(Nil, PaymentError.BalanceTooLow)))
goto(PAYMENT_ABORTED) using PaymentAborted(d.sender, d.request, d.failures :+ LocalFailure(Nil, PaymentError.BalanceTooLow), d.pending.keySet)
} else {
val pending = setFees(d.request.routeParams, payments, payments.size + d.pending.size)
pending.foreach { case (childId, payment) => spawnChildPaymentFsm(childId) ! payment }
Expand Down Expand Up @@ -270,7 +269,7 @@ class MultiPartPaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig,
if (isFromFinalRecipient) {
Some(PaymentAborted(d.sender, d.request, d.failures ++ pf.failures, d.pending.keySet - pf.id))
} else if (d.remainingAttempts == 0) {
val failure = LocalFailure(PaymentError.RetryExhausted)
val failure = LocalFailure(Nil, PaymentError.RetryExhausted)
Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(failure))
Some(PaymentAborted(d.sender, d.request, d.failures ++ pf.failures :+ failure, d.pending.keySet - pf.id))
} else {
Expand Down Expand Up @@ -390,7 +389,7 @@ object MultiPartPaymentLifecycle {

/** If the payment failed immediately with a RouteNotFound, the channel we selected should be ignored in retries. */
private def shouldBlacklistChannel(pf: PaymentFailed): Boolean = pf.failures match {
case LocalFailure(RouteNotFound) :: Nil => true
case LocalFailure(_, RouteNotFound) :: Nil => true
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import fr.acinq.eclair.payment._
import fr.acinq.eclair.payment.send.MultiPartPaymentLifecycle.SendMultiPartPayment
import fr.acinq.eclair.payment.send.PaymentError._
import fr.acinq.eclair.payment.send.PaymentLifecycle.{SendPayment, SendPaymentToRoute}
import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop, RouteParams}
import fr.acinq.eclair.router.Router.{Hop, NodeHop, Route, RouteParams}
import fr.acinq.eclair.wire.Onion.FinalLegacyPayload
import fr.acinq.eclair.wire._
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, LongToBtcAmount, MilliSatoshi, NodeParams, randomBytes32}
Expand All @@ -50,13 +50,13 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, relayer: ActorR
val finalExpiry = r.finalExpiry(nodeParams.currentBlockHeight)
r.paymentRequest match {
case Some(invoice) if !invoice.features.supported =>
sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(UnsupportedFeatures(invoice.features.bitmask)) :: Nil)
sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(Nil, UnsupportedFeatures(invoice.features.bitmask)) :: Nil)
case Some(invoice) if invoice.features.allowMultiPart && Features.hasFeature(nodeParams.features, Features.BasicMultiPartPayment) =>
invoice.paymentSecret match {
case Some(paymentSecret) =>
spawnMultiPartPaymentFsm(paymentCfg) forward SendMultiPartPayment(paymentSecret, r.recipientNodeId, r.recipientAmount, finalExpiry, r.maxAttempts, r.assistedRoutes, r.routeParams, userCustomTlvs = r.userCustomTlvs)
case None =>
sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(PaymentSecretMissing) :: Nil)
sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(Nil, PaymentSecretMissing) :: Nil)
}
case _ =>
val paymentSecret = r.paymentRequest.flatMap(_.paymentSecret)
Expand All @@ -69,9 +69,9 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, relayer: ActorR
sender ! paymentId
r.trampolineAttempts match {
case Nil =>
sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(TrampolineFeesMissing) :: Nil)
sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(Nil, TrampolineFeesMissing) :: Nil)
case _ if !r.paymentRequest.features.allowTrampoline && r.paymentRequest.amount.isEmpty =>
sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(TrampolineLegacyAmountLessInvoice) :: Nil)
sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(Nil, TrampolineLegacyAmountLessInvoice) :: Nil)
case (trampolineFees, trampolineExpiryDelta) :: remainingAttempts =>
log.info(s"sending trampoline payment with trampoline fees=$trampolineFees and expiry delta=$trampolineExpiryDelta")
sendTrampolinePayment(paymentId, r, trampolineFees, trampolineExpiryDelta)
Expand Down Expand Up @@ -121,7 +121,7 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, relayer: ActorR
case None => payFsm forward SendPaymentToRoute(r.route, FinalLegacyPayload(r.recipientAmount, finalExpiry), r.paymentRequest.routingInfo)
}
case _ =>
sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(TrampolineMultiNodeNotSupported) :: Nil)
sender ! PaymentFailed(paymentId, r.paymentHash, LocalFailure(Nil, TrampolineMultiNodeNotSupported) :: Nil)
}
}

Expand Down Expand Up @@ -160,7 +160,7 @@ class PaymentInitiator(nodeParams: NodeParams, router: ActorRef, relayer: ActorR

object PaymentInitiator {

def props(nodeParams: NodeParams, router: ActorRef, relayer: ActorRef, register: ActorRef) = Props(classOf[PaymentInitiator], nodeParams, router, relayer, register)
def props(nodeParams: NodeParams, router: ActorRef, relayer: ActorRef, register: ActorRef) = Props(new PaymentInitiator(nodeParams, router, relayer, register))

case class PendingPayment(sender: ActorRef, remainingAttempts: Seq[(MilliSatoshi, CltvExpiryDelta)], r: SendTrampolinePaymentRequest)

Expand Down Expand Up @@ -310,7 +310,7 @@ object PaymentInitiator {
storeInDb: Boolean, // e.g. for trampoline we don't want to store in the DB when we're relaying payments
publishEvent: Boolean,
additionalHops: Seq[NodeHop]) {
def fullRoute(hops: Seq[ChannelHop]): Seq[Hop] = hops ++ additionalHops
def fullRoute(route: Route): Seq[Hop] = route.hops ++ additionalHops

def createPaymentSent(preimage: ByteVector32, parts: Seq[PaymentSent.PartialPayment]) = PaymentSent(parentId, paymentHash, preimage, recipientAmount, recipientNodeId, parts)
}
Expand Down
Loading

0 comments on commit c4d0604

Please sign in to comment.