Skip to content

Commit

Permalink
ast: Relaxing strict-mode check for unused args in else-branching fun…
Browse files Browse the repository at this point in the history
…ctions (#5760)

Check is relaxed from requiring all arguments to be used in _all_ of the function's else-separated bodies,
to only require all arguments to be used in _any_ of the function's else-separated bodies.

Fixes: #5758

Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling authored Mar 16, 2023
1 parent 5ff0bcf commit 9474b8f
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 58 deletions.
148 changes: 108 additions & 40 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -2171,7 +2171,6 @@ func (c *Compiler) rewriteLocalVars() {
gen := c.localvargen

WalkRules(mod, func(rule *Rule) bool {

// Rewrite assignments contained in head of rule. Assignments can
// occur in rule head if they're inside a comprehension. Note,
// assigned vars in comprehensions in the head will be rewritten
Expand All @@ -2194,58 +2193,98 @@ func (c *Compiler) rewriteLocalVars() {
c.err(err)
}

// Rewrite assignments in body.
used := NewVarSet()

last := rule.Head.Ref()[len(rule.Head.Ref())-1]
used.Update(last.Vars())
argsStack := newLocalDeclaredVars()

if rule.Head.Key != nil {
used.Update(rule.Head.Key.Vars())
args := NewVarVisitor()
if c.strict {
args.Walk(rule.Head.Args)
}
unusedArgs := args.Vars()

if rule.Head.Value != nil {
used.Update(rule.Head.Value.Vars())
}
c.rewriteLocalArgVars(gen, argsStack, rule)

stack := newLocalDeclaredVars()
// Rewrite local vars in each else-branch of the rule.
// Note: this is done instead of a walk so that we can capture any unused function arguments
// across else-branches.
for rule := rule; rule != nil; rule = rule.Else {
stack, errs := c.rewriteLocalVarsInRule(rule, unusedArgs, argsStack, gen)

c.rewriteLocalArgVars(gen, stack, rule)
for arg := range unusedArgs {
if stack.Count(arg) > 1 {
delete(unusedArgs, arg)
}
}

body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body, c.strict)
for _, err := range errs {
c.err(err)
for _, err := range errs {
c.err(err)
}
}

// For rewritten vars use the collection of all variables that
// were in the stack at some point in time.
for k, v := range stack.rewritten {
c.RewrittenVars[k] = v
if c.strict {
// Report an error for each unused function argument
for arg := range unusedArgs {
if !arg.IsWildcard() {
c.err(NewError(CompileErr, rule.Head.Location, "unused argument %v", arg))
}
}
}

rule.Body = body
return true
})
}
}

// Rewrite vars in head that refer to locally declared vars in the body.
localXform := rewriteHeadVarLocalTransform{declared: declared}
func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsStack *localDeclaredVars, gen *localVarGenerator) (*localDeclaredVars, Errors) {
// Rewrite assignments in body.
used := NewVarSet()

for i := range rule.Head.Args {
rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i])
}
last := rule.Head.Ref()[len(rule.Head.Ref())-1]
used.Update(last.Vars())

for i := 1; i < len(rule.Head.Ref()); i++ {
rule.Head.Reference[i], _ = transformTerm(localXform, rule.Head.Ref()[i])
}
if rule.Head.Key != nil {
rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key)
}
if rule.Head.Key != nil {
used.Update(rule.Head.Key.Vars())
}

if rule.Head.Value != nil {
rule.Head.Value, _ = transformTerm(localXform, rule.Head.Value)
if rule.Head.Value != nil {
valueVars := rule.Head.Value.Vars()
used.Update(valueVars)
for arg := range unusedArgs {
if valueVars.Contains(arg) {
delete(unusedArgs, arg)
}
}
}

return false
})
stack := argsStack.Copy()

body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body, c.strict)

// For rewritten vars use the collection of all variables that
// were in the stack at some point in time.
for k, v := range stack.rewritten {
c.RewrittenVars[k] = v
}

rule.Body = body

// Rewrite vars in head that refer to locally declared vars in the body.
localXform := rewriteHeadVarLocalTransform{declared: declared}

for i := range rule.Head.Args {
rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i])
}

for i := 1; i < len(rule.Head.Ref()); i++ {
rule.Head.Reference[i], _ = transformTerm(localXform, rule.Head.Ref()[i])
}
if rule.Head.Key != nil {
rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key)
}

if rule.Head.Value != nil {
rule.Head.Value, _ = transformTerm(localXform, rule.Head.Value)
}
return stack, errs
}

type rewriteNestedHeadVarLocalTransform struct {
Expand Down Expand Up @@ -4607,6 +4646,35 @@ func newLocalDeclaredVars() *localDeclaredVars {
}
}

func (s *localDeclaredVars) Copy() *localDeclaredVars {
stack := &localDeclaredVars{
vars: []*declaredVarSet{},
rewritten: map[Var]Var{},
}

for i := range s.vars {
stack.vars = append(stack.vars, newDeclaredVarSet())
for k, v := range s.vars[i].vs {
stack.vars[0].vs[k] = v
}
for k, v := range s.vars[i].reverse {
stack.vars[0].reverse[k] = v
}
for k, v := range s.vars[i].count {
stack.vars[0].count[k] = v
}
for k, v := range s.vars[i].occurrence {
stack.vars[0].occurrence[k] = v
}
}

for k, v := range s.rewritten {
stack.rewritten[k] = v
}

return stack
}

func (s *localDeclaredVars) Push() {
s.vars = append(s.vars, newDeclaredVarSet())
}
Expand Down Expand Up @@ -4701,7 +4769,7 @@ func (s localDeclaredVars) Count(x Var) int {
func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, strict bool) (Body, map[Var]Var, Errors) {
var errs Errors
body, errs = rewriteDeclaredVarsInBody(g, stack, used, body, errs, strict)
return body, stack.Pop().vs, errs
return body, stack.Peek().vs, errs
}

func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, errs Errors, strict bool) (Body, Errors) {
Expand Down Expand Up @@ -4732,11 +4800,11 @@ func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, u
cpy.Append(NewExpr(BooleanTerm(true)))
}

errs = checkUnusedAssignedAndArgVars(body[0].Loc(), stack, used, errs, strict)
errs = checkUnusedAssignedVars(body[0].Loc(), stack, used, errs, strict)
return cpy, checkUnusedDeclaredVars(body, stack, used, cpy, errs)
}

func checkUnusedAssignedAndArgVars(loc *Location, stack *localDeclaredVars, used VarSet, errs Errors, strict bool) Errors {
func checkUnusedAssignedVars(loc *Location, stack *localDeclaredVars, used VarSet, errs Errors, strict bool) Errors {

if !strict || len(errs) > 0 {
return errs
Expand All @@ -4748,7 +4816,7 @@ func checkUnusedAssignedAndArgVars(loc *Location, stack *localDeclaredVars, used
for v, occ := range dvs.occurrence {
// A var that was assigned in this scope must have been seen (used) more than once (the time of assignment) in
// the same, or nested, scope to be counted as used.
if !v.IsWildcard() && stack.Count(v) <= 1 && (occ == assignedVar || occ == argVar) {
if !v.IsWildcard() && stack.Count(v) <= 1 && occ == assignedVar {
unused.Add(dvs.vs[v])
}
}
Expand Down
Loading

0 comments on commit 9474b8f

Please sign in to comment.