Skip to content

Commit

Permalink
Merge pull request #358 from kkHAIKE/encode_MarshalText_fix
Browse files Browse the repository at this point in the history
change eindirect behave match with indirect from decode
  • Loading branch information
arp242 authored Jun 25, 2022
2 parents 0a9f2b0 + c03a31c commit f0ccf71
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 76 deletions.
136 changes: 69 additions & 67 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ var dblQuotedReplacer = strings.NewReplacer(
"\x7f", `\u007f`,
)

var (
marshalToml = reflect.TypeOf((*Marshaler)(nil)).Elem()
marshalText = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
)

// Marshaler is the interface implemented by types that can marshal themselves
// into valid TOML.
type Marshaler interface {
Expand Down Expand Up @@ -154,12 +160,12 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) {
// If we can marshal the type to text, then we use that. This prevents the
// encoder for handling these types as generic structs (or whatever the
// underlying type of a TextMarshaler is).
switch t := rv.Interface().(type) {
case encoding.TextMarshaler, Marshaler:
switch {
case isMarshaler(rv):
enc.writeKeyValue(key, rv, false)
return
case Primitive: // TODO: #76 would make this superfluous after implemented.
enc.encode(key, reflect.ValueOf(t.undecoded))
case rv.Type() == primitiveType: // TODO: #76 would make this superfluous after implemented.
enc.encode(key, reflect.ValueOf(rv.Interface().(Primitive).undecoded))
return
}

Expand Down Expand Up @@ -318,7 +324,7 @@ func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
length := rv.Len()
enc.wf("[")
for i := 0; i < length; i++ {
elem := rv.Index(i)
elem := eindirect(rv.Index(i))
enc.eElement(elem)
if i != length-1 {
enc.wf(", ")
Expand All @@ -332,7 +338,7 @@ func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
encPanic(errNoKey)
}
for i := 0; i < rv.Len(); i++ {
trv := rv.Index(i)
trv := eindirect(rv.Index(i))
if isNil(trv) {
continue
}
Expand All @@ -357,7 +363,7 @@ func (enc *Encoder) eTable(key Key, rv reflect.Value) {
}

func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value, inline bool) {
switch rv := eindirect(rv); rv.Kind() {
switch rv.Kind() {
case reflect.Map:
enc.eMap(key, rv, inline)
case reflect.Struct:
Expand All @@ -379,7 +385,7 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {
var mapKeysDirect, mapKeysSub []string
for _, mapKey := range rv.MapKeys() {
k := mapKey.String()
if typeIsTable(tomlTypeOfGo(rv.MapIndex(mapKey))) {
if typeIsTable(tomlTypeOfGo(eindirect(rv.MapIndex(mapKey)))) {
mapKeysSub = append(mapKeysSub, k)
} else {
mapKeysDirect = append(mapKeysDirect, k)
Expand All @@ -389,7 +395,7 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {
var writeMapKeys = func(mapKeys []string, trailC bool) {
sort.Strings(mapKeys)
for i, mapKey := range mapKeys {
val := rv.MapIndex(reflect.ValueOf(mapKey))
val := eindirect(rv.MapIndex(reflect.ValueOf(mapKey)))
if isNil(val) {
continue
}
Expand Down Expand Up @@ -417,6 +423,13 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {

const is32Bit = (32 << (^uint(0) >> 63)) == 32

func pointerTo(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
return pointerTo(t.Elem())
}
return t
}

func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
// Write keys for fields directly under this key first, because if we write
// a field that creates a new table then all keys under it will be in that
Expand All @@ -433,35 +446,25 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
for i := 0; i < rt.NumField(); i++ {
f := rt.Field(i)
if f.PkgPath != "" && !f.Anonymous { /// Skip unexported fields.
isEmbed := f.Anonymous && pointerTo(f.Type).Kind() == reflect.Struct
if f.PkgPath != "" && !isEmbed { /// Skip unexported fields.
continue
}
opts := getOptions(f.Tag)
if opts.skip {
continue
}

frv := rv.Field(i)
frv := eindirect(rv.Field(i))

// Treat anonymous struct fields with tag names as though they are
// not anonymous, like encoding/json does.
//
// Non-struct anonymous fields use the normal encoding logic.
if f.Anonymous {
t := f.Type
switch t.Kind() {
case reflect.Struct:
if getOptions(f.Tag).name == "" {
addFields(t, frv, append(start, f.Index...))
continue
}
case reflect.Ptr:
if t.Elem().Kind() == reflect.Struct && getOptions(f.Tag).name == "" {
if !frv.IsNil() {
addFields(t.Elem(), frv.Elem(), append(start, f.Index...))
}
continue
}
if isEmbed {
if getOptions(f.Tag).name == "" && frv.Kind() == reflect.Struct {
addFields(frv.Type(), frv, append(start, f.Index...))
continue
}
}

Expand All @@ -487,7 +490,7 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
writeFields := func(fields [][]int) {
for _, fieldIndex := range fields {
fieldType := rt.FieldByIndex(fieldIndex)
fieldVal := rv.FieldByIndex(fieldIndex)
fieldVal := eindirect(rv.FieldByIndex(fieldIndex))

if isNil(fieldVal) { /// Don't write anything for nil fields.
continue
Expand Down Expand Up @@ -540,6 +543,21 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() {
return nil
}

if rv.Kind() == reflect.Struct {
if rv.Type() == timeType {
return tomlDatetime
}
if isMarshaler(rv) {
return tomlString
}
return tomlHash
}

if isMarshaler(rv) {
return tomlString
}

switch rv.Kind() {
case reflect.Bool:
return tomlBool
Expand All @@ -561,42 +579,14 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
return tomlString
case reflect.Map:
return tomlHash
case reflect.Struct:
if _, ok := rv.Interface().(time.Time); ok {
return tomlDatetime
}
if isMarshaler(rv) {
return tomlString
}
return tomlHash
default:
if isMarshaler(rv) {
return tomlString
}

encPanic(errors.New("unsupported type: " + rv.Kind().String()))
panic("unreachable")
}
}

func isMarshaler(rv reflect.Value) bool {
switch rv.Interface().(type) {
case encoding.TextMarshaler:
return true
case Marshaler:
return true
}

// Someone used a pointer receiver: we can make it work for pointer values.
if rv.CanAddr() {
if _, ok := rv.Addr().Interface().(encoding.TextMarshaler); ok {
return true
}
if _, ok := rv.Addr().Interface().(Marshaler); ok {
return true
}
}
return false
return rv.Type().Implements(marshalText) || rv.Type().Implements(marshalToml)
}

// isTableArray reports if all entries in the array or slice are a table.
Expand All @@ -605,19 +595,19 @@ func isTableArray(arr reflect.Value) bool {
return false
}

/// Don't allow nil.
ret := true
for i := 0; i < arr.Len(); i++ {
if tomlTypeOfGo(arr.Index(i)) == nil {
tt := tomlTypeOfGo(eindirect(arr.Index(i)))
// Don't allow nil.
if tt == nil {
encPanic(errArrayNilElement)
}
}

for i := 0; i < arr.Len(); i++ {
if !typeEqual(tomlHash, tomlTypeOfGo(arr.Index(i))) {
return false
if ret && !typeEqual(tomlHash, tt) {
ret = false
}
}
return true
return ret
}

type tagOptions struct {
Expand Down Expand Up @@ -715,13 +705,25 @@ func encPanic(err error) {
panic(tomlEncodeError{err})
}

// Resolve any level of pointers to the actual value (e.g. **string → string).
func eindirect(v reflect.Value) reflect.Value {
switch v.Kind() {
case reflect.Ptr, reflect.Interface:
return eindirect(v.Elem())
default:
if v.Kind() != reflect.Ptr && v.Kind() != reflect.Interface {
if isMarshaler(v) {
return v
}
if v.CanAddr() { /// Special case for marshalers; see #358.
if pv := v.Addr(); isMarshaler(pv) {
return pv
}
}
return v
}

if v.IsNil() {
return v
}

return eindirect(v.Elem())
}

func isNil(rv reflect.Value) bool {
Expand Down
Loading

0 comments on commit f0ccf71

Please sign in to comment.