Skip to content

Commit

Permalink
feat(pgdialect): add identity support
Browse files Browse the repository at this point in the history
  • Loading branch information
oGi4i committed Jul 5, 2022
1 parent 57bd1fa commit 646251e
Show file tree
Hide file tree
Showing 14 changed files with 102 additions and 3 deletions.
1 change: 1 addition & 0 deletions dialect/feature/feature.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ const (
SelectExists
UpdateFromTable
MSSavepoint
GeneratedIdentity
)
5 changes: 3 additions & 2 deletions dialect/pgdialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ func New() *Dialect {
feature.TableTruncate |
feature.TableNotExists |
feature.InsertOnConflict |
feature.SelectExists
feature.SelectExists |
feature.GeneratedIdentity
return d
}

Expand All @@ -73,7 +74,7 @@ func (d *Dialect) OnTable(table *schema.Table) {
func (d *Dialect) onField(field *schema.Field) {
field.DiscoveredSQLType = fieldSQLType(field)

if field.AutoIncrement {
if field.AutoIncrement && !field.Identity {
switch field.DiscoveredSQLType {
case sqltype.SmallInt:
field.CreateTableSQLType = pgTypeSmallSerial
Expand Down
74 changes: 74 additions & 0 deletions internal/dbtest/pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,43 @@ func TestPostgresInsertNoRows(t *testing.T) {
}
}

func TestPostgresInsertNoRowsIdentity(t *testing.T) {
type User struct {
ID int64 `bun:",pk,identity"`
}

db := pg(t)

err := db.ResetModel(ctx, (*User)(nil))
require.NoError(t, err)

{
res, err := db.NewInsert().
Model(&User{ID: 1}).
On("CONFLICT DO NOTHING").
Returning("*").
Exec(ctx)
require.NoError(t, err)

n, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1), n)
}

{
res, err := db.NewInsert().
Model(&User{ID: 1}).
On("CONFLICT DO NOTHING").
Returning("*").
Exec(ctx)
require.NoError(t, err)

n, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(0), n)
}
}

func TestPostgresScanonlyField(t *testing.T) {
type Model struct {
Array []string `bun:",scanonly,array"`
Expand Down Expand Up @@ -463,6 +500,43 @@ func TestPostgresOnConflictDoUpdate(t *testing.T) {
}
}

func TestPostgresOnConflictDoUpdateIdentity(t *testing.T) {
type Model struct {
ID int64 `bun:",pk,identity"`
UpdatedAt time.Time
}

ctx := context.Background()

db := pg(t)
defer db.Close()

err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)

model := &Model{ID: 1}

_, err = db.NewInsert().
Model(model).
On("CONFLICT (id) DO UPDATE").
Set("updated_at = now()").
Returning("id, updated_at").
Exec(ctx)
require.NoError(t, err)
require.Zero(t, model.UpdatedAt)

for i := 0; i < 2; i++ {
_, err = db.NewInsert().
Model(model).
On("CONFLICT (id) DO UPDATE").
Set("updated_at = now()").
Returning("id, updated_at").
Exec(ctx)
require.NoError(t, err)
require.NotZero(t, model.UpdatedAt)
}
}

func TestPostgresCopyFromCopyTo(t *testing.T) {
ctx := context.Background()

Expand Down
6 changes: 6 additions & 0 deletions internal/dbtest/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,12 @@ func TestQuery(t *testing.T) {
Str: "hello",
}).UseIndex("ix1", "ix2").UseIndex("ix3").Where("id = 3")
},
func(db *bun.DB) schema.QueryAppender {
type User struct {
ID int64 `bun:",pk,autoincrement,identity"`
}
return db.NewCreateTable().Model(new(User))
},
}

timeRE := regexp.MustCompile(`'2\d{3}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)?(\+\d{2}:\d{2})?'`)
Expand Down
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mariadb-151
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE `users` (`id` BIGINT NOT NULL AUTO_INCREMENT, PRIMARY KEY (`id`))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mssql2019-151
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE "users" ("id" BIGINT NOT NULL IDENTITY, PRIMARY KEY ("id"))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql5-151
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE `users` (`id` BIGINT NOT NULL AUTO_INCREMENT, PRIMARY KEY (`id`))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-mysql8-151
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE `users` (`id` BIGINT NOT NULL AUTO_INCREMENT, PRIMARY KEY (`id`))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pg-151
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE "users" ("id" BIGINT NOT NULL GENERATED BY DEFAULT AS IDENTITY, PRIMARY KEY ("id"))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-pgx-151
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE "users" ("id" BIGINT NOT NULL GENERATED BY DEFAULT AS IDENTITY, PRIMARY KEY ("id"))
1 change: 1 addition & 0 deletions internal/dbtest/testdata/snapshots/TestQuery-sqlite-151
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE "users" ("id" INTEGER NOT NULL, PRIMARY KEY ("id"))
5 changes: 5 additions & 0 deletions query_table_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by
b = append(b, " IDENTITY"...)
}
}
if field.Identity {
if fmter.Dialect().Features().Has(feature.GeneratedIdentity) {
b = append(b, " GENERATED BY DEFAULT AS IDENTITY"...)
}
}
if field.SQLDefault != "" {
b = append(b, " DEFAULT "...)
b = append(b, field.SQLDefault...)
Expand Down
1 change: 1 addition & 0 deletions schema/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Field struct {
NotNull bool
NullZero bool
AutoIncrement bool
Identity bool

Append AppenderFunc
Scan ScannerFunc
Expand Down
6 changes: 5 additions & 1 deletion schema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,9 @@ func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Fie
field.AutoIncrement = true
field.NullZero = true
}
if tag.HasOption("identity") {
field.Identity = true
}

if v, ok := tag.Options["unique"]; ok {
var names []string
Expand Down Expand Up @@ -911,7 +914,8 @@ func isKnownFieldOption(name string) bool {
"on_update",
"on_delete",
"m2m",
"polymorphic":
"polymorphic",
"identity":
return true
}
return false
Expand Down

0 comments on commit 646251e

Please sign in to comment.