diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index 976a70c27..b4e66d254 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -224,6 +224,7 @@ func TestDB(t *testing.T) { {testInterfaceJSON}, {testScanBytes}, {testPointers}, + {testExists}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -873,3 +874,15 @@ func testPointers(t *testing.T, db *bun.DB) { err = db.NewSelect().Model(&models2).Order("id ASC").Scan(ctx) require.NoError(t, err) } + +func testExists(t *testing.T, db *bun.DB) { + ctx := context.Background() + + exists, err := db.NewSelect().ColumnExpr("1").Exists(ctx) + require.NoError(t, err) + require.True(t, exists) + + exists, err = db.NewSelect().ColumnExpr("1").Where("1 = 0").Exists(ctx) + require.NoError(t, err) + require.False(t, exists) +} diff --git a/query_select.go b/query_select.go index 28fafa92c..03f4bc04b 100644 --- a/query_select.go +++ b/query_select.go @@ -738,7 +738,7 @@ func (q *SelectQuery) afterSelectHook(ctx context.Context) error { func (q *SelectQuery) Count(ctx context.Context) (int, error) { qq := countQuery{q} - queryBytes, err := qq.appendQuery(q.db.fmter, nil, true) + queryBytes, err := qq.AppendQuery(q.db.fmter, nil) if err != nil { return 0, err } @@ -794,6 +794,25 @@ func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (in return count, firstErr } +func (q *SelectQuery) Exists(ctx context.Context) (bool, error) { + qq := existsQuery{q} + + queryBytes, err := qq.AppendQuery(q.db.fmter, nil) + if err != nil { + return false, err + } + + query := internal.String(queryBytes) + ctx, event := q.db.beforeQuery(ctx, qq, query, nil) + + var exists bool + err = q.conn.QueryRowContext(ctx, query).Scan(&exists) + + q.db.afterQuery(ctx, event, nil, err) + + return exists, err +} + //------------------------------------------------------------------------------ type joinQuery struct { @@ -837,3 +856,22 @@ type countQuery struct { func (q countQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { return q.appendQuery(fmter, b, true) } + +//------------------------------------------------------------------------------ + +type existsQuery struct { + *SelectQuery +} + +func (q existsQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { + b = append(b, "SELECT EXISTS ("...) + + b, err = q.appendQuery(fmter, b, false) + if err != nil { + return nil, err + } + + b = append(b, ")"...) + + return b, nil +}