diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index e79224c417e..ab1cdfea917 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -426,6 +426,7 @@ func (a *apic) CAPIPullIsOld() (bool, error) { } func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, deleteCounters map[string]map[string]int) (int, error) { + ctx := context.TODO() nbDeleted := 0 for _, decision := range deletedDecisions { @@ -438,7 +439,7 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet filter["scopes"] = []string{*decision.Scope} } - dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(filter) + dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { return 0, fmt.Errorf("expiring decisions error: %w", err) } @@ -458,6 +459,8 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) { var nbDeleted int + ctx := context.TODO() + for _, decisions := range deletedDecisions { scope := decisions.Scope @@ -470,7 +473,7 @@ func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisi filter["scopes"] = []string{*scope} } - dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(filter) + dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { return 0, fmt.Errorf("expiring decisions error: %w", err) } diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index 979fa330c4b..09c40450642 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -94,7 +94,9 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) { return } - nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(decisionID) + ctx := gctx.Request.Context() + + nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(ctx, decisionID) if err != nil { c.HandleDBErrors(gctx, err) @@ -116,7 +118,9 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) { } func (c *Controller) DeleteDecisions(gctx *gin.Context) { - nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(gctx.Request.URL.Query()) + ctx := gctx.Request.Context() + + nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(ctx, gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index 04fd55b55c2..78f5dc9b0fe 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -43,6 +43,8 @@ type listUnsubscribe struct { } func DecisionCmd(message *Message, p *Papi, sync bool) error { + ctx := context.TODO() + switch message.Header.OperationCmd { case "delete": data, err := json.Marshal(message.Data) @@ -65,7 +67,7 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error { filter := make(map[string][]string) filter["uuid"] = UUIDs - _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(filter) + _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { return fmt.Errorf("unable to expire decisions %+v: %w", UUIDs, err) } @@ -170,6 +172,8 @@ func AlertCmd(message *Message, p *Papi, sync bool) error { } func ManagementCmd(message *Message, p *Papi, sync bool) error { + ctx := context.TODO() + if sync { p.Logger.Infof("Ignoring management command from PAPI in sync mode") return nil @@ -197,7 +201,7 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { filter["origin"] = []string{types.ListOrigin} filter["scenario"] = []string{unsubscribeMsg.Name} - _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(filter) + _, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter) if err != nil { return fmt.Errorf("unable to expire decisions for list %s : %w", unsubscribeMsg.Name, err) } diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index ebf6cffdec2..8547990c25f 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -317,20 +317,7 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(ctx context.Context, since *t return data, nil } -func (c *Client) DeleteDecisionById(decisionID int) ([]*ent.Decision, error) { - toDelete, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX) - if err != nil { - c.Log.Warningf("DeleteDecisionById : %s", err) - return nil, errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionID) - } - - count, err := c.DeleteDecisions(toDelete) - c.Log.Debugf("deleted %d decisions", count) - - return toDelete, err -} - -func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) { +func (c *Client) DeleteDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int @@ -433,13 +420,13 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) } - toDelete, err := decisions.All(c.CTX) + toDelete, err := decisions.All(ctx) if err != nil { c.Log.Warningf("DeleteDecisionsWithFilter : %s", err) return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter") } - count, err := c.DeleteDecisions(toDelete) + count, err := c.DeleteDecisions(ctx, toDelete) if err != nil { c.Log.Warningf("While deleting decisions : %s", err) return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter") @@ -449,7 +436,7 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, } // ExpireDecisionsWithFilter updates the expiration time to now() for the decisions matching the filter, and returns the updated items -func (c *Client) ExpireDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) { +func (c *Client) ExpireDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int @@ -558,13 +545,13 @@ func (c *Client) ExpireDecisionsWithFilter(filter map[string][]string) (string, return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz) } - DecisionsToDelete, err := decisions.All(c.CTX) + DecisionsToDelete, err := decisions.All(ctx) if err != nil { c.Log.Warningf("ExpireDecisionsWithFilter : %s", err) return "0", nil, errors.Wrap(DeleteFail, "expire decisions with provided filter") } - count, err := c.ExpireDecisions(DecisionsToDelete) + count, err := c.ExpireDecisions(ctx, DecisionsToDelete) if err != nil { return "0", nil, errors.Wrapf(DeleteFail, "expire decisions with provided filter : %s", err) } @@ -583,13 +570,13 @@ func decisionIDs(decisions []*ent.Decision) []int { // ExpireDecisions sets the expiration of a list of decisions to now() // It returns the number of impacted decisions for the CAPI/PAPI -func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) { +func (c *Client) ExpireDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) { if len(decisions) <= decisionDeleteBulkSize { ids := decisionIDs(decisions) rows, err := c.Ent.Decision.Update().Where( decision.IDIn(ids...), - ).SetUntil(time.Now().UTC()).Save(c.CTX) + ).SetUntil(time.Now().UTC()).Save(ctx) if err != nil { return 0, fmt.Errorf("expire decisions with provided filter: %w", err) } @@ -602,7 +589,7 @@ func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) { total := 0 for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) { - rows, err := c.ExpireDecisions(chunk) + rows, err := c.ExpireDecisions(ctx, chunk) if err != nil { return total, err } @@ -615,13 +602,13 @@ func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) { // DeleteDecisions removes a list of decisions from the database // It returns the number of impacted decisions for the CAPI/PAPI -func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) { +func (c *Client) DeleteDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) { if len(decisions) < decisionDeleteBulkSize { ids := decisionIDs(decisions) rows, err := c.Ent.Decision.Delete().Where( decision.IDIn(ids...), - ).Exec(c.CTX) + ).Exec(ctx) if err != nil { return 0, fmt.Errorf("hard delete decisions with provided filter: %w", err) } @@ -634,7 +621,7 @@ func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) { tot := 0 for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) { - rows, err := c.DeleteDecisions(chunk) + rows, err := c.DeleteDecisions(ctx, chunk) if err != nil { return tot, err } @@ -646,8 +633,8 @@ func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) { } // ExpireDecision set the expiration of a decision to now() -func (c *Client) ExpireDecisionByID(decisionID int) (int, []*ent.Decision, error) { - toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX) +func (c *Client) ExpireDecisionByID(ctx context.Context, decisionID int) (int, []*ent.Decision, error) { + toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(ctx) // XXX: do we want 500 or 404 here? if err != nil || len(toUpdate) == 0 { @@ -659,12 +646,12 @@ func (c *Client) ExpireDecisionByID(decisionID int) (int, []*ent.Decision, error return 0, nil, ItemNotFound } - count, err := c.ExpireDecisions(toUpdate) + count, err := c.ExpireDecisions(ctx, toUpdate) return count, toUpdate, err } -func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { +func (c *Client) CountDecisionsByValue(ctx context.Context, decisionValue string) (int, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz, count int @@ -682,7 +669,7 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") } - count, err = decisions.Count(c.CTX) + count, err = decisions.Count(ctx) if err != nil { return 0, errors.Wrapf(err, "fail to count decisions") } @@ -690,7 +677,7 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { return count, nil } -func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error) { +func (c *Client) CountActiveDecisionsByValue(ctx context.Context, decisionValue string) (int, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz, count int @@ -710,7 +697,7 @@ func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error) decisions = decisions.Where(decision.UntilGT(time.Now().UTC())) - count, err = decisions.Count(c.CTX) + count, err = decisions.Count(ctx) if err != nil { return 0, fmt.Errorf("fail to count decisions: %w", err) } @@ -718,7 +705,7 @@ func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error) return count, nil } -func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.Duration, error) { +func (c *Client) GetActiveDecisionsTimeLeftByValue(ctx context.Context, decisionValue string) (time.Duration, error) { var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int @@ -740,7 +727,7 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.D decisions = decisions.Order(ent.Desc(decision.FieldUntil)) - decision, err := decisions.First(c.CTX) + decision, err := decisions.First(ctx) if err != nil && !ent.IsNotFound(err) { return 0, fmt.Errorf("fail to get decision: %w", err) } @@ -752,7 +739,7 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.D return decision.Until.Sub(time.Now().UTC()), nil } -func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Time) (int, error) { +func (c *Client) CountDecisionsSinceByValue(ctx context.Context, decisionValue string, since time.Time) (int, error) { ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(decisionValue) if err != nil { return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err) @@ -768,7 +755,7 @@ func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Tim return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") } - count, err := decisions.Count(c.CTX) + count, err := decisions.Count(ctx) if err != nil { return 0, errors.Wrapf(err, "fail to count decisions") } diff --git a/pkg/exprhelpers/helpers.go b/pkg/exprhelpers/helpers.go index 2ca7d0be79a..6b7eb0840e9 100644 --- a/pkg/exprhelpers/helpers.go +++ b/pkg/exprhelpers/helpers.go @@ -2,6 +2,7 @@ package exprhelpers import ( "bufio" + "context" "encoding/base64" "errors" "fmt" @@ -592,7 +593,10 @@ func GetDecisionsCount(params ...any) (any, error) { return 0, nil } - count, err := dbClient.CountDecisionsByValue(value) + + ctx := context.TODO() + + count, err := dbClient.CountDecisionsByValue(ctx, value) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility @@ -613,8 +617,11 @@ func GetDecisionsSinceCount(params ...any) (any, error) { log.Errorf("Failed to parse since parameter '%s' : %s", since, err) return 0, nil } + + ctx := context.TODO() sinceTime := time.Now().UTC().Add(-sinceDuration) - count, err := dbClient.CountDecisionsSinceByValue(value, sinceTime) + + count, err := dbClient.CountDecisionsSinceByValue(ctx, value, sinceTime) if err != nil { log.Errorf("Failed to get decisions count from value '%s'", value) return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility @@ -628,7 +635,8 @@ func GetActiveDecisionsCount(params ...any) (any, error) { log.Error("No database config to call GetActiveDecisionsCount()") return 0, nil } - count, err := dbClient.CountActiveDecisionsByValue(value) + ctx := context.TODO() + count, err := dbClient.CountActiveDecisionsByValue(ctx, value) if err != nil { log.Errorf("Failed to get active decisions count from value '%s'", value) return 0, err @@ -642,7 +650,8 @@ func GetActiveDecisionsTimeLeft(params ...any) (any, error) { log.Error("No database config to call GetActiveDecisionsTimeLeft()") return 0, nil } - timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(value) + ctx := context.TODO() + timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(ctx, value) if err != nil { log.Errorf("Failed to get active decisions time left from value '%s'", value) return 0, err