Skip to content

Commit

Permalink
format: Deep copy inputs to avoid mutating the caller's copy
Browse files Browse the repository at this point in the history
As part of this change, also update the format package to unmangle the
variables slightly differently--just remove the wildcard prefix
instead of translating the variable names. This makes it easier to
tell where the variables came from in the first place and is a bit
less complicated.

Fixes open-policy-agent#2439

Signed-off-by: Torin Sandall <torinsandall@gmail.com>
  • Loading branch information
tsandall authored and patrick-east committed May 29, 2020
1 parent 3042b50 commit fe18f11
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 57 deletions.
55 changes: 55 additions & 0 deletions ast/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,14 @@ func (c *Comment) String() string {
return "#" + string(c.Text)
}

// Copy returns a deep copy of c.
func (c *Comment) Copy() *Comment {
cpy := *c
cpy.Text = make([]byte, len(c.Text))
copy(cpy.Text, c.Text)
return &cpy
}

// Equal returns true if this comment equals the other comment.
// Unlike other equality checks on AST nodes, comment equality
// depends on location.
Expand Down Expand Up @@ -1251,6 +1259,53 @@ func (w *With) SetLoc(loc *Location) {
w.Location = loc
}

// Copy returns a deep copy of the AST node x. If x is not an AST node, x is returned unmodified.
func Copy(x interface{}) interface{} {
switch x := x.(type) {
case *Module:
return x.Copy()
case *Package:
return x.Copy()
case *Import:
return x.Copy()
case *Rule:
return x.Copy()
case *Head:
return x.Copy()
case Args:
return x.Copy()
case Body:
return x.Copy()
case *Expr:
return x.Copy()
case *With:
return x.Copy()
case *SomeDecl:
return x.Copy()
case *Term:
return x.Copy()
case *ArrayComprehension:
return x.Copy()
case *SetComprehension:
return x.Copy()
case *ObjectComprehension:
return x.Copy()
case Set:
return x.Copy()
case Object:
return x.Copy()
case Array:
return x.Copy()
case Ref:
return x.Copy()
case Call:
return x.Copy()
case *Comment:
return x.Copy()
}
return x
}

// RuleSet represents a collection of rules that produce a virtual document.
type RuleSet []*Rule

Expand Down
19 changes: 19 additions & 0 deletions ast/policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"reflect"
"testing"

"github.com/open-policy-agent/opa/ast/location"
"github.com/open-policy-agent/opa/util"
)

Expand Down Expand Up @@ -514,6 +515,24 @@ func TestSomeDeclString(t *testing.T) {
}
}

func TestCommentCopy(t *testing.T) {
comment := &Comment{
Text: []byte("foo bar baz"),
Location: &location.Location{}, // location must be set for comment equality
}

cpy := comment.Copy()
if !cpy.Equal(comment) {
t.Fatal("expected copy to be equal")
}

comment.Text[1] = '0'

if cpy.Equal(comment) {
t.Fatal("expected copy to be unmodified")
}
}

