Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional string.format() test for message field accesses #694

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions ext/strings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ import (
"time"
"unicode/utf8"

"google.golang.org/protobuf/proto"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"

proto3pb "github.com/google/cel-go/test/proto3pb"
)

Expand Down Expand Up @@ -941,6 +944,19 @@ func TestStringFormat(t *testing.T) {
expectedRuntimeCost: 13,
expectedEstimatedCost: checker.CostEstimate{Min: 13, Max: 13},
},
{
name: "message field support",
format: "message field msg.single_int32: %d, msg.single_double: %.1f",
formatArgs: `msg.single_int32, msg.single_double`,
dynArgs: map[string]any{
"msg": &proto3pb.TestAllTypes{
SingleInt32: 2,
SingleDouble: 1.0,
},
},
locale: "en_US",
expectedOutput: `message field msg.single_int32: 2, msg.single_double: 1.0`,
},
{
name: "unrecognized formatting clause",
format: "%a",
Expand Down Expand Up @@ -1209,8 +1225,31 @@ func TestStringFormat(t *testing.T) {
buildVariables := func(vars map[string]any) []cel.EnvOption {
opts := make([]cel.EnvOption, len(vars))
i := 0
for name := range vars {
opts[i] = cel.Variable(name, cel.DynType)
for name, value := range vars {
t := cel.DynType
switch v := value.(type) {
case proto.Message:
t = cel.ObjectType(string(v.ProtoReflect().Descriptor().FullName()))
case types.Bool:
t = cel.BoolType
case types.Bytes:
t = cel.BytesType
case types.Double:
t = cel.DoubleType
case types.Duration:
t = cel.DurationType
case types.Int:
t = cel.IntType
case types.Null:
t = cel.NullType
case types.String:
t = cel.StringType
case types.Timestamp:
t = cel.TimestampType
case types.Uint:
t = cel.UintType
}
opts[i] = cel.Variable(name, t)
i++
}
return opts
Expand Down Expand Up @@ -1266,7 +1305,7 @@ func TestStringFormat(t *testing.T) {
checkCase(out, tt.expectedOutput, err, tt.err, t)
if tt.locale == "" {
// if the test has no locale specified, then that means it
// should have the same output regardless of lcoale
// should have the same output regardless of locale
t.Run("no change on locale", func(t *testing.T) {
out, err := runCase(tt.format, tt.formatArgs, "da_DK", tt.dynArgs, tt.skipCompileCheck, tt.expectedRuntimeCost, tt.expectedEstimatedCost, t)
checkCase(out, tt.expectedOutput, err, tt.err, t)
Expand Down
19 changes: 16 additions & 3 deletions interpreter/attributes.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,11 @@ type absoluteAttribute struct {

// ID implements the Attribute interface method.
func (a *absoluteAttribute) ID() int64 {
return a.id
qual_count := len(a.qualifiers)
if qual_count == 0 {
return a.id
}
return a.qualifiers[qual_count-1].ID()
}

// IsOptional returns trivially false for an attribute as the attribute represents a fully
Expand Down Expand Up @@ -315,6 +319,11 @@ type conditionalAttribute struct {

// ID is an implementation of the Attribute interface method.
func (a *conditionalAttribute) ID() int64 {
// There's a field access after the conditional.
if a.truthy.ID() == a.falsy.ID() {
return a.truthy.ID()
}
// Otherwise return the conditional id as the consistent id being tracked.
return a.id
}

Expand Down Expand Up @@ -379,7 +388,7 @@ type maybeAttribute struct {

// ID is an implementation of the Attribute interface method.
func (a *maybeAttribute) ID() int64 {
return a.id
return a.attrs[0].ID()
}

// IsOptional returns trivially false for an attribute as the attribute represents a fully
Expand Down Expand Up @@ -494,7 +503,11 @@ type relativeAttribute struct {

// ID is an implementation of the Attribute interface method.
func (a *relativeAttribute) ID() int64 {
return a.id
qual_count := len(a.qualifiers)
if qual_count == 0 {
return a.id
}
return a.qualifiers[qual_count-1].ID()
}

// IsOptional returns trivially false for an attribute as the attribute represents a fully
Expand Down
53 changes: 44 additions & 9 deletions interpreter/attributes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -869,8 +869,8 @@ func TestAttributeStateTracking(t *testing.T) {
in: map[string]any{},
out: types.True,
state: map[int64]any{
// overall expression
1: true,
// [{"field": true}]
1: []ref.Val{types.DefaultTypeAdapter.NativeToValue(map[ref.Val]ref.Val{types.String("field"): types.True})},
// [{"field": true}][0]
6: map[ref.Val]ref.Val{types.String("field"): types.True},
// [{"field": true}][0].field
Expand All @@ -893,8 +893,6 @@ func TestAttributeStateTracking(t *testing.T) {
},
out: types.True,
state: map[int64]any{
// overall expression
1: true,
// a[1]
2: map[string]bool{"two": true},
// a[1]["two"]
Expand All @@ -918,8 +916,6 @@ func TestAttributeStateTracking(t *testing.T) {
},
out: types.String("dex"),
state: map[int64]any{
// overall expression
1: "dex",
// a[1]
2: map[int64]any{
1: 0,
Expand Down Expand Up @@ -948,8 +944,6 @@ func TestAttributeStateTracking(t *testing.T) {
},
out: types.String("index"),
state: map[int64]any{
// overall expression
1: "index",
// a[1]
2: map[int64]any{
1: 0,
Expand All @@ -969,6 +963,46 @@ func TestAttributeStateTracking(t *testing.T) {
10: int64(0),
},
},
{
expr: `true ? a : b`,
env: []*exprpb.Decl{
decls.NewVar("a", decls.String),
decls.NewVar("b", decls.String),
},
in: map[string]any{
"a": "hello",
"b": "world",
},
out: types.String("hello"),
state: map[int64]any{
// 'hello'
2: types.String("hello"),
},
},
{
expr: `(a.size() != 0 ? a : b)[0]`,
env: []*exprpb.Decl{
decls.NewVar("a", decls.NewListType(decls.String)),
decls.NewVar("b", decls.NewListType(decls.String)),
},
in: map[string]any{
"a": []string{"hello", "world"},
"b": []string{"world", "hello"},
},
out: types.String("hello"),
state: map[int64]any{
// ["hello", "world"]
1: types.DefaultTypeAdapter.NativeToValue([]string{"hello", "world"}),
// ["hello", "world"].size() // 2
2: types.Int(2),
// ["hello", "world"].size() != 0
3: types.True,
// constant 0
4: types.IntZero,
// 'hello'
8: types.String("hello"),
},
},
}
for _, test := range tests {
tc := test
Expand Down Expand Up @@ -1014,7 +1048,8 @@ func TestAttributeStateTracking(t *testing.T) {
t.Errorf("state not found for %d=%v", id, val)
continue
}
if !reflect.DeepEqual(stVal.Value(), val) {
wantStVal := types.DefaultTypeAdapter.NativeToValue(val)
if wantStVal.Equal(stVal) != types.True {
t.Errorf("got %v, wanted %v for id: %d", stVal.Value(), val, id)
}
}
Expand Down