Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use RETURNING clause for batch create #3293

Merged
merged 9 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"TableName": "\"test_models\"",
"ColumnsDecl": "\"created_at\", \"id\", \"int\", \"nid\", \"null_time_ptr\", \"string\", \"traits\", \"updated_at\"",
"Columns": [
"created_at",
"id",
"int",
"nid",
"null_time_ptr",
"string",
"traits",
"updated_at"
],
"Placeholders": "(?, ?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, gen_random_uuid(), ?, ?, ?, ?, ?, ?),\n(?, ?, ?, ?, ?, ?, ?, ?)"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[
"0001-01-01T00:00:00Z",
"0001-01-01T00:00:00Z",
"string",
42,
null,
{
"foo": "bar"
}
]
149 changes: 137 additions & 12 deletions persistence/sql/batch/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@ package batch

import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"strings"
"time"

"github.com/jmoiron/sqlx/reflectx"

"github.com/ory/x/dbal"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/jmoiron/sqlx/reflectx"
"github.com/pkg/errors"

"github.com/ory/x/otelx"
Expand All @@ -38,7 +42,7 @@ type (
}
)

func buildInsertQueryArgs[T any](ctx context.Context, quoter quoter, models []*T) insertQueryArgs {
func buildInsertQueryArgs[T any](ctx context.Context, dialect string, mapper *reflectx.Mapper, quoter quoter, models []*T) insertQueryArgs {
var (
v T
model = pop.NewModel(v, ctx)
Expand All @@ -60,8 +64,41 @@ func buildInsertQueryArgs[T any](ctx context.Context, quoter quoter, models []*T
for _, col := range columns {
quotedColumns = append(quotedColumns, quoter.Quote(col))
}
for range models {
placeholders = append(placeholders, fmt.Sprintf("(%s)", strings.Join(placeholderRow, ", ")))

// We generate a list (for every row one) of VALUE statements here that
// will be substituted by their column values later:
//
// (?, ?, ?, ?),
// (?, ?, ?, ?),
// (?, ?, ?, ?)
for _, m := range models {
m := reflect.ValueOf(m)

pl := make([]string, len(placeholderRow))
copy(pl, placeholderRow)

// There is a special case - when using CockroachDB we want to generate
// UUIDs using "gen_random_uuid()" which ends up in a VALUE statement of:
//
// (gen_random_uuid(), ?, ?, ?),
for k := range placeholderRow {
if columns[k] != "id" {
continue
}

field := mapper.FieldByName(m, columns[k])
val, ok := field.Interface().(uuid.UUID)
if !ok {
continue
}

if val == uuid.Nil && dialect == dbal.DriverCockroachDB {
pl[k] = "gen_random_uuid()"
break
}
}

placeholders = append(placeholders, fmt.Sprintf("(%s)", strings.Join(pl, ", ")))
}

return insertQueryArgs{
Expand All @@ -72,12 +109,11 @@ func buildInsertQueryArgs[T any](ctx context.Context, quoter quoter, models []*T
}
}

func buildInsertQueryValues[T any](mapper *reflectx.Mapper, columns []string, models []*T) (values []any, err error) {
now := time.Now().UTC().Truncate(time.Microsecond)

func buildInsertQueryValues[T any](dialect string, mapper *reflectx.Mapper, columns []string, models []*T, nowFunc func() time.Time) (values []any, err error) {
for _, m := range models {
m := reflect.ValueOf(m)

now := nowFunc()
// Append model fields to args
for _, c := range columns {
field := mapper.FieldByName(m, c)
Expand All @@ -89,17 +125,31 @@ func buildInsertQueryValues[T any](mapper *reflectx.Mapper, columns []string, mo
}
case "updated_at":
field.Set(reflect.ValueOf(now))

case "id":
if field.Interface().(uuid.UUID) != uuid.Nil {
break // breaks switch, not for
} else if dialect == dbal.DriverCockroachDB {
// This is a special case:
// 1. We're using cockroach
// 2. It's the primary key field ("ID")
// 3. A UUID was not yet set.
//
// If all these conditions meet, the VALUE statement will look as such:
//
// (gen_random_uuid(), ?, ?, ?, ...)
//
// For that reason, we do not add the ID value to the list of arguments,
// because one of the arguments is using a built-in and thus doesn't need a value.
continue // break switch, not for
}

id, err := uuid.NewV4()
if err != nil {
return nil, err
}
field.Set(reflect.ValueOf(id))
}

values = append(values, field.Interface())

// Special-handling for *sqlxx.NullTime: mapper.FieldByName sets this to a zero time.Time,
Expand All @@ -125,26 +175,101 @@ func Create[T any](ctx context.Context, p *TracerConnection, models []*T) (err e
return nil
}

var v T
model := pop.NewModel(v, ctx)

conn := p.Connection
quoter, ok := conn.Dialect.(quoter)
if !ok {
return errors.Errorf("store is not a quoter: %T", conn.Store)
}

queryArgs := buildInsertQueryArgs(ctx, quoter, models)
values, err := buildInsertQueryValues(conn.TX.Mapper, queryArgs.Columns, models)
queryArgs := buildInsertQueryArgs(ctx, conn.Dialect.Name(), conn.TX.Mapper, quoter, models)
values, err := buildInsertQueryValues(conn.Dialect.Name(), conn.TX.Mapper, queryArgs.Columns, models, func() time.Time { return time.Now().UTC().Truncate(time.Microsecond) })
if err != nil {
return err
}

var returningClause string
if conn.Dialect.Name() != dbal.DriverMySQL {
// PostgreSQL, CockroachDB, SQLite support RETURNING.
returningClause = fmt.Sprintf("RETURNING %s", model.IDField())
}

query := conn.Dialect.TranslateSQL(fmt.Sprintf(
"INSERT INTO %s (%s) VALUES\n%s",
"INSERT INTO %s (%s) VALUES\n%s\n%s",
queryArgs.TableName,
queryArgs.ColumnsDecl,
queryArgs.Placeholders,
returningClause,
))

_, err = conn.Store.ExecContext(ctx, query, values...)
rows, err := conn.TX.QueryContext(ctx, query, values...)
if err != nil {
return sqlcon.HandleError(err)
}
defer rows.Close()

// Hydrate the models from the RETURNING clause.
//
// Databases not supporting RETURNING will just return 0 rows.
count := 0
for rows.Next() {
if err := rows.Err(); err != nil {
return sqlcon.HandleError(err)
}

if err := setModelID(rows, pop.NewModel(models[count], ctx)); err != nil {
return err
}
count++
}

if err := rows.Err(); err != nil {
return sqlcon.HandleError(err)
}

if err := rows.Close(); err != nil {
return sqlcon.HandleError(err)
}

return sqlcon.HandleError(err)
}

// setModelID was copy & pasted from pop. It basically sets
// the primary key to the given value read from the SQL row.
func setModelID(row *sql.Rows, model *pop.Model) error {
el := reflect.ValueOf(model.Value).Elem()
fbn := el.FieldByName("ID")
if !fbn.IsValid() {
return errors.New("model does not have a field named id")
}

pkt, err := model.PrimaryKeyType()
if err != nil {
return errors.WithStack(err)
}

switch pkt {
case "UUID":
var id uuid.UUID
if err := row.Scan(&id); err != nil {
return errors.WithStack(err)
}
fbn.Set(reflect.ValueOf(id))
default:
var id interface{}
if err := row.Scan(&id); err != nil {
return errors.WithStack(err)
}
v := reflect.ValueOf(id)
switch fbn.Kind() {
case reflect.Int, reflect.Int64:
fbn.SetInt(v.Int())
default:
fbn.Set(reflect.ValueOf(id))
}
}

return nil
}
77 changes: 58 additions & 19 deletions persistence/sql/batch/create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"testing"
"time"

"github.com/ory/x/dbal"

"github.com/gofrs/uuid"
"github.com/jmoiron/sqlx/reflectx"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -39,11 +41,20 @@ func (i testModel) TableName(ctx context.Context) string {

func (tq testQuoter) Quote(s string) string { return fmt.Sprintf("%q", s) }

func makeModels[T any]() []*T {
models := make([]*T, 10)
for k := range models {
models[k] = new(T)
}
return models
}

func Test_buildInsertQueryArgs(t *testing.T) {
ctx := context.Background()
t.Run("case=testModel", func(t *testing.T) {
models := make([]*testModel, 10)
args := buildInsertQueryArgs(ctx, testQuoter{}, models)
models := makeModels[testModel]()
mapper := reflectx.NewMapper("db")
args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models)
snapshotx.SnapshotT(t, args)

query := fmt.Sprintf("INSERT INTO %s (%s) VALUES\n%s", args.TableName, args.ColumnsDecl, args.Placeholders)
Expand All @@ -61,20 +72,35 @@ func Test_buildInsertQueryArgs(t *testing.T) {
})

t.Run("case=Identities", func(t *testing.T) {
models := make([]*identity.Identity, 10)
args := buildInsertQueryArgs(ctx, testQuoter{}, models)
models := makeModels[identity.Identity]()
mapper := reflectx.NewMapper("db")
args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models)
snapshotx.SnapshotT(t, args)
})

t.Run("case=RecoveryAddress", func(t *testing.T) {
models := make([]*identity.RecoveryAddress, 10)
args := buildInsertQueryArgs(ctx, testQuoter{}, models)
models := makeModels[identity.RecoveryAddress]()
mapper := reflectx.NewMapper("db")
args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models)
snapshotx.SnapshotT(t, args)
})

t.Run("case=RecoveryAddress", func(t *testing.T) {
models := make([]*identity.RecoveryAddress, 10)
args := buildInsertQueryArgs(ctx, testQuoter{}, models)
models := makeModels[identity.RecoveryAddress]()
mapper := reflectx.NewMapper("db")
args := buildInsertQueryArgs(ctx, "other", mapper, testQuoter{}, models)
snapshotx.SnapshotT(t, args)
})

t.Run("case=cockroach", func(t *testing.T) {
models := makeModels[testModel]()
for k := range models {
if k%3 == 0 {
models[k].ID = uuid.FromStringOrNil(fmt.Sprintf("ae0125a9-2786-4ada-82d2-d169cf75047%d", k))
}
}
mapper := reflectx.NewMapper("db")
args := buildInsertQueryArgs(ctx, "cockroach", mapper, testQuoter{}, models)
snapshotx.SnapshotT(t, args)
})
}
Expand All @@ -87,21 +113,34 @@ func Test_buildInsertQueryValues(t *testing.T) {
Traits: []byte(`{"foo": "bar"}`),
}
mapper := reflectx.NewMapper("db")
values, err := buildInsertQueryValues(mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model})
require.NoError(t, err)

assert.NotNil(t, model.CreatedAt)
assert.Equal(t, model.CreatedAt, values[0])
nowFunc := func() time.Time {
return time.Time{}
}
t.Run("case=cockroach", func(t *testing.T) {
values, err := buildInsertQueryValues(dbal.DriverCockroachDB, mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc)
require.NoError(t, err)
snapshotx.SnapshotT(t, values)
})

t.Run("case=others", func(t *testing.T) {
values, err := buildInsertQueryValues("other", mapper, []string{"created_at", "updated_at", "id", "string", "int", "null_time_ptr", "traits"}, []*testModel{model}, nowFunc)
require.NoError(t, err)

assert.NotNil(t, model.CreatedAt)
assert.Equal(t, model.CreatedAt, values[0])

assert.NotNil(t, model.UpdatedAt)
assert.Equal(t, model.UpdatedAt, values[1])

assert.NotNil(t, model.UpdatedAt)
assert.Equal(t, model.UpdatedAt, values[1])
assert.NotZero(t, model.ID)
assert.Equal(t, model.ID, values[2])

assert.NotNil(t, model.ID)
assert.Equal(t, model.ID, values[2])
assert.Equal(t, model.String, values[3])
assert.Equal(t, model.Int, values[4])

assert.Equal(t, model.String, values[3])
assert.Equal(t, model.Int, values[4])
assert.Nil(t, model.NullTimePtr)

assert.Nil(t, model.NullTimePtr)
})
})
}
3 changes: 1 addition & 2 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,8 @@ func (p *IdentityPersister) createVerifiableAddresses(ctx context.Context, conn
work = append(work, &id.VerifiableAddresses[i])
}
}
err = batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work)

