Skip to content

Commit

Permalink
rest of decisions
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc committed Sep 19, 2024
1 parent 6036018 commit d79e5e5
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 46 deletions.
7 changes: 5 additions & 2 deletions pkg/apiserver/apic.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ func (a *apic) CAPIPullIsOld() (bool, error) {
}

func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, deleteCounters map[string]map[string]int) (int, error) {
ctx := context.TODO()
nbDeleted := 0

for _, decision := range deletedDecisions {
Expand All @@ -438,7 +439,7 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet
filter["scopes"] = []string{*decision.Scope}
}

dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(filter)
dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter)
if err != nil {
return 0, fmt.Errorf("expiring decisions error: %w", err)
}
Expand All @@ -458,6 +459,8 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet
func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) {
var nbDeleted int

ctx := context.TODO()

for _, decisions := range deletedDecisions {
scope := decisions.Scope

Expand All @@ -470,7 +473,7 @@ func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisi
filter["scopes"] = []string{*scope}
}

dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(filter)
dbCliRet, _, err := a.dbClient.ExpireDecisionsWithFilter(ctx, filter)
if err != nil {
return 0, fmt.Errorf("expiring decisions error: %w", err)
}
Expand Down
8 changes: 6 additions & 2 deletions pkg/apiserver/controllers/v1/decisions.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) {
return
}

nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(decisionID)
ctx := gctx.Request.Context()

nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionByID(ctx, decisionID)
if err != nil {
c.HandleDBErrors(gctx, err)

Expand All @@ -116,7 +118,9 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) {
}

