Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
martskins committed Feb 7, 2024
1 parent 2975833 commit 3fcf757
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 69 deletions.
126 changes: 58 additions & 68 deletions gopls/internal/analysis/fillswitch/fillswitch.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package fillswitch defines an Analyzer that automatically
// fills the missing cases in type switches or switches over named types.
// Package fillswitch provides diagnostics and fixes to fills the missing cases
// in type switches or switches over named types.
//
// The analyzer's diagnostic is merely a prompt.
// The actual fix is created by a separate direct call from gopls to
Expand Down Expand Up @@ -32,12 +32,6 @@ import (

const FixCategory = "fillswitch" // recognized by gopls ApplyFix

// errNoSuggestedFix is returned when no suggested fix is available. This could
// be because all cases are already covered, or (in the case of a type switch)
// because the remaining cases are for types not accessible by the current
// package.
var errNoSuggestedFix = errors.New("no suggested fix")

// Diagnose computes diagnostics for switch statements with missing cases
// overlapping with the provided start and end position.
//
Expand All @@ -49,8 +43,10 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac
var diags []analysis.Diagnostic
nodeFilter := []ast.Node{(*ast.SwitchStmt)(nil), (*ast.TypeSwitchStmt)(nil)}
inspect.Preorder(nodeFilter, func(n ast.Node) {
if expr, ok := n.(*ast.SwitchStmt); ok {
if (start.IsValid() && expr.End() < start) || (end.IsValid() && expr.Pos() > end) {
switch expr := n.(type) {
case *ast.SwitchStmt:
if start.IsValid() && expr.End() < start ||
end.IsValid() && expr.Pos() > end {
return // non-overlapping
}

Expand All @@ -63,7 +59,7 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac
return
}

if _, err := suggestedFixSwitch(expr, pkg, info); err != nil {
if fix, err := suggestedFixSwitch(expr, pkg, info); err != nil || fix == nil {
return
}

Expand All @@ -77,10 +73,9 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac
// No TextEdits => computed later by gopls.
}},
})
}

if expr, ok := n.(*ast.TypeSwitchStmt); ok {
if (start.IsValid() && expr.End() < start) || (end.IsValid() && expr.Pos() > end) {
case *ast.TypeSwitchStmt:
if start.IsValid() && expr.End() < start ||
end.IsValid() && expr.Pos() > end {
return // non-overlapping
}

Expand All @@ -93,7 +88,7 @@ func Diagnose(inspect *inspector.Inspector, start, end token.Pos, pkg *types.Pac
return
}

if _, err := suggestedFixTypeSwitch(expr, pkg, info); err != nil {
if fix, err := suggestedFixTypeSwitch(expr, pkg, info); err != nil || fix == nil {
return
}

Expand All @@ -120,43 +115,40 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *
}

scope := namedType.Obj().Pkg().Scope()
variants := make([]string, 0)
var variants []string
for _, name := range scope.Names() {
obj := scope.Lookup(name)
if _, ok := obj.(*types.TypeName); !ok {
continue
continue // not a type
}

if types.Identical(obj.Type(), namedType.Obj().Type()) {
continue
}

if types.AssignableTo(obj.Type(), namedType.Obj().Type()) {
if obj.Pkg().Name() != pkg.Name() {
if !obj.Exported() {
continue
}
if types.IsInterface(obj.Type()) {
continue
}

variants = append(variants, obj.Pkg().Name()+"."+obj.Name())
} else {
variants = append(variants, obj.Name())
name := obj.Name()
samePkg := obj.Pkg() == pkg
if !samePkg {
if !obj.Exported() {
continue // inaccessible
}
} else if types.AssignableTo(types.NewPointer(obj.Type()), namedType.Obj().Type()) {
if obj.Pkg().Name() != pkg.Name() {
if !obj.Exported() {
continue
}
name = obj.Pkg().Name() + name
}

variants = append(variants, "*"+obj.Pkg().Name()+"."+obj.Name())
} else {
variants = append(variants, "*"+obj.Name())
}
if types.AssignableTo(obj.Type(), namedType.Obj().Type()) {
variants = append(variants, name)
} else if types.AssignableTo(types.NewPointer(obj.Type()), namedType.Obj().Type()) {
variants = append(variants, "*"+name)
}
}

handledVariants := getHandledVariants(stmt.Body)
handledVariants := caseTypes(stmt.Body, info)
if len(variants) == 0 || len(variants) == len(handledVariants) {
return nil, errNoSuggestedFix
return nil, nil
}

newText := buildNewText(variants, handledVariants)
Expand All @@ -165,7 +157,7 @@ func suggestedFixTypeSwitch(stmt *ast.TypeSwitchStmt, pkg *types.Package, info *
TextEdits: []analysis.TextEdit{{
Pos: stmt.End() - 1,
End: stmt.End() - 1,
NewText: indent([]byte(newText), []byte{'\t'}),
NewText: bytes.ReplaceAll([]byte(newText), []byte("\n"), []byte("\n\t")),
}},
}, nil
}
Expand Down Expand Up @@ -198,9 +190,9 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In
}
}

handledVariants := getHandledVariants(stmt.Body)
handledVariants := caseTypes(stmt.Body, info)
if len(variants) == 0 || len(variants) == len(handledVariants) {
return nil, errNoSuggestedFix
return nil, nil
}

newText := buildNewText(variants, handledVariants)
Expand All @@ -209,7 +201,7 @@ func suggestedFixSwitch(stmt *ast.SwitchStmt, pkg *types.Package, info *types.In
TextEdits: []analysis.TextEdit{{
Pos: stmt.End() - 1,
End: stmt.End() - 1,
NewText: indent([]byte(newText), []byte{'\t'}),
NewText: bytes.ReplaceAll([]byte(newText), []byte("\n"), []byte("\n\t")),
}},
}, nil
}
Expand Down Expand Up @@ -288,20 +280,36 @@ func buildNewText(variants []string, handledVariants []string) string {
return textBuilder.String()
}

func getHandledVariants(body *ast.BlockStmt) []string {
out := make([]string, 0)
for _, bl := range body.List {
for _, c := range bl.(*ast.CaseClause).List {
switch v := c.(type) {
func caseTypes(body *ast.BlockStmt, info *types.Info) []string {
var out []string
for _, stmt := range body.List {
for _, e := range stmt.(*ast.CaseClause).List {
switch e := e.(type) {
case *ast.Ident:
out = append(out, v.Name)
out = append(out, e.Name)
case *ast.SelectorExpr:
out = append(out, v.X.(*ast.Ident).Name+"."+v.Sel.Name)
if _, ok := e.X.(*ast.Ident); !ok {
continue
}

out = append(out, e.X.(*ast.Ident).Name+"."+e.Sel.Name)
case *ast.StarExpr:
switch v := v.X.(type) {
switch v := e.X.(type) {
case *ast.Ident:
if !info.Types[v].IsType() {
continue
}

out = append(out, "*"+v.Name)
case *ast.SelectorExpr:
if !info.Types[v].IsType() {
continue
}

if _, ok := e.X.(*ast.Ident); !ok {
continue
}

out = append(out, "*"+v.X.(*ast.Ident).Name+"."+v.Sel.Name)
}
}
Expand Down Expand Up @@ -339,21 +347,3 @@ func SuggestedFix(ctx context.Context, snapshot *cache.Snapshot, pkg *cache.Pack
return nil, nil, fmt.Errorf("no switch statement found")
}
}

// indent works line by line through str, prefixing each line with
// prefix.
func indent(str, prefix []byte) []byte {
split := bytes.Split(str, []byte("\n"))
newText := bytes.NewBuffer(nil)
for i, s := range split {
if i != 0 {
newText.Write(prefix)
}

newText.Write(s)
if i < len(split)-1 {
newText.WriteByte('\n')
}
}
return newText.Bytes()
}
2 changes: 1 addition & 1 deletion gopls/internal/analysis/fillswitch/fillswitch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
)

// analyzer allows us to test the fillswitch code action using the analysistest
// harness. (fillswitch used to be a gopls analyzer.)
// harness.
var analyzer = &analysis.Analyzer{
Name: "fillswitch",
Doc: "test only",
Expand Down

0 comments on commit 3fcf757

Please sign in to comment.