From 0dde8e3f55797f82905fca847f31ba2ae3f48124 Mon Sep 17 00:00:00 2001 From: marco Date: Wed, 18 Sep 2024 14:33:06 +0200 Subject: [PATCH] drop CTX from dbclient --- pkg/apiserver/apic.go | 6 +++--- pkg/apiserver/apic_test.go | 10 ++++++---- pkg/apiserver/middlewares/v1/jwt.go | 4 +++- pkg/database/database.go | 2 -- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index ab1cdfea917..9b56fef6549 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -406,13 +406,13 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { } } -func (a *apic) CAPIPullIsOld() (bool, error) { +func (a *apic) CAPIPullIsOld(ctx context.Context) (bool, error) { /*only pull community blocklist if it's older than 1h30 */ alerts := a.dbClient.Ent.Alert.Query() alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID))) alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert - count, err := alerts.Count(a.dbClient.CTX) + count, err := alerts.Count(ctx) if err != nil { return false, fmt.Errorf("while looking for CAPI alert: %w", err) } @@ -634,7 +634,7 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error { } if !forcePull { - if lastPullIsOld, err := a.CAPIPullIsOld(); err != nil { + if lastPullIsOld, err := a.CAPIPullIsOld(ctx); err != nil { return err } else if !lastPullIsOld { return nil diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 97943b495e5..3bb158acf35 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -113,7 +113,9 @@ func assertTotalAlertCount(t *testing.T, dbClient *database.Client, count int) { func TestAPICCAPIPullIsOld(t *testing.T) { api := getAPIC(t) - isOld, err := api.CAPIPullIsOld() + ctx := context.Background() + + isOld, err := api.CAPIPullIsOld(ctx) require.NoError(t, err) assert.True(t, isOld) @@ -124,7 +126,7 @@ func TestAPICCAPIPullIsOld(t *testing.T) { SetScope("Country"). SetValue("Blah"). SetOrigin(types.CAPIOrigin). - SaveX(context.Background()) + SaveX(ctx) api.dbClient.Ent.Alert.Create(). SetCreatedAt(time.Now()). @@ -132,9 +134,9 @@ func TestAPICCAPIPullIsOld(t *testing.T) { AddDecisions( decision, ). - SaveX(context.Background()) + SaveX(ctx) - isOld, err = api.CAPIPullIsOld() + isOld, err = api.CAPIPullIsOld(ctx) require.NoError(t, err) assert.False(t, isOld) diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index b7cccd1aa39..5877c945106 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -129,6 +129,8 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { err error ) + ctx := c.Request.Context() + ret := authInput{} if err = c.ShouldBindJSON(&loginInput); err != nil { @@ -145,7 +147,7 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). Where(machine.MachineId(ret.machineID)). - First(j.DbClient.CTX) + First(ctx) if err != nil { log.Infof("Error machine login for %s : %+v ", ret.machineID, err) return nil, err diff --git a/pkg/database/database.go b/pkg/database/database.go index e513459199f..bb41dd3b645 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -21,7 +21,6 @@ import ( type Client struct { Ent *ent.Client - CTX context.Context Log *log.Logger CanFlush bool Type string @@ -106,7 +105,6 @@ func NewClient(ctx context.Context, config *csconfig.DatabaseCfg) (*Client, erro return &Client{ Ent: client, - CTX: ctx, Log: clog, CanFlush: true, Type: config.Type,