Skip to content

Commit

Permalink
Merge pull request #839 from kardianos/master
Browse files Browse the repository at this point in the history
pq: pipe the context through to the dialer, make use of the Connector
  • Loading branch information
maddyblue authored Mar 26, 2019
2 parents 5ccc985 + 4c45500 commit d6156e1
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 97 deletions.
132 changes: 50 additions & 82 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pq

import (
"bufio"
"context"
"crypto/md5"
"database/sql"
"database/sql/driver"
Expand Down Expand Up @@ -89,13 +90,24 @@ type Dialer interface {
DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
}

type defaultDialer struct{}
type DialerContext interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

type defaultDialer struct {
d net.Dialer
}

func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) {
return net.Dial(ntw, addr)
func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
return d.d.Dial(network, address)
}
func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout(ntw, addr, timeout)
func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return d.DialContext(ctx, network, address)
}
func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.d.DialContext(ctx, network, address)
}

type conn struct {
Expand Down Expand Up @@ -244,98 +256,43 @@ func (cn *conn) writeBuf(b byte) *writeBuf {
}
}

// Open opens a new connection to the database. name is a connection string.
// Open opens a new connection to the database. dsn is a connection string.
// Most users should only use it through database/sql package from the standard
// library.
func Open(name string) (_ driver.Conn, err error) {
return DialOpen(defaultDialer{}, name)
func Open(dsn string) (_ driver.Conn, err error) {
return DialOpen(defaultDialer{}, dsn)
}

// DialOpen opens a new connection to the database using a dialer.
func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
c, err := NewConnector(dsn)
if err != nil {
return nil, err
}
c.dialer = d
return c.open(context.Background())
}

func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
// Handle any panics during connection initialization. Note that we
// specifically do *not* want to use errRecover(), as that would turn any
// connection errors into ErrBadConns, hiding the real error message from
// the user.
defer errRecoverNoErrBadConn(&err)

o := make(values)

// A number of defaults are applied here, in this order:
//
// * Very low precedence defaults applied in every situation
// * Environment variables
// * Explicitly passed connection information
o["host"] = "localhost"
o["port"] = "5432"
// N.B.: Extra float digits should be set to 3, but that breaks
// Postgres 8.4 and older, where the max is 2.
o["extra_float_digits"] = "2"
for k, v := range parseEnviron(os.Environ()) {
o[k] = v
}

if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
name, err = ParseURL(name)
if err != nil {
return nil, err
}
}

if err := parseOpts(name, o); err != nil {
return nil, err
}

// Use the "fallback" application name if necessary
if fallback, ok := o["fallback_application_name"]; ok {
if _, ok := o["application_name"]; !ok {
o["application_name"] = fallback
}
}

// We can't work with any client_encoding other than UTF-8 currently.
// However, we have historically allowed the user to set it to UTF-8
// explicitly, and there's no reason to break such programs, so allow that.
// Note that the "options" setting could also set client_encoding, but
// parsing its value is not worth it. Instead, we always explicitly send
// client_encoding as a separate run-time parameter, which should override
// anything set in options.
if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
return nil, errors.New("client_encoding must be absent or 'UTF8'")
}
o["client_encoding"] = "UTF8"
// DateStyle needs a similar treatment.
if datestyle, ok := o["datestyle"]; ok {
if datestyle != "ISO, MDY" {
panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v",
"ISO, MDY", datestyle))
}
} else {
o["datestyle"] = "ISO, MDY"
}
o := c.opts

// If a user is not provided by any other means, the last
// resort is to use the current operating system provided user
// name.
if _, ok := o["user"]; !ok {
u, err := userCurrent()
if err != nil {
return nil, err
}
o["user"] = u
}

cn := &conn{
cn = &conn{
opts: o,
dialer: d,
dialer: c.dialer,
}
err = cn.handleDriverSettings(o)
if err != nil {
return nil, err
}
cn.handlePgpass(o)

cn.c, err = dial(d, o)
cn.c, err = dial(ctx, c.dialer, o)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -364,10 +321,10 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
return cn, err
}

func dial(d Dialer, o values) (net.Conn, error) {
ntw, addr := network(o)
func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
network, address := network(o)
// SSL is not necessary or supported over UNIX domain sockets
if ntw == "unix" {
if network == "unix" {
o["sslmode"] = "disable"
}

Expand All @@ -378,19 +335,30 @@ func dial(d Dialer, o values) (net.Conn, error) {
return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
}
duration := time.Duration(seconds) * time.Second

// connect_timeout should apply to the entire connection establishment
// procedure, so we both use a timeout for the TCP connection
// establishment and set a deadline for doing the initial handshake.
// The deadline is then reset after startup() is done.
deadline := time.Now().Add(duration)
conn, err := d.DialTimeout(ntw, addr, duration)
var conn net.Conn
if dctx, ok := d.(DialerContext); ok {
ctx, cancel := context.WithTimeout(ctx, duration)
defer cancel()
conn, err = dctx.DialContext(ctx, network, address)
} else {
conn, err = d.DialTimeout(network, address, duration)
}
if err != nil {
return nil, err
}
err = conn.SetDeadline(deadline)
return conn, err
}
return d.Dial(ntw, addr)
if dctx, ok := d.(DialerContext); ok {
return dctx.DialContext(ctx, network, address)
}
return d.Dial(network, address)
}

