diff --git a/query_base.go b/query_base.go index a1e357fe5..a5aa07710 100644 --- a/query_base.go +++ b/query_base.go @@ -37,10 +37,10 @@ var ( _ IConn = (*sql.Tx)(nil) _ IConn = (*DB)(nil) _ IConn = (*Conn)(nil) - _ IConn = (*Tx)(nil) + _ IConn = (Tx{}) ) -// IDB is a common interface for *bun.DB, bun.Conn, and bun.Tx. +// IDB is a common interface for *bun.DB, and bun.Tx. type IDB interface { IConn Dialect() schema.Dialect @@ -60,9 +60,8 @@ type IDB interface { } var ( - _ IConn = (*DB)(nil) - _ IConn = (*Conn)(nil) - _ IConn = (*Tx)(nil) + _ IDB = (*DB)(nil) + _ IDB = Tx{} ) type baseQuery struct { diff --git a/query_select.go b/query_select.go index 27a3208fc..a07cb5d5f 100644 --- a/query_select.go +++ b/query_select.go @@ -344,9 +344,9 @@ func (q *SelectQuery) selectJoins(ctx context.Context, joins []relationJoin) err case schema.HasOneRelation, schema.BelongsToRelation: err = q.selectJoins(ctx, j.JoinModel.getJoins()) case schema.HasManyRelation: - err = j.selectMany(ctx, q.db.NewSelect()) + err = j.selectMany(ctx, q.db.NewSelect().Conn(q.conn)) case schema.ManyToManyRelation: - err = j.selectM2M(ctx, q.db.NewSelect()) + err = j.selectM2M(ctx, q.db.NewSelect().Conn(q.conn)) default: panic("not reached") }