diff --git a/cel/cel_test.go b/cel/cel_test.go index 042cc93e..5bcffa25 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -15,6 +15,7 @@ package cel import ( + "bytes" "context" "fmt" "io/ioutil" @@ -446,9 +447,37 @@ func TestCrossTypeNumericComparisons(t *testing.T) { } } +func TestExtendStdlibFunction(t *testing.T) { + e, err := NewEnv( + Function(overloads.Contains, + MemberOverload("bytes_contains_bytes", []*Type{BytesType, BytesType}, BoolType, + BinaryBinding(func(bstr, bsub ref.Val) ref.Val { + return types.Bool(bytes.Contains([]byte(bstr.(types.Bytes)), []byte(bsub.(types.Bytes)))) + }))), + ) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + ast, iss := e.Compile(`b'string'.contains(b'tri') && 'string'.contains('tri')`) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + prg, err := e.Program(ast) + if err != nil { + t.Fatalf("Progarm(ast) failed: %v", err) + } + out, _, err := prg.Eval(NoVars()) + if err != nil { + t.Fatalf("contains check errored: %v", err) + } + if out != types.True { + t.Errorf("contains check got %v, wanted true", out) + } +} + func TestCustomTypes(t *testing.T) { reg := types.NewEmptyRegistry() - e, _ := NewEnv( + e, err := NewEnv( CustomTypeAdapter(reg), CustomTypeProvider(reg), Container("google.api.expr.v1alpha1"), @@ -459,6 +488,9 @@ func TestCustomTypes(t *testing.T) { types.StringType), Variable("expr", ObjectType("google.api.expr.v1alpha1.Expr")), ) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } ast, iss := e.Compile(` expr == Expr{id: 2, diff --git a/cel/library.go b/cel/library.go index 10037988..8de7b40a 100644 --- a/cel/library.go +++ b/cel/library.go @@ -109,29 +109,28 @@ func (stdLibrary) LibraryName() string { // CompileOptions returns options for the standard CEL function declarations and macros. func (stdLibrary) CompileOptions() []EnvOption { return []EnvOption{ + func(e *Env) (*Env, error) { + var err error + for _, fn := range stdlib.Functions() { + existing, found := e.functions[fn.Name] + if found { + fn, err = existing.Merge(fn) + if err != nil { + return nil, err + } + } + e.functions[fn.Name] = fn + } + return e, nil + }, Declarations(stdlib.TypeExprDecls()...), - Declarations(stdlib.FunctionExprDecls()...), Macros(StandardMacros...), } } // ProgramOptions returns function implementations for the standard CEL functions. func (stdLibrary) ProgramOptions() []ProgramOption { - return []ProgramOption{ - func(p *prog) (*prog, error) { - for _, fn := range stdlib.Functions() { - bindings, err := fn.Bindings() - if err != nil { - return nil, err - } - err = p.dispatcher.Add(bindings...) - if err != nil { - return nil, err - } - } - return p, nil - }, - } + return []ProgramOption{} } // OptionalTypes enable support for optional syntax and types in CEL.