Skip to content

Commit

Permalink
Update the Go AST representation to handle a second iteration variable (
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones authored Sep 30, 2024
1 parent 4b8b15b commit a118ff0
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 9 deletions.
5 changes: 3 additions & 2 deletions common/ast/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ go_library(
"navigable.go",
],
importpath = "github.com/google/cel-go/common/ast",
deps = [
deps = [
"//common:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
Expand All @@ -35,12 +35,13 @@ go_test(
embed = [
":go_default_library",
],
deps = [
deps = [
"//checker:go_default_library",
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/containers:go_default_library",
"//common/decls:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/stdlib:go_default_library",
"//common/types:go_default_library",
Expand Down
4 changes: 3 additions & 1 deletion common/ast/conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,10 @@ func exprComprehension(factory ExprFactory, id int64, comp *exprpb.Expr_Comprehe
if err != nil {
return nil, err
}
return factory.NewComprehension(id,
return factory.NewComprehensionTwoVar(id,
iterRange,
comp.GetIterVar(),
comp.GetIterVar2(),
comp.GetAccuVar(),
accuInit,
loopCond,
Expand Down Expand Up @@ -363,6 +364,7 @@ func protoComprehension(id int64, comp ComprehensionExpr) (*exprpb.Expr, error)
ExprKind: &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterVar: comp.IterVar(),
IterVar2: comp.IterVar2(),
IterRange: iterRange,
AccuVar: comp.AccuVar(),
AccuInit: accuInit,
Expand Down
205 changes: 201 additions & 4 deletions common/ast/conversion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
chkdecls "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
Expand All @@ -35,6 +36,7 @@ import (
)

func TestConvertAST(t *testing.T) {
fac := ast.NewExprFactory()
tests := []struct {
goAST *ast.AST
pbAST *exprpb.CheckedExpr
Expand Down Expand Up @@ -68,6 +70,115 @@ func TestConvertAST(t *testing.T) {
},
},
},
{
goAST: ast.NewAST(
fac.NewComprehensionTwoVar(1,
fac.NewIdent(2, "data"),
"i",
"v",
"__result__",
fac.NewList(3, []ast.Expr{}, []int32{}),
fac.NewLiteral(4, types.True),
fac.NewCall(8, operators.Add,
fac.NewAccuIdent(9),
fac.NewCall(5, operators.Add,
fac.NewIdent(6, "i"),
fac.NewIdent(7, "v"),
)),
fac.NewAccuIdent(10),
), nil),
pbAST: &exprpb.CheckedExpr{
Expr: &exprpb.Expr{
Id: 1,
ExprKind: &exprpb.Expr_ComprehensionExpr{
ComprehensionExpr: &exprpb.Expr_Comprehension{
IterRange: &exprpb.Expr{
Id: 2,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: "data",
},
},
},
IterVar: "i",
IterVar2: "v",
AccuVar: "__result__",
AccuInit: &exprpb.Expr{
Id: 3,
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{},
},
},
LoopCondition: &exprpb.Expr{
Id: 4,
ExprKind: &exprpb.Expr_ConstExpr{
ConstExpr: &exprpb.Constant{
ConstantKind: &exprpb.Constant_BoolValue{
BoolValue: true,
},
},
},
},
LoopStep: &exprpb.Expr{
Id: 8,
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: operators.Add,
Args: []*exprpb.Expr{
{
Id: 9,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: "__result__",
},
},
},
{
Id: 5,
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: operators.Add,
Args: []*exprpb.Expr{
{
Id: 6,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: "i",
},
},
},
{
Id: 7,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: "v",
},
},
},
},
},
},
},
},
},
},
},
Result: &exprpb.Expr{
Id: 10,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: "__result__",
},
},
},
},
},
},
SourceInfo: &exprpb.SourceInfo{},
TypeMap: map[int64]*exprpb.Type{},
ReferenceMap: map[int64]*exprpb.Reference{},
},
},
}

