Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure variables in comprehensions don't collide #1062

Merged
merged 1 commit into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 34 additions & 32 deletions ext/comprehensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package ext

import (
"fmt"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
Expand Down Expand Up @@ -220,14 +222,11 @@ func (compreV2Lib) ProgramOptions() []cel.ProgramOption {
}

func quantifierAll(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, err := extractIterVar(mef, args[0])
if err != nil {
return nil, err
}
iterVar2, err := extractIterVar(mef, args[1])
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}

return mef.NewComprehensionTwoVar(
target,
iterVar1,
Expand All @@ -241,14 +240,11 @@ func quantifierAll(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (
}

func quantifierExists(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, err := extractIterVar(mef, args[0])
if err != nil {
return nil, err
}
iterVar2, err := extractIterVar(mef, args[1])
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}

return mef.NewComprehensionTwoVar(
target,
iterVar1,
Expand All @@ -262,14 +258,11 @@ func quantifierExists(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr
}

func quantifierExistsOne(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, err := extractIterVar(mef, args[0])
if err != nil {
return nil, err
}
iterVar2, err := extractIterVar(mef, args[1])
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}

return mef.NewComprehensionTwoVar(
target,
iterVar1,
Expand All @@ -285,11 +278,7 @@ func quantifierExistsOne(mef cel.MacroExprFactory, target ast.Expr, args []ast.E
}

func transformList(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, err := extractIterVar(mef, args[0])
if err != nil {
return nil, err
}
iterVar2, err := extractIterVar(mef, args[1])
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -324,11 +313,7 @@ func transformList(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (
}

func transformMap(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, err := extractIterVar(mef, args[0])
if err != nil {
return nil, err
}
iterVar2, err := extractIterVar(mef, args[1])
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -362,11 +347,7 @@ func transformMap(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (a
}

func transformMapEntry(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
iterVar1, err := extractIterVar(mef, args[0])
if err != nil {
return nil, err
}
iterVar2, err := extractIterVar(mef, args[1])
iterVar1, iterVar2, err := extractIterVars(mef, args[0], args[1])
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -399,10 +380,31 @@ func transformMapEntry(mef cel.MacroExprFactory, target ast.Expr, args []ast.Exp
), nil
}

func extractIterVar(meh cel.MacroExprFactory, target ast.Expr) (string, *cel.Error) {
func extractIterVars(mef cel.MacroExprFactory, arg0, arg1 ast.Expr) (string, string, *cel.Error) {
iterVar1, err := extractIterVar(mef, arg0)
if err != nil {
return "", "", err
}
iterVar2, err := extractIterVar(mef, arg1)
if err != nil {
return "", "", err
}
if iterVar1 == iterVar2 {
return "", "", mef.NewError(arg1.ID(), fmt.Sprintf("duplicate variable name: %s", iterVar1))
}
if iterVar1 == parser.AccumulatorName {
return "", "", mef.NewError(arg0.ID(), "iteration variable overwrites accumulator variable")
}
if iterVar2 == parser.AccumulatorName {
return "", "", mef.NewError(arg1.ID(), "iteration variable overwrites accumulator variable")
}
return iterVar1, iterVar2, nil
}

func extractIterVar(mef cel.MacroExprFactory, target ast.Expr) (string, *cel.Error) {
iterVar, found := extractIdent(target)
if !found {
return "", meh.NewError(target.ID(), "argument must be a simple name")
return "", mef.NewError(target.ID(), "argument must be a simple name")
}
return iterVar, nil
}
12 changes: 12 additions & 0 deletions ext/comprehensions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,18 @@ func TestTwoVarComprehensionsStaticErrors(t *testing.T) {
expr string
err string
}{
{
expr: "[].all(i, i, i < i)",
err: "duplicate variable name: i",
},
{
expr: "[].all(__result__, i, __result__ < i)",
err: "iteration variable overwrites accumulator variable",
},
{
expr: "[].all(j, __result__, __result__ < j)",
err: "iteration variable overwrites accumulator variable",
},
{
expr: "[].all(i.j, k, i.j < k)",
err: "argument must be a simple name",
Expand Down