Skip to content

Commit

Permalink
Merge pull request #89 from kozmod/bugfix/88_context_operator_B_compa…
Browse files Browse the repository at this point in the history
…rable

[#88] fix default `ContextOperator`: change `B` to comparable
  • Loading branch information
kozmod authored Jan 23, 2024
2 parents d903aba + 4465a69 commit f0263dc
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 48 deletions.
20 changes: 11 additions & 9 deletions operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,26 @@ import (
)

// ContextOperator inject and extract Tx from context.Context.
type ContextOperator[B any, T Tx] struct {
beginner *B
//
// Default ContextOperator uses comparable key for context.Context value operation.
type ContextOperator[K comparable, T Tx] struct {
key K
}

// NewContextOperator returns new ContextOperator.
func NewContextOperator[B any, T Tx](b *B) *ContextOperator[B, T] {
return &ContextOperator[B, T]{
beginner: b,
func NewContextOperator[K comparable, T Tx](key K) *ContextOperator[K, T] {
return &ContextOperator[K, T]{
key: key,
}
}

// Inject returns new context.Context contains Tx as value.
func (p *ContextOperator[B, T]) Inject(ctx context.Context, tx T) context.Context {
return context.WithValue(ctx, p.beginner, tx)
func (p *ContextOperator[K, T]) Inject(ctx context.Context, tx T) context.Context {
return context.WithValue(ctx, p.key, tx)
}

// Extract returns Tx extracted from context.Context.
func (p *ContextOperator[B, T]) Extract(ctx context.Context) (T, bool) {
c, ok := ctx.Value(p.beginner).(T)
func (p *ContextOperator[K, T]) Extract(ctx context.Context) (T, bool) {
c, ok := ctx.Value(p.key).(T)
return c, ok
}
6 changes: 3 additions & 3 deletions stdlib/transactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type dbWrapper struct {
}

// BeginTx starts a transaction.
func (db *dbWrapper) BeginTx(ctx context.Context, opts ...oniontx.Option[*sql.TxOptions]) (*txWrapper, error) {
func (db dbWrapper) BeginTx(ctx context.Context, opts ...oniontx.Option[*sql.TxOptions]) (*txWrapper, error) {
var txOptions sql.TxOptions
for _, opt := range opts {
opt.Apply(&txOptions)
Expand Down Expand Up @@ -58,12 +58,12 @@ type Transactor struct {
// NewTransactor returns new Transactor.
func NewTransactor(db *sql.DB) *Transactor {
var (
base = &dbWrapper{DB: db}
base = dbWrapper{DB: db}
operator = oniontx.NewContextOperator[*dbWrapper, *txWrapper](&base)
transactor = Transactor{
operator: operator,
transactor: oniontx.NewTransactor[*dbWrapper, *txWrapper, *sql.TxOptions](
base,
&base,
operator,
),
}
Expand Down
129 changes: 93 additions & 36 deletions transactor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,51 @@ import (

func Test_CtxOperator(t *testing.T) {
t.Run("success_extract_committer", func(t *testing.T) {
var (
ctx = context.Background()
c = committerMock{}
b = &beginnerMock[*committerMock, any]{}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
)
ctx = o.Inject(ctx, &c)
extracted, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, extracted == &c)
t.Run("extract_pointer", func(t *testing.T) {
var (
ctx = context.Background()
c = committerMock{}
b = beginnerMock[*committerMock, any]{}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
)
ctx = o.Inject(ctx, &c)
extracted, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, extracted == &c)
})
t.Run("extract_value", func(t *testing.T) {
var (
ctx = context.Background()
c = committerValueMock{
committer: &committerMock{},
}
b = beginnerValueMock[committerValueMock, any]{
beginner: &beginnerMock[committerValueMock, any]{},
}
o = NewContextOperator[beginnerValueMock[committerValueMock, any], committerValueMock](b)
)
ctx = o.Inject(ctx, c)
extracted, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, extracted == c)
})
t.Run("extract_nil_value", func(t *testing.T) {
var (
ctx = context.Background()
c = committerValueMock{
committer: nil,
}
b = beginnerValueMock[committerValueMock, any]{
beginner: nil,
}
o = NewContextOperator[beginnerValueMock[committerValueMock, any], committerValueMock](b)
)
ctx = o.Inject(ctx, c)
extracted, ok := o.Extract(ctx)
assertTrue(t, ok)
assertTrue(t, extracted == c)
})

})
}

Expand All @@ -36,15 +71,15 @@ func Test_Transactor(t *testing.T) {
return nil
},
}
b = &beginnerMock[*committerMock, any]{
b = beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginnerCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := tr.TryGetTx(ctx)
Expand All @@ -67,20 +102,20 @@ func Test_Transactor(t *testing.T) {
return nil
},
}
b = &beginnerMock[*committerMock, any]{
b = beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginnerCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
beginner := tr.TxBeginner()
assertTrue(t, beginner != nil)
assertTrue(t, b == beginner)
assertTrue(t, &b == beginner)
return nil
})
assertTrue(t, err == nil)
Expand All @@ -99,15 +134,15 @@ func Test_Transactor(t *testing.T) {
return nil
},
}
b = &beginnerMock[*committerMock, any]{
b = beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginnerCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
Expand All @@ -129,9 +164,9 @@ func Test_Transactor(t *testing.T) {
return nil
},
}
b = &beginnerMock[*committerMock, any]{}
b = beginnerMock[*committerMock, any]{}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
ctx = o.Inject(ctx, &c)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
Expand All @@ -155,14 +190,14 @@ func Test_Transactor(t *testing.T) {
return expError
},
}
b = &beginnerMock[*committerMock, any]{
b = beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
Expand All @@ -187,15 +222,15 @@ func Test_Transactor(t *testing.T) {
return nil
},
}
b = &beginnerMock[*committerMock, any]{
b = beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
Expand All @@ -222,15 +257,15 @@ func Test_Transactor(t *testing.T) {
return rollbackErr
},
}
b = &beginnerMock[*committerMock, any]{
b = beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
Expand All @@ -256,15 +291,15 @@ func Test_Transactor(t *testing.T) {
return nil
},
}
b = &beginnerMock[*committerMock, any]{
b = beginnerMock[*committerMock, any]{
beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) {
beginCalled = true
assertTrue(t, opts == nil)
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
tx, ok := o.Extract(ctx)
Expand Down Expand Up @@ -300,7 +335,7 @@ func Test_Transactor(t *testing.T) {
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
Expand All @@ -327,7 +362,7 @@ func Test_Transactor(t *testing.T) {
return nil, expError
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
Expand Down Expand Up @@ -435,7 +470,7 @@ func Test_Transactor_recursive_call(t *testing.T) {
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
)
err := tr.WithinTx(ctx, func(ctx context.Context) error {
Expand Down Expand Up @@ -472,7 +507,7 @@ func Test_Transactor_recursive_call(t *testing.T) {
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
)

Expand Down Expand Up @@ -535,7 +570,7 @@ func Test_Transactor_recursive_call(t *testing.T) {
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
)

Expand Down Expand Up @@ -600,7 +635,7 @@ func Test_Transactor_recursive_call(t *testing.T) {
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
)

Expand Down Expand Up @@ -667,7 +702,7 @@ func Test_Transactor_recursive_call(t *testing.T) {
return &c, nil
},
}
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b)
o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b)
tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o)
)

Expand Down Expand Up @@ -712,7 +747,7 @@ func Test_Transactor_recursive_call(t *testing.T) {
})
}

// beginnerMock was added to avoid to use external dependencies for mocking
// beginnerMock was added to avoid to use external dependencies for mocking (pointer receiver).
type beginnerMock[T Tx, O any] struct {
beginFn func(ctx context.Context, opts ...Option[O]) (T, error)
}
Expand All @@ -721,7 +756,7 @@ func (b *beginnerMock[T, O]) BeginTx(ctx context.Context, opts ...Option[O]) (T,
return b.beginFn(ctx, opts...)
}

// committerMock was added to avoid to use external dependencies for mocking
// committerMock was added to avoid to use external dependencies for mocking (pointer receiver).
type committerMock struct {
commitFn func(ctx context.Context) error
rollbackFn func(ctx context.Context) error
Expand All @@ -735,6 +770,28 @@ func (c *committerMock) Rollback(ctx context.Context) error {
return c.rollbackFn(ctx)
}

// beginnerValueMock was added to avoid to use external dependencies for mocking (value receiver).
type beginnerValueMock[T Tx, O any] struct {
beginner *beginnerMock[T, O]
}

func (b beginnerValueMock[T, O]) BeginTx(ctx context.Context, opts ...Option[O]) (T, error) {
return b.beginner.beginFn(ctx, opts...)
}

// committerValueMock was added to avoid to use external dependencies for mocking (value receiver).
type committerValueMock struct {
committer *committerMock
}

func (c committerValueMock) Commit(ctx context.Context) error {
return c.committer.commitFn(ctx)
}

func (c committerValueMock) Rollback(ctx context.Context) error {
return c.committer.commitFn(ctx)
}

// assertTrue was added to avoid to use external dependencies for mocking
func assertTrue(t *testing.T, val bool) {
t.Helper()
Expand Down

0 comments on commit f0263dc

Please sign in to comment.