Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for retrieving multiple results. #425

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 100 additions & 42 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ var (
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.")
ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.")
ErrNoMoreResults = errors.New("pq: no more results")
)

const NextResults = "NEXT"

type drv struct{}

func (d *drv) Open(name string) (driver.Conn, error) {
Expand Down Expand Up @@ -115,6 +118,9 @@ type conn struct {
// Whether to always send []byte parameters over as binary. Enables single
// round-trip mode for non-prepared Query calls.
binaryParameters bool

// Whether the connection is ready to execute a query.
readyForQuery bool
}

// Handle driver-side settings in parsed connection string.
Expand Down Expand Up @@ -164,7 +170,7 @@ func (c *conn) handlePgpass(o values) {
return
}
mode := fileinfo.Mode()
if mode & (0x77) != 0 {
if mode&(0x77) != 0 {
// XXX should warn about incorrect .pgpass permissions as psql does
return
}
Expand All @@ -180,7 +186,7 @@ func (c *conn) handlePgpass(o values) {
db := o.Get("dbname")
username := o.Get("user")
// From: https://github.com/tg/pgpass/blob/master/reader.go
getFields := func (s string) []string {
getFields := func(s string) []string {
fs := make([]string, 0, 5)
f := make([]rune, 0, len(s))

Expand All @@ -200,7 +206,7 @@ func (c *conn) handlePgpass(o values) {
}
}
return append(fs, string(f))
}
}
for scanner.Scan() {
line := scanner.Text()
if len(line) == 0 || line[0] == '#' {
Expand All @@ -210,7 +216,7 @@ func (c *conn) handlePgpass(o values) {
if len(split) != 5 {
continue
}
if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
o["password"] = split[4]
return
}
Expand Down Expand Up @@ -587,6 +593,8 @@ func (cn *conn) gname() string {
}

func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
cn.waitReadyForQuery()

b := cn.writeBuf('Q')
b.string(q)
cn.send(b)
Expand Down Expand Up @@ -614,56 +622,73 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err
func (cn *conn) simpleQuery(q string) (res *rows, err error) {
defer cn.errRecover(&err)

st := &stmt{cn: cn, name: ""}

b := cn.writeBuf('Q')
b.string(q)
cn.send(b)
querySent := false
nextResult := q == NextResults

for {
if cn.readyForQuery && !querySent {
if nextResult {
return nil, ErrNoMoreResults
}

// Mark the connection as having sent a query.
cn.readyForQuery = false
b := cn.writeBuf('Q')
b.string(q)
cn.send(b)
querySent = true
}

t, r := cn.recv1()
switch t {
case 'C', 'I':
// We allow queries which don't return any results through Query as
// well as Exec. We still have to give database/sql a rows object
// the user can close, though, to avoid connections from being
// leaked. A "rows" with done=true works fine for that purpose.
if err != nil {
cn.bad = true
errorf("unexpected message %q in simple query execution", t)
}
if res == nil {
res = &rows{
cn: cn,
colNames: st.colNames,
colTyps: st.colTyps,
colFmts: st.colFmts,
if nextResult || querySent {
// We allow queries which don't return any results through Query as
// well as Exec. We still have to give database/sql a rows object
// the user can close, though, to avoid connections from being
// leaked. A "rows" with done=true works fine for that purpose.
if err != nil {
cn.bad = true
errorf("unexpected message %q in simple query execution", t)
}
if res == nil {
res = &rows{
cn: cn,
}
}
res.done = true
}
res.done = true
case 'Z':
cn.processReadyForQuery(r)
// done
return
if querySent {
// done
return
}
case 'E':
res = nil
err = parseError(r)
if nextResult || querySent {
res = nil
err = parseError(r)
}
case 'D':
if res == nil {
cn.bad = true
errorf("unexpected DataRow in simple query execution")
if nextResult || querySent {
if res == nil {
cn.bad = true
errorf("unexpected DataRow in simple query execution")
}
// the query didn't fail; kick off to Next
cn.saveMessage(t, r)
return
}
// the query didn't fail; kick off to Next
cn.saveMessage(t, r)
return
case 'T':
// res might be non-nil here if we received a previous
// CommandComplete, but that's fine; just overwrite it
res = &rows{cn: cn}
res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)

// To work around a bug in QueryRow in Go 1.2 and earlier, wait
// until the first DataRow has been received.
if nextResult || querySent {
// res might be non-nil here if we received a previous
// CommandComplete, but that's fine; just overwrite it
res = &rows{cn: cn}
res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)

// To work around a bug in QueryRow in Go 1.2 and earlier, wait
// until the first DataRow has been received.
}
default:
cn.bad = true
errorf("unknown response for simple query: %q", t)
Expand Down Expand Up @@ -747,6 +772,8 @@ func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
}
defer cn.errRecover(&err)

cn.waitReadyForQuery()

if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
return cn.prepareCopyIn(q)
}
Expand Down Expand Up @@ -782,6 +809,8 @@ func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err err
return cn.simpleQuery(query)
}

cn.waitReadyForQuery()

if cn.binaryParameters {
cn.sendBinaryModeQuery(query, args)

Expand Down Expand Up @@ -818,6 +847,8 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err
return r, err
}

cn.waitReadyForQuery()

if cn.binaryParameters {
cn.sendBinaryModeQuery(query, args)

Expand Down Expand Up @@ -1306,6 +1337,10 @@ func (st *stmt) exec(v []driver.Value) {
}

cn := st.cn
cn.waitReadyForQuery()
// Mark the connection has having sent a query.
cn.readyForQuery = false

w := cn.writeBuf('B')
w.byte(0) // unnamed portal
w.string(st.name)
Expand Down Expand Up @@ -1436,7 +1471,11 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
case 'E':
err = parseError(&rs.rb)
case 'C', 'I':
continue
rs.done = true
if err != nil {
return err
}
return io.EOF
case 'Z':
conn.processReadyForQuery(&rs.rb)
rs.done = true
Expand Down Expand Up @@ -1532,6 +1571,9 @@ func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
}

// Mark the connection has having sent a query.
cn.readyForQuery = false

b := cn.writeBuf('P')
b.byte(0) // unnamed statement
b.string(query)
Expand Down Expand Up @@ -1581,6 +1623,7 @@ func (c *conn) processParameterStatus(r *readBuf) {

func (c *conn) processReadyForQuery(r *readBuf) {
c.txnStatus = transactionStatus(r.byte())
c.readyForQuery = true
}

func (cn *conn) readReadyForQuery() {
Expand All @@ -1595,6 +1638,21 @@ func (cn *conn) readReadyForQuery() {
}
}

func (cn *conn) waitReadyForQuery() {
// The postgres server sends a 'Z' command when it is ready to receive a
// query. We use this as a sync marker to skip over commands we're not
// handling in our current state. For example, we might be skipping over
// subsequent results when a query contained multiple statements and only the
// first result was retrieved.
for !cn.readyForQuery {
t, r := cn.recv1()
switch t {
case 'Z':
cn.processReadyForQuery(r)
}
}
}

func (cn *conn) readParseResponse() {
t, r := cn.recv1()
switch t {
Expand Down
65 changes: 65 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ func TestOpenURL(t *testing.T) {
}

const pgpass_file = "/tmp/pqgotest_pgpass"

func TestPgpass(t *testing.T) {
testAssert := func(conninfo string, expected string, reason string) {
conn, err := openTestConnConninfo(conninfo)
Expand Down Expand Up @@ -339,6 +340,70 @@ func TestRowsCloseBeforeDone(t *testing.T) {
}
}

func TestMultipleResults(t *testing.T) {
db := openTestConn(t)
defer db.Close()

var val int
if err := db.QueryRow("SELECT 1; SELECT 2").Scan(&val); err != nil {
t.Fatal(err)
}
if val != 1 {
t.Fatalf("expected 1, but found %d", val)
}
if err := db.QueryRow(NextResults).Scan(&val); err != nil {
t.Fatal(err)
}
if val != 2 {
t.Fatalf("expected 2, but found %d", val)
}
if err := db.QueryRow(NextResults).Scan(&val); err != ErrNoMoreResults {
t.Fatalf("expected %s, but found %v", ErrNoMoreResults, err)
}

// Now test discarding the second result.
if err := db.QueryRow("SELECT 3; SELECT 4").Scan(&val); err != nil {
t.Fatal(err)
}
if val != 3 {
t.Fatalf("expected 3, but found %d", val)
}
if err := db.QueryRow("SELECT 5").Scan(&val); err != nil {
t.Fatal(err)
}
if val != 5 {
t.Fatalf("expected 5, but found %d", val)
}
}

func TestTxnMultipleResults(t *testing.T) {
db := openTestConn(t)
defer db.Close()

tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()

var val int
if err := tx.QueryRow("SELECT 1; SELECT 2").Scan(&val); err != nil {
t.Fatal(err)
}
if val != 1 {
t.Fatalf("expected 1, but found %d", val)
}
if err := tx.QueryRow(NextResults).Scan(&val); err != nil {
t.Fatal(err)
}
if val != 2 {
t.Fatalf("expected 2, but found %d", val)
}
if err := tx.QueryRow(NextResults).Scan(&val); err != ErrNoMoreResults {
t.Fatalf("expected %s, but found %v", ErrNoMoreResults, err)
}
}

func TestParameterCountMismatch(t *testing.T) {
db := openTestConn(t)
defer db.Close()
Expand Down