Skip to content

Commit

Permalink
Fixed panics in Read Command Reflection API
Browse files Browse the repository at this point in the history
  • Loading branch information
un000 committed May 13, 2023
1 parent f9dfde8 commit e7e19d6
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 34 deletions.
121 changes: 87 additions & 34 deletions read_command_reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,26 +146,35 @@ func setValue(f reflect.Value, value interface{}) Error {
return nil
}

switch f.Kind() {
switch fieldKind := f.Kind(); fieldKind {
case reflect.Int, reflect.Int64, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Uint, reflect.Uint64, reflect.Uint8, reflect.Uint16, reflect.Uint32:
v := reflect.ValueOf(value)
v = v.Convert(f.Type())
f.Set(v)
case reflect.Float64, reflect.Float32:
if v, ok := value.(float32); ok {
value = float64(v)
t := f.Type()
if !v.CanConvert(t) {
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind))
}

// if value has returned as a float
if fv, ok := value.(float64); ok {
f.SetFloat(fv)
} else {
// an int value has been set in the float - possibly due to a lua UDF
f.SetFloat(float64(value.(int)))
v = v.Convert(t)
f.Set(v)
case reflect.Float64, reflect.Float32:
switch v := value.(type) {
case float64:
f.SetFloat(v)
case float32:
f.SetFloat(float64(v))
case int:
f.SetFloat(float64(v))
default:
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind))
}
case reflect.String:
rv := reflect.ValueOf(value.(string))
v, ok := value.(string)
if !ok {
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind))
}

