Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for pruning macro call nodes with limited optional support #705

Merged
merged 4 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 137 additions & 90 deletions interpreter/prune.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package interpreter

import (
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
Expand Down Expand Up @@ -94,35 +95,50 @@ func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr {
}

func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, bool) {
switch val.Type() {
case types.BoolType:
switch v := val.(type) {
case types.Bool:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: val.Value().(bool)}}), true
case types.IntType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: bool(v)}}), true
case types.Bytes:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: val.Value().(int64)}}), true
case types.UintType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: []byte(v)}}), true
case types.Double:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: val.Value().(uint64)}}), true
case types.StringType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: float64(v)}}), true
case types.Duration:
p.state.SetValue(id, val)
durationString := string(v.ConvertToType(types.StringType).(types.String))
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: overloads.TypeConvertDuration,
Args: []*exprpb.Expr{
p.createLiteral(p.nextID(),
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: durationString}}),
},
},
},
}, true
case types.Int:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: val.Value().(string)}}), true
case types.DoubleType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: int64(v)}}), true
case types.Uint:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: val.Value().(float64)}}), true
case types.BytesType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: uint64(v)}}), true
case types.String:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: val.Value().([]byte)}}), true
case types.NullType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: string(v)}}), true
case types.Null:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: val.Value().(structpb.NullValue)}}), true
&exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: v.Value().(structpb.NullValue)}}), true
}

// Attempt to build a list literal.
Expand Down Expand Up @@ -196,29 +212,37 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
return nil, false
}

func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
return nil, false
func (p *astPruner) maybePruneOptional(elem *exprpb.Expr) (*exprpb.Expr, bool) {
elemVal, found := p.value(elem.GetId())
if found && elemVal.Type() == types.OptionalType {
opt := elemVal.(*types.Optional)
if !opt.HasValue() {
return nil, true
}
if newElem, pruned := p.maybeCreateLiteral(elem.GetId(), opt.GetValue()); pruned {
return newElem, true
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
}
}
return elem, false
}

func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
val, valueExists := p.value(call.GetArgs()[1].GetId())
if !valueExists {
v, exists := p.value(call.GetArgs()[1].GetId())
if !exists || types.IsUnknownOrError(v) {
return nil, false
}
if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero {
if sz, ok := v.(traits.Sizer); ok && sz.Size() == types.IntZero {
return p.maybeCreateLiteral(node.GetId(), types.False)
}
return nil, false
}

func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
return nil, false
}
call := node.GetCallExpr()
arg := call.GetArgs()[0]
v, exists := p.value(arg.GetId())
if !exists {
if !exists || types.IsUnknownOrError(v) {
return nil, false
}
if b, ok := v.(types.Bool); ok {
Expand All @@ -228,37 +252,28 @@ func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool)
}

func (p *astPruner) maybePruneAndOr(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
return nil, false
}

