diff --git a/ast/policy.go b/ast/policy.go index d66f7f06ff..10cee9f50d 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -315,6 +315,14 @@ func (c *Comment) String() string { return "#" + string(c.Text) } +// Copy returns a deep copy of c. +func (c *Comment) Copy() *Comment { + cpy := *c + cpy.Text = make([]byte, len(c.Text)) + copy(cpy.Text, c.Text) + return &cpy +} + // Equal returns true if this comment equals the other comment. // Unlike other equality checks on AST nodes, comment equality // depends on location. @@ -1251,6 +1259,53 @@ func (w *With) SetLoc(loc *Location) { w.Location = loc } +// Copy returns a deep copy of the AST node x. If x is not an AST node, x is returned unmodified. +func Copy(x interface{}) interface{} { + switch x := x.(type) { + case *Module: + return x.Copy() + case *Package: + return x.Copy() + case *Import: + return x.Copy() + case *Rule: + return x.Copy() + case *Head: + return x.Copy() + case Args: + return x.Copy() + case Body: + return x.Copy() + case *Expr: + return x.Copy() + case *With: + return x.Copy() + case *SomeDecl: + return x.Copy() + case *Term: + return x.Copy() + case *ArrayComprehension: + return x.Copy() + case *SetComprehension: + return x.Copy() + case *ObjectComprehension: + return x.Copy() + case Set: + return x.Copy() + case Object: + return x.Copy() + case Array: + return x.Copy() + case Ref: + return x.Copy() + case Call: + return x.Copy() + case *Comment: + return x.Copy() + } + return x +} + // RuleSet represents a collection of rules that produce a virtual document. type RuleSet []*Rule diff --git a/ast/policy_test.go b/ast/policy_test.go index 400fe6cea7..8ca9e377c9 100644 --- a/ast/policy_test.go +++ b/ast/policy_test.go @@ -10,6 +10,7 @@ import ( "reflect" "testing" + "github.com/open-policy-agent/opa/ast/location" "github.com/open-policy-agent/opa/util" ) @@ -514,6 +515,24 @@ func TestSomeDeclString(t *testing.T) { } } +func TestCommentCopy(t *testing.T) { + comment := &Comment{ + Text: []byte("foo bar baz"), + Location: &location.Location{}, // location must be set for comment equality + } + + cpy := comment.Copy() + if !cpy.Equal(comment) { + t.Fatal("expected copy to be equal") + } + + comment.Text[1] = '0' + + if cpy.Equal(comment) { + t.Fatal("expected copy to be unmodified") + } +} + func assertExprEqual(t *testing.T, a, b *Expr) { t.Helper() if !a.Equal(b) { diff --git a/compile/compile.go b/compile/compile.go index 7142927f0e..bfe43c0893 100644 --- a/compile/compile.go +++ b/compile/compile.go @@ -300,7 +300,7 @@ func (c *Compiler) compileWasm(ctx context.Context) error { } store := inmem.NewFromObject(c.bundle.Data) - resultSym := ast.VarTerm(ast.WildcardPrefix + "result") + resultSym := ast.VarTerm(ast.WildcardPrefix + "__result__") cr, err := rego.New( rego.ParsedQuery(ast.NewBody(ast.Equality.Expr(resultSym, c.entrypointrefs[0]))), @@ -377,7 +377,7 @@ func (o *optimizer) Do(ctx context.Context) error { } store := inmem.NewFromObject(data) - resultsym := ast.VarTerm(o.resultsymprefix + "result") + resultsym := ast.VarTerm(o.resultsymprefix + "__result__") usedFilenames := map[string]int{} // NOTE(tsandall): the entrypoints are optimized in order so that the optimization diff --git a/compile/compile_test.go b/compile/compile_test.go index 019ad19f83..d363e9f5bd 100644 --- a/compile/compile_test.go +++ b/compile/compile_test.go @@ -523,9 +523,9 @@ func TestOptimizerOutput(t *testing.T) { "optimized/test.rego": ` package test - p = result { 1 = input.x; result = true } - p = result { 2 = input.x; result = true } - p = result { 3 = input.x; result = true } + p = __result__ { 1 = input.x; __result__ = true } + p = __result__ { 2 = input.x; __result__ = true } + p = __result__ { 3 = input.x; __result__ = true } `, "test.rego": ` package test @@ -589,12 +589,12 @@ func TestOptimizerOutput(t *testing.T) { "optimized/test.rego": ` package test - p = result { 1 = input.x; result = true } + p = __result__ { 1 = input.x; __result__ = true } `, "optimized/test.1.rego": ` package test - r = result { 1 = input.x; result = true } + r = __result__ { 1 = input.x; __result__ = true } `, "test.rego": ` package test @@ -617,7 +617,7 @@ func TestOptimizerOutput(t *testing.T) { "optimized/test.rego": ` package test - foo = result { result = {"bar": {"p": true}} }`, + foo = __result__ { __result__ = {"bar": {"p": true}} }`, }, }, { @@ -647,7 +647,7 @@ func TestOptimizerOutput(t *testing.T) { "optimized/test.1.rego": ` package test - p = result { data.test.q[input.x]; result = true } + p = __result__ { data.test.q[input.x]; __result__ = true } `, "optimized/test.rego": ` package test @@ -679,8 +679,8 @@ func TestOptimizerOutput(t *testing.T) { wantModules: map[string]string{ "optimized/partial/0/0.rego": ` package test["foo bar"] - p = result { 1 = input.x; result = true } - p = result { 2 = input.x; result = true } + p = __result__ { 1 = input.x; __result__ = true } + p = __result__ { 2 = input.x; __result__ = true } `, "x.rego": ` package test["foo bar"] @@ -719,7 +719,7 @@ func TestOptimizerOutput(t *testing.T) { "optimized/test.rego": ` package test - p = result { not data.partial.__not1_0__; result = true } + p = __result__ { not data.partial.__not1_0__; __result__ = true } `, "test.rego": ` package test diff --git a/format/format.go b/format/format.go index 95722146cf..35ed7603de 100644 --- a/format/format.go +++ b/format/format.go @@ -40,14 +40,16 @@ func MustAst(x interface{}) []byte { } // Ast formats a Rego AST element. If the passed value is not a valid AST -// element, Ast returns nil and an error. Ast relies on all AST elements having -// non-nil Location values. If an AST element with a nil Location value is -// encountered, a default location will be set on the AST node. +// element, Ast returns nil and an error. If AST nodes are missing locations +// an arbitrary location will be used. func Ast(x interface{}) (formatted []byte, err error) { - wildcards := map[string]*ast.Term{} - wildcardNames := map[string]string{} - wildcardCounter := 0 + // The node has to be deep copied because it may be mutated below. Alternatively, + // we could avoid the copy by checking if mtuation will occur first. For now, + // since format is not latency sensitive, just deep copy in all cases. + x = ast.Copy(x) + + wildcards := map[ast.Var]*ast.Term{} // Preprocess the AST. Set any required defaults and calculate // values required for printing the formatted output. @@ -58,40 +60,15 @@ func Ast(x interface{}) (formatted []byte, err error) { return false } case *ast.Term: - if v, ok := n.Value.(ast.Var); ok { - if v.IsWildcard() { - str := string(v) - if _, seen := wildcards[str]; !seen { - // Keep a reference to the wildcard term so we can, if - // needed, rewrite it later - wildcards[str] = n - } else { - // On the second time we have seen the wildcard generate - // a name for it - newName, ok := wildcardNames[str] - if !ok { - newName = fmt.Sprintf("__wildcard%d__", wildcardCounter) - wildcardNames[str] = newName - wildcardCounter++ - - // Rewrite the first one that was "seen" - wildcards[str].Value = ast.Var(newName) - } - - // Rewrite the current wildcard with its generated name - n.Value = ast.Var(newName) - } - } - } + unmangleWildcardVar(wildcards, n) } - if x.Loc() == nil { x.SetLoc(defaultLocation(x)) } return false }) - w := &writer{indent: "\t", wildcardNames: wildcardNames} + w := &writer{indent: "\t"} switch x := x.(type) { case *ast.Module: w.writeModule(x) @@ -122,6 +99,34 @@ func Ast(x interface{}) (formatted []byte, err error) { return squashTrailingNewlines(w.buf.Bytes()), nil } +func unmangleWildcardVar(wildcards map[ast.Var]*ast.Term, n *ast.Term) { + + v, ok := n.Value.(ast.Var) + if !ok || !v.IsWildcard() { + return + } + + first, ok := wildcards[v] + if !ok { + wildcards[v] = n + return + } + + w := v[len(ast.WildcardPrefix):] + + // Prepend an underscore to ensure the variable will parse. + if len(w) == 0 || w[0] != '_' { + w = "_" + w + } + + if first != nil { + first.Value = w + wildcards[v] = nil + } + + n.Value = w +} + func squashTrailingNewlines(bs []byte) []byte { if bytes.HasSuffix(bs, []byte("\n")) { return append(bytes.TrimRight(bs, "\n"), '\n') diff --git a/format/format_test.go b/format/format_test.go index a8f5dc9232..1047c0d568 100644 --- a/format/format_test.go +++ b/format/format_test.go @@ -238,9 +238,9 @@ func TestFormatAST(t *testing.T) { }, }, }, - expected: `input.arr[__wildcard0__]["some key"][_] = bar -input.arr[__wildcard0__].bar = qux -foo[__wildcard1__][__wildcard0__].bar = bar[__wildcard1__][_][__wildcard0__].bar + expected: `input.arr[_01]["some key"][_] = bar +input.arr[_01].bar = qux +foo[_03][_01].bar = bar[_03][_][_01].bar `, }, { @@ -252,11 +252,11 @@ foo[__wildcard1__][__wildcard0__].bar = bar[__wildcard1__][_][__wildcard0__].bar }, &ast.Expr{ Index: 1, - Terms: ast.RefTerm(ast.VarTerm("$x"), ast.VarTerm("x")), + Terms: ast.RefTerm(ast.VarTerm("$x"), ast.VarTerm("y")), }, }, - expected: `__wildcard0__ -__wildcard0__[x]`, + expected: `_x +_x[y]`, }, { note: "body shared wildcard - nested ref", @@ -267,11 +267,11 @@ __wildcard0__[x]`, }, &ast.Expr{ Index: 1, - Terms: ast.RefTerm(ast.VarTerm("a"), ast.RefTerm(ast.VarTerm("$x"), ast.VarTerm("x"))), + Terms: ast.RefTerm(ast.VarTerm("a"), ast.RefTerm(ast.VarTerm("$x"), ast.VarTerm("y"))), }, }, - expected: `__wildcard0__ -a[__wildcard0__[x]]`, + expected: `_x +a[_x[y]]`, }, { note: "body shared wildcard - nested ref array", @@ -282,11 +282,11 @@ a[__wildcard0__[x]]`, }, &ast.Expr{ Index: 1, - Terms: ast.RefTerm(ast.VarTerm("a"), ast.RefTerm(ast.VarTerm("$x"), ast.VarTerm("x"), ast.ArrayTerm(ast.VarTerm("y"), ast.VarTerm("z")))), + Terms: ast.RefTerm(ast.VarTerm("a"), ast.RefTerm(ast.VarTerm("$x"), ast.VarTerm("y"), ast.ArrayTerm(ast.VarTerm("z"), ast.VarTerm("w")))), }, }, - expected: `__wildcard0__ -a[__wildcard0__[x][[y, z]]]`, + expected: `_x +a[_x[y][[z, w]]]`, }, } @@ -305,6 +305,32 @@ a[__wildcard0__[x][[y, z]]]`, } } +func TestFormatDeepCopy(t *testing.T) { + + original := ast.Body{ + &ast.Expr{ + Index: 0, + Terms: ast.VarTerm("$x"), + }, + &ast.Expr{ + Index: 1, + Terms: ast.RefTerm(ast.VarTerm("$x"), ast.VarTerm("y")), + }, + } + + cpy := original.Copy() + + _, err := Ast(original) + if err != nil { + t.Fatal(err) + } + + if !cpy.Equal(original) { + t.Fatal("expected original to be unmodified") + } + +} + func differsAt(a, b []byte) (int, int) { if bytes.Equal(a, b) { return 0, 0