diff --git a/columns.go b/columns.go index b02cfc7..2e97b17 100644 --- a/columns.go +++ b/columns.go @@ -74,18 +74,16 @@ func columns(v interface{}, strict bool, excluded ...string) ([]string, error) { return res, nil } + names := columnNames(model, strict, excluded...) + toCache := append(names, excluded...) + columnsCache.Store(model, toCache) + return names, nil +} + +func columnNames(model reflect.Value, strict bool, excluded ...string) []string { numfield := model.NumField() names := make([]string, 0, numfield) - isExcluded := func(name string) bool { - for _, ex := range excluded { - if ex == name { - return true - } - } - return false - } - for i := 0; i < numfield; i++ { valField := model.Field(i) if !valField.IsValid() || !valField.CanSet() { @@ -93,27 +91,43 @@ func columns(v interface{}, strict bool, excluded ...string) ([]string, error) { } typeField := model.Type().Field(i) - if tag, ok := typeField.Tag.Lookup(dbTag); ok { - if tag != "-" && !isExcluded(tag) { - names = append(names, tag) - } + + if typeField.Type.Kind() == reflect.Struct { + embeddedNames := columnNames(valField, strict, excluded...) + names = append(names, embeddedNames...) continue } - if strict { + fieldName := typeField.Name + if tag, hasTag := typeField.Tag.Lookup(dbTag); hasTag { + if tag == "-" { + continue + } + fieldName = tag + } else if strict { + // there's no tag name and we're in strict mode so move on continue } - if isExcluded(typeField.Name) || !supportedColumnType(valField.Kind()) { + if isExcluded(fieldName, excluded...) { continue } - names = append(names, typeField.Name) + if supportedColumnType(valField.Kind()) { + names = append(names, fieldName) + } } - toCache := append(names, excluded...) - columnsCache.Store(model, toCache) - return names, nil + return names +} + +func isExcluded(name string, excluded ...string) bool { + for _, ex := range excluded { + if ex == name { + return true + } + } + return false } func reflectValue(v interface{}) (reflect.Value, error) { diff --git a/columns_test.go b/columns_test.go index bbeb900..6362534 100644 --- a/columns_test.go +++ b/columns_test.go @@ -63,16 +63,76 @@ func TestColumnsIgnoresPrivateFields(t *testing.T) { assert.EqualValues(t, []string{"Age"}, cols) } -func TestColumnsAddsComplexTypesWhenStructTag(t *testing.T) { +func TestColumnsAddsComplexTypesWhenNoStructTag(t *testing.T) { type person struct { Address struct { Street string + } + } + + cols, err := Columns(&person{}) + assert.NoError(t, err) + assert.EqualValues(t, []string{"Street"}, cols) +} + +func TestColumnsAddsComplexTypesWhenStructTag(t *testing.T) { + type person struct { + Address struct { + Street string `db:"address.street"` + } + } + + cols, err := Columns(&person{}) + assert.NoError(t, err) + assert.EqualValues(t, []string{"address.street"}, cols) +} + +func TestColumnsDoesNotAddStructTag(t *testing.T) { + type person struct { + Address struct { + Street string `db:"address.street"` } `db:"address"` } cols, err := Columns(&person{}) assert.NoError(t, err) - assert.EqualValues(t, []string{"address"}, cols) + assert.EqualValues(t, []string{"address.street"}, cols) +} + +func TestColumnsStrictAddsComplexTypesWhenStructTag(t *testing.T) { + type person struct { + Address struct { + Street string `db:"address.street"` + } + } + + cols, err := ColumnsStrict(&person{}) + assert.NoError(t, err) + assert.EqualValues(t, []string{"address.street"}, cols) +} + +func TestColumnsStrictDoesNotAddComplexTypesWhenNoStructTag(t *testing.T) { + type person struct { + Address struct { + Street string + } + } + + cols, err := ColumnsStrict(&person{}) + assert.NoError(t, err) + assert.EqualValues(t, []string{}, cols) +} + +func TestColumnsStrictAddsComplexTypesRegardlessOfStructTag(t *testing.T) { + type person struct { + Address struct { + Street string `db:"address.street"` + } `db:"-"` + } + + cols, err := ColumnsStrict(&person{}) + assert.NoError(t, err) + assert.EqualValues(t, []string{"address.street"}, cols) } func TestColumnsIgnoresComplexTypesWhenNoStructTag(t *testing.T) { @@ -84,7 +144,7 @@ func TestColumnsIgnoresComplexTypesWhenNoStructTag(t *testing.T) { cols, err := Columns(&person{}) assert.NoError(t, err) - assert.EqualValues(t, []string{}, cols) + assert.EqualValues(t, []string{"Street"}, cols) } func TestColumnsExcludesFields(t *testing.T) { diff --git a/examples_values_columns_test.go b/examples_values_columns_test.go index 282107d..c66324f 100644 --- a/examples_values_columns_test.go +++ b/examples_values_columns_test.go @@ -22,6 +22,31 @@ func ExampleValues() { // [1 Brett] } +func ExampleValues_nested() { + type Address struct { + Street string + City string + } + + person := struct { + ID int + Name string + Address + }{ + Name: "Brett", + ID: 1, + Address: Address{ + City: "San Francisco", + }, + } + + cols := []string{"Name", "City"} + vals, _ := scan.Values(cols, &person) + fmt.Printf("%+v", vals) + // Output: + // [Brett San Francisco] +} + func ExampleColumns() { var person struct { ID int `db:"person_id"` @@ -59,3 +84,51 @@ func ExampleColumnsStrict() { // Output: // [id age] } + +func ExampleColumnsNested() { + var person struct { + ID int `db:"person.id"` + Name string `db:"person.name"` + Company struct { + ID int `db:"company.id"` + Name string + } + } + + cols, _ := scan.Columns(&person) + fmt.Printf("%+v", cols) + // Output: + // [person.id person.name company.id Name] +} + +func ExampleColumnsNestedStrict() { + var person struct { + ID int `db:"person.id"` + Name string `db:"person.name"` + Company struct { + ID int `db:"company.id"` + Name string + } + } + + cols, _ := scan.ColumnsStrict(&person) + fmt.Printf("%+v", cols) + // Output: + // [person.id person.name company.id] +} + +func ExampleColumnsNested_exclude() { + var person struct { + ID int `db:"person.id"` + Name string `db:"person.name"` + Company struct { + ID int `db:"-"` + Name string `db:"company.name"` + } + } + + cols, _ := scan.Columns(&person) + fmt.Printf("%+v", cols) + // Output: + // [person.id person.name company.name] +} diff --git a/values.go b/values.go index 328ce57..0bed199 100644 --- a/values.go +++ b/values.go @@ -26,22 +26,28 @@ func Values(cols []string, v interface{}) ([]interface{}, error) { return nil, fmt.Errorf("field %T.%q either does not exist or is unexported: %w", v, col, ErrStructFieldMissing) } - vals[i] = model.Field(j).Interface() + vals[i] = model.FieldByIndex(j).Interface() } return vals, nil } -func loadFields(val reflect.Value) map[string]int { +func loadFields(val reflect.Value) map[string][]int { if cache, cached := valuesCache.Load(val); cached { - return cache.(map[string]int) + return cache.(map[string][]int) } return writeFieldsCache(val) } -func writeFieldsCache(val reflect.Value) map[string]int { +func writeFieldsCache(val reflect.Value) map[string][]int { + m := map[string][]int{} + writeFields(val, m, []int{}) + valuesCache.Store(val, m) + return m +} + +func writeFields(val reflect.Value, m map[string][]int, index []int) { typ := val.Type() numfield := val.NumField() - m := map[string]int{} for i := 0; i < numfield; i++ { if !val.Field(i).CanSet() { @@ -49,11 +55,16 @@ func writeFieldsCache(val reflect.Value) map[string]int { } field := typ.Field(i) - m[field.Name] = i + fieldIndex := append(index, field.Index...) + + if field.Type.Kind() == reflect.Struct { + writeFields(val.Field(i), m, fieldIndex) + continue + } + + m[field.Name] = fieldIndex if tag, ok := field.Tag.Lookup(dbTag); ok { - m[tag] = i + m[tag] = fieldIndex } } - valuesCache.Store(val, m) - return m } diff --git a/values_test.go b/values_test.go index d267962..9ef3b44 100644 --- a/values_test.go +++ b/values_test.go @@ -33,6 +33,26 @@ func TestValuesScansDBTags(t *testing.T) { assert.EqualValues(t, []interface{}{"Brett"}, vals) } +func TestValuesScansNestedFields(t *testing.T) { + type Address struct { + Street string + City string + } + + type Person struct { + Name string + Age int + Address + } + + p := &Person{Name: "Brett", Address: Address{Street: "123 Main St", City: "San Francisco"}} + + vals, err := Values([]string{"Name", "Street", "City"}, p) + require.NoError(t, err) + + assert.EqualValues(t, []interface{}{"Brett", "123 Main St", "San Francisco"}, vals) +} + func TestValuesReturnsErrorWhenPassingNonPointer(t *testing.T) { _, err := Values([]string{"Name"}, "") require.Error(t, err) @@ -80,7 +100,7 @@ func TestValuesReadsFromCacheFirst(t *testing.T) { } v := reflect.Indirect(reflect.ValueOf(&person)) - valuesCache.Store(v, map[string]int{"Name": 0}) + valuesCache.Store(v, map[string][]int{"Name": {0}}) vals, err := Values([]string{"Name"}, &person) require.NoError(t, err)