Skip to content

Commit

Permalink
Merge pull request #137 from louissobel/add-receiver-checking-against…
Browse files Browse the repository at this point in the history
…-exclude

Check all levels of embedded interfaces against the exclude list.
  • Loading branch information
kisielk committed May 22, 2018
2 parents e5700e7 + cd18fe7 commit 55d8f50
Show file tree
Hide file tree
Showing 5 changed files with 342 additions and 10 deletions.
144 changes: 144 additions & 0 deletions internal/errcheck/embedded_walker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package errcheck

import (
"fmt"
"go/types"
)

// walkThroughEmbeddedInterfaces returns a slice of Interfaces that
// we need to walk through in order to reach the actual definition,
// in an Interface, of the method selected by the given selection.
//
// false will be returned in the second return value if:
// - the right side of the selection is not a function
// - the actual definition of the function is not in an Interface
//
// The returned slice will contain all the interface types that need
// to be walked through to reach the actual definition.
//
// For example, say we have:
//
// type Inner interface {Method()}
// type Middle interface {Inner}
// type Outer interface {Middle}
// type T struct {Outer}
// type U struct {T}
// type V struct {U}
//
// And then the selector:
//
// V.Method
//
// We'll return [Outer, Middle, Inner] by first walking through the embedded structs
// until we reach the Outer interface, then descending through the embedded interfaces
// until we find the one that actually explicitly defines Method.
func walkThroughEmbeddedInterfaces(sel *types.Selection) ([]types.Type, bool) {
fn, ok := sel.Obj().(*types.Func)
if !ok {
return nil, false
}

// Start off at the receiver.
currentT := sel.Recv()

// First, we can walk through any Struct fields provided
// by the selection Index() method. We ignore the last
// index because it would give the method itself.
indexes := sel.Index()
for _, fieldIndex := range indexes[:len(indexes)-1] {
currentT = getTypeAtFieldIndex(currentT, fieldIndex)
}

// Now currentT is either a type implementing the actual function,
// an Invalid type (if the receiver is a package), or an interface.
//
// If it's not an Interface, then we're done, as this function
// only cares about Interface-defined functions.
//
// If it is an Interface, we potentially need to continue digging until
// we find the Interface that actually explicitly defines the function.
interfaceT, ok := maybeUnname(currentT).(*types.Interface)
if !ok {
return nil, false
}

// The first interface we pass through is this one we've found. We return the possibly
// wrapping types.Named because it is more useful to work with for callers.
result := []types.Type{currentT}

// If this interface itself explicitly defines the given method
// then we're done digging.
for !explicitlyDefinesMethod(interfaceT, fn) {
// Otherwise, we find which of the embedded interfaces _does_
// define the method, add it to our list, and loop.
namedInterfaceT, ok := getEmbeddedInterfaceDefiningMethod(interfaceT, fn)
if !ok {
// This should be impossible as long as we type-checked: either the
// interface or one of its embedded ones must implement the method...
panic(fmt.Sprintf("either %v or one of its embedded interfaces must implement %v", currentT, fn))
}
result = append(result, namedInterfaceT)
interfaceT = namedInterfaceT.Underlying().(*types.Interface)
}

return result, true
}

func getTypeAtFieldIndex(startingAt types.Type, fieldIndex int) types.Type {
t := maybeUnname(maybeDereference(startingAt))
s, ok := t.(*types.Struct)
if !ok {
panic(fmt.Sprintf("cannot get Field of a type that is not a struct, got a %T", t))
}

return s.Field(fieldIndex).Type()
}

// getEmbeddedInterfaceDefiningMethod searches through any embedded interfaces of the
// passed interface searching for one that defines the given function. If found, the
// types.Named wrapping that interface will be returned along with true in the second value.
//
// If no such embedded interface is found, nil and false are returned.
func getEmbeddedInterfaceDefiningMethod(interfaceT *types.Interface, fn *types.Func) (*types.Named, bool) {
for i := 0; i < interfaceT.NumEmbeddeds(); i++ {
embedded := interfaceT.Embedded(i)
if definesMethod(embedded.Underlying().(*types.Interface), fn) {
return embedded, true
}
}
return nil, false
}

