From 646251ec02a1e2ec717e907e6f128d8b51f17c6d Mon Sep 17 00:00:00 2001 From: oGi4i Date: Tue, 5 Jul 2022 11:45:33 +0300 Subject: [PATCH] feat(pgdialect): add identity support --- dialect/feature/feature.go | 1 + dialect/pgdialect/dialect.go | 5 +- internal/dbtest/pg_test.go | 74 +++++++++++++++++++ internal/dbtest/query_test.go | 6 ++ .../testdata/snapshots/TestQuery-mariadb-151 | 1 + .../snapshots/TestQuery-mssql2019-151 | 1 + .../testdata/snapshots/TestQuery-mysql5-151 | 1 + .../testdata/snapshots/TestQuery-mysql8-151 | 1 + .../testdata/snapshots/TestQuery-pg-151 | 1 + .../testdata/snapshots/TestQuery-pgx-151 | 1 + .../testdata/snapshots/TestQuery-sqlite-151 | 1 + query_table_create.go | 5 ++ schema/field.go | 1 + schema/table.go | 6 +- 14 files changed, 102 insertions(+), 3 deletions(-) create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mariadb-151 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mssql2019-151 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mysql5-151 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-mysql8-151 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-pg-151 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-pgx-151 create mode 100644 internal/dbtest/testdata/snapshots/TestQuery-sqlite-151 diff --git a/dialect/feature/feature.go b/dialect/feature/feature.go index a2bba2c47..956dc4985 100644 --- a/dialect/feature/feature.go +++ b/dialect/feature/feature.go @@ -30,4 +30,5 @@ const ( SelectExists UpdateFromTable MSSavepoint + GeneratedIdentity ) diff --git a/dialect/pgdialect/dialect.go b/dialect/pgdialect/dialect.go index 1b64ea753..d524f0a1a 100644 --- a/dialect/pgdialect/dialect.go +++ b/dialect/pgdialect/dialect.go @@ -46,7 +46,8 @@ func New() *Dialect { feature.TableTruncate | feature.TableNotExists | feature.InsertOnConflict | - feature.SelectExists + feature.SelectExists | + feature.GeneratedIdentity return d } @@ -73,7 +74,7 @@ func (d *Dialect) OnTable(table *schema.Table) { func (d *Dialect) onField(field *schema.Field) { field.DiscoveredSQLType = fieldSQLType(field) - if field.AutoIncrement { + if field.AutoIncrement && !field.Identity { switch field.DiscoveredSQLType { case sqltype.SmallInt: field.CreateTableSQLType = pgTypeSmallSerial diff --git a/internal/dbtest/pg_test.go b/internal/dbtest/pg_test.go index 67c196fc8..61b95a9b1 100644 --- a/internal/dbtest/pg_test.go +++ b/internal/dbtest/pg_test.go @@ -222,6 +222,43 @@ func TestPostgresInsertNoRows(t *testing.T) { } } +func TestPostgresInsertNoRowsIdentity(t *testing.T) { + type User struct { + ID int64 `bun:",pk,identity"` + } + + db := pg(t) + + err := db.ResetModel(ctx, (*User)(nil)) + require.NoError(t, err) + + { + res, err := db.NewInsert(). + Model(&User{ID: 1}). + On("CONFLICT DO NOTHING"). + Returning("*"). + Exec(ctx) + require.NoError(t, err) + + n, err := res.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), n) + } + + { + res, err := db.NewInsert(). + Model(&User{ID: 1}). + On("CONFLICT DO NOTHING"). + Returning("*"). + Exec(ctx) + require.NoError(t, err) + + n, err := res.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(0), n) + } +} + func TestPostgresScanonlyField(t *testing.T) { type Model struct { Array []string `bun:",scanonly,array"` @@ -463,6 +500,43 @@ func TestPostgresOnConflictDoUpdate(t *testing.T) { } } +func TestPostgresOnConflictDoUpdateIdentity(t *testing.T) { + type Model struct { + ID int64 `bun:",pk,identity"` + UpdatedAt time.Time + } + + ctx := context.Background() + + db := pg(t) + defer db.Close() + + err := db.ResetModel(ctx, (*Model)(nil)) + require.NoError(t, err) + + model := &Model{ID: 1} + + _, err = db.NewInsert(). + Model(model). + On("CONFLICT (id) DO UPDATE"). + Set("updated_at = now()"). + Returning("id, updated_at"). + Exec(ctx) + require.NoError(t, err) + require.Zero(t, model.UpdatedAt) + + for i := 0; i < 2; i++ { + _, err = db.NewInsert(). + Model(model). + On("CONFLICT (id) DO UPDATE"). + Set("updated_at = now()"). + Returning("id, updated_at"). + Exec(ctx) + require.NoError(t, err) + require.NotZero(t, model.UpdatedAt) + } +} + func TestPostgresCopyFromCopyTo(t *testing.T) { ctx := context.Background() diff --git a/internal/dbtest/query_test.go b/internal/dbtest/query_test.go index 3daf09b5c..5966411c6 100644 --- a/internal/dbtest/query_test.go +++ b/internal/dbtest/query_test.go @@ -896,6 +896,12 @@ func TestQuery(t *testing.T) { Str: "hello", }).UseIndex("ix1", "ix2").UseIndex("ix3").Where("id = 3") }, + func(db *bun.DB) schema.QueryAppender { + type User struct { + ID int64 `bun:",pk,autoincrement,identity"` + } + return db.NewCreateTable().Model(new(User)) + }, } timeRE := regexp.MustCompile(`'2\d{3}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)?(\+\d{2}:\d{2})?'`) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mariadb-151 b/internal/dbtest/testdata/snapshots/TestQuery-mariadb-151 new file mode 100644 index 000000000..589d23b86 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mariadb-151 @@ -0,0 +1 @@ +CREATE TABLE `users` (`id` BIGINT NOT NULL AUTO_INCREMENT, PRIMARY KEY (`id`)) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mssql2019-151 b/internal/dbtest/testdata/snapshots/TestQuery-mssql2019-151 new file mode 100644 index 000000000..867ce83e5 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mssql2019-151 @@ -0,0 +1 @@ +CREATE TABLE "users" ("id" BIGINT NOT NULL IDENTITY, PRIMARY KEY ("id")) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql5-151 b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-151 new file mode 100644 index 000000000..589d23b86 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql5-151 @@ -0,0 +1 @@ +CREATE TABLE `users` (`id` BIGINT NOT NULL AUTO_INCREMENT, PRIMARY KEY (`id`)) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-mysql8-151 b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-151 new file mode 100644 index 000000000..589d23b86 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-mysql8-151 @@ -0,0 +1 @@ +CREATE TABLE `users` (`id` BIGINT NOT NULL AUTO_INCREMENT, PRIMARY KEY (`id`)) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pg-151 b/internal/dbtest/testdata/snapshots/TestQuery-pg-151 new file mode 100644 index 000000000..fa11c8aa3 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pg-151 @@ -0,0 +1 @@ +CREATE TABLE "users" ("id" BIGINT NOT NULL GENERATED BY DEFAULT AS IDENTITY, PRIMARY KEY ("id")) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-pgx-151 b/internal/dbtest/testdata/snapshots/TestQuery-pgx-151 new file mode 100644 index 000000000..fa11c8aa3 --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-pgx-151 @@ -0,0 +1 @@ +CREATE TABLE "users" ("id" BIGINT NOT NULL GENERATED BY DEFAULT AS IDENTITY, PRIMARY KEY ("id")) diff --git a/internal/dbtest/testdata/snapshots/TestQuery-sqlite-151 b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-151 new file mode 100644 index 000000000..2b1eb5a4b --- /dev/null +++ b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-151 @@ -0,0 +1 @@ +CREATE TABLE "users" ("id" INTEGER NOT NULL, PRIMARY KEY ("id")) diff --git a/query_table_create.go b/query_table_create.go index daa1ccca9..c795e8a97 100644 --- a/query_table_create.go +++ b/query_table_create.go @@ -168,6 +168,11 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by b = append(b, " IDENTITY"...) } } + if field.Identity { + if fmter.Dialect().Features().Has(feature.GeneratedIdentity) { + b = append(b, " GENERATED BY DEFAULT AS IDENTITY"...) + } + } if field.SQLDefault != "" { b = append(b, " DEFAULT "...) b = append(b, field.SQLDefault...) diff --git a/schema/field.go b/schema/field.go index ade7d5f2e..283a3b992 100644 --- a/schema/field.go +++ b/schema/field.go @@ -32,6 +32,7 @@ type Field struct { NotNull bool NullZero bool AutoIncrement bool + Identity bool Append AppenderFunc Scan ScannerFunc diff --git a/schema/table.go b/schema/table.go index 82043debb..9791f8ff1 100644 --- a/schema/table.go +++ b/schema/table.go @@ -353,6 +353,9 @@ func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Fie field.AutoIncrement = true field.NullZero = true } + if tag.HasOption("identity") { + field.Identity = true + } if v, ok := tag.Options["unique"]; ok { var names []string @@ -911,7 +914,8 @@ func isKnownFieldOption(name string) bool { "on_update", "on_delete", "m2m", - "polymorphic": + "polymorphic", + "identity": return true } return false