func network(o values) (string, string) {
Expand Down
14 changes: 11 additions & 3 deletions conn_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"io/ioutil"
"time"
)

// Implement the "QueryerContext" interface
Expand Down Expand Up @@ -92,7 +93,14 @@ func (cn *conn) watchCancel(ctx context.Context) func() {
go func() {
select {
case <-done:
_ = cn.cancel()
// At this point the function level context is canceled,
// so it must not be used for the additional network
// request to cancel the query.
// Create a new context to pass into the dial.
ctxCancel, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

_ = cn.cancel(ctxCancel)
finished <- struct{}{}
case <-finished:
}
Expand All @@ -107,8 +115,8 @@ func (cn *conn) watchCancel(ctx context.Context) func() {
return nil
}

func (cn *conn) cancel() error {
c, err := dial(cn.dialer, cn.opts)
func (cn *conn) cancel(ctx context.Context) error {
c, err := dial(ctx, cn.dialer, cn.opts)
if err != nil {
return err
}
Expand Down
91 changes: 79 additions & 12 deletions connector.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// +build go1.10

package pq

import (
"context"
"database/sql/driver"
"errors"
"fmt"
"os"
"strings"
)

// Connector represents a fixed configuration for the pq driver with a given
Expand All @@ -14,30 +16,95 @@ import (
//
// See https://golang.org/pkg/database/sql/driver/#Connector.
// See https://golang.org/pkg/database/sql/#OpenDB.
type connector struct {
name string
type Connector struct {
opts values
dialer Dialer
}

// Connect returns a connection to the database using the fixed configuration
// of this Connector. Context is not used.
func (c *connector) Connect(_ context.Context) (driver.Conn, error) {
return (&Driver{}).Open(c.name)
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
return c.open(ctx)
}

// Driver returnst the underlying driver of this Connector.
func (c *connector) Driver() driver.Driver {
func (c *Connector) Driver() driver.Driver {
return &Driver{}
}

var _ driver.Connector = &connector{}

// NewConnector returns a connector for the pq driver in a fixed configuration
// with the given name. The returned connector can be used to create any number
// with the given dsn. The returned connector can be used to create any number
// of equivalent Conn's. The returned connector is intended to be used with
// database/sql.OpenDB.
//
// See https://golang.org/pkg/database/sql/driver/#Connector.
// See https://golang.org/pkg/database/sql/#OpenDB.
func NewConnector(name string) (driver.Connector, error) {
return &connector{name: name}, nil
func NewConnector(dsn string) (*Connector, error) {
var err error
o := make(values)

// A number of defaults are applied here, in this order:
//
// * Very low precedence defaults applied in every situation
// * Environment variables
// * Explicitly passed connection information
o["host"] = "localhost"
o["port"] = "5432"
// N.B.: Extra float digits should be set to 3, but that breaks
// Postgres 8.4 and older, where the max is 2.
o["extra_float_digits"] = "2"
for k, v := range parseEnviron(os.Environ()) {
o[k] = v
}

if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
dsn, err = ParseURL(dsn)
if err != nil {
return nil, err
}
}

if err := parseOpts(dsn, o); err != nil {
return nil, err
}

// Use the "fallback" application name if necessary
if fallback, ok := o["fallback_application_name"]; ok {
if _, ok := o["application_name"]; !ok {
o["application_name"] = fallback
}
}

// We can't work with any client_encoding other than UTF-8 currently.
// However, we have historically allowed the user to set it to UTF-8
// explicitly, and there's no reason to break such programs, so allow that.
// Note that the "options" setting could also set client_encoding, but
// parsing its value is not worth it. Instead, we always explicitly send
// client_encoding as a separate run-time parameter, which should override
// anything set in options.
if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
return nil, errors.New("client_encoding must be absent or 'UTF8'")
}
o["client_encoding"] = "UTF8"
// DateStyle needs a similar treatment.
if datestyle, ok := o["datestyle"]; ok {
if datestyle != "ISO, MDY" {
return nil, fmt.Errorf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle)
}
} else {
o["datestyle"] = "ISO, MDY"
}

// If a user is not provided by any other means, the last
// resort is to use the current operating system provided user
// name.
if _, ok := o["user"]; !ok {
u, err := userCurrent()
if err != nil {
return nil, err
}
o["user"] = u
}

return &Connector{opts: o, dialer: defaultDialer{}}, nil
}

0 comments on commit d6156e1

Please sign in to comment.