Skip to content

Commit

Permalink
fix: transaction error (#654)
Browse files Browse the repository at this point in the history
* fix: transaction error

* chore: update mocks

* remove mock

* fix error

---------

Co-authored-by: hwbrzzl <hwbrzzl@users.noreply.github.com>
  • Loading branch information
hwbrzzl and hwbrzzl authored Sep 24, 2024
1 parent a8d9d6c commit ef9e9d6
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 3,537 deletions.
20 changes: 8 additions & 12 deletions contracts/database/orm/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,18 @@ type Orm interface {
// Observe registers an observer with the Orm.
Observe(model any, observer Observer)
// Transaction runs a callback wrapped in a database transaction.
Transaction(txFunc func(tx Transaction) error) error
Transaction(txFunc func(tx Query) error) error
// WithContext sets the context to be used by the Orm.
WithContext(ctx context.Context) Orm
}

type Transaction interface {
Query
// Commit commits the changes in a transaction.
Commit() error
// Rollback rolls back the changes in a transaction.
Rollback() error
}

type Query interface {
// Association gets an association instance by name.
Association(association string) Association
// Begin begins a new transaction
Begin() (Transaction, error)
// Driver gets the driver for the query.
Driver() Driver
Begin() (Query, error)
// Commit commits the changes in a transaction.
Commit() error
// Count retrieve the "count" result of the query.
Count(count *int64) error
// Create inserts new record into the database.
Expand All @@ -47,6 +39,8 @@ type Query interface {
Delete(value any, conds ...any) (*Result, error)
// Distinct specifies distinct fields to query.
Distinct(args ...any) Query
// Driver gets the driver for the query.
Driver() Driver
// Exec executes raw sql
Exec(sql string, values ...any) (*Result, error)
// Exists returns true if matching records exist; otherwise, it returns false.
Expand Down Expand Up @@ -118,6 +112,8 @@ type Query interface {
Pluck(column string, dest any) error
// Raw creates a raw query.
Raw(sql string, values ...any) Query
// Rollback rolls back the changes in a transaction.
Rollback() error
// Save updates value in a database
Save(value any) error
// SaveQuietly updates value in a database without firing events
Expand Down
19 changes: 15 additions & 4 deletions database/gorm/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,17 @@ func (r *QueryImpl) Association(association string) ormcontract.Association {
return query.instance.Association(association)
}

func (r *QueryImpl) Begin() (ormcontract.Transaction, error) {
func (r *QueryImpl) Begin() (ormcontract.Query, error) {
tx := r.instance.Begin()
if tx.Error != nil {
return nil, tx.Error
}

return NewTransaction(tx, r.config, r.connection), tx.Error
return r.new(tx), nil
}

func (r *QueryImpl) Driver() ormcontract.Driver {
return ormcontract.Driver(r.instance.Dialector.Name())
func (r *QueryImpl) Commit() error {
return r.instance.Commit().Error
}

func (r *QueryImpl) Count(count *int64) error {
Expand Down Expand Up @@ -167,6 +170,10 @@ func (r *QueryImpl) Distinct(args ...any) ormcontract.Query {
return r.setConditions(conditions)
}

func (r *QueryImpl) Driver() ormcontract.Driver {
return ormcontract.Driver(r.instance.Dialector.Name())
}

func (r *QueryImpl) Exec(sql string, values ...any) (*ormcontract.Result, error) {
query := r.buildConditions()
result := query.instance.Exec(sql, values...)
Expand Down Expand Up @@ -580,6 +587,10 @@ func (r *QueryImpl) Raw(sql string, values ...any) ormcontract.Query {
return r.new(r.instance.Raw(sql, values...))
}

func (r *QueryImpl) Rollback() error {
return r.instance.Rollback().Error
}

func (r *QueryImpl) Save(value any) error {
query, err := r.refreshConnection(value)
if err != nil {
Expand Down
25 changes: 0 additions & 25 deletions database/gorm/transaction.go

This file was deleted.

2 changes: 1 addition & 1 deletion database/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (r *OrmImpl) Observe(model any, observer ormcontract.Observer) {
})
}

func (r *OrmImpl) Transaction(txFunc func(tx ormcontract.Transaction) error) error {
func (r *OrmImpl) Transaction(txFunc func(tx ormcontract.Query) error) error {
tx, err := r.Query().Begin()
if err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions database/orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (s *OrmSuite) TestTransactionSuccess() {
for _, connection := range connections {
user := User{Name: "transaction_success_user", Avatar: "transaction_success_avatar"}
user1 := User{Name: "transaction_success_user1", Avatar: "transaction_success_avatar1"}
s.Nil(s.orm.Connection(connection.String()).Transaction(func(tx contractsorm.Transaction) error {
s.Nil(s.orm.Connection(connection.String()).Transaction(func(tx contractsorm.Query) error {
s.Nil(tx.Create(&user))
s.Nil(tx.Create(&user1))

Expand All @@ -172,7 +172,7 @@ func (s *OrmSuite) TestTransactionSuccess() {

func (s *OrmSuite) TestTransactionError() {
for _, connection := range connections {
s.NotNil(s.orm.Connection(connection.String()).Transaction(func(tx contractsorm.Transaction) error {
s.NotNil(s.orm.Connection(connection.String()).Transaction(func(tx contractsorm.Query) error {
user := User{Name: "transaction_error_user", Avatar: "transaction_error_avatar"}
s.Nil(tx.Create(&user))

Expand Down
12 changes: 6 additions & 6 deletions mocks/database/orm/Orm.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

104 changes: 97 additions & 7 deletions mocks/database/orm/Query.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit ef9e9d6

Please sign in to comment.