diff --git a/gopls/internal/analysis/fillswitch/doc.go b/gopls/internal/analysis/fillswitch/doc.go new file mode 100644 index 00000000000..076c3a1323d --- /dev/null +++ b/gopls/internal/analysis/fillswitch/doc.go @@ -0,0 +1,66 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package fillswitch identifies switches with missing cases. +// +// It reports a diagnostic for each type switch or 'enum' switch that +// has missing cases, and suggests a fix to fill them in. +// +// The possible cases are: for a type switch, each accessible named +// type T or pointer *T that is assignable to the interface type; and +// for an 'enum' switch, each accessible named constant of the same +// type as the switch value. +// +// For an 'enum' switch, it will suggest cases for all possible values of the +// type. +// +// type Suit int8 +// const ( +// Spades Suit = iota +// Hearts +// Diamonds +// Clubs +// ) +// +// var s Suit +// switch s { +// case Spades: +// } +// +// It will report a diagnostic with a suggested fix to fill in the remaining +// cases: +// +// var s Suit +// switch s { +// case Spades: +// case Hearts: +// case Diamonds: +// case Clubs: +// default: +// panic(fmt.Sprintf("unexpected Suit: %v", s)) +// } +// +// For a type switch, it will suggest cases for all types that implement the +// interface. +// +// var stmt ast.Stmt +// switch stmt.(type) { +// case *ast.IfStmt: +// } +// +// It will report a diagnostic with a suggested fix to fill in the remaining +// cases: +// +// var stmt ast.Stmt +// switch stmt.(type) { +// case *ast.IfStmt: +// case *ast.ForStmt: +// case *ast.RangeStmt: +// case *ast.AssignStmt: +// case *ast.GoStmt: +// ... +// default: +// panic(fmt.Sprintf("unexpected ast.Stmt: %T", stmt)) +// } +package fillswitch diff --git a/gopls/internal/analysis/fillswitch/fillswitch.go b/gopls/internal/analysis/fillswitch/fillswitch.go new file mode 100644 index 00000000000..b93ade01065 --- /dev/null +++ b/gopls/internal/analysis/fillswitch/fillswitch.go @@ -0,0 +1,301 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fillswitch + +import ( + "bytes" + "fmt" + "go/ast" + "go/token" + "go/types" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/ast/inspector" +) + +// Diagnose computes diagnostics for switch statements with missing cases +// overlapping with the provided start and end position. +// +// If either start or end is invalid, the entire package is inspected. +func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Package, info *types.Info) []analysis.Diagnostic { + var diags []analysis.Diagnostic + nodeFilter := []ast.Node{(*ast.SwitchStmt)(nil), (*ast.TypeSwitchStmt)(nil)} + inspect.Preorder(nodeFilter, func(n ast.Node) { + if start.IsValid() && n.End() < start || + end.IsValid() && n.Pos() > end { + return // non-overlapping + } + + var fix *analysis.SuggestedFix + switch n := n.(type) { + case *ast.SwitchStmt: + fix = suggestedFixSwitch(n, pkg, info) + case *ast.TypeSwitchStmt: + fix = suggestedFixTypeSwitch(n, pkg, info) + } + + if fix == nil { + return + } + + diags = append(diags, analysis.Diagnostic{ + Message: fix.Message, + Pos: n.Pos(), + End: n.Pos() + token.Pos(len("switch")), + SuggestedFixes: []analysis.SuggestedFix{*fix}, + }) + }) + + return diags +} + +func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *types.Info) *analysis.SuggestedFix { + if hasDefaultCase(stmt.Body) { + return nil + } + + namedType := namedTypeFromTypeSwitch(stmt, info) + if namedType == nil { + return nil + } + + existingCases := caseTypes(stmt.Body, info) + // Gather accessible package-level concrete types + // that implement the switch interface type. + scope := namedType.Obj().Pkg().Scope() + var buf bytes.Buffer + for _, name := range scope.Names() { + obj := scope.Lookup(name) + if tname, ok := obj.(*types.TypeName); !ok || tname.IsAlias() { + continue // not a defined type + } + + if types.IsInterface(obj.Type()) { + continue + } + + samePkg := obj.Pkg() == pkg + if !samePkg && !obj.Exported() { + continue // inaccessible + } + + var key caseType + if types.AssignableTo(obj.Type(), namedType.Obj().Type()) { + key.named = obj.Type().(*types.Named) + } else if ptr := types.NewPointer(obj.Type()); types.AssignableTo(ptr, namedType.Obj().Type()) { + key.named = obj.Type().(*types.Named) + key.ptr = true + } + + if key.named != nil { + if existingCases[key] { + continue + } + + if buf.Len() > 0 { + buf.WriteString("\t") + } + + buf.WriteString("case ") + if key.ptr { + buf.WriteByte('*') + } + + if p := key.named.Obj().Pkg(); p != pkg { + // TODO: use the correct package name when the import is renamed + buf.WriteString(p.Name()) + buf.WriteByte('.') + } + buf.WriteString(key.named.Obj().Name()) + buf.WriteString(":\n") + } + } + + if buf.Len() == 0 { + return nil + } + + switch assign := stmt.Assign.(type) { + case *ast.AssignStmt: + addDefaultCase(&buf, namedType, assign.Lhs[0]) + case *ast.ExprStmt: + if assert, ok := assign.X.(*ast.TypeAssertExpr); ok { + addDefaultCase(&buf, namedType, assert.X) + } + } + + return &analysis.SuggestedFix{ + Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), + TextEdits: []analysis.TextEdit{{ + Pos: stmt.End() - token.Pos(len("}")), + End: stmt.End() - token.Pos(len("}")), + NewText: buf.Bytes(), + }}, + } +} + +func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.Info) *analysis.SuggestedFix { + if hasDefaultCase(stmt.Body) { + return nil + } + + namedType, ok := info.TypeOf(stmt.Tag).(*types.Named) + if !ok { + return nil + } + + existingCases := caseConsts(stmt.Body, info) + // Gather accessible named constants of the same type as the switch value. + scope := namedType.Obj().Pkg().Scope() + var buf bytes.Buffer + for _, name := range scope.Names() { + obj := scope.Lookup(name) + if c, ok := obj.(*types.Const); ok && + (obj.Pkg() == pkg || obj.Exported()) && // accessible + types.Identical(obj.Type(), namedType.Obj().Type()) && + !existingCases[c] { + + if buf.Len() > 0 { + buf.WriteString("\t") + } + + buf.WriteString("case ") + if c.Pkg() != pkg { + buf.WriteString(c.Pkg().Name()) + buf.WriteByte('.') + } + buf.WriteString(c.Name()) + buf.WriteString(":\n") + } + } + + if buf.Len() == 0 { + return nil + } + + addDefaultCase(&buf, namedType, stmt.Tag) + + return &analysis.SuggestedFix{ + Message: fmt.Sprintf("Add cases for %s", namedType.Obj().Name()), + TextEdits: []analysis.TextEdit{{ + Pos: stmt.End() - token.Pos(len("}")), + End: stmt.End() - token.Pos(len("}")), + NewText: buf.Bytes(), + }}, + } +} + +func addDefaultCase(buf *bytes.Buffer, named *types.Named, expr ast.Expr) { + var dottedBuf bytes.Buffer + // writeDotted emits a dotted path a.b.c. + var writeDotted func(e ast.Expr) bool + writeDotted = func(e ast.Expr) bool { + switch e := e.(type) { + case *ast.SelectorExpr: + if !writeDotted(e.X) { + return false + } + dottedBuf.WriteByte('.') + dottedBuf.WriteString(e.Sel.Name) + return true + case *ast.Ident: + dottedBuf.WriteString(e.Name) + return true + } + return false + } + + buf.WriteString("\tdefault:\n") + typeName := fmt.Sprintf("%s.%s", named.Obj().Pkg().Name(), named.Obj().Name()) + if writeDotted(expr) { + // Switch tag expression is a dotted path. + // It is safe to re-evaluate it in the default case. + format := fmt.Sprintf("unexpected %s: %%#v", typeName) + fmt.Fprintf(buf, "\t\tpanic(fmt.Sprintf(%q, %s))\n\t", format, dottedBuf.String()) + } else { + // Emit simpler message, without re-evaluating tag expression. + fmt.Fprintf(buf, "\t\tpanic(%q)\n\t", "unexpected "+typeName) + } +} + +func namedTypeFromTypeSwitch(stmt *ast.TypeSwitchStmt, info *types.Info) *types.Named { + switch assign := stmt.Assign.(type) { + case *ast.ExprStmt: + if typ, ok := assign.X.(*ast.TypeAssertExpr); ok { + if named, ok := info.TypeOf(typ.X).(*types.Named); ok { + return named + } + } + + case *ast.AssignStmt: + if typ, ok := assign.Rhs[0].(*ast.TypeAssertExpr); ok { + if named, ok := info.TypeOf(typ.X).(*types.Named); ok { + return named + } + } + } + + return nil +} + +func hasDefaultCase(body *ast.BlockStmt) bool { + for _, clause := range body.List { + if len(clause.(*ast.CaseClause).List) == 0 { + return true + } + } + + return false +} + +func caseConsts(body *ast.BlockStmt, info *types.Info) map[*types.Const]bool { + out := map[*types.Const]bool{} + for _, stmt := range body.List { + for _, e := range stmt.(*ast.CaseClause).List { + if info.Types[e].Value == nil { + continue // not a constant + } + + if sel, ok := e.(*ast.SelectorExpr); ok { + e = sel.Sel // replace pkg.C with C + } + + if e, ok := e.(*ast.Ident); ok { + if c, ok := info.Uses[e].(*types.Const); ok { + out[c] = true + } + } + } + } + + return out +} + +type caseType struct { + named *types.Named + ptr bool +} + +func caseTypes(body *ast.BlockStmt, info *types.Info) map[caseType]bool { + out := map[caseType]bool{} + for _, stmt := range body.List { + for _, e := range stmt.(*ast.CaseClause).List { + if tv, ok := info.Types[e]; ok && tv.IsType() { + t := tv.Type + ptr := false + if p, ok := t.(*types.Pointer); ok { + t = p.Elem() + ptr = true + } + + if named, ok := t.(*types.Named); ok { + out[caseType{named, ptr}] = true + } + } + } + } + + return out +} diff --git a/gopls/internal/analysis/fillswitch/fillswitch_test.go b/gopls/internal/analysis/fillswitch/fillswitch_test.go new file mode 100644 index 00000000000..15d3ef1dd70 --- /dev/null +++ b/gopls/internal/analysis/fillswitch/fillswitch_test.go @@ -0,0 +1,38 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fillswitch_test + +import ( + "go/token" + "testing" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/analysistest" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" + "golang.org/x/tools/gopls/internal/analysis/fillswitch" +) + +// analyzer allows us to test the fillswitch code action using the analysistest +// harness. +var analyzer = &analysis.Analyzer{ + Name: "fillswitch", + Doc: "test only", + Requires: []*analysis.Analyzer{inspect.Analyzer}, + Run: func(pass *analysis.Pass) (any, error) { + inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + for _, d := range fillswitch.Diagnose(inspect, token.NoPos, token.NoPos, pass.Pkg, pass.TypesInfo) { + pass.Report(d) + } + return nil, nil + }, + URL: "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/fillswitch", + RunDespiteErrors: true, +} + +func Test(t *testing.T) { + testdata := analysistest.TestData() + analysistest.Run(t, testdata, analyzer, "a") +} diff --git a/gopls/internal/analysis/fillswitch/testdata/src/a/a.go b/gopls/internal/analysis/fillswitch/testdata/src/a/a.go new file mode 100644 index 00000000000..06d01da5f1e --- /dev/null +++ b/gopls/internal/analysis/fillswitch/testdata/src/a/a.go @@ -0,0 +1,78 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fillswitch + +import ( + data "b" +) + +type typeA int + +const ( + typeAOne typeA = iota + typeATwo + typeAThree +) + +func doSwitch() { + var a typeA + switch a { // want `Add cases for typeA` + } + + switch a { // want `Add cases for typeA` + case typeAOne: + } + + switch a { + case typeAOne: + default: + } + + switch a { + case typeAOne: + case typeATwo: + case typeAThree: + } + + var b data.TypeB + switch b { // want `Add cases for TypeB` + case data.TypeBOne: + } +} + +type notification interface { + isNotification() +} + +type notificationOne struct{} + +func (notificationOne) isNotification() {} + +type notificationTwo struct{} + +func (notificationTwo) isNotification() {} + +func doTypeSwitch() { + var not notification + switch not.(type) { // want `Add cases for notification` + } + + switch not.(type) { // want `Add cases for notification` + case notificationOne: + } + + switch not.(type) { + case notificationOne: + case notificationTwo: + } + + switch not.(type) { + default: + } + + var t data.ExportedInterface + switch t { + } +} diff --git a/gopls/internal/analysis/fillswitch/testdata/src/b/b.go b/gopls/internal/analysis/fillswitch/testdata/src/b/b.go new file mode 100644 index 00000000000..f65f3a7e6f2 --- /dev/null +++ b/gopls/internal/analysis/fillswitch/testdata/src/b/b.go @@ -0,0 +1,21 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fillswitch + +type TypeB int + +const ( + TypeBOne TypeB = iota + TypeBTwo + TypeBThree +) + +type ExportedInterface interface { + isExportedInterface() +} + +type notExportedType struct{} + +func (notExportedType) isExportedInterface() {} diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go index cab1b42f4b7..fa876ac474c 100644 --- a/gopls/internal/golang/codeaction.go +++ b/gopls/internal/golang/codeaction.go @@ -13,6 +13,7 @@ import ( "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/gopls/internal/analysis/fillstruct" + "golang.org/x/tools/gopls/internal/analysis/fillswitch" "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/cache/parsego" "golang.org/x/tools/gopls/internal/file" @@ -98,7 +99,7 @@ func CodeActions(ctx context.Context, snapshot *cache.Snapshot, fh file.Handle, return nil, err } if want[protocol.RefactorRewrite] { - rewrites, err := getRewriteCodeActions(pkg, pgf, fh, rng, snapshot.Options()) + rewrites, err := getRewriteCodeActions(ctx, pkg, snapshot, pgf, fh, rng, snapshot.Options()) if err != nil { return nil, err } @@ -252,8 +253,7 @@ func newCodeAction(title string, kind protocol.CodeActionKind, cmd *protocol.Com return action } -// getRewriteCodeActions returns refactor.rewrite code actions available at the specified range. -func getRewriteCodeActions(pkg *cache.Package, pgf *parsego.File, fh file.Handle, rng protocol.Range, options *settings.Options) (_ []protocol.CodeAction, rerr error) { +func getRewriteCodeActions(ctx context.Context, pkg *cache.Package, snapshot *cache.Snapshot, pgf *parsego.File, fh file.Handle, rng protocol.Range, options *settings.Options) (_ []protocol.CodeAction, rerr error) { // golang/go#61693: code actions were refactored to run outside of the // analysis framework, but as a result they lost their panic recovery. // @@ -354,6 +354,28 @@ func getRewriteCodeActions(pkg *cache.Package, pgf *parsego.File, fh file.Handle } } + for _, diag := range fillswitch.Diagnose(inspect, start, end, pkg.GetTypes(), pkg.GetTypesInfo()) { + edits, err := suggestedFixToEdits(ctx, snapshot, pkg.FileSet(), &diag.SuggestedFixes[0]) + if err != nil { + return nil, err + } + + changes := []protocol.DocumentChanges{} // must be a slice + for _, edit := range edits { + edit := edit + changes = append(changes, protocol.DocumentChanges{ + TextDocumentEdit: &edit, + }) + } + + actions = append(actions, protocol.CodeAction{ + Title: diag.Message, + Kind: protocol.RefactorRewrite, + Edit: &protocol.WorkspaceEdit{ + DocumentChanges: changes, + }, + }) + } for i := range commands { actions = append(actions, newCodeAction(commands[i].Title, protocol.RefactorRewrite, &commands[i], nil, options)) } diff --git a/gopls/internal/test/marker/testdata/codeaction/fill_switch.txt b/gopls/internal/test/marker/testdata/codeaction/fill_switch.txt new file mode 100644 index 00000000000..2c1b19e130c --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/fill_switch.txt @@ -0,0 +1,105 @@ +This test checks the behavior of the 'fill switch' code action. +See fill_switch_resolve.txt for same test with resolve support. + +-- flags -- +-ignore_extra_diags + +-- go.mod -- +module golang.org/lsptests/fillswitch + +go 1.18 + +-- data/data.go -- +package data + +type TypeB int + +const ( + TypeBOne TypeB = iota + TypeBTwo + TypeBThree +) + +-- a.go -- +package fillswitch + +import ( + "golang.org/lsptests/fillswitch/data" +) + +type typeA int + +const ( + typeAOne typeA = iota + typeATwo + typeAThree +) + +type notification interface { + isNotification() +} + +type notificationOne struct{} + +func (notificationOne) isNotification() {} + +type notificationTwo struct{} + +func (notificationTwo) isNotification() {} + +func doSwitch() { + var b data.TypeB + switch b { + case data.TypeBOne: //@codeactionedit(":", "refactor.rewrite", a1) + } + + var a typeA + switch a { + case typeAThree: //@codeactionedit(":", "refactor.rewrite", a2) + } + + var n notification + switch n.(type) { //@codeactionedit("{", "refactor.rewrite", a3) + } + + switch nt := n.(type) { //@codeactionedit("{", "refactor.rewrite", a4) + } + + var s struct { + a typeA + } + + switch s.a { + case typeAThree: //@codeactionedit(":", "refactor.rewrite", a5) + } +} +-- @a1/a.go -- +@@ -31 +31,4 @@ ++ case data.TypeBThree: ++ case data.TypeBTwo: ++ default: ++ panic(fmt.Sprintf("unexpected data.TypeB: %#v", b)) +-- @a2/a.go -- +@@ -36 +36,4 @@ ++ case typeAOne: ++ case typeATwo: ++ default: ++ panic(fmt.Sprintf("unexpected fillswitch.typeA: %#v", a)) +-- @a3/a.go -- +@@ -40 +40,4 @@ ++ case notificationOne: ++ case notificationTwo: ++ default: ++ panic(fmt.Sprintf("unexpected fillswitch.notification: %#v", n)) +-- @a4/a.go -- +@@ -43 +43,4 @@ ++ case notificationOne: ++ case notificationTwo: ++ default: ++ panic(fmt.Sprintf("unexpected fillswitch.notification: %#v", nt)) +-- @a5/a.go -- +@@ -51 +51,4 @@ ++ case typeAOne: ++ case typeATwo: ++ default: ++ panic(fmt.Sprintf("unexpected fillswitch.typeA: %#v", s.a)) diff --git a/gopls/internal/test/marker/testdata/codeaction/fill_switch_resolve.txt b/gopls/internal/test/marker/testdata/codeaction/fill_switch_resolve.txt new file mode 100644 index 00000000000..504acd6043e --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/fill_switch_resolve.txt @@ -0,0 +1,116 @@ +This test checks the behavior of the 'fill switch' code action, with resolve support. +See fill_switch.txt for same test without resolve support. + +-- capabilities.json -- +{ + "textDocument": { + "codeAction": { + "dataSupport": true, + "resolveSupport": { + "properties": ["edit"] + } + } + } +} +-- flags -- +-ignore_extra_diags + +-- go.mod -- +module golang.org/lsptests/fillswitch + +go 1.18 + +-- data/data.go -- +package data + +type TypeB int + +const ( + TypeBOne TypeB = iota + TypeBTwo + TypeBThree +) + +-- a.go -- +package fillswitch + +import ( + "golang.org/lsptests/fillswitch/data" +) + +type typeA int + +const ( + typeAOne typeA = iota + typeATwo + typeAThree +) + +type notification interface { + isNotification() +} + +type notificationOne struct{} + +func (notificationOne) isNotification() {} + +type notificationTwo struct{} + +func (notificationTwo) isNotification() {} + +func doSwitch() { + var b data.TypeB + switch b { + case data.TypeBOne: //@codeactionedit(":", "refactor.rewrite", a1) + } + + var a typeA + switch a { + case typeAThree: //@codeactionedit(":", "refactor.rewrite", a2) + } + + var n notification + switch n.(type) { //@codeactionedit("{", "refactor.rewrite", a3) + } + + switch nt := n.(type) { //@codeactionedit("{", "refactor.rewrite", a4) + } + + var s struct { + a typeA + } + + switch s.a { + case typeAThree: //@codeactionedit(":", "refactor.rewrite", a5) + } +} +-- @a1/a.go -- +@@ -31 +31,4 @@ ++ case data.TypeBThree: ++ case data.TypeBTwo: ++ default: ++ panic(fmt.Sprintf("unexpected data.TypeB: %#v", b)) +-- @a2/a.go -- +@@ -36 +36,4 @@ ++ case typeAOne: ++ case typeATwo: ++ default: ++ panic(fmt.Sprintf("unexpected fillswitch.typeA: %#v", a)) +-- @a3/a.go -- +@@ -40 +40,4 @@ ++ case notificationOne: ++ case notificationTwo: ++ default: ++ panic(fmt.Sprintf("unexpected fillswitch.notification: %#v", n)) +-- @a4/a.go -- +@@ -43 +43,4 @@ ++ case notificationOne: ++ case notificationTwo: ++ default: ++ panic(fmt.Sprintf("unexpected fillswitch.notification: %#v", nt)) +-- @a5/a.go -- +@@ -51 +51,4 @@ ++ case typeAOne: ++ case typeATwo: ++ default: ++ panic(fmt.Sprintf("unexpected fillswitch.typeA: %#v", s.a))