Skip to content

Commit

Permalink
context propagation: pkg/database/config (#3246)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc authored Sep 19, 2024
1 parent b4a2403 commit eeb2801
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 45 deletions.
2 changes: 1 addition & 1 deletion cmd/crowdsec-cli/clipapi/papi.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Clie
return fmt.Errorf("unable to get PAPI permissions: %w", err)
}

lastTimestampStr, err := db.GetConfigItem(apiserver.PapiPullKey)
lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey)
if err != nil {
lastTimestampStr = ptr.Of("never")
}
Expand Down
28 changes: 14 additions & 14 deletions pkg/apiserver/apic.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio
// we receive a list of decisions and links for blocklist and we need to create a list of alerts :
// one alert for "community blocklist"
// one alert per list we're subscribed to
func (a *apic) PullTop(forcePull bool) error {
func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
var err error

// A mutex with TryLock would be a bit simpler
Expand Down Expand Up @@ -655,7 +655,7 @@ func (a *apic) PullTop(forcePull bool) error {

log.Infof("Starting community-blocklist update")

data, _, err := a.apiClient.Decisions.GetStreamV3(context.Background(), apiclient.DecisionsStreamOpts{Startup: a.startup})
data, _, err := a.apiClient.Decisions.GetStreamV3(ctx, apiclient.DecisionsStreamOpts{Startup: a.startup})
if err != nil {
return fmt.Errorf("get stream: %w", err)
}
Expand Down Expand Up @@ -700,17 +700,17 @@ func (a *apic) PullTop(forcePull bool) error {
}

// update blocklists
if err := a.UpdateBlocklists(data.Links, addCounters, forcePull); err != nil {
if err := a.UpdateBlocklists(ctx, data.Links, addCounters, forcePull); err != nil {
return fmt.Errorf("while updating blocklists: %w", err)
}

return nil
}

// we receive a link to a blocklist, we pull the content of the blocklist and we create one alert
func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error {
func (a *apic) PullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, forcePull bool) error {
addCounters, _ := makeAddAndDeleteCounters()
if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{
if err := a.UpdateBlocklists(ctx, &modelscapi.GetDecisionsStreamResponseLinks{
Blocklists: []*modelscapi.BlocklistLink{blocklist},
}, addCounters, forcePull); err != nil {
return fmt.Errorf("while pulling blocklist: %w", err)
Expand Down Expand Up @@ -820,7 +820,7 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo
return false, nil
}

func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error {
func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error {
if blocklist.Scope == nil {
log.Warningf("blocklist has no scope")
return nil
Expand Down Expand Up @@ -848,13 +848,13 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap
)

if !forcePull {
lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName)
lastPullTimestamp, err = a.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
if err != nil {
return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
}
}

decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp)
decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(ctx, blocklist, lastPullTimestamp)
if err != nil {
return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err)
}
Expand All @@ -869,7 +869,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap
return nil
}

err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
err = a.dbClient.SetConfigItem(ctx, blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
if err != nil {
return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
}
Expand All @@ -892,7 +892,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap
return nil
}

func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error {
func (a *apic) UpdateBlocklists(ctx context.Context, links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error {
if links == nil {
return nil
}
Expand All @@ -908,7 +908,7 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink
}

for _, blocklist := range links.Blocklists {
if err := a.updateBlocklist(defaultClient, blocklist, addCounters, forcePull); err != nil {
if err := a.updateBlocklist(ctx, defaultClient, blocklist, addCounters, forcePull); err != nil {
return err
}
}
Expand All @@ -931,7 +931,7 @@ func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int
}
}

func (a *apic) Pull() error {
func (a *apic) Pull(ctx context.Context) error {
defer trace.CatchPanic("lapi/pullFromAPIC")

toldOnce := false
Expand All @@ -955,7 +955,7 @@ func (a *apic) Pull() error {
time.Sleep(1 * time.Second)
}

if err := a.PullTop(false); err != nil {
if err := a.PullTop(ctx, false); err != nil {
log.Errorf("capi pull top: %s", err)
}

Expand All @@ -967,7 +967,7 @@ func (a *apic) Pull() error {
case <-ticker.C:
ticker.Reset(a.pullInterval)

if err := a.PullTop(false); err != nil {
if err := a.PullTop(ctx, false); err != nil {
log.Errorf("capi pull top: %s", err)
continue
}
Expand Down
24 changes: 15 additions & 9 deletions pkg/apiserver/apic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ func TestFillAlertsWithDecisions(t *testing.T) {
}

func TestAPICWhitelists(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)
// one whitelist on IP, one on CIDR
api.whitelists = &csconfig.CapiWhitelist{}
Expand Down Expand Up @@ -685,7 +686,7 @@ func TestAPICWhitelists(t *testing.T) {
require.NoError(t, err)

api.apiClient = apic
err = api.PullTop(false)
err = api.PullTop(ctx, false)
require.NoError(t, err)

assertTotalDecisionCount(t, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing
Expand Down Expand Up @@ -736,6 +737,7 @@ func TestAPICWhitelists(t *testing.T) {
}

func TestAPICPullTop(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)
api.dbClient.Ent.Decision.Create().
SetOrigin(types.CAPIOrigin).
Expand Down Expand Up @@ -826,7 +828,7 @@ func TestAPICPullTop(t *testing.T) {
require.NoError(t, err)

api.apiClient = apic
err = api.PullTop(false)
err = api.PullTop(ctx, false)
require.NoError(t, err)

assertTotalDecisionCount(t, api.dbClient, 5)
Expand Down Expand Up @@ -860,6 +862,7 @@ func TestAPICPullTop(t *testing.T) {
}

func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
ctx := context.Background()
// no decision in db, no last modified parameter.
api := getAPIC(t)

Expand Down Expand Up @@ -913,11 +916,11 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
require.NoError(t, err)

api.apiClient = apic
err = api.PullTop(false)
err = api.PullTop(ctx, false)
require.NoError(t, err)

blocklistConfigItemName := "blocklist:blocklist1:last_pull"
lastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName)
lastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
require.NoError(t, err)
assert.NotEqual(t, "", *lastPullTimestamp)

Expand All @@ -927,14 +930,15 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
return httpmock.NewStringResponse(304, ""), nil
})

err = api.PullTop(false)
err = api.PullTop(ctx, false)
require.NoError(t, err)
secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName)
secondLastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
require.NoError(t, err)
assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp)
}

func TestAPICPullTopBLCacheForceCall(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)

httpmock.Activate()
Expand Down Expand Up @@ -1005,11 +1009,12 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) {
require.NoError(t, err)

api.apiClient = apic
err = api.PullTop(false)
err = api.PullTop(ctx, false)
require.NoError(t, err)
}

func TestAPICPullBlocklistCall(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)

httpmock.Activate()
Expand All @@ -1032,7 +1037,7 @@ func TestAPICPullBlocklistCall(t *testing.T) {
require.NoError(t, err)

api.apiClient = apic
err = api.PullBlocklist(&modelscapi.BlocklistLink{
err = api.PullBlocklist(ctx, &modelscapi.BlocklistLink{
URL: ptr.Of("http://api.crowdsec.net/blocklist1"),
Name: ptr.Of("blocklist1"),
Scope: ptr.Of("Ip"),
Expand Down Expand Up @@ -1134,6 +1139,7 @@ func TestAPICPush(t *testing.T) {
}

func TestAPICPull(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)
tests := []struct {
name string
Expand Down Expand Up @@ -1204,7 +1210,7 @@ func TestAPICPull(t *testing.T) {
go func() {
logrus.SetOutput(&buf)

if err := api.Pull(); err != nil {
if err := api.Pull(ctx); err != nil {
panic(err)
}
}()
Expand Down
18 changes: 10 additions & 8 deletions pkg/apiserver/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,17 +310,17 @@ func (s *APIServer) apicPush() error {
return nil
}

func (s *APIServer) apicPull() error {
if err := s.apic.Pull(); err != nil {
func (s *APIServer) apicPull(ctx context.Context) error {
if err := s.apic.Pull(ctx); err != nil {
log.Errorf("capi pull: %s", err)
return err
}

return nil
}

func (s *APIServer) papiPull() error {
if err := s.papi.Pull(); err != nil {
func (s *APIServer) papiPull(ctx context.Context) error {
if err := s.papi.Pull(ctx); err != nil {
log.Errorf("papi pull: %s", err)
return err
}
Expand All @@ -337,16 +337,16 @@ func (s *APIServer) papiSync() error {
return nil
}

func (s *APIServer) initAPIC() {
func (s *APIServer) initAPIC(ctx context.Context) {
s.apic.pushTomb.Go(s.apicPush)
s.apic.pullTomb.Go(s.apicPull)
s.apic.pullTomb.Go(func() error { return s.apicPull(ctx) })

// csConfig.API.Server.ConsoleConfig.ShareCustomScenarios
if s.apic.apiClient.IsEnrolled() {
if s.consoleConfig.IsPAPIEnabled() {
if s.papi.URL != "" {
log.Info("Starting PAPI decision receiver")
s.papi.pullTomb.Go(s.papiPull)
s.papi.pullTomb.Go(func() error { return s.papiPull(ctx) })
s.papi.syncTomb.Go(s.papiSync)
} else {
log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.")
Expand Down Expand Up @@ -381,8 +381,10 @@ func (s *APIServer) Run(apiReady chan bool) error {
TLSConfig: tlsCfg,
}

ctx := context.TODO()

if s.apic != nil {
s.initAPIC()
s.initAPIC(ctx)
}

s.httpServerTomb.Go(func() error {
Expand Down
8 changes: 4 additions & 4 deletions pkg/apiserver/papi.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,13 @@ func (p *Papi) PullOnce(since time.Time, sync bool) error {
}

// PullPAPI is the long polling client for real-time decisions from PAPI
func (p *Papi) Pull() error {
func (p *Papi) Pull(ctx context.Context) error {
defer trace.CatchPanic("lapi/PullPAPI")
p.Logger.Infof("Starting Polling API Pull")

lastTimestamp := time.Time{}

lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey)
lastTimestampStr, err := p.DBClient.GetConfigItem(ctx, PapiPullKey)
if err != nil {
p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err)
}
Expand All @@ -248,7 +248,7 @@ func (p *Papi) Pull() error {
return fmt.Errorf("failed to serialize last timestamp: %w", err)
}

if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil {
if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil {
p.Logger.Errorf("error setting papi pull last key: %s", err)
} else {
p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime))
Expand Down Expand Up @@ -277,7 +277,7 @@ func (p *Papi) Pull() error {
continue
}

if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil {
if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil {
return fmt.Errorf("failed to update last timestamp: %w", err)
}

Expand Down
7 changes: 5 additions & 2 deletions pkg/apiserver/papi_cmd.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package apiserver

import (
"context"
"encoding/json"
"fmt"
"time"
Expand Down Expand Up @@ -215,17 +216,19 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err)
}

ctx := context.TODO()

if forcePullMsg.Blocklist == nil {
p.Logger.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists")

err = p.apic.PullTop(true)
err = p.apic.PullTop(ctx, true)
if err != nil {
return fmt.Errorf("failed to force pull operation: %w", err)
}
} else {
p.Logger.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name)

err = p.apic.PullBlocklist(&modelscapi.BlocklistLink{
err = p.apic.PullBlocklist(ctx, &modelscapi.BlocklistLink{
Name: &forcePullMsg.Blocklist.Name,
URL: &forcePullMsg.Blocklist.Url,
Remediation: &forcePullMsg.Blocklist.Remediation,
Expand Down
Loading

0 comments on commit eeb2801

Please sign in to comment.