Skip to content

Commit

Permalink
fix: enhance transaction functionality (#1281)
Browse files Browse the repository at this point in the history
### Motivation
Various fixes and refactoring for transaction.

### Modifications

* Employ context in the `Commit` and `Abort` methods
* Use client operation timeout
* Use `atomic.Int32` for the state
* Make all state reads atomic
* Clean up and improve error messages
  • Loading branch information
reugn authored Dec 17, 2024
1 parent 0612938 commit c9245bc
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 94 deletions.
2 changes: 1 addition & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Run `make lint` from the root path of this project to check code with golangci-lint.

run:
deadline: 6m
timeout: 5m

linters:
# Uncomment this line to run only the explicitly enabled linters
Expand Down
4 changes: 2 additions & 2 deletions pulsar/consumer_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@ func (pc *partitionConsumer) internalAckWithTxn(req *ackWithTxnRequest) {
req.err = newError(ConsumerClosed, "Failed to ack by closing or closed consumer")
return
}
if req.Transaction.state != TxnOpen {
pc.log.WithField("state", req.Transaction.state).Error("Failed to ack by a non-open transaction.")
if req.Transaction.state.Load() != int32(TxnOpen) {
pc.log.WithField("state", req.Transaction.state.Load()).Error("Failed to ack by a non-open transaction.")
req.err = newError(InvalidStatus, "Failed to ack by a non-open transaction.")
return
}
Expand Down
4 changes: 2 additions & 2 deletions pulsar/producer_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -1135,8 +1135,8 @@ func (p *partitionProducer) prepareTransaction(sr *sendRequest) error {
}

txn := (sr.msg.Transaction).(*transaction)
if txn.state != TxnOpen {
p.log.WithField("state", txn.state).Error("Failed to send message" +
if txn.state.Load() != int32(TxnOpen) {
p.log.WithField("state", txn.state.Load()).Error("Failed to send message" +
" by a non-open transaction.")
return joinErrors(ErrTransaction,
fmt.Errorf("failed to send message by a non-open transaction"))
Expand Down
112 changes: 64 additions & 48 deletions pulsar/transaction_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package pulsar

import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
Expand All @@ -33,9 +35,9 @@ type subscription struct {
}

type transaction struct {
sync.Mutex
mu sync.Mutex
txnID TxnID
state TxnState
state atomic.Int32
tcClient *transactionCoordinatorClient
registerPartitions map[string]bool
registerAckSubscriptions map[subscription]bool
Expand All @@ -54,96 +56,106 @@ type transaction struct {
// 1. When the transaction is committed or aborted, a bool will be read from opsFlow chan.
// 2. When the opsCount increment from 0 to 1, a bool will be read from opsFlow chan.
opsFlow chan bool
opsCount int32
opsCount atomic.Int32
opTimeout time.Duration
log log.Logger
}

func newTransaction(id TxnID, tcClient *transactionCoordinatorClient, timeout time.Duration) *transaction {
transaction := &transaction{
txnID: id,
state: TxnOpen,
registerPartitions: make(map[string]bool),
registerAckSubscriptions: make(map[subscription]bool),
opsFlow: make(chan bool, 1),
opTimeout: 5 * time.Second,
opTimeout: tcClient.client.operationTimeout,
tcClient: tcClient,
}
//This means there are not pending requests with this transaction. The transaction can be committed or aborted.
transaction.state.Store(int32(TxnOpen))
// This means there are not pending requests with this transaction. The transaction can be committed or aborted.
transaction.opsFlow <- true
go func() {
//Set the state of the transaction to timeout after timeout
// Set the state of the transaction to timeout after timeout
<-time.After(timeout)
atomic.CompareAndSwapInt32((*int32)(&transaction.state), int32(TxnOpen), int32(TxnTimeout))
transaction.state.CompareAndSwap(int32(TxnOpen), int32(TxnTimeout))
}()
transaction.log = tcClient.log.SubLogger(log.Fields{})
return transaction
}

func (txn *transaction) GetState() TxnState {
return txn.state
return TxnState(txn.state.Load())
}

func (txn *transaction) Commit(_ context.Context) error {
if !(atomic.CompareAndSwapInt32((*int32)(&txn.state), int32(TxnOpen), int32(TxnCommitting)) ||
txn.state == TxnCommitting) {
return newError(InvalidStatus, "Expect transaction state is TxnOpen but "+txn.state.string())
func (txn *transaction) Commit(ctx context.Context) error {
if !(txn.state.CompareAndSwap(int32(TxnOpen), int32(TxnCommitting))) {
txnState := txn.state.Load()
return newError(InvalidStatus, txnStateErrorMessage(TxnOpen, TxnState(txnState)))
}

//Wait for all operations to complete
// Wait for all operations to complete
select {
case <-txn.opsFlow:
case <-ctx.Done():
txn.state.Store(int32(TxnOpen))
return ctx.Err()
case <-time.After(txn.opTimeout):
txn.state.Store(int32(TxnTimeout))
return newError(TimeoutError, "There are some operations that are not completed after the timeout.")
}
//Send commit transaction command to transaction coordinator
// Send commit transaction command to transaction coordinator
err := txn.tcClient.endTxn(&txn.txnID, pb.TxnAction_COMMIT)
if err == nil {
atomic.StoreInt32((*int32)(&txn.state), int32(TxnCommitted))
txn.state.Store(int32(TxnCommitted))
} else {
if e, ok := err.(*Error); ok && (e.Result() == TransactionNoFoundError || e.Result() == InvalidStatus) {
atomic.StoreInt32((*int32)(&txn.state), int32(TxnError))
var e *Error
if errors.As(err, &e) && (e.Result() == TransactionNoFoundError || e.Result() == InvalidStatus) {
txn.state.Store(int32(TxnError))
return err
}
txn.opsFlow <- true
}
return err
}

func (txn *transaction) Abort(_ context.Context) error {
if !(atomic.CompareAndSwapInt32((*int32)(&txn.state), int32(TxnOpen), int32(TxnAborting)) ||
txn.state == TxnAborting) {
return newError(InvalidStatus, "Expect transaction state is TxnOpen but "+txn.state.string())
func (txn *transaction) Abort(ctx context.Context) error {
if !(txn.state.CompareAndSwap(int32(TxnOpen), int32(TxnAborting))) {
txnState := txn.state.Load()
return newError(InvalidStatus, txnStateErrorMessage(TxnOpen, TxnState(txnState)))
}

//Wait for all operations to complete
// Wait for all operations to complete
select {
case <-txn.opsFlow:
case <-ctx.Done():
txn.state.Store(int32(TxnOpen))
return ctx.Err()
case <-time.After(txn.opTimeout):
txn.state.Store(int32(TxnTimeout))
return newError(TimeoutError, "There are some operations that are not completed after the timeout.")
}
//Send abort transaction command to transaction coordinator
// Send abort transaction command to transaction coordinator
err := txn.tcClient.endTxn(&txn.txnID, pb.TxnAction_ABORT)
if err == nil {
atomic.StoreInt32((*int32)(&txn.state), int32(TxnAborted))
txn.state.Store(int32(TxnAborted))
} else {
if e, ok := err.(*Error); ok && (e.Result() == TransactionNoFoundError || e.Result() == InvalidStatus) {
atomic.StoreInt32((*int32)(&txn.state), int32(TxnError))
} else {
txn.opsFlow <- true
var e *Error
if errors.As(err, &e) && (e.Result() == TransactionNoFoundError || e.Result() == InvalidStatus) {
txn.state.Store(int32(TxnError))
return err
}
txn.opsFlow <- true
}
return err
}

func (txn *transaction) registerSendOrAckOp() error {
if atomic.AddInt32(&txn.opsCount, 1) == 1 {
//There are new operations that not completed
if txn.opsCount.Add(1) == 1 {
// There are new operations that were not completed
select {
case <-txn.opsFlow:
return nil
case <-time.After(txn.opTimeout):
if _, err := txn.checkIfOpen(); err != nil {
if err := txn.verifyOpen(); err != nil {
return err
}
return newError(TimeoutError, "Failed to get the semaphore to register the send/ack operation")
Expand All @@ -154,23 +166,22 @@ func (txn *transaction) registerSendOrAckOp() error {

func (txn *transaction) endSendOrAckOp(err error) {
if err != nil {
atomic.StoreInt32((*int32)(&txn.state), int32(TxnError))
txn.state.Store(int32(TxnError))
}
if atomic.AddInt32(&txn.opsCount, -1) == 0 {
//This means there are not pending send/ack requests
if txn.opsCount.Add(-1) == 0 {
// This means there are no pending send/ack requests
txn.opsFlow <- true
}
}

func (txn *transaction) registerProducerTopic(topic string) error {
isOpen, err := txn.checkIfOpen()
if !isOpen {
if err := txn.verifyOpen(); err != nil {
return err
}
_, ok := txn.registerPartitions[topic]
if !ok {
txn.Lock()
defer txn.Unlock()
txn.mu.Lock()
defer txn.mu.Unlock()
if _, ok = txn.registerPartitions[topic]; !ok {
err := txn.tcClient.addPublishPartitionToTxn(&txn.txnID, []string{topic})
if err != nil {
Expand All @@ -183,8 +194,7 @@ func (txn *transaction) registerProducerTopic(topic string) error {
}

func (txn *transaction) registerAckTopic(topic string, subName string) error {
isOpen, err := txn.checkIfOpen()
if !isOpen {
if err := txn.verifyOpen(); err != nil {
return err
}
sub := subscription{
Expand All @@ -193,8 +203,8 @@ func (txn *transaction) registerAckTopic(topic string, subName string) error {
}
_, ok := txn.registerAckSubscriptions[sub]
if !ok {
txn.Lock()
defer txn.Unlock()
txn.mu.Lock()
defer txn.mu.Unlock()
if _, ok = txn.registerAckSubscriptions[sub]; !ok {
err := txn.tcClient.addSubscriptionToTxn(&txn.txnID, topic, subName)
if err != nil {
Expand All @@ -210,14 +220,15 @@ func (txn *transaction) GetTxnID() TxnID {
return txn.txnID
}

func (txn *transaction) checkIfOpen() (bool, error) {
if txn.state == TxnOpen {
return true, nil
func (txn *transaction) verifyOpen() error {
txnState := txn.state.Load()
if txnState != int32(TxnOpen) {
return newError(InvalidStatus, txnStateErrorMessage(TxnOpen, TxnState(txnState)))
}
return false, newError(InvalidStatus, "Expect transaction state is TxnOpen but "+txn.state.string())
return nil
}

func (state TxnState) string() string {
func (state TxnState) String() string {
switch state {
case TxnOpen:
return "TxnOpen"
Expand All @@ -237,3 +248,8 @@ func (state TxnState) string() string {
return "Unknown"
}
}

//nolint:unparam
func txnStateErrorMessage(expected, actual TxnState) string {
return fmt.Sprintf("Expected transaction state: %s, actual: %s", expected, actual)
}
Loading

0 comments on commit c9245bc

Please sign in to comment.