func assertExprEqual(t *testing.T, a, b *Expr) {
t.Helper()
if !a.Equal(b) {
Expand Down
4 changes: 2 additions & 2 deletions compile/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ func (c *Compiler) compileWasm(ctx context.Context) error {
}

store := inmem.NewFromObject(c.bundle.Data)
resultSym := ast.VarTerm(ast.WildcardPrefix + "result")
resultSym := ast.VarTerm(ast.WildcardPrefix + "__result__")

cr, err := rego.New(
rego.ParsedQuery(ast.NewBody(ast.Equality.Expr(resultSym, c.entrypointrefs[0]))),
Expand Down Expand Up @@ -377,7 +377,7 @@ func (o *optimizer) Do(ctx context.Context) error {
}

store := inmem.NewFromObject(data)
resultsym := ast.VarTerm(o.resultsymprefix + "result")
resultsym := ast.VarTerm(o.resultsymprefix + "__result__")
usedFilenames := map[string]int{}

// NOTE(tsandall): the entrypoints are optimized in order so that the optimization
Expand Down
20 changes: 10 additions & 10 deletions compile/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,9 +523,9 @@ func TestOptimizerOutput(t *testing.T) {
"optimized/test.rego": `
package test
p = result { 1 = input.x; result = true }
p = result { 2 = input.x; result = true }
p = result { 3 = input.x; result = true }
p = __result__ { 1 = input.x; __result__ = true }
p = __result__ { 2 = input.x; __result__ = true }
p = __result__ { 3 = input.x; __result__ = true }
`,
"test.rego": `
package test
Expand Down Expand Up @@ -589,12 +589,12 @@ func TestOptimizerOutput(t *testing.T) {
"optimized/test.rego": `
package test
p = result { 1 = input.x; result = true }
p = __result__ { 1 = input.x; __result__ = true }
`,
"optimized/test.1.rego": `
package test
r = result { 1 = input.x; result = true }
r = __result__ { 1 = input.x; __result__ = true }
`,
"test.rego": `
package test
Expand All @@ -617,7 +617,7 @@ func TestOptimizerOutput(t *testing.T) {
"optimized/test.rego": `
package test
foo = result { result = {"bar": {"p": true}} }`,
foo = __result__ { __result__ = {"bar": {"p": true}} }`,
},
},
{
Expand Down Expand Up @@ -647,7 +647,7 @@ func TestOptimizerOutput(t *testing.T) {
"optimized/test.1.rego": `
package test
p = result { data.test.q[input.x]; result = true }
p = __result__ { data.test.q[input.x]; __result__ = true }
`,
"optimized/test.rego": `
package test
Expand Down Expand Up @@ -679,8 +679,8 @@ func TestOptimizerOutput(t *testing.T) {
wantModules: map[string]string{
"optimized/partial/0/0.rego": `
package test["foo bar"]
p = result { 1 = input.x; result = true }
p = result { 2 = input.x; result = true }
p = __result__ { 1 = input.x; __result__ = true }
p = __result__ { 2 = input.x; __result__ = true }
`,
"x.rego": `
package test["foo bar"]
Expand Down Expand Up @@ -719,7 +719,7 @@ func TestOptimizerOutput(t *testing.T) {
"optimized/test.rego": `
package test
p = result { not data.partial.__not1_0__; result = true }
p = __result__ { not data.partial.__not1_0__; __result__ = true }
`,
"test.rego": `
package test
Expand Down
71 changes: 38 additions & 33 deletions format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,16 @@ func MustAst(x interface{}) []byte {
}

// Ast formats a Rego AST element. If the passed value is not a valid AST
// element, Ast returns nil and an error. Ast relies on all AST elements having
// non-nil Location values. If an AST element with a nil Location value is
// encountered, a default location will be set on the AST node.
// element, Ast returns nil and an error. If AST nodes are missing locations
// an arbitrary location will be used.
func Ast(x interface{}) (formatted []byte, err error) {

wildcards := map[string]*ast.Term{}
wildcardNames := map[string]string{}
wildcardCounter := 0
// The node has to be deep copied because it may be mutated below. Alternatively,
// we could avoid the copy by checking if mtuation will occur first. For now,
// since format is not latency sensitive, just deep copy in all cases.
x = ast.Copy(x)

wildcards := map[ast.Var]*ast.Term{}

// Preprocess the AST. Set any required defaults and calculate
// values required for printing the formatted output.
Expand All @@ -58,40 +60,15 @@ func Ast(x interface{}) (formatted []byte, err error) {
return false
}
case *ast.Term:
if v, ok := n.Value.(ast.Var); ok {
if v.IsWildcard() {
str := string(v)
if _, seen := wildcards[str]; !seen {
// Keep a reference to the wildcard term so we can, if
// needed, rewrite it later
wildcards[str] = n
} else {
// On the second time we have seen the wildcard generate
// a name for it
newName, ok := wildcardNames[str]
if !ok {
newName = fmt.Sprintf("__wildcard%d__", wildcardCounter)
wildcardNames[str] = newName
wildcardCounter++

// Rewrite the first one that was "seen"
wildcards[str].Value = ast.Var(newName)
}

// Rewrite the current wildcard with its generated name
n.Value = ast.Var(newName)
}
}
}
unmangleWildcardVar(wildcards, n)
}

if x.Loc() == nil {
x.SetLoc(defaultLocation(x))
}
return false
})

w := &writer{indent: "\t", wildcardNames: wildcardNames}
w := &writer{indent: "\t"}
switch x := x.(type) {
case *ast.Module:
w.writeModule(x)
Expand Down Expand Up @@ -122,6 +99,34 @@ func Ast(x interface{}) (formatted []byte, err error) {
return squashTrailingNewlines(w.buf.Bytes()), nil
}

func unmangleWildcardVar(wildcards map[ast.Var]*ast.Term, n *ast.Term) {

v, ok := n.Value.(ast.Var)
if !ok || !v.IsWildcard() {
return
}

first, ok := wildcards[v]
if !ok {
wildcards[v] = n
return
}

w := v[len(ast.WildcardPrefix):]

// Prepend an underscore to ensure the variable will parse.
if len(w) == 0 || w[0] != '_' {
w = "_" + w
}

if first != nil {
first.Value = w
wildcards[v] = nil
}

n.Value = w
}

func squashTrailingNewlines(bs []byte) []byte {
if bytes.HasSuffix(bs, []byte("\n")) {
return append(bytes.TrimRight(bs, "\n"), '\n')
Expand Down
50 changes: 38 additions & 12 deletions format/format_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,9 @@ func TestFormatAST(t *testing.T) {
},
},
},
expected: `input.arr[__wildcard0__]["some key"][_] = bar
input.arr[__wildcard0__].bar = qux
foo[__wildcard1__][__wildcard0__].bar = bar[__wildcard1__][_][__wildcard0__].bar
expected: `input.arr[_01]["some key"][_] = bar
input.arr[_01].bar = qux
foo[_03][_01].bar = bar[_03][_][_01].bar
`,
},
{
Expand All @@ -252,11 +252,11 @@ foo[__wildcard1__][__wildcard0__].bar = bar[__wildcard1__][_][__wildcard0__].bar
},
&ast.Expr{
Index: 1,
Terms: ast.RefTerm(ast.VarTerm("$x"), ast.VarTerm("x")),
Terms: ast.RefTerm(ast.VarTerm("$x"), ast.VarTerm("y")),
},
},
expected: `__wildcard0__
__wildcard0__[x]`,
expected: `_x
_x[y]`,
},
{
note: "body shared wildcard - nested ref",
Expand All @@ -267,11 +267,11 @@ __wildcard0__[x]`,
},
&ast.Expr{
Index: 1,
Terms: ast.RefTerm(ast.VarTerm("a"), ast.RefTerm(ast.VarTerm("$x"), ast.VarTerm("x"))),
Terms: ast.RefTerm(ast.VarTerm("a"), ast.RefTerm(ast.VarTerm("$x"), ast.VarTerm("y"))),
},
},
expected: `__wildcard0__
a[__wildcard0__[x]]`,
expected: `_x
a[_x[y]]`,
},
{
note: "body shared wildcard - nested ref array",
Expand All @@ -282,11 +282,11 @@ a[__wildcard0__[x]]`,
},
&ast.Expr{
Index: 1,
Terms: ast.RefTerm(ast.VarTerm("a"), ast.RefTerm(ast.VarTerm("$x"), ast.VarTerm("x"), ast.ArrayTerm(ast.VarTerm("y"), ast.VarTerm("z")))),
Terms: ast.RefTerm(ast.VarTerm("a"), ast.RefTerm(ast.VarTerm("$x"), ast.VarTerm("y"), ast.ArrayTerm(ast.VarTerm("z"), ast.VarTerm("w")))),
},
},
expected: `__wildcard0__
a[__wildcard0__[x][[y, z]]]`,
expected: `_x
a[_x[y][[z, w]]]`,
},
}

Expand All @@ -305,6 +305,32 @@ a[__wildcard0__[x][[y, z]]]`,
}
}

func TestFormatDeepCopy(t *testing.T) {

original := ast.Body{
&ast.Expr{
Index: 0,
Terms: ast.VarTerm("$x"),
},
&ast.Expr{
Index: 1,
Terms: ast.RefTerm(ast.VarTerm("$x"), ast.VarTerm("y")),
},
}

cpy := original.Copy()

_, err := Ast(original)
if err != nil {
t.Fatal(err)
}

if !cpy.Equal(original) {
t.Fatal("expected original to be unmodified")
}

}

func differsAt(a, b []byte) (int, int) {
if bytes.Equal(a, b) {
return 0, 0
Expand Down

0 comments on commit fe18f11

Please sign in to comment.