Skip to content

Commit

Permalink
Merge pull request #73 from xataio/add-tx-support-to-postgres-lib
Browse files Browse the repository at this point in the history
Add tx support to postgres lib
  • Loading branch information
eminano authored Sep 23, 2024
2 parents accab0d + 4a29918 commit 8642f97
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 21 deletions.
10 changes: 9 additions & 1 deletion internal/postgres/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,23 @@ package postgres
import (
"errors"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)

var ErrConnTimeout = errors.New("connection timeout")
var (
ErrConnTimeout = errors.New("connection timeout")
ErrNoRows = errors.New("no rows")
)

func mapError(err error) error {
if pgconn.Timeout(err) {
return ErrConnTimeout
}

if errors.Is(err, pgx.ErrNoRows) {
return ErrNoRows
}

return err
}
5 changes: 5 additions & 0 deletions internal/postgres/mocks/mock_pg_querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type Querier struct {
QueryRowFn func(ctx context.Context, query string, args ...any) postgres.Row
QueryFn func(ctx context.Context, query string, args ...any) (postgres.Rows, error)
ExecFn func(context.Context, string, ...any) (postgres.CommandTag, error)
ExecInTxFn func(context.Context, func(tx postgres.Tx) error) error
CloseFn func(context.Context) error
}

Expand All @@ -27,6 +28,10 @@ func (m *Querier) Exec(ctx context.Context, query string, args ...any) (postgres
return m.ExecFn(ctx, query, args...)
}

func (m *Querier) ExecInTx(ctx context.Context, fn func(tx postgres.Tx) error) error {
return m.ExecInTxFn(ctx, fn)
}

func (m *Querier) Close(ctx context.Context) error {
return m.CloseFn(ctx)
}
35 changes: 19 additions & 16 deletions internal/postgres/pg_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,12 @@ import (
"fmt"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)

type Conn struct {
conn *pgx.Conn
}

type Row interface {
pgx.Row
}

type Rows interface {
pgx.Rows
}

type CommandTag struct {
pgconn.CommandTag
}

func NewConn(ctx context.Context, url string) (*Conn, error) {
pgCfg, err := pgx.ParseConfig(url)
if err != nil {
Expand All @@ -41,16 +28,32 @@ func NewConn(ctx context.Context, url string) (*Conn, error) {
}

func (c *Conn) QueryRow(ctx context.Context, query string, args ...any) Row {
return c.conn.QueryRow(ctx, query, args...)
row := c.conn.QueryRow(ctx, query, args...)
return &mappedRow{inner: row}
}

func (c *Conn) Query(ctx context.Context, query string, args ...any) (Rows, error) {
return c.conn.Query(ctx, query, args...)
rows, err := c.conn.Query(ctx, query, args...)
return rows, mapError(err)
}

func (c *Conn) Exec(ctx context.Context, query string, args ...any) (CommandTag, error) {
tag, err := c.conn.Exec(ctx, query, args...)
return CommandTag{tag}, err
return CommandTag{tag}, mapError(err)
}

func (c *Conn) ExecInTx(ctx context.Context, fn func(Tx) error) error {
tx, err := c.conn.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
return mapError(err)
}

if err := fn(&Txn{Tx: tx}); err != nil {
tx.Rollback(ctx)
return mapError(err)
}

return tx.Commit(ctx)
}

func (c *Conn) Close(ctx context.Context) error {
Expand Down
23 changes: 20 additions & 3 deletions internal/postgres/pg_conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"fmt"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)

Expand All @@ -28,16 +29,32 @@ func NewConnPool(ctx context.Context, url string) (*Pool, error) {
}

func (c *Pool) QueryRow(ctx context.Context, query string, args ...any) Row {
return c.Pool.QueryRow(ctx, query, args...)
row := c.Pool.QueryRow(ctx, query, args...)
return &mappedRow{inner: row}
}

func (c *Pool) Query(ctx context.Context, query string, args ...any) (Rows, error) {
return c.Pool.Query(ctx, query, args...)
rows, err := c.Pool.Query(ctx, query, args...)
return rows, mapError(err)
}

func (c *Pool) Exec(ctx context.Context, query string, args ...any) (CommandTag, error) {
tag, err := c.Pool.Exec(ctx, query, args...)
return CommandTag{tag}, err
return CommandTag{tag}, mapError(err)
}

func (c *Pool) ExecInTx(ctx context.Context, fn func(Tx) error) error {
tx, err := c.Pool.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
return mapError(err)
}

if err := fn(&Txn{Tx: tx}); err != nil {
tx.Rollback(ctx)
return mapError(err)
}

return tx.Commit(ctx)
}

func (c *Pool) Close(_ context.Context) error {
Expand Down
35 changes: 34 additions & 1 deletion internal/postgres/pg_querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,44 @@

package postgres

import "context"
import (
"context"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)

type Querier interface {
Query(ctx context.Context, query string, args ...any) (Rows, error)
QueryRow(ctx context.Context, query string, args ...any) Row
Exec(ctx context.Context, query string, args ...any) (CommandTag, error)
ExecInTx(ctx context.Context, fn func(tx Tx) error) error
Close(ctx context.Context) error
}

type Row interface {
pgx.Row
}

type Rows interface {
pgx.Rows
}

type Tx interface {
Query(ctx context.Context, query string, args ...any) (Rows, error)
QueryRow(ctx context.Context, query string, args ...any) Row
Exec(ctx context.Context, query string, args ...any) (CommandTag, error)
}

type CommandTag struct {
pgconn.CommandTag
}

type mappedRow struct {
inner Row
}

func (mr *mappedRow) Scan(dest ...any) error {
err := mr.inner.Scan(dest...)
return mapError(err)
}
28 changes: 28 additions & 0 deletions internal/postgres/pg_tx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// SPDX-License-Identifier: Apache-2.0

package postgres

import (
"context"

"github.com/jackc/pgx/v5"
)

type Txn struct {
pgx.Tx
}

func (t *Txn) QueryRow(ctx context.Context, query string, args ...any) Row {
row := t.Tx.QueryRow(ctx, query, args...)
return &mappedRow{inner: row}
}

func (t *Txn) Query(ctx context.Context, query string, args ...any) (Rows, error) {
rows, err := t.Tx.Query(ctx, query, args...)
return rows, mapError(err)
}

func (t *Txn) Exec(ctx context.Context, query string, args ...any) (CommandTag, error) {
tag, err := t.Tx.Exec(ctx, query, args...)
return CommandTag{tag}, mapError(err)
}

0 comments on commit 8642f97

Please sign in to comment.