Skip to content

Commit

Permalink
challenger: custom context for LightningClient methods
Browse files Browse the repository at this point in the history
Whenever we use the LightningClient from an LNC connection we need to
add the macaroon to the headers.
  • Loading branch information
positiveblue committed Jun 27, 2023
1 parent 5d6f1d1 commit 8bea271
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
4 changes: 2 additions & 2 deletions aperture.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ func (a *Aperture) Start(errChan chan error) error {
}

a.challenger, err = NewLndChallenger(
client, genInvoiceReq, errChan,
client, genInvoiceReq, nodeConn.CtxFunc, errChan,
)
if err != nil {
return err
Expand Down Expand Up @@ -369,7 +369,7 @@ func (a *Aperture) Start(errChan chan error) error {
}

a.challenger, err = NewLndChallenger(
client, genInvoiceReq, errChan,
client, genInvoiceReq, context.Background, errChan,
)
if err != nil {
return err
Expand Down
18 changes: 15 additions & 3 deletions challenger.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type InvoiceClient interface {
// payment challenges.
type LndChallenger struct {
client InvoiceClient
clientCtx func() context.Context
genInvoiceReq InvoiceRequestGenerator

invoiceStates map[lntypes.Hash]lnrpc.Invoice_InvoiceState
Expand Down Expand Up @@ -69,15 +70,23 @@ const (
// an lnd backend to create payment challenges.
func NewLndChallenger(client InvoiceClient,
genInvoiceReq InvoiceRequestGenerator,
ctxFunc func() context.Context,
errChan chan<- error) (*LndChallenger, error) {

// Make sure we have a valid context function. This will be called to
// create a new context for each call to the lnd client.
if ctxFunc == nil {
ctxFunc = context.Background
}

if genInvoiceReq == nil {
return nil, fmt.Errorf("genInvoiceReq cannot be nil")
}

invoicesMtx := &sync.Mutex{}
return &LndChallenger{
client: client,
clientCtx: ctxFunc,
genInvoiceReq: genInvoiceReq,
invoiceStates: make(map[lntypes.Hash]lnrpc.Invoice_InvoiceState),
invoicesMtx: invoicesMtx,
Expand All @@ -103,8 +112,9 @@ func (l *LndChallenger) Start() error {
// cache. We need to keep track of all invoices, even quite old ones to
// make sure tokens are valid. But to save space we only keep track of
// an invoice's state.
ctx := l.clientCtx()
invoiceResp, err := l.client.ListInvoices(
context.Background(), &lnrpc.ListInvoiceRequest{
ctx, &lnrpc.ListInvoiceRequest{
NumMaxInvoices: math.MaxUint64,
},
)
Expand Down Expand Up @@ -137,7 +147,7 @@ func (l *LndChallenger) Start() error {
l.invoicesMtx.Unlock()

// We need to be able to cancel any subscription we make.
ctxc, cancel := context.WithCancel(context.Background())
ctxc, cancel := context.WithCancel(l.clientCtx())
l.invoicesCancel = cancel

subscriptionResp, err := l.client.SubscribeInvoices(
Expand Down Expand Up @@ -261,12 +271,14 @@ func (l *LndChallenger) NewChallenge(price int64) (string, lntypes.Hash, error)
log.Errorf("Error generating invoice request: %v", err)
return "", lntypes.ZeroHash, err
}
ctx := context.Background()

ctx := l.clientCtx()
response, err := l.client.AddInvoice(ctx, invoice)
if err != nil {
log.Errorf("Error adding invoice: %v", err)
return "", lntypes.ZeroHash, err
}

paymentHash, err := lntypes.MakeHash(response.RHash)
if err != nil {
log.Errorf("Error parsing payment hash: %v", err)
Expand Down
3 changes: 2 additions & 1 deletion challenger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ func newChallenger() (*LndChallenger, *mockInvoiceClient, chan error) {
mainErrChan := make(chan error)
return &LndChallenger{
client: mockClient,
clientCtx: context.Background,
genInvoiceReq: genInvoiceReq,
invoiceStates: make(map[lntypes.Hash]lnrpc.Invoice_InvoiceState),
quit: make(chan struct{}),
Expand Down Expand Up @@ -130,7 +131,7 @@ func TestLndChallenger(t *testing.T) {
// First of all, test that the NewLndChallenger doesn't allow a nil
// invoice generator function.
errChan := make(chan error)
_, err := NewLndChallenger(nil, nil, errChan)
_, err := NewLndChallenger(nil, nil, nil, errChan)
require.Error(t, err)

// Now mock the lnd backend and create a challenger instance that we can
Expand Down

0 comments on commit 8bea271

Please sign in to comment.