From 261ed8cfbb19fe9c6e59b50e887f76a76b0b8029 Mon Sep 17 00:00:00 2001 From: positiveblue Date: Wed, 14 Jun 2023 17:22:50 -0700 Subject: [PATCH] challenger: custom context for LightningClient methods Whenever we use the LightningClient from an LNC connection we need to add the macaroon to the headers. --- aperture.go | 4 ++-- challenger.go | 18 +++++++++++++++--- challenger_test.go | 3 ++- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/aperture.go b/aperture.go index d2a7701..1033926 100644 --- a/aperture.go +++ b/aperture.go @@ -340,7 +340,7 @@ func (a *Aperture) Start(errChan chan error) error { } a.challenger, err = NewLndChallenger( - client, genInvoiceReq, errChan, + client, genInvoiceReq, nodeConn.CtxFunc, errChan, ) if err != nil { return err @@ -369,7 +369,7 @@ func (a *Aperture) Start(errChan chan error) error { } a.challenger, err = NewLndChallenger( - client, genInvoiceReq, errChan, + client, genInvoiceReq, context.Background, errChan, ) if err != nil { return err diff --git a/challenger.go b/challenger.go index 35ac0b1..0de1131 100644 --- a/challenger.go +++ b/challenger.go @@ -41,6 +41,7 @@ type InvoiceClient interface { // payment challenges. type LndChallenger struct { client InvoiceClient + clientCtx func() context.Context genInvoiceReq InvoiceRequestGenerator invoiceStates map[lntypes.Hash]lnrpc.Invoice_InvoiceState @@ -69,8 +70,15 @@ const ( // an lnd backend to create payment challenges. func NewLndChallenger(client InvoiceClient, genInvoiceReq InvoiceRequestGenerator, + ctxFunc func() context.Context, errChan chan<- error) (*LndChallenger, error) { + // Make sure we have a valid context function. This will be called to + // create a new context for each call to the lnd client. + if ctxFunc == nil { + ctxFunc = context.Background + } + if genInvoiceReq == nil { return nil, fmt.Errorf("genInvoiceReq cannot be nil") } @@ -78,6 +86,7 @@ func NewLndChallenger(client InvoiceClient, invoicesMtx := &sync.Mutex{} return &LndChallenger{ client: client, + clientCtx: ctxFunc, genInvoiceReq: genInvoiceReq, invoiceStates: make(map[lntypes.Hash]lnrpc.Invoice_InvoiceState), invoicesMtx: invoicesMtx, @@ -103,8 +112,9 @@ func (l *LndChallenger) Start() error { // cache. We need to keep track of all invoices, even quite old ones to // make sure tokens are valid. But to save space we only keep track of // an invoice's state. + ctx := l.clientCtx() invoiceResp, err := l.client.ListInvoices( - context.Background(), &lnrpc.ListInvoiceRequest{ + ctx, &lnrpc.ListInvoiceRequest{ NumMaxInvoices: math.MaxUint64, }, ) @@ -137,7 +147,7 @@ func (l *LndChallenger) Start() error { l.invoicesMtx.Unlock() // We need to be able to cancel any subscription we make. - ctxc, cancel := context.WithCancel(context.Background()) + ctxc, cancel := context.WithCancel(l.clientCtx()) l.invoicesCancel = cancel subscriptionResp, err := l.client.SubscribeInvoices( @@ -261,12 +271,14 @@ func (l *LndChallenger) NewChallenge(price int64) (string, lntypes.Hash, error) log.Errorf("Error generating invoice request: %v", err) return "", lntypes.ZeroHash, err } - ctx := context.Background() + + ctx := l.clientCtx() response, err := l.client.AddInvoice(ctx, invoice) if err != nil { log.Errorf("Error adding invoice: %v", err) return "", lntypes.ZeroHash, err } + paymentHash, err := lntypes.MakeHash(response.RHash) if err != nil { log.Errorf("Error parsing payment hash: %v", err) diff --git a/challenger_test.go b/challenger_test.go index f0bf7e0..f7ab606 100644 --- a/challenger_test.go +++ b/challenger_test.go @@ -102,6 +102,7 @@ func newChallenger() (*LndChallenger, *mockInvoiceClient, chan error) { mainErrChan := make(chan error) return &LndChallenger{ client: mockClient, + clientCtx: context.Background, genInvoiceReq: genInvoiceReq, invoiceStates: make(map[lntypes.Hash]lnrpc.Invoice_InvoiceState), quit: make(chan struct{}), @@ -130,7 +131,7 @@ func TestLndChallenger(t *testing.T) { // First of all, test that the NewLndChallenger doesn't allow a nil // invoice generator function. errChan := make(chan error) - _, err := NewLndChallenger(nil, nil, errChan) + _, err := NewLndChallenger(nil, nil, nil, errChan) require.Error(t, err) // Now mock the lnd backend and create a challenger instance that we can