From eeb28014c6860a0f50e87ef1488fb641d09edbb9 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Thu, 19 Sep 2024 14:09:35 +0200 Subject: [PATCH] context propagation: pkg/database/config (#3246) --- cmd/crowdsec-cli/clipapi/papi.go | 2 +- pkg/apiserver/apic.go | 28 ++++++++++++++-------------- pkg/apiserver/apic_test.go | 24 +++++++++++++++--------- pkg/apiserver/apiserver.go | 18 ++++++++++-------- pkg/apiserver/papi.go | 8 ++++---- pkg/apiserver/papi_cmd.go | 7 +++++-- pkg/database/config.go | 17 ++++++++++------- 7 files changed, 59 insertions(+), 45 deletions(-) diff --git a/cmd/crowdsec-cli/clipapi/papi.go b/cmd/crowdsec-cli/clipapi/papi.go index c0f08157f31..b8101a0fb34 100644 --- a/cmd/crowdsec-cli/clipapi/papi.go +++ b/cmd/crowdsec-cli/clipapi/papi.go @@ -74,7 +74,7 @@ func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Clie return fmt.Errorf("unable to get PAPI permissions: %w", err) } - lastTimestampStr, err := db.GetConfigItem(apiserver.PapiPullKey) + lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey) if err != nil { lastTimestampStr = ptr.Of("never") } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 3ed2e12ea54..b5384c6cc5c 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -614,7 +614,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio // we receive a list of decisions and links for blocklist and we need to create a list of alerts : // one alert for "community blocklist" // one alert per list we're subscribed to -func (a *apic) PullTop(forcePull bool) error { +func (a *apic) PullTop(ctx context.Context, forcePull bool) error { var err error // A mutex with TryLock would be a bit simpler @@ -655,7 +655,7 @@ func (a *apic) PullTop(forcePull bool) error { log.Infof("Starting community-blocklist update") - data, _, err := a.apiClient.Decisions.GetStreamV3(context.Background(), apiclient.DecisionsStreamOpts{Startup: a.startup}) + data, _, err := a.apiClient.Decisions.GetStreamV3(ctx, apiclient.DecisionsStreamOpts{Startup: a.startup}) if err != nil { return fmt.Errorf("get stream: %w", err) } @@ -700,7 +700,7 @@ func (a *apic) PullTop(forcePull bool) error { } // update blocklists - if err := a.UpdateBlocklists(data.Links, addCounters, forcePull); err != nil { + if err := a.UpdateBlocklists(ctx, data.Links, addCounters, forcePull); err != nil { return fmt.Errorf("while updating blocklists: %w", err) } @@ -708,9 +708,9 @@ func (a *apic) PullTop(forcePull bool) error { } // we receive a link to a blocklist, we pull the content of the blocklist and we create one alert -func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error { +func (a *apic) PullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, forcePull bool) error { addCounters, _ := makeAddAndDeleteCounters() - if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{ + if err := a.UpdateBlocklists(ctx, &modelscapi.GetDecisionsStreamResponseLinks{ Blocklists: []*modelscapi.BlocklistLink{blocklist}, }, addCounters, forcePull); err != nil { return fmt.Errorf("while pulling blocklist: %w", err) @@ -820,7 +820,7 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo return false, nil } -func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error { +func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error { if blocklist.Scope == nil { log.Warningf("blocklist has no scope") return nil @@ -848,13 +848,13 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap ) if !forcePull { - lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName) + lastPullTimestamp, err = a.dbClient.GetConfigItem(ctx, blocklistConfigItemName) if err != nil { return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) } } - decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp) + decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(ctx, blocklist, lastPullTimestamp) if err != nil { return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err) } @@ -869,7 +869,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap return nil } - err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) + err = a.dbClient.SetConfigItem(ctx, blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) if err != nil { return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) } @@ -892,7 +892,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap return nil } -func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error { +func (a *apic) UpdateBlocklists(ctx context.Context, links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error { if links == nil { return nil } @@ -908,7 +908,7 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink } for _, blocklist := range links.Blocklists { - if err := a.updateBlocklist(defaultClient, blocklist, addCounters, forcePull); err != nil { + if err := a.updateBlocklist(ctx, defaultClient, blocklist, addCounters, forcePull); err != nil { return err } } @@ -931,7 +931,7 @@ func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int } } -func (a *apic) Pull() error { +func (a *apic) Pull(ctx context.Context) error { defer trace.CatchPanic("lapi/pullFromAPIC") toldOnce := false @@ -955,7 +955,7 @@ func (a *apic) Pull() error { time.Sleep(1 * time.Second) } - if err := a.PullTop(false); err != nil { + if err := a.PullTop(ctx, false); err != nil { log.Errorf("capi pull top: %s", err) } @@ -967,7 +967,7 @@ func (a *apic) Pull() error { case <-ticker.C: ticker.Reset(a.pullInterval) - if err := a.PullTop(false); err != nil { + if err := a.PullTop(ctx, false); err != nil { log.Errorf("capi pull top: %s", err) continue } diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 182bf18532f..97943b495e5 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -550,6 +550,7 @@ func TestFillAlertsWithDecisions(t *testing.T) { } func TestAPICWhitelists(t *testing.T) { + ctx := context.Background() api := getAPIC(t) // one whitelist on IP, one on CIDR api.whitelists = &csconfig.CapiWhitelist{} @@ -685,7 +686,7 @@ func TestAPICWhitelists(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) assertTotalDecisionCount(t, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing @@ -736,6 +737,7 @@ func TestAPICWhitelists(t *testing.T) { } func TestAPICPullTop(t *testing.T) { + ctx := context.Background() api := getAPIC(t) api.dbClient.Ent.Decision.Create(). SetOrigin(types.CAPIOrigin). @@ -826,7 +828,7 @@ func TestAPICPullTop(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) assertTotalDecisionCount(t, api.dbClient, 5) @@ -860,6 +862,7 @@ func TestAPICPullTop(t *testing.T) { } func TestAPICPullTopBLCacheFirstCall(t *testing.T) { + ctx := context.Background() // no decision in db, no last modified parameter. api := getAPIC(t) @@ -913,11 +916,11 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) blocklistConfigItemName := "blocklist:blocklist1:last_pull" - lastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + lastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName) require.NoError(t, err) assert.NotEqual(t, "", *lastPullTimestamp) @@ -927,14 +930,15 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { return httpmock.NewStringResponse(304, ""), nil }) - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) - secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + secondLastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName) require.NoError(t, err) assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp) } func TestAPICPullTopBLCacheForceCall(t *testing.T) { + ctx := context.Background() api := getAPIC(t) httpmock.Activate() @@ -1005,11 +1009,12 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop(false) + err = api.PullTop(ctx, false) require.NoError(t, err) } func TestAPICPullBlocklistCall(t *testing.T) { + ctx := context.Background() api := getAPIC(t) httpmock.Activate() @@ -1032,7 +1037,7 @@ func TestAPICPullBlocklistCall(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullBlocklist(&modelscapi.BlocklistLink{ + err = api.PullBlocklist(ctx, &modelscapi.BlocklistLink{ URL: ptr.Of("http://api.crowdsec.net/blocklist1"), Name: ptr.Of("blocklist1"), Scope: ptr.Of("Ip"), @@ -1134,6 +1139,7 @@ func TestAPICPush(t *testing.T) { } func TestAPICPull(t *testing.T) { + ctx := context.Background() api := getAPIC(t) tests := []struct { name string @@ -1204,7 +1210,7 @@ func TestAPICPull(t *testing.T) { go func() { logrus.SetOutput(&buf) - if err := api.Pull(); err != nil { + if err := api.Pull(ctx); err != nil { panic(err) } }() diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 95d18ccb028..6b5d6803be9 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -310,8 +310,8 @@ func (s *APIServer) apicPush() error { return nil } -func (s *APIServer) apicPull() error { - if err := s.apic.Pull(); err != nil { +func (s *APIServer) apicPull(ctx context.Context) error { + if err := s.apic.Pull(ctx); err != nil { log.Errorf("capi pull: %s", err) return err } @@ -319,8 +319,8 @@ func (s *APIServer) apicPull() error { return nil } -func (s *APIServer) papiPull() error { - if err := s.papi.Pull(); err != nil { +func (s *APIServer) papiPull(ctx context.Context) error { + if err := s.papi.Pull(ctx); err != nil { log.Errorf("papi pull: %s", err) return err } @@ -337,16 +337,16 @@ func (s *APIServer) papiSync() error { return nil } -func (s *APIServer) initAPIC() { +func (s *APIServer) initAPIC(ctx context.Context) { s.apic.pushTomb.Go(s.apicPush) - s.apic.pullTomb.Go(s.apicPull) + s.apic.pullTomb.Go(func() error { return s.apicPull(ctx) }) // csConfig.API.Server.ConsoleConfig.ShareCustomScenarios if s.apic.apiClient.IsEnrolled() { if s.consoleConfig.IsPAPIEnabled() { if s.papi.URL != "" { log.Info("Starting PAPI decision receiver") - s.papi.pullTomb.Go(s.papiPull) + s.papi.pullTomb.Go(func() error { return s.papiPull(ctx) }) s.papi.syncTomb.Go(s.papiSync) } else { log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.") @@ -381,8 +381,10 @@ func (s *APIServer) Run(apiReady chan bool) error { TLSConfig: tlsCfg, } + ctx := context.TODO() + if s.apic != nil { - s.initAPIC() + s.initAPIC(ctx) } s.httpServerTomb.Go(func() error { diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index 89ad93930a1..7dd6b346aa9 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -230,13 +230,13 @@ func (p *Papi) PullOnce(since time.Time, sync bool) error { } // PullPAPI is the long polling client for real-time decisions from PAPI -func (p *Papi) Pull() error { +func (p *Papi) Pull(ctx context.Context) error { defer trace.CatchPanic("lapi/PullPAPI") p.Logger.Infof("Starting Polling API Pull") lastTimestamp := time.Time{} - lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey) + lastTimestampStr, err := p.DBClient.GetConfigItem(ctx, PapiPullKey) if err != nil { p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err) } @@ -248,7 +248,7 @@ func (p *Papi) Pull() error { return fmt.Errorf("failed to serialize last timestamp: %w", err) } - if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { + if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil { p.Logger.Errorf("error setting papi pull last key: %s", err) } else { p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime)) @@ -277,7 +277,7 @@ func (p *Papi) Pull() error { continue } - if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil { + if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil { return fmt.Errorf("failed to update last timestamp: %w", err) } diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index a1137161698..943eb4139de 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -1,6 +1,7 @@ package apiserver import ( + "context" "encoding/json" "fmt" "time" @@ -215,17 +216,19 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err) } + ctx := context.TODO() + if forcePullMsg.Blocklist == nil { p.Logger.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists") - err = p.apic.PullTop(true) + err = p.apic.PullTop(ctx, true) if err != nil { return fmt.Errorf("failed to force pull operation: %w", err) } } else { p.Logger.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name) - err = p.apic.PullBlocklist(&modelscapi.BlocklistLink{ + err = p.apic.PullBlocklist(ctx, &modelscapi.BlocklistLink{ Name: &forcePullMsg.Blocklist.Name, URL: &forcePullMsg.Blocklist.Url, Remediation: &forcePullMsg.Blocklist.Remediation, diff --git a/pkg/database/config.go b/pkg/database/config.go index 8c3578ad596..89ccb1e1b28 100644 --- a/pkg/database/config.go +++ b/pkg/database/config.go @@ -1,17 +1,20 @@ package database import ( + "context" + "github.com/pkg/errors" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" ) -func (c *Client) GetConfigItem(key string) (*string, error) { - result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(c.CTX) +func (c *Client) GetConfigItem(ctx context.Context, key string) (*string, error) { + result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(ctx) if err != nil && ent.IsNotFound(err) { return nil, nil } + if err != nil { return nil, errors.Wrapf(QueryFail, "select config item: %s", err) } @@ -19,16 +22,16 @@ func (c *Client) GetConfigItem(key string) (*string, error) { return &result.Value, nil } -func (c *Client) SetConfigItem(key string, value string) error { - - nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(c.CTX) - if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { //not found, create - err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(c.CTX) +func (c *Client) SetConfigItem(ctx context.Context, key string, value string) error { + nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(ctx) + if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { // not found, create + err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(ctx) if err != nil { return errors.Wrapf(QueryFail, "insert config item: %s", err) } } else if err != nil { return errors.Wrapf(QueryFail, "update config item: %s", err) } + return nil }