diff --git a/ast/compile.go b/ast/compile.go index 3fdb342248..ddb074e3d0 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -2171,7 +2171,6 @@ func (c *Compiler) rewriteLocalVars() { gen := c.localvargen WalkRules(mod, func(rule *Rule) bool { - // Rewrite assignments contained in head of rule. Assignments can // occur in rule head if they're inside a comprehension. Note, // assigned vars in comprehensions in the head will be rewritten @@ -2194,58 +2193,98 @@ func (c *Compiler) rewriteLocalVars() { c.err(err) } - // Rewrite assignments in body. - used := NewVarSet() - - last := rule.Head.Ref()[len(rule.Head.Ref())-1] - used.Update(last.Vars()) + argsStack := newLocalDeclaredVars() - if rule.Head.Key != nil { - used.Update(rule.Head.Key.Vars()) + args := NewVarVisitor() + if c.strict { + args.Walk(rule.Head.Args) } + unusedArgs := args.Vars() - if rule.Head.Value != nil { - used.Update(rule.Head.Value.Vars()) - } + c.rewriteLocalArgVars(gen, argsStack, rule) - stack := newLocalDeclaredVars() + // Rewrite local vars in each else-branch of the rule. + // Note: this is done instead of a walk so that we can capture any unused function arguments + // across else-branches. + for rule := rule; rule != nil; rule = rule.Else { + stack, errs := c.rewriteLocalVarsInRule(rule, unusedArgs, argsStack, gen) - c.rewriteLocalArgVars(gen, stack, rule) + for arg := range unusedArgs { + if stack.Count(arg) > 1 { + delete(unusedArgs, arg) + } + } - body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body, c.strict) - for _, err := range errs { - c.err(err) + for _, err := range errs { + c.err(err) + } } - // For rewritten vars use the collection of all variables that - // were in the stack at some point in time. - for k, v := range stack.rewritten { - c.RewrittenVars[k] = v + if c.strict { + // Report an error for each unused function argument + for arg := range unusedArgs { + if !arg.IsWildcard() { + c.err(NewError(CompileErr, rule.Head.Location, "unused argument %v", arg)) + } + } } - rule.Body = body + return true + }) + } +} - // Rewrite vars in head that refer to locally declared vars in the body. - localXform := rewriteHeadVarLocalTransform{declared: declared} +func (c *Compiler) rewriteLocalVarsInRule(rule *Rule, unusedArgs VarSet, argsStack *localDeclaredVars, gen *localVarGenerator) (*localDeclaredVars, Errors) { + // Rewrite assignments in body. + used := NewVarSet() - for i := range rule.Head.Args { - rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i]) - } + last := rule.Head.Ref()[len(rule.Head.Ref())-1] + used.Update(last.Vars()) - for i := 1; i < len(rule.Head.Ref()); i++ { - rule.Head.Reference[i], _ = transformTerm(localXform, rule.Head.Ref()[i]) - } - if rule.Head.Key != nil { - rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key) - } + if rule.Head.Key != nil { + used.Update(rule.Head.Key.Vars()) + } - if rule.Head.Value != nil { - rule.Head.Value, _ = transformTerm(localXform, rule.Head.Value) + if rule.Head.Value != nil { + valueVars := rule.Head.Value.Vars() + used.Update(valueVars) + for arg := range unusedArgs { + if valueVars.Contains(arg) { + delete(unusedArgs, arg) } + } + } - return false - }) + stack := argsStack.Copy() + + body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body, c.strict) + + // For rewritten vars use the collection of all variables that + // were in the stack at some point in time. + for k, v := range stack.rewritten { + c.RewrittenVars[k] = v + } + + rule.Body = body + + // Rewrite vars in head that refer to locally declared vars in the body. + localXform := rewriteHeadVarLocalTransform{declared: declared} + + for i := range rule.Head.Args { + rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i]) + } + + for i := 1; i < len(rule.Head.Ref()); i++ { + rule.Head.Reference[i], _ = transformTerm(localXform, rule.Head.Ref()[i]) + } + if rule.Head.Key != nil { + rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key) + } + + if rule.Head.Value != nil { + rule.Head.Value, _ = transformTerm(localXform, rule.Head.Value) } + return stack, errs } type rewriteNestedHeadVarLocalTransform struct { @@ -4607,6 +4646,35 @@ func newLocalDeclaredVars() *localDeclaredVars { } } +func (s *localDeclaredVars) Copy() *localDeclaredVars { + stack := &localDeclaredVars{ + vars: []*declaredVarSet{}, + rewritten: map[Var]Var{}, + } + + for i := range s.vars { + stack.vars = append(stack.vars, newDeclaredVarSet()) + for k, v := range s.vars[i].vs { + stack.vars[0].vs[k] = v + } + for k, v := range s.vars[i].reverse { + stack.vars[0].reverse[k] = v + } + for k, v := range s.vars[i].count { + stack.vars[0].count[k] = v + } + for k, v := range s.vars[i].occurrence { + stack.vars[0].occurrence[k] = v + } + } + + for k, v := range s.rewritten { + stack.rewritten[k] = v + } + + return stack +} + func (s *localDeclaredVars) Push() { s.vars = append(s.vars, newDeclaredVarSet()) } @@ -4701,7 +4769,7 @@ func (s localDeclaredVars) Count(x Var) int { func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, strict bool) (Body, map[Var]Var, Errors) { var errs Errors body, errs = rewriteDeclaredVarsInBody(g, stack, used, body, errs, strict) - return body, stack.Pop().vs, errs + return body, stack.Peek().vs, errs } func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, errs Errors, strict bool) (Body, Errors) { @@ -4732,11 +4800,11 @@ func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, u cpy.Append(NewExpr(BooleanTerm(true))) } - errs = checkUnusedAssignedAndArgVars(body[0].Loc(), stack, used, errs, strict) + errs = checkUnusedAssignedVars(body[0].Loc(), stack, used, errs, strict) return cpy, checkUnusedDeclaredVars(body, stack, used, cpy, errs) } -func checkUnusedAssignedAndArgVars(loc *Location, stack *localDeclaredVars, used VarSet, errs Errors, strict bool) Errors { +func checkUnusedAssignedVars(loc *Location, stack *localDeclaredVars, used VarSet, errs Errors, strict bool) Errors { if !strict || len(errs) > 0 { return errs @@ -4748,7 +4816,7 @@ func checkUnusedAssignedAndArgVars(loc *Location, stack *localDeclaredVars, used for v, occ := range dvs.occurrence { // A var that was assigned in this scope must have been seen (used) more than once (the time of assignment) in // the same, or nested, scope to be counted as used. - if !v.IsWildcard() && stack.Count(v) <= 1 && (occ == assignedVar || occ == argVar) { + if !v.IsWildcard() && stack.Count(v) <= 1 && occ == assignedVar { unused.Add(dvs.vs[v]) } } diff --git a/ast/compile_test.go b/ast/compile_test.go index d0570b0be5..6ee9d18775 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -4195,16 +4195,15 @@ func TestCompilerRewriteLocalAssignments(t *testing.T) { module := MustParseModule(` package test - f(__local0__) = __local1__ { __local0__ == 1; __local1__ = 2 } else = __local3__ { __local2__ == 3; __local3__ = 4 } + f(__local0__) = __local1__ { __local0__ == 1; __local1__ = 2 } else = __local2__ { __local0__ == 3; __local2__ = 4 } `) - module.Rules[0].Else.Head.Args[0].Value = Var("__local2__") + module.Rules[0].Else.Head.Args[0].Value = Var("__local0__") return module }, expRewrittenMap: map[Var]Var{ Var("__local0__"): Var("x"), Var("__local1__"): Var("y"), - Var("__local2__"): Var("x"), - Var("__local3__"): Var("y"), + Var("__local2__"): Var("y"), }, }, { @@ -5053,7 +5052,7 @@ func TestRewriteDeclaredVars(t *testing.T) { } } -func TestCheckUnusedAssignedAndArgVars(t *testing.T) { +func TestCheckUnusedFunctionArgVars(t *testing.T) { tests := []strictnessTestCase{ { note: "one of the two function args is not used - issue 5602 regression test", @@ -5064,8 +5063,22 @@ func TestCheckUnusedAssignedAndArgVars(t *testing.T) { expectedErrors: Errors{ &Error{ Code: CompileErr, - Location: NewLocation([]byte("x = 1"), "", 3, 5), - Message: "assigned var y unused", + Location: NewLocation([]byte("func(x, y)"), "", 2, 4), + Message: "unused argument y", + }, + }, + }, + { + note: "one of the two ref-head function args is not used", + module: `package test + a.b.c.func(x, y) { + x = 1 + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("a.b.c.func(x, y)"), "", 2, 4), + Message: "unused argument y", }, }, }, @@ -5079,13 +5092,13 @@ func TestCheckUnusedAssignedAndArgVars(t *testing.T) { expectedErrors: Errors{ &Error{ Code: CompileErr, - Location: NewLocation([]byte("input.baz = 1"), "", 3, 5), - Message: "assigned var x unused", + Location: NewLocation([]byte("func(x, y)"), "", 2, 4), + Message: "unused argument x", }, &Error{ Code: CompileErr, - Location: NewLocation([]byte("input.baz = 1"), "", 3, 5), - Message: "assigned var y unused", + Location: NewLocation([]byte("func(x, y)"), "", 2, 4), + Message: "unused argument y", }, }, }, @@ -5099,8 +5112,8 @@ func TestCheckUnusedAssignedAndArgVars(t *testing.T) { expectedErrors: Errors{ &Error{ Code: CompileErr, - Location: NewLocation([]byte("input.test == \"foo\""), "", 3, 5), - Message: "assigned var y unused", + Location: NewLocation([]byte("func(x, y)"), "", 2, 4), + Message: "unused argument y", }, }, }, @@ -5122,8 +5135,8 @@ func TestCheckUnusedAssignedAndArgVars(t *testing.T) { expectedErrors: Errors{ &Error{ Code: CompileErr, - Location: NewLocation([]byte("input.test == \"foo\""), "", 3, 5), - Message: "assigned var x unused", + Location: NewLocation([]byte("func(x, _)"), "", 2, 4), + Message: "unused argument x", }, }, }, @@ -5135,6 +5148,144 @@ func TestCheckUnusedAssignedAndArgVars(t *testing.T) { }`, expectedErrors: Errors{}, }, + { + note: "argvar not used in body but in head value comprehension", + module: `package test + a := {"foo": 1} + func(x) := { x: v | v := a[x] } { + input.test == "foo" + }`, + expectedErrors: Errors{}, + }, + { + note: "argvar not used in body and shadowed in head value comprehension", + module: `package test + a := {"foo": 1} + func(x) := { x: v | x := "foo"; v := a[x] } { + input.test == "foo" + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("func(x) := { x: v | x := \"foo\"; v := a[x] }"), "", 3, 4), + Message: "unused argument x", + }, + }, + }, + { + note: "argvar used in primary body but not in else body", + module: `package test + func(x) { + input.test == x + } else := false { + input.test == "foo" + }`, + expectedErrors: Errors{}, + }, + { + note: "argvar used in primary body but not in else body (with wildcard)", + module: `package test + func(x, _) { + input.test == x + } else := false { + input.test == "foo" + }`, + expectedErrors: Errors{}, + }, + { + note: "argvar not used in primary body but in else body", + module: `package test + func(x) { + input.test == "foo" + } else := false { + input.test == x + }`, + expectedErrors: Errors{}, + }, + { + note: "argvar not used in primary body but in else body (with wildcard)", + module: `package test + func(x, _) { + input.test == "foo" + } else := false { + input.test == x + }`, + expectedErrors: Errors{}, + }, + { + note: "argvar used in primary body but not in implicit else body", + module: `package test + func(x) { + input.test == x + } else := false`, + expectedErrors: Errors{}, + }, + { + note: "argvars usage spread over multiple bodies", + module: `package test + func(x, y, z) { + input.test == x + } else { + input.test == y + } else { + input.test == z + }`, + expectedErrors: Errors{}, + }, + { + note: "argvars usage spread over multiple bodies, missing in first", + module: `package test + func(x, y, z) { + input.test == "foo" + } else { + input.test == y + } else { + input.test == z + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("func(x, y, z)"), "", 2, 4), + Message: "unused argument x", + }, + }, + }, + { + note: "argvars usage spread over multiple bodies, missing in second", + module: `package test + func(x, y, z) { + input.test == x + } else { + input.test == "bar" + } else { + input.test == z + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("func(x, y, z)"), "", 2, 4), + Message: "unused argument y", + }, + }, + }, + { + note: "argvars usage spread over multiple bodies, missing in third", + module: `package test + func(x, y, z) { + input.test == x + } else { + input.test == y + } else { + input.test == "baz" + }`, + expectedErrors: Errors{ + &Error{ + Code: CompileErr, + Location: NewLocation([]byte("func(x, y, z)"), "", 2, 4), + Message: "unused argument z", + }, + }, + }, } t.Helper() @@ -5968,14 +6119,14 @@ func TestRewritePrintCallsWithElseImplicitArgs(t *testing.T) { exp := MustParseModuleWithOpts(`package test f(__local0__, __local1__) = true { __local0__ = __local1__ } - else = false { __local6__ = {__local4__ | __local4__ = __local2__}; __local7__ = {__local5__ | __local5__ = __local3__}; internal.print([__local6__, __local7__]) } + else = false { __local4__ = {__local2__ | __local2__ = __local0__}; __local5__ = {__local3__ | __local3__ = __local1__}; internal.print([__local4__, __local5__]) } `, opts) // NOTE(tsandall): we have to patch the implicit args on the else rule // because of how the parser copies the arg names across from the first // rule. - exp.Rules[0].Else.Head.Args[0] = VarTerm("__local2__") - exp.Rules[0].Else.Head.Args[1] = VarTerm("__local3__") + exp.Rules[0].Else.Head.Args[0] = VarTerm("__local0__") + exp.Rules[0].Else.Head.Args[1] = VarTerm("__local1__") if !exp.Equal(c.Modules["test.rego"]) { t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", exp, c.Modules["test.rego"])