Skip to content

Commit

Permalink
return edits on diagnostic
Browse files Browse the repository at this point in the history
  • Loading branch information
martskins committed Feb 9, 2024
1 parent bab3012 commit c8fda66
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 93 deletions.
65 changes: 65 additions & 0 deletions gopls/internal/analysis/fillswitch/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package fillswitch identifies switches with missing cases.
//
// It will provide diagnostics for type switches or switches over named types
// that are missing cases and provides a code action to fill those in.
//
// If the switch statement is over a named type, it will suggest cases for all
// const values that are assignable to the named type.
//
// type T int
// const (
// A T = iota
// B
// C
// )
//
// var t T
// switch t {
// case A:
// }
//
// It will provide a diagnostic with a suggested edit to fill in the remaining
// cases:
//
// var t T
// switch t {
// case A:
// case B:
// case C:
// }
//
// If the switch statement is over type of an interface, it will suggest cases for all types
// that implement the interface.
//
// type I interface {
// M()
// }
//
// type T struct{}
// func (t *T) M() {}
//
// type E struct{}
// func (e *E) M() {}
//
// var i I
// switch i.(type) {
// case *T:
// }
//
// It will provide a diagnostic with a suggested edit to fill in the remaining
// cases:
//
// var i I
// switch i.(type) {
// case *T:
// case *E:
// }
//
// The provided diagnostics will only suggest cases for types that are defined
// on the same package as the switch statement, or for types that are exported;
// and it will not suggest any case if the switch handles the default case.
package fillswitch
96 changes: 24 additions & 72 deletions gopls/internal/analysis/fillswitch/fillswitch.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ package fillswitch

import (
"bytes"
"context"
"errors"
"fmt"
"go/ast"
Expand All @@ -24,20 +23,14 @@ import (
"strings"

"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/ast/inspector"
"golang.org/x/tools/gopls/internal/cache"
"golang.org/x/tools/gopls/internal/cache/parsego"
)

const FixCategory = "fillswitch" // recognized by gopls ApplyFix
const FixCategory = "fillswitch"

// Diagnose computes diagnostics for switch statements with missing cases
// overlapping with the provided start and end position.
//
// The diagnostic contains a lazy fix; the actual patch is computed
// (via the ApplyFix command) by a call to [SuggestedFix].
//
// If either start or end is invalid, the entire package is inspected.
func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Package, info *types.Info) []analysis.Diagnostic {
var diags []analysis.Diagnostic
Expand All @@ -50,49 +43,35 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac
return // non-overlapping
}

namedType, err := namedTypeFromSwitch(expr, info)
if err != nil {
return
}

if fix, err := suggestedFixSwitch(expr, pkg, info); err != nil || fix == nil {
fix, err := suggestedFixSwitch(expr, pkg, info)
if err != nil || fix == nil {
return
}

diags = append(diags, analysis.Diagnostic{
Message: "Switch has missing cases",
Pos: expr.Pos(),
End: expr.End(),
Category: FixCategory,
SuggestedFixes: []analysis.SuggestedFix{{
Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()),
// No TextEdits => computed later by gopls.
}},
Message: fix.Message,
Pos: expr.Pos(),
End: expr.End(),
Category: FixCategory,
SuggestedFixes: []analysis.SuggestedFix{*fix},
})
case *ast.TypeSwitchStmt:
if start.IsValid() && expr.End() < start ||
end.IsValid() && expr.Pos() > end {
return // non-overlapping
}

namedType, err := namedTypeFromTypeSwitch(expr, info)
if err != nil {
return
}

if fix, err := suggestedFixTypeSwitch(expr, pkg, info); err != nil || fix == nil {
fix, err := suggestedFixTypeSwitch(expr, pkg, info)
if err != nil || fix == nil {
return
}

diags = append(diags, analysis.Diagnostic{
Message: "Switch has missing cases",
Pos: expr.Pos(),
End: expr.End(),
Category: FixCategory,
SuggestedFixes: []analysis.SuggestedFix{{
Message: fmt.Sprintf("Add cases for %v", namedType.Obj().Name()),
// No TextEdits => computed later by gopls.
}},
Message: fix.Message,
Pos: expr.Pos(),
End: expr.End(),
Category: FixCategory,
SuggestedFixes: []analysis.SuggestedFix{*fix},
})
}
})
Expand Down Expand Up @@ -134,7 +113,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *
}
}

