Skip to content

Commit

Permalink
feat: added support for composite generics
Browse files Browse the repository at this point in the history
  • Loading branch information
defaulterrr committed Mar 17, 2023
1 parent 3c94f9e commit 548317b
Show file tree
Hide file tree
Showing 10 changed files with 1,209 additions and 83 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ generate:
go run ./cmd/minimock/minimock.go -i ./tests.genericInout -o ./tests/generic/generic_inout.go
go run ./cmd/minimock/minimock.go -i ./tests.genericOut -o ./tests/generic/generic_out.go
go run ./cmd/minimock/minimock.go -i ./tests.genericIn -o ./tests/generic/generic_in.go
go run ./cmd/minimock/minimock.go -i ./tests.genericSpecific -o ./tests/generic/generic_specific.go
go run ./cmd/minimock/minimock.go -i ./tests.genericSimpleUnion -o ./tests/generic/generic_simple_union.go
go run ./cmd/minimock/minimock.go -i ./tests.genericComplexUnion -o ./tests/generic/generic_complex_union.go
go run ./cmd/minimock/minimock.go -i ./tests.genericInlineUnion -o ./tests/generic/generic_inline_union.go

./bin:
mkdir ./bin
Expand Down
92 changes: 9 additions & 83 deletions cmd/minimock/minimock.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"time"

minimock "github.com/gojuno/minimock/v3"
"github.com/gojuno/minimock/v3/internal/types"
"github.com/hexdigest/gowrap/generator"
"github.com/hexdigest/gowrap/pkg"
"github.com/pkg/errors"
Expand Down Expand Up @@ -95,10 +96,7 @@ func run(opts *options) (err error) {
}
}

interfaces, err := findInterfaces(astPackage, in.Type)
if err != nil {
return err
}
interfaces := types.FindAllInterfaces(astPackage, in.Type)

