Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ContextProtoVars() to simplify proto-based inputs #779

Merged
merged 1 commit into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cel/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,6 @@ go_test(
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//encoding/prototext:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
"@org_golang_google_protobuf//types/known/wrapperspb:go_default_library",
],
)
57 changes: 52 additions & 5 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ import (
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
descpb "google.golang.org/protobuf/types/descriptorpb"
dynamicpb "google.golang.org/protobuf/types/dynamicpb"
durationpb "google.golang.org/protobuf/types/known/durationpb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
wrapperspb "google.golang.org/protobuf/types/known/wrapperspb"

proto2pb "github.com/google/cel-go/test/proto2pb"
proto3pb "github.com/google/cel-go/test/proto3pb"
Expand Down Expand Up @@ -1622,17 +1625,61 @@ func TestResidualAstModified(t *testing.T) {
}
}

func TestDeclareContextProto(t *testing.T) {
func TestContextProto(t *testing.T) {
descriptor := new(proto3pb.TestAllTypes).ProtoReflect().Descriptor()
option := DeclareContextProto(descriptor)
env := testEnv(t, option)
expression := `single_int64 == 1 && single_double == 1.0 && single_bool == true && single_string == '' && single_nested_message == google.expr.proto3.test.TestAllTypes.NestedMessage{}
&& single_nested_enum == google.expr.proto3.test.TestAllTypes.NestedEnum.FOO && single_duration == duration('5s') && single_timestamp == timestamp('1972-01-01T10:00:20.021-05:00')
&& single_any == google.protobuf.Any{} && repeated_int32 == [1,2] && map_string_string == {'': ''} && map_int64_nested_type == {0 : google.expr.proto3.test.NestedTestAllTypes{}}`
_, iss := env.Compile(expression)
expression := `
single_int64 == 1
&& single_double == 1.0
&& single_bool == true
&& single_string == ''
&& single_nested_message == google.expr.proto3.test.TestAllTypes.NestedMessage{}
&& standalone_enum == google.expr.proto3.test.TestAllTypes.NestedEnum.FOO
&& single_duration == duration('5s')
&& single_timestamp == timestamp(63154820)
&& single_any == null
&& single_uint32_wrapper == null
&& single_uint64_wrapper == 0u
&& repeated_int32 == [1,2]
&& map_string_string == {'': ''}
&& map_int64_nested_type == {0 : google.expr.proto3.test.NestedTestAllTypes{}}`
ast, iss := env.Compile(expression)
if iss.Err() != nil {
t.Fatalf("env.Compile(%s) failed: %s", expression, iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
in := &proto3pb.TestAllTypes{
SingleInt64: 1,
SingleDouble: 1.0,
SingleBool: true,
NestedType: &proto3pb.TestAllTypes_SingleNestedMessage{
SingleNestedMessage: &proto3pb.TestAllTypes_NestedMessage{},
},
StandaloneEnum: proto3pb.TestAllTypes_FOO,
SingleDuration: &durationpb.Duration{Seconds: 5},
SingleTimestamp: &timestamppb.Timestamp{
Seconds: 63154820,
},
SingleUint64Wrapper: wrapperspb.UInt64(0),
RepeatedInt32: []int32{1, 2},
MapStringString: map[string]string{"": ""},
MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{0: {}},
}
vars, err := ContextProtoVars(in)
if err != nil {
t.Fatalf("ContextProtoVars(%v) failed: %v", in, err)
}
out, _, err := prg.Eval(vars)
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out.Equal(types.True) != types.True {
t.Errorf("prg.Eval() got %v, wanted true", out)
}
}

func TestRegexOptimizer(t *testing.T) {
Expand Down
61 changes: 42 additions & 19 deletions cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/dynamicpb"

"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/types"
Expand Down Expand Up @@ -491,25 +490,21 @@ func CostLimit(costLimit uint64) ProgramOption {
}
}

func fieldToCELType(field protoreflect.FieldDescriptor) (*exprpb.Type, error) {
func fieldToCELType(field protoreflect.FieldDescriptor) (*Type, error) {
if field.Kind() == protoreflect.MessageKind || field.Kind() == protoreflect.GroupKind {
msgName := (string)(field.Message().FullName())
wellKnownType, found := pb.CheckedWellKnowns[msgName]
if found {
return wellKnownType, nil
}
return decls.NewObjectType(msgName), nil
return ObjectType(msgName), nil
}
if primitiveType, found := pb.CheckedPrimitives[field.Kind()]; found {
if primitiveType, found := types.ProtoCELPrimitives[field.Kind()]; found {
return primitiveType, nil
}
if field.Kind() == protoreflect.EnumKind {
return decls.Int, nil
return IntType, nil
}
return nil, fmt.Errorf("field %s type %s not implemented", field.FullName(), field.Kind().String())
}

func fieldToDecl(field protoreflect.FieldDescriptor) (*exprpb.Decl, error) {
func fieldToVariable(field protoreflect.FieldDescriptor) (EnvOption, error) {
name := string(field.Name())
if field.IsMap() {
mapKey := field.MapKey()
Expand All @@ -522,44 +517,72 @@ func fieldToDecl(field protoreflect.FieldDescriptor) (*exprpb.Decl, error) {
if err != nil {
return nil, err
}
return decls.NewVar(name, decls.NewMapType(keyType, valueType)), nil
return Variable(name, MapType(keyType, valueType)), nil
}
if field.IsList() {
elemType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return decls.NewVar(name, decls.NewListType(elemType)), nil
return Variable(name, ListType(elemType)), nil
}
celType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return decls.NewVar(name, celType), nil
return Variable(name, celType), nil
}

// DeclareContextProto returns an option to extend CEL environment with declarations from the given context proto.
// Each field of the proto defines a variable of the same name in the environment.
// https://github.com/google/cel-spec/blob/master/doc/langdef.md#evaluation-environment
func DeclareContextProto(descriptor protoreflect.MessageDescriptor) EnvOption {
return func(e *Env) (*Env, error) {
var decls []*exprpb.Decl
fields := descriptor.Fields()
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
decl, err := fieldToDecl(field)
variable, err := fieldToVariable(field)
if err != nil {
return nil, err
}
e, err = variable(e)
if err != nil {
return nil, err
}
decls = append(decls, decl)
}
var err error
e, err = Declarations(decls...)(e)
return Types(dynamicpb.NewMessage(descriptor))(e)
}
}

// ContextProtoVars uses the fields of the input proto.Messages as top-level variables within an Activation.
//
// Consider using with `DeclareContextProto` to simplify variable type declarations and publishing when using
// protocol buffers.
func ContextProtoVars(ctx proto.Message) (interpreter.Activation, error) {
if ctx == nil || !ctx.ProtoReflect().IsValid() {
return interpreter.EmptyActivation(), nil
}
reg, err := types.NewRegistry(ctx)
if err != nil {
return nil, err
}
pbRef := ctx.ProtoReflect()
typeName := string(pbRef.Descriptor().FullName())
fields := pbRef.Descriptor().Fields()
vars := make(map[string]any, fields.Len())
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
sft, found := reg.FindStructFieldType(typeName, field.TextName())
if !found {
return nil, fmt.Errorf("no such field: %s", field.TextName())
}
fieldVal, err := sft.GetFrom(ctx)
if err != nil {
return nil, err
}
return Types(dynamicpb.NewMessage(descriptor))(e)
vars[field.TextName()] = fieldVal
}
return interpreter.NewActivation(vars)
}

// EnableMacroCallTracking ensures that call expressions which are replaced by macros
Expand Down
8 changes: 5 additions & 3 deletions common/types/pb/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,13 @@ func unwrapDynamic(desc description, refMsg protoreflect.Message) (any, bool, er
unwrappedAny := &anypb.Any{}
err := Merge(unwrappedAny, msg)
if err != nil {
return nil, false, err
return nil, false, fmt.Errorf("unwrap dynamic field failed: %v", err)
}
dynMsg, err := unwrappedAny.UnmarshalNew()
if err != nil {
// Allow the error to move further up the stack as it should result in an type
// conversion error if the caller does not recover it somehow.
return nil, false, err
return nil, false, fmt.Errorf("unmarshal dynamic any failed: %v", err)
}
// Attempt to unwrap the dynamic type, otherwise return the dynamic message.
unwrapped, nested, err := unwrapDynamic(desc, dynMsg.ProtoReflect())
Expand Down Expand Up @@ -564,8 +564,10 @@ func zeroValueOf(msg proto.Message) proto.Message {
}

var (
jsonValueTypeURL = "types.googleapis.com/google.protobuf.Value"
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved

zeroValueMap = map[string]proto.Message{
"google.protobuf.Any": &anypb.Any{},
"google.protobuf.Any": &anypb.Any{TypeUrl: jsonValueTypeURL},
"google.protobuf.Duration": &dpb.Duration{},
"google.protobuf.ListValue": &structpb.ListValue{},
"google.protobuf.Struct": &structpb.Struct{},
Expand Down
5 changes: 3 additions & 2 deletions common/types/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ func singularFieldDescToCELType(field *pb.FieldDescription) *Type {
if field.IsEnum() {
return IntType
}
return protoCELPrimitives[field.ProtoKind()]
return ProtoCELPrimitives[field.ProtoKind()]
}

// defaultTypeAdapter converts go native types to CEL values.
Expand Down Expand Up @@ -657,7 +657,8 @@ func fieldTypeConversionError(field *pb.FieldDescription, err error) error {
}

var (
protoCELPrimitives = map[protoreflect.Kind]*Type{
// ProtoCELPrimitives provides a map from the protoreflect Kind to the equivalent CEL type.
ProtoCELPrimitives = map[protoreflect.Kind]*Type{
protoreflect.BoolKind: BoolType,
protoreflect.BytesKind: BytesType,
protoreflect.DoubleKind: DoubleType,
Expand Down