From aee2bd9246ec5f3e6da9d8260defcf1f1f733578 Mon Sep 17 00:00:00 2001 From: Patrick East Date: Tue, 10 Mar 2020 19:28:23 -0700 Subject: [PATCH] ast: Don't allow calls in function signatures Rather than causing a panic calls in function declaration args will now just raise a parse error. In the future we could potentially support them but we would need to sort out some additional details/ambiguity around behavior. Its unclear that any users need this behavior so for now we'll just correct the panic. Fixes: #2081 Signed-off-by: Patrick East --- ast/compile.go | 5 +- ast/compile_test.go | 131 ++++++++++++++++++++++++++++++++------------ ast/env.go | 4 ++ 3 files changed, 103 insertions(+), 37 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index b043c708fe..a6b0298926 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -1207,8 +1207,11 @@ func (vis *ruleArgLocalRewriter) Visit(x interface{}) Visitor { // Scalars are no-ops. Comprehensions are handled above. Sets must not // contain variables. return nil + case Call: + vis.errs = append(vis.errs, NewError(CompileErr, t.Location, "rule arguments cannot contain calls")) + return nil default: - // Recurse on refs, arrays, and calls. Any embedded + // Recurse on refs and arrays. Any embedded // variables can be rewritten. return vis } diff --git a/ast/compile_test.go b/ast/compile_test.go index 7fc1f06811..879bf3500e 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -930,47 +930,106 @@ func TestCompilerExprExpansion(t *testing.T) { } func TestCompilerRewriteExprTerms(t *testing.T) { - module := ` - package test - - p { x = a + b * y } - - q[[data.test.f(x)]] { x = 1 } - - r = [data.test.f(x)] { x = 1 } - - f(x) = data.test.g(x) - - pi = 3 + .14 - - with_value { 1 with input as f(1) } - ` - - compiler := NewCompiler() - compiler.Modules = map[string]*Module{ - "test": MustParseModule(module), + cases := []struct { + note string + module string + expected interface{} + }{ + { + note: "base", + module: ` + package test + + p { x = a + b * y } + + q[[data.test.f(x)]] { x = 1 } + + r = [data.test.f(x)] { x = 1 } + + f(x) = data.test.g(x) + + pi = 3 + .14 + + with_value { 1 with input as f(1) } + `, + expected: ` + package test + + p { mul(b, y, __local1__); plus(a, __local1__, __local2__); eq(x, __local2__) } + + q[[__local3__]] { x = 1; data.test.f(x, __local3__) } + + r = [__local4__] { x = 1; data.test.f(x, __local4__) } + + f(__local0__) = __local5__ { true; data.test.g(__local0__, __local5__) } + + pi = __local6__ { true; plus(3, 0.14, __local6__) } + + with_value { data.test.f(1, __local7__); 1 with input as __local7__ } + `, + }, + { + note: "builtin calls in head", + module: ` + package test + + f(1+1) = 7 + `, + expected: Errors{&Error{Message: "rule arguments cannot contain calls"}}, + }, + { + note: "builtin calls in head", + module: ` + package test + + f(object.get(x)) { object := {"a": 1}; object.a == x } + `, + expected: Errors{&Error{Message: "rule arguments cannot contain calls"}}, + }, } - compileStages(compiler, compiler.rewriteExprTerms) - assertNotFailed(t, compiler) - expected := MustParseModule(` - package test - - p { mul(b, y, __local1__); plus(a, __local1__, __local2__); eq(x, __local2__) } - - q[[__local3__]] { x = 1; data.test.f(x, __local3__) } - - r = [__local4__] { x = 1; data.test.f(x, __local4__) } + for _, tc := range cases { + t.Run(tc.note, func(t *testing.T) { + compiler := NewCompiler() + compiler.Modules = map[string]*Module{ + "test": MustParseModule(tc.module), + } + compileStages(compiler, compiler.rewriteExprTerms) - f(__local0__) = __local5__ { true; data.test.g(__local0__, __local5__) } + switch exp := tc.expected.(type) { + case string: + assertNotFailed(t, compiler) - pi = __local6__ { true; plus(3, 0.14, __local6__) } + expected := MustParseModule(exp) - with_value { data.test.f(1, __local7__); 1 with input as __local7__ } - `) + if !expected.Equal(compiler.Modules["test"]) { + t.Fatalf("Expected modules to be equal. Expected:\n\n%v\n\nGot:\n\n%v", expected, compiler.Modules["test"]) + } + case Errors: + if len(exp) != len(compiler.Errors) { + t.Fatalf("Expected %d errors, got %d:\n\n%s\n", len(exp), len(compiler.Errors), compiler.Errors.Error()) + } + incorrectErrs := false + for _, e := range exp { + found := false + for _, actual := range compiler.Errors { + if e.Message == actual.Message { + found = true + break + } + } + if !found { + incorrectErrs = true + } + } + if incorrectErrs { + t.Fatalf("Expected errors:\n\n%s\n\nGot:\n\n%s\n", exp.Error(), compiler.Errors.Error()) + } + default: + t.Fatalf("Unsupported value type for test case 'expected' field: %v", exp) + } - if !expected.Equal(compiler.Modules["test"]) { - t.Fatalf("Expected modules to be equal. Expected:\n\n%v\n\nGot:\n\n%v", expected, compiler.Modules["test"]) + }) } } @@ -3336,7 +3395,7 @@ func assertCompilerErrorStrings(t *testing.T, compiler *Compiler, expected []str func assertNotFailed(t *testing.T, c *Compiler) { if c.Failed() { - t.Errorf("Unexpected compilation error: %v", c.Errors) + t.Fatalf("Unexpected compilation error: %v", c.Errors) } } diff --git a/ast/env.go b/ast/env.go index 0519c30be5..ca32fb0ae8 100644 --- a/ast/env.go +++ b/ast/env.go @@ -129,6 +129,10 @@ func (env *TypeEnv) Get(x interface{}) types.Type { } return nil + // Calls. + case Call: + return nil + default: panic("unreachable") }