Skip to content

Commit

Permalink
feat(pgdriver): add Config.ConnParams to session config params
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Sep 27, 2021
1 parent 2abf3ba commit 408caf0
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 60 deletions.
79 changes: 45 additions & 34 deletions driver/pgdriver/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"net"
"net/url"
"os"
"sort"
"strconv"
"strings"
"time"
Expand All @@ -34,6 +33,8 @@ type Config struct {
Password string
Database string
AppName string
// PostgreSQL session parameters updated with `SET` command when a connection is created.
ConnParams map[string]interface{}

// Timeout for socket reads. If reached, commands fail with a timeout instead of blocking.
ReadTimeout time.Duration
Expand Down Expand Up @@ -68,87 +69,93 @@ func newDefaultConfig() *Config {
return cfg
}

type DriverOption func(*Connector)
type DriverOption func(cfg *Config)

func WithAddr(addr string) DriverOption {
if addr == "" {
panic("addr is empty")
}
return func(d *Connector) {
d.cfg.Addr = addr
return func(cfg *Config) {
cfg.Addr = addr
}
}

func WithTLSConfig(cfg *tls.Config) DriverOption {
return func(d *Connector) {
d.cfg.TLSConfig = cfg
func WithTLSConfig(tlsConfig *tls.Config) DriverOption {
return func(cfg *Config) {
cfg.TLSConfig = tlsConfig
}
}

func WithUser(user string) DriverOption {
if user == "" {
panic("user is empty")
}
return func(d *Connector) {
d.cfg.User = user
return func(cfg *Config) {
cfg.User = user
}
}

func WithPassword(password string) DriverOption {
return func(d *Connector) {
d.cfg.Password = password
return func(cfg *Config) {
cfg.Password = password
}
}

func WithDatabase(database string) DriverOption {
if database == "" {
panic("database is empty")
}
return func(d *Connector) {
d.cfg.Database = database
return func(cfg *Config) {
cfg.Database = database
}
}

func WithApplicationName(appName string) DriverOption {
return func(d *Connector) {
d.cfg.AppName = appName
return func(cfg *Config) {
cfg.AppName = appName
}
}

func WithConnParams(params map[string]interface{}) DriverOption {
return func(cfg *Config) {
cfg.ConnParams = params
}
}

func WithTimeout(timeout time.Duration) DriverOption {
return func(d *Connector) {
d.cfg.DialTimeout = timeout
d.cfg.ReadTimeout = timeout
d.cfg.WriteTimeout = timeout
return func(cfg *Config) {
cfg.DialTimeout = timeout
cfg.ReadTimeout = timeout
cfg.WriteTimeout = timeout
}
}

func WithDialTimeout(dialTimeout time.Duration) DriverOption {
return func(d *Connector) {
d.cfg.DialTimeout = dialTimeout
return func(cfg *Config) {
cfg.DialTimeout = dialTimeout
}
}

func WithReadTimeout(readTimeout time.Duration) DriverOption {
return func(d *Connector) {
d.cfg.ReadTimeout = readTimeout
return func(cfg *Config) {
cfg.ReadTimeout = readTimeout
}
}

func WithWriteTimeout(writeTimeout time.Duration) DriverOption {
return func(d *Connector) {
d.cfg.WriteTimeout = writeTimeout
return func(cfg *Config) {
cfg.WriteTimeout = writeTimeout
}
}

func WithDSN(dsn string) DriverOption {
return func(d *Connector) {
return func(cfg *Config) {
opts, err := parseDSN(dsn)
if err != nil {
panic(err)
}
for _, opt := range opts {
opt(d)
opt(cfg)
}
}
}
Expand Down Expand Up @@ -225,8 +232,13 @@ func parseDSN(dsn string) ([]DriverOption, error) {
if err != nil {
return nil, q.err
}

if len(rem) > 0 {
return nil, fmt.Errorf("pgdriver: unexpected option: %s", strings.Join(rem, ", "))
params := make(map[string]interface{}, len(rem))
for k, v := range rem {
params[k] = v
}
opts = append(opts, WithConnParams(params))
}

return opts, nil
Expand Down Expand Up @@ -279,17 +291,16 @@ func (o *queryOptions) duration(name string) time.Duration {
return 0
}

func (o *queryOptions) remaining() ([]string, error) {
func (o *queryOptions) remaining() (map[string]string, error) {
if o.err != nil {
return nil, o.err
}
if len(o.q) == 0 {
return nil, nil
}
keys := make([]string, 0, len(o.q))
for k := range o.q {
keys = append(keys, k)
m := make(map[string]string, len(o.q))
for k, ss := range o.q {
m[k] = ss[len(ss)-1]
}
sort.Strings(keys)
return keys, nil
return m, nil
}
16 changes: 16 additions & 0 deletions driver/pgdriver/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ func TestParseDSN(t *testing.T) {
WriteTimeout: 3 * time.Second,
},
},
{
dsn: "postgres://postgres:1@localhost:5432/testDatabase?search_path=foo",
cfg: &pgdriver.Config{
Network: "tcp",
Addr: "localhost:5432",
User: "postgres",
Password: "1",
Database: "testDatabase",
ConnParams: map[string]interface{}{
"search_path": "foo",
},
DialTimeout: 5 * time.Second,
ReadTimeout: 10 * time.Second,
WriteTimeout: 5 * time.Second,
},
},
{
dsn: "postgres://postgres:password@app.xxx.us-east-1.rds.amazonaws.com:5432/test?sslmode=disable",
cfg: &pgdriver.Config{
Expand Down
50 changes: 31 additions & 19 deletions driver/pgdriver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,35 +71,34 @@ type Connector struct {
}

func NewConnector(opts ...DriverOption) *Connector {
d := &Connector{cfg: newDefaultConfig()}
c := &Connector{cfg: newDefaultConfig()}
for _, opt := range opts {
opt(d)
opt(c.cfg)
}
return d
return c
}

var _ driver.Connector = (*Connector)(nil)

func (d *Connector) Connect(ctx context.Context) (driver.Conn, error) {
if err := d.cfg.verify(); err != nil {
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
if err := c.cfg.verify(); err != nil {
return nil, err
}

return newConn(ctx, d)
return newConn(ctx, c.cfg)
}

func (d *Connector) Driver() driver.Driver {
return Driver{connector: d}
func (c *Connector) Driver() driver.Driver {
return Driver{connector: c}
}

func (d *Connector) Config() *Config {
return d.cfg
func (c *Connector) Config() *Config {
return c.cfg
}

//------------------------------------------------------------------------------

type Conn struct {
driver *Connector
cfg *Config

netConn net.Conn
rd *reader
Expand All @@ -112,20 +111,20 @@ type Conn struct {
closed int32
}

func newConn(ctx context.Context, driver *Connector) (*Conn, error) {
netConn, err := driver.cfg.Dialer(ctx, driver.cfg.Network, driver.cfg.Addr)
func newConn(ctx context.Context, cfg *Config) (*Conn, error) {
netConn, err := cfg.Dialer(ctx, cfg.Network, cfg.Addr)
if err != nil {
return nil, err
}

cn := &Conn{
driver: driver,
cfg: cfg,
netConn: netConn,
rd: newReader(netConn),
}

if cn.driver.cfg.TLSConfig != nil {
if err := enableSSL(ctx, cn, cn.driver.cfg.TLSConfig); err != nil {
if cfg.TLSConfig != nil {
if err := enableSSL(ctx, cn, cfg.TLSConfig); err != nil {
return nil, err
}
}
Expand All @@ -134,6 +133,19 @@ func newConn(ctx context.Context, driver *Connector) (*Conn, error) {
return nil, err
}

for k, v := range cfg.ConnParams {
if v != nil {
_, err = cn.ExecContext(ctx, fmt.Sprintf("SET %s TO $1", k), []driver.NamedValue{
{Value: v},
})
} else {
_, err = cn.ExecContext(ctx, fmt.Sprintf("SET %s TO DEFAULT", k), nil)
}
if err != nil {
return nil, err
}
}

return cn, nil
}

Expand Down Expand Up @@ -277,14 +289,14 @@ func (cn *Conn) Ping(ctx context.Context) error {

func (cn *Conn) setReadDeadline(ctx context.Context, timeout time.Duration) {
if timeout == -1 {
timeout = cn.driver.cfg.ReadTimeout
timeout = cn.cfg.ReadTimeout
}
_ = cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout))
}

func (cn *Conn) setWriteDeadline(ctx context.Context, timeout time.Duration) {
if timeout == -1 {
timeout = cn.driver.cfg.WriteTimeout
timeout = cn.cfg.WriteTimeout
}
_ = cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout))
}
Expand Down
14 changes: 14 additions & 0 deletions driver/pgdriver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,20 @@ func TestFloat64(t *testing.T) {
require.Equal(t, 1.1, f)
}

func TestConnParams(t *testing.T) {
db := sql.OpenDB(pgdriver.NewConnector(
pgdriver.WithDSN(dsn()),
pgdriver.WithConnParams(map[string]interface{}{
"search_path": "foo",
}),
))

var searchPath string
err := db.QueryRow("SHOW search_path").Scan(&searchPath)
require.NoError(t, err)
require.Equal(t, "foo", searchPath)
}

func sqlDB() *sql.DB {
db, err := sql.Open("pg", dsn())
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions driver/pgdriver/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ func formatQuery(query string, args []driver.NamedValue) (string, error) {
switch c := p.Next(); c {
case '$':
if i, ok := p.Number(); ok {
if i < 1 {
return "", fmt.Errorf("pgdriver: got $%d, but the minimal arg index is 1", i)
}
if i > len(args) {
return "", fmt.Errorf("pgdriver: got %d args, wanted %d", len(args), i)
}
Expand Down
14 changes: 7 additions & 7 deletions driver/pgdriver/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,12 @@ func writeStartup(ctx context.Context, cn *Conn) error {
wb.StartMessage(0)
wb.WriteInt32(196608)
wb.WriteString("user")
wb.WriteString(cn.driver.cfg.User)
wb.WriteString(cn.cfg.User)
wb.WriteString("database")
wb.WriteString(cn.driver.cfg.Database)
if cn.driver.cfg.AppName != "" {
wb.WriteString(cn.cfg.Database)
if cn.cfg.AppName != "" {
wb.WriteString("application_name")
wb.WriteString(cn.driver.cfg.AppName)
wb.WriteString(cn.cfg.AppName)
}
wb.WriteString("")
wb.FinishMessage()
Expand Down Expand Up @@ -233,7 +233,7 @@ func auth(ctx context.Context, cn *Conn, rd *reader) error {
}

func authCleartext(ctx context.Context, cn *Conn, rd *reader) error {
if err := writePassword(ctx, cn, cn.driver.cfg.Password); err != nil {
if err := writePassword(ctx, cn, cn.cfg.Password); err != nil {
return err
}
return readAuthOK(cn, rd)
Expand Down Expand Up @@ -274,7 +274,7 @@ func authMD5(ctx context.Context, cn *Conn, rd *reader) error {
return err
}

secret := "md5" + md5s(md5s(cn.driver.cfg.Password+cn.driver.cfg.User)+string(b))
secret := "md5" + md5s(md5s(cn.cfg.Password+cn.cfg.User)+string(b))
if err := writePassword(ctx, cn, secret); err != nil {
return err
}
Expand Down Expand Up @@ -323,7 +323,7 @@ loop:
}

creds := sasl.Credentials(func() (Username, Password, Identity []byte) {
return []byte(cn.driver.cfg.User), []byte(cn.driver.cfg.Password), nil
return []byte(cn.cfg.User), []byte(cn.cfg.Password), nil
})
client := sasl.NewClient(saslMech, creds)

Expand Down

0 comments on commit 408caf0

Please sign in to comment.