Skip to content

Commit

Permalink
Fix cast process for decoding of anchor value (#602)
Browse files Browse the repository at this point in the history
* add test case

* fix cast process for decoding of anchor value
  • Loading branch information
goccy authored Dec 21, 2024
1 parent 54cd51f commit aeed806
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 36 deletions.
68 changes: 44 additions & 24 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,11 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No
if err := d.decodeValue(ctx, v, src); err != nil {
return err
}
dst.Set(d.castToAssignableValue(v, dst.Type()))
castedValue, err := d.castToAssignableValue(v, dst.Type(), src)
if err != nil {
return err
}
dst.Set(castedValue)
case reflect.Interface:
if dst.Type() == astNodeType {
dst.Set(reflect.ValueOf(src))
Expand Down Expand Up @@ -1121,33 +1125,43 @@ func (d *Decoder) createDecodableValue(typ reflect.Type) reflect.Value {
return reflect.New(typ).Elem()
}

func (d *Decoder) castToAssignableValue(value reflect.Value, target reflect.Type) reflect.Value {
func (d *Decoder) castToAssignableValue(value reflect.Value, target reflect.Type, src ast.Node) (reflect.Value, error) {
if target.Kind() != reflect.Ptr {
return value
}
maxTryCount := 5
tryCount := 0
for {
if tryCount > maxTryCount {
return value
if !value.Type().AssignableTo(target) {
return reflect.Value{}, errors.ErrTypeMismatch(target, value.Type(), src.GetToken())
}
return value, nil
}

const maxAddrCount = 5

for i := 0; i < maxAddrCount; i++ {
if value.Type().AssignableTo(target) {
break
}
value = value.Addr()
tryCount++
}
return value
if !value.Type().AssignableTo(target) {
return reflect.Value{}, errors.ErrTypeMismatch(target, value.Type(), src.GetToken())
}
return value, nil
}

func (d *Decoder) createDecodedNewValue(
ctx context.Context, typ reflect.Type, defaultVal reflect.Value, node ast.Node,
) (reflect.Value, error) {
if node.Type() == ast.AliasType {
aliasName := node.(*ast.AliasNode).Value.GetToken().Value
newValue := d.anchorValueMap[aliasName]
if newValue.IsValid() {
return newValue, nil
value := d.anchorValueMap[aliasName]
if value.IsValid() {
v, err := d.castToAssignableValue(value, typ, node)
if err == nil {
return v, nil
}
}
anchor, exists := d.anchorNodeMap[aliasName]
if exists {
node = anchor
}
}
var newValue reflect.Value
Expand All @@ -1164,10 +1178,10 @@ func (d *Decoder) createDecodedNewValue(
}
if node.Type() != ast.NullType {
if err := d.decodeValue(ctx, newValue, node); err != nil {
return newValue, err
return reflect.Value{}, err
}
}
return newValue, nil
return d.castToAssignableValue(newValue, typ, node)
}

func (d *Decoder) keyToNodeMap(node ast.Node, ignoreMergeKey bool, getKeyOrValueNode func(*ast.MapNodeIter) ast.Node) (map[string]ast.Node, error) {
Expand Down Expand Up @@ -1238,6 +1252,9 @@ func (d *Decoder) keyToValueNodeMap(node ast.Node, ignoreMergeKey bool) (map[str
}

func (d *Decoder) setDefaultValueIfConflicted(v reflect.Value, fieldMap StructFieldMap) error {
for v.Type().Kind() == reflect.Ptr {
v = v.Elem()
}
typ := v.Type()
if typ.Kind() != reflect.Struct {
return nil
Expand Down Expand Up @@ -1413,7 +1430,11 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N
if aliasName != "" {
newFieldValue := d.anchorValueMap[aliasName]
if newFieldValue.IsValid() {
fieldValue.Set(d.castToAssignableValue(newFieldValue, fieldValue.Type()))
value, err := d.castToAssignableValue(newFieldValue, fieldValue.Type(), d.anchorNodeMap[aliasName])
if err != nil {
return err
}
fieldValue.Set(value)
}
}
continue
Expand Down Expand Up @@ -1459,7 +1480,7 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N
continue
}
_ = d.setDefaultValueIfConflicted(newFieldValue, structFieldMap)
fieldValue.Set(d.castToAssignableValue(newFieldValue, fieldValue.Type()))
fieldValue.Set(newFieldValue)
continue
}
v, exists := keyToNodeMap[structField.RenderName]
Expand Down Expand Up @@ -1488,7 +1509,7 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N
}
continue
}
fieldValue.Set(d.castToAssignableValue(newFieldValue, fieldValue.Type()))
fieldValue.Set(newFieldValue)
}
if foundErr != nil {
return foundErr
Expand Down Expand Up @@ -1566,9 +1587,8 @@ func (d *Decoder) decodeArray(ctx context.Context, dst reflect.Value, src ast.No
foundErr = err
}
continue
} else {
arrayValue.Index(idx).Set(d.castToAssignableValue(dstValue, elemType))
}
arrayValue.Index(idx).Set(dstValue)
}
idx++
}
Expand Down Expand Up @@ -1613,7 +1633,7 @@ func (d *Decoder) decodeSlice(ctx context.Context, dst reflect.Value, src ast.No
}
continue
}
sliceValue = reflect.Append(sliceValue, d.castToAssignableValue(dstValue, elemType))
sliceValue = reflect.Append(sliceValue, dstValue)
}
dst.Set(sliceValue)
if foundErr != nil {
Expand Down Expand Up @@ -1796,7 +1816,7 @@ func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node
}
if !k.IsValid() {
// expect nil key
mapValue.SetMapIndex(d.createDecodableValue(keyType), d.castToAssignableValue(dstValue, valueType))
mapValue.SetMapIndex(d.createDecodableValue(keyType), dstValue)
continue
}
if keyType.Kind() != k.Kind() {
Expand All @@ -1805,7 +1825,7 @@ func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node
key.GetToken(),
)
}
mapValue.SetMapIndex(k, d.castToAssignableValue(dstValue, valueType))
mapValue.SetMapIndex(k, dstValue)
}
dst.Set(mapValue)
if foundErr != nil {
Expand Down
40 changes: 28 additions & 12 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1321,12 +1321,6 @@ func TestDecoder_TypeConversionError(t *testing.T) {
if !strings.Contains(err.Error(), msg) {
t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg)
}
if len(v) == 0 || len(v["v"]) == 0 {
t.Fatal("failed to decode value")
}
if v["v"][0] != 1 {
t.Fatal("failed to decode value")
}
})
t.Run("string to int", func(t *testing.T) {
var v map[string][]int
Expand All @@ -1338,12 +1332,6 @@ func TestDecoder_TypeConversionError(t *testing.T) {
if !strings.Contains(err.Error(), msg) {
t.Fatalf("expected error message: %s to contain: %s", err.Error(), msg)
}
if len(v) == 0 || len(v["v"]) == 0 {
t.Fatal("failed to decode value")
}
if v["v"][0] != 1 {
t.Fatal("failed to decode value")
}
})
})
t.Run("overflow error", func(t *testing.T) {
Expand Down Expand Up @@ -2739,6 +2727,34 @@ func (u *unmarshalList) UnmarshalYAML(b []byte) error {
return nil
}

func TestDecoder_DecodeWithAnchorAnyValue(t *testing.T) {
type Config struct {
Env []string `json:"env"`
}

type Schema struct {
Def map[string]any `json:"def"`
Config Config `json:"config"`
}

data := `
def:
myenv: &my_env
- VAR1=1
- VAR2=2
config:
env: *my_env
`

var cfg Schema
if err := yaml.NewDecoder(strings.NewReader(data)).Decode(&cfg); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(cfg.Config.Env, []string{"VAR1=1", "VAR2=2"}) {
t.Fatalf("failed to decode value. actual = %+v", cfg)
}
}

func TestDecoder_UnmarshalBytesWithSeparatedList(t *testing.T) {
yml := `
a:
Expand Down

0 comments on commit aeed806

Please sign in to comment.