Skip to content

Commit

Permalink
Add support for generating mocks for generic interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul Cruickshank authored and LandonTClipp committed May 24, 2022
1 parent 68d25fe commit dc5539e
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 23 deletions.
28 changes: 24 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,19 +1,39 @@
module github.com/vektra/mockery/v2

go 1.16
go 1.18

require (
github.com/mitchellh/go-homedir v1.1.0
github.com/pkg/errors v0.9.1
github.com/rs/zerolog v1.26.1
github.com/spf13/afero v1.8.0 // indirect
github.com/spf13/cobra v1.3.0
github.com/spf13/viper v1.10.1
github.com/stretchr/testify v1.7.0
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce
golang.org/x/tools v0.1.10
gopkg.in/yaml.v2 v2.4.0
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.5.1 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/magiconair/properties v1.8.5 // indirect
github.com/mitchellh/mapstructure v1.4.3 // indirect
github.com/pelletier/go-toml v1.9.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/afero v1.8.0 // indirect
github.com/spf13/cast v1.4.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.1.1 // indirect
github.com/subosito/gotenv v1.2.0 // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect
golang.org/x/tools v0.1.10
golang.org/x/text v0.3.7 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gopkg.in/ini.v1 v1.66.3 // indirect
gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)
10 changes: 0 additions & 10 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/hashicorp/consul/api v1.11.0/go.mod h1:XjsvQN+RJGWI2TWy1/kqaE16HrR2J/FWgkYjdZQsX9M=
github.com/hashicorp/consul/api v1.12.0/go.mod h1:6pVBMo0ebnYdt2S3H87XhekM/HHrUoTD2XXb/VrZVy0=
github.com/hashicorp/consul/sdk v0.8.0/go.mod h1:GBvyrGALthsZObzUGsfgHZQDXjg4lOjagTIwIR1vPms=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80=
Expand Down Expand Up @@ -317,7 +316,6 @@ github.com/rs/zerolog v1.26.1/go.mod h1:/wSSJWX7lVrsOwlbyTRSOJvqRlc+WjWlfes+CiJ+
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
github.com/sagikazarmark/crypt v0.3.0/go.mod h1:uD/D+6UF4SrIR1uGEv7bBNkNqLGqUr43MRiaGWX1Nig=
github.com/sagikazarmark/crypt v0.4.0/go.mod h1:ALv2SRj7GxYV4HO9elxH9nS6M9gW+xDNxqmyJ6RfDFM=
github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
Expand Down Expand Up @@ -356,7 +354,6 @@ github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
go.etcd.io/etcd/api/v3 v3.5.1/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs=
go.etcd.io/etcd/client/pkg/v3 v3.5.1/go.mod h1:IJHfcCEKxYu1Os13ZdwCwIUTUVGYTSAM3YSwc9/Ac1g=
go.etcd.io/etcd/client/v2 v2.305.1/go.mod h1:pMEacxZW7o8pg4CrFE7pquyCJJzZvkvdD2RibOCCCGs=
Expand All @@ -382,7 +379,6 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211215165025-cf75a172585e/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce h1:Roh6XWxHFKrPgC/EQhVubSAGQ6Ozk6IdxHSzt1mR0EI=
Expand Down Expand Up @@ -468,8 +464,6 @@ golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy
golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
Expand Down Expand Up @@ -564,10 +558,8 @@ golang.org/x/sys v0.0.0-20210823070655-63515b42dcdf/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210908233432-aa78b53d3365/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211205182925-97ca703d548d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 h1:XfKQ4OlFl8okEOr5UvAqFRVj8pY/4yfcXrddB8qAbU0=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
Expand Down Expand Up @@ -681,7 +673,6 @@ google.golang.org/api v0.57.0/go.mod h1:dVPlbZyBo2/OjBpmvNdpn2GRm6rPy75jyU7bmhdr
google.golang.org/api v0.59.0/go.mod h1:sT2boj7M9YJxZzgeZqXogmhfmRWDtPzT31xkieUbuZU=
google.golang.org/api v0.61.0/go.mod h1:xQRti5UdCmoCEqFxcz93fTl338AVqDgyaDRuOZ3hg9I=
google.golang.org/api v0.62.0/go.mod h1:dKmwPCydfsad4qCH08MSdgWjfHOyfpd4VtDGgRFdavw=
google.golang.org/api v0.63.0/go.mod h1:gs4ij2ffTRXwuzzgJl/56BdwJaA194ijkfn++9tDuPo=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
Expand Down Expand Up @@ -781,7 +772,6 @@ google.golang.org/grpc v1.39.1/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnD
google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34=
google.golang.org/grpc v1.40.1/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34=
google.golang.org/grpc v1.42.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU=
google.golang.org/grpc v1.43.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU=
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
Expand Down
21 changes: 21 additions & 0 deletions pkg/fixtures/generic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package test

type Constraint interface {
int
}

type Generic[T Constraint] interface {
Get() T
}

type GenericAny[T any] interface {
Get() T
}

type GenericComparable[T comparable] interface {
Get() T
}

type Embedded interface {
Generic[int]
}
75 changes: 66 additions & 9 deletions pkg/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ func (g *Generator) populateImports(ctx context.Context) {

log.Debug().Msgf("populating imports")

// imports from generic type constraints
if tParams := g.iface.NamedType.TypeParams(); tParams != nil && tParams.Len() > 0 {
for i := 0; i < tParams.Len(); i++ {
g.renderType(ctx, tParams.At(i).Constraint())
}
}

// imports from type arguments
if tArgs := g.iface.NamedType.TypeArgs(); tArgs != nil && tArgs.Len() > 0 {
for i := 0; i < tArgs.Len(); i++ {
g.renderType(ctx, tArgs.At(i))
}
}

for _, method := range g.iface.Methods() {
ftype := method.Signature
g.addImportsFromTuple(ctx, ftype.Params())
Expand All @@ -88,6 +102,13 @@ func (g *Generator) addImportsFromTuple(ctx context.Context, list *types.Tuple)
}
}

