diff --git a/db.go b/db.go index 83328015f..96a39e84a 100644 --- a/db.go +++ b/db.go @@ -32,6 +32,7 @@ func WithDiscardUnknownColumns() DBOption { type DB struct { *sql.DB + dialect schema.Dialect features feature.Feature diff --git a/query_select.go b/query_select.go index c9d908bac..1b84672f9 100644 --- a/query_select.go +++ b/query_select.go @@ -8,6 +8,7 @@ import ( "fmt" "strconv" "strings" + "sync" "github.com/uptrace/bun/internal" "github.com/uptrace/bun/schema" @@ -781,6 +782,53 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) { } func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (int, error) { + if _, ok := q.conn.(*DB); ok { + return q.scanAndCountConc(ctx, dest...) + } + return q.scanAndCountSeq(ctx, dest...) +} + +func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) (int, error) { + var count int + var wg sync.WaitGroup + var mu sync.Mutex + var firstErr error + + if q.limit >= 0 { + wg.Add(1) + go func() { + defer wg.Done() + + if err := q.Scan(ctx, dest...); err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + } + + wg.Add(1) + go func() { + defer wg.Done() + + var err error + count, err = q.Count(ctx) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + } + }() + + wg.Wait() + return count, firstErr +} + +func (q *SelectQuery) scanAndCountSeq(ctx context.Context, dest ...interface{}) (int, error) { var firstErr error if q.limit >= 0 {