gopts := generator.Options{
SourcePackage: sourcePackage.PkgPath,
Expand All @@ -125,71 +123,26 @@ func run(opts *options) (err error) {
return nil
}

func getTypeParams(typeSpec *ast.TypeSpec) []interfaceSpecificationParam {
params := []interfaceSpecificationParam{}

// Check whether node has any type params at all
if typeSpec == nil || typeSpec.TypeParams == nil {
return nil
}

// If node has any type params - store them in slice and return as a spec
for _, param := range typeSpec.TypeParams.List {
names := []string{}
for _, name := range param.Names {
names = append(names, name.Name)
}

paramType := ""

if ident, ok := param.Type.(*ast.Ident); ok {
paramType = ident.Name
}

params = append(params, interfaceSpecificationParam{
paramNames: names,
paramType: paramType,
})
}

return params
}

// interfaceSpecification represents abstraction over interface type. It contains all the metadata
// required to render a mock for given interface. One could deduce whether interface is generic
// by looking for type params
type interfaceSpecification struct {
interfaceName string
interfaceParams []interfaceSpecificationParam
}

// interfaceSpecificationParam represents a group of type param variables and their type
// I.e. [T,K any] would result in names "T","K" and type "any"
type interfaceSpecificationParam struct {
paramNames []string
paramType string
}

func processPackage(opts generator.Options, interfaces []interfaceSpecification, writeTo, suffix, mockName string) (err error) {
func processPackage(opts generator.Options, interfaces []types.InterfaceSpecification, writeTo, suffix, mockName string) (err error) {
for _, iface := range interfaces {
opts.InterfaceName = iface.interfaceName
opts.InterfaceName = iface.InterfaceName

params := ""
paramsReferences := ""

for _, param := range iface.interfaceParams {
names := strings.Join(param.paramNames, ",")
params += fmt.Sprintf("%s %s", names, param.paramType)
for _, param := range iface.InterfaceParams {
names := strings.Join(param.ParamNames, ",")
params += fmt.Sprintf("%s %s", names, param.ParamType)
if paramsReferences == "" {
paramsReferences = names
} else {
paramsReferences = strings.Join([]string{paramsReferences, names}, ",")
}
}

opts.OutputFile, err = destinationFile(iface.interfaceName, writeTo, suffix)
opts.OutputFile, err = destinationFile(iface.InterfaceName, writeTo, suffix)
if err != nil {
return errors.Wrapf(err, "failed to generate mock for %s", iface.interfaceName)
return errors.Wrapf(err, "failed to generate mock for %s", iface.InterfaceName)
}

opts.Vars["MockName"] = fmt.Sprintf("%sMock", opts.InterfaceName)
Expand Down Expand Up @@ -284,33 +237,6 @@ func generate(o generator.Options) (err error) {
return ioutil.WriteFile(o.OutputFile, buf.Bytes(), 0644)
}

func findInterfaces(p *ast.Package, pattern string) ([]interfaceSpecification, error) {
var interfaceSpecifications []interfaceSpecification

for _, f := range p.Files {
for _, d := range f.Decls {
if gd, ok := d.(*ast.GenDecl); ok && gd.Tok == token.TYPE {
for _, spec := range gd.Specs {
if ts, ok := spec.(*ast.TypeSpec); ok {
if _, ok := ts.Type.(*ast.InterfaceType); ok && match(ts.Name.Name, pattern) {
interfaceSpecifications = append(interfaceSpecifications, interfaceSpecification{
interfaceName: ts.Name.Name,
interfaceParams: getTypeParams(ts),
})
}
}
}
}
}
}

if len(interfaceSpecifications) == 0 {
return nil, errors.Errorf("failed to find any interfaces matching %s in %s", pattern, p.Name)
}

return interfaceSpecifications, nil
}

func match(s, pattern string) bool {
return pattern == "*" || s == pattern
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/pmezard/go-difflib v1.0.0
github.com/stretchr/testify v1.7.0
golang.org/x/tools v0.1.3
google.golang.org/protobuf v1.30.0
)

require (
Expand Down
7 changes: 7 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfb
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/hexdigest/gowrap v1.1.7/go.mod h1:Z+nBFUDLa01iaNM+/jzoOA1JJ7sm51rnYFauKFUB5fs=
github.com/hexdigest/gowrap v1.1.8 h1:xGTnuMvHou3sa+PSHphOCxPJTJyqNRvGl21t/p3eLes=
github.com/hexdigest/gowrap v1.1.8/go.mod h1:H/JiFmQMp//tedlV8qt2xBdGzmne6bpbaSuiHmygnMw=
Expand Down Expand Up @@ -101,6 +104,7 @@ golang.org/x/tools v0.1.3 h1:L69ShwSZEyCsLKoAxDKeMvLDZkumEe8gXUZAjab0tX8=
golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
Expand All @@ -110,6 +114,9 @@ google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.27.1/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
Expand Down
139 changes: 139 additions & 0 deletions internal/types/interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package types

import (
"go/ast"
"go/token"
)

// InterfaceSpecification represents abstraction over interface type. It contains all the metadata
// required to render a mock for given interface. One could deduce whether interface is generic
// by looking for type params
type InterfaceSpecification struct {
InterfaceName string
InterfaceParams []InterfaceSpecificationParam
}

// InterfaceSpecificationParam represents a group of type param variables and their type
// I.e. [T,K any] would result in names "T","K" and type "any"
type InterfaceSpecificationParam struct {
ParamNames []string
ParamType string
}

func FindAllInterfaces(p *ast.Package, pattern string) []InterfaceSpecification {
// Find all declared types in a single package
types := []*ast.TypeSpec{}
for _, file := range p.Files {
types = append(types, findAllTypeSpecsInFile(file)...)
}

// Filter interfaces from all the declarations
interfaces := []*ast.TypeSpec{}
for _, typeSpec := range types {
if isInterface(typeSpec) {
interfaces = append(interfaces, typeSpec)
}
}

// Filter interfaces with the given pattern
filteredInterfaces := []*ast.TypeSpec{}
for _, iface := range interfaces {
if match(iface.Name.Name, pattern) {
filteredInterfaces = append(filteredInterfaces, iface)
}
}

// Transform AST nodes into specifications
interfaceSpecifications := make([]InterfaceSpecification, 0, len(filteredInterfaces))
for _, iface := range filteredInterfaces {
interfaceSpecifications = append(interfaceSpecifications, InterfaceSpecification{
InterfaceName: iface.Name.Name,
InterfaceParams: getTypeParams(iface),
})
}

return interfaceSpecifications
}

func isInterface(typeSpec *ast.TypeSpec) bool {
// Check if this type declaration is specifically an interface declaration
_, ok := typeSpec.Type.(*ast.InterfaceType)
return ok
}

// findAllInterfaceNodesInFile ranges over file's AST nodes and extracts all interfaces inside
// returned *ast.TypeSpecs can be safely interpreted as interface declaration nodes
func findAllTypeSpecsInFile(f *ast.File) []*ast.TypeSpec {
typeSpecs := []*ast.TypeSpec{}

// Range over all declarations in a single file
for _, declaration := range f.Decls {
// Check if declaration is an import, constant, type or variable declaration.
// If it is, check specifically if it's a TYPE as all interfaces are types
if genericDeclaration, ok := declaration.(*ast.GenDecl); ok && genericDeclaration.Tok == token.TYPE {
// Range over all specifications and find ones that are Type declarations
// This is mostly a precaution
for _, spec := range genericDeclaration.Specs {
// Check directly for a type spec declaration
if typeSpec, ok := spec.(*ast.TypeSpec); ok {
typeSpecs = append(typeSpecs, typeSpec)
}
}
}
}

return typeSpecs
}

// match returns true if pattern is a wildcard or directly matches the given name
func match(name, pattern string) bool {
return pattern == "*" || name == pattern
}

func getTypeParams(typeSpec *ast.TypeSpec) []InterfaceSpecificationParam {
params := []InterfaceSpecificationParam{}

// Check whether node has any type params at all
if typeSpec == nil || typeSpec.TypeParams == nil {
return nil
}

// If node has any type params - store them in slice and return as a spec
for _, param := range typeSpec.TypeParams.List {
names := []string{}
for _, name := range param.Names {
names = append(names, name.Name)
}

paramType := ""

ast.Print(token.NewFileSet(), param.Type)

switch node := param.Type.(type) {
// Direct declarations in form of
// [T int] or [T any]
case *ast.Ident:
paramType = node.Name
// Reference to a type, i.e.
// proto.Message

// we can reference those without worrying about external imports
// due to Go tooling being able to deduce missing imports from
// the surrounding context (files and existing references to types).
// i.e. user already referenced the type in the nearby file.
case *ast.SelectorExpr:
paramType = node.X.(*ast.Ident).Name + "." + node.Sel.Name
// Inline reference, i.e.
// int | float64
case *ast.BinaryExpr:
paramType = node.X.(*ast.Ident).Name + " | " + node.Y.(*ast.Ident).Name
}

params = append(params, InterfaceSpecificationParam{
ParamNames: names,
ParamType: paramType,
})
}

return params
}
Loading

0 comments on commit 548317b

Please sign in to comment.