func explicitlyDefinesMethod(interfaceT *types.Interface, fn *types.Func) bool {
for i := 0; i < interfaceT.NumExplicitMethods(); i++ {
if interfaceT.ExplicitMethod(i) == fn {
return true
}
}
return false
}

func definesMethod(interfaceT *types.Interface, fn *types.Func) bool {
for i := 0; i < interfaceT.NumMethods(); i++ {
if interfaceT.Method(i) == fn {
return true
}
}
return false
}

func maybeDereference(t types.Type) types.Type {
p, ok := t.(*types.Pointer)
if ok {
return p.Elem()
}
return t
}

func maybeUnname(t types.Type) types.Type {
n, ok := t.(*types.Named)
if ok {
return n.Underlying()
}
return t
}
93 changes: 93 additions & 0 deletions internal/errcheck/embedded_walker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package errcheck

import (
"go/ast"
"go/parser"
"go/token"
"go/types"
"testing"
)

const commonSrc = `
package p
type Inner struct {}
func (Inner) Method()
type Outer struct {Inner}
type OuterP struct {*Inner}
type InnerInterface interface {
Method()
}
type OuterInterface interface {InnerInterface}
type MiddleInterfaceStruct struct {OuterInterface}
type OuterInterfaceStruct struct {MiddleInterfaceStruct}
var c = `

type testCase struct {
selector string
expectedOk bool
expected []string
}

func TestWalkThroughEmbeddedInterfaces(t *testing.T) {
cases := []testCase{
testCase{"Inner{}.Method", false, nil},
testCase{"(&Inner{}).Method", false, nil},
testCase{"Outer{}.Method", false, nil},
testCase{"InnerInterface.Method", true, []string{"test.InnerInterface"}},
testCase{"OuterInterface.Method", true, []string{"test.OuterInterface", "test.InnerInterface"}},
testCase{"OuterInterfaceStruct.Method", true, []string{"test.OuterInterface", "test.InnerInterface"}},
}

for _, c := range cases {
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "test", commonSrc+c.selector, 0)
if err != nil {
t.Fatal(err)
}

conf := types.Config{}
info := types.Info{
Selections: make(map[*ast.SelectorExpr]*types.Selection),
}
_, err = conf.Check("test", fset, []*ast.File{f}, &info)
if err != nil {
t.Fatal(err)
}
ast.Inspect(f, func(n ast.Node) bool {
s, ok := n.(*ast.SelectorExpr)
if ok {
selection, ok := info.Selections[s]
if !ok {
t.Fatalf("no Selection!")
}
ts, ok := walkThroughEmbeddedInterfaces(selection)
if ok != c.expectedOk {
t.Errorf("expected ok %v got %v", c.expectedOk, ok)
return false
}
if !ok {
return false
}

if len(ts) != len(c.expected) {
t.Fatalf("expected %d types, got %d", len(c.expected), len(ts))
}

for i, e := range c.expected {
if e != ts[i].String() {
t.Errorf("mismatch at index %d: expected %s got %s", i, e, ts[i])
}
}
}

return true
})

}

}
99 changes: 89 additions & 10 deletions internal/errcheck/errcheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ func (c *Checker) SetExclude(l map[string]bool) {
"(*strings.Builder).WriteByte",
"(*strings.Builder).WriteRune",
"(*strings.Builder).WriteString",

// hash
"(hash.Hash).Write",
} {
c.exclude[exc] = true
}
Expand Down Expand Up @@ -236,29 +239,105 @@ type visitor struct {
errors []UncheckedError
}

