From ae428697e35b2815f3ae7aa6f75e881825990bf3 Mon Sep 17 00:00:00 2001 From: Stefan Majer Date: Wed, 31 Jul 2024 10:01:30 +0200 Subject: [PATCH] Add missing tx.Rollback in case of errors (#160) --- sql.go | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/sql.go b/sql.go index d9e3b90..cc043a2 100644 --- a/sql.go +++ b/sql.go @@ -73,11 +73,19 @@ func (s *sql) CreatePrefix(ctx context.Context, prefix Prefix, namespace string) if err != nil { return Prefix{}, fmt.Errorf("unable to start transaction:%w", err) } + // Defer a rollback in case anything fails. + defer func() { + _ = tx.Rollback() + }() + _, err = tx.ExecContext(ctx, "INSERT INTO "+getTableName(namespace)+"(cidr, prefix) VALUES ($1, $2)", prefix.Cidr, pj) if err != nil { return Prefix{}, fmt.Errorf("unable to insert prefix:%w", err) } - return prefix, tx.Commit() + if err := tx.Commit(); err != nil { + return Prefix{}, err + } + return prefix, nil } func (s *sql) ReadPrefix(ctx context.Context, prefix, namespace string) (Prefix, error) { @@ -157,6 +165,11 @@ func (s *sql) UpdatePrefix(ctx context.Context, prefix Prefix, namespace string) if err != nil { return Prefix{}, fmt.Errorf("unable to start transaction:%w", err) } + // Defer a rollback in case anything fails. + defer func() { + _ = tx.Rollback() + }() + result, err := tx.ExecContext(ctx, "SELECT prefix FROM "+getTableName(namespace)+" WHERE cidr=$1 AND prefix->>'Version'=$2 FOR UPDATE", prefix.Cidr, oldVersion) if err != nil { return Prefix{}, fmt.Errorf("%w: unable to select for update prefix:%s", ErrOptimisticLockError, prefix.Cidr) @@ -183,7 +196,10 @@ func (s *sql) UpdatePrefix(ctx context.Context, prefix Prefix, namespace string) _ = tx.Rollback() return Prefix{}, fmt.Errorf("%w: updatePrefix did not effect any row", ErrOptimisticLockError) } - return prefix, tx.Commit() + if err := tx.Commit(); err != nil { + return Prefix{}, err + } + return prefix, nil } func (s *sql) DeletePrefix(ctx context.Context, prefix Prefix, namespace string) (Prefix, error) { @@ -194,11 +210,19 @@ func (s *sql) DeletePrefix(ctx context.Context, prefix Prefix, namespace string) if err != nil { return Prefix{}, fmt.Errorf("unable to start transaction: %w", err) } + // Defer a rollback in case anything fails. + defer func() { + _ = tx.Rollback() + }() + _, err = tx.ExecContext(ctx, "DELETE from "+getTableName(namespace)+" WHERE cidr=$1", prefix.Cidr) if err != nil { return Prefix{}, fmt.Errorf("unable delete prefix: %w", err) } - return prefix, tx.Commit() + if err := tx.Commit(); err != nil { + return Prefix{}, err + } + return prefix, nil } func (s *sql) Name() string { return "postgres" @@ -241,6 +265,10 @@ func (s *sql) DeleteNamespace(ctx context.Context, namespace string) error { if err != nil { return fmt.Errorf("unable to start transaction: %w", err) } + // Defer a rollback in case anything fails. + defer func() { + _ = tx.Rollback() + }() _, err = tx.ExecContext(ctx, "DROP TABLE "+getTableName(namespace)) if err != nil { return fmt.Errorf("unable delete prefix:%w", err)