Skip to content

Commit

Permalink
also check function literals, fixes #19
Browse files Browse the repository at this point in the history
  • Loading branch information
vankleefjim authored and sivchari committed Jun 15, 2022
1 parent a75b385 commit cc152e3
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 35 deletions.
78 changes: 43 additions & 35 deletions tenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,49 +34,58 @@ func run(pass *analysis.Pass) (interface{}, error) {
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)

nodeFilter := []ast.Node{
(*ast.File)(nil),
(*ast.FuncDecl)(nil),
(*ast.FuncLit)(nil),
}

inspect.Preorder(nodeFilter, func(n ast.Node) {
switch n := n.(type) {
case *ast.File:
for _, decl := range n.Decls {

funcDecl, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
checkFunc(pass, funcDecl, pass.Fset.File(n.Pos()).Name())
}
case *ast.FuncDecl:
checkFuncDecl(pass, n, pass.Fset.File(n.Pos()).Name())
case *ast.FuncLit:
checkFuncLit(pass, n, pass.Fset.File(n.Pos()).Name())
}
})

return nil, nil
}

func checkFunc(pass *analysis.Pass, n *ast.FuncDecl, fileName string) {
argName, ok := targetRunner(n, fileName)
if ok {
for _, stmt := range n.Body.List {
switch stmt := stmt.(type) {
case *ast.ExprStmt:
if !checkExprStmt(pass, stmt, n, argName) {
continue
}
case *ast.IfStmt:
if !checkIfStmt(pass, stmt, n, argName) {
continue
}
case *ast.AssignStmt:
if !checkAssignStmt(pass, stmt, n, argName) {
continue
}
func checkFuncDecl(pass *analysis.Pass, f *ast.FuncDecl, fileName string) {
argName, ok := targetRunner(f.Type.Params.List, fileName)
if !ok {
return
}
checkStmts(pass, f.Body.List, f.Name.Name, argName)
}

func checkFuncLit(pass *analysis.Pass, f *ast.FuncLit, fileName string) {
argName, ok := targetRunner(f.Type.Params.List, fileName)
if !ok {
return
}
checkStmts(pass, f.Body.List, "function literal", argName)
}

func checkStmts(pass *analysis.Pass, stmts []ast.Stmt, funcName, argName string) {
for _, stmt := range stmts {
switch stmt := stmt.(type) {
case *ast.ExprStmt:
if !checkExprStmt(pass, stmt, funcName, argName) {
continue
}
case *ast.IfStmt:
if !checkIfStmt(pass, stmt, funcName, argName) {
continue
}
case *ast.AssignStmt:
if !checkAssignStmt(pass, stmt, funcName, argName) {
continue
}
}
}
}

func checkExprStmt(pass *analysis.Pass, stmt *ast.ExprStmt, n *ast.FuncDecl, argName string) bool {
func checkExprStmt(pass *analysis.Pass, stmt *ast.ExprStmt, funcName, argName string) bool {
callExpr, ok := stmt.X.(*ast.CallExpr)
if !ok {
return false
Expand All @@ -94,12 +103,12 @@ func checkExprStmt(pass *analysis.Pass, stmt *ast.ExprStmt, n *ast.FuncDecl, arg
if argName == "" {
argName = "testing"
}
pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, n.Name.Name)
pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, funcName)
}
return true
}

func checkIfStmt(pass *analysis.Pass, stmt *ast.IfStmt, n *ast.FuncDecl, argName string) bool {
func checkIfStmt(pass *analysis.Pass, stmt *ast.IfStmt, funcName, argName string) bool {
assignStmt, ok := stmt.Init.(*ast.AssignStmt)
if !ok {
return false
Expand All @@ -121,12 +130,12 @@ func checkIfStmt(pass *analysis.Pass, stmt *ast.IfStmt, n *ast.FuncDecl, argName
if argName == "" {
argName = "testing"
}
pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, n.Name.Name)
pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, funcName)
}
return true
}

func checkAssignStmt(pass *analysis.Pass, stmt *ast.AssignStmt, n *ast.FuncDecl, argName string) bool {
func checkAssignStmt(pass *analysis.Pass, stmt *ast.AssignStmt, funcName, argName string) bool {
rhs, ok := stmt.Rhs[0].(*ast.CallExpr)
if !ok {
return false
Expand All @@ -144,13 +153,12 @@ func checkAssignStmt(pass *analysis.Pass, stmt *ast.AssignStmt, n *ast.FuncDecl,
if argName == "" {
argName = "testing"
}
pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, n.Name.Name)
pass.Reportf(stmt.Pos(), "os.Setenv() can be replaced by `%s.Setenv()` in %s", argName, funcName)
}
return true
}

func targetRunner(funcDecl *ast.FuncDecl, fileName string) (string, bool) {
params := funcDecl.Type.Params.List
func targetRunner(params []*ast.Field, fileName string) (string, bool) {
for _, p := range params {
switch typ := p.Type.(type) {
case *ast.StarExpr:
Expand Down
12 changes: 12 additions & 0 deletions testdata/src/a/a_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,15 @@ func FuzzF(f *testing.F) {
_ = err
}
}

func TestFunctionLiteral(t *testing.T) {
testsetup()
t.Run("test", func(t *testing.T) {
os.Setenv("a", "b") // want "os\\.Setenv\\(\\) can be replaced by `t\\.Setenv\\(\\)` in function literal"
err := os.Setenv("a", "b") // want "os\\.Setenv\\(\\) can be replaced by `t\\.Setenv\\(\\)` in function literal"
_ = err
if err := os.Setenv("a", "b"); err != nil { // want "os\\.Setenv\\(\\) can be replaced by `t\\.Setenv\\(\\)` in function literal"
_ = err
}
})
}

0 comments on commit cc152e3

Please sign in to comment.