Skip to content

Commit

Permalink
context propagation: pass ctx to UpdateScenario() (#3258)
Browse files Browse the repository at this point in the history
* context propagation: pass ctx to UpdateScenario()

* context propagation: SendMetrics, SendUsageMetrics, plugin config
  • Loading branch information
mmetc authored Oct 2, 2024
1 parent 897613e commit 27451a5
Show file tree
Hide file tree
Showing 11 changed files with 33 additions and 33 deletions.
2 changes: 1 addition & 1 deletion cmd/crowdsec-cli/clicapi/capi.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func queryCAPIStatus(ctx context.Context, hub *cwhub.Hub, credURL string, login
// I don't believe papi is neede to check enrollement
// PapiURL: papiURL,
VersionPrefix: "v3",
UpdateScenario: func() ([]string, error) {
UpdateScenario: func(_ context.Context) ([]string, error) {
return itemsForAPI, nil
},
})
Expand Down
2 changes: 1 addition & 1 deletion cmd/crowdsec/lapiclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func AuthenticatedLAPIClient(credentials csconfig.ApiCredentialsCfg, hub *cwhub.
URL: apiURL,
PapiURL: papiURL,
VersionPrefix: "v1",
UpdateScenario: func() ([]string, error) {
UpdateScenario: func(_ context.Context) ([]string, error) {
return itemsForAPI, nil
},
})
Expand Down
7 changes: 5 additions & 2 deletions pkg/apiclient/auth_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package apiclient

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -30,15 +31,17 @@ type JWTTransport struct {
// Transport is the underlying HTTP transport to use when making requests.
// It will default to http.DefaultTransport if nil.
Transport http.RoundTripper
UpdateScenario func() ([]string, error)
UpdateScenario func(context.Context) ([]string, error)
refreshTokenMutex sync.Mutex
}

func (t *JWTTransport) refreshJwtToken() error {
var err error

ctx := context.TODO()

if t.UpdateScenario != nil {
t.Scenarios, err = t.UpdateScenario()
t.Scenarios, err = t.UpdateScenario(ctx)
if err != nil {
return fmt.Errorf("can't update scenario list: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/apiclient/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package apiclient

import (
"context"
"net/url"

"github.com/go-openapi/strfmt"
Expand All @@ -15,5 +16,5 @@ type Config struct {
VersionPrefix string
UserAgent string
RegistrationToken string
UpdateScenario func() ([]string, error)
UpdateScenario func(context.Context) ([]string, error)
}
10 changes: 4 additions & 6 deletions pkg/apiserver/apic.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,9 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration {
return ret
}

func (a *apic) FetchScenariosListFromDB() ([]string, error) {
func (a *apic) FetchScenariosListFromDB(ctx context.Context) ([]string, error) {
scenarios := make([]string, 0)

ctx := context.TODO()

machines, err := a.dbClient.ListMachines(ctx)
if err != nil {
return nil, fmt.Errorf("while listing machines: %w", err)
Expand Down Expand Up @@ -214,7 +212,7 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient
return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.PapiURL, err)
}

ret.scenarioList, err = ret.FetchScenariosListFromDB()
ret.scenarioList, err = ret.FetchScenariosListFromDB(ctx)
if err != nil {
return nil, fmt.Errorf("while fetching scenarios from db: %w", err)
}
Expand All @@ -234,7 +232,7 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient

// The watcher will be authenticated by the RoundTripper the first time it will call CAPI
// Explicit authentication will provoke a useless supplementary call to CAPI
scenarios, err := ret.FetchScenariosListFromDB()
scenarios, err := ret.FetchScenariosListFromDB(ctx)
if err != nil {
return ret, fmt.Errorf("get scenario in db: %w", err)
}
Expand Down Expand Up @@ -944,7 +942,7 @@ func (a *apic) Pull(ctx context.Context) error {
toldOnce := false

for {
scenario, err := a.FetchScenariosListFromDB()
scenario, err := a.FetchScenariosListFromDB(ctx)
if err != nil {
log.Errorf("unable to fetch scenarios from db: %s", err)
}
Expand Down
12 changes: 4 additions & 8 deletions pkg/apiserver/apic_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,9 @@ func (a *apic) fetchMachineIDs(ctx context.Context) ([]string, error) {
// Metrics are sent at start, then at the randomized metricsIntervalFirst,
// then at regular metricsInterval. If a change is detected in the list
// of machines, the next metrics are sent immediately.
func (a *apic) SendMetrics(stop chan (bool)) {
func (a *apic) SendMetrics(ctx context.Context, stop chan (bool)) {
defer trace.CatchPanic("lapi/metricsToAPIC")

ctx := context.TODO()

// verify the list of machines every <checkInt> interval
const checkInt = 20 * time.Second

Expand Down Expand Up @@ -321,7 +319,7 @@ func (a *apic) SendMetrics(stop chan (bool)) {
if metrics != nil {
log.Info("capi metrics: sending")

_, _, err = a.apiClient.Metrics.Add(context.Background(), metrics)
_, _, err = a.apiClient.Metrics.Add(ctx, metrics)
if err != nil {
log.Errorf("capi metrics: failed: %s", err)
}
Expand All @@ -339,11 +337,9 @@ func (a *apic) SendMetrics(stop chan (bool)) {
}
}

func (a *apic) SendUsageMetrics() {
func (a *apic) SendUsageMetrics(ctx context.Context) {
defer trace.CatchPanic("lapi/usageMetricsToAPIC")

ctx := context.TODO()

firstRun := true

log.Debugf("Start sending usage metrics to CrowdSec Central API (interval: %s once, then %s)", a.usageMetricsIntervalFirst, a.usageMetricsInterval)
Expand All @@ -368,7 +364,7 @@ func (a *apic) SendUsageMetrics() {
continue
}

_, resp, err := a.apiClient.UsageMetrics.Add(context.Background(), metrics)
_, resp, err := a.apiClient.UsageMetrics.Add(ctx, metrics)
if err != nil {
log.Errorf("unable to send usage metrics: %s", err)

Expand Down
12 changes: 7 additions & 5 deletions pkg/apiserver/apic_metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
)

func TestAPICSendMetrics(t *testing.T) {
ctx := context.Background()

tests := []struct {
name string
duration time.Duration
Expand All @@ -34,24 +36,24 @@ func TestAPICSendMetrics(t *testing.T) {
metricsInterval: time.Millisecond * 20,
expectedCalls: 5,
setUp: func(api *apic) {
api.dbClient.Ent.Machine.Delete().ExecX(context.Background())
api.dbClient.Ent.Machine.Delete().ExecX(ctx)
api.dbClient.Ent.Machine.Create().
SetMachineId("1234").
SetPassword(testPassword.String()).
SetIpAddress("1.2.3.4").
SetScenarios("crowdsecurity/test").
SetLastPush(time.Time{}).
SetUpdatedAt(time.Time{}).
ExecX(context.Background())
ExecX(ctx)

api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background())
api.dbClient.Ent.Bouncer.Delete().ExecX(ctx)
api.dbClient.Ent.Bouncer.Create().
SetIPAddress("1.2.3.6").
SetName("someBouncer").
SetAPIKey("foobar").
SetRevoked(false).
SetLastPull(time.Time{}).
ExecX(context.Background())
ExecX(ctx)
},
},
}
Expand Down Expand Up @@ -86,7 +88,7 @@ func TestAPICSendMetrics(t *testing.T) {

httpmock.ZeroCallCounters()

go api.SendMetrics(stop)
go api.SendMetrics(ctx, stop)

time.Sleep(tc.duration)
stop <- true
Expand Down
6 changes: 4 additions & 2 deletions pkg/apiserver/apic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ func TestAPICCAPIPullIsOld(t *testing.T) {
}

func TestAPICFetchScenariosListFromDB(t *testing.T) {
ctx := context.Background()

tests := []struct {
name string
machineIDsWithScenarios map[string]string
Expand Down Expand Up @@ -174,10 +176,10 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
SetPassword(testPassword.String()).
SetIpAddress("1.2.3.4").
SetScenarios(scenarios).
ExecX(context.Background())
ExecX(ctx)
}

scenarios, err := api.FetchScenariosListFromDB()
scenarios, err := api.FetchScenariosListFromDB(ctx)
require.NoError(t, err)

for machineID := range tc.machineIDsWithScenarios {
Expand Down
4 changes: 2 additions & 2 deletions pkg/apiserver/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,12 @@ func (s *APIServer) initAPIC(ctx context.Context) {
}

s.apic.metricsTomb.Go(func() error {
s.apic.SendMetrics(make(chan bool))
s.apic.SendMetrics(ctx, make(chan bool))
return nil
})

s.apic.metricsTomb.Go(func() error {
s.apic.SendUsageMetrics()
s.apic.SendUsageMetrics(ctx)
return nil
})
}
Expand Down
4 changes: 1 addition & 3 deletions pkg/csplugin/notifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ func (m *GRPCClient) Notify(ctx context.Context, notification *protobufs.Notific
}

func (m *GRPCClient) Configure(ctx context.Context, config *protobufs.Config) (*protobufs.Empty, error) {
_, err := m.client.Configure(
context.Background(), config,
)
_, err := m.client.Configure(ctx, config)
return &protobufs.Empty{}, err
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/protobufs/plugin_interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ type NotifierPlugin struct {
type GRPCClient struct{ client NotifierClient }

func (m *GRPCClient) Notify(ctx context.Context, notification *Notification) (*Empty, error) {
_, err := m.client.Notify(context.Background(), notification)
_, err := m.client.Notify(ctx, notification)
return &Empty{}, err
}

func (m *GRPCClient) Configure(ctx context.Context, config *Config) (*Empty, error) {
_, err := m.client.Configure(context.Background(), config)
_, err := m.client.Configure(ctx, config)
return &Empty{}, err
}

Expand Down

0 comments on commit 27451a5

Please sign in to comment.