Skip to content

Commit

Permalink
Refactor save set updating to use ast.Walk
Browse files Browse the repository at this point in the history
In 6f9ed62 we fixed an issue in partial eval that was causing vars
embedded inside composites to not be added to the save set when they
should have been.

This change simply refactors that fix to use ast.Walk and rely on the
same logic in both saveUnify and saveCall.

Note, the plug call on the var terms has been removed as it's not needed
since we are already plugging terms (or applying) when checking if terms
are contained in the save set.

Signed-off-by: Torin Sandall <torinsandall@gmail.com>
  • Loading branch information
tsandall committed Jul 17, 2018
1 parent 6f9ed62 commit 28582df
Showing 1 changed file with 18 additions and 42 deletions.
60 changes: 18 additions & 42 deletions topdown/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,18 @@ func (e *eval) biunifyComprehensionObject(x *ast.ObjectComprehension, b *ast.Ter
return e.biunify(ast.NewTerm(result), b, b1, b2, iter)
}

func (e *eval) getSaveTerms(x interface{}) (result []*ast.Term) {
vis := ast.NewVarVisitor().WithParams(ast.VarVisitorParams{
SkipClosures: true,
SkipRefHead: true,
})
ast.Walk(vis, x)
for v := range vis.Vars() {
result = append(result, ast.NewTerm(v))
}
return
}

func (e *eval) saveExpr(expr *ast.Expr, b *bindings, iter unifyIterator) error {
e.saveStack.Push(expr, b, nil)
defer e.saveStack.Pop()
Expand All @@ -704,9 +716,11 @@ func (e *eval) saveExpr(expr *ast.Expr, b *bindings, iter unifyIterator) error {

func (e *eval) saveUnify(a, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error {
expr := ast.Equality.Expr(a, b)
elem := newSaveSetElem(e.getUnifyOutputs(expr))
e.saveSet.Push(elem)
defer e.saveSet.Pop()
if ts := e.getSaveTerms(expr); len(ts) > 0 {
elem := newSaveSetElem(ts)
e.saveSet.Push(elem)
defer e.saveSet.Pop()
}
e.saveStack.Push(expr, b1, b2)
defer e.saveStack.Pop()
e.traceSave(expr)
Expand Down Expand Up @@ -734,9 +748,7 @@ func (e *eval) saveCall(operator *ast.Term, args []*ast.Term, result *ast.Term,
}
terms[len(terms)-1] = result
expr := ast.NewExpr(terms)

// result might be composite (object/array)
if ts := nonGroundedValues(result); len(ts) > 0 {
if ts := e.getSaveTerms(result); len(ts) > 0 {
elem := newSaveSetElem(ts)
e.saveSet.Push(elem)
defer e.saveSet.Pop()
Expand Down Expand Up @@ -847,21 +859,6 @@ func (e *eval) Resolve(ref ast.Ref) (ast.Value, error) {
return nil, fmt.Errorf("illegal ref")
}

func (e *eval) getUnifyOutputs(expr *ast.Expr) []*ast.Term {
vars := expr.Vars(ast.VarVisitorParams{
SkipClosures: true,
SkipRefHead: true,
})
outputs := make([]*ast.Term, 0, len(vars))
for v := range vars {
term := e.bindings.Plug(ast.NewTerm(v))
if !term.IsGround() {
outputs = append(outputs, term)
}
}
return outputs
}

func (e *eval) generateVar(suffix string) *ast.Term {
return ast.VarTerm(fmt.Sprintf("%v_%v", e.genvarprefix, suffix))
}
Expand Down Expand Up @@ -1774,24 +1771,3 @@ func plugSlice(xs []*ast.Term, b *bindings) []*ast.Term {
}
return cpy
}

// Note: this needs to be recursive, as both Arrays and Objects may have nested
// non-grounded values
func nonGroundedValues(result *ast.Term) []*ast.Term {
res := []*ast.Term{}
switch xs := result.Value.(type) {
case ast.Array:
for _, x := range xs {
res = append(res, nonGroundedValues(x)...)
}
case ast.Object:
xs.Foreach(func(_, val *ast.Term) {
res = append(res, nonGroundedValues(val)...)
})
default:
if !result.IsGround() {
res = append(res, result)
}
}
return res
}

0 comments on commit 28582df

Please sign in to comment.