diff --git a/cel/BUILD.bazel b/cel/BUILD.bazel index de86e2c2..0905f635 100644 --- a/cel/BUILD.bazel +++ b/cel/BUILD.bazel @@ -78,5 +78,6 @@ go_test( "@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//encoding/prototext:go_default_library", "@org_golang_google_protobuf//types/known/structpb:go_default_library", + "@org_golang_google_protobuf//types/known/wrapperspb:go_default_library", ], ) diff --git a/cel/cel_test.go b/cel/cel_test.go index 0aeebc6c..27c63325 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -42,6 +42,9 @@ import ( exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" descpb "google.golang.org/protobuf/types/descriptorpb" dynamicpb "google.golang.org/protobuf/types/dynamicpb" + durationpb "google.golang.org/protobuf/types/known/durationpb" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" + wrapperspb "google.golang.org/protobuf/types/known/wrapperspb" proto2pb "github.com/google/cel-go/test/proto2pb" proto3pb "github.com/google/cel-go/test/proto3pb" @@ -1622,17 +1625,61 @@ func TestResidualAstModified(t *testing.T) { } } -func TestDeclareContextProto(t *testing.T) { +func TestContextProto(t *testing.T) { descriptor := new(proto3pb.TestAllTypes).ProtoReflect().Descriptor() option := DeclareContextProto(descriptor) env := testEnv(t, option) - expression := `single_int64 == 1 && single_double == 1.0 && single_bool == true && single_string == '' && single_nested_message == google.expr.proto3.test.TestAllTypes.NestedMessage{} - && single_nested_enum == google.expr.proto3.test.TestAllTypes.NestedEnum.FOO && single_duration == duration('5s') && single_timestamp == timestamp('1972-01-01T10:00:20.021-05:00') - && single_any == google.protobuf.Any{} && repeated_int32 == [1,2] && map_string_string == {'': ''} && map_int64_nested_type == {0 : google.expr.proto3.test.NestedTestAllTypes{}}` - _, iss := env.Compile(expression) + expression := ` + single_int64 == 1 + && single_double == 1.0 + && single_bool == true + && single_string == '' + && single_nested_message == google.expr.proto3.test.TestAllTypes.NestedMessage{} + && standalone_enum == google.expr.proto3.test.TestAllTypes.NestedEnum.FOO + && single_duration == duration('5s') + && single_timestamp == timestamp(63154820) + && single_any == null + && single_uint32_wrapper == null + && single_uint64_wrapper == 0u + && repeated_int32 == [1,2] + && map_string_string == {'': ''} + && map_int64_nested_type == {0 : google.expr.proto3.test.NestedTestAllTypes{}}` + ast, iss := env.Compile(expression) if iss.Err() != nil { t.Fatalf("env.Compile(%s) failed: %s", expression, iss.Err()) } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + in := &proto3pb.TestAllTypes{ + SingleInt64: 1, + SingleDouble: 1.0, + SingleBool: true, + NestedType: &proto3pb.TestAllTypes_SingleNestedMessage{ + SingleNestedMessage: &proto3pb.TestAllTypes_NestedMessage{}, + }, + StandaloneEnum: proto3pb.TestAllTypes_FOO, + SingleDuration: &durationpb.Duration{Seconds: 5}, + SingleTimestamp: ×tamppb.Timestamp{ + Seconds: 63154820, + }, + SingleUint64Wrapper: wrapperspb.UInt64(0), + RepeatedInt32: []int32{1, 2}, + MapStringString: map[string]string{"": ""}, + MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{0: {}}, + } + vars, err := ContextProtoVars(in) + if err != nil { + t.Fatalf("ContextProtoVars(%v) failed: %v", in, err) + } + out, _, err := prg.Eval(vars) + if err != nil { + t.Fatalf("prg.Eval() failed: %v", err) + } + if out.Equal(types.True) != types.True { + t.Errorf("prg.Eval() got %v, wanted true", out) + } } func TestRegexOptimizer(t *testing.T) { diff --git a/cel/options.go b/cel/options.go index d3890e84..d47f55d8 100644 --- a/cel/options.go +++ b/cel/options.go @@ -23,7 +23,6 @@ import ( "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/dynamicpb" - "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/functions" "github.com/google/cel-go/common/types" @@ -491,25 +490,21 @@ func CostLimit(costLimit uint64) ProgramOption { } } -func fieldToCELType(field protoreflect.FieldDescriptor) (*exprpb.Type, error) { +func fieldToCELType(field protoreflect.FieldDescriptor) (*Type, error) { if field.Kind() == protoreflect.MessageKind || field.Kind() == protoreflect.GroupKind { msgName := (string)(field.Message().FullName()) - wellKnownType, found := pb.CheckedWellKnowns[msgName] - if found { - return wellKnownType, nil - } - return decls.NewObjectType(msgName), nil + return ObjectType(msgName), nil } - if primitiveType, found := pb.CheckedPrimitives[field.Kind()]; found { + if primitiveType, found := types.ProtoCELPrimitives[field.Kind()]; found { return primitiveType, nil } if field.Kind() == protoreflect.EnumKind { - return decls.Int, nil + return IntType, nil } return nil, fmt.Errorf("field %s type %s not implemented", field.FullName(), field.Kind().String()) } -func fieldToDecl(field protoreflect.FieldDescriptor) (*exprpb.Decl, error) { +func fieldToVariable(field protoreflect.FieldDescriptor) (EnvOption, error) { name := string(field.Name()) if field.IsMap() { mapKey := field.MapKey() @@ -522,20 +517,20 @@ func fieldToDecl(field protoreflect.FieldDescriptor) (*exprpb.Decl, error) { if err != nil { return nil, err } - return decls.NewVar(name, decls.NewMapType(keyType, valueType)), nil + return Variable(name, MapType(keyType, valueType)), nil } if field.IsList() { elemType, err := fieldToCELType(field) if err != nil { return nil, err } - return decls.NewVar(name, decls.NewListType(elemType)), nil + return Variable(name, ListType(elemType)), nil } celType, err := fieldToCELType(field) if err != nil { return nil, err } - return decls.NewVar(name, celType), nil + return Variable(name, celType), nil } // DeclareContextProto returns an option to extend CEL environment with declarations from the given context proto. @@ -543,23 +538,51 @@ func fieldToDecl(field protoreflect.FieldDescriptor) (*exprpb.Decl, error) { // https://github.com/google/cel-spec/blob/master/doc/langdef.md#evaluation-environment func DeclareContextProto(descriptor protoreflect.MessageDescriptor) EnvOption { return func(e *Env) (*Env, error) { - var decls []*exprpb.Decl fields := descriptor.Fields() for i := 0; i < fields.Len(); i++ { field := fields.Get(i) - decl, err := fieldToDecl(field) + variable, err := fieldToVariable(field) + if err != nil { + return nil, err + } + e, err = variable(e) if err != nil { return nil, err } - decls = append(decls, decl) } - var err error - e, err = Declarations(decls...)(e) + return Types(dynamicpb.NewMessage(descriptor))(e) + } +} + +// ContextProtoVars uses the fields of the input proto.Messages as top-level variables within an Activation. +// +// Consider using with `DeclareContextProto` to simplify variable type declarations and publishing when using +// protocol buffers. +func ContextProtoVars(ctx proto.Message) (interpreter.Activation, error) { + if ctx == nil || !ctx.ProtoReflect().IsValid() { + return interpreter.EmptyActivation(), nil + } + reg, err := types.NewRegistry(ctx) + if err != nil { + return nil, err + } + pbRef := ctx.ProtoReflect() + typeName := string(pbRef.Descriptor().FullName()) + fields := pbRef.Descriptor().Fields() + vars := make(map[string]any, fields.Len()) + for i := 0; i < fields.Len(); i++ { + field := fields.Get(i) + sft, found := reg.FindStructFieldType(typeName, field.TextName()) + if !found { + return nil, fmt.Errorf("no such field: %s", field.TextName()) + } + fieldVal, err := sft.GetFrom(ctx) if err != nil { return nil, err } - return Types(dynamicpb.NewMessage(descriptor))(e) + vars[field.TextName()] = fieldVal } + return interpreter.NewActivation(vars) } // EnableMacroCallTracking ensures that call expressions which are replaced by macros diff --git a/common/types/pb/type.go b/common/types/pb/type.go index 5f69b339..cf7405ee 100644 --- a/common/types/pb/type.go +++ b/common/types/pb/type.go @@ -467,13 +467,13 @@ func unwrapDynamic(desc description, refMsg protoreflect.Message) (any, bool, er unwrappedAny := &anypb.Any{} err := Merge(unwrappedAny, msg) if err != nil { - return nil, false, err + return nil, false, fmt.Errorf("unwrap dynamic field failed: %v", err) } dynMsg, err := unwrappedAny.UnmarshalNew() if err != nil { // Allow the error to move further up the stack as it should result in an type // conversion error if the caller does not recover it somehow. - return nil, false, err + return nil, false, fmt.Errorf("unmarshal dynamic any failed: %v", err) } // Attempt to unwrap the dynamic type, otherwise return the dynamic message. unwrapped, nested, err := unwrapDynamic(desc, dynMsg.ProtoReflect()) @@ -564,8 +564,10 @@ func zeroValueOf(msg proto.Message) proto.Message { } var ( + jsonValueTypeURL = "types.googleapis.com/google.protobuf.Value" + zeroValueMap = map[string]proto.Message{ - "google.protobuf.Any": &anypb.Any{}, + "google.protobuf.Any": &anypb.Any{TypeUrl: jsonValueTypeURL}, "google.protobuf.Duration": &dpb.Duration{}, "google.protobuf.ListValue": &structpb.ListValue{}, "google.protobuf.Struct": &structpb.Struct{}, diff --git a/common/types/provider.go b/common/types/provider.go index a8cb5983..52e34817 100644 --- a/common/types/provider.go +++ b/common/types/provider.go @@ -346,7 +346,7 @@ func singularFieldDescToCELType(field *pb.FieldDescription) *Type { if field.IsEnum() { return IntType } - return protoCELPrimitives[field.ProtoKind()] + return ProtoCELPrimitives[field.ProtoKind()] } // defaultTypeAdapter converts go native types to CEL values. @@ -657,7 +657,8 @@ func fieldTypeConversionError(field *pb.FieldDescription, err error) error { } var ( - protoCELPrimitives = map[protoreflect.Kind]*Type{ + // ProtoCELPrimitives provides a map from the protoreflect Kind to the equivalent CEL type. + ProtoCELPrimitives = map[protoreflect.Kind]*Type{ protoreflect.BoolKind: BoolType, protoreflect.BytesKind: BytesType, protoreflect.DoubleKind: DoubleType,