Skip to content

Commit

Permalink
temp: more indirection yay
Browse files Browse the repository at this point in the history
  • Loading branch information
ellemouton committed Jun 20, 2024
1 parent 5f00b80 commit 25593ae
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 61 deletions.
53 changes: 53 additions & 0 deletions routing/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,59 @@ type ReadOnlyGraph interface {
Graph
}

type GraphSessionConstructor interface {
NewSession() (GraphSession, error)
}

type GraphSession interface {
Graph() routingGraph
Close()
}

func NewGraphSessionConstructor(graph ReadOnlyGraph,
sourceNode route.Vertex) GraphSessionConstructor {

return &readLockGraphSessConstructor{
sourceNode: sourceNode,
graph: graph,
}
}

type readLockGraphSessConstructor struct {
sourceNode route.Vertex
graph ReadOnlyGraph
}

func (r *readLockGraphSessConstructor) NewSession() (GraphSession, error) {
cachedGraph, err := NewCachedGraph(r.sourceNode, r.graph, true)
if err != nil {
return nil, err
}

return &readLockGraphSession{
cachedGraph: cachedGraph,
}, nil
}

var _ GraphSessionConstructor = (*readLockGraphSessConstructor)(nil)

type readLockGraphSession struct {
cachedGraph *CachedGraph
}

func (r *readLockGraphSession) Graph() routingGraph {
return r.cachedGraph
}

func (r *readLockGraphSession) Close() {
err := r.cachedGraph.Close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}

var _ GraphSession = (*readLockGraphSession)(nil)

