diff --git a/driver/pgdriver/config.go b/driver/pgdriver/config.go index 8e8abfe59..c3e1f66f7 100644 --- a/driver/pgdriver/config.go +++ b/driver/pgdriver/config.go @@ -8,6 +8,8 @@ import ( "net" "net/url" "os" + "sort" + "strconv" "strings" "time" ) @@ -33,11 +35,9 @@ type Config struct { Database string AppName string - // Timeout for socket reads. If reached, commands will fail - // with a timeout instead of blocking. + // Timeout for socket reads. If reached, commands fail with a timeout instead of blocking. ReadTimeout time.Duration - // Timeout for socket writes. If reached, commands will fail - // with a timeout instead of blocking. + // Timeout for socket writes. If reached, commands fail with a timeout instead of blocking. WriteTimeout time.Duration } @@ -153,6 +153,15 @@ func WithDSN(dsn string) DriverOption { } } +func env(key, defValue string) string { + if s := os.Getenv(key); s != "" { + return s + } + return defValue +} + +//------------------------------------------------------------------------------ + func parseDSN(dsn string) ([]DriverOption, error) { u, err := url.Parse(dsn) if err != nil { @@ -163,11 +172,6 @@ func parseDSN(dsn string) ([]DriverOption, error) { return nil, errors.New("pgdriver: invalid scheme: " + u.Scheme) } - query, err := url.ParseQuery(u.RawQuery) - if err != nil { - return nil, err - } - var opts []DriverOption if u.Host != "" { @@ -187,39 +191,45 @@ func parseDSN(dsn string) ([]DriverOption, error) { opts = append(opts, WithDatabase(u.Path[1:])) } - if appName := query.Get("application_name"); appName != "" { + q := queryOptions{q: u.Query()} + + if appName := q.string("application_name"); appName != "" { opts = append(opts, WithApplicationName(appName)) } - delete(query, "application_name") - - if sslMode := query.Get("sslmode"); sslMode != "" { - switch sslMode { - case "verify-ca", "verify-full": - opts = append(opts, WithTLSConfig(new(tls.Config))) - case "allow", "prefer", "require": - opts = append(opts, WithTLSConfig(&tls.Config{InsecureSkipVerify: true})) - case "disable": - // no TLS config - default: - return nil, fmt.Errorf("pgdriver: sslmode '%s' is not supported", sslMode) - } - } else { + + switch sslMode := q.string("sslmode"); sslMode { + case "verify-ca", "verify-full": + opts = append(opts, WithTLSConfig(new(tls.Config))) + case "allow", "prefer", "require": opts = append(opts, WithTLSConfig(&tls.Config{InsecureSkipVerify: true})) + case "disable", "": + // no TLS config + default: + return nil, fmt.Errorf("pgdriver: sslmode '%s' is not supported", sslMode) } - delete(query, "sslmode") - for key := range query { - return nil, fmt.Errorf("pgdriver: unsupported option=%q", key) + if d := q.duration("timeout"); d != 0 { + opts = append(opts, WithTimeout(d)) + } + if d := q.duration("dial_timeout"); d != 0 { + opts = append(opts, WithDialTimeout(d)) + } + if d := q.duration("read_timeout"); d != 0 { + opts = append(opts, WithReadTimeout(d)) + } + if d := q.duration("write_timeout"); d != 0 { + opts = append(opts, WithWriteTimeout(d)) } - return opts, nil -} - -func env(key, defValue string) string { - if s := os.Getenv(key); s != "" { - return s + rem, err := q.remaining() + if err != nil { + return nil, q.err } - return defValue + if len(rem) > 0 { + return nil, fmt.Errorf("pgdriver: unexpected option: %s", strings.Join(rem, ", ")) + } + + return opts, nil } // verify is a method to make sure if the config is legitimate @@ -231,3 +241,55 @@ func (c *Config) verify() error { } return nil } + +type queryOptions struct { + q url.Values + err error +} + +func (o *queryOptions) string(name string) string { + vs := o.q[name] + if len(vs) == 0 { + return "" + } + delete(o.q, name) // enable detection of unknown parameters + return vs[len(vs)-1] +} + +func (o *queryOptions) duration(name string) time.Duration { + s := o.string(name) + if s == "" { + return 0 + } + // try plain number first + if i, err := strconv.Atoi(s); err == nil { + if i <= 0 { + // disable timeouts + return -1 + } + return time.Duration(i) * time.Second + } + dur, err := time.ParseDuration(s) + if err == nil { + return dur + } + if o.err == nil { + o.err = fmt.Errorf("pgdriver: invalid %s duration: %w", name, err) + } + return 0 +} + +func (o *queryOptions) remaining() ([]string, error) { + if o.err != nil { + return nil, o.err + } + if len(o.q) == 0 { + return nil, nil + } + keys := make([]string, 0, len(o.q)) + for k := range o.q { + keys = append(keys, k) + } + sort.Strings(keys) + return keys, nil +} diff --git a/driver/pgdriver/config_test.go b/driver/pgdriver/config_test.go index af8cb4565..ac2632bef 100644 --- a/driver/pgdriver/config_test.go +++ b/driver/pgdriver/config_test.go @@ -28,6 +28,19 @@ func TestParseDSN(t *testing.T) { WriteTimeout: 5 * time.Second, }, }, + { + dsn: "postgres://postgres:1@localhost:5432/testDatabase?sslmode=disable&dial_timeout=1&read_timeout=2s&write_timeout=3", + cfg: &pgdriver.Config{ + Network: "tcp", + Addr: "localhost:5432", + User: "postgres", + Password: "1", + Database: "testDatabase", + DialTimeout: 1 * time.Second, + ReadTimeout: 2 * time.Second, + WriteTimeout: 3 * time.Second, + }, + }, { dsn: "postgres://postgres:password@app.xxx.us-east-1.rds.amazonaws.com:5432/test?sslmode=disable", cfg: &pgdriver.Config{ diff --git a/example/pg-faceted-search/main.go b/example/pg-faceted-search/main.go index 9b706af86..c545fee6b 100644 --- a/example/pg-faceted-search/main.go +++ b/example/pg-faceted-search/main.go @@ -65,8 +65,7 @@ func main() { panic(err) } - fmt.Println("\n") - fmt.Println("all facets:\n") + fmt.Printf("\n\nall facets:\n\n") spew.Dump(facets) facets, err = selectFacets(ctx, db, "moods:mysterious") @@ -74,8 +73,7 @@ func main() { panic(err) } - fmt.Println("\n") - fmt.Println("moods:mysterious facets:\n") + fmt.Printf("\n\nmoods:mysterious facets:\n\n") spew.Dump(facets) }