func (c *Controller) DeleteDecisions(gctx *gin.Context) {
nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(gctx.Request.URL.Query())
ctx := gctx.Request.Context()

nbDeleted, deletedFromDB, err := c.DBClient.ExpireDecisionsWithFilter(ctx, gctx.Request.URL.Query())
if err != nil {
c.HandleDBErrors(gctx, err)

Expand Down
8 changes: 6 additions & 2 deletions pkg/apiserver/papi_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ type listUnsubscribe struct {
}

func DecisionCmd(message *Message, p *Papi, sync bool) error {
ctx := context.TODO()

switch message.Header.OperationCmd {
case "delete":
data, err := json.Marshal(message.Data)
Expand All @@ -65,7 +67,7 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error {
filter := make(map[string][]string)
filter["uuid"] = UUIDs

_, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(filter)
_, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter)
if err != nil {
return fmt.Errorf("unable to expire decisions %+v: %w", UUIDs, err)
}
Expand Down Expand Up @@ -170,6 +172,8 @@ func AlertCmd(message *Message, p *Papi, sync bool) error {
}

func ManagementCmd(message *Message, p *Papi, sync bool) error {
ctx := context.TODO()

if sync {
p.Logger.Infof("Ignoring management command from PAPI in sync mode")
return nil
Expand Down Expand Up @@ -197,7 +201,7 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
filter["origin"] = []string{types.ListOrigin}
filter["scenario"] = []string{unsubscribeMsg.Name}

_, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(filter)
_, deletedDecisions, err := p.DBClient.ExpireDecisionsWithFilter(ctx, filter)
if err != nil {
return fmt.Errorf("unable to expire decisions for list %s : %w", unsubscribeMsg.Name, err)
}
Expand Down
59 changes: 23 additions & 36 deletions pkg/database/decisions.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,20 +317,7 @@ func (c *Client) QueryNewDecisionsSinceWithFilters(ctx context.Context, since *t
return data, nil
}

func (c *Client) DeleteDecisionById(decisionID int) ([]*ent.Decision, error) {
toDelete, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX)
if err != nil {
c.Log.Warningf("DeleteDecisionById : %s", err)
return nil, errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionID)
}

count, err := c.DeleteDecisions(toDelete)
c.Log.Debugf("deleted %d decisions", count)

return toDelete, err
}

func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) {
func (c *Client) DeleteDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int
Expand Down Expand Up @@ -433,13 +420,13 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string,
return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz)
}

toDelete, err := decisions.All(c.CTX)
toDelete, err := decisions.All(ctx)
if err != nil {
c.Log.Warningf("DeleteDecisionsWithFilter : %s", err)
return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter")
}

count, err := c.DeleteDecisions(toDelete)
count, err := c.DeleteDecisions(ctx, toDelete)
if err != nil {
c.Log.Warningf("While deleting decisions : %s", err)
return "0", nil, errors.Wrap(DeleteFail, "decisions with provided filter")
Expand All @@ -449,7 +436,7 @@ func (c *Client) DeleteDecisionsWithFilter(filter map[string][]string) (string,
}

// ExpireDecisionsWithFilter updates the expiration time to now() for the decisions matching the filter, and returns the updated items
func (c *Client) ExpireDecisionsWithFilter(filter map[string][]string) (string, []*ent.Decision, error) {
func (c *Client) ExpireDecisionsWithFilter(ctx context.Context, filter map[string][]string) (string, []*ent.Decision, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int
Expand Down Expand Up @@ -558,13 +545,13 @@ func (c *Client) ExpireDecisionsWithFilter(filter map[string][]string) (string,
return "0", nil, errors.Wrapf(InvalidFilter, "Unknown ip size %d", ip_sz)
}

DecisionsToDelete, err := decisions.All(c.CTX)
DecisionsToDelete, err := decisions.All(ctx)
if err != nil {
c.Log.Warningf("ExpireDecisionsWithFilter : %s", err)
return "0", nil, errors.Wrap(DeleteFail, "expire decisions with provided filter")
}

count, err := c.ExpireDecisions(DecisionsToDelete)
count, err := c.ExpireDecisions(ctx, DecisionsToDelete)
if err != nil {
return "0", nil, errors.Wrapf(DeleteFail, "expire decisions with provided filter : %s", err)
}
Expand All @@ -583,13 +570,13 @@ func decisionIDs(decisions []*ent.Decision) []int {

// ExpireDecisions sets the expiration of a list of decisions to now()
// It returns the number of impacted decisions for the CAPI/PAPI
func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) {
func (c *Client) ExpireDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) {
if len(decisions) <= decisionDeleteBulkSize {
ids := decisionIDs(decisions)

rows, err := c.Ent.Decision.Update().Where(
decision.IDIn(ids...),
).SetUntil(time.Now().UTC()).Save(c.CTX)
).SetUntil(time.Now().UTC()).Save(ctx)
if err != nil {
return 0, fmt.Errorf("expire decisions with provided filter: %w", err)
}
Expand All @@ -602,7 +589,7 @@ func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) {
total := 0

for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) {
rows, err := c.ExpireDecisions(chunk)
rows, err := c.ExpireDecisions(ctx, chunk)
if err != nil {
return total, err
}
Expand All @@ -615,13 +602,13 @@ func (c *Client) ExpireDecisions(decisions []*ent.Decision) (int, error) {

// DeleteDecisions removes a list of decisions from the database
// It returns the number of impacted decisions for the CAPI/PAPI
func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) {
func (c *Client) DeleteDecisions(ctx context.Context, decisions []*ent.Decision) (int, error) {
if len(decisions) < decisionDeleteBulkSize {
ids := decisionIDs(decisions)

rows, err := c.Ent.Decision.Delete().Where(
decision.IDIn(ids...),
).Exec(c.CTX)
).Exec(ctx)
if err != nil {
return 0, fmt.Errorf("hard delete decisions with provided filter: %w", err)
}
Expand All @@ -634,7 +621,7 @@ func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) {
tot := 0

for _, chunk := range slicetools.Chunks(decisions, decisionDeleteBulkSize) {
rows, err := c.DeleteDecisions(chunk)
rows, err := c.DeleteDecisions(ctx, chunk)
if err != nil {
return tot, err
}
Expand All @@ -646,8 +633,8 @@ func (c *Client) DeleteDecisions(decisions []*ent.Decision) (int, error) {
}

// ExpireDecision set the expiration of a decision to now()
func (c *Client) ExpireDecisionByID(decisionID int) (int, []*ent.Decision, error) {
toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX)
func (c *Client) ExpireDecisionByID(ctx context.Context, decisionID int) (int, []*ent.Decision, error) {
toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(ctx)

// XXX: do we want 500 or 404 here?
if err != nil || len(toUpdate) == 0 {
Expand All @@ -659,12 +646,12 @@ func (c *Client) ExpireDecisionByID(decisionID int) (int, []*ent.Decision, error
return 0, nil, ItemNotFound
}

count, err := c.ExpireDecisions(toUpdate)
count, err := c.ExpireDecisions(ctx, toUpdate)

return count, toUpdate, err
}

func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) {
func (c *Client) CountDecisionsByValue(ctx context.Context, decisionValue string) (int, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz, count int
Expand All @@ -682,15 +669,15 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) {
return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter")
}

count, err = decisions.Count(c.CTX)
count, err = decisions.Count(ctx)
if err != nil {
return 0, errors.Wrapf(err, "fail to count decisions")
}

return count, nil
}

func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error) {
func (c *Client) CountActiveDecisionsByValue(ctx context.Context, decisionValue string) (int, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz, count int
Expand All @@ -710,15 +697,15 @@ func (c *Client) CountActiveDecisionsByValue(decisionValue string) (int, error)

decisions = decisions.Where(decision.UntilGT(time.Now().UTC()))

count, err = decisions.Count(c.CTX)
count, err = decisions.Count(ctx)
if err != nil {
return 0, fmt.Errorf("fail to count decisions: %w", err)
}

return count, nil
}

func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.Duration, error) {
func (c *Client) GetActiveDecisionsTimeLeftByValue(ctx context.Context, decisionValue string) (time.Duration, error) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int
Expand All @@ -740,7 +727,7 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.D

decisions = decisions.Order(ent.Desc(decision.FieldUntil))

decision, err := decisions.First(c.CTX)
decision, err := decisions.First(ctx)
if err != nil && !ent.IsNotFound(err) {
return 0, fmt.Errorf("fail to get decision: %w", err)
}
Expand All @@ -752,7 +739,7 @@ func (c *Client) GetActiveDecisionsTimeLeftByValue(decisionValue string) (time.D
return decision.Until.Sub(time.Now().UTC()), nil
}

func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Time) (int, error) {
func (c *Client) CountDecisionsSinceByValue(ctx context.Context, decisionValue string, since time.Time) (int, error) {
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(decisionValue)
if err != nil {
return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err)
Expand All @@ -768,7 +755,7 @@ func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Tim
return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter")
}

count, err := decisions.Count(c.CTX)
count, err := decisions.Count(ctx)
if err != nil {
return 0, errors.Wrapf(err, "fail to count decisions")
}
Expand Down
17 changes: 13 additions & 4 deletions pkg/exprhelpers/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package exprhelpers

import (
"bufio"
"context"
"encoding/base64"
"errors"
"fmt"
Expand Down Expand Up @@ -592,7 +593,10 @@ func GetDecisionsCount(params ...any) (any, error) {
return 0, nil

}
count, err := dbClient.CountDecisionsByValue(value)

ctx := context.TODO()

count, err := dbClient.CountDecisionsByValue(ctx, value)
if err != nil {
log.Errorf("Failed to get decisions count from value '%s'", value)
return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility
Expand All @@ -613,8 +617,11 @@ func GetDecisionsSinceCount(params ...any) (any, error) {
log.Errorf("Failed to parse since parameter '%s' : %s", since, err)
return 0, nil
}

ctx := context.TODO()
sinceTime := time.Now().UTC().Add(-sinceDuration)
count, err := dbClient.CountDecisionsSinceByValue(value, sinceTime)

count, err := dbClient.CountDecisionsSinceByValue(ctx, value, sinceTime)
if err != nil {
log.Errorf("Failed to get decisions count from value '%s'", value)
return 0, nil //nolint:nilerr // This helper did not return an error before the move to expr.Function, we keep this behavior for backward compatibility
Expand All @@ -628,7 +635,8 @@ func GetActiveDecisionsCount(params ...any) (any, error) {
log.Error("No database config to call GetActiveDecisionsCount()")
return 0, nil
}
count, err := dbClient.CountActiveDecisionsByValue(value)
ctx := context.TODO()
count, err := dbClient.CountActiveDecisionsByValue(ctx, value)
if err != nil {
log.Errorf("Failed to get active decisions count from value '%s'", value)
return 0, err
Expand All @@ -642,7 +650,8 @@ func GetActiveDecisionsTimeLeft(params ...any) (any, error) {
log.Error("No database config to call GetActiveDecisionsTimeLeft()")
return 0, nil
}
timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(value)
ctx := context.TODO()
timeLeft, err := dbClient.GetActiveDecisionsTimeLeftByValue(ctx, value)
if err != nil {
log.Errorf("Failed to get active decisions time left from value '%s'", value)
return 0, err
Expand Down

0 comments on commit d79e5e5

Please sign in to comment.