diff --git a/db.go b/db.go index 47e654655..9b8e391dc 100644 --- a/db.go +++ b/db.go @@ -98,6 +98,10 @@ func (db *DB) NewDelete() *DeleteQuery { return NewDeleteQuery(db) } +func (db *DB) NewRaw(query string, args ...interface{}) *RawQuery { + return NewRawQuery(db, query, args...) +} + func (db *DB) NewCreateTable() *CreateTableQuery { return NewCreateTableQuery(db) } @@ -342,6 +346,10 @@ func (c Conn) NewDelete() *DeleteQuery { return NewDeleteQuery(c.db).Conn(c) } +func (c Conn) NewRaw(query string, args ...interface{}) *RawQuery { + return NewRawQuery(c.db, query, args...).Conn(c) +} + func (c Conn) NewCreateTable() *CreateTableQuery { return NewCreateTableQuery(c.db).Conn(c) } @@ -648,6 +656,10 @@ func (tx Tx) NewDelete() *DeleteQuery { return NewDeleteQuery(tx.db).Conn(tx) } +func (tx Tx) NewRaw(query string, args ...interface{}) *RawQuery { + return NewRawQuery(tx.db, query, args...).Conn(tx) +} + func (tx Tx) NewCreateTable() *CreateTableQuery { return NewCreateTableQuery(tx.db).Conn(tx) } diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index d5556b35d..bbebe77f3 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -824,10 +824,25 @@ func testSelectBool(t *testing.T, db *bun.DB) { func testRawQuery(t *testing.T, db *bun.DB) { var num int - err := db.Raw("SELECT ?", 123).Scan(ctx, &num) + err := db.NewRaw("SELECT ?", 123).Scan(ctx, &num) require.NoError(t, err) require.Equal(t, 123, num) + + _ = db.RunInTx(context.Background(), nil, func(ctx context.Context, tx bun.Tx) error { + var num int + err := db.NewRaw("SELECT ?", 456).Scan(ctx, &num) + require.NoError(t, err) + require.Equal(t, 456, num) + return nil + }) + + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + + err = conn.NewRaw("SELECT ?", 789).Scan(ctx, &num) + require.NoError(t, err) + require.Equal(t, 789, num) } func testFKViolation(t *testing.T, db *bun.DB) { diff --git a/query_base.go b/query_base.go index 45b77f028..b140ca24d 100644 --- a/query_base.go +++ b/query_base.go @@ -50,6 +50,7 @@ type IDB interface { NewInsert() *InsertQuery NewUpdate() *UpdateQuery NewDelete() *DeleteQuery + NewRaw(query string, args ...interface{}) *RawQuery NewCreateTable() *CreateTableQuery NewDropTable() *DropTableQuery NewCreateIndex() *CreateIndexQuery @@ -649,6 +650,10 @@ func (q *baseQuery) NewDelete() *DeleteQuery { return NewDeleteQuery(q.db).Conn(q.conn) } +func (q *baseQuery) NewRaw(query string, args ...interface{}) *RawQuery { + return NewRawQuery(q.db, query, args...).Conn(q.conn) +} + func (q *baseQuery) NewCreateTable() *CreateTableQuery { return NewCreateTableQuery(q.db).Conn(q.conn) } diff --git a/query_raw.go b/query_raw.go index 30ae77508..afbe12130 100644 --- a/query_raw.go +++ b/query_raw.go @@ -13,6 +13,7 @@ type RawQuery struct { args []interface{} } +// Deprecated: Use NewRaw instead. When add it to IDB, it conflicts with the sql.Conn#Raw func (db *DB) Raw(query string, args ...interface{}) *RawQuery { return &RawQuery{ baseQuery: baseQuery{ @@ -24,6 +25,22 @@ func (db *DB) Raw(query string, args ...interface{}) *RawQuery { } } +func NewRawQuery(db *DB, query string, args ...interface{}) *RawQuery { + return &RawQuery{ + baseQuery: baseQuery{ + db: db, + conn: db.DB, + }, + query: query, + args: args, + } +} + +func (q *RawQuery) Conn(db IConn) *RawQuery { + q.setConn(db) + return q +} + func (q *RawQuery) Scan(ctx context.Context, dest ...interface{}) error { if q.err != nil { return q.err