diff --git a/codegen/build.go b/codegen/build.go index 20129e4f929..42dff9243bf 100644 --- a/codegen/build.go +++ b/codegen/build.go @@ -81,6 +81,7 @@ func (cfg *Config) bind() (*Build, error) { Interfaces: cfg.buildInterfaces(namedTypes, prog), Inputs: inputs, Imports: imports.finalize(), + SchemaRaw: cfg.SchemaStr, } if qr, ok := cfg.schema.EntryPoints["query"]; ok { diff --git a/codegen/codegen.go b/codegen/codegen.go index 8c3e0a82107..0a6f2e1a4e7 100644 --- a/codegen/codegen.go +++ b/codegen/codegen.go @@ -30,7 +30,6 @@ func Generate(cfg Config) error { return errors.Wrap(err, "model plan failed") } if len(modelsBuild.Models) > 0 || len(modelsBuild.Enums) > 0 { - modelsBuild.PackageName = cfg.Model.Package var buf *bytes.Buffer buf, err = templates.Run("models.gotpl", modelsBuild) if err != nil { @@ -57,8 +56,6 @@ func Generate(cfg Config) error { if err != nil { return errors.Wrap(err, "exec plan failed") } - build.SchemaRaw = cfg.SchemaStr - build.PackageName = cfg.Exec.Package var buf *bytes.Buffer buf, err = templates.Run("generated.gotpl", build) @@ -129,20 +126,18 @@ func abs(path string) string { return filepath.ToSlash(absPath) } -func importPath(dir string, pkgName string) string { - fullPkgName := filepath.Join(filepath.Dir(dir), pkgName) - +func importPath(dir string) string { for _, gopath := range filepath.SplitList(build.Default.GOPATH) { gopath = filepath.Join(gopath, "src") + string(os.PathSeparator) - if len(gopath) > len(fullPkgName) { + if len(gopath) > len(dir) { continue } - if strings.EqualFold(gopath, fullPkgName[0:len(gopath)]) { - fullPkgName = fullPkgName[len(gopath):] + if strings.EqualFold(gopath, dir[0:len(gopath)]) { + dir = dir[len(gopath):] break } } - return filepath.ToSlash(fullPkgName) + return filepath.ToSlash(dir) } func gofmt(filename string, b []byte) ([]byte, error) { diff --git a/codegen/codegen_test.go b/codegen/codegen_test.go index a632e7b379a..d777447ef8c 100644 --- a/codegen/codegen_test.go +++ b/codegen/codegen_test.go @@ -14,7 +14,7 @@ func Test_fullPackageName(t *testing.T) { t.Run("gopath longer than package name", func(t *testing.T) { build.Default.GOPATH = "/a/src/xxxxxxxxxxxxxxxxxxxxxxxx:/b/src/y" var got string - ok := assert.NotPanics(t, func() { got = importPath("/b/src/y/foo/bar", "bar") }) + ok := assert.NotPanics(t, func() { got = importPath("/b/src/y/foo/bar") }) if ok { assert.Equal(t, "/b/src/y/foo/bar", got) } @@ -23,7 +23,7 @@ func Test_fullPackageName(t *testing.T) { build.Default.GOPATH = "/a/src/x:/b/src/y" var got string - ok := assert.NotPanics(t, func() { got = importPath("/a/src/x/foo/bar", "bar") }) + ok := assert.NotPanics(t, func() { got = importPath("/a/src/x/foo/bar") }) if ok { assert.Equal(t, "/a/src/x/foo/bar", got) } diff --git a/codegen/config.go b/codegen/config.go index f98f079bcb5..977db7ac2c3 100644 --- a/codegen/config.go +++ b/codegen/config.go @@ -2,6 +2,7 @@ package codegen import ( "fmt" + "go/build" "io/ioutil" "os" "path/filepath" @@ -80,16 +81,24 @@ func (c *PackageConfig) normalize() error { if c.Filename == "" { return errors.New("Filename is required") } - c.Filename = abs(c.Filename) + // If Package is not set, first attempt to load the package at the output dir. If that fails + // fallback to just the base dir name of the output filename. if c.Package == "" { - c.Package = filepath.Base(c.Dir()) + cwd, _ := os.Getwd() + pkg, err := build.Default.Import(c.Dir(), cwd, 0) + if err != nil { + c.Package = filepath.Base(c.Dir()) + } else { + c.Package = pkg.Name + } } c.Package = sanitizePackageName(c.Package) + c.Filename = abs(c.Filename) return nil } func (c *PackageConfig) ImportPath() string { - return importPath(c.Dir(), c.Package) + return importPath(c.Dir()) } func (c *PackageConfig) Dir() string { @@ -97,7 +106,7 @@ func (c *PackageConfig) Dir() string { } func (c *PackageConfig) Check() error { - if strings.ContainsAny(c.Package, "./") { + if strings.ContainsAny(c.Package, "./\\") { return fmt.Errorf("package should be the output package name only, do not include the output filename") } return nil