From 71a32d6adc14d7d65d8d56944a7f3c1242316271 Mon Sep 17 00:00:00 2001 From: un000 Date: Sun, 14 May 2023 00:48:51 +0400 Subject: [PATCH] Fixed panics in Read Command Reflection API --- read_command_reflect.go | 121 +++++++++++++++++++++++++---------- read_command_reflect_test.go | 94 +++++++++++++++++++++++++++ 2 files changed, 181 insertions(+), 34 deletions(-) create mode 100644 read_command_reflect_test.go diff --git a/read_command_reflect.go b/read_command_reflect.go index 1b51a387..6beac7f0 100644 --- a/read_command_reflect.go +++ b/read_command_reflect.go @@ -146,26 +146,35 @@ func setValue(f reflect.Value, value interface{}) Error { return nil } - switch f.Kind() { + switch fieldKind := f.Kind(); fieldKind { case reflect.Int, reflect.Int64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint64, reflect.Uint8, reflect.Uint16, reflect.Uint32: v := reflect.ValueOf(value) - v = v.Convert(f.Type()) - f.Set(v) - case reflect.Float64, reflect.Float32: - if v, ok := value.(float32); ok { - value = float64(v) + t := f.Type() + if !v.CanConvert(t) { + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind)) } - // if value has returned as a float - if fv, ok := value.(float64); ok { - f.SetFloat(fv) - } else { - // an int value has been set in the float - possibly due to a lua UDF - f.SetFloat(float64(value.(int))) + v = v.Convert(t) + f.Set(v) + case reflect.Float64, reflect.Float32: + switch v := value.(type) { + case float64: + f.SetFloat(v) + case float32: + f.SetFloat(float64(v)) + case int: + f.SetFloat(float64(v)) + default: + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind)) } case reflect.String: - rv := reflect.ValueOf(value.(string)) + v, ok := value.(string) + if !ok { + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind)) + } + + rv := reflect.ValueOf(v) if rv.Type() != f.Type() { rv = rv.Convert(f.Type()) } @@ -177,16 +186,19 @@ func setValue(f reflect.Value, value interface{}) Error { case bool: f.SetBool(v) default: - return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for boolean field", value)) + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind)) } case reflect.Interface: if value != nil { f.Set(reflect.ValueOf(value)) } case reflect.Ptr: - switch f.Type().Elem().Kind() { + switch fieldKind := f.Type().Elem().Kind(); fieldKind { case reflect.String: - tempV := value.(string) + tempV, ok := value.(string) + if !ok { + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for *%s field", value, fieldKind)) + } rv := reflect.ValueOf(&tempV) if rv.Type() != f.Type() { rv = rv.Convert(f.Type()) @@ -194,7 +206,12 @@ func setValue(f reflect.Value, value interface{}) Error { f.Set(rv) case reflect.Int, reflect.Int64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint64, reflect.Uint8, reflect.Uint16, reflect.Uint32: - v := reflect.ValueOf(value).Convert(f.Type().Elem()) + v := reflect.ValueOf(value) + t := f.Type().Elem() + if !v.CanConvert(t) { + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for *%s field", value, fieldKind)) + } + v = v.Convert(t) if f.IsZero() { f.Set(reflect.New(f.Type().Elem())) } @@ -207,7 +224,11 @@ func setValue(f reflect.Value, value interface{}) Error { if fv, ok := value.(float64); ok { tempV = fv } else { - tempV = math.Float64frombits(uint64(value.(int))) + v, ok := value.(int) + if !ok { + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for *%s field", value, fieldKind)) + } + tempV = math.Float64frombits(uint64(v)) } rv := reflect.ValueOf(&tempV) @@ -219,7 +240,7 @@ func setValue(f reflect.Value, value interface{}) Error { var tempV bool switch v := value.(type) { case int: - tempV = (v == 1) + tempV = v == 1 case bool: tempV = v default: @@ -242,7 +263,11 @@ func setValue(f reflect.Value, value interface{}) Error { if fv, ok := value.(float64); ok { tempV64 = fv } else { - tempV64 = math.Float64frombits(uint64(value.(int))) + v, ok := value.(int) + if !ok { + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for *%s field", value, fieldKind)) + } + tempV64 = math.Float64frombits(uint64(v)) } tempV := float32(tempV64) @@ -256,12 +281,19 @@ func setValue(f reflect.Value, value interface{}) Error { case reflect.Struct: // support time.Time if f.Type().Elem().PkgPath() == "time" && f.Type().Elem().Name() == "Time" { - tm := time.Unix(0, int64(value.(int))) + v, ok := value.(int) + if !ok { + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for *%s field", value, fieldKind)) + } + tm := time.Unix(0, int64(v)) f.Set(reflect.ValueOf(&tm)) break } - valMap := value.(map[interface{}]interface{}) - // iteraste over struct fields and recursively fill them up + valMap, ok := value.(map[interface{}]interface{}) + if !ok { + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind)) + } + // iterate over struct fields and recursively fill them up if valMap != nil { newObjPtr := f if f.IsNil() { @@ -280,6 +312,9 @@ func setValue(f reflect.Value, value interface{}) Error { case reflect.Slice, reflect.Array: // BLOBs come back as []byte theArray := reflect.ValueOf(value) + if theArray.Kind() != reflect.Slice { + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind)) + } if f.Kind() == reflect.Slice { if f.IsNil() { @@ -297,37 +332,48 @@ func setValue(f reflect.Value, value interface{}) Error { } case reflect.Map: emptyStruct := reflect.ValueOf(struct{}{}) - theMap := value.(map[interface{}]interface{}) + theMap, ok := value.(map[interface{}]interface{}) + if !ok { + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind)) + } if theMap != nil { newMap := reflect.MakeMap(f.Type()) var newKey, newVal reflect.Value for key, elem := range theMap { + fKeyType := f.Type().Key() if key != nil { newKey = reflect.ValueOf(key) } else { - newKey = reflect.Zero(f.Type().Key()) + newKey = reflect.Zero(fKeyType) } - if newKey.Type() != f.Type().Key() { - newKey = newKey.Convert(f.Type().Key()) + if newKey.Type() != fKeyType { + if !newKey.CanConvert(fKeyType) { + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid key `%#v` for %s field", value, fieldKind)) + } + newKey = newKey.Convert(fKeyType) } + fElemType := f.Type().Elem() if elem != nil { newVal = reflect.ValueOf(elem) } else { - newVal = reflect.Zero(f.Type().Elem()) + newVal = reflect.Zero(fElemType) } - if newVal.Type() != f.Type().Elem() { + if newVal.Type() != fElemType { switch newVal.Kind() { case reflect.Map, reflect.Slice, reflect.Array: - newVal = reflect.New(f.Type().Elem()) + newVal = reflect.New(fElemType) if err := setValue(newVal.Elem(), elem); err != nil { return err } newVal = reflect.Indirect(newVal) default: - newVal = newVal.Convert(f.Type().Elem()) + if !newVal.CanConvert(fElemType) { + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind)) + } + newVal = newVal.Convert(fElemType) } } @@ -347,12 +393,19 @@ func setValue(f reflect.Value, value interface{}) Error { case reflect.Struct: // support time.Time if f.Type().PkgPath() == "time" && f.Type().Name() == "Time" { - f.Set(reflect.ValueOf(time.Unix(0, int64(value.(int))))) + v, ok := value.(int) + if !ok { + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for time %s field", value, fieldKind)) + } + f.Set(reflect.ValueOf(time.Unix(0, int64(v)))) break } - valMap := value.(map[interface{}]interface{}) - // iteraste over struct fields and recursively fill them up + valMap, ok := value.(map[interface{}]interface{}) + if !ok { + return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind)) + } + // iterate over struct fields and recursively fill them up if err := setStructValue(f, valMap, f.Type(), nil); err != nil { return err } diff --git a/read_command_reflect_test.go b/read_command_reflect_test.go new file mode 100644 index 00000000..3891442d --- /dev/null +++ b/read_command_reflect_test.go @@ -0,0 +1,94 @@ +// Copyright 2014-2022 Aerospike, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aerospike + +import ( + "reflect" + + gg "github.com/onsi/ginkgo/v2" + gm "github.com/onsi/gomega" +) + +var _ = gg.Describe("Read Command Reflect setValue", func() { + type testStruct struct { + Int int + Bool bool + Int64 int64 + String string + Float64 float64 + SliceString []string + SliceInt []int + SliceFloat64 []float64 + MapStringFloat64 map[string]float64 + MapStringString map[string]string + MapInterfaceInterface map[interface{}]interface{} + } + + ts := &testStruct{} + reflectField := func(name string) reflect.Value { + return reflect.Indirect(reflect.ValueOf(ts)).FieldByName(name) + } + + tests := []struct { + name string + field reflect.Value + obj interface{} + error bool + }{ + {name: "int->int", field: reflectField("Int"), obj: 5}, + {name: "int->int64", field: reflectField("Int64"), obj: int64(5)}, + {name: "int->bool", field: reflectField("Bool"), obj: true}, + {name: "int->float64", field: reflectField("Float64"), obj: 5}, + {name: "[]string->[]string", field: reflectField("SliceString"), obj: []string{"1", "2"}}, + {name: "[]int->[]int", field: reflectField("SliceInt"), obj: []int{1, 2}}, + {name: "map[string]string->map[string]string", field: reflectField("MapStringString"), obj: map[interface{}]interface{}{"1": "2"}}, + {name: "map[string]float64->map[string]float64", field: reflectField("MapStringFloat64"), obj: map[interface{}]interface{}{"1": 2}}, + {name: "map[interface{}]interface{}->map[interface{}]interface{}", field: reflectField("MapInterfaceInterface"), obj: map[interface{}]interface{}{"1": 2}}, + + {name: "string->int", field: reflectField("Int"), obj: "5", error: true}, + {name: "string->bool", field: reflectField("Bool"), obj: "true", error: true}, + {name: "int->string", field: reflectField("String"), obj: 5, error: true}, + {name: "bool->int", field: reflectField("Int"), obj: true, error: true}, + {name: "bool->string", field: reflectField("String"), obj: true, error: true}, + {name: "int->[]string", field: reflectField("SliceString"), obj: 5, error: true}, + {name: "int->[]int", field: reflectField("SliceInt"), obj: 5, error: true}, + {name: "int->[]float64", field: reflectField("SliceFloat64"), obj: 5, error: true}, + {name: "[]string->int", field: reflectField("Int"), obj: []string{"1", "2"}, error: true}, + {name: "[]string->int64", field: reflectField("Int64"), obj: []string{"1", "2"}, error: true}, + {name: "[]string->float64", field: reflectField("Float64"), obj: []string{"1", "2"}, error: true}, + {name: "[]int->int", field: reflectField("Int"), obj: []int{1, 2}, error: true}, + {name: "[]int->int64", field: reflectField("Int64"), obj: []int{1, 2}, error: true}, + {name: "[]string->[]int", field: reflectField("SliceInt"), obj: []string{"1", "2"}, error: true}, + {name: "map[string]string->[]int", field: reflectField("SliceInt"), obj: map[interface{}]interface{}{"1": "2"}, error: true}, + {name: "[]int->map[string]string", field: reflectField("MapStringString"), obj: []int{1, 2}, error: true}, + {name: "map[string]string->map[string]float64", field: reflectField("MapStringFloat64"), obj: map[interface{}]interface{}{"1": "2"}, error: true}, + } + + for _, tt := range tests { + tc := tt + gg.Context(tc.name, func() { + gg.It("Should return correct error", func() { + gm.Expect(func() { + err := setValue(tc.field, tc.obj) + if tc.error { + gm.Expect(err).To(gm.HaveOccurred()) + return + } + gm.Expect(err).ToNot(gm.HaveOccurred()) + }).To(gm.Not(gm.Panic())) + }) + }) + } +})