Skip to content

Commit

Permalink
drop CTX from dbclient
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc committed Sep 19, 2024
1 parent d79e5e5 commit 0dde8e3
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
6 changes: 3 additions & 3 deletions pkg/apiserver/apic.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,13 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) {
}
}

func (a *apic) CAPIPullIsOld() (bool, error) {
func (a *apic) CAPIPullIsOld(ctx context.Context) (bool, error) {
/*only pull community blocklist if it's older than 1h30 */
alerts := a.dbClient.Ent.Alert.Query()
alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID)))
alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert

count, err := alerts.Count(a.dbClient.CTX)
count, err := alerts.Count(ctx)
if err != nil {
return false, fmt.Errorf("while looking for CAPI alert: %w", err)
}
Expand Down Expand Up @@ -634,7 +634,7 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
}

if !forcePull {
if lastPullIsOld, err := a.CAPIPullIsOld(); err != nil {
if lastPullIsOld, err := a.CAPIPullIsOld(ctx); err != nil {
return err
} else if !lastPullIsOld {
return nil
Expand Down
10 changes: 6 additions & 4 deletions pkg/apiserver/apic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ func assertTotalAlertCount(t *testing.T, dbClient *database.Client, count int) {
func TestAPICCAPIPullIsOld(t *testing.T) {
api := getAPIC(t)

isOld, err := api.CAPIPullIsOld()
ctx := context.Background()

isOld, err := api.CAPIPullIsOld(ctx)
require.NoError(t, err)
assert.True(t, isOld)

Expand All @@ -124,17 +126,17 @@ func TestAPICCAPIPullIsOld(t *testing.T) {
SetScope("Country").
SetValue("Blah").
SetOrigin(types.CAPIOrigin).
SaveX(context.Background())
SaveX(ctx)

api.dbClient.Ent.Alert.Create().
SetCreatedAt(time.Now()).
SetScenario("crowdsec/test").
AddDecisions(
decision,
).
SaveX(context.Background())
SaveX(ctx)

isOld, err = api.CAPIPullIsOld()
isOld, err = api.CAPIPullIsOld(ctx)
require.NoError(t, err)

assert.False(t, isOld)
Expand Down
4 changes: 3 additions & 1 deletion pkg/apiserver/middlewares/v1/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
err error
)

ctx := c.Request.Context()

ret := authInput{}

if err = c.ShouldBindJSON(&loginInput); err != nil {
Expand All @@ -145,7 +147,7 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {

ret.clientMachine, err = j.DbClient.Ent.Machine.Query().
Where(machine.MachineId(ret.machineID)).
First(j.DbClient.CTX)
First(ctx)
if err != nil {
log.Infof("Error machine login for %s : %+v ", ret.machineID, err)
return nil, err
Expand Down
2 changes: 0 additions & 2 deletions pkg/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (

type Client struct {
Ent *ent.Client
CTX context.Context
Log *log.Logger
CanFlush bool
Type string
Expand Down Expand Up @@ -106,7 +105,6 @@ func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, erro

return &Client{
Ent: client,
CTX: ctx,
Log: clog,
CanFlush: true,
Type: config.Type,
Expand Down

0 comments on commit 0dde8e3

Please sign in to comment.