Skip to content

Commit

Permalink
Merge pull request #57 from kozmod/bugfix/56_fix_recursive_commit_or_…
Browse files Browse the repository at this point in the history
…raolback

[#56] fix recursive call commit or rollback
  • Loading branch information
kozmod authored Jan 15, 2024
2 parents dd95cd9 + d565539 commit 5dd394f
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 80 deletions.
18 changes: 14 additions & 4 deletions transactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,28 @@ func (t *Transactor[B, T, O]) WithinTxWithOpts(ctx context.Context, fn func(ctx
defer func() {
switch p := recover(); {
case p != nil:
if ok {
err = xerrors.Errorf("transactor - panic [%v]", p)
return
}
if rbErr := tx.Rollback(ctx); rbErr != nil {
err = xerrors.Errorf("transactor - panic [%v]: %w", p, errors.Join(rbErr, ErrRollbackFailed))
return
} else {
err = xerrors.Errorf("transactor - panic [%v]: %w", p, ErrRollbackSuccess)
}
err = xerrors.Errorf("transactor - panic [%v]: %w", p, ErrRollbackSuccess)
case err != nil:
if ok {
return
}
if rbErr := tx.Rollback(ctx); rbErr != nil {
err = xerrors.Errorf("transactor - call: %w", errors.Join(err, rbErr, ErrRollbackFailed))
return
} else {
err = xerrors.Errorf("transactor - call: %w", errors.Join(err, ErrRollbackSuccess))
}
err = xerrors.Errorf("transactor - call: %w", errors.Join(err, ErrRollbackSuccess))
default:
if ok {
return
}
if err = tx.Commit(ctx); err != nil {
err = xerrors.Errorf("transactor: %w", errors.Join(err, ErrCommitFailed))
}
Expand Down
247 changes: 171 additions & 76 deletions transactor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func Test_Transactor(t *testing.T) {
assertTrue(t, beginnerCalled)
assertTrue(t, commitCalled)
})
t.Run("success_commit_with_exists_tx", func(t *testing.T) {
t.Run("success_and_not_commit_with_exists_tx", func(t *testing.T) {
var (
ctx = context.Background()
commitCalled bool
Expand All @@ -141,7 +141,7 @@ func Test_Transactor(t *testing.T) {
return nil
})
assertTrue(t, err == nil)
assertTrue(t, commitCalled)
assertTrue(t, !commitCalled)
})
t.Run("failed_commit", func(t *testing.T) {
var (
Expand Down Expand Up @@ -321,7 +321,7 @@ func Test_Transactor(t *testing.T) {
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
_, ok := o.Extract(ctx)
assertTrue(t, !ok)
assertFalse(t, ok)
return nil
})
assertTrue(t, errors.Is(err, ErrBeginTx))
Expand Down Expand Up @@ -353,83 +353,170 @@ func Test_Transactor(t *testing.T) {
})
assertTrue(t, errors.Is(err, ErrNilTxOperator))
})
t.Run("recursive_call", func(t *testing.T) {
t.Run("success_commit", func(t *testing.T) {
var (
ctx = context.Background()
commitCalled bool
beginnerCalled bool
c = committerMock{
commitFn: func(ctx context.Context) error {
commitCalled = true
return nil
},
}
b = &beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginnerCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
trA = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
trB = trA
)
err := trA.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, &c == tx)
return trB.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, &c == tx)
return nil
})
}

// nolint: dupl
func Test_Transactor_recursive_call(t *testing.T) {
t.Run("success_commit", func(t *testing.T) {
var (
ctx = context.Background()
commitCalled bool
beginnerCalled bool
c = committerMock{
commitFn: func(ctx context.Context) error {
commitCalled = true
return nil
},
}
b = &beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginnerCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
trA = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
trB = trA
)
err := trA.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, &c == tx)
return trB.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, &c == tx)
return nil

})
})
assertTrue(t, err == nil)
assertTrue(t, beginnerCalled)
assertTrue(t, commitCalled)
})
t.Run("success_rollback", func(t *testing.T) {
var (
ctx = context.Background()
rollbackCalled bool
beginCalled bool
c = committerMock{
rollbackFn: func(ctx context.Context) error {
rollbackCalled = true
return nil
},
}
b = &beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
trA = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
trB = trA
)
err := trA.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, &c == tx)
return trB.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, &c == tx)
return fmt.Errorf("some error")
})
})
assertTrue(t, errors.Is(err, ErrRollbackSuccess))
assertTrue(t, rollbackCalled)
assertTrue(t, beginCalled)
})
t.Run("success_and_commit_on_top_lvl_func", func(t *testing.T) {
const (
ctxValTopLvl = "top_lvl"
ctxValSecondLvl = "second_lvl"
)

/*
functions to inject and check recursion level
*/
var (
ctxKey struct{}
injectLvl = func(ctx context.Context, lvl string) context.Context {
t.Helper()
return context.WithValue(ctx, ctxKey, lvl)
}

isLvlEqual = func(ctx context.Context, required string) bool {
t.Helper()
lvl, ok := ctx.Value(ctxKey).(string)
if !ok {
return false
}
return strings.EqualFold(lvl, required)
}
)

var (
ctx = context.Background()
commitCalled int
beginCalled int
c = committerMock{
commitFn: func(ctx context.Context) error {
commitCalled++
// assert that commit was called on the recursion "top" level.
assertTrue(t, isLvlEqual(ctx, ctxValTopLvl))
// assert that commit call wasn't called on the "second" recursion level.
assertFalse(t, isLvlEqual(ctx, ctxValSecondLvl))
return nil
},
}
b = &beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginCalled++
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
)

{
// inject "top" level variable in context.Context
ctx = injectLvl(ctx, ctxValTopLvl)
}

err := tr.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, &c == tx)
// inject "second" level variable in context.Context.
ctx = injectLvl(ctx, ctxValSecondLvl)

})
})
assertTrue(t, err == nil)
assertTrue(t, beginnerCalled)
assertTrue(t, commitCalled)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, &c == tx)
return nil
})
t.Run("success_rollback", func(t *testing.T) {
var (
ctx = context.Background()
rollbackCalled bool
beginCalled bool
c = committerMock{
rollbackFn: func(ctx context.Context) error {
rollbackCalled = true
return nil
},
}
b = &beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
trA = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
trB = trA
)
err := trA.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, &c == tx)
return trB.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, &c == tx)
return fmt.Errorf("some error")
})
})
assertTrue(t, errors.Is(err, ErrRollbackSuccess))
assertTrue(t, rollbackCalled)
assertTrue(t, beginCalled)
assertTrue(t, err == nil)

err = tr.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, &c == tx)
return nil
})
assertTrue(t, err == nil)

return err
})
assertTrue(t, beginCalled == 1)
assertTrue(t, err == nil)
assertTrue(t, commitCalled == 1)
})
}

Expand Down Expand Up @@ -463,3 +550,11 @@ func assertTrue(t *testing.T, val bool) {
t.Fatal()
}
}

// assertFalse was added to avoid to use external dependencies for mocking
func assertFalse(t *testing.T, val bool) {
t.Helper()
if val {
t.Fatal()
}
}

0 comments on commit 5dd394f

Please sign in to comment.