Skip to content

Commit

Permalink
Implement nested structs for columns and values (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
ricci2511 authored Jul 15, 2023
1 parent da49f61 commit 0eac140
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 32 deletions.
52 changes: 33 additions & 19 deletions columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,46 +74,60 @@ func columns(v interface{}, strict bool, excluded ...string) ([]string, error) {
return res, nil
}

names := columnNames(model, strict, excluded...)
toCache := append(names, excluded...)
columnsCache.Store(model, toCache)
return names, nil
}

func columnNames(model reflect.Value, strict bool, excluded ...string) []string {
numfield := model.NumField()
names := make([]string, 0, numfield)

isExcluded := func(name string) bool {
for _, ex := range excluded {
if ex == name {
return true
}
}
return false
}

for i := 0; i < numfield; i++ {
valField := model.Field(i)
if !valField.IsValid() || !valField.CanSet() {
continue
}

typeField := model.Type().Field(i)
if tag, ok := typeField.Tag.Lookup(dbTag); ok {
if tag != "-" && !isExcluded(tag) {
names = append(names, tag)
}

if typeField.Type.Kind() == reflect.Struct {
embeddedNames := columnNames(valField, strict, excluded...)
names = append(names, embeddedNames...)
continue
}

if strict {
fieldName := typeField.Name
if tag, hasTag := typeField.Tag.Lookup(dbTag); hasTag {
if tag == "-" {
continue
}
fieldName = tag
} else if strict {
// there's no tag name and we're in strict mode so move on
continue
}

if isExcluded(typeField.Name) || !supportedColumnType(valField.Kind()) {
if isExcluded(fieldName, excluded...) {
continue
}

names = append(names, typeField.Name)
if supportedColumnType(valField.Kind()) {
names = append(names, fieldName)
}
}

toCache := append(names, excluded...)
columnsCache.Store(model, toCache)
return names, nil
return names
}

func isExcluded(name string, excluded ...string) bool {
for _, ex := range excluded {
if ex == name {
return true
}
}
return false
}

func reflectValue(v interface{}) (reflect.Value, error) {
Expand Down
66 changes: 63 additions & 3 deletions columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,76 @@ func TestColumnsIgnoresPrivateFields(t *testing.T) {
assert.EqualValues(t, []string{"Age"}, cols)
}

func TestColumnsAddsComplexTypesWhenStructTag(t *testing.T) {
func TestColumnsAddsComplexTypesWhenNoStructTag(t *testing.T) {
type person struct {
Address struct {
Street string
}
}

cols, err := Columns(&person{})
assert.NoError(t, err)
assert.EqualValues(t, []string{"Street"}, cols)
}

func TestColumnsAddsComplexTypesWhenStructTag(t *testing.T) {
type person struct {
Address struct {
Street string `db:"address.street"`
}
}

cols, err := Columns(&person{})
assert.NoError(t, err)
assert.EqualValues(t, []string{"address.street"}, cols)
}

func TestColumnsDoesNotAddStructTag(t *testing.T) {
type person struct {
Address struct {
Street string `db:"address.street"`
} `db:"address"`
}

cols, err := Columns(&person{})
assert.NoError(t, err)
assert.EqualValues(t, []string{"address"}, cols)
assert.EqualValues(t, []string{"address.street"}, cols)
}

func TestColumnsStrictAddsComplexTypesWhenStructTag(t *testing.T) {
type person struct {
Address struct {
Street string `db:"address.street"`
}
}

cols, err := ColumnsStrict(&person{})
assert.NoError(t, err)
assert.EqualValues(t, []string{"address.street"}, cols)
}

func TestColumnsStrictDoesNotAddComplexTypesWhenNoStructTag(t *testing.T) {
type person struct {
Address struct {
Street string
}
}

cols, err := ColumnsStrict(&person{})
assert.NoError(t, err)
assert.EqualValues(t, []string{}, cols)
}

func TestColumnsStrictAddsComplexTypesRegardlessOfStructTag(t *testing.T) {
type person struct {
Address struct {
Street string `db:"address.street"`
} `db:"-"`
}

cols, err := ColumnsStrict(&person{})
assert.NoError(t, err)
assert.EqualValues(t, []string{"address.street"}, cols)
}

func TestColumnsIgnoresComplexTypesWhenNoStructTag(t *testing.T) {
Expand All @@ -84,7 +144,7 @@ func TestColumnsIgnoresComplexTypesWhenNoStructTag(t *testing.T) {

cols, err := Columns(&person{})
assert.NoError(t, err)
assert.EqualValues(t, []string{}, cols)
assert.EqualValues(t, []string{"Street"}, cols)
}

func TestColumnsExcludesFields(t *testing.T) {
Expand Down
73 changes: 73 additions & 0 deletions examples_values_columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,31 @@ func ExampleValues() {
// [1 Brett]
}

func ExampleValues_nested() {
type Address struct {
Street string
City string
}

person := struct {
ID int
Name string
Address
}{
Name: "Brett",
ID: 1,
Address: Address{
City: "San Francisco",
},
}

cols := []string{"Name", "City"}
vals, _ := scan.Values(cols, &person)
fmt.Printf("%+v", vals)
// Output:
// [Brett San Francisco]
}

func ExampleColumns() {
var person struct {
ID int `db:"person_id"`
Expand Down Expand Up @@ -59,3 +84,51 @@ func ExampleColumnsStrict() {
// Output:
// [id age]
}

func ExampleColumnsNested() {
var person struct {
ID int `db:"person.id"`
Name string `db:"person.name"`
Company struct {
ID int `db:"company.id"`
Name string
}
}

cols, _ := scan.Columns(&person)
fmt.Printf("%+v", cols)
// Output:
// [person.id person.name company.id Name]
}

func ExampleColumnsNestedStrict() {
var person struct {
ID int `db:"person.id"`
Name string `db:"person.name"`
Company struct {
ID int `db:"company.id"`
Name string
}
}

cols, _ := scan.ColumnsStrict(&person)
fmt.Printf("%+v", cols)
// Output:
// [person.id person.name company.id]
}

func ExampleColumnsNested_exclude() {
var person struct {
ID int `db:"person.id"`
Name string `db:"person.name"`
Company struct {
ID int `db:"-"`
Name string `db:"company.name"`
}
}

cols, _ := scan.Columns(&person)
fmt.Printf("%+v", cols)
// Output:
// [person.id person.name company.name]
}
29 changes: 20 additions & 9 deletions values.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,45 @@ func Values(cols []string, v interface{}) ([]interface{}, error) {
return nil, fmt.Errorf("field %T.%q either does not exist or is unexported: %w", v, col, ErrStructFieldMissing)
}

vals[i] = model.Field(j).Interface()
vals[i] = model.FieldByIndex(j).Interface()
}
return vals, nil
}

func loadFields(val reflect.Value) map[string]int {
func loadFields(val reflect.Value) map[string][]int {
if cache, cached := valuesCache.Load(val); cached {
return cache.(map[string]int)
return cache.(map[string][]int)
}
return writeFieldsCache(val)
}

func writeFieldsCache(val reflect.Value) map[string]int {
func writeFieldsCache(val reflect.Value) map[string][]int {
m := map[string][]int{}
writeFields(val, m, []int{})
valuesCache.Store(val, m)
return m
}

func writeFields(val reflect.Value, m map[string][]int, index []int) {
typ := val.Type()
numfield := val.NumField()
m := map[string]int{}

for i := 0; i < numfield; i++ {
if !val.Field(i).CanSet() {
continue
}

field := typ.Field(i)
m[field.Name] = i
fieldIndex := append(index, field.Index...)

if field.Type.Kind() == reflect.Struct {
writeFields(val.Field(i), m, fieldIndex)
continue
}

m[field.Name] = fieldIndex
if tag, ok := field.Tag.Lookup(dbTag); ok {
m[tag] = i
m[tag] = fieldIndex
}
}
valuesCache.Store(val, m)
return m
}
22 changes: 21 additions & 1 deletion values_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,26 @@ func TestValuesScansDBTags(t *testing.T) {
assert.EqualValues(t, []interface{}{"Brett"}, vals)
}

func TestValuesScansNestedFields(t *testing.T) {
type Address struct {
Street string
City string
}

type Person struct {
Name string
Age int
Address
}

p := &Person{Name: "Brett", Address: Address{Street: "123 Main St", City: "San Francisco"}}

vals, err := Values([]string{"Name", "Street", "City"}, p)
require.NoError(t, err)

assert.EqualValues(t, []interface{}{"Brett", "123 Main St", "San Francisco"}, vals)
}

func TestValuesReturnsErrorWhenPassingNonPointer(t *testing.T) {
_, err := Values([]string{"Name"}, "")
require.Error(t, err)
Expand Down Expand Up @@ -80,7 +100,7 @@ func TestValuesReadsFromCacheFirst(t *testing.T) {
}

v := reflect.Indirect(reflect.ValueOf(&person))
valuesCache.Store(v, map[string]int{"Name": 0})
valuesCache.Store(v, map[string][]int{"Name": {0}})

vals, err := Values([]string{"Name"}, &person)
require.NoError(t, err)
Expand Down

0 comments on commit 0eac140

Please sign in to comment.