diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index bf4e82f4c..d5556b35d 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -269,7 +269,7 @@ func TestDB(t *testing.T) { {testUpsert}, {testMultiUpdate}, {testUpdateWithSkipupdateTag}, - {testTxScanAndCount}, + {testScanAndCount}, {testEmbedModelValue}, {testEmbedModelPointer}, {testJSONMarshaler}, @@ -1301,7 +1301,7 @@ func testUpdateWithSkipupdateTag(t *testing.T, db *bun.DB) { require.NotEqual(t, model.CreatedAt.UTC(), model_.CreatedAt.UTC()) } -func testTxScanAndCount(t *testing.T, db *bun.DB) { +func testScanAndCount(t *testing.T, db *bun.DB) { type Model struct { ID int64 `bun:",pk,autoincrement"` Str string @@ -1312,14 +1312,33 @@ func testTxScanAndCount(t *testing.T, db *bun.DB) { err := db.ResetModel(ctx, (*Model)(nil)) require.NoError(t, err) - for i := 0; i < 100; i++ { - err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { - var models []Model - _, err := tx.NewSelect().Model(&models).ScanAndCount(ctx) - return err - }) + t.Run("tx", func(t *testing.T) { + for i := 0; i < 100; i++ { + err := db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + var models []Model + count, err := tx.NewSelect().Model(&models).ScanAndCount(ctx) + require.NoError(t, err) + require.Equal(t, 0, count) + return err + }) + require.NoError(t, err) + } + }) + + t.Run("no limit", func(t *testing.T) { + src := []Model{ + {Str: "str1"}, + {Str: "str2"}, + } + _, err = db.NewInsert().Model(&src).Exec(ctx) require.NoError(t, err) - } + + var dest []Model + count, err := db.NewSelect().Model(&dest).ScanAndCount(ctx) + require.NoError(t, err) + require.Equal(t, 2, count) + require.Equal(t, 2, len(dest)) + }) } func testEmbedModelValue(t *testing.T, db *bun.DB) { diff --git a/query_select.go b/query_select.go index 1c5116c96..b61bcfaf0 100644 --- a/query_select.go +++ b/query_select.go @@ -48,8 +48,6 @@ func NewSelectQuery(db *DB) *SelectQuery { conn: db.DB, }, }, - offset: -1, - limit: -1, } } @@ -400,7 +398,6 @@ func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQ if len(join.Relation.Condition) > 0 { apply1 = func(q *SelectQuery) *SelectQuery { - for _, opt := range join.Relation.Condition { q.addWhere(schema.SafeQueryWithSep(opt, nil, " AND ")) } @@ -601,7 +598,7 @@ func (q *SelectQuery) appendQuery( } if fmter.Dialect().Features().Has(feature.OffsetFetch) { - if q.limit >= 0 && q.offset >= 0 { + if q.limit > 0 && q.offset > 0 { b = append(b, " OFFSET "...) b = strconv.AppendInt(b, int64(q.offset), 10) b = append(b, " ROWS"...) @@ -609,23 +606,23 @@ func (q *SelectQuery) appendQuery( b = append(b, " FETCH NEXT "...) b = strconv.AppendInt(b, int64(q.limit), 10) b = append(b, " ROWS ONLY"...) - } else if q.limit >= 0 { + } else if q.limit > 0 { b = append(b, " OFFSET 0 ROWS"...) b = append(b, " FETCH NEXT "...) b = strconv.AppendInt(b, int64(q.limit), 10) b = append(b, " ROWS ONLY"...) - } else if q.offset >= 0 { + } else if q.offset > 0 { b = append(b, " OFFSET "...) b = strconv.AppendInt(b, int64(q.offset), 10) b = append(b, " ROWS"...) } } else { - if q.limit >= 0 { + if q.limit > 0 { b = append(b, " LIMIT "...) b = strconv.AppendInt(b, int64(q.limit), 10) } - if q.offset >= 0 { + if q.offset > 0 { b = append(b, " OFFSET "...) b = strconv.AppendInt(b, int64(q.offset), 10) }