func (v *visitor) fullName(call *ast.CallExpr) (string, bool) {
// selectorAndFunc tries to get the selector and function from call expression.
// For example, given the call expression representing "a.b()", the selector
// is "a.b" and the function is "b" itself.
//
// The final return value will be true if it is able to do extract a selector
// from the call and look up the function object it refers to.
//
// If the call does not include a selector (like if it is a plain "f()" function call)
// then the final return value will be false.
func (v *visitor) selectorAndFunc(call *ast.CallExpr) (*ast.SelectorExpr, *types.Func, bool) {
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return "", false
return nil, nil, false
}

fn, ok := v.pkg.ObjectOf(sel.Sel).(*types.Func)
if !ok {
// Shouldn't happen, but be paranoid
return "", false
return nil, nil, false
}
// The name is fully qualified by the import path, possible type,
// function/method name and pointer receiver.
//

return sel, fn, true

}

// fullName will return a package / receiver-type qualified name for a called function
// if the function is the result of a selector. Otherwise it will return
// the empty string.
//
// The name is fully qualified by the import path, possible type,
// function/method name and pointer receiver.
//
// For example,
// - for "fmt.Printf(...)" it will return "fmt.Printf"
// - for "base64.StdEncoding.Decode(...)" it will return "(*encoding/base64.Encoding).Decode"
// - for "myFunc()" it will return ""
func (v *visitor) fullName(call *ast.CallExpr) string {
_, fn, ok := v.selectorAndFunc(call)
if !ok {
return ""
}

// TODO(dh): vendored packages will have /vendor/ in their name,
// thus not matching vendored standard library packages. If we
// want to support vendored stdlib packages, we need to implement
// FullName with our own logic.
return fn.FullName(), true
return fn.FullName()
}

// namesForExcludeCheck will return a list of fully-qualified function names
// from a function call that can be used to check against the exclusion list.
//
// If a function call is against a local function (like "myFunc()") then no
// names are returned. If the function is package-qualified (like "fmt.Printf()")
// then just that function's fullName is returned.
//
// Otherwise, we walk through all the potentially embeddded interfaces of the receiver
// the collect a list of type-qualified function names that we will check.
func (v *visitor) namesForExcludeCheck(call *ast.CallExpr) []string {
sel, fn, ok := v.selectorAndFunc(call)
if !ok {
return nil
}

name := v.fullName(call)
if name == "" {
return nil
}

// This will be missing for functions without a receiver (like fmt.Printf),
// so just fall back to the the function's fullName in that case.
selection, ok := v.pkg.Selections[sel]
if !ok {
return []string{name}
}

// This will return with ok false if the function isn't defined
// on an interface, so just fall back to the fullName.
ts, ok := walkThroughEmbeddedInterfaces(selection)
if !ok {
return []string{name}
}

result := make([]string, len(ts))
for i, t := range ts {
// Like in fullName, vendored packages will have /vendor/ in their name,
// thus not matching vendored standard library packages. If we
// want to support vendored stdlib packages, we need to implement
// additional logic here.
result[i] = fmt.Sprintf("(%s).%s", t.String(), fn.Name())
}
return result
}

func (v *visitor) excludeCall(call *ast.CallExpr) bool {
if name, ok := v.fullName(call); ok {
return v.exclude[name]
for _, name := range v.namesForExcludeCheck(call) {
if v.exclude[name] {
return true
}
}

return false
Expand Down Expand Up @@ -390,7 +469,7 @@ func (v *visitor) addErrorAtPosition(position token.Pos, call *ast.CallExpr) {

var name string
if call != nil {
name, _ = v.fullName(call)
name = v.fullName(call)
}

v.errors = append(v.errors, UncheckedError{pos, line, name})
Expand Down
3 changes: 3 additions & 0 deletions internal/errcheck/errcheck_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ func test(t *testing.T, f flags) {
checker := NewChecker()
checker.Asserts = asserts
checker.Blank = blank
checker.SetExclude(map[string]bool{
fmt.Sprintf("(%s.ErrorMakerInterface).MakeNilError", testPackage): true,
})
err := checker.CheckPackages(testPackage)
uerr, ok := err.(*UncheckedErrors)
if !ok {
Expand Down
Loading

0 comments on commit 55d8f50

Please sign in to comment.