diff --git a/topdown/copypropagation/copypropagation.go b/topdown/copypropagation/copypropagation.go index 82490ca23f..820184b660 100644 --- a/topdown/copypropagation/copypropagation.go +++ b/topdown/copypropagation/copypropagation.go @@ -82,10 +82,6 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) { for _, expr := range query { - // Deep copy the expr as it may be mutated below. The caller that is running - // copy propagation may hold references to the expr. - expr = expr.Copy() - pctx := &plugContext{ bindings: bindings, uf: uf, @@ -93,7 +89,7 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) { headvars: headvars, } - if p.plugBindings(pctx, expr) { + if expr, keep := p.plugBindings(pctx, expr); keep { if p.updateBindings(pctx, expr) { result.Append(expr) } @@ -151,62 +147,94 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) { // plugBindings applies the binding list and union-find to x. This process // removes as many variables as possible. -func (p *CopyPropagator) plugBindings(pctx *plugContext, x interface{}) bool { +func (p *CopyPropagator) plugBindings(pctx *plugContext, expr *ast.Expr) (*ast.Expr, bool) { // Kill single term expressions that are in the binding list. They will be // re-added during post-processing if needed. - if expr, ok := x.(*ast.Expr); ok { - if term, ok := expr.Terms.(*ast.Term); ok { - if v, ok := term.Value.(ast.Var); ok { - if root, ok := pctx.uf.Find(v); ok { - if _, ok := pctx.bindings[root.key]; ok { - return false - } + if term, ok := expr.Terms.(*ast.Term); ok { + if v, ok := term.Value.(ast.Var); ok { + if root, ok := pctx.uf.Find(v); ok { + if _, ok := pctx.bindings[root.key]; ok { + return nil, false } } } } - ast.WalkTerms(x, func(t *ast.Term) bool { - // Apply union-find to remove redundant variables from input. - switch v := t.Value.(type) { - case ast.Var: - if root, ok := pctx.uf.Find(v); ok { - t.Value = root.Value() - } - case ast.Ref: - if root, ok := pctx.uf.Find(v[0].Value.(ast.Var)); ok { - v[0].Value = root.Value() + xform := bindingPlugTransform{ + pctx: pctx, + } + + // Deep copy the expression as it may be mutated during the transform and + // the caller running copy propagation may have references to the + // expression. Note, the transform does not contain any error paths and + // should never return a non-expression value for the root so consider + // errors unreachable. + x, err := ast.Transform(xform, expr.Copy()) + + if expr, ok := x.(*ast.Expr); !ok || err != nil { + panic("unreachable") + } else { + return expr, true + } +} + +type bindingPlugTransform struct { + pctx *plugContext +} + +func (t bindingPlugTransform) Transform(x interface{}) (interface{}, error) { + switch x := x.(type) { + case ast.Var: + return t.plugBindingsVar(t.pctx, x), nil + case ast.Ref: + return t.plugBindingsRef(t.pctx, x), nil + default: + return x, nil + } +} + +func (t bindingPlugTransform) plugBindingsVar(pctx *plugContext, v ast.Var) (result ast.Value) { + + result = v + + // Apply union-find to remove redundant variables from input. + if root, ok := pctx.uf.Find(v); ok { + result = root.Value() + } + + // Apply binding list to substitute remaining vars. + if v, ok := result.(ast.Var); ok { + if b, ok := pctx.bindings[v]; ok { + if !pctx.negated || b.v.IsGround() { + result = b.v } } - // Apply binding list to substitute remaining vars. - switch v := t.Value.(type) { - case ast.Var: - if b, ok := pctx.bindings[v]; ok { - if !pctx.negated || b.v.IsGround() { - t.Value = b.v - } - return true - } - case ast.Ref: - // Refs require special handling. If the head of the ref was killed, then the - // rest of the ref must be concatenated with the new base. - // - // Invariant: ref heads can only be replaced by refs (not calls). - if b, ok := pctx.bindings[v[0].Value.(ast.Var)]; ok { - if !pctx.negated || b.v.IsGround() { - t.Value = b.v.(ast.Ref).Concat(v[1:]) - } - } - for i := 1; i < len(v); i++ { - p.plugBindings(pctx, v[i]) - } - return true + } + + return result +} + +func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast.Ref { + + // Apply union-find to remove redundant variables from input. + if root, ok := pctx.uf.Find(v[0].Value.(ast.Var)); ok { + v[0].Value = root.Value() + } + + result := v + + // Refs require special handling. If the head of the ref was killed, then + // the rest of the ref must be concatenated with the new base. + // + // Invariant: ref heads can only be replaced by refs (not calls). + if b, ok := pctx.bindings[v[0].Value.(ast.Var)]; ok { + if !pctx.negated || b.v.IsGround() { + result = b.v.(ast.Ref).Concat(v[1:]) } - return false - }) + } - return true + return result } // updateBindings returns false if the expression can be killed. If the diff --git a/topdown/topdown_partial_test.go b/topdown/topdown_partial_test.go index 0ce7c982a1..f6fb0210dd 100644 --- a/topdown/topdown_partial_test.go +++ b/topdown/topdown_partial_test.go @@ -1046,6 +1046,22 @@ func TestTopDownPartialEval(t *testing.T) { "not input.y = x1; x1 = input.x[i1]", }, }, + { + note: "copy propagation: rewrite object key (bug 1177)", + query: `data.test.p = true`, + modules: []string{ + ` + package test + + p { + x = input.x + y = input.y + x = {y: 1} + } + `, + }, + wantQueries: []string{`input.x = {input.y: 1}`}, + }, { note: "save set vars are namespaced", query: "input = x; data.test.f(1)",