diff --git a/code.go b/code.go index 8b1d704..672132a 100644 --- a/code.go +++ b/code.go @@ -138,12 +138,12 @@ func (l line) qualify(imports *Imports) { for i := range l.args { switch a := l.args[i].(type) { case Package: - p := imports.RegisterImportForPackage(a) + p := imports.registerPackage(a) if p != a.Name { l.args[i] = Package{Name: p, ImportPath: a.ImportPath} } case *Package: - p := imports.RegisterImportForPackage(*a) + p := imports.registerPackage(*a) if p != a.Name { l.args[i] = Package{Name: p, ImportPath: a.ImportPath} } @@ -209,7 +209,7 @@ func (l line) qualify(imports *Imports) { } case *types.Package: p := PackageForGoType(a) - l.args[i] = imports.RegisterImportForPackage(p) + l.args[i] = imports.registerPackage(p) } } } diff --git a/imports.go b/imports.go index f36415a..6a159ae 100644 --- a/imports.go +++ b/imports.go @@ -39,6 +39,17 @@ func (i *Imports) RegisterImportForPackage(pkg Package) string { return i.RegisterImport(pkg.ImportPath, pkg.Name) } +// registerPackage is like RegisterImportForPackage, but instead of returning +// the prefix (which includes a trailing dot if non-empty), it just returns the +// package alias. +func (i *Imports) registerPackage(pkg Package) string { + p := i.RegisterImportForPackage(pkg) + if len(p) > 0 && p[len(p)-1] == '.' { + p = p[:len(p)-1] + } + return p +} + // RegisterImport "imports" the specified package and returns the package prefix // to use for symbols in the imported package. It is safe to import the same // package repeatedly -- the same prefix will be returned every time. If an diff --git a/templates.go b/templates.go index fb6df3a..c0f8302 100644 --- a/templates.go +++ b/templates.go @@ -9,6 +9,7 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b switch data.Kind() { case reflect.Interface: if data.Elem().IsValid() { + qualified := true var newRv reflect.Value switch d := data.Interface().(type) { case TypeName: @@ -44,20 +45,25 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b } case *types.Package: origP := PackageForGoType(d) - newP := imports.RegisterImportForPackage(origP) + newP := imports.registerPackage(origP) if newP != origP.Name { newRv = reflect.ValueOf(Package{Name: newP, ImportPath: origP.ImportPath}) } + default: + qualified = false } - if newRv.IsValid() && newRv.Type().Implements(data.Type()) { - // For TypeName, this should always be true; but for cases - // where we've changed the type of the value, if we try to - // return an incompatible type, the result will be a panic - // with a location and message that is not awesome for - // users of this package. So we'll ignore the new value if - // it's not the right type. - return newRv, true + if qualified { + if newRv.IsValid() && newRv.Type().Implements(data.Type()) { + // For TypeName, this should always be true; but for cases + // where we've changed the type of the value, if we try to + // return an incompatible type, the result will be a panic + // with a location and message that is not awesome for + // users of this package. So we'll ignore the new value if + // it's not the right type. + return newRv, true + } + return data, false } return qualifyTemplateData(imports, data.Elem()) @@ -66,7 +72,7 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b case reflect.Struct: switch t := data.Interface().(type) { case Package: - p := imports.RegisterImportForPackage(t) + p := imports.registerPackage(t) if p != t.Name { return reflect.ValueOf(&Package{Name: p, ImportPath: t.ImportPath}).Elem(), true } @@ -89,7 +95,7 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b case ConstSpec: if t.parent != nil { oldPkg := t.parent.PackageName - newPkg := imports.RegisterImportForPackage(t.parent.Package()) + newPkg := imports.registerPackage(t.parent.Package()) if newPkg != oldPkg { newCs := t newCs.parent = &GoFile{PackageName: newPkg} @@ -99,7 +105,7 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b case VarSpec: if t.parent != nil { oldPkg := t.parent.PackageName - newPkg := imports.RegisterImportForPackage(t.parent.Package()) + newPkg := imports.registerPackage(t.parent.Package()) if newPkg != oldPkg { newVs := t newVs.parent = &GoFile{PackageName: newPkg} @@ -109,7 +115,7 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b case TypeSpec: if t.parent != nil { oldPkg := t.parent.PackageName - newPkg := imports.RegisterImportForPackage(t.parent.Package()) + newPkg := imports.registerPackage(t.parent.Package()) if newPkg != oldPkg { newTs := t newTs.parent = &GoFile{PackageName: newPkg} @@ -119,7 +125,7 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b case FuncSpec: if t.parent != nil { oldPkg := t.parent.PackageName - newPkg := imports.RegisterImportForPackage(t.parent.Package()) + newPkg := imports.registerPackage(t.parent.Package()) if newPkg != oldPkg { newFs := t newFs.parent = &GoFile{PackageName: newPkg} @@ -141,7 +147,16 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b default: var newStruct reflect.Value for i := 0; i < data.NumField(); i++ { - newV, changedV := qualifyTemplateData(imports, data.Field(i)) + var newV reflect.Value + var changedV bool + fld, ok := getField(data, i) + if !ok { + // do not recurse + newV = data.Field(i) + changedV = false + } else { + newV, changedV = qualifyTemplateData(imports, fld) + } if newStruct.IsValid() { newStruct.Field(i).Set(newV) } else if changedV { diff --git a/templates_no_unsafe.go b/templates_no_unsafe.go new file mode 100644 index 0000000..a7d9769 --- /dev/null +++ b/templates_no_unsafe.go @@ -0,0 +1,15 @@ +//+build appengine gopherjs purego +// NB: other environments where unsafe is unappropriate should use "purego" build tag +// https://github.com/golang/go/issues/23172 + +package gopoet + +import ( + "reflect" +) + +func getField(v reflect.Value, index int) (reflect.Value, bool) { + fld := v.Field(index) + // We can't use unsafe, so return false for unexported fields :( + return fld, !fld.IsValid() || fld.CanInterface() +} diff --git a/templates_unsafe.go b/templates_unsafe.go new file mode 100644 index 0000000..df1d609 --- /dev/null +++ b/templates_unsafe.go @@ -0,0 +1,25 @@ +//+build !appengine,!gopherjs,!purego +// NB: other environments where unsafe is unappropriate should use "purego" build tag +// https://github.com/golang/go/issues/23172 + +package gopoet + +import ( + "reflect" + "unsafe" +) + +func getField(v reflect.Value, index int) (reflect.Value, bool) { + fld := v.Field(index) + if !fld.IsValid() || fld.CanInterface() { + return fld, true + } + + // NB: We are being super-sneaky. Go reflection will not let us call + // fld.Interface() if fld was obtained via unexported fields (which it + // was!). So we use unsafe to create an alternate reflect.Value instance + // that represents the same value (same type and address). We can then + // call Interface() on *that*. + val := reflect.NewAt(fld.Type(), unsafe.Pointer(fld.UnsafeAddr())).Elem() + return val, true +}