Skip to content

Commit

Permalink
Additional string.format() test for message field accesses (#694)
Browse files Browse the repository at this point in the history
* Additional test for message field accesses
* Improvements in attribute identification yield more accurate state tracking and correlation to type ids from checked expressions
  • Loading branch information
TristonianJones authored May 2, 2023
1 parent 65c30b8 commit 1fcac88
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 15 deletions.
45 changes: 42 additions & 3 deletions ext/strings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ import (
"time"
"unicode/utf8"

"google.golang.org/protobuf/proto"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"

proto3pb "github.com/google/cel-go/test/proto3pb"
)

Expand Down Expand Up @@ -941,6 +944,19 @@ func TestStringFormat(t *testing.T) {
expectedRuntimeCost: 13,
expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13},
},
{
name: "message field support",
format: "message field msg.single_int32: %d, msg.single_double: %.1f",
formatArgs: `msg.single_int32, msg.single_double`,
dynArgs: map[string]any{
"msg": &proto3pb.TestAllTypes{
SingleInt32: 2,
SingleDouble: 1.0,
},
},
locale: "en_US",
expectedOutput: `message field msg.single_int32: 2, msg.single_double: 1.0`,
},
{
name: "unrecognized formatting clause",
format: "%a",
Expand Down Expand Up @@ -1209,8 +1225,31 @@ func TestStringFormat(t *testing.T) {
buildVariables := func(vars map[string]any) []cel.EnvOption {
opts := make([]cel.EnvOption, len(vars))
i := 0
for name := range vars {
opts[i] = cel.Variable(name, cel.DynType)
for name, value := range vars {
t := cel.DynType
switch v := value.(type) {
case proto.Message:
t = cel.ObjectType(string(v.ProtoReflect().Descriptor().FullName()))
case types.Bool:
t = cel.BoolType
case types.Bytes:
t = cel.BytesType
case types.Double:
t = cel.DoubleType
case types.Duration:
t = cel.DurationType
case types.Int:
t = cel.IntType
case types.Null:
t = cel.NullType
case types.String:
t = cel.StringType
case types.Timestamp:
t = cel.TimestampType
case types.Uint:
t = cel.UintType
}
opts[i] = cel.Variable(name, t)
i++
}
return opts
Expand Down Expand Up @@ -1266,7 +1305,7 @@ func TestStringFormat(t *testing.T) {
checkCase(out, tt.expectedOutput, err, tt.err, t)
if tt.locale == "" {
// if the test has no locale specified, then that means it
// should have the same output regardless of lcoale
// should have the same output regardless of locale
t.Run("no change on locale", func(t *testing.T) {
out, err := runCase(tt.format, tt.formatArgs, "da_DK", tt.dynArgs, tt.skipCompileCheck, tt.expectedRuntimeCost, tt.expectedEstimatedCost, t)
checkCase(out, tt.expectedOutput, err, tt.err, t)
Expand Down
19 changes: 16 additions & 3 deletions interpreter/attributes.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,11 @@ type absoluteAttribute struct {

// ID implements the Attribute interface method.
func (a *absoluteAttribute) ID() int64 {
return a.id
qual_count := len(a.qualifiers)
if qual_count == 0 {
return a.id
}
return a.qualifiers[qual_count-1].ID()
}

// IsOptional returns trivially false for an attribute as the attribute represents a fully
Expand Down Expand Up @@ -323,6 +327,11 @@ type conditionalAttribute struct {

// ID is an implementation of the Attribute interface method.
func (a *conditionalAttribute) ID() int64 {
// There's a field access after the conditional.
if a.truthy.ID() == a.falsy.ID() {
return a.truthy.ID()
}
// Otherwise return the conditional id as the consistent id being tracked.
return a.id
}

Expand Down Expand Up @@ -387,7 +396,7 @@ type maybeAttribute struct {

// ID is an implementation of the Attribute interface method.
func (a *maybeAttribute) ID() int64 {
return a.id
return a.attrs[0].ID()
}

// IsOptional returns trivially false for an attribute as the attribute represents a fully
Expand Down Expand Up @@ -504,7 +513,11 @@ type relativeAttribute struct {

// ID is an implementation of the Attribute interface method.
func (a *relativeAttribute) ID() int64 {
return a.id
qual_count := len(a.qualifiers)
if qual_count == 0 {
return a.id
}
return a.qualifiers[qual_count-1].ID()
}

// IsOptional returns trivially false for an attribute as the attribute represents a fully
Expand Down
53 changes: 44 additions & 9 deletions interpreter/attributes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -869,8 +869,8 @@ func TestAttributeStateTracking(t *testing.T) {
in: map[string]any{},
out: types.True,
state: map[int64]any{
// overall expression
1: true,
// [{"field": true}]
1: []ref.Val{types.DefaultTypeAdapter.NativeToValue(map[ref.Val]ref.Val{types.String("field"): types.True})},
// [{"field": true}][0]
6: map[ref.Val]ref.Val{types.String("field"): types.True},
// [{"field": true}][0].field
Expand All @@ -893,8 +893,6 @@ func TestAttributeStateTracking(t *testing.T) {
},
out: types.True,
state: map[int64]any{
// overall expression
1: true,
// a[1]
2: map[string]bool{"two": true},
// a[1]["two"]
Expand All @@ -918,8 +916,6 @@ func TestAttributeStateTracking(t *testing.T) {
},
out: types.String("dex"),
state: map[int64]any{
// overall expression
1: "dex",
// a[1]
2: map[int64]any{
1: 0,
Expand Down Expand Up @@ -948,8 +944,6 @@ func TestAttributeStateTracking(t *testing.T) {
},
out: types.String("index"),
state: map[int64]any{
// overall expression
1: "index",
// a[1]
2: map[int64]any{
1: 0,
Expand All @@ -969,6 +963,46 @@ func TestAttributeStateTracking(t *testing.T) {
10: int64(0),
},
},
{
expr: `true ? a : b`,
env: []*exprpb.Decl{
decls.NewVar("a", decls.String),
decls.NewVar("b", decls.String),
},
in: map[string]any{
"a": "hello",
"b": "world",
},
out: types.String("hello"),
state: map[int64]any{
// 'hello'
2: types.String("hello"),
},
},
{
expr: `(a.size() != 0 ? a : b)[0]`,
env: []*exprpb.Decl{
decls.NewVar("a", decls.NewListType(decls.String)),
decls.NewVar("b", decls.NewListType(decls.String)),
},
in: map[string]any{
"a": []string{"hello", "world"},
"b": []string{"world", "hello"},
},
out: types.String("hello"),
state: map[int64]any{
// ["hello", "world"]
1: types.DefaultTypeAdapter.NativeToValue([]string{"hello", "world"}),
// ["hello", "world"].size() // 2
2: types.Int(2),
// ["hello", "world"].size() != 0
3: types.True,
// constant 0
4: types.IntZero,
// 'hello'
8: types.String("hello"),
},
},
}
for _, test := range tests {
tc := test
Expand Down Expand Up @@ -1014,7 +1048,8 @@ func TestAttributeStateTracking(t *testing.T) {
t.Errorf("state not found for %d=%v", id, val)
continue
}
if !reflect.DeepEqual(stVal.Value(), val) {
wantStVal := types.DefaultTypeAdapter.NativeToValue(val)
if wantStVal.Equal(stVal) != types.True {
t.Errorf("got %v, wanted %v for id: %d", stVal.Value(), val, id)
}
}
Expand Down

0 comments on commit 1fcac88

Please sign in to comment.