Skip to content

Commit

Permalink
Validators using navigable expr pattern (#762)
Browse files Browse the repository at this point in the history
* NavigableExpr utilities for querying the expression graph
* Shift navigation root to CheckedAST
  • Loading branch information
TristonianJones authored Jul 7, 2023
1 parent 0f0525a commit 5dc9173
Show file tree
Hide file tree
Showing 5 changed files with 1,203 additions and 44 deletions.
9 changes: 9 additions & 0 deletions common/ast/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ go_library(
name = "go_default_library",
srcs = [
"ast.go",
"expr.go",
],
importpath = "github.com/google/cel-go/common/ast",
deps = [
Expand All @@ -28,15 +29,23 @@ go_test(
name = "go_default_test",
srcs = [
"ast_test.go",
"expr_test.go",
],
embed = [
":go_default_library",
],
deps = [
"//checker:go_default_library",
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/containers:go_default_library",
"//common/decls:go_default_library",
"//common/overloads:go_default_library",
"//common/stdlib:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//parser:go_default_library",
"//test/proto3pb:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
Expand Down
88 changes: 45 additions & 43 deletions common/ast/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package ast
package ast_test

import (
"reflect"
Expand All @@ -22,6 +22,7 @@ import (
"google.golang.org/protobuf/proto"

chkdecls "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
Expand All @@ -30,16 +31,16 @@ import (
)

func TestConvertAST(t *testing.T) {
ast := &CheckedAST{
goAST := &ast.CheckedAST{
Expr: &exprpb.Expr{},
SourceInfo: &exprpb.SourceInfo{},
TypeMap: map[int64]*types.Type{
1: types.BoolType,
2: types.DynType,
},
ReferenceMap: map[int64]*ReferenceInfo{
1: NewFunctionReference(overloads.LogicalNot),
2: NewIdentReference("TRUE", types.True),
ReferenceMap: map[int64]*ast.ReferenceInfo{
1: ast.NewFunctionReference(overloads.LogicalNot),
2: ast.NewIdentReference("TRUE", types.True),
},
}

Expand All @@ -61,19 +62,19 @@ func TestConvertAST(t *testing.T) {
},
}

checkedAST, err := CheckedExprToCheckedAST(exprAST)
checkedAST, err := ast.CheckedExprToCheckedAST(exprAST)
if err != nil {
t.Fatalf("CheckedExprToCheckedAST() failed: %v", err)
}
if !reflect.DeepEqual(checkedAST.ReferenceMap, ast.ReferenceMap) ||
!reflect.DeepEqual(checkedAST.TypeMap, ast.TypeMap) {
t.Errorf("conversion to AST did not produce identical results: got %v, wanted %v", checkedAST, ast)
if !reflect.DeepEqual(checkedAST.ReferenceMap, goAST.ReferenceMap) ||
!reflect.DeepEqual(checkedAST.TypeMap, goAST.TypeMap) {
t.Errorf("conversion to AST did not produce identical results: got %v, wanted %v", checkedAST, goAST)
}
if !checkedAST.ReferenceMap[1].Equals(ast.ReferenceMap[1]) ||
!checkedAST.ReferenceMap[2].Equals(ast.ReferenceMap[2]) {
if !checkedAST.ReferenceMap[1].Equals(goAST.ReferenceMap[1]) ||
!checkedAST.ReferenceMap[2].Equals(goAST.ReferenceMap[2]) {
t.Error("converted reference info values not equal")
}
checkedExpr, err := CheckedASTToCheckedExpr(ast)
checkedExpr, err := ast.CheckedASTToCheckedExpr(goAST)
if err != nil {
t.Fatalf("CheckedASTToCheckedExpr() failed: %v", err)
}
Expand All @@ -85,68 +86,68 @@ func TestConvertAST(t *testing.T) {
func TestReferenceInfoEquals(t *testing.T) {
tests := []struct {
name string
a *ReferenceInfo
b *ReferenceInfo
a *ast.ReferenceInfo
b *ast.ReferenceInfo
equal bool
}{
{
name: "single overload equal",
a: NewFunctionReference(overloads.AddBytes),
b: NewFunctionReference(overloads.AddBytes),
a: ast.NewFunctionReference(overloads.AddBytes),
b: ast.NewFunctionReference(overloads.AddBytes),
equal: true,
},
{
name: "single overload not equal",
a: NewFunctionReference(overloads.AddBytes),
b: NewFunctionReference(overloads.AddDouble),
a: ast.NewFunctionReference(overloads.AddBytes),
b: ast.NewFunctionReference(overloads.AddDouble),
equal: false,
},
{
name: "single and multiple overload not equal",
a: NewFunctionReference(overloads.AddBytes),
b: NewFunctionReference(overloads.AddBytes, overloads.AddDouble),
a: ast.NewFunctionReference(overloads.AddBytes),
b: ast.NewFunctionReference(overloads.AddBytes, overloads.AddDouble),
equal: false,
},
{
name: "multiple overloads equal",
a: NewFunctionReference(overloads.AddBytes, overloads.AddDouble),
b: NewFunctionReference(overloads.AddDouble, overloads.AddBytes),
a: ast.NewFunctionReference(overloads.AddBytes, overloads.AddDouble),
b: ast.NewFunctionReference(overloads.AddDouble, overloads.AddBytes),
equal: true,
},
{
name: "identifier reference equal",
a: NewIdentReference("BYTES", nil),
b: NewIdentReference("BYTES", nil),
a: ast.NewIdentReference("BYTES", nil),
b: ast.NewIdentReference("BYTES", nil),
equal: true,
},
{
name: "identifier reference not equal",
a: NewIdentReference("BYTES", nil),
b: NewIdentReference("TRUE", nil),
a: ast.NewIdentReference("BYTES", nil),
b: ast.NewIdentReference("TRUE", nil),
equal: false,
},
{
name: "identifier and constant reference not equal",
a: NewIdentReference("BYTES", nil),
b: NewIdentReference("BYTES", types.Bytes("bytes")),
a: ast.NewIdentReference("BYTES", nil),
b: ast.NewIdentReference("BYTES", types.Bytes("bytes")),
equal: false,
},
{
name: "constant references equal",
a: NewIdentReference("BYTES", types.Bytes("bytes")),
b: NewIdentReference("BYTES", types.Bytes("bytes")),
a: ast.NewIdentReference("BYTES", types.Bytes("bytes")),
b: ast.NewIdentReference("BYTES", types.Bytes("bytes")),
equal: true,
},
{
name: "constant references not equal",
a: NewIdentReference("BYTES", types.Bytes("bytes")),
b: NewIdentReference("BYTES", types.Bytes("bytes-other")),
a: ast.NewIdentReference("BYTES", types.Bytes("bytes")),
b: ast.NewIdentReference("BYTES", types.Bytes("bytes-other")),
equal: false,
},
{
name: "constant and overload reference not equal",
a: NewIdentReference("BYTES", types.Bytes("bytes")),
b: NewFunctionReference(overloads.AddDouble, overloads.AddBytes),
a: ast.NewIdentReference("BYTES", types.Bytes("bytes")),
b: ast.NewFunctionReference(overloads.AddDouble, overloads.AddBytes),
equal: false,
},
}
Expand All @@ -162,26 +163,27 @@ func TestReferenceInfoEquals(t *testing.T) {
}

func TestReferenceInfoAddOverload(t *testing.T) {
add := NewFunctionReference(overloads.AddBytes)
add := ast.NewFunctionReference(overloads.AddBytes)
add.AddOverload(overloads.AddDouble)
if !add.Equals(NewFunctionReference(overloads.AddBytes, overloads.AddDouble)) {
if !add.Equals(ast.NewFunctionReference(overloads.AddBytes, overloads.AddDouble)) {
t.Error("AddOverload() did not produce equal references")
}
add.AddOverload(overloads.AddDouble)
if !add.Equals(NewFunctionReference(overloads.AddBytes, overloads.AddDouble)) {
if !add.Equals(ast.NewFunctionReference(overloads.AddBytes, overloads.AddDouble)) {
t.Error("repeated AddOverload() did not produce equal references")
}
}

func TestReferenceInfoToReferenceExprError(t *testing.T) {
out, err := ReferenceInfoToReferenceExpr(NewIdentReference("SECOND", types.Duration{Duration: time.Duration(1) * time.Second}))
out, err := ast.ReferenceInfoToReferenceExpr(
ast.NewIdentReference("SECOND", types.Duration{Duration: time.Duration(1) * time.Second}))
if err == nil {
t.Errorf("ReferenceInfoToReferenceExpr() got %v, wanted error", out)
}
}

func TestReferenceExprToReferenceInfoError(t *testing.T) {
out, err := ReferenceExprToReferenceInfo(&exprpb.Reference{Value: &exprpb.Constant{}})
out, err := ast.ReferenceExprToReferenceInfo(&exprpb.Reference{Value: &exprpb.Constant{}})
if err == nil {
t.Errorf("ReferenceExprToReferenceInfo() got %v, wanted error", out)
}
Expand All @@ -198,11 +200,11 @@ func TestConvertVal(t *testing.T) {
types.Uint(27),
}
for _, tst := range tests {
c, err := ValToConstant(tst)
c, err := ast.ValToConstant(tst)
if err != nil {
t.Errorf("ValToConstant(%v) failed: %v", tst, err)
}
v, err := ConstantToVal(c)
v, err := ast.ConstantToVal(c)
if err != nil {
t.Errorf("ValToConstant(%v) failed: %v", c, err)
}
Expand All @@ -213,14 +215,14 @@ func TestConvertVal(t *testing.T) {
}

func TestValToConstantError(t *testing.T) {
out, err := ValToConstant(types.Duration{Duration: time.Duration(10)})
out, err := ast.ValToConstant(types.Duration{Duration: time.Duration(10)})
if err == nil {
t.Errorf("ValToConstant() got %v, wanted error", out)
}
}

func TestConstantToValError(t *testing.T) {
out, err := ConstantToVal(&exprpb.Constant{})
out, err := ast.ConstantToVal(&exprpb.Constant{})
if err == nil {
t.Errorf("ConstantToVal() got %v, wanted error", out)
}
Expand Down
Loading

0 comments on commit 5dc9173

Please sign in to comment.