Skip to content

Commit

Permalink
context propagation: pkg/database/config
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc committed Sep 19, 2024
1 parent b4a2403 commit 1579556
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 44 deletions.
2 changes: 1 addition & 1 deletion cmd/crowdsec-cli/clipapi/papi.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Clie
return fmt.Errorf("unable to get PAPI permissions: %w", err)
}

lastTimestampStr, err := db.GetConfigItem(apiserver.PapiPullKey)
lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey)

Check warning on line 77 in cmd/crowdsec-cli/clipapi/papi.go

View check run for this annotation

Codecov / codecov/patch

cmd/crowdsec-cli/clipapi/papi.go#L77

Added line #L77 was not covered by tests
if err != nil {
lastTimestampStr = ptr.Of("never")
}
Expand Down
28 changes: 14 additions & 14 deletions pkg/apiserver/apic.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio
// we receive a list of decisions and links for blocklist and we need to create a list of alerts :
// one alert for "community blocklist"
// one alert per list we're subscribed to
func (a *apic) PullTop(forcePull bool) error {
func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
var err error

// A mutex with TryLock would be a bit simpler
Expand Down Expand Up @@ -655,7 +655,7 @@ func (a *apic) PullTop(forcePull bool) error {

log.Infof("Starting community-blocklist update")

data, _, err := a.apiClient.Decisions.GetStreamV3(context.Background(), apiclient.DecisionsStreamOpts{Startup: a.startup})
data, _, err := a.apiClient.Decisions.GetStreamV3(ctx, apiclient.DecisionsStreamOpts{Startup: a.startup})
if err != nil {
return fmt.Errorf("get stream: %w", err)
}
Expand Down Expand Up @@ -700,17 +700,17 @@ func (a *apic) PullTop(forcePull bool) error {
}

// update blocklists
if err := a.UpdateBlocklists(data.Links, addCounters, forcePull); err != nil {
if err := a.UpdateBlocklists(ctx, data.Links, addCounters, forcePull); err != nil {
return fmt.Errorf("while updating blocklists: %w", err)
}

return nil
}

// we receive a link to a blocklist, we pull the content of the blocklist and we create one alert
func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error {
func (a *apic) PullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, forcePull bool) error {
addCounters, _ := makeAddAndDeleteCounters()
if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{
if err := a.UpdateBlocklists(ctx, &modelscapi.GetDecisionsStreamResponseLinks{
Blocklists: []*modelscapi.BlocklistLink{blocklist},
}, addCounters, forcePull); err != nil {
return fmt.Errorf("while pulling blocklist: %w", err)
Expand Down Expand Up @@ -820,7 +820,7 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo
return false, nil
}

func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error {
func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error {
if blocklist.Scope == nil {
log.Warningf("blocklist has no scope")
return nil
Expand Down Expand Up @@ -848,13 +848,13 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap
)

if !forcePull {
lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName)
lastPullTimestamp, err = a.dbClient.GetConfigItem(ctx, blocklistConfigItemName)

Check warning on line 851 in pkg/apiserver/apic.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/apic.go#L851

Added line #L851 was not covered by tests
if err != nil {
return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
}
}

decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp)
decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(ctx, blocklist, lastPullTimestamp)
if err != nil {
return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err)
}
Expand All @@ -869,7 +869,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap
return nil
}

err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
err = a.dbClient.SetConfigItem(ctx, blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
if err != nil {
return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
}
Expand All @@ -892,7 +892,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap
return nil
}

func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error {
func (a *apic) UpdateBlocklists(ctx context.Context, links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error {
if links == nil {
return nil
}
Expand All @@ -908,7 +908,7 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink
}

for _, blocklist := range links.Blocklists {
if err := a.updateBlocklist(defaultClient, blocklist, addCounters, forcePull); err != nil {
if err := a.updateBlocklist(ctx, defaultClient, blocklist, addCounters, forcePull); err != nil {
return err
}
}
Expand All @@ -931,7 +931,7 @@ func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int
}
}

func (a *apic) Pull() error {
func (a *apic) Pull(ctx context.Context) error {
defer trace.CatchPanic("lapi/pullFromAPIC")

toldOnce := false
Expand All @@ -955,7 +955,7 @@ func (a *apic) Pull() error {
time.Sleep(1 * time.Second)
}

if err := a.PullTop(false); err != nil {
if err := a.PullTop(ctx, false); err != nil {
log.Errorf("capi pull top: %s", err)
}

Expand All @@ -967,7 +967,7 @@ func (a *apic) Pull() error {
case <-ticker.C:
ticker.Reset(a.pullInterval)

if err := a.PullTop(false); err != nil {
if err := a.PullTop(ctx, false); err != nil {
log.Errorf("capi pull top: %s", err)
continue
}
Expand Down
24 changes: 15 additions & 9 deletions pkg/apiserver/apic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ func TestFillAlertsWithDecisions(t *testing.T) {
}

func TestAPICWhitelists(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)
// one whitelist on IP, one on CIDR
api.whitelists = &csconfig.CapiWhitelist{}
Expand Down Expand Up @@ -685,7 +686,7 @@ func TestAPICWhitelists(t *testing.T) {
require.NoError(t, err)

api.apiClient = apic
err = api.PullTop(false)
err = api.PullTop(ctx, false)
require.NoError(t, err)

assertTotalDecisionCount(t, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing
Expand Down Expand Up @@ -736,6 +737,7 @@ func TestAPICWhitelists(t *testing.T) {
}

func TestAPICPullTop(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)
api.dbClient.Ent.Decision.Create().
SetOrigin(types.CAPIOrigin).
Expand Down Expand Up @@ -826,7 +828,7 @@ func TestAPICPullTop(t *testing.T) {
require.NoError(t, err)

api.apiClient = apic
err = api.PullTop(false)
err = api.PullTop(ctx, false)
require.NoError(t, err)

assertTotalDecisionCount(t, api.dbClient, 5)
Expand Down Expand Up @@ -860,6 +862,7 @@ func TestAPICPullTop(t *testing.T) {
}

func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
ctx := context.Background()
// no decision in db, no last modified parameter.
api := getAPIC(t)

Expand Down Expand Up @@ -913,11 +916,11 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
require.NoError(t, err)

api.apiClient = apic
err = api.PullTop(false)
err = api.PullTop(ctx, false)
require.NoError(t, err)

blocklistConfigItemName := "blocklist:blocklist1:last_pull"
lastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName)
lastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
require.NoError(t, err)
assert.NotEqual(t, "", *lastPullTimestamp)

Expand All @@ -927,14 +930,15 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
return httpmock.NewStringResponse(304, ""), nil
})

err = api.PullTop(false)
err = api.PullTop(ctx, false)
require.NoError(t, err)
secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName)
secondLastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
require.NoError(t, err)
assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp)
}

func TestAPICPullTopBLCacheForceCall(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)

httpmock.Activate()
Expand Down Expand Up @@ -1005,11 +1009,12 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) {
require.NoError(t, err)

api.apiClient = apic
err = api.PullTop(false)
err = api.PullTop(ctx, false)
require.NoError(t, err)
}

func TestAPICPullBlocklistCall(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)

httpmock.Activate()
Expand All @@ -1032,7 +1037,7 @@ func TestAPICPullBlocklistCall(t *testing.T) {
require.NoError(t, err)

api.apiClient = apic
err = api.PullBlocklist(&modelscapi.BlocklistLink{
err = api.PullBlocklist(ctx, &modelscapi.BlocklistLink{
URL: ptr.Of("http://api.crowdsec.net/blocklist1"),
Name: ptr.Of("blocklist1"),
Scope: ptr.Of("Ip"),
Expand Down Expand Up @@ -1134,6 +1139,7 @@ func TestAPICPush(t *testing.T) {
}

func TestAPICPull(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)
tests := []struct {
name string
Expand Down Expand Up @@ -1204,7 +1210,7 @@ func TestAPICPull(t *testing.T) {
go func() {
logrus.SetOutput(&buf)

if err := api.Pull(); err != nil {
if err := api.Pull(ctx); err != nil {
panic(err)
}
}()
Expand Down
18 changes: 10 additions & 8 deletions pkg/apiserver/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,17 +310,17 @@ func (s *APIServer) apicPush() error {
return nil
}

func (s *APIServer) apicPull() error {
if err := s.apic.Pull(); err != nil {
func (s *APIServer) apicPull(ctx context.Context) error {
if err := s.apic.Pull(ctx); err != nil {
log.Errorf("capi pull: %s", err)
return err
}

return nil
}

func (s *APIServer) papiPull() error {
if err := s.papi.Pull(); err != nil {
func (s *APIServer) papiPull(ctx context.Context) error {
if err := s.papi.Pull(ctx); err != nil {

Check warning on line 323 in pkg/apiserver/apiserver.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/apiserver.go#L322-L323

Added lines #L322 - L323 were not covered by tests
log.Errorf("papi pull: %s", err)
return err
}
Expand All @@ -337,16 +337,16 @@ func (s *APIServer) papiSync() error {
return nil
}

func (s *APIServer) initAPIC() {
func (s *APIServer) initAPIC(ctx context.Context) {
s.apic.pushTomb.Go(s.apicPush)
s.apic.pullTomb.Go(s.apicPull)
s.apic.pullTomb.Go(func() error { return s.apicPull(ctx) })

// csConfig.API.Server.ConsoleConfig.ShareCustomScenarios
if s.apic.apiClient.IsEnrolled() {
if s.consoleConfig.IsPAPIEnabled() {
if s.papi.URL != "" {
log.Info("Starting PAPI decision receiver")
s.papi.pullTomb.Go(s.papiPull)
s.papi.pullTomb.Go(func() error { return s.papiPull(ctx) })

Check warning on line 349 in pkg/apiserver/apiserver.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/apiserver.go#L349

Added line #L349 was not covered by tests
s.papi.syncTomb.Go(s.papiSync)
} else {
log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.")
Expand Down Expand Up @@ -381,8 +381,10 @@ func (s *APIServer) Run(apiReady chan bool) error {
TLSConfig: tlsCfg,
}

ctx := context.TODO()

Check warning on line 385 in pkg/apiserver/apiserver.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/apiserver.go#L385

Added line #L385 was not covered by tests
if s.apic != nil {
s.initAPIC()
s.initAPIC(ctx)
}

s.httpServerTomb.Go(func() error {
Expand Down
8 changes: 4 additions & 4 deletions pkg/apiserver/papi.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,13 @@ func (p *Papi) PullOnce(since time.Time, sync bool) error {
}

// PullPAPI is the long polling client for real-time decisions from PAPI
func (p *Papi) Pull() error {
func (p *Papi) Pull(ctx context.Context) error {

Check warning on line 233 in pkg/apiserver/papi.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/papi.go#L233

Added line #L233 was not covered by tests
defer trace.CatchPanic("lapi/PullPAPI")
p.Logger.Infof("Starting Polling API Pull")

lastTimestamp := time.Time{}

lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey)
lastTimestampStr, err := p.DBClient.GetConfigItem(ctx, PapiPullKey)

Check warning on line 239 in pkg/apiserver/papi.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/papi.go#L239

Added line #L239 was not covered by tests
if err != nil {
p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err)
}
Expand All @@ -248,7 +248,7 @@ func (p *Papi) Pull() error {
return fmt.Errorf("failed to serialize last timestamp: %w", err)
}

if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil {
if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil {

Check warning on line 251 in pkg/apiserver/papi.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/papi.go#L251

Added line #L251 was not covered by tests
p.Logger.Errorf("error setting papi pull last key: %s", err)
} else {
p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime))
Expand Down Expand Up @@ -277,7 +277,7 @@ func (p *Papi) Pull() error {
continue
}

if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil {
if err := p.DBClient.SetConfigItem(ctx, PapiPullKey, string(binTime)); err != nil {

Check warning on line 280 in pkg/apiserver/papi.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/papi.go#L280

Added line #L280 was not covered by tests
return fmt.Errorf("failed to update last timestamp: %w", err)
}

Expand Down
7 changes: 5 additions & 2 deletions pkg/apiserver/papi_cmd.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package apiserver

import (
"context"
"encoding/json"
"fmt"
"time"
Expand Down Expand Up @@ -215,17 +216,19 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err)
}

ctx := context.TODO()

Check warning on line 220 in pkg/apiserver/papi_cmd.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/papi_cmd.go#L219-L220

Added lines #L219 - L220 were not covered by tests
if forcePullMsg.Blocklist == nil {
p.Logger.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists")

err = p.apic.PullTop(true)
err = p.apic.PullTop(ctx, true)

Check warning on line 224 in pkg/apiserver/papi_cmd.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/papi_cmd.go#L224

Added line #L224 was not covered by tests
if err != nil {
return fmt.Errorf("failed to force pull operation: %w", err)
}
} else {
p.Logger.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name)

err = p.apic.PullBlocklist(&modelscapi.BlocklistLink{
err = p.apic.PullBlocklist(ctx, &modelscapi.BlocklistLink{

Check warning on line 231 in pkg/apiserver/papi_cmd.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/papi_cmd.go#L231

Added line #L231 was not covered by tests
Name: &forcePullMsg.Blocklist.Name,
URL: &forcePullMsg.Blocklist.Url,
Remediation: &forcePullMsg.Blocklist.Remediation,
Expand Down
12 changes: 6 additions & 6 deletions pkg/database/config.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package database

import (
"context"
"github.com/pkg/errors"

"github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem"
)

func (c *Client) GetConfigItem(key string) (*string, error) {
result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(c.CTX)
func (c *Client) GetConfigItem(ctx context.Context, key string) (*string, error) {
result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(ctx)
if err != nil && ent.IsNotFound(err) {
return nil, nil
}
Expand All @@ -19,11 +20,10 @@ func (c *Client) GetConfigItem(key string) (*string, error) {
return &result.Value, nil
}

func (c *Client) SetConfigItem(key string, value string) error {

nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(c.CTX)
func (c *Client) SetConfigItem(ctx context.Context, key string, value string) error {
nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(ctx)
if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { //not found, create
err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(c.CTX)
err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(ctx)
if err != nil {
return errors.Wrapf(QueryFail, "insert config item: %s", err)
}
Expand Down

0 comments on commit 1579556

Please sign in to comment.