Skip to content

Commit

Permalink
Add support for named pointers to structs
Browse files Browse the repository at this point in the history
This makes it possible to use named struct pointers like this:
type Document struct {
    Title  string
    Owner  *User `db:"owner"`
    Author *User `db:"author"`
}
  • Loading branch information
Pitmairen committed Aug 12, 2015
1 parent 344b1e9 commit cc395bc
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 1 deletion.
2 changes: 1 addition & 1 deletion reflectx/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc func(string)
fi.Index = apnd(tq.fi.Index, fieldPos)
fi.Children = make([]*FieldInfo, Deref(f.Type).NumField())
queue = append(queue, typeQueue{Deref(f.Type), &fi, pp})
} else if fi.Zero.Kind() == reflect.Struct {
} else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) {

This comment has been minimized.

Copy link
@atomgas

atomgas Dec 23, 2015

This line creates an endless loop for me, and consumes up to 3GM memory, maybe because the struct embeds itself

type A struct {
  Id int64
  RefA *A 
  ...
}

This comment has been minimized.

Copy link
@Pitmairen

Pitmairen Dec 24, 2015

Author Contributor

I'm not sure what is the best way to fix it, but adding a check like the following (on line 343) should allow a single level self reference like in your example:

if fi.Zero.Kind() != reflect.Ptr || tq.t != Deref(f.Type) || tq.fi.Parent == nil {
    queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path})
}
fi.Index = apnd(tq.fi.Index, fieldPos)
fi.Children = make([]*FieldInfo, Deref(f.Type).NumField())
queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path})
Expand Down
49 changes: 49 additions & 0 deletions reflectx/reflect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,55 @@ func TestPtrFields(t *testing.T) {
}
}

func TestNamedPtrFields(t *testing.T) {
m := NewMapperTagFunc("db", strings.ToLower, nil)

type User struct {
Name string
}

type Asset struct {
Title string

Owner *User `db:"owner"`
}
type Post struct {
Author string

Asset1 *Asset `db:"asset1"`
Asset2 *Asset `db:"asset2"`
}

post := &Post{Author: "Joe", Asset1: &Asset{Title: "Hiyo", Owner: &User{"Username"}}} // Let Asset2 be nil
pv := reflect.ValueOf(post)

fields := m.TypeMap(reflect.TypeOf(post))
if len(fields.Index) != 9 {
t.Errorf("Expecting 9 fields")
}

v := m.FieldByName(pv, "asset1.title")
if v.Interface().(string) != post.Asset1.Title {
t.Errorf("Expecting %s, got %s", post.Asset1.Title, v.Interface().(string))
}
v = m.FieldByName(pv, "asset1.owner.name")
if v.Interface().(string) != post.Asset1.Owner.Name {
t.Errorf("Expecting %s, got %s", post.Asset1.Owner.Name, v.Interface().(string))
}
v = m.FieldByName(pv, "asset2.title")
if v.Interface().(string) != post.Asset2.Title {
t.Errorf("Expecting %s, got %s", post.Asset2.Title, v.Interface().(string))
}
v = m.FieldByName(pv, "asset2.owner.name")
if v.Interface().(string) != post.Asset2.Owner.Name {
t.Errorf("Expecting %s, got %s", post.Asset2.Owner.Name, v.Interface().(string))
}
v = m.FieldByName(pv, "author")
if v.Interface().(string) != post.Author {
t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string))
}
}

func TestFieldMap(t *testing.T) {
type Foo struct {
A int
Expand Down
40 changes: 40 additions & 0 deletions sqlx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,46 @@ func TestJoinQuery(t *testing.T) {
})
}

func TestJoinQueryNamedPointerStructs(t *testing.T) {
type Employee struct {
Name string
Id int64
// BossId is an id into the employee table
BossId sql.NullInt64 `db:"boss_id"`
}
type Boss Employee

RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
loadDefaultFixture(db, t)

var employees []struct {
Emp1 *Employee `db:"emp1"`
Emp2 *Employee `db:"emp2"`
*Boss `db:"boss"`
}

err := db.Select(
&employees,
`SELECT emp.name "emp1.name", emp.id "emp1.id", emp.boss_id "emp1.boss_id",
emp.name "emp2.name", emp.id "emp2.id", emp.boss_id "emp2.boss_id",
boss.id "boss.id", boss.name "boss.name" FROM employees AS emp
JOIN employees AS boss ON emp.boss_id = boss.id
`)
if err != nil {
t.Fatal(err)
}

for _, em := range employees {
if len(em.Emp1.Name) == 0 || len(em.Emp2.Name) == 0 {
t.Errorf("Expected non zero lengthed name.")
}
if em.Emp1.BossId.Int64 != em.Boss.Id || em.Emp2.BossId.Int64 != em.Boss.Id {
t.Errorf("Expected boss ids to match")
}
}
})
}

func TestSelectSliceMapTime(t *testing.T) {
RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
loadDefaultFixture(db, t)
Expand Down

0 comments on commit cc395bc

Please sign in to comment.