Skip to content

Commit

Permalink
Merge pull request #260 from Al2Klimov/bugfix/nested-default
Browse files Browse the repository at this point in the history
Preserve defaults while decoding nested structs
  • Loading branch information
goccy authored Nov 18, 2021
2 parents 5f46a66 + c6bcb80 commit ecc53fd
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 6 deletions.
20 changes: 14 additions & 6 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,9 @@ func (d *Decoder) castToAssignableValue(value reflect.Value, target reflect.Type
return value
}

func (d *Decoder) createDecodedNewValue(ctx context.Context, typ reflect.Type, node ast.Node) (reflect.Value, error) {
func (d *Decoder) createDecodedNewValue(
ctx context.Context, typ reflect.Type, defaultVal reflect.Value, node ast.Node,
) (reflect.Value, error) {
if node.Type() == ast.AliasType {
aliasName := node.(*ast.AliasNode).Value.GetToken().Value
newValue := d.anchorValueMap[aliasName]
Expand All @@ -788,6 +790,12 @@ func (d *Decoder) createDecodedNewValue(ctx context.Context, typ reflect.Type, n
return reflect.Zero(typ), nil
}
newValue := d.createDecodableValue(typ)
for defaultVal.Kind() == reflect.Ptr {
defaultVal = defaultVal.Elem()
}
if defaultVal.IsValid() && defaultVal.Type().AssignableTo(newValue.Type()) {
newValue.Set(defaultVal)
}
if err := d.decodeValue(ctx, newValue, node); err != nil {
return newValue, errors.Wrapf(err, "failed to decode value")
}
Expand Down Expand Up @@ -1033,7 +1041,7 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N
key := &ast.StringNode{BaseNode: &ast.BaseNode{}, Value: k}
mapNode.Values = append(mapNode.Values, ast.MappingValue(nil, key, v))
}
newFieldValue, err := d.createDecodedNewValue(ctx, fieldValue.Type(), mapNode)
newFieldValue, err := d.createDecodedNewValue(ctx, fieldValue.Type(), fieldValue, mapNode)
if d.disallowUnknownField {
if err := d.deleteStructKeys(fieldValue.Type(), unknownFields); err != nil {
return errors.Wrapf(err, "cannot delete struct keys")
Expand Down Expand Up @@ -1075,7 +1083,7 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N
fieldValue.Set(reflect.Zero(fieldValue.Type()))
continue
}
newFieldValue, err := d.createDecodedNewValue(ctx, fieldValue.Type(), v)
newFieldValue, err := d.createDecodedNewValue(ctx, fieldValue.Type(), fieldValue, v)
if err != nil {
if foundErr != nil {
continue
Expand Down Expand Up @@ -1156,7 +1164,7 @@ func (d *Decoder) decodeArray(ctx context.Context, dst reflect.Value, src ast.No
// set nil value to pointer
arrayValue.Index(idx).Set(reflect.Zero(elemType))
} else {
dstValue, err := d.createDecodedNewValue(ctx, elemType, v)
dstValue, err := d.createDecodedNewValue(ctx, elemType, reflect.Value{}, v)
if err != nil {
if foundErr == nil {
foundErr = err
Expand Down Expand Up @@ -1196,7 +1204,7 @@ func (d *Decoder) decodeSlice(ctx context.Context, dst reflect.Value, src ast.No
sliceValue = reflect.Append(sliceValue, reflect.Zero(elemType))
continue
}
dstValue, err := d.createDecodedNewValue(ctx, elemType, v)
dstValue, err := d.createDecodedNewValue(ctx, elemType, reflect.Value{}, v)
if err != nil {
if foundErr == nil {
foundErr = err
Expand Down Expand Up @@ -1338,7 +1346,7 @@ func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node
mapValue.SetMapIndex(k, reflect.Zero(valueType))
continue
}
dstValue, err := d.createDecodedNewValue(ctx, valueType, value)
dstValue, err := d.createDecodedNewValue(ctx, valueType, reflect.Value{}, value)
if err != nil {
if foundErr == nil {
foundErr = err
Expand Down
98 changes: 98 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1692,13 +1692,63 @@ func TestDecoder_DefaultValues(t *testing.T) {
A string `yaml:"a"`
B string `yaml:"b"`
c string // private
D struct {
E string `yaml:"e"`
F struct {
G string `yaml:"g"`
} `yaml:"f"`
H struct {
I string `yaml:"i"`
} `yaml:",inline"`
} `yaml:"d"`
J struct {
K string `yaml:"k"`
L struct {
M string `yaml:"m"`
} `yaml:"l"`
N struct {
O string `yaml:"o"`
} `yaml:",inline"`
} `yaml:",inline"`
P struct {
Q string `yaml:"q"`
R struct {
S string `yaml:"s"`
} `yaml:"r"`
T struct {
U string `yaml:"u"`
} `yaml:",inline"`
} `yaml:"p"`
V struct {
W string `yaml:"w"`
X struct {
Y string `yaml:"y"`
} `yaml:"x"`
Z struct {
Ä string `yaml:"ä"`
} `yaml:",inline"`
} `yaml:",inline"`
}{
B: "defaultBValue",
c: "defaultCValue",
}

v.D.E = "defaultEValue"
v.D.F.G = "defaultGValue"
v.D.H.I = "defaultIValue"
v.J.K = "defaultKValue"
v.J.L.M = "defaultMValue"
v.J.N.O = "defaultOValue"
v.P.R.S = "defaultSValue"
v.P.T.U = "defaultUValue"
v.V.X.Y = "defaultYValue"
v.V.Z.Ä = "defaultÄValue"

const src = `---
a: a_value
p:
q: q_value
w: w_value
`
if err := yaml.NewDecoder(strings.NewReader(src)).Decode(&v); err != nil {
t.Fatalf(`parsing should succeed: %s`, err)
Expand All @@ -1714,6 +1764,54 @@ a: a_value
if v.c != "defaultCValue" {
t.Fatalf("v.c should be `defaultCValue`, got `%s`", v.c)
}

if v.D.E != "defaultEValue" {
t.Fatalf("v.D.E should be `defaultEValue`, got `%s`", v.D.E)
}

if v.D.F.G != "defaultGValue" {
t.Fatalf("v.D.F.G should be `defaultGValue`, got `%s`", v.D.F.G)
}

if v.D.H.I != "defaultIValue" {
t.Fatalf("v.D.H.I should be `defaultIValue`, got `%s`", v.D.H.I)
}

if v.J.K != "defaultKValue" {
t.Fatalf("v.J.K should be `defaultKValue`, got `%s`", v.J.K)
}

if v.J.L.M != "defaultMValue" {
t.Fatalf("v.J.L.M should be `defaultMValue`, got `%s`", v.J.L.M)
}

if v.J.N.O != "defaultOValue" {
t.Fatalf("v.J.N.O should be `defaultOValue`, got `%s`", v.J.N.O)
}

if v.P.Q != "q_value" {
t.Fatalf("v.P.Q should be `q_value`, got `%s`", v.P.Q)
}

if v.P.R.S != "defaultSValue" {
t.Fatalf("v.P.R.S should be `defaultSValue`, got `%s`", v.P.R.S)
}

if v.P.T.U != "defaultUValue" {
t.Fatalf("v.P.T.U should be `defaultUValue`, got `%s`", v.P.T.U)
}

if v.V.W != "w_value" {
t.Fatalf("v.V.W should be `w_value`, got `%s`", v.V.W)
}

if v.V.X.Y != "defaultYValue" {
t.Fatalf("v.V.X.Y should be `defaultYValue`, got `%s`", v.V.X.Y)
}

if v.V.Z.Ä != "defaultÄValue" {
t.Fatalf("v.V.Z.Ä should be `defaultÄValue`, got `%s`", v.V.Z.Ä)
}
}

func Example_YAMLTags() {
Expand Down

0 comments on commit ecc53fd

Please sign in to comment.