diff --git a/operator.go b/operator.go index 7e2653c..e647512 100644 --- a/operator.go +++ b/operator.go @@ -5,24 +5,26 @@ import ( ) // ContextOperator inject and extract Tx from context.Context. -type ContextOperator[B any, T Tx] struct { - beginner *B +// +// Default ContextOperator uses comparable key for context.Context value operation. +type ContextOperator[K comparable, T Tx] struct { + key K } // NewContextOperator returns new ContextOperator. -func NewContextOperator[B any, T Tx](b *B) *ContextOperator[B, T] { - return &ContextOperator[B, T]{ - beginner: b, +func NewContextOperator[K comparable, T Tx](key K) *ContextOperator[K, T] { + return &ContextOperator[K, T]{ + key: key, } } // Inject returns new context.Context contains Tx as value. -func (p *ContextOperator[B, T]) Inject(ctx context.Context, tx T) context.Context { - return context.WithValue(ctx, p.beginner, tx) +func (p *ContextOperator[K, T]) Inject(ctx context.Context, tx T) context.Context { + return context.WithValue(ctx, p.key, tx) } // Extract returns Tx extracted from context.Context. -func (p *ContextOperator[B, T]) Extract(ctx context.Context) (T, bool) { - c, ok := ctx.Value(p.beginner).(T) +func (p *ContextOperator[K, T]) Extract(ctx context.Context) (T, bool) { + c, ok := ctx.Value(p.key).(T) return c, ok } diff --git a/stdlib/transactor.go b/stdlib/transactor.go index 51a06bc..48f7284 100644 --- a/stdlib/transactor.go +++ b/stdlib/transactor.go @@ -25,7 +25,7 @@ type dbWrapper struct { } // BeginTx starts a transaction. -func (db *dbWrapper) BeginTx(ctx context.Context, opts ...oniontx.Option[*sql.TxOptions]) (*txWrapper, error) { +func (db dbWrapper) BeginTx(ctx context.Context, opts ...oniontx.Option[*sql.TxOptions]) (*txWrapper, error) { var txOptions sql.TxOptions for _, opt := range opts { opt.Apply(&txOptions) @@ -58,12 +58,12 @@ type Transactor struct { // NewTransactor returns new Transactor. func NewTransactor(db *sql.DB) *Transactor { var ( - base = &dbWrapper{DB: db} + base = dbWrapper{DB: db} operator = oniontx.NewContextOperator[*dbWrapper, *txWrapper](&base) transactor = Transactor{ operator: operator, transactor: oniontx.NewTransactor[*dbWrapper, *txWrapper, *sql.TxOptions]( - base, + &base, operator, ), } diff --git a/transactor_test.go b/transactor_test.go index ca39c5b..57319f4 100644 --- a/transactor_test.go +++ b/transactor_test.go @@ -10,16 +10,51 @@ import ( func Test_CtxOperator(t *testing.T) { t.Run("success_extract_committer", func(t *testing.T) { - var ( - ctx = context.Background() - c = committerMock{} - b = &beginnerMock[*committerMock, any]{} - o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) - ) - ctx = o.Inject(ctx, &c) - extracted, ok := o.Extract(ctx) - assertTrue(t, ok) - assertTrue(t, extracted == &c) + t.Run("extract_pointer", func(t *testing.T) { + var ( + ctx = context.Background() + c = committerMock{} + b = beginnerMock[*committerMock, any]{} + o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) + ) + ctx = o.Inject(ctx, &c) + extracted, ok := o.Extract(ctx) + assertTrue(t, ok) + assertTrue(t, extracted == &c) + }) + t.Run("extract_value", func(t *testing.T) { + var ( + ctx = context.Background() + c = committerValueMock{ + committer: &committerMock{}, + } + b = beginnerValueMock[committerValueMock, any]{ + beginner: &beginnerMock[committerValueMock, any]{}, + } + o = NewContextOperator[beginnerValueMock[committerValueMock, any], committerValueMock](b) + ) + ctx = o.Inject(ctx, c) + extracted, ok := o.Extract(ctx) + assertTrue(t, ok) + assertTrue(t, extracted == c) + }) + t.Run("extract_nil_value", func(t *testing.T) { + var ( + ctx = context.Background() + c = committerValueMock{ + committer: nil, + } + b = beginnerValueMock[committerValueMock, any]{ + beginner: nil, + } + o = NewContextOperator[beginnerValueMock[committerValueMock, any], committerValueMock](b) + ) + ctx = o.Inject(ctx, c) + extracted, ok := o.Extract(ctx) + assertTrue(t, ok) + assertTrue(t, extracted == c) + }) + }) } @@ -36,7 +71,7 @@ func Test_Transactor(t *testing.T) { return nil }, } - b = &beginnerMock[*committerMock, any]{ + b = beginnerMock[*committerMock, any]{ beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) { beginnerCalled = true assertTrue(t, opts == nil) @@ -44,7 +79,7 @@ func Test_Transactor(t *testing.T) { }, } o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) - tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o) + tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o) ) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := tr.TryGetTx(ctx) @@ -67,7 +102,7 @@ func Test_Transactor(t *testing.T) { return nil }, } - b = &beginnerMock[*committerMock, any]{ + b = beginnerMock[*committerMock, any]{ beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) { beginnerCalled = true assertTrue(t, opts == nil) @@ -75,12 +110,12 @@ func Test_Transactor(t *testing.T) { }, } o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) - tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o) + tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o) ) err := tr.WithinTx(ctx, func(ctx context.Context) error { beginner := tr.TxBeginner() assertTrue(t, beginner != nil) - assertTrue(t, b == beginner) + assertTrue(t, &b == beginner) return nil }) assertTrue(t, err == nil) @@ -99,7 +134,7 @@ func Test_Transactor(t *testing.T) { return nil }, } - b = &beginnerMock[*committerMock, any]{ + b = beginnerMock[*committerMock, any]{ beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) { beginnerCalled = true assertTrue(t, opts == nil) @@ -107,7 +142,7 @@ func Test_Transactor(t *testing.T) { }, } o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) - tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o) + tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o) ) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) @@ -129,9 +164,9 @@ func Test_Transactor(t *testing.T) { return nil }, } - b = &beginnerMock[*committerMock, any]{} + b = beginnerMock[*committerMock, any]{} o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) - tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o) + tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o) ) ctx = o.Inject(ctx, &c) err := tr.WithinTx(ctx, func(ctx context.Context) error { @@ -155,14 +190,14 @@ func Test_Transactor(t *testing.T) { return expError }, } - b = &beginnerMock[*committerMock, any]{ + b = beginnerMock[*committerMock, any]{ beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) { assertTrue(t, opts == nil) return &c, nil }, } o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) - tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o) + tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o) ) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) @@ -187,7 +222,7 @@ func Test_Transactor(t *testing.T) { return nil }, } - b = &beginnerMock[*committerMock, any]{ + b = beginnerMock[*committerMock, any]{ beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) { beginCalled = true assertTrue(t, opts == nil) @@ -195,7 +230,7 @@ func Test_Transactor(t *testing.T) { }, } o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) - tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o) + tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o) ) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) @@ -222,7 +257,7 @@ func Test_Transactor(t *testing.T) { return rollbackErr }, } - b = &beginnerMock[*committerMock, any]{ + b = beginnerMock[*committerMock, any]{ beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) { beginCalled = true assertTrue(t, opts == nil) @@ -230,7 +265,7 @@ func Test_Transactor(t *testing.T) { }, } o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) - tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o) + tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o) ) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) @@ -256,7 +291,7 @@ func Test_Transactor(t *testing.T) { return nil }, } - b = &beginnerMock[*committerMock, any]{ + b = beginnerMock[*committerMock, any]{ beginFn: func(ctx context.Context, opts ...Option[any]) (*committerMock, error) { beginCalled = true assertTrue(t, opts == nil) @@ -264,7 +299,7 @@ func Test_Transactor(t *testing.T) { }, } o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) - tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o) + tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](&b, o) ) err := tr.WithinTx(ctx, func(ctx context.Context) error { tx, ok := o.Extract(ctx) @@ -300,7 +335,7 @@ func Test_Transactor(t *testing.T) { return &c, nil }, } - o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) + o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b) tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o) ) err := tr.WithinTx(ctx, func(ctx context.Context) error { @@ -327,7 +362,7 @@ func Test_Transactor(t *testing.T) { return nil, expError }, } - o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) + o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b) tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o) ) err := tr.WithinTx(ctx, func(ctx context.Context) error { @@ -435,7 +470,7 @@ func Test_Transactor_recursive_call(t *testing.T) { return &c, nil }, } - o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) + o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b) tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o) ) err := tr.WithinTx(ctx, func(ctx context.Context) error { @@ -472,7 +507,7 @@ func Test_Transactor_recursive_call(t *testing.T) { return &c, nil }, } - o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) + o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b) tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o) ) @@ -535,7 +570,7 @@ func Test_Transactor_recursive_call(t *testing.T) { return &c, nil }, } - o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) + o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b) tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o) ) @@ -600,7 +635,7 @@ func Test_Transactor_recursive_call(t *testing.T) { return &c, nil }, } - o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) + o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b) tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o) ) @@ -667,7 +702,7 @@ func Test_Transactor_recursive_call(t *testing.T) { return &c, nil }, } - o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](&b) + o = NewContextOperator[*beginnerMock[*committerMock, any], *committerMock](b) tr = NewTransactor[*beginnerMock[*committerMock, any], *committerMock, any](b, o) ) @@ -712,7 +747,7 @@ func Test_Transactor_recursive_call(t *testing.T) { }) } -// beginnerMock was added to avoid to use external dependencies for mocking +// beginnerMock was added to avoid to use external dependencies for mocking (pointer receiver). type beginnerMock[T Tx, O any] struct { beginFn func(ctx context.Context, opts ...Option[O]) (T, error) } @@ -721,7 +756,7 @@ func (b *beginnerMock[T, O]) BeginTx(ctx context.Context, opts ...Option[O]) (T, return b.beginFn(ctx, opts...) } -// committerMock was added to avoid to use external dependencies for mocking +// committerMock was added to avoid to use external dependencies for mocking (pointer receiver). type committerMock struct { commitFn func(ctx context.Context) error rollbackFn func(ctx context.Context) error @@ -735,6 +770,28 @@ func (c *committerMock) Rollback(ctx context.Context) error { return c.rollbackFn(ctx) } +// beginnerValueMock was added to avoid to use external dependencies for mocking (value receiver). +type beginnerValueMock[T Tx, O any] struct { + beginner *beginnerMock[T, O] +} + +func (b beginnerValueMock[T, O]) BeginTx(ctx context.Context, opts ...Option[O]) (T, error) { + return b.beginner.beginFn(ctx, opts...) +} + +// committerValueMock was added to avoid to use external dependencies for mocking (value receiver). +type committerValueMock struct { + committer *committerMock +} + +func (c committerValueMock) Commit(ctx context.Context) error { + return c.committer.commitFn(ctx) +} + +func (c committerValueMock) Rollback(ctx context.Context) error { + return c.committer.commitFn(ctx) +} + // assertTrue was added to avoid to use external dependencies for mocking func assertTrue(t *testing.T, val bool) { t.Helper()