for i, tst := range tests {
Expand All @@ -83,11 +194,13 @@ func TestConvertAST(t *testing.T) {
!reflect.DeepEqual(checkedAST.TypeMap(), goAST.TypeMap()) {
t.Errorf("conversion to AST did not produce identical results: got %v, wanted %v", checkedAST, goAST)
}
if !checkedAST.ReferenceMap()[1].Equals(goAST.ReferenceMap()[1]) ||
!checkedAST.ReferenceMap()[2].Equals(goAST.ReferenceMap()[2]) {
t.Error("converted reference info values not equal")
if len(checkedAST.ReferenceMap()) > 2 {
if !checkedAST.ReferenceMap()[1].Equals(goAST.ReferenceMap()[1]) ||
!checkedAST.ReferenceMap()[2].Equals(goAST.ReferenceMap()[2]) {
t.Error("converted reference info values not equal")
}
}
checkedExpr, err := ast.ToProto(goAST)
checkedExpr, err := ast.ToProto(checkedAST)
if err != nil {
t.Fatalf("ASTToProto() failed: %v", err)
}
Expand All @@ -98,6 +211,90 @@ func TestConvertAST(t *testing.T) {
}
}

func TestConvertProtoToEntryExpr(t *testing.T) {
fac := ast.NewExprFactory()
tests := []struct {
goAST ast.EntryExpr
pbAST *exprpb.Expr_CreateStruct_Entry
}{
{
goAST: fac.NewMapEntry(1,
fac.NewIdent(2, "var_key"),
fac.NewLiteral(3, types.String("hello")),
true),
pbAST: &exprpb.Expr_CreateStruct_Entry{
Id: 1,
KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{
MapKey: &exprpb.Expr{
Id: 2,
ExprKind: &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: "var_key",
},
},
},
},
Value: &exprpb.Expr{
Id: 3,
ExprKind: &exprpb.Expr_ConstExpr{
ConstExpr: &exprpb.Constant{
ConstantKind: &exprpb.Constant_StringValue{
StringValue: "hello",
},
},
},
},
OptionalEntry: true,
},
},
{
goAST: fac.NewStructField(1,
"field_name",
fac.NewLiteral(2, types.String("hello")),
false),
pbAST: &exprpb.Expr_CreateStruct_Entry{
Id: 1,
KeyKind: &exprpb.Expr_CreateStruct_Entry_FieldKey{
FieldKey: "field_name",
},
Value: &exprpb.Expr{
Id: 2,
ExprKind: &exprpb.Expr_ConstExpr{
ConstExpr: &exprpb.Constant{
ConstantKind: &exprpb.Constant_StringValue{
StringValue: "hello",
},
},
},
},
OptionalEntry: false,
},
},
}

for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
goAST := tc.goAST
pbAST := tc.pbAST
gotGoAST, err := ast.ProtoToEntryExpr(pbAST)
if err != nil {
t.Fatalf("ProtoToEntryExpr() failed: %v", err)
}
if !reflect.DeepEqual(goAST, gotGoAST) {
t.Errorf("conversion to go AST did not produce identical results: got %v, wanted %v", gotGoAST, goAST)
}
gotProtoAST, err := ast.EntryExprToProto(gotGoAST)
if err != nil {
t.Fatalf("EntryExprToProto() failed: %v", err)
}
if !proto.Equal(gotProtoAST, pbAST) {
t.Errorf("conversion to protobuf did not produce identical results: got %v, wanted %v", gotProtoAST, pbAST)
}
})
}
}

