Skip to content

Commit

Permalink
Handle package name mismatch with dirname
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathew Byrne committed Jul 9, 2018
1 parent ebf1b2a commit 6d38f77
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 16 deletions.
1 change: 1 addition & 0 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func (cfg *Config) bind() (*Build, error) {
Interfaces: cfg.buildInterfaces(namedTypes, prog),
Inputs: inputs,
Imports: imports.finalize(),
SchemaRaw: cfg.SchemaStr,
}

if qr, ok := cfg.schema.EntryPoints["query"]; ok {
Expand Down
15 changes: 5 additions & 10 deletions codegen/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ func Generate(cfg Config) error {
return errors.Wrap(err, "model plan failed")
}
if len(modelsBuild.Models) > 0 || len(modelsBuild.Enums) > 0 {
modelsBuild.PackageName = cfg.Model.Package
var buf *bytes.Buffer
buf, err = templates.Run("models.gotpl", modelsBuild)
if err != nil {
Expand All @@ -57,8 +56,6 @@ func Generate(cfg Config) error {
if err != nil {
return errors.Wrap(err, "exec plan failed")
}
build.SchemaRaw = cfg.SchemaStr
build.PackageName = cfg.Exec.Package

var buf *bytes.Buffer
buf, err = templates.Run("generated.gotpl", build)
Expand Down Expand Up @@ -129,20 +126,18 @@ func abs(path string) string {
return filepath.ToSlash(absPath)
}

func importPath(dir string, pkgName string) string {
fullPkgName := filepath.Join(filepath.Dir(dir), pkgName)

func importPath(dir string) string {
for _, gopath := range filepath.SplitList(build.Default.GOPATH) {
gopath = filepath.Join(gopath, "src") + string(os.PathSeparator)
if len(gopath) > len(fullPkgName) {
if len(gopath) > len(dir) {
continue
}
if strings.EqualFold(gopath, fullPkgName[0:len(gopath)]) {
fullPkgName = fullPkgName[len(gopath):]
if strings.EqualFold(gopath, dir[0:len(gopath)]) {
dir = dir[len(gopath):]
break
}
}
return filepath.ToSlash(fullPkgName)
return filepath.ToSlash(dir)
}

func gofmt(filename string, b []byte) ([]byte, error) {
Expand Down
4 changes: 2 additions & 2 deletions codegen/codegen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func Test_fullPackageName(t *testing.T) {
t.Run("gopath longer than package name", func(t *testing.T) {
build.Default.GOPATH = "/a/src/xxxxxxxxxxxxxxxxxxxxxxxx:/b/src/y"
var got string
ok := assert.NotPanics(t, func() { got = importPath("/b/src/y/foo/bar", "bar") })
ok := assert.NotPanics(t, func() { got = importPath("/b/src/y/foo/bar") })
if ok {
assert.Equal(t, "/b/src/y/foo/bar", got)
}
Expand All @@ -23,7 +23,7 @@ func Test_fullPackageName(t *testing.T) {
build.Default.GOPATH = "/a/src/x:/b/src/y"

var got string
ok := assert.NotPanics(t, func() { got = importPath("/a/src/x/foo/bar", "bar") })
ok := assert.NotPanics(t, func() { got = importPath("/a/src/x/foo/bar") })
if ok {
assert.Equal(t, "/a/src/x/foo/bar", got)
}
Expand Down
17 changes: 13 additions & 4 deletions codegen/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package codegen

import (
"fmt"
"go/build"
"io/ioutil"
"os"
"path/filepath"
Expand Down Expand Up @@ -80,24 +81,32 @@ func (c *PackageConfig) normalize() error {
if c.Filename == "" {
return errors.New("Filename is required")
}
c.Filename = abs(c.Filename)
// If Package is not set, first attempt to load the package at the output dir. If that fails
// fallback to just the base dir name of the output filename.
if c.Package == "" {
c.Package = filepath.Base(c.Dir())
cwd, _ := os.Getwd()
pkg, err := build.Default.Import(c.Dir(), cwd, 0)
if err != nil {
c.Package = filepath.Base(c.Dir())
} else {
c.Package = pkg.Name
}
}
c.Package = sanitizePackageName(c.Package)
c.Filename = abs(c.Filename)
return nil
}

func (c *PackageConfig) ImportPath() string {
return importPath(c.Dir(), c.Package)
return importPath(c.Dir())
}

func (c *PackageConfig) Dir() string {
return filepath.ToSlash(filepath.Dir(c.Filename))
}

func (c *PackageConfig) Check() error {
if strings.ContainsAny(c.Package, "./") {
if strings.ContainsAny(c.Package, "./\\") {
return fmt.Errorf("package should be the output package name only, do not include the output filename")
}
return nil
Expand Down

0 comments on commit 6d38f77

Please sign in to comment.