Skip to content

Commit

Permalink
Decode: fix reuse of slice for array tables (#934)
Browse files Browse the repository at this point in the history
When decoding into a non-empty slice, it needs to be emptied so that only the
tables contained in the document are present in the resulting value.

Arrays are not impacted because their unmarshal offset is tracked separately.

Fixes #931
  • Loading branch information
pelletier authored Feb 27, 2024
1 parent 2e087bd commit 06fb30b
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 31 deletions.
62 changes: 33 additions & 29 deletions internal/tracker/seen.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,9 @@ func (s *SeenTracker) setExplicitFlag(parentIdx int) {

// CheckExpression takes a top-level node and checks that it does not contain
// keys that have been seen in previous calls, and validates that types are
// consistent.
func (s *SeenTracker) CheckExpression(node *unstable.Node) error {
// consistent. It returns true if it is the first time this node's key is seen.
// Useful to clear array tables on first use.
func (s *SeenTracker) CheckExpression(node *unstable.Node) (bool, error) {
if s.entries == nil {
s.reset()
}
Expand All @@ -166,7 +167,7 @@ func (s *SeenTracker) CheckExpression(node *unstable.Node) error {
}
}

func (s *SeenTracker) checkTable(node *unstable.Node) error {
func (s *SeenTracker) checkTable(node *unstable.Node) (bool, error) {
if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx)
}
Expand All @@ -192,7 +193,7 @@ func (s *SeenTracker) checkTable(node *unstable.Node) error {
} else {
entry := s.entries[idx]
if entry.kind == valueKind {
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
}
}
parentIdx = idx
Expand All @@ -201,25 +202,27 @@ func (s *SeenTracker) checkTable(node *unstable.Node) error {
k := it.Node().Data
idx := s.find(parentIdx, k)

first := false
if idx >= 0 {
kind := s.entries[idx].kind
if kind != tableKind {
return fmt.Errorf("toml: key %s should be a table, not a %s", string(k), kind)
return false, fmt.Errorf("toml: key %s should be a table, not a %s", string(k), kind)
}
if s.entries[idx].explicit {
return fmt.Errorf("toml: table %s already exists", string(k))
return false, fmt.Errorf("toml: table %s already exists", string(k))
}
s.entries[idx].explicit = true
} else {
idx = s.create(parentIdx, k, tableKind, true, false)
first = true
}

s.currentIdx = idx

return nil
return first, nil
}

func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
func (s *SeenTracker) checkArrayTable(node *unstable.Node) (bool, error) {
if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx)
}
Expand All @@ -242,7 +245,7 @@ func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
} else {
entry := s.entries[idx]
if entry.kind == valueKind {
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
}
}

Expand All @@ -252,22 +255,23 @@ func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
k := it.Node().Data
idx := s.find(parentIdx, k)

if idx >= 0 {
firstTime := idx < 0
if firstTime {
idx = s.create(parentIdx, k, arrayTableKind, true, false)
} else {
kind := s.entries[idx].kind
if kind != arrayTableKind {
return fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", kind, string(k))
return false, fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", kind, string(k))
}
s.clear(idx)
} else {
idx = s.create(parentIdx, k, arrayTableKind, true, false)
}

s.currentIdx = idx

return nil
return firstTime, nil
}

func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
func (s *SeenTracker) checkKeyValue(node *unstable.Node) (bool, error) {
parentIdx := s.currentIdx
it := node.Key()

Expand All @@ -281,11 +285,11 @@ func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
} else {
entry := s.entries[idx]
if it.IsLast() {
return fmt.Errorf("toml: key %s is already defined", string(k))
return false, fmt.Errorf("toml: key %s is already defined", string(k))
} else if entry.kind != tableKind {
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
} else if entry.explicit {
return fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k))
return false, fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k))
}
}

Expand All @@ -303,30 +307,30 @@ func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
return s.checkArray(value)
}

return nil
return false, nil
}