func TestConvertExpr(t *testing.T) {
fac := ast.NewExprFactory()
tests := []struct {
Expand Down
24 changes: 24 additions & 0 deletions common/ast/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,22 @@ type ComprehensionExpr interface {
IterRange() Expr

// IterVar returns the iteration variable name.
//
// For one-variable comprehensions, the iter var refers to the element value
// when iterating over a list, or the map key when iterating over a map.
//
// For two-variable comprehneions, the iter var refers to the list index or the
// map key.
IterVar() string

// IterVar2 returns the second iteration variable name.
//
// When the value is non-empty, the comprehension is a two-variable comprehension.
IterVar2() string

// HasIterVar2 returns true if the second iteration variable is non-empty.
HasIterVar2() bool

// AccuVar returns the accumulation variable name.
AccuVar() string

Expand Down Expand Up @@ -397,6 +411,7 @@ func (e *expr) SetKindCase(other Expr) {
e.exprKindCase = &baseComprehensionExpr{
iterRange: c.IterRange(),
iterVar: c.IterVar(),
iterVar2: c.IterVar2(),
accuVar: c.AccuVar(),
accuInit: c.AccuInit(),
loopCond: c.LoopCondition(),
Expand Down Expand Up @@ -505,6 +520,7 @@ var _ ComprehensionExpr = &baseComprehensionExpr{}
type baseComprehensionExpr struct {
iterRange Expr
iterVar string
iterVar2 string
accuVar string
accuInit Expr
loopCond Expr
Expand All @@ -527,6 +543,14 @@ func (e *baseComprehensionExpr) IterVar() string {
return e.iterVar
}

func (e *baseComprehensionExpr) IterVar2() string {
return e.iterVar2
}

func (e *baseComprehensionExpr) HasIterVar2() bool {
return e.iterVar2 != ""
}

func (e *baseComprehensionExpr) AccuVar() string {
return e.accuVar
}
Expand Down
14 changes: 12 additions & 2 deletions common/ast/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ type ExprFactory interface {
// NewCall creates an Expr value representing a global function call.
NewCall(id int64, function string, args ...Expr) Expr

// NewComprehension creates an Expr value representing a comprehension over a value range.
// NewComprehension creates an Expr value representing a one-variable comprehension over a value range.
NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCondition, loopStep, result Expr) Expr

// NewComprehensionTwoVar creates an Expr value representing a two-variable comprehension over a value range.
NewComprehensionTwoVar(id int64, iterRange Expr, iterVar, iterVar2, accuVar string, accuInit, loopCondition, loopStep, result Expr) Expr

// NewMemberCall creates an Expr value representing a member function call.
NewMemberCall(id int64, function string, receiver Expr, args ...Expr) Expr

Expand Down Expand Up @@ -111,11 +114,17 @@ func (fac *baseExprFactory) NewMemberCall(id int64, function string, target Expr
}

func (fac *baseExprFactory) NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCond, loopStep, result Expr) Expr {
// Set the iter_var2 to empty string to indicate the second variable is omitted
return fac.NewComprehensionTwoVar(id, iterRange, iterVar, "", accuVar, accuInit, loopCond, loopStep, result)
}

func (fac *baseExprFactory) NewComprehensionTwoVar(id int64, iterRange Expr, iterVar, iterVar2, accuVar string, accuInit, loopCond, loopStep, result Expr) Expr {
return fac.newExpr(
id,
&baseComprehensionExpr{
iterRange: iterRange,
iterVar: iterVar,
iterVar2: iterVar2,
accuVar: accuVar,
accuInit: accuInit,
loopCond: loopCond,
Expand Down Expand Up @@ -223,9 +232,10 @@ func (fac *baseExprFactory) CopyExpr(e Expr) Expr {
return fac.NewMemberCall(e.ID(), c.FunctionName(), fac.CopyExpr(c.Target()), argsCopy...)
case ComprehensionKind:
compre := e.AsComprehension()
return fac.NewComprehension(e.ID(),
return fac.NewComprehensionTwoVar(e.ID(),
fac.CopyExpr(compre.IterRange()),
compre.IterVar(),
compre.IterVar2(),
compre.AccuVar(),
fac.CopyExpr(compre.AccuInit()),
fac.CopyExpr(compre.LoopCondition()),
Expand Down
8 changes: 8 additions & 0 deletions common/ast/navigable.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,14 @@ func (comp navigableComprehensionImpl) IterVar() string {
return comp.Expr.AsComprehension().IterVar()
}

func (comp navigableComprehensionImpl) IterVar2() string {
return comp.Expr.AsComprehension().IterVar2()
}

func (comp navigableComprehensionImpl) HasIterVar2() bool {
return comp.Expr.AsComprehension().HasIterVar2()
}

func (comp navigableComprehensionImpl) AccuVar() string {
return comp.Expr.AsComprehension().AccuVar()
}
Expand Down
Loading

0 comments on commit a118ff0

Please sign in to comment.