Skip to content

Commit

Permalink
Add missing tx.Rollback in case of errors (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
majst01 authored Jul 31, 2024
1 parent e141dbc commit ae42869
Showing 1 changed file with 31 additions and 3 deletions.
34 changes: 31 additions & 3 deletions sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,19 @@ func (s *sql) CreatePrefix(ctx context.Context, prefix Prefix, namespace string)
if err != nil {
return Prefix{}, fmt.Errorf("unable to start transaction:%w", err)
}
// Defer a rollback in case anything fails.
defer func() {
_ = tx.Rollback()
}()

_, err = tx.ExecContext(ctx, "INSERT INTO "+getTableName(namespace)+"(cidr, prefix) VALUES ($1, $2)", prefix.Cidr, pj)
if err != nil {
return Prefix{}, fmt.Errorf("unable to insert prefix:%w", err)
}
return prefix, tx.Commit()
if err := tx.Commit(); err != nil {
return Prefix{}, err
}
return prefix, nil
}

func (s *sql) ReadPrefix(ctx context.Context, prefix, namespace string) (Prefix, error) {
Expand Down Expand Up @@ -157,6 +165,11 @@ func (s *sql) UpdatePrefix(ctx context.Context, prefix Prefix, namespace string)
if err != nil {
return Prefix{}, fmt.Errorf("unable to start transaction:%w", err)
}
// Defer a rollback in case anything fails.
defer func() {
_ = tx.Rollback()
}()

result, err := tx.ExecContext(ctx, "SELECT prefix FROM "+getTableName(namespace)+" WHERE cidr=$1 AND prefix->>'Version'=$2 FOR UPDATE", prefix.Cidr, oldVersion)
if err != nil {
return Prefix{}, fmt.Errorf("%w: unable to select for update prefix:%s", ErrOptimisticLockError, prefix.Cidr)
Expand All @@ -183,7 +196,10 @@ func (s *sql) UpdatePrefix(ctx context.Context, prefix Prefix, namespace string)
_ = tx.Rollback()
return Prefix{}, fmt.Errorf("%w: updatePrefix did not effect any row", ErrOptimisticLockError)
}
return prefix, tx.Commit()
if err := tx.Commit(); err != nil {
return Prefix{}, err
}
return prefix, nil
}

func (s *sql) DeletePrefix(ctx context.Context, prefix Prefix, namespace string) (Prefix, error) {
Expand All @@ -194,11 +210,19 @@ func (s *sql) DeletePrefix(ctx context.Context, prefix Prefix, namespace string)
if err != nil {
return Prefix{}, fmt.Errorf("unable to start transaction: %w", err)
}
// Defer a rollback in case anything fails.
defer func() {
_ = tx.Rollback()
}()

_, err = tx.ExecContext(ctx, "DELETE from "+getTableName(namespace)+" WHERE cidr=$1", prefix.Cidr)
if err != nil {
return Prefix{}, fmt.Errorf("unable delete prefix: %w", err)
}
return prefix, tx.Commit()
if err := tx.Commit(); err != nil {
return Prefix{}, err
}
return prefix, nil
}
func (s *sql) Name() string {
return "postgres"
Expand Down Expand Up @@ -241,6 +265,10 @@ func (s *sql) DeleteNamespace(ctx context.Context, namespace string) error {
if err != nil {
return fmt.Errorf("unable to start transaction: %w", err)
}
// Defer a rollback in case anything fails.
defer func() {
_ = tx.Rollback()
}()
_, err = tx.ExecContext(ctx, "DROP TABLE "+getTableName(namespace))
if err != nil {
return fmt.Errorf("unable delete prefix:%w", err)
Expand Down

0 comments on commit ae42869

Please sign in to comment.