Skip to content

Commit

Permalink
feat(pgdriver): implement database/sql/driver.SessionResetter
Browse files Browse the repository at this point in the history
  • Loading branch information
htdvisser committed Aug 1, 2022
1 parent 1517410 commit bda298a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 0 deletions.
12 changes: 12 additions & 0 deletions driver/pgdriver/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ type Config struct {
ReadTimeout time.Duration
// Timeout for socket writes. If reached, commands fail with a timeout instead of blocking.
WriteTimeout time.Duration

// ResetSessionFunc is called prior to executing a query on a connection that has been used before.
ResetSessionFunc func(context.Context, *Conn) error
}

func newDefaultConfig() *Config {
Expand Down Expand Up @@ -173,6 +176,15 @@ func WithWriteTimeout(writeTimeout time.Duration) Option {
}
}

// WithResetSessionFunc configures a function that is called prior to executing
// a query on a connection that has been used before.
// If the func returns driver.ErrBadConn, the connection is discarded.
func WithResetSessionFunc(fn func(context.Context, *Conn) error) Option {
return func(cfg *Config) {
cfg.ResetSessionFunc = fn
}
}

func WithDSN(dsn string) Option {
return func(cfg *Config) {
opts, err := parseDSN(dsn)
Expand Down
12 changes: 12 additions & 0 deletions driver/pgdriver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,18 @@ func (cn *Conn) IsValid() bool {
return !cn.isClosed()
}

var _ driver.SessionResetter = (*Conn)(nil)

func (cn *Conn) ResetSession(ctx context.Context) error {
if cn.isClosed() {
return driver.ErrBadConn
}
if cn.cfg.ResetSessionFunc != nil {
return cn.cfg.ResetSessionFunc(ctx, cn)
}
return nil
}

func (cn *Conn) checkBadConn(err error) error {
if isBadConn(err, false) {
// Close and return driver.ErrBadConn next time the conn is used.
Expand Down
24 changes: 24 additions & 0 deletions driver/pgdriver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,30 @@ func TestConnector(t *testing.T) {
require.NoError(t, err)
}

func TestConnector_WithResetSessionFunc(t *testing.T) {
var resetCalled int

db := sql.OpenDB(pgdriver.NewConnector(
pgdriver.WithDSN(dsn()),
pgdriver.WithResetSessionFunc(func(context.Context, *pgdriver.Conn) error {
resetCalled++
return nil
}),
))

db.SetMaxOpenConns(1)

for i := 0; i < 3; i++ {
err := db.Ping()
require.NoError(t, err)
}

require.Equal(t, 2, resetCalled)

err := db.Close()
require.NoError(t, err)
}

func TestStmtSelect(t *testing.T) {
ctx := context.Background()
db := sqlDB()
Expand Down

0 comments on commit bda298a

Please sign in to comment.