From df0b65c5dea3e23213b37a109ff74439babc8b93 Mon Sep 17 00:00:00 2001 From: "Alexander A. Klimov" Date: Thu, 14 Oct 2021 18:52:02 +0200 Subject: [PATCH 1/2] TestDecoder_DefaultValues(): also cover nested structs --- decode_test.go | 98 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/decode_test.go b/decode_test.go index ca473d43..fb13b9af 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1692,13 +1692,63 @@ func TestDecoder_DefaultValues(t *testing.T) { A string `yaml:"a"` B string `yaml:"b"` c string // private + D struct { + E string `yaml:"e"` + F struct { + G string `yaml:"g"` + } `yaml:"f"` + H struct { + I string `yaml:"i"` + } `yaml:",inline"` + } `yaml:"d"` + J struct { + K string `yaml:"k"` + L struct { + M string `yaml:"m"` + } `yaml:"l"` + N struct { + O string `yaml:"o"` + } `yaml:",inline"` + } `yaml:",inline"` + P struct { + Q string `yaml:"q"` + R struct { + S string `yaml:"s"` + } `yaml:"r"` + T struct { + U string `yaml:"u"` + } `yaml:",inline"` + } `yaml:"p"` + V struct { + W string `yaml:"w"` + X struct { + Y string `yaml:"y"` + } `yaml:"x"` + Z struct { + Ä string `yaml:"ä"` + } `yaml:",inline"` + } `yaml:",inline"` }{ B: "defaultBValue", c: "defaultCValue", } + v.D.E = "defaultEValue" + v.D.F.G = "defaultGValue" + v.D.H.I = "defaultIValue" + v.J.K = "defaultKValue" + v.J.L.M = "defaultMValue" + v.J.N.O = "defaultOValue" + v.P.R.S = "defaultSValue" + v.P.T.U = "defaultUValue" + v.V.X.Y = "defaultYValue" + v.V.Z.Ä = "defaultÄValue" + const src = `--- a: a_value +p: + q: q_value +w: w_value ` if err := yaml.NewDecoder(strings.NewReader(src)).Decode(&v); err != nil { t.Fatalf(`parsing should succeed: %s`, err) @@ -1714,6 +1764,54 @@ a: a_value if v.c != "defaultCValue" { t.Fatalf("v.c should be `defaultCValue`, got `%s`", v.c) } + + if v.D.E != "defaultEValue" { + t.Fatalf("v.D.E should be `defaultEValue`, got `%s`", v.D.E) + } + + if v.D.F.G != "defaultGValue" { + t.Fatalf("v.D.F.G should be `defaultGValue`, got `%s`", v.D.F.G) + } + + if v.D.H.I != "defaultIValue" { + t.Fatalf("v.D.H.I should be `defaultIValue`, got `%s`", v.D.H.I) + } + + if v.J.K != "defaultKValue" { + t.Fatalf("v.J.K should be `defaultKValue`, got `%s`", v.J.K) + } + + if v.J.L.M != "defaultMValue" { + t.Fatalf("v.J.L.M should be `defaultMValue`, got `%s`", v.J.L.M) + } + + if v.J.N.O != "defaultOValue" { + t.Fatalf("v.J.N.O should be `defaultOValue`, got `%s`", v.J.N.O) + } + + if v.P.Q != "q_value" { + t.Fatalf("v.P.Q should be `q_value`, got `%s`", v.P.Q) + } + + if v.P.R.S != "defaultSValue" { + t.Fatalf("v.P.R.S should be `defaultSValue`, got `%s`", v.P.R.S) + } + + if v.P.T.U != "defaultUValue" { + t.Fatalf("v.P.T.U should be `defaultUValue`, got `%s`", v.P.T.U) + } + + if v.V.W != "w_value" { + t.Fatalf("v.V.W should be `w_value`, got `%s`", v.V.W) + } + + if v.V.X.Y != "defaultYValue" { + t.Fatalf("v.V.X.Y should be `defaultYValue`, got `%s`", v.V.X.Y) + } + + if v.V.Z.Ä != "defaultÄValue" { + t.Fatalf("v.V.Z.Ä should be `defaultÄValue`, got `%s`", v.V.Z.Ä) + } } func Example_YAMLTags() { From c6bcb80455359c483de9b53eb0316bbd07748f1b Mon Sep 17 00:00:00 2001 From: "Alexander A. Klimov" Date: Fri, 15 Oct 2021 15:11:46 +0200 Subject: [PATCH 2/2] Preserve defaults while decoding nested structs --- decode.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/decode.go b/decode.go index a313541d..45fb6e65 100644 --- a/decode.go +++ b/decode.go @@ -776,7 +776,9 @@ func (d *Decoder) castToAssignableValue(value reflect.Value, target reflect.Type return value } -func (d *Decoder) createDecodedNewValue(ctx context.Context, typ reflect.Type, node ast.Node) (reflect.Value, error) { +func (d *Decoder) createDecodedNewValue( + ctx context.Context, typ reflect.Type, defaultVal reflect.Value, node ast.Node, +) (reflect.Value, error) { if node.Type() == ast.AliasType { aliasName := node.(*ast.AliasNode).Value.GetToken().Value newValue := d.anchorValueMap[aliasName] @@ -788,6 +790,12 @@ func (d *Decoder) createDecodedNewValue(ctx context.Context, typ reflect.Type, n return reflect.Zero(typ), nil } newValue := d.createDecodableValue(typ) + for defaultVal.Kind() == reflect.Ptr { + defaultVal = defaultVal.Elem() + } + if defaultVal.IsValid() && defaultVal.Type().AssignableTo(newValue.Type()) { + newValue.Set(defaultVal) + } if err := d.decodeValue(ctx, newValue, node); err != nil { return newValue, errors.Wrapf(err, "failed to decode value") } @@ -1033,7 +1041,7 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N key := &ast.StringNode{BaseNode: &ast.BaseNode{}, Value: k} mapNode.Values = append(mapNode.Values, ast.MappingValue(nil, key, v)) } - newFieldValue, err := d.createDecodedNewValue(ctx, fieldValue.Type(), mapNode) + newFieldValue, err := d.createDecodedNewValue(ctx, fieldValue.Type(), fieldValue, mapNode) if d.disallowUnknownField { if err := d.deleteStructKeys(fieldValue.Type(), unknownFields); err != nil { return errors.Wrapf(err, "cannot delete struct keys") @@ -1075,7 +1083,7 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N fieldValue.Set(reflect.Zero(fieldValue.Type())) continue } - newFieldValue, err := d.createDecodedNewValue(ctx, fieldValue.Type(), v) + newFieldValue, err := d.createDecodedNewValue(ctx, fieldValue.Type(), fieldValue, v) if err != nil { if foundErr != nil { continue @@ -1156,7 +1164,7 @@ func (d *Decoder) decodeArray(ctx context.Context, dst reflect.Value, src ast.No // set nil value to pointer arrayValue.Index(idx).Set(reflect.Zero(elemType)) } else { - dstValue, err := d.createDecodedNewValue(ctx, elemType, v) + dstValue, err := d.createDecodedNewValue(ctx, elemType, reflect.Value{}, v) if err != nil { if foundErr == nil { foundErr = err @@ -1196,7 +1204,7 @@ func (d *Decoder) decodeSlice(ctx context.Context, dst reflect.Value, src ast.No sliceValue = reflect.Append(sliceValue, reflect.Zero(elemType)) continue } - dstValue, err := d.createDecodedNewValue(ctx, elemType, v) + dstValue, err := d.createDecodedNewValue(ctx, elemType, reflect.Value{}, v) if err != nil { if foundErr == nil { foundErr = err @@ -1338,7 +1346,7 @@ func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node mapValue.SetMapIndex(k, reflect.Zero(valueType)) continue } - dstValue, err := d.createDecodedNewValue(ctx, valueType, value) + dstValue, err := d.createDecodedNewValue(ctx, valueType, reflect.Value{}, value) if err != nil { if foundErr == nil { foundErr = err