func (g *Generator) addPackageScopedType(ctx context.Context, o *types.TypeName) string {
if o.Pkg() == nil || o.Pkg().Name() == "main" || (!g.KeepTree && g.InPackage && o.Pkg() == g.iface.Pkg) {
return o.Name()
}
return g.addPackageImport(ctx, o.Pkg()) + "." + o.Name()
}

func (g *Generator) addPackageImport(ctx context.Context, pkg *types.Package) string {
return g.addPackageImportWithName(ctx, pkg.Path(), pkg.Name())
}
Expand Down Expand Up @@ -231,6 +252,36 @@ func (g *Generator) mockName() string {
return g.maybeMakeNameExported(g.iface.Name, g.Exported)
}

func (g *Generator) typeConstraints(ctx context.Context) string {
tp := g.iface.NamedType.TypeParams()
if tp == nil || tp.Len() == 0 {
return ""
}
qualifiedParams := make([]string, 0, tp.Len())
for i := 0; i < tp.Len(); i++ {
param := tp.At(i)
switch constraint := param.Constraint().(type) {
case *types.Named:
qualifiedParams = append(qualifiedParams, fmt.Sprintf("%s %s", param.String(), g.addPackageScopedType(ctx, constraint.Obj())))
case *types.Interface:
qualifiedParams = append(qualifiedParams, fmt.Sprintf("%s %s", param.String(), constraint.String()))
}
}
return fmt.Sprintf("[%s]", strings.Join(qualifiedParams, ", "))
}

func (g *Generator) typeParams() string {
tp := g.iface.NamedType.TypeParams()
if tp == nil || tp.Len() == 0 {
return ""
}
params := make([]string, 0, tp.Len())
for i := 0; i < tp.Len(); i++ {
params = append(params, tp.At(i).String())
}
return fmt.Sprintf("[%s]", strings.Join(params, ", "))
}

func (g *Generator) expecterName() string {
return g.mockName() + "_Expecter"
}
Expand Down Expand Up @@ -335,11 +386,12 @@ type namer interface {
func (g *Generator) renderType(ctx context.Context, typ types.Type) string {
switch t := typ.(type) {
case *types.Named:
o := t.Obj()
if o.Pkg() == nil || o.Pkg().Name() == "main" || (!g.KeepTree && g.InPackage && o.Pkg() == g.iface.Pkg) {
return o.Name()
return g.addPackageScopedType(ctx, t.Obj())
case *types.TypeParam:
if t.Constraint() != nil {
return t.Obj().Name()
}
return g.addPackageImport(ctx, o.Pkg()) + "." + o.Name()
return g.addPackageScopedType(ctx, t.Obj())
case *types.Basic:
if t.Kind() == types.UnsafePointer {
return "unsafe.Pointer"
Expand Down Expand Up @@ -512,7 +564,7 @@ func (g *Generator) Generate(ctx context.Context) error {
)

g.printf(
"type %s struct {\n\tmock.Mock\n}\n\n", g.mockName(),
"type %s%s struct {\n\tmock.Mock\n}\n\n", g.mockName(), g.typeConstraints(ctx),
)

if g.WithExpecter {
Expand Down Expand Up @@ -541,7 +593,7 @@ func (g *Generator) Generate(ctx context.Context) error {
)
}
g.printf(
"func (_m *%s) %s(%s) ", g.mockName(), fname,
"func (_m *%s%s) %s(%s) ", g.mockName(), g.typeParams(), fname,
strings.Join(params.Params, ", "),
)

Expand Down Expand Up @@ -612,7 +664,7 @@ func (g *Generator) Generate(ctx context.Context) error {
}
}

g.generateConstructor()
g.generateConstructor(ctx)

return nil
}
Expand Down Expand Up @@ -712,16 +764,17 @@ func (_c *{{.CallStruct}}) Return({{range .Returns.Params}}{{.}},{{end}}) *{{.Ca
`)
}

func (g *Generator) generateConstructor() {
func (g *Generator) generateConstructor(ctx context.Context) {
const constructorTemplate = `
type {{ .ConstructorTestingInterfaceName }} interface {
mock.TestingT
Cleanup(func())
}
// {{ .ConstructorName }} creates a new instance of {{ .MockName }}. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func {{ .ConstructorName }}(t {{ .ConstructorTestingInterfaceName }}) *{{ .MockName }} {
func {{ .ConstructorName }}{{ .TypeConstraint }}(t {{ .ConstructorTestingInterfaceName }}) *{{ .MockName }}{{ .TypeParams }} {
mock := &{{ .MockName }}{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
Expand All @@ -736,10 +789,14 @@ func {{ .ConstructorName }}(t {{ .ConstructorTestingInterfaceName }}) *{{ .MockN
ConstructorName string
ConstructorTestingInterfaceName string
MockName string
TypeConstraint string
TypeParams string
}{
ConstructorName: constructorName,
ConstructorTestingInterfaceName: constructorName + "T",
MockName: mockName,
TypeConstraint: g.typeConstraints(ctx),
TypeParams: g.typeParams(),
}
g.printTemplate(data, constructorTemplate)
}
Expand Down
Loading

0 comments on commit dc5539e

Please sign in to comment.