handledVariants := typeSwitchCases(stmt.Body, info)
handledVariants := caseTypes(stmt.Body, info)
if len(variants) == 0 || len(variants) == len(handledVariants) {
return nil, nil
}
Expand All @@ -144,7 +123,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *
TextEdits: []analysis.TextEdit{{
Pos: stmt.End() - 1,
End: stmt.End() - 1,
NewText: buildNewTypesText(variants, handledVariants, pkg),
NewText: buildTypesText(variants, handledVariants, pkg),
}},
}, nil
}
Expand All @@ -170,7 +149,7 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In

samePkg := obj.Pkg() != pkg
if samePkg && !obj.Exported() {
continue
continue // inaccessible
}

if types.Identical(obj.Type(), namedType.Obj().Type()) {
Expand All @@ -188,7 +167,7 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In
TextEdits: []analysis.TextEdit{{
Pos: stmt.End() - 1,
End: stmt.End() - 1,
NewText: buildNewConstsText(variants, handledVariants, pkg),
NewText: buildConstsText(variants, handledVariants, pkg),
}},
}, nil
}
Expand Down Expand Up @@ -252,7 +231,7 @@ func hasDefaultCase(body *ast.BlockStmt) bool {
return false
}

func buildNewConstsText(variants []*types.Const, handledVariants []*types.Const, currentPkg *types.Package) []byte {
func buildConstsText(variants []*types.Const, handledVariants []*types.Const, currentPkg *types.Package) []byte {
var textBuilder strings.Builder
for _, c := range variants {
if slices.Contains(handledVariants, c) {
Expand Down Expand Up @@ -287,7 +266,7 @@ func isSameType(c, t types.Type) bool {
return false
}

func buildNewTypesText(variants []types.Type, handledVariants []types.Type, currentPkg *types.Package) []byte {
func buildTypesText(variants []types.Type, handledVariants []types.Type, currentPkg *types.Package) []byte {
var textBuilder strings.Builder
for _, c := range variants {
if slices.ContainsFunc(handledVariants, func(t types.Type) bool { return isSameType(c, t) }) {
Expand All @@ -309,6 +288,7 @@ func buildNewTypesText(variants []types.Type, handledVariants []types.Type, curr
}

if e.Obj().Pkg() != currentPkg {
// TODO: use the correct package name when the import is renamed
textBuilder.WriteString("*" + e.Obj().Pkg().Name() + "." + e.Obj().Name())
} else {
textBuilder.WriteString("*" + e.Obj().Name())
Expand All @@ -335,6 +315,7 @@ func caseConsts(body *ast.BlockStmt, info *types.Info) []*types.Const {
if !ok {
continue
}

c, ok := obj.(*types.Const)
if !ok {
continue
Expand Down Expand Up @@ -365,7 +346,7 @@ func caseConsts(body *ast.BlockStmt, info *types.Info) []*types.Const {
return out
}

func typeSwitchCases(body *ast.BlockStmt, info *types.Info) []types.Type {
func caseTypes(body *ast.BlockStmt, info *types.Info) []types.Type {
var out []types.Type
for _, stmt := range body.List {
for _, e := range stmt.(*ast.CaseClause).List {
Expand Down Expand Up @@ -421,32 +402,3 @@ func typeSwitchCases(body *ast.BlockStmt, info *types.Info) []types.Type {

return out
}

// SuggestedFix computes the suggested fix for the kinds of
// diagnostics produced by the Analyzer above.
func SuggestedFix(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Package, pgf *parsego.File, start, end token.Pos) (*token.FileSet, *analysis.SuggestedFix, error) {
pos := start // don't use the end
path, _ := astutil.PathEnclosingInterval(pgf.File, pos, pos)
if len(path) < 2 {
return nil, nil, fmt.Errorf("no expression found")
}

switch stmt := path[0].(type) {
case *ast.SwitchStmt:
fix, err := suggestedFixSwitch(stmt, pkg.GetTypes(), pkg.GetTypesInfo())
if err != nil {
return nil, nil, err
}

return pkg.FileSet(), fix, nil
case *ast.TypeSwitchStmt:
fix, err := suggestedFixTypeSwitch(stmt, pkg.GetTypes(), pkg.GetTypesInfo())
if err != nil {
return nil, nil, err
}

return pkg.FileSet(), fix, nil
default:
return nil, nil, fmt.Errorf("no switch statement found")
}
}
10 changes: 5 additions & 5 deletions gopls/internal/analysis/fillswitch/testdata/src/a/a.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ const (

func doSwitch() {
var a typeA
switch a { // want `Switch has missing cases`
switch a { // want `Add cases for typeA`
}

switch a { // want `Switch has missing cases`
switch a { // want `Add cases for typeA`
case typeAOne:
}

Expand All @@ -37,7 +37,7 @@ func doSwitch() {
}

var b data.TypeB
switch b { // want `Switch has missing cases`
switch b { // want `Add cases for TypeB`
case data.TypeBOne:
}
}
Expand All @@ -56,10 +56,10 @@ func (notificationTwo) isNotification() {}

func doTypeSwitch() {
var not notification
switch not.(type) { // want `Switch has missing cases`
switch not.(type) { // want `Add cases for notification`
}

switch not.(type) { // want `Switch has missing cases`
switch not.(type) { // want `Add cases for notification`
case notificationOne:
}

Expand Down
31 changes: 17 additions & 14 deletions gopls/internal/golang/codeaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func CodeActions(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle,
return nil, err
}
if want[protocol.RefactorRewrite] {
rewrites, err := getRewriteCodeActions(pkg, pgf, fh, rng, snapshot.Options())
rewrites, err := getRewriteCodeActions(ctx, pkg, snapshot, pgf, fh, rng, snapshot.Options())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -254,7 +254,7 @@ func newCodeAction(title string, kind protocol.CodeActionKind, cmd *protocol.Com
}

// getRewriteCodeActions returns refactor.rewrite code actions available at the specified range.
func getRewriteCodeActions(pkg *cache.Package, pgf *ParsedGoFile, fh file.Handle, rng protocol.Range, options *settings.Options) (_ []protocol.CodeAction, rerr error) {
func getRewriteCodeActions(ctx context.Context, pkg *cache.Package, snapshot *cache.Snapshot, pgf *ParsedGoFile, fh file.Handle, rng protocol.Range, options *settings.Options) (_ []protocol.CodeAction, rerr error) {
// golang/go#61693: code actions were refactored to run outside of the
// analysis framework, but as a result they lost their panic recovery.
//
Expand Down Expand Up @@ -330,24 +330,27 @@ func getRewriteCodeActions(pkg *cache.Package, pgf *ParsedGoFile, fh file.Handle
}

for _, diag := range fillswitch.Diagnose(inspect, start, end, pkg.GetTypes(), pkg.GetTypesInfo()) {
rng, err := pgf.Mapper.PosRange(pgf.Tok, diag.Pos, diag.End)
edits, err := suggestedFixToEdits(ctx, snapshot, pkg.FileSet(), &diag.SuggestedFixes[0])
if err != nil {
return nil, err
}
for _, fix := range diag.SuggestedFixes {
cmd, err := command.NewApplyFixCommand(fix.Message, command.ApplyFixArgs{
Fix: diag.Category,
URI: pgf.URI,
Range: rng,
ResolveEdits: supportsResolveEdits(options),

changes := []protocol.DocumentChanges{} // must be a slice
for _, edit := range edits {
edit := edit
changes = append(changes, protocol.DocumentChanges{
TextDocumentEdit: &edit,
})
if err != nil {
return nil, err
}
commands = append(commands, cmd)
}
}

actions = append(actions, protocol.CodeAction{
Title: diag.Message,
Kind: protocol.RefactorRewrite,
Edit: &protocol.WorkspaceEdit{
DocumentChanges: changes,
},
})
}
for i := range commands {
actions = append(actions, newCodeAction(commands[i].Title, protocol.RefactorRewrite, &commands[i], nil, options))
}
Expand Down
2 changes: 0 additions & 2 deletions gopls/internal/golang/fix.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/gopls/internal/analysis/embeddirective"
"golang.org/x/tools/gopls/internal/analysis/fillstruct"
"golang.org/x/tools/gopls/internal/analysis/fillswitch"
"golang.org/x/tools/gopls/internal/analysis/stubmethods"
"golang.org/x/tools/gopls/internal/analysis/undeclaredname"
"golang.org/x/tools/gopls/internal/analysis/unusedparams"
Expand Down Expand Up @@ -108,7 +107,6 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file
fillstruct.FixCategory: singleFile(fillstruct.SuggestedFix),
stubmethods.FixCategory: stubMethodsFixer,
undeclaredname.FixCategory: singleFile(undeclaredname.SuggestedFix),
fillswitch.FixCategory: fillswitch.SuggestedFix,

// Ad-hoc fixers: these are used when the command is
// constructed directly by logic in server/code_action.
Expand Down

0 comments on commit c8fda66

Please sign in to comment.