diff --git a/routing/graph.go b/routing/graph.go index f0f6efcfa7..f07ed4730b 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -33,6 +33,59 @@ type ReadOnlyGraph interface { Graph } +type GraphSessionConstructor interface { + NewSession() (GraphSession, error) +} + +type GraphSession interface { + Graph() routingGraph + Close() +} + +func NewGraphSessionConstructor(graph ReadOnlyGraph, + sourceNode route.Vertex) GraphSessionConstructor { + + return &readLockGraphSessConstructor{ + sourceNode: sourceNode, + graph: graph, + } +} + +type readLockGraphSessConstructor struct { + sourceNode route.Vertex + graph ReadOnlyGraph +} + +func (r *readLockGraphSessConstructor) NewSession() (GraphSession, error) { + cachedGraph, err := NewCachedGraph(r.sourceNode, r.graph, true) + if err != nil { + return nil, err + } + + return &readLockGraphSession{ + cachedGraph: cachedGraph, + }, nil +} + +var _ GraphSessionConstructor = (*readLockGraphSessConstructor)(nil) + +type readLockGraphSession struct { + cachedGraph *CachedGraph +} + +func (r *readLockGraphSession) Graph() routingGraph { + return r.cachedGraph +} + +func (r *readLockGraphSession) Close() { + err := r.cachedGraph.Close() + if err != nil { + log.Errorf("Error closing db tx: %v", err) + } +} + +var _ GraphSession = (*readLockGraphSession)(nil) + // Graph describes the API necessary for a graph source to have in order to be // used by the Router for pathfinding. type Graph interface { diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 95a5eaf65f..2a92a36b9c 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -201,9 +201,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, session, err := newPaymentSession( &payment, c.graph.source.pubkey, getBandwidthHints, - func() (routingGraph, func(), error) { - return c.graph, func() {}, nil - }, + &mockGraphSessionConstructor{graph: c.graph}, mc, c.pathFindingCfg, ) if err != nil { @@ -292,6 +290,29 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, return attempts, nil } +type mockGraphSessionConstructor struct { + graph routingGraph +} + +func (m *mockGraphSessionConstructor) NewSession() (GraphSession, error) { + return &mockGraphSession{graph: m.graph}, nil +} + +var _ GraphSessionConstructor = (*mockGraphSessionConstructor)(nil) + +type mockGraphSession struct { + graph routingGraph +} + +func (m *mockGraphSession) Graph() routingGraph { + return m.graph +} + +func (m *mockGraphSession) Close() { +} + +var _ GraphSession = (*mockGraphSession)(nil) + // getNodeIndex returns the zero-based index of the given node in the route. func getNodeIndex(route *route.Route, failureSource route.Vertex) *int { if failureSource == route.SourcePubKey { diff --git a/routing/payment_session.go b/routing/payment_session.go index e3a9110ad3..2d49e459ac 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -176,7 +176,7 @@ type paymentSession struct { pathFinder pathFinder - getRoutingGraph func() (routingGraph, func(), error) + graphSessionConstructor GraphSessionConstructor // pathFindingConfig defines global parameters that control the // trade-off in path finding between fees and probability. @@ -197,7 +197,7 @@ type paymentSession struct { // newPaymentSession instantiates a new payment session. func newPaymentSession(p *LightningPayment, selfNode route.Vertex, getBandwidthHints func(routingGraph) (bandwidthHints, error), - getRoutingGraph func() (routingGraph, func(), error), + graphSessionConstructor GraphSessionConstructor, missionControl MissionController, pathFindingConfig PathFindingConfig) ( *paymentSession, error) { @@ -209,16 +209,16 @@ func newPaymentSession(p *LightningPayment, selfNode route.Vertex, logPrefix := fmt.Sprintf("PaymentSession(%x):", p.Identifier()) return &paymentSession{ - selfNode: selfNode, - additionalEdges: edges, - getBandwidthHints: getBandwidthHints, - payment: p, - pathFinder: findPath, - getRoutingGraph: getRoutingGraph, - pathFindingConfig: pathFindingConfig, - missionControl: missionControl, - minShardAmt: DefaultShardMinAmt, - log: build.NewPrefixLog(logPrefix, log), + selfNode: selfNode, + additionalEdges: edges, + getBandwidthHints: getBandwidthHints, + payment: p, + pathFinder: findPath, + graphSessionConstructor: graphSessionConstructor, + pathFindingConfig: pathFindingConfig, + missionControl: missionControl, + minShardAmt: DefaultShardMinAmt, + log: build.NewPrefixLog(logPrefix, log), }, nil } @@ -281,12 +281,14 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, } for { - // Get a routing graph. - routingGraph, cleanup, err := p.getRoutingGraph() + // Start a new routing graph session. + graphSession, err := p.graphSessionConstructor.NewSession() if err != nil { return nil, err } + routingGraph := graphSession.Graph() + // We'll also obtain a set of bandwidthHints from the lower // layer for each of our outbound channels. This will allow the // path finding to skip any links that aren't active or just @@ -312,8 +314,8 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, maxAmt, p.payment.TimePref, finalHtlcExpiry, ) - // Close routing graph. - cleanup() + // Close routing graph session. + graphSession.Close() switch { case err == errNoPathFound: diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index e28392baef..1255840416 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -16,11 +16,11 @@ var _ PaymentSessionSource = (*SessionSource)(nil) // SessionSource defines a source for the router to retrieve new payment // sessions. type SessionSource struct { - // RoutingGraph provides a Graph that can be used for path finding for a - // specific payment. If the NewPathFindingTx method is called to obtain - // a read-only lock on the graph, then the clean-up all-back must be - // called once path-finding is complete. - RoutingGraph ReadOnlyGraph + // GraphSessionConstructor can be used to create a new GraphSession + // which can then be used to interact with a Graph for path finding for + // a specific payment. Close must be called on the GraphSession once + // path-finding is complete. + GraphSessionConstructor GraphSessionConstructor // SourceNode is the graph's source node. SourceNode *channeldb.LightningNode @@ -46,23 +46,6 @@ type SessionSource struct { PathFindingConfig PathFindingConfig } -// getRoutingGraph returns a routing graph and a clean-up function for -// pathfinding. -func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { - routingTx, err := NewCachedGraph( - m.SourceNode.PubKeyBytes, m.RoutingGraph, true, - ) - if err != nil { - return nil, nil, err - } - return routingTx, func() { - err := routingTx.Close() - if err != nil { - log.Errorf("Error closing db tx: %v", err) - } - }, nil -} - // NewPaymentSession creates a new payment session backed by the latest prune // view from Mission Control. An optional set of routing hints can be provided // in order to populate additional edges to explore when finding a path to the @@ -78,7 +61,8 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( session, err := newPaymentSession( p, m.SourceNode.PubKeyBytes, getBandwidthHints, - m.getRoutingGraph, m.MissionControl, m.PathFindingConfig, + m.GraphSessionConstructor, m.MissionControl, + m.PathFindingConfig, ) if err != nil { return nil, err diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index b7efed5b7c..79a3638e18 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -119,9 +119,7 @@ func TestUpdateAdditionalEdge(t *testing.T) { func(routingGraph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, - func() (routingGraph, func(), error) { - return &sessionGraph{}, func() {}, nil - }, + &mockGraphSessionConstructor{graph: &sessionGraph{}}, &MissionControl{}, PathFindingConfig{}, ) @@ -199,9 +197,7 @@ func TestRequestRoute(t *testing.T) { func(routingGraph) (bandwidthHints, error) { return &mockBandwidthHints{}, nil }, - func() (routingGraph, func(), error) { - return &sessionGraph{}, func() {}, nil - }, + &mockGraphSessionConstructor{graph: &sessionGraph{}}, &MissionControl{}, PathFindingConfig{}, ) diff --git a/routing/router.go b/routing/router.go index 982125146e..96b12c9749 100644 --- a/routing/router.go +++ b/routing/router.go @@ -220,7 +220,7 @@ type Config struct { SelfNode route.Vertex // RoutingGraph is a graph source that will be used for pathfinding. - RoutingGraph ReadOnlyGraph + RoutingGraph Graph // Chain is the router's source to the most up-to-date blockchain data. // All incoming advertised channels will be checked against the chain @@ -319,17 +319,13 @@ type ChannelRouter struct { // channel graph is a subset of the UTXO set) set, then the router will proceed // to fully sync to the latest state of the UTXO set. func New(cfg Config) (*ChannelRouter, error) { - graph, err := NewCachedGraph( - cfg.SelfNode, cfg.RoutingGraph, false, - ) - if err != nil { - return nil, err - } - return &ChannelRouter{ - cfg: &cfg, - cachedGraph: graph, - quit: make(chan struct{}), + cfg: &cfg, + cachedGraph: &CachedGraph{ + graph: cfg.RoutingGraph, + source: cfg.SelfNode, + }, + quit: make(chan struct{}), }, nil } diff --git a/routing/router_test.go b/routing/router_test.go index 931fb48e30..0ab5081c0b 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -132,7 +132,12 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, sourceNode, err := graphInstance.graph.SourceNode() require.NoError(t, err) sessionSource := &SessionSource{ - RoutingGraph: graphInstance.graph, + GraphSessionConstructor: &mockGraphSessionConstructor{ + graph: &CachedGraph{ + graph: graphInstance.graph, + source: sourceNode.PubKeyBytes, + }, + }, SourceNode: sourceNode, GetLink: graphInstance.getLink, PathFindingConfig: pathFindingConfig, diff --git a/server.go b/server.go index bd8249848a..c544ef4f12 100644 --- a/server.go +++ b/server.go @@ -959,7 +959,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return nil, fmt.Errorf("error getting source node: %w", err) } paymentSessionSource := &routing.SessionSource{ - RoutingGraph: chanGraph, + GraphSessionConstructor: routing.NewGraphSessionConstructor( + chanGraph, sourceNode.PubKeyBytes, + ), SourceNode: sourceNode, MissionControl: s.missionControl, GetLink: s.htlcSwitch.GetLinkByShortID,