diff --git a/columns/columns.go b/columns/columns.go index 8ee8f15c..4ea26564 100644 --- a/columns/columns.go +++ b/columns/columns.go @@ -13,6 +13,7 @@ type Columns struct { lock *sync.RWMutex TableName string TableAlias string + IDField string } // Add a column to the list. @@ -74,7 +75,7 @@ func (c *Columns) Add(names ...string) []*Column { } else if xs[1] == "w" { col.Readable = false } - } else if col.Name == "id" { + } else if col.Name == c.IDField { col.Writeable = false } @@ -98,7 +99,7 @@ func (c *Columns) Remove(names ...string) { // Writeable gets a list of the writeable columns from the column list. func (c Columns) Writeable() *WriteableColumns { - w := &WriteableColumns{NewColumnsWithAlias(c.TableName, c.TableAlias)} + w := &WriteableColumns{NewColumnsWithAlias(c.TableName, c.TableAlias, c.IDField)} for _, col := range c.Cols { if col.Writeable { w.Cols[col.Name] = col @@ -109,7 +110,7 @@ func (c Columns) Writeable() *WriteableColumns { // Readable gets a list of the readable columns from the column list. func (c Columns) Readable() *ReadableColumns { - w := &ReadableColumns{NewColumnsWithAlias(c.TableName, c.TableAlias)} + w := &ReadableColumns{NewColumnsWithAlias(c.TableName, c.TableAlias, c.IDField)} for _, col := range c.Cols { if col.Readable { w.Cols[col.Name] = col @@ -157,17 +158,18 @@ func (c Columns) SymbolizedString() string { } // NewColumns constructs a list of columns for a given table name. -func NewColumns(tableName string) Columns { - return NewColumnsWithAlias(tableName, "") +func NewColumns(tableName, idField string) Columns { + return NewColumnsWithAlias(tableName, "", idField) } // NewColumnsWithAlias constructs a list of columns for a given table // name, using a given alias for the table. -func NewColumnsWithAlias(tableName string, tableAlias string) Columns { +func NewColumnsWithAlias(tableName, tableAlias, idField string) Columns { return Columns{ lock: &sync.RWMutex{}, Cols: map[string]*Column{}, TableName: tableName, TableAlias: tableAlias, + IDField: idField, } } diff --git a/columns/columns_for_struct.go b/columns/columns_for_struct.go index 22cdbebc..a20cd426 100644 --- a/columns/columns_for_struct.go +++ b/columns/columns_for_struct.go @@ -6,17 +6,17 @@ import ( // ForStruct returns a Columns instance for // the struct passed in. -func ForStruct(s interface{}, tableName string) (columns Columns) { - return ForStructWithAlias(s, tableName, "") +func ForStruct(s interface{}, tableName, idField string) (columns Columns) { + return ForStructWithAlias(s, tableName, "", idField) } // ForStructWithAlias returns a Columns instance for the struct passed in. // If the tableAlias is not empty, it will be used. -func ForStructWithAlias(s interface{}, tableName string, tableAlias string) (columns Columns) { - columns = NewColumnsWithAlias(tableName, tableAlias) +func ForStructWithAlias(s interface{}, tableName, tableAlias, idField string) (columns Columns) { + columns = NewColumnsWithAlias(tableName, tableAlias, idField) defer func() { if r := recover(); r != nil { - columns = NewColumnsWithAlias(tableName, tableAlias) + columns = NewColumnsWithAlias(tableName, tableAlias, idField) columns.Add("*") } }() diff --git a/columns/columns_test.go b/columns/columns_test.go index caa0716a..f4699dc4 100644 --- a/columns/columns_test.go +++ b/columns/columns_test.go @@ -21,8 +21,8 @@ type foos []foo func Test_Column_MapsSlice(t *testing.T) { r := require.New(t) - c1 := columns.ForStruct(&foo{}, "foo") - c2 := columns.ForStruct(&foos{}, "foo") + c1 := columns.ForStruct(&foo{}, "foo", "id") + c2 := columns.ForStruct(&foos{}, "foo", "id") r.Equal(c1.String(), c2.String()) } @@ -30,7 +30,7 @@ func Test_Columns_Basics(t *testing.T) { r := require.New(t) for _, f := range []interface{}{foo{}, &foo{}} { - c := columns.ForStruct(f, "foo") + c := columns.ForStruct(f, "foo", "id") r.Equal(len(c.Cols), 4) r.Equal(c.Cols["first_name"], &columns.Column{Name: "first_name", Writeable: false, Readable: true, SelectSQL: "first_name as f"}) r.Equal(c.Cols["LastName"], &columns.Column{Name: "LastName", Writeable: true, Readable: true, SelectSQL: "foo.LastName"}) @@ -43,7 +43,7 @@ func Test_Columns_Add(t *testing.T) { r := require.New(t) for _, f := range []interface{}{foo{}, &foo{}} { - c := columns.ForStruct(f, "foo") + c := columns.ForStruct(f, "foo", "id") r.Equal(len(c.Cols), 4) c.Add("foo", "first_name") r.Equal(len(c.Cols), 5) @@ -55,7 +55,7 @@ func Test_Columns_Remove(t *testing.T) { r := require.New(t) for _, f := range []interface{}{foo{}, &foo{}} { - c := columns.ForStruct(f, "foo") + c := columns.ForStruct(f, "foo", "id") r.Equal(len(c.Cols), 4) c.Remove("foo", "first_name") r.Equal(len(c.Cols), 3) @@ -75,9 +75,43 @@ func (fooQuoter) Quote(key string) string { func Test_Columns_Sorted(t *testing.T) { r := require.New(t) - c := columns.ForStruct(fooWithSuffix{}, "fooWithSuffix") + c := columns.ForStruct(fooWithSuffix{}, "fooWithSuffix", "id") r.Equal(len(c.Cols), 2) r.Equal(c.SymbolizedString(), ":amount, :amount_units") r.Equal(c.String(), "amount, amount_units") r.Equal(c.QuotedString(fooQuoter{}), "`amount`, `amount_units`") } + +func Test_Columns_IDField(t *testing.T) { + type withID struct { + ID string `db:"id"` + } + + r := require.New(t) + c := columns.ForStruct(withID{}, "with_id", "id") + r.Equal(1, len(c.Cols), "%+v", c) + r.Equal(&columns.Column{Name: "id", Writeable: false, Readable: true, SelectSQL: "with_id.id"}, c.Cols["id"]) +} + +func Test_Columns_IDField_Readonly(t *testing.T) { + type withIDReadonly struct { + ID string `db:"id" rw:"r"` + } + + r := require.New(t) + c := columns.ForStruct(withIDReadonly{}, "with_id_readonly", "id") + r.Equal(1, len(c.Cols), "%+v", c) + r.Equal(&columns.Column{Name: "id", Writeable: false, Readable: true, SelectSQL: "with_id_readonly.id"}, c.Cols["id"]) +} + +func Test_Columns_ID_Field_Not_ID(t *testing.T) { + type withNonStandardID struct { + PK string `db:"notid"` + } + + r := require.New(t) + + c := columns.ForStruct(withNonStandardID{}, "non_standard_id", "notid") + r.Equal(1, len(c.Cols), "%+v", c) + r.Equal(&columns.Column{Name: "notid", Writeable: false, Readable: true, SelectSQL: "non_standard_id.notid"}, c.Cols["notid"]) +} diff --git a/columns/readable_columns_test.go b/columns/readable_columns_test.go index 8394b967..a563d789 100644 --- a/columns/readable_columns_test.go +++ b/columns/readable_columns_test.go @@ -10,7 +10,7 @@ import ( func Test_Columns_ReadableString(t *testing.T) { r := require.New(t) for _, f := range []interface{}{foo{}, &foo{}} { - c := columns.ForStruct(f, "foo") + c := columns.ForStruct(f, "foo", "id") u := c.Readable().String() r.Equal(u, "LastName, first_name, read") } @@ -19,7 +19,7 @@ func Test_Columns_ReadableString(t *testing.T) { func Test_Columns_Readable_SelectString(t *testing.T) { r := require.New(t) for _, f := range []interface{}{foo{}, &foo{}} { - c := columns.ForStruct(f, "foo") + c := columns.ForStruct(f, "foo", "id") u := c.Readable().SelectString() r.Equal(u, "first_name as f, foo.LastName, foo.read") } @@ -28,7 +28,7 @@ func Test_Columns_Readable_SelectString(t *testing.T) { func Test_Columns_ReadableString_Symbolized(t *testing.T) { r := require.New(t) for _, f := range []interface{}{foo{}, &foo{}} { - c := columns.ForStruct(f, "foo") + c := columns.ForStruct(f, "foo", "id") u := c.Readable().SymbolizedString() r.Equal(u, ":LastName, :first_name, :read") } diff --git a/columns/writeable_columns_test.go b/columns/writeable_columns_test.go index 269735f3..053dbdaf 100644 --- a/columns/writeable_columns_test.go +++ b/columns/writeable_columns_test.go @@ -10,7 +10,7 @@ import ( func Test_Columns_WriteableString_Symbolized(t *testing.T) { r := require.New(t) for _, f := range []interface{}{foo{}, &foo{}} { - c := columns.ForStruct(f, "foo") + c := columns.ForStruct(f, "foo", "id") u := c.Writeable().SymbolizedString() r.Equal(u, ":LastName, :write") } @@ -19,7 +19,7 @@ func Test_Columns_WriteableString_Symbolized(t *testing.T) { func Test_Columns_UpdateString(t *testing.T) { r := require.New(t) for _, f := range []interface{}{foo{}, &foo{}} { - c := columns.ForStruct(f, "foo") + c := columns.ForStruct(f, "foo", "id") u := c.Writeable().UpdateString() r.Equal(u, "LastName = :LastName, write = :write") } @@ -35,7 +35,7 @@ func Test_Columns_QuotedUpdateString(t *testing.T) { r := require.New(t) q := testQuoter{} for _, f := range []interface{}{foo{}, &foo{}} { - c := columns.ForStruct(f, "foo") + c := columns.ForStruct(f, "foo", "id") u := c.Writeable().QuotedUpdateString(q) r.Equal(u, "\"LastName\" = :LastName, \"write\" = :write") } @@ -44,7 +44,7 @@ func Test_Columns_QuotedUpdateString(t *testing.T) { func Test_Columns_WriteableString(t *testing.T) { r := require.New(t) for _, f := range []interface{}{foo{}, &foo{}} { - c := columns.ForStruct(f, "foo") + c := columns.ForStruct(f, "foo", "id") u := c.Writeable().String() r.Equal(u, "LastName, write") } diff --git a/executors.go b/executors.go index 9704c919..b3560141 100644 --- a/executors.go +++ b/executors.go @@ -228,7 +228,7 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error { } tn := m.TableName() - cols := columns.ForStructWithAlias(m.Value, tn, m.As) + cols := m.Columns() if tn == sm.TableName() { cols.Remove(excludeColumns...) @@ -350,8 +350,8 @@ func (c *Connection) Update(model interface{}, excludeColumns ...string) error { } tn := m.TableName() - cols := columns.ForStructWithAlias(model, tn, m.As) - cols.Remove("id", "created_at") + cols := columns.ForStructWithAlias(model, tn, m.As, m.IDField()) + cols.Remove(m.IDField(), "created_at") if tn == sm.TableName() { cols.Remove(excludeColumns...) @@ -393,11 +393,11 @@ func (c *Connection) UpdateColumns(model interface{}, columnNames ...string) err cols := columns.Columns{} if len(columnNames) > 0 && tn == sm.TableName() { - cols = columns.NewColumnsWithAlias(tn, m.As) + cols = columns.NewColumnsWithAlias(tn, m.As, sm.IDField()) cols.Add(columnNames...) } else { - cols = columns.ForStructWithAlias(model, tn, m.As) + cols = columns.ForStructWithAlias(model, tn, m.As, m.IDField()) } cols.Remove("id", "created_at") diff --git a/executors_test.go b/executors_test.go index 4fb4bb74..bee9a5e8 100644 --- a/executors_test.go +++ b/executors_test.go @@ -510,6 +510,28 @@ func Test_Create_With_Non_ID_PK_String(t *testing.T) { }) } +func Test_Create_Non_PK_ID(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + transaction(func(tx *Connection) { + r := require.New(t) + + r.NoError(tx.Create(&NonStandardID{OutfacingID: "make sure the tested entry does not have pk=0"})) + + count, err := tx.Count(&NonStandardID{}) + entry := &NonStandardID{ + OutfacingID: "beautiful to the outside ID", + } + r.NoError(tx.Create(entry)) + + ctx, err := tx.Count(&NonStandardID{}) + r.NoError(err) + r.Equal(count+1, ctx) + r.NotZero(entry.ID) + }) +} + func Test_Eager_Create_Has_Many(t *testing.T) { if PDB == nil { t.Skip("skipping integration tests") @@ -1470,6 +1492,54 @@ func Test_Update_UUID(t *testing.T) { }) } +func Test_Update_With_Non_ID_PK(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + transaction(func(tx *Connection) { + r := require.New(t) + + r.NoError(tx.Create(&CrookedColour{Name: "cc is not the first one"})) + + cc := CrookedColour{ + Name: "You?", + } + err := tx.Create(&cc) + r.NoError(err) + r.NotZero(cc.ID) + id := cc.ID + + updatedName := "Me!" + cc.Name = updatedName + r.NoError(tx.Update(&cc)) + r.Equal(id, cc.ID) + + r.NoError(tx.Reload(&cc)) + r.Equal(updatedName, cc.Name) + r.Equal(id, cc.ID) + }) +} + +func Test_Update_Non_PK_ID(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + transaction(func(tx *Connection) { + r := require.New(t) + + client := &NonStandardID{ + OutfacingID: "my awesome hydra client", + } + r.NoError(tx.Create(client)) + + updatedID := "your awesome hydra client" + client.OutfacingID = updatedID + r.NoError(tx.Update(client)) + r.NoError(tx.Reload(client)) + r.Equal(updatedID, client.OutfacingID) + }) +} + func Test_Destroy(t *testing.T) { if PDB == nil { t.Skip("skipping integration tests") diff --git a/model.go b/model.go index 62efd25b..58019f73 100644 --- a/model.go +++ b/model.go @@ -2,6 +2,7 @@ package pop import ( "fmt" + "github.com/gobuffalo/pop/v5/columns" "github.com/pkg/errors" "reflect" "sync" @@ -46,7 +47,18 @@ func (m *Model) ID() interface{} { // IDField returns the name of the DB field used for the ID. // By default, it will return "id". func (m *Model) IDField() string { - field, ok := reflect.TypeOf(m.Value).Elem().FieldByName("ID") + modelType := reflect.TypeOf(m.Value) + + // remove all indirections + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Array { + modelType = modelType.Elem() + } + + if modelType.Kind() == reflect.String { + return "id" + } + + field, ok := modelType.FieldByName("ID") if !ok { return "id" } @@ -101,6 +113,10 @@ func (m *Model) TableName() string { return tableMap[cacheKey] } +func (m *Model) Columns() columns.Columns { + return columns.ForStructWithAlias(m.Value, m.TableName(), m.As, m.IDField()) +} + func (m *Model) cacheKey(t reflect.Type) string { return t.PkgPath() + "." + t.Name() } diff --git a/pop_test.go b/pop_test.go index 33efcbe2..4ac25b68 100644 --- a/pop_test.go +++ b/pop_test.go @@ -419,3 +419,8 @@ type CrookedSong struct { CreatedAt time.Time `db:"created_at"` UpdatedAt time.Time `db:"updated_at"` } + +type NonStandardID struct { + ID int `db:"pk"` + OutfacingID string `db:"id"` +} diff --git a/sql_builder.go b/sql_builder.go index 4ce35703..edd05c46 100644 --- a/sql_builder.go +++ b/sql_builder.go @@ -229,7 +229,7 @@ func (sq *sqlBuilder) buildColumns() columns.Columns { if ok && cols.TableAlias == asName { return cols } - cols = columns.ForStructWithAlias(sq.Model.Value, tableName, asName) + cols = columns.ForStructWithAlias(sq.Model.Value, tableName, asName, sq.Model.IDField()) columnCacheMutex.Lock() columnCache[tableName] = cols columnCacheMutex.Unlock() @@ -237,7 +237,7 @@ func (sq *sqlBuilder) buildColumns() columns.Columns { } // acl > 0 - cols := columns.NewColumns("") + cols := columns.NewColumns("", sq.Model.IDField()) cols.Add(sq.AddColumns...) return cols } diff --git a/testdata/migrations/20201028153041_non_standard_id.down.fizz b/testdata/migrations/20201028153041_non_standard_id.down.fizz new file mode 100644 index 00000000..5c56284f --- /dev/null +++ b/testdata/migrations/20201028153041_non_standard_id.down.fizz @@ -0,0 +1 @@ +drop_table("non_standard_ids") diff --git a/testdata/migrations/20201028153041_non_standard_id.up.fizz b/testdata/migrations/20201028153041_non_standard_id.up.fizz new file mode 100644 index 00000000..7590e4b4 --- /dev/null +++ b/testdata/migrations/20201028153041_non_standard_id.up.fizz @@ -0,0 +1,6 @@ +create_table("non_standard_ids") { + t.Column("pk", "int", { primary: true }) + t.Column("id", "string", {}) + + t.DisableTimestamps() +}