call := node.GetCallExpr()
// We know result is unknown, so we have at least one unknown arg
// and if one side is a known value, we know we can ignore it.
if p.existsWithKnownValue(call.Args[0].GetId()) {
return call.Args[1], true
if p.existsWithKnownValue(call.GetArgs()[0].GetId()) {
return call.GetArgs()[1], true
}
if p.existsWithKnownValue(call.Args[1].GetId()) {
return call.Args[0], true
if p.existsWithKnownValue(call.GetArgs()[1].GetId()) {
return call.GetArgs()[0], true
}
return nil, false
}

func (p *astPruner) maybePruneConditional(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
return nil, false
}

call := node.GetCallExpr()
condVal, condValueExists := p.value(call.Args[0].GetId())
if !condValueExists || types.IsUnknownOrError(condVal) {
cond, exists := p.value(call.GetArgs()[0].GetId())
if !exists || types.IsUnknownOrError(cond) {
return nil, false
}

if condVal.Value().(bool) {
return call.Args[1], true
if cond.Value().(bool) {
return call.GetArgs()[1], true
}
return call.Args[2], true
return call.GetArgs()[2], true
}

func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) {
Expand All @@ -279,23 +294,28 @@ func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) {
}

func (p *astPruner) maybePrune(node *exprpb.Expr) (*exprpb.Expr, bool) {
out, pruned := p.prune(node)
if pruned {
delete(p.macroCalls, node.GetId())
}
return out, pruned
return p.prune(node)
}

func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
if node == nil {
return node, false
}
val, valueExists := p.value(node.GetId())
if valueExists && !types.IsUnknownOrError(val) {
if newNode, ok := p.maybeCreateLiteral(node.GetId(), val); ok {
v, exists := p.value(node.GetId())
if exists && !types.IsUnknownOrError(v) {
if newNode, ok := p.maybeCreateLiteral(node.GetId(), v); ok {
// if the macro completely evaluated, then delete the reference to it, if one exists.
delete(p.macroCalls, node.GetId())
// return the literal value.
return newNode, true
}
}
if macro, found := p.macroCalls[node.GetId()]; found {
// prune the expression in terms of the macro call instead of the expanded form.
if newMacro, pruned := p.prune(macro); pruned {
p.macroCalls[node.GetId()] = newMacro
}
}

// We have either an unknown/error value, or something we don't want to
// transform, or expression was not evaluated. If possible, drill down
Expand Down Expand Up @@ -351,21 +371,51 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
}
case *exprpb.Expr_ListExpr:
elems := node.GetListExpr().GetElements()
newElems := make([]*exprpb.Expr, len(elems))
optIndices := node.GetListExpr().GetOptionalIndices()
optIndexMap := map[int32]bool{}
for _, i := range optIndices {
optIndexMap[i] = true
}
newOptIndexMap := make(map[int32]bool, len(optIndexMap))
newElems := make([]*exprpb.Expr, 0, len(elems))
var prunedList bool

prunedIdx := 0
for i, elem := range elems {
newElems[i] = elem
_, isOpt := optIndexMap[int32(i)]
if isOpt {
newElem, pruned := p.maybePruneOptional(elem)
if pruned {
if newElem != nil {
newElems = append(newElems, newElem)
prunedIdx++
}
prunedList = true
continue
}
newOptIndexMap[int32(prunedIdx)] = true
}
if newElem, prunedElem := p.maybePrune(elem); prunedElem {
newElems[i] = newElem
newElems = append(newElems, newElem)
prunedList = true
} else {
newElems = append(newElems, elem)
}
prunedIdx++
}
optIndices = make([]int32, len(newOptIndexMap))
idx := 0
for i := range newOptIndexMap {
optIndices[idx] = i
idx++
}
if prunedList {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{
Elements: newElems,
Elements: newElems,
OptionalIndices: optIndices,
},
},
}, true
Expand Down Expand Up @@ -395,6 +445,7 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
MapKey: newKey,
}
}
newEntry.OptionalEntry = entry.GetOptionalEntry()
newEntries[i] = newEntry
}
if prunedStruct {
Expand All @@ -408,27 +459,6 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
},
}, true
}
case *exprpb.Expr_ComprehensionExpr:
compre := node.GetComprehensionExpr()
// Only the range of the comprehension is pruned since the state tracking only records
// the last iteration of the comprehension and not each step in the evaluation which
// means that the any residuals computed in between might be inaccurate.
if newRange, pruned := p.maybePrune(compre.GetIterRange()); pruned {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterVar: compre.GetIterVar(),
IterRange: newRange,
AccuVar: compre.GetAccuVar(),
AccuInit: compre.GetAccuInit(),
LoopCondition: compre.GetLoopCondition(),
LoopStep: compre.GetLoopStep(),
Result: compre.GetResult(),
},
},
}, true
}
}
return node, false
}
Expand All @@ -438,14 +468,9 @@ func (p *astPruner) value(id int64) (ref.Val, bool) {
return val, (found && val != nil)
}

func (p *astPruner) existsWithUnknownValue(id int64) bool {
val, valueExists := p.value(id)
return valueExists && types.IsUnknown(val)
}

func (p *astPruner) existsWithKnownValue(id int64) bool {
val, valueExists := p.value(id)
return valueExists && !types.IsUnknown(val)
return valueExists && !types.IsUnknownOrError(val)
}

func (p *astPruner) nextID() int64 {
Expand All @@ -460,14 +485,39 @@ func (p *astPruner) nextID() int64 {
}
}

type astVisitor struct {
// visitEntry is called on every expr node, including those within a map/struct entry.
visitExpr func(expr *exprpb.Expr)
// visitEntry is called before entering the key, value of a map/struct entry.
visitEntry func(entry *exprpb.Expr_CreateStruct_Entry)
}

func getMaxID(expr *exprpb.Expr) int64 {
maxID := int64(1)
visit(expr, maxIDVisitor(&maxID))
return maxID
}

func maxIDVisitor(maxID *int64) astVisitor {
return astVisitor{
visitExpr: func(e *exprpb.Expr) {
if e.GetId() >= *maxID {
*maxID = e.GetId() + 1
}
},
visitEntry: func(e *exprpb.Expr_CreateStruct_Entry) {
if e.GetId() >= *maxID {
*maxID = e.GetId() + 1
}
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
},
}
}

func visit(expr *exprpb.Expr, visitor astVisitor) {
exprs := []*exprpb.Expr{expr}
for len(exprs) != 0 {
e := exprs[0]
if e.GetId() >= maxID {
maxID = e.GetId() + 1
}
visitor.visitExpr(e)
exprs = exprs[1:]
switch e.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
Expand All @@ -490,16 +540,13 @@ func getMaxID(expr *exprpb.Expr) int64 {
list := e.GetListExpr()
exprs = append(exprs, list.GetElements()...)
case *exprpb.Expr_StructExpr:
for _, entry := range expr.GetStructExpr().GetEntries() {
for _, entry := range e.GetStructExpr().GetEntries() {
visitor.visitEntry(entry)
if entry.GetMapKey() != nil {
exprs = append(exprs, entry.GetMapKey())
}
exprs = append(exprs, entry.GetValue())
if entry.GetId() >= maxID {
maxID = entry.GetId() + 1
}
}
}
}
return maxID
}
Loading