// Graph describes the API necessary for a graph source to have in order to be
// used by the Router for pathfinding.
type Graph interface {
Expand Down
27 changes: 24 additions & 3 deletions routing/integrated_routing_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,

session, err := newPaymentSession(
&payment, c.graph.source.pubkey, getBandwidthHints,
func() (routingGraph, func(), error) {
return c.graph, func() {}, nil
},
&mockGraphSessionConstructor{graph: c.graph},
mc, c.pathFindingCfg,
)
if err != nil {
Expand Down Expand Up @@ -292,6 +290,29 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,
return attempts, nil
}

type mockGraphSessionConstructor struct {
graph routingGraph
}

func (m *mockGraphSessionConstructor) NewSession() (GraphSession, error) {
return &mockGraphSession{graph: m.graph}, nil
}

var _ GraphSessionConstructor = (*mockGraphSessionConstructor)(nil)

type mockGraphSession struct {
graph routingGraph
}

func (m *mockGraphSession) Graph() routingGraph {
return m.graph
}

func (m *mockGraphSession) Close() {
}

var _ GraphSession = (*mockGraphSession)(nil)

// getNodeIndex returns the zero-based index of the given node in the route.
func getNodeIndex(route *route.Route, failureSource route.Vertex) *int {
if failureSource == route.SourcePubKey {
Expand Down
34 changes: 18 additions & 16 deletions routing/payment_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ type paymentSession struct {

pathFinder pathFinder

getRoutingGraph func() (routingGraph, func(), error)
graphSessionConstructor GraphSessionConstructor

// pathFindingConfig defines global parameters that control the
// trade-off in path finding between fees and probability.
Expand All @@ -197,7 +197,7 @@ type paymentSession struct {
// newPaymentSession instantiates a new payment session.
func newPaymentSession(p *LightningPayment, selfNode route.Vertex,
getBandwidthHints func(routingGraph) (bandwidthHints, error),
getRoutingGraph func() (routingGraph, func(), error),
graphSessionConstructor GraphSessionConstructor,
missionControl MissionController, pathFindingConfig PathFindingConfig) (
*paymentSession, error) {

Expand All @@ -209,16 +209,16 @@ func newPaymentSession(p *LightningPayment, selfNode route.Vertex,
logPrefix := fmt.Sprintf("PaymentSession(%x):", p.Identifier())

return &paymentSession{
selfNode: selfNode,
additionalEdges: edges,
getBandwidthHints: getBandwidthHints,
payment: p,
pathFinder: findPath,
getRoutingGraph: getRoutingGraph,
pathFindingConfig: pathFindingConfig,
missionControl: missionControl,
minShardAmt: DefaultShardMinAmt,
log: build.NewPrefixLog(logPrefix, log),
selfNode: selfNode,
additionalEdges: edges,
getBandwidthHints: getBandwidthHints,
payment: p,
pathFinder: findPath,
graphSessionConstructor: graphSessionConstructor,
pathFindingConfig: pathFindingConfig,
missionControl: missionControl,
minShardAmt: DefaultShardMinAmt,
log: build.NewPrefixLog(logPrefix, log),
}, nil
}

Expand Down Expand Up @@ -281,12 +281,14 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
}

for {
// Get a routing graph.
routingGraph, cleanup, err := p.getRoutingGraph()
// Start a new routing graph session.
graphSession, err := p.graphSessionConstructor.NewSession()
if err != nil {
return nil, err
}

routingGraph := graphSession.Graph()

// We'll also obtain a set of bandwidthHints from the lower
// layer for each of our outbound channels. This will allow the
// path finding to skip any links that aren't active or just
Expand All @@ -312,8 +314,8 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
maxAmt, p.payment.TimePref, finalHtlcExpiry,
)

// Close routing graph.
cleanup()
// Close routing graph session.
graphSession.Close()

switch {
case err == errNoPathFound:
Expand Down
30 changes: 7 additions & 23 deletions routing/payment_session_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ var _ PaymentSessionSource = (*SessionSource)(nil)
// SessionSource defines a source for the router to retrieve new payment
// sessions.
type SessionSource struct {
// RoutingGraph provides a Graph that can be used for path finding for a
// specific payment. If the NewPathFindingTx method is called to obtain
// a read-only lock on the graph, then the clean-up all-back must be
// called once path-finding is complete.
RoutingGraph ReadOnlyGraph
// GraphSessionConstructor can be used to create a new GraphSession
// which can then be used to interact with a Graph for path finding for
// a specific payment. Close must be called on the GraphSession once
// path-finding is complete.
GraphSessionConstructor GraphSessionConstructor

// SourceNode is the graph's source node.
SourceNode *channeldb.LightningNode
Expand All @@ -46,23 +46,6 @@ type SessionSource struct {
PathFindingConfig PathFindingConfig
}

// getRoutingGraph returns a routing graph and a clean-up function for
// pathfinding.
func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) {
routingTx, err := NewCachedGraph(
m.SourceNode.PubKeyBytes, m.RoutingGraph, true,
)
if err != nil {
return nil, nil, err
}
return routingTx, func() {
err := routingTx.Close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}, nil
}

// NewPaymentSession creates a new payment session backed by the latest prune
// view from Mission Control. An optional set of routing hints can be provided
// in order to populate additional edges to explore when finding a path to the
Expand All @@ -78,7 +61,8 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment) (

session, err := newPaymentSession(
p, m.SourceNode.PubKeyBytes, getBandwidthHints,
m.getRoutingGraph, m.MissionControl, m.PathFindingConfig,
m.GraphSessionConstructor, m.MissionControl,
m.PathFindingConfig,
)
if err != nil {
return nil, err
Expand Down
8 changes: 2 additions & 6 deletions routing/payment_session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ func TestUpdateAdditionalEdge(t *testing.T) {
func(routingGraph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil
},
func() (routingGraph, func(), error) {
return &sessionGraph{}, func() {}, nil
},
&mockGraphSessionConstructor{graph: &sessionGraph{}},
&MissionControl{},
PathFindingConfig{},
)
Expand Down Expand Up @@ -199,9 +197,7 @@ func TestRequestRoute(t *testing.T) {
func(routingGraph) (bandwidthHints, error) {
return &mockBandwidthHints{}, nil
},
func() (routingGraph, func(), error) {
return &sessionGraph{}, func() {}, nil
},
&mockGraphSessionConstructor{graph: &sessionGraph{}},
&MissionControl{},
PathFindingConfig{},
)
Expand Down
18 changes: 7 additions & 11 deletions routing/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ type Config struct {
SelfNode route.Vertex

// RoutingGraph is a graph source that will be used for pathfinding.
RoutingGraph ReadOnlyGraph
RoutingGraph Graph

// Chain is the router's source to the most up-to-date blockchain data.
// All incoming advertised channels will be checked against the chain
Expand Down Expand Up @@ -319,17 +319,13 @@ type ChannelRouter struct {
// channel graph is a subset of the UTXO set) set, then the router will proceed
// to fully sync to the latest state of the UTXO set.
func New(cfg Config) (*ChannelRouter, error) {
graph, err := NewCachedGraph(
cfg.SelfNode, cfg.RoutingGraph, false,
)
if err != nil {
return nil, err
}

return &ChannelRouter{
cfg: &cfg,
cachedGraph: graph,
quit: make(chan struct{}),
cfg: &cfg,
cachedGraph: &CachedGraph{
graph: cfg.RoutingGraph,
source: cfg.SelfNode,
},
quit: make(chan struct{}),
}, nil
}

Expand Down
7 changes: 6 additions & 1 deletion routing/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,12 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
sourceNode, err := graphInstance.graph.SourceNode()
require.NoError(t, err)
sessionSource := &SessionSource{
RoutingGraph: graphInstance.graph,
GraphSessionConstructor: &mockGraphSessionConstructor{
graph: &CachedGraph{
graph: graphInstance.graph,
source: sourceNode.PubKeyBytes,
},
},
SourceNode: sourceNode,
GetLink: graphInstance.getLink,
PathFindingConfig: pathFindingConfig,
Expand Down
4 changes: 3 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
return nil, fmt.Errorf("error getting source node: %w", err)
}
paymentSessionSource := &routing.SessionSource{
RoutingGraph: chanGraph,
GraphSessionConstructor: routing.NewGraphSessionConstructor(
chanGraph, sourceNode.PubKeyBytes,
),
SourceNode: sourceNode,
MissionControl: s.missionControl,
GetLink: s.htlcSwitch.GetLinkByShortID,
Expand Down

0 comments on commit 25593ae

Please sign in to comment.