From 408caf0bb579e23e26fc6149efd6851814c22517 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Mon, 27 Sep 2021 13:23:00 +0300 Subject: [PATCH] feat(pgdriver): add Config.ConnParams to session config params --- driver/pgdriver/config.go | 79 +++++++++++++++++++--------------- driver/pgdriver/config_test.go | 16 +++++++ driver/pgdriver/driver.go | 50 +++++++++++++-------- driver/pgdriver/driver_test.go | 14 ++++++ driver/pgdriver/format.go | 3 ++ driver/pgdriver/proto.go | 14 +++--- 6 files changed, 116 insertions(+), 60 deletions(-) diff --git a/driver/pgdriver/config.go b/driver/pgdriver/config.go index c3e1f66f7..8931b24c0 100644 --- a/driver/pgdriver/config.go +++ b/driver/pgdriver/config.go @@ -8,7 +8,6 @@ import ( "net" "net/url" "os" - "sort" "strconv" "strings" "time" @@ -34,6 +33,8 @@ type Config struct { Password string Database string AppName string + // PostgreSQL session parameters updated with `SET` command when a connection is created. + ConnParams map[string]interface{} // Timeout for socket reads. If reached, commands fail with a timeout instead of blocking. ReadTimeout time.Duration @@ -68,20 +69,20 @@ func newDefaultConfig() *Config { return cfg } -type DriverOption func(*Connector) +type DriverOption func(cfg *Config) func WithAddr(addr string) DriverOption { if addr == "" { panic("addr is empty") } - return func(d *Connector) { - d.cfg.Addr = addr + return func(cfg *Config) { + cfg.Addr = addr } } -func WithTLSConfig(cfg *tls.Config) DriverOption { - return func(d *Connector) { - d.cfg.TLSConfig = cfg +func WithTLSConfig(tlsConfig *tls.Config) DriverOption { + return func(cfg *Config) { + cfg.TLSConfig = tlsConfig } } @@ -89,14 +90,14 @@ func WithUser(user string) DriverOption { if user == "" { panic("user is empty") } - return func(d *Connector) { - d.cfg.User = user + return func(cfg *Config) { + cfg.User = user } } func WithPassword(password string) DriverOption { - return func(d *Connector) { - d.cfg.Password = password + return func(cfg *Config) { + cfg.Password = password } } @@ -104,51 +105,57 @@ func WithDatabase(database string) DriverOption { if database == "" { panic("database is empty") } - return func(d *Connector) { - d.cfg.Database = database + return func(cfg *Config) { + cfg.Database = database } } func WithApplicationName(appName string) DriverOption { - return func(d *Connector) { - d.cfg.AppName = appName + return func(cfg *Config) { + cfg.AppName = appName + } +} + +func WithConnParams(params map[string]interface{}) DriverOption { + return func(cfg *Config) { + cfg.ConnParams = params } } func WithTimeout(timeout time.Duration) DriverOption { - return func(d *Connector) { - d.cfg.DialTimeout = timeout - d.cfg.ReadTimeout = timeout - d.cfg.WriteTimeout = timeout + return func(cfg *Config) { + cfg.DialTimeout = timeout + cfg.ReadTimeout = timeout + cfg.WriteTimeout = timeout } } func WithDialTimeout(dialTimeout time.Duration) DriverOption { - return func(d *Connector) { - d.cfg.DialTimeout = dialTimeout + return func(cfg *Config) { + cfg.DialTimeout = dialTimeout } } func WithReadTimeout(readTimeout time.Duration) DriverOption { - return func(d *Connector) { - d.cfg.ReadTimeout = readTimeout + return func(cfg *Config) { + cfg.ReadTimeout = readTimeout } } func WithWriteTimeout(writeTimeout time.Duration) DriverOption { - return func(d *Connector) { - d.cfg.WriteTimeout = writeTimeout + return func(cfg *Config) { + cfg.WriteTimeout = writeTimeout } } func WithDSN(dsn string) DriverOption { - return func(d *Connector) { + return func(cfg *Config) { opts, err := parseDSN(dsn) if err != nil { panic(err) } for _, opt := range opts { - opt(d) + opt(cfg) } } } @@ -225,8 +232,13 @@ func parseDSN(dsn string) ([]DriverOption, error) { if err != nil { return nil, q.err } + if len(rem) > 0 { - return nil, fmt.Errorf("pgdriver: unexpected option: %s", strings.Join(rem, ", ")) + params := make(map[string]interface{}, len(rem)) + for k, v := range rem { + params[k] = v + } + opts = append(opts, WithConnParams(params)) } return opts, nil @@ -279,17 +291,16 @@ func (o *queryOptions) duration(name string) time.Duration { return 0 } -func (o *queryOptions) remaining() ([]string, error) { +func (o *queryOptions) remaining() (map[string]string, error) { if o.err != nil { return nil, o.err } if len(o.q) == 0 { return nil, nil } - keys := make([]string, 0, len(o.q)) - for k := range o.q { - keys = append(keys, k) + m := make(map[string]string, len(o.q)) + for k, ss := range o.q { + m[k] = ss[len(ss)-1] } - sort.Strings(keys) - return keys, nil + return m, nil } diff --git a/driver/pgdriver/config_test.go b/driver/pgdriver/config_test.go index ac2632bef..f2ce9c492 100644 --- a/driver/pgdriver/config_test.go +++ b/driver/pgdriver/config_test.go @@ -41,6 +41,22 @@ func TestParseDSN(t *testing.T) { WriteTimeout: 3 * time.Second, }, }, + { + dsn: "postgres://postgres:1@localhost:5432/testDatabase?search_path=foo", + cfg: &pgdriver.Config{ + Network: "tcp", + Addr: "localhost:5432", + User: "postgres", + Password: "1", + Database: "testDatabase", + ConnParams: map[string]interface{}{ + "search_path": "foo", + }, + DialTimeout: 5 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 5 * time.Second, + }, + }, { dsn: "postgres://postgres:password@app.xxx.us-east-1.rds.amazonaws.com:5432/test?sslmode=disable", cfg: &pgdriver.Config{ diff --git a/driver/pgdriver/driver.go b/driver/pgdriver/driver.go index 2df2d530b..3d5646d73 100644 --- a/driver/pgdriver/driver.go +++ b/driver/pgdriver/driver.go @@ -71,35 +71,34 @@ type Connector struct { } func NewConnector(opts ...DriverOption) *Connector { - d := &Connector{cfg: newDefaultConfig()} + c := &Connector{cfg: newDefaultConfig()} for _, opt := range opts { - opt(d) + opt(c.cfg) } - return d + return c } var _ driver.Connector = (*Connector)(nil) -func (d *Connector) Connect(ctx context.Context) (driver.Conn, error) { - if err := d.cfg.verify(); err != nil { +func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { + if err := c.cfg.verify(); err != nil { return nil, err } - - return newConn(ctx, d) + return newConn(ctx, c.cfg) } -func (d *Connector) Driver() driver.Driver { - return Driver{connector: d} +func (c *Connector) Driver() driver.Driver { + return Driver{connector: c} } -func (d *Connector) Config() *Config { - return d.cfg +func (c *Connector) Config() *Config { + return c.cfg } //------------------------------------------------------------------------------ type Conn struct { - driver *Connector + cfg *Config netConn net.Conn rd *reader @@ -112,20 +111,20 @@ type Conn struct { closed int32 } -func newConn(ctx context.Context, driver *Connector) (*Conn, error) { - netConn, err := driver.cfg.Dialer(ctx, driver.cfg.Network, driver.cfg.Addr) +func newConn(ctx context.Context, cfg *Config) (*Conn, error) { + netConn, err := cfg.Dialer(ctx, cfg.Network, cfg.Addr) if err != nil { return nil, err } cn := &Conn{ - driver: driver, + cfg: cfg, netConn: netConn, rd: newReader(netConn), } - if cn.driver.cfg.TLSConfig != nil { - if err := enableSSL(ctx, cn, cn.driver.cfg.TLSConfig); err != nil { + if cfg.TLSConfig != nil { + if err := enableSSL(ctx, cn, cfg.TLSConfig); err != nil { return nil, err } } @@ -134,6 +133,19 @@ func newConn(ctx context.Context, driver *Connector) (*Conn, error) { return nil, err } + for k, v := range cfg.ConnParams { + if v != nil { + _, err = cn.ExecContext(ctx, fmt.Sprintf("SET %s TO $1", k), []driver.NamedValue{ + {Value: v}, + }) + } else { + _, err = cn.ExecContext(ctx, fmt.Sprintf("SET %s TO DEFAULT", k), nil) + } + if err != nil { + return nil, err + } + } + return cn, nil } @@ -277,14 +289,14 @@ func (cn *Conn) Ping(ctx context.Context) error { func (cn *Conn) setReadDeadline(ctx context.Context, timeout time.Duration) { if timeout == -1 { - timeout = cn.driver.cfg.ReadTimeout + timeout = cn.cfg.ReadTimeout } _ = cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)) } func (cn *Conn) setWriteDeadline(ctx context.Context, timeout time.Duration) { if timeout == -1 { - timeout = cn.driver.cfg.WriteTimeout + timeout = cn.cfg.WriteTimeout } _ = cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)) } diff --git a/driver/pgdriver/driver_test.go b/driver/pgdriver/driver_test.go index 26f90f72e..a61ae69d7 100644 --- a/driver/pgdriver/driver_test.go +++ b/driver/pgdriver/driver_test.go @@ -217,6 +217,20 @@ func TestFloat64(t *testing.T) { require.Equal(t, 1.1, f) } +func TestConnParams(t *testing.T) { + db := sql.OpenDB(pgdriver.NewConnector( + pgdriver.WithDSN(dsn()), + pgdriver.WithConnParams(map[string]interface{}{ + "search_path": "foo", + }), + )) + + var searchPath string + err := db.QueryRow("SHOW search_path").Scan(&searchPath) + require.NoError(t, err) + require.Equal(t, "foo", searchPath) +} + func sqlDB() *sql.DB { db, err := sql.Open("pg", dsn()) if err != nil { diff --git a/driver/pgdriver/format.go b/driver/pgdriver/format.go index c85967da5..18252aa40 100644 --- a/driver/pgdriver/format.go +++ b/driver/pgdriver/format.go @@ -22,6 +22,9 @@ func formatQuery(query string, args []driver.NamedValue) (string, error) { switch c := p.Next(); c { case '$': if i, ok := p.Number(); ok { + if i < 1 { + return "", fmt.Errorf("pgdriver: got $%d, but the minimal arg index is 1", i) + } if i > len(args) { return "", fmt.Errorf("pgdriver: got %d args, wanted %d", len(args), i) } diff --git a/driver/pgdriver/proto.go b/driver/pgdriver/proto.go index c6fbf05c5..310b380ba 100644 --- a/driver/pgdriver/proto.go +++ b/driver/pgdriver/proto.go @@ -194,12 +194,12 @@ func writeStartup(ctx context.Context, cn *Conn) error { wb.StartMessage(0) wb.WriteInt32(196608) wb.WriteString("user") - wb.WriteString(cn.driver.cfg.User) + wb.WriteString(cn.cfg.User) wb.WriteString("database") - wb.WriteString(cn.driver.cfg.Database) - if cn.driver.cfg.AppName != "" { + wb.WriteString(cn.cfg.Database) + if cn.cfg.AppName != "" { wb.WriteString("application_name") - wb.WriteString(cn.driver.cfg.AppName) + wb.WriteString(cn.cfg.AppName) } wb.WriteString("") wb.FinishMessage() @@ -233,7 +233,7 @@ func auth(ctx context.Context, cn *Conn, rd *reader) error { } func authCleartext(ctx context.Context, cn *Conn, rd *reader) error { - if err := writePassword(ctx, cn, cn.driver.cfg.Password); err != nil { + if err := writePassword(ctx, cn, cn.cfg.Password); err != nil { return err } return readAuthOK(cn, rd) @@ -274,7 +274,7 @@ func authMD5(ctx context.Context, cn *Conn, rd *reader) error { return err } - secret := "md5" + md5s(md5s(cn.driver.cfg.Password+cn.driver.cfg.User)+string(b)) + secret := "md5" + md5s(md5s(cn.cfg.Password+cn.cfg.User)+string(b)) if err := writePassword(ctx, cn, secret); err != nil { return err } @@ -323,7 +323,7 @@ loop: } creds := sasl.Credentials(func() (Username, Password, Identity []byte) { - return []byte(cn.driver.cfg.User), []byte(cn.driver.cfg.Password), nil + return []byte(cn.cfg.User), []byte(cn.cfg.Password), nil }) client := sasl.NewClient(saslMech, creds)