return err
return batch.Create(ctx, &batch.TracerConnection{Tracer: p.r.Tracer(ctx), Connection: conn}, work)
}

func updateAssociation[T interface {
Expand Down
2 changes: 1 addition & 1 deletion script/testenv.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

docker rm -f kratos_test_database_mysql kratos_test_database_postgres kratos_test_database_cockroach kratos_test_hydra || true
docker run --platform linux/amd64 --name kratos_test_database_mysql -p 3444:3306 -e MYSQL_ROOT_PASSWORD=secret -d mysql:8.0.23
docker run --platform linux/amd64 --name kratos_test_database_mysql -p 3444:3306 -e MYSQL_ROOT_PASSWORD=secret -d mysql:8.0.26
docker run --platform linux/amd64 --name kratos_test_database_postgres -p 3445:5432 -e POSTGRES_PASSWORD=secret -e POSTGRES_DB=postgres -d postgres:11.8 postgres -c log_statement=all
docker run --platform linux/amd64 --name kratos_test_database_cockroach -p 3446:26257 -p 3447:8080 -d cockroachdb/cockroach:v22.2.6 start-single-node --insecure
docker run --platform linux/amd64 --name kratos_test_hydra -p 4444:4444 -p 4445:4445 -d -e DSN=memory -e URLS_SELF_ISSUER=http://localhost:4444/ -e URLS_LOGIN=http://localhost:4446/login -e URLS_CONSENT=http://localhost:4446/consent oryd/hydra:v2.0.2 serve all --dev
Expand Down