From 6199427fa5ff1a4ae5292c4e86e4f25475969477 Mon Sep 17 00:00:00 2001 From: Inteon <42113979+inteon@users.noreply.github.com> Date: Thu, 4 Nov 2021 16:59:42 +0100 Subject: [PATCH] update fields.go based on upstream changes Signed-off-by: Inteon <42113979+inteon@users.noreply.github.com> --- fields.go | 279 +++++++++++++----------------------------------------- yaml.go | 20 +--- 2 files changed, 69 insertions(+), 230 deletions(-) diff --git a/fields.go b/fields.go index ba3dd4d..14d0759 100644 --- a/fields.go +++ b/fields.go @@ -5,7 +5,6 @@ package yaml import ( - "bytes" "encoding" "encoding/json" "reflect" @@ -13,26 +12,25 @@ import ( "strings" "sync" "unicode" - "unicode/utf8" ) // indirect walks down 'value' allocating pointers as needed, // until it gets to a non-pointer. -// if it encounters an Unmarshaler, indirect stops and returns that. -// if decodingNull is true, indirect stops at the last pointer so it can be set to nil. -func indirect(value reflect.Value, decodingNull bool) (json.Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { +// If it encounters an Unmarshaler, indirect stops and returns nil. +func indirect(value reflect.Value) *reflect.Value { // If 'value' is a named type and is addressable, // start with its address, so that if the type has pointer methods, // we find them. if value.Kind() != reflect.Ptr && value.Type().Name() != "" && value.CanAddr() { value = value.Addr() } + for { // Load value from interface, but only if the result will be // usefully addressable. if value.Kind() == reflect.Interface && !value.IsNil() { element := value.Elem() - if element.Kind() == reflect.Ptr && !element.IsNil() && (!decodingNull || element.Elem().Kind() == reflect.Ptr) { + if element.Kind() == reflect.Ptr && !element.IsNil() { value = element continue } @@ -42,30 +40,38 @@ func indirect(value reflect.Value, decodingNull bool) (json.Unmarshaler, encodin break } - if value.Elem().Kind() != reflect.Ptr && decodingNull && value.CanSet() { + // Prevent infinite loop if v is an interface pointing to its own address: + // var v interface{} + // v = &v + if value.Elem().Kind() == reflect.Interface && value.Elem().Elem() == value { + value = value.Elem() break } + if value.IsNil() { value = reflect.New(value.Type().Elem()) } - if value.Type().NumMethod() > 0 { - if u, ok := value.Interface().(json.Unmarshaler); ok { - return u, nil, reflect.Value{} + + // We have a JSON or Text Umarshaler at this level, so we can't be trying + // to decode into a string. + if value.Type().NumMethod() > 0 && value.CanInterface() { + if _, ok := value.Interface().(json.Unmarshaler); ok { + return nil } - if u, ok := value.Interface().(encoding.TextUnmarshaler); ok { - return nil, u, reflect.Value{} + if _, ok := value.Interface().(encoding.TextUnmarshaler); ok { + return nil } } + value = value.Elem() } - return nil, nil, value + + return &value } // A field represents a single field found in a struct. type field struct { - name string - nameBytes []byte // []byte(name) - equalFold func(s, t []byte) bool // bytes.EqualFold or equivalent + name string tag bool index []int @@ -74,12 +80,6 @@ type field struct { quoted bool } -func fillField(f field) field { - f.nameBytes = []byte(f.name) - f.equalFold = foldFunc(f.nameBytes) - return f -} - // byName sorts field by name, breaking ties with depth, // then breaking ties with "name came from json tag", then // breaking ties with index sequence. @@ -130,8 +130,7 @@ func typeFields(t reflect.Type) []field { next := []field{{typ: t}} // Count of queued names for current level and the next. - var count map[reflect.Type]int - var nextCount map[reflect.Type]int + var count, nextCount map[reflect.Type]int // Types already visited at an earlier level. visited := map[reflect.Type]bool{} @@ -152,7 +151,19 @@ func typeFields(t reflect.Type) []field { // Scan f.typ for fields to include. for i := 0; i < f.typ.NumField(); i++ { sf := f.typ.Field(i) - if sf.PkgPath != "" { // unexported + if sf.Anonymous { + t := sf.Type + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if !sf.IsExported() && t.Kind() != reflect.Struct { + // Ignore embedded fields of unexported non-struct types. + continue + } + // Do not ignore embedded fields of unexported struct types + // since they may have exported fields. + } else if !sf.IsExported() { + // Ignore unexported non-embedded fields. continue } tag := sf.Tag.Get("json") @@ -173,20 +184,34 @@ func typeFields(t reflect.Type) []field { ft = ft.Elem() } + // Only strings, floats, integers, and booleans can be quoted. + quoted := false + if opts.Contains("string") { + switch ft.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, + reflect.Float32, reflect.Float64, + reflect.String: + quoted = true + } + } + // Record found field and index sequence. if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct { tagged := name != "" if name == "" { name = sf.Name } - fields = append(fields, fillField(field{ + fields = append(fields, field{ name: name, tag: tagged, index: index, typ: ft, omitEmpty: opts.Contains("omitempty"), - quoted: opts.Contains("string"), - })) + quoted: quoted, + }) + if count[f.typ] > 1 { // If there were multiple instances, add a second, // so that the annihilation code will see a duplicate. @@ -200,7 +225,7 @@ func typeFields(t reflect.Type) []field { // Record new anonymous struct to explore in next round. nextCount[ft]++ if nextCount[ft] == 1 { - next = append(next, fillField(field{name: ft.Name(), index: index, typ: ft})) + next = append(next, field{name: ft.Name(), index: index, typ: ft}) } } } @@ -249,65 +274,24 @@ func typeFields(t reflect.Type) []field { // will be false: This condition is an error in Go and we skip all // the fields. func dominantField(fields []field) (field, bool) { - // The fields are sorted in increasing index-length order. The winner - // must therefore be one with the shortest index length. Drop all - // longer entries, which is easy: just truncate the slice. - length := len(fields[0].index) - tagged := -1 // Index of first tagged field. - for i, f := range fields { - if len(f.index) > length { - fields = fields[:i] - break - } - if f.tag { - if tagged >= 0 { - // Multiple tagged fields at the same level: conflict. - // Return no field. - return field{}, false - } - tagged = i - } - } - if tagged >= 0 { - return fields[tagged], true - } - // All remaining fields have the same length. If there's more than one, - // we have a conflict (two fields named "X" at the same level) and we - // return no field. - if len(fields) > 1 { + // The fields are sorted in increasing index-length order, then by presence of tag. + // That means that the first field is the dominant one. We need only check + // for error cases: two fields at top level, either both tagged or neither tagged. + if len(fields) > 1 && len(fields[0].index) == len(fields[1].index) && fields[0].tag == fields[1].tag { return field{}, false } return fields[0], true } -var fieldCache struct { - sync.RWMutex - m map[reflect.Type][]field -} +var fieldCache sync.Map // map[reflect.Type]structFields // cachedTypeFields is like typeFields but uses a cache to avoid repeated work. func cachedTypeFields(t reflect.Type) []field { - fieldCache.RLock() - f := fieldCache.m[t] - fieldCache.RUnlock() - if f != nil { - return f - } - - // Compute fields without lock. - // Might duplicate effort but won't hold other computations back. - f = typeFields(t) - if f == nil { - f = []field{} + if f, ok := fieldCache.Load(t); ok { + return f.([]field) } - - fieldCache.Lock() - if fieldCache.m == nil { - fieldCache.m = map[reflect.Type][]field{} - } - fieldCache.m[t] = f - fieldCache.Unlock() - return f + f, _ := fieldCache.LoadOrStore(t, typeFields(t)) + return f.([]field) } func isValidTag(s string) bool { @@ -316,144 +300,11 @@ func isValidTag(s string) bool { } for _, c := range s { switch { - case strings.ContainsRune("!#$%&()*+-./:<=>?@[]^_{|}~ ", c): + case strings.ContainsRune("!#$%&()*+-./:;<=>?@[]^_{|}~ ", c): // Backslash and quote chars are reserved, but // otherwise any punctuation chars are allowed // in a tag name. - default: - if !unicode.IsLetter(c) && !unicode.IsDigit(c) { - return false - } - } - } - return true -} - -const ( - caseMask = ^byte(0x20) // Mask to ignore case in ASCII. - kelvin = '\u212a' - smallLongEss = '\u017f' -) - -// foldFunc returns one of four different case folding equivalence -// functions, from most general (and slow) to fastest: -// -// 1) bytes.EqualFold, if the key s contains any non-ASCII UTF-8 -// 2) equalFoldRight, if s contains special folding ASCII ('k', 'K', 's', 'S') -// 3) asciiEqualFold, no special, but includes non-letters (including _) -// 4) simpleLetterEqualFold, no specials, no non-letters. -// -// The letters S and K are special because they map to 3 runes, not just 2: -// * S maps to s and to U+017F 'ſ' Latin small letter long s -// * k maps to K and to U+212A 'K' Kelvin sign -// See http://play.golang.org/p/tTxjOc0OGo -// -// The returned function is specialized for matching against s and -// should only be given s. It's not curried for performance reasons. -func foldFunc(s []byte) func(s, t []byte) bool { - nonLetter := false - special := false // special letter - for _, b := range s { - if b >= utf8.RuneSelf { - return bytes.EqualFold - } - upper := b & caseMask - if upper < 'A' || upper > 'Z' { - nonLetter = true - } else if upper == 'K' || upper == 'S' { - // See above for why these letters are special. - special = true - } - } - if special { - return equalFoldRight - } - if nonLetter { - return asciiEqualFold - } - return simpleLetterEqualFold -} - -// equalFoldRight is a specialization of bytes.EqualFold when s is -// known to be all ASCII (including punctuation), but contains an 's', -// 'S', 'k', or 'K', requiring a Unicode fold on the bytes in t. -// See comments on foldFunc. -func equalFoldRight(s, t []byte) bool { - for _, sb := range s { - if len(t) == 0 { - return false - } - tb := t[0] - if tb < utf8.RuneSelf { - if sb != tb { - sbUpper := sb & caseMask - if 'A' <= sbUpper && sbUpper <= 'Z' { - if sbUpper != tb&caseMask { - return false - } - } else { - return false - } - } - t = t[1:] - continue - } - // sb is ASCII and t is not. t must be either kelvin - // sign or long s; sb must be s, S, k, or K. - tr, size := utf8.DecodeRune(t) - switch sb { - case 's', 'S': - if tr != smallLongEss { - return false - } - case 'k', 'K': - if tr != kelvin { - return false - } - default: - return false - } - t = t[size:] - - } - - return len(t) <= 0 -} - -// asciiEqualFold is a specialization of bytes.EqualFold for use when -// s is all ASCII (but may contain non-letters) and contains no -// special-folding letters. -// See comments on foldFunc. -func asciiEqualFold(s, t []byte) bool { - if len(s) != len(t) { - return false - } - for i, sb := range s { - tb := t[i] - if sb == tb { - continue - } - if ('a' <= sb && sb <= 'z') || ('A' <= sb && sb <= 'Z') { - if sb&caseMask != tb&caseMask { - return false - } - } else { - return false - } - } - return true -} - -// simpleLetterEqualFold is a specialization of bytes.EqualFold for -// use when s is all ASCII letters (no underscores, etc) and also -// doesn't contain 'k', 'K', 's', or 'S'. -// See comments on foldFunc. -func simpleLetterEqualFold(s, t []byte) bool { - if len(s) != len(t) { - return false - } - for i, b := range s { - if b&caseMask != t[i]&caseMask { + case !unicode.IsLetter(c) && !unicode.IsDigit(c): return false } } diff --git a/yaml.go b/yaml.go index 668071e..939709f 100644 --- a/yaml.go +++ b/yaml.go @@ -188,14 +188,7 @@ func convertToJSONableObject(yamlObj interface{}, jsonTarget *reflect.Value) (in // decoding into the value, we're just checking if the ultimate target is a // string. if jsonTarget != nil { - jsonUnmarshaler, textUnmarshaler, pointerValue := indirect(*jsonTarget, false) - // We have a JSON or Text Umarshaler at this level, so we can't be trying - // to decode into a string. - if jsonUnmarshaler != nil || textUnmarshaler != nil { - jsonTarget = nil - } else { - jsonTarget = &pointerValue - } + jsonTarget = indirect(*jsonTarget) } // Transform map[string]interface{} into map[interface{}]interface{} @@ -267,21 +260,16 @@ func convertToJSONableObject(yamlObj interface{}, jsonTarget *reflect.Value) (in if jsonTarget != nil { t := *jsonTarget if t.Kind() == reflect.Struct { - keyBytes := []byte(keyString) // Find the field that the JSON library would use. var f *field fields := cachedTypeFields(t.Type()) for i := range fields { - ff := &fields[i] - if bytes.Equal(ff.nameBytes, keyBytes) { - f = ff + f = &fields[i] + if f.name == keyString { break } - // Do case-insensitive comparison. - if f == nil && ff.equalFold(ff.nameBytes, keyBytes) { - f = ff - } } + if f != nil { // Find the reflect.Value of the most preferential // struct field.