rv := reflect.ValueOf(v)
if rv.Type() != f.Type() {
rv = rv.Convert(f.Type())
}
Expand All @@ -177,24 +186,32 @@ func setValue(f reflect.Value, value interface{}) Error {
case bool:
f.SetBool(v)
default:
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for boolean field", value))
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind))
}
case reflect.Interface:
if value != nil {
f.Set(reflect.ValueOf(value))
}
case reflect.Ptr:
switch f.Type().Elem().Kind() {
switch fieldKind := f.Type().Elem().Kind(); fieldKind {
case reflect.String:
tempV := value.(string)
tempV, ok := value.(string)
if !ok {
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for *%s field", value, fieldKind))
}
rv := reflect.ValueOf(&tempV)
if rv.Type() != f.Type() {
rv = rv.Convert(f.Type())
}
f.Set(rv)
case reflect.Int, reflect.Int64, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Uint, reflect.Uint64, reflect.Uint8, reflect.Uint16, reflect.Uint32:
v := reflect.ValueOf(value).Convert(f.Type().Elem())
v := reflect.ValueOf(value)
t := f.Type().Elem()
if !v.CanConvert(t) {
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for *%s field", value, fieldKind))
}
v = v.Convert(t)
if f.IsZero() {
f.Set(reflect.New(f.Type().Elem()))
}
Expand All @@ -207,7 +224,11 @@ func setValue(f reflect.Value, value interface{}) Error {
if fv, ok := value.(float64); ok {
tempV = fv
} else {
tempV = math.Float64frombits(uint64(value.(int)))
v, ok := value.(int)
if !ok {
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for *%s field", value, fieldKind))
}
tempV = math.Float64frombits(uint64(v))
}

rv := reflect.ValueOf(&tempV)
Expand All @@ -219,7 +240,7 @@ func setValue(f reflect.Value, value interface{}) Error {
var tempV bool
switch v := value.(type) {
case int:
tempV = (v == 1)
tempV = v == 1
case bool:
tempV = v
default:
Expand All @@ -242,7 +263,11 @@ func setValue(f reflect.Value, value interface{}) Error {
if fv, ok := value.(float64); ok {
tempV64 = fv
} else {
tempV64 = math.Float64frombits(uint64(value.(int)))
v, ok := value.(int)
if !ok {
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for *%s field", value, fieldKind))
}
tempV64 = math.Float64frombits(uint64(v))
}

tempV := float32(tempV64)
Expand All @@ -256,12 +281,19 @@ func setValue(f reflect.Value, value interface{}) Error {
case reflect.Struct:
// support time.Time
if f.Type().Elem().PkgPath() == "time" && f.Type().Elem().Name() == "Time" {
tm := time.Unix(0, int64(value.(int)))
v, ok := value.(int)
if !ok {
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for *%s field", value, fieldKind))
}
tm := time.Unix(0, int64(v))
f.Set(reflect.ValueOf(&tm))
break
}
valMap := value.(map[interface{}]interface{})
// iteraste over struct fields and recursively fill them up
valMap, ok := value.(map[interface{}]interface{})
if !ok {
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind))
}
// iterate over struct fields and recursively fill them up
if valMap != nil {
newObjPtr := f
if f.IsNil() {
Expand All @@ -280,6 +312,9 @@ func setValue(f reflect.Value, value interface{}) Error {
case reflect.Slice, reflect.Array:
// BLOBs come back as []byte
theArray := reflect.ValueOf(value)
if theArray.Kind() != reflect.Slice {
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind))
}

if f.Kind() == reflect.Slice {
if f.IsNil() {
Expand All @@ -297,37 +332,48 @@ func setValue(f reflect.Value, value interface{}) Error {
}
case reflect.Map:
emptyStruct := reflect.ValueOf(struct{}{})
theMap := value.(map[interface{}]interface{})
theMap, ok := value.(map[interface{}]interface{})
if !ok {
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind))
}
if theMap != nil {
newMap := reflect.MakeMap(f.Type())
var newKey, newVal reflect.Value
for key, elem := range theMap {
fKeyType := f.Type().Key()
if key != nil {
newKey = reflect.ValueOf(key)
} else {
newKey = reflect.Zero(f.Type().Key())
newKey = reflect.Zero(fKeyType)
}

if newKey.Type() != f.Type().Key() {
newKey = newKey.Convert(f.Type().Key())
if newKey.Type() != fKeyType {
if !newKey.CanConvert(fKeyType) {
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid key `%#v` for %s field", value, fieldKind))
}
newKey = newKey.Convert(fKeyType)
}

fElemType := f.Type().Elem()
if elem != nil {
newVal = reflect.ValueOf(elem)
} else {
newVal = reflect.Zero(f.Type().Elem())
newVal = reflect.Zero(fElemType)
}

if newVal.Type() != f.Type().Elem() {
if newVal.Type() != fElemType {
switch newVal.Kind() {
case reflect.Map, reflect.Slice, reflect.Array:
newVal = reflect.New(f.Type().Elem())
newVal = reflect.New(fElemType)
if err := setValue(newVal.Elem(), elem); err != nil {
return err
}
newVal = reflect.Indirect(newVal)
default:
newVal = newVal.Convert(f.Type().Elem())
if !newVal.CanConvert(fElemType) {
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind))
}
newVal = newVal.Convert(fElemType)
}
}

Expand All @@ -347,12 +393,19 @@ func setValue(f reflect.Value, value interface{}) Error {
case reflect.Struct:
// support time.Time
if f.Type().PkgPath() == "time" && f.Type().Name() == "Time" {
f.Set(reflect.ValueOf(time.Unix(0, int64(value.(int)))))
v, ok := value.(int)
if !ok {
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for time %s field", value, fieldKind))
}
f.Set(reflect.ValueOf(time.Unix(0, int64(v))))
break
}

valMap := value.(map[interface{}]interface{})
// iteraste over struct fields and recursively fill them up
valMap, ok := value.(map[interface{}]interface{})
if !ok {
return newError(types.PARSE_ERROR, fmt.Sprintf("Invalid value `%#v` for %s field", value, fieldKind))
}
// iterate over struct fields and recursively fill them up
if err := setStructValue(f, valMap, f.Type(), nil); err != nil {
return err
}
Expand Down
94 changes: 94 additions & 0 deletions read_command_reflect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright 2014-2022 Aerospike, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package aerospike

import (
"reflect"

gg "github.com/onsi/ginkgo/v2"
gm "github.com/onsi/gomega"
)

var _ = gg.Describe("Read Command Reflect setValue", func() {
type testStruct struct {
Int int
Bool bool
Int64 int64
String string
Float64 float64
SliceString []string
SliceInt []int
SliceFloat64 []float64
MapStringFloat64 map[string]float64
MapStringString map[string]string
MapInterfaceInterface map[interface{}]interface{}
}

ts := &testStruct{}
reflectField := func(name string) reflect.Value {
return reflect.Indirect(reflect.ValueOf(ts)).FieldByName(name)
}

tests := []struct {
name string
field reflect.Value
obj interface{}
error bool
}{
{name: "int->int", field: reflectField("Int"), obj: 5},
{name: "int->int64", field: reflectField("Int64"), obj: int64(5)},
{name: "int->bool", field: reflectField("Bool"), obj: true},
{name: "int->float64", field: reflectField("Float64"), obj: 5},
{name: "[]string->[]string", field: reflectField("SliceString"), obj: []string{"1", "2"}},
{name: "[]int->[]int", field: reflectField("SliceInt"), obj: []int{1, 2}},
{name: "map[string]string->map[string]string", field: reflectField("MapStringString"), obj: map[interface{}]interface{}{"1": "2"}},
{name: "map[string]float64->map[string]float64", field: reflectField("MapStringFloat64"), obj: map[interface{}]interface{}{"1": 2}},
{name: "map[interface{}]interface{}->map[interface{}]interface{}", field: reflectField("MapInterfaceInterface"), obj: map[interface{}]interface{}{"1": 2}},

{name: "string->bool", field: reflectField("Bool"), obj: "true", error: true},
{name: "int->string", field: reflectField("String"), obj: 5, error: true},
{name: "bool->int", field: reflectField("Int"), obj: true, error: true},
{name: "bool->string", field: reflectField("String"), obj: true, error: true},
{name: "int->[]string", field: reflectField("SliceString"), obj: 5, error: true},
{name: "int->[]int", field: reflectField("SliceInt"), obj: 5, error: true},
{name: "int->[]float64", field: reflectField("SliceFloat64"), obj: 5, error: true},
{name: "[]string->int", field: reflectField("Int"), obj: []string{"1", "2"}, error: true},
{name: "[]string->int64", field: reflectField("Int64"), obj: []string{"1", "2"}, error: true},
{name: "[]string->float64", field: reflectField("Float64"), obj: []string{"1", "2"}, error: true},
{name: "[]int->int", field: reflectField("Int"), obj: []int{1, 2}, error: true},
{name: "[]int->int64", field: reflectField("Int64"), obj: []int{1, 2}, error: true},
{name: "[]string->[]int", field: reflectField("SliceInt"), obj: []string{"1", "2"}, error: true},
{name: "map[string]string->[]int", field: reflectField("SliceInt"), obj: map[interface{}]interface{}{"1": "2"}, error: true},
{name: "[]int->map[string]string", field: reflectField("MapStringString"), obj: []int{1, 2}, error: true},
{name: "map[string]string->map[string]float64", field: reflectField("MapStringFloat64"), obj: map[interface{}]interface{}{"1": "2"}, error: true},
}

for _, tt := range tests {
tc := tt
gg.Context(tc.name, func() {
gg.It("Should return correct error", func() {
gm.Expect(func() {
err := setValue(tc.field, tc.obj)
if tc.error {
gm.Expect(err).To(gm.HaveOccurred())
} else {
gm.Expect(err).ToNot(gm.HaveOccurred())
}
}).To(gm.Not(gm.Panic()))

})
})
}
})

0 comments on commit e7e19d6

Please sign in to comment.