func (s *SeenTracker) checkArray(node *unstable.Node) error {
func (s *SeenTracker) checkArray(node *unstable.Node) (first bool, err error) {
it := node.Children()
for it.Next() {
n := it.Node()
switch n.Kind {
case unstable.InlineTable:
err := s.checkInlineTable(n)
first, err = s.checkInlineTable(n)
if err != nil {
return err
return false, err
}
case unstable.Array:
err := s.checkArray(n)
first, err = s.checkArray(n)
if err != nil {
return err
return false, err
}
}
}
return nil
return first, nil
}

func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
func (s *SeenTracker) checkInlineTable(node *unstable.Node) (first bool, err error) {
if pool.New == nil {
pool.New = func() interface{} {
return &SeenTracker{}
Expand All @@ -339,9 +343,9 @@ func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
it := node.Children()
for it.Next() {
n := it.Node()
err := s.checkKeyValue(n)
first, err = s.checkKeyValue(n)
if err != nil {
return err
return false, err
}
}

Expand All @@ -352,5 +356,5 @@ func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
// redefinition of its keys: check* functions cannot walk into
// a value.
pool.Put(s)
return nil
return first, nil
}
18 changes: 16 additions & 2 deletions unmarshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ type decoder struct {
// need to be skipped.
skipUntilTable bool

// Flag indicating that the current array/slice table should be cleared because
// it is the first encounter of an array table.
clearArrayTable bool

// Tracks position in Go arrays.
// This is used when decoding [[array tables]] into Go arrays. Given array
// tables are separate TOML expression, we need to keep track of where we
Expand Down Expand Up @@ -246,9 +250,10 @@ Rules for the unmarshal code:
func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) error {
var x reflect.Value
var err error
var first bool // used for to clear array tables on first use

if !(d.skipUntilTable && expr.Kind == unstable.KeyValue) {
err = d.seen.CheckExpression(expr)
first, err = d.seen.CheckExpression(expr)
if err != nil {
return err
}
Expand All @@ -267,6 +272,7 @@ func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) err
case unstable.ArrayTable:
d.skipUntilTable = false
d.strict.EnterArrayTable(expr)
d.clearArrayTable = first
x, err = d.handleArrayTable(expr.Key(), v)
default:
panic(fmt.Errorf("parser should not permit expression of kind %s at document root", expr.Kind))
Expand Down Expand Up @@ -307,6 +313,10 @@ func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflec
reflect.Copy(nelem, elem)
elem = nelem
}
if d.clearArrayTable && elem.Len() > 0 {
elem.SetLen(0)
d.clearArrayTable = false
}
}
return d.handleArrayTableCollectionLast(key, elem)
case reflect.Ptr:
Expand All @@ -325,6 +335,10 @@ func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflec

return v, nil
case reflect.Slice:
if d.clearArrayTable && v.Len() > 0 {
v.SetLen(0)
d.clearArrayTable = false
}
elemType := v.Type().Elem()
var elem reflect.Value
if elemType.Kind() == reflect.Interface {
Expand Down Expand Up @@ -576,7 +590,7 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
break
}

err := d.seen.CheckExpression(expr)
_, err := d.seen.CheckExpression(expr)
if err != nil {
return reflect.Value{}, err
}
Expand Down
70 changes: 70 additions & 0 deletions unmarshaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2823,6 +2823,76 @@ blah.a = "def"`)
require.Equal(t, "def", cfg.A)
}

func TestIssue931(t *testing.T) {
type item struct {
Name string
}

type items struct {
Slice []item
}

its := items{[]item{{"a"}, {"b"}}}

b := []byte(`
[[Slice]]
Name = 'c'
[[Slice]]
Name = 'd'
`)

toml.Unmarshal(b, &its)
require.Equal(t, items{[]item{{"c"}, {"d"}}}, its)
}

func TestIssue931Interface(t *testing.T) {
type items struct {
Slice interface{}
}

type item = map[string]interface{}

its := items{[]interface{}{item{"Name": "a"}, item{"Name": "b"}}}

b := []byte(`
[[Slice]]
Name = 'c'
[[Slice]]
Name = 'd'
`)

toml.Unmarshal(b, &its)
require.Equal(t, items{[]interface{}{item{"Name": "c"}, item{"Name": "d"}}}, its)
}

func TestIssue931SliceInterface(t *testing.T) {
type items struct {
Slice []interface{}
}

type item = map[string]interface{}

its := items{
[]interface{}{
item{"Name": "a"},
item{"Name": "b"},
},
}

b := []byte(`
[[Slice]]
Name = 'c'
[[Slice]]
Name = 'd'
`)

toml.Unmarshal(b, &its)
require.Equal(t, items{[]interface{}{item{"Name": "c"}, item{"Name": "d"}}}, its)
}

func TestUnmarshalDecodeErrors(t *testing.T) {
examples := []struct {
desc string
Expand Down

0 comments on commit 06fb30b

Please sign in to comment.