From d362e751fc1a060ed0320ab899fe23c70bc93402 Mon Sep 17 00:00:00 2001 From: Martin Asquino Date: Fri, 2 Feb 2024 00:30:53 +0000 Subject: [PATCH] gopls: add fill switch cases code action --- .../analysis/fillswitch/fillswitch.go | 350 ++++++++++++++++++ .../analysis/fillswitch/fillswitch_test.go | 38 ++ .../fillswitch/testdata/switch/switch.go | 34 ++ .../testdata/typeswitch/typeswitch.go | 36 ++ gopls/internal/golang/codeaction.go | 20 + gopls/internal/golang/fix.go | 2 + 6 files changed, 480 insertions(+) create mode 100644 gopls/internal/analysis/fillswitch/fillswitch.go create mode 100644 gopls/internal/analysis/fillswitch/fillswitch_test.go create mode 100644 gopls/internal/analysis/fillswitch/testdata/switch/switch.go create mode 100644 gopls/internal/analysis/fillswitch/testdata/typeswitch/typeswitch.go diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go new file mode 100644 index 00000000000..2b41adbdb2e --- /dev/null +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -0,0 +1,350 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package fillswitch defines an Analyzer that automatically +// fills the missing cases in type switches or switches over named types. +// +// The analyzer's diagnostic is merely a prompt. +// The actual fix is created by a separate direct call from gopls to +// the SuggestedFixes function. +// Tests of Analyzer.Run can be found in ./testdata/src. +// Tests of the SuggestedFixes logic live in ../../testdata/fillswitch. +package fillswitch + +import ( + "bytes" + "context" + "errors" + "fmt" + "go/ast" + "go/token" + "go/types" + "slices" + "strings" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/go/ast/inspector" + "golang.org/x/tools/gopls/internal/cache" + "golang.org/x/tools/gopls/internal/cache/parsego" +) + +const FixCategory = "fillswitch" // recognized by gopls ApplyFix + +// errNoSuggestedFix is returned when no suggested fix is available. This could +// be because all cases are already covered, or (in the case of a type switch) +// because the remaining cases are for types not accessible by the current +// package. +var errNoSuggestedFix = errors.New("no suggested fix") + +// Diagnose computes diagnostics for switch statements with missing cases +// overlapping with the provided start and end position. +// +// The diagnostic contains a lazy fix; the actual patch is computed +// (via the ApplyFix command) by a call to [SuggestedFix]. +// +// If either start or end is invalid, the entire package is inspected. +func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Package, info *types.Info) []analysis.Diagnostic { + var diags []analysis.Diagnostic + nodeFilter := []ast.Node{(*ast.SwitchStmt)(nil), (*ast.TypeSwitchStmt)(nil)} + inspect.Preorder(nodeFilter, func(n ast.Node) { + if expr, ok := n.(*ast.SwitchStmt); ok { + if (start.IsValid() && expr.End() < start) || (end.IsValid() && expr.Pos() > end) { + return // non-overlapping + } + + if defaultHandled(expr.Body) { + return + } + + namedType, err := namedTypeFromSwitch(expr, info) + if err != nil { + return + } + + if _, err := suggestedFixSwitch(expr, pkg, info); err != nil { + return + } + + diags = append(diags, analysis.Diagnostic{ + Message: "Switch has missing cases", + Pos: expr.Pos(), + End: expr.End(), + Category: FixCategory, + SuggestedFixes: []analysis.SuggestedFix{{ + Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), + // No TextEdits => computed later by gopls. + }}, + }) + } + + if expr, ok := n.(*ast.TypeSwitchStmt); ok { + if (start.IsValid() && expr.End() < start) || (end.IsValid() && expr.Pos() > end) { + return // non-overlapping + } + + if defaultHandled(expr.Body) { + return + } + + namedType, err := namedTypeFromTypeSwitch(expr, info) + if err != nil { + return + } + + if _, err := suggestedFixTypeSwitch(expr, pkg, info); err != nil { + return + } + + diags = append(diags, analysis.Diagnostic{ + Message: "Switch has missing cases", + Pos: expr.Pos(), + End: expr.End(), + Category: FixCategory, + SuggestedFixes: []analysis.SuggestedFix{{ + Message: fmt.Sprintf("Add cases for %v", namedType.Obj().Name()), + // No TextEdits => computed later by gopls. + }}, + }) + } + }) + + return diags +} + +func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { + namedType, err := namedTypeFromTypeSwitch(stmt, info) + if err != nil { + return nil, err + } + + scope := namedType.Obj().Pkg().Scope() + variants := make([]string, 0) + for _, name := range scope.Names() { + obj := scope.Lookup(name) + if _, ok := obj.(*types.TypeName); !ok { + continue + } + + if types.Identical(obj.Type(), namedType.Obj().Type()) { + continue + } + + if types.AssignableTo(obj.Type(), namedType.Obj().Type()) { + if obj.Pkg().Name() != pkg.Name() { + if !obj.Exported() { + continue + } + + variants = append(variants, obj.Pkg().Name()+"."+obj.Name()) + } else { + variants = append(variants, obj.Name()) + } + } else if types.AssignableTo(types.NewPointer(obj.Type()), namedType.Obj().Type()) { + if obj.Pkg().Name() != pkg.Name() { + if !obj.Exported() { + continue + } + + variants = append(variants, "*"+obj.Pkg().Name()+"."+obj.Name()) + } else { + variants = append(variants, "*"+obj.Name()) + } + } + } + + handledVariants := getHandledVariants(stmt.Body) + if len(variants) == 0 || len(variants) == len(handledVariants) { + return nil, errNoSuggestedFix + } + + newText := buildNewText(variants, handledVariants) + return &analysis.SuggestedFix{ + Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), + TextEdits: []analysis.TextEdit{{ + Pos: stmt.End() - 1, + End: stmt.End() - 1, + NewText: indent([]byte(newText), []byte{'\t'}), + }}, + }, nil +} + +func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { + namedType, err := namedTypeFromSwitch(stmt, info) + if err != nil { + return nil, err + } + + scope := namedType.Obj().Pkg().Scope() + variants := make([]string, 0) + for _, name := range scope.Names() { + obj := scope.Lookup(name) + if obj.Id() == namedType.Obj().Id() { + continue + } + + if types.Identical(obj.Type(), namedType.Obj().Type()) { + // TODO: comparing the package name like this feels wrong, is it? + if obj.Pkg().Name() != pkg.Name() { + if !obj.Exported() { + continue + } + + variants = append(variants, obj.Pkg().Name()+"."+obj.Name()) + } else { + variants = append(variants, obj.Name()) + } + } + } + + handledVariants := getHandledVariants(stmt.Body) + if len(variants) == 0 || len(variants) == len(handledVariants) { + return nil, errNoSuggestedFix + } + + newText := buildNewText(variants, handledVariants) + return &analysis.SuggestedFix{ + Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), + TextEdits: []analysis.TextEdit{{ + Pos: stmt.End() - 1, + End: stmt.End() - 1, + NewText: indent([]byte(newText), []byte{'\t'}), + }}, + }, nil +} + +func namedTypeFromSwitch(stmt *ast.SwitchStmt, info *types.Info) (*types.Named, error) { + typ := info.TypeOf(stmt.Tag) + if typ == nil { + return nil, errors.New("expected switch statement to have a tag") + } + + namedType, ok := typ.(*types.Named) + if !ok { + return nil, errors.New("switch statement is not on a named type") + } + + return namedType, nil +} + +func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) (*types.Named, error) { + switch s := stmt.Assign.(type) { + case *ast.ExprStmt: + typ := s.X.(*ast.TypeAssertExpr) + namedType, ok := info.TypeOf(typ.X).(*types.Named) + if !ok { + return nil, errors.New("type switch expression is not on a named type") + } + + return namedType, nil + case *ast.AssignStmt: + for _, expr := range s.Rhs { + typ, ok := expr.(*ast.TypeAssertExpr) + if !ok { + continue + } + + namedType, ok := info.TypeOf(typ.X).(*types.Named) + if !ok { + continue + } + + return namedType, nil + } + + return nil, errors.New("expected type switch expression to have a named type") + default: + return nil, errors.New("node is not a type switch statement") + } +} + +func defaultHandled(body *ast.BlockStmt) bool { + for _, bl := range body.List { + if len(bl.(*ast.CaseClause).List) == 0 { + return true + } + } + + return false +} + +func buildNewText(variants []string, handledVariants []string) string { + var textBuilder strings.Builder + for _, c := range variants { + if slices.Contains(handledVariants, c) { + continue + } + + textBuilder.WriteString("case ") + textBuilder.WriteString(c) + textBuilder.WriteString(":\n") + } + + return textBuilder.String() +} + +func getHandledVariants(body *ast.BlockStmt) []string { + out := make([]string, 0) + for _, bl := range body.List { + for _, c := range bl.(*ast.CaseClause).List { + switch v := c.(type) { + case *ast.Ident: + out = append(out, v.Name) + case *ast.SelectorExpr: + out = append(out, v.X.(*ast.Ident).Name+"."+v.Sel.Name) + case *ast.StarExpr: + out = append(out, "*"+v.X.(*ast.Ident).Name) + } + } + } + + return out +} + +// SuggestedFix computes the suggested fix for the kinds of +// diagnostics produced by the Analyzer above. +func SuggestedFix(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) { + pos := start // don't use the end + path, _ := astutil.PathEnclosingInterval(pgf.File, pos, pos) + if len(path) < 2 { + return nil, nil, fmt.Errorf("no expression found") + } + + switch stmt := path[0].(type) { + case *ast.SwitchStmt: + fix, err := suggestedFixSwitch(stmt, pkg.GetTypes(), pkg.GetTypesInfo()) + if err != nil { + return nil, nil, err + } + + return pkg.FileSet(), fix, nil + case *ast.TypeSwitchStmt: + fix, err := suggestedFixTypeSwitch(stmt, pkg.GetTypes(), pkg.GetTypesInfo()) + if err != nil { + return nil, nil, err + } + + return pkg.FileSet(), fix, nil + default: + return nil, nil, fmt.Errorf("no switch statement found") + } +} + +// indent works line by line through str, prefixing each line with +// prefix. +func indent(str, prefix []byte) []byte { + split := bytes.Split(str, []byte("\n")) + newText := bytes.NewBuffer(nil) + for i, s := range split { + if i != 0 { + newText.Write(prefix) + } + + newText.Write(s) + if i < len(split)-1 { + newText.WriteByte('\n') + } + } + return newText.Bytes() +} diff --git a/gopls/internal/analysis/fillswitch/fillswitch_test.go b/gopls/internal/analysis/fillswitch/fillswitch_test.go new file mode 100644 index 00000000000..9d8e6ad2704 --- /dev/null +++ b/gopls/internal/analysis/fillswitch/fillswitch_test.go @@ -0,0 +1,38 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fillswitch_test + +import ( + "go/token" + "testing" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/analysistest" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" + "golang.org/x/tools/gopls/internal/analysis/fillswitch" +) + +// analyzer allows us to test the fillswitch code action using the analysistest +// harness. (fillswitch used to be a gopls analyzer.) +var analyzer = &analysis.Analyzer{ + Name: "fillswitch", + Doc: "test only", + Requires: []*analysis.Analyzer{inspect.Analyzer}, + Run: func(pass *analysis.Pass) (any, error) { + inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + for _, d := range fillswitch.Diagnose(inspect, token.NoPos, token.NoPos, pass.Pkg, pass.TypesInfo) { + pass.Report(d) + } + return nil, nil + }, + URL: "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/fillswitch", + RunDespiteErrors: true, +} + +func Test(t *testing.T) { + testdata := analysistest.TestData() + analysistest.Run(t, testdata, analyzer, "switch", "typeswitch") +} diff --git a/gopls/internal/analysis/fillswitch/testdata/switch/switch.go b/gopls/internal/analysis/fillswitch/testdata/switch/switch.go new file mode 100644 index 00000000000..e006ae95be4 --- /dev/null +++ b/gopls/internal/analysis/fillswitch/testdata/switch/switch.go @@ -0,0 +1,34 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fillstruct + +type typeA int + +const ( + typeAOne typeOne = iota + typeATwo + typeAThree +) + +func doSwitch() { + var a typeA + switch a { // want `Switch has missing cases` + } + + switch a { // want `Switch has missing cases` + case typeAOne: + } + + switch a { + case typeAOne: + default: + } + + switch a { + case typeAOne: + case typeATwo: + case typeAThree: + } +} diff --git a/gopls/internal/analysis/fillswitch/testdata/typeswitch/typeswitch.go b/gopls/internal/analysis/fillswitch/testdata/typeswitch/typeswitch.go new file mode 100644 index 00000000000..64a4ec12d2d --- /dev/null +++ b/gopls/internal/analysis/fillswitch/testdata/typeswitch/typeswitch.go @@ -0,0 +1,36 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fillswitch + +type notification interface { + isNotification() +} + +type notificationOne struct{} + +func (notificationOne) isNotification() {} + +type notificationTwo struct{} + +func (notificationTwo) isNotification() {} + +func doSwitch() { + var not notification + switch not { // want `Switch has missing cases` + } + + switch not { // want `Switch has missing cases` + case notificationOne: + } + + switch not { + case notificationOne: + case notificationTwo: + } + + switch not { + default: + } +} diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go index a9aae821be8..8365629b067 100644 --- a/gopls/internal/golang/codeaction.go +++ b/gopls/internal/golang/codeaction.go @@ -13,6 +13,7 @@ import ( "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/analysis/fillstruct" + "golang.org/x/tools/gopls/internal/analysis/fillswitch" "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/cache/parsego" "golang.org/x/tools/gopls/internal/file" @@ -328,6 +329,25 @@ func getRewriteCodeActions(pkg *cache.Package, pgf *ParsedGoFile, fh file.Handle } } + for _, diag := range fillswitch.Diagnose(inspect, start, end, pkg.GetTypes(), pkg.GetTypesInfo()) { + rng, err := pgf.Mapper.PosRange(pgf.Tok, diag.Pos, diag.End) + if err != nil { + return nil, err + } + for _, fix := range diag.SuggestedFixes { + cmd, err := command.NewApplyFixCommand(fix.Message, command.ApplyFixArgs{ + Fix: diag.Category, + URI: pgf.URI, + Range: rng, + ResolveEdits: true, + }) + if err != nil { + return nil, err + } + commands = append(commands, cmd) + } + } + for i := range commands { actions = append(actions, newCodeAction(commands[i].Title, protocol.RefactorRewrite, &commands[i], nil, options)) } diff --git a/gopls/internal/golang/fix.go b/gopls/internal/golang/fix.go index 6f07cb869c5..5e4cdfd408e 100644 --- a/gopls/internal/golang/fix.go +++ b/gopls/internal/golang/fix.go @@ -14,6 +14,7 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/gopls/internal/analysis/embeddirective" "golang.org/x/tools/gopls/internal/analysis/fillstruct" + "golang.org/x/tools/gopls/internal/analysis/fillswitch" "golang.org/x/tools/gopls/internal/analysis/stubmethods" "golang.org/x/tools/gopls/internal/analysis/undeclaredname" "golang.org/x/tools/gopls/internal/analysis/unusedparams" @@ -107,6 +108,7 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file fillstruct.FixCategory: singleFile(fillstruct.SuggestedFix), stubmethods.FixCategory: stubMethodsFixer, undeclaredname.FixCategory: singleFile(undeclaredname.SuggestedFix), + fillswitch.FixCategory: fillswitch.SuggestedFix, // Ad-hoc fixers: these are used when the command is // constructed directly by logic in server/code_action.