From a8ef02fe726de51e8ec1a9a0a4c631753bc96d3b Mon Sep 17 00:00:00 2001 From: Sam Xie Date: Sat, 18 Jan 2020 13:36:17 +0800 Subject: [PATCH 1/3] feat: use "." to refer to the current path's package --- mockgen/mockgen.go | 13 ++++++++++- mockgen/parse.go | 56 +++++++++++++++++++++++++++++++++++++--------- 2 files changed, 58 insertions(+), 11 deletions(-) diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index 8126c814..ffceeb06 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -68,7 +68,18 @@ func main() { usage() log.Fatal("Expected exactly two arguments") } - pkg, err = reflectMode(flag.Arg(0), strings.Split(flag.Arg(1), ",")) + packageName := flag.Arg(0) + if packageName == "." { + dir, err := os.Getwd() + if err != nil { + log.Fatalf("Get current directory failed: %v", err) + } + packageName, err = packageNameOfDir(dir) + if err != nil { + log.Fatalf("Parse package name failed: %v", err) + } + } + pkg, err = reflectMode(packageName, strings.Split(flag.Arg(1), ",")) } if err != nil { log.Fatalf("Loading input failed: %v", err) diff --git a/mockgen/parse.go b/mockgen/parse.go index e35d16f5..1aead1b0 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -24,6 +24,7 @@ import ( "go/build" "go/parser" "go/token" + "io/ioutil" "log" "path" "path/filepath" @@ -48,19 +49,10 @@ func sourceMode(source string) (*model.Package, error) { return nil, fmt.Errorf("failed getting source directory: %v", err) } - cfg := &packages.Config{Mode: packages.LoadFiles, Tests: true, Dir: srcDir} - pkgs, err := packages.Load(cfg, "file="+source) + packageImport, err := parsePackageImport(source, srcDir) if err != nil { return nil, err } - if packages.PrintErrors(pkgs) > 0 || len(pkgs) == 0 { - return nil, errors.New("loading package failed") - } - - packageImport := pkgs[0].PkgPath - - // It is illegal to import a _test package. - packageImport = strings.TrimSuffix(packageImport, "_test") fs := token.NewFileSet() file, err := parser.ParseFile(fs, source, nil, 0) @@ -519,3 +511,47 @@ func isVariadic(f *ast.FuncType) bool { _, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis) return ok } + +// packageNameOfDir get package import path via dir +func packageNameOfDir(srcDir string) (string, error) { + files, err := ioutil.ReadDir(srcDir) + if err != nil { + log.Fatal(err) + } + + var goFilePath string + for _, file := range files { + log.Println(file.Name()) + if !file.IsDir() && strings.HasSuffix(file.Name(), ".go") { + goFilePath = file.Name() + break + } + } + if goFilePath == "" { + return "", fmt.Errorf("go source file not found %s", srcDir) + } + + packageImport, err := parsePackageImport(goFilePath, srcDir) + if err != nil { + return "", err + } + return packageImport, nil +} + +// parseImportPackage get package import path via source file +func parsePackageImport(source, srcDir string) (string, error) { + cfg := &packages.Config{Mode: packages.LoadFiles, Tests: true, Dir: srcDir} + pkgs, err := packages.Load(cfg, "file="+source) + if err != nil { + return "", err + } + if packages.PrintErrors(pkgs) > 0 || len(pkgs) == 0 { + return "", errors.New("loading package failed") + } + + packageImport := pkgs[0].PkgPath + + // It is illegal to import a _test package. + packageImport = strings.TrimSuffix(packageImport, "_test") + return packageImport, nil +} From 17ebec651714b485a91450f30c3b02d3cd068c2c Mon Sep 17 00:00:00 2001 From: Sam Xie Date: Sat, 18 Jan 2020 13:31:49 +0800 Subject: [PATCH 2/3] doc: update reflect mode --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 9204baf9..72c09d66 100644 --- a/README.md +++ b/README.md @@ -48,10 +48,15 @@ that uses reflection to understand interfaces. It is enabled by passing two non-flag arguments: an import path, and a comma-separated list of symbols. +You can use "." to refer to the current path's package. + Example: ```bash mockgen database/sql/driver Conn,Driver + +# Convenient for `go:generate`. +mockgen . Conn,Driver ``` The `mockgen` command is used to generate source code for a mock From c018a8c1687206b46ea6eacb98850777bfed8844 Mon Sep 17 00:00:00 2001 From: Sam Xie Date: Sat, 1 Feb 2020 16:52:42 +0800 Subject: [PATCH 3/3] fix: generated code lose package name --- mockgen/mockgen.go | 5 +++-- mockgen/parse.go | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index ffceeb06..664ad2f6 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -61,6 +61,7 @@ func main() { var pkg *model.Package var err error + var packageName string if *source != "" { pkg, err = sourceMode(*source) } else { @@ -68,7 +69,7 @@ func main() { usage() log.Fatal("Expected exactly two arguments") } - packageName := flag.Arg(0) + packageName = flag.Arg(0) if packageName == "." { dir, err := os.Getwd() if err != nil { @@ -133,7 +134,7 @@ func main() { if *source != "" { g.filename = *source } else { - g.srcPackage = flag.Arg(0) + g.srcPackage = packageName g.srcInterfaces = flag.Arg(1) } diff --git a/mockgen/parse.go b/mockgen/parse.go index 1aead1b0..32844d0e 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -521,7 +521,6 @@ func packageNameOfDir(srcDir string) (string, error) { var goFilePath string for _, file := range files { - log.Println(file.Name()) if !file.IsDir() && strings.HasSuffix(file.Name(), ".go") { goFilePath = file.Name() break