Skip to content

Commit

Permalink
Add support for retrieving multiple results.
Browse files Browse the repository at this point in the history
Reworked how the msg-ready command (`Z`) is processed. Previously
execution of a query would look for the msg-ready command before
completing the operation. Now, when executing a query, the driver places
the connection into a state where it knows there may be more results. If
another query is subsequently executed, the driver waits for the
msg-ready command to arrive, discarding any other commands, before
sending the new query. But if the special `NEXT` query is executed, the
driver looks for another set of results on the connection. If no such
results are found, ErrNoMoreResults is returned.
  • Loading branch information
petermattis committed Feb 3, 2016
1 parent 3e033bd commit 73fc92e
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 33 deletions.
124 changes: 91 additions & 33 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ var (
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.")
ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.")
ErrNoMoreResults = errors.New("pq: no more results")
)

const NextResults = "NEXT"

type drv struct{}

func (d *drv) Open(name string) (driver.Conn, error) {
Expand Down Expand Up @@ -115,6 +118,9 @@ type conn struct {
// Whether to always send []byte parameters over as binary. Enables single
// round-trip mode for non-prepared Query calls.
binaryParameters bool

// Whether the connection is ready to execute a query.
readyForQuery bool
}

// Handle driver-side settings in parsed connection string.
Expand Down Expand Up @@ -587,6 +593,8 @@ func (cn *conn) gname() string {
}

func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
cn.waitReadyForQuery()

b := cn.writeBuf('Q')
b.string(q)
cn.send(b)
Expand Down Expand Up @@ -614,51 +622,73 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err
func (cn *conn) simpleQuery(q string) (res *rows, err error) {
defer cn.errRecover(&err)

b := cn.writeBuf('Q')
b.string(q)
cn.send(b)
querySent := false
nextResult := q == NextResults

for {
if cn.readyForQuery && !querySent {
if nextResult {
return nil, ErrNoMoreResults
}

// Mark the connection has having sent a query.
cn.readyForQuery = false
b := cn.writeBuf('Q')
b.string(q)
cn.send(b)
querySent = true
}

t, r := cn.recv1()
switch t {
case 'C', 'I':
// We allow queries which don't return any results through Query as
// well as Exec. We still have to give database/sql a rows object
// the user can close, though, to avoid connections from being
// leaked. A "rows" with done=true works fine for that purpose.
if err != nil {
cn.bad = true
errorf("unexpected message %q in simple query execution", t)
}
if res == nil {
res = &rows{
cn: cn,
if nextResult || querySent {
// We allow queries which don't return any results through Query as
// well as Exec. We still have to give database/sql a rows object
// the user can close, though, to avoid connections from being
// leaked. A "rows" with done=true works fine for that purpose.
if err != nil {
cn.bad = true
errorf("unexpected message %q in simple query execution", t)
}
if res == nil {
res = &rows{
cn: cn,
}
}
res.done = true
}
res.done = true
case 'Z':
cn.processReadyForQuery(r)
// done
return
if querySent {
// done
return
}
case 'E':
res = nil
err = parseError(r)
if nextResult || querySent {
res = nil
err = parseError(r)
}
case 'D':
if res == nil {
cn.bad = true
errorf("unexpected DataRow in simple query execution")
if nextResult || querySent {
if res == nil {
cn.bad = true
errorf("unexpected DataRow in simple query execution")
}
// the query didn't fail; kick off to Next
cn.saveMessage(t, r)
return
}
// the query didn't fail; kick off to Next
cn.saveMessage(t, r)
return
case 'T':
// res might be non-nil here if we received a previous
// CommandComplete, but that's fine; just overwrite it
res = &rows{cn: cn}
res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)

// To work around a bug in QueryRow in Go 1.2 and earlier, wait
// until the first DataRow has been received.
if nextResult || querySent {
// res might be non-nil here if we received a previous
// CommandComplete, but that's fine; just overwrite it
res = &rows{cn: cn}
res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)

// To work around a bug in QueryRow in Go 1.2 and earlier, wait
// until the first DataRow has been received.
}
default:
cn.bad = true
errorf("unknown response for simple query: %q", t)
Expand Down Expand Up @@ -742,6 +772,8 @@ func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
}
defer cn.errRecover(&err)

cn.waitReadyForQuery()

if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
return cn.prepareCopyIn(q)
}
Expand Down Expand Up @@ -777,6 +809,8 @@ func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err err
return cn.simpleQuery(query)
}

cn.waitReadyForQuery()

if cn.binaryParameters {
cn.sendBinaryModeQuery(query, args)

Expand Down Expand Up @@ -813,6 +847,8 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err
return r, err
}

cn.waitReadyForQuery()

if cn.binaryParameters {
cn.sendBinaryModeQuery(query, args)

Expand Down Expand Up @@ -1301,6 +1337,10 @@ func (st *stmt) exec(v []driver.Value) {
}

cn := st.cn
cn.waitReadyForQuery()
// Mark the connection has having sent a query.
cn.readyForQuery = false

w := cn.writeBuf('B')
w.byte(0) // unnamed portal
w.string(st.name)
Expand Down Expand Up @@ -1431,7 +1471,11 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
case 'E':
err = parseError(&rs.rb)
case 'C', 'I':
continue
rs.done = true
if err != nil {
return err
}
return io.EOF
case 'Z':
conn.processReadyForQuery(&rs.rb)
rs.done = true
Expand Down Expand Up @@ -1527,6 +1571,9 @@ func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
}

// Mark the connection has having sent a query.
cn.readyForQuery = false

b := cn.writeBuf('P')
b.byte(0) // unnamed statement
b.string(query)
Expand Down Expand Up @@ -1576,6 +1623,7 @@ func (c *conn) processParameterStatus(r *readBuf) {

func (c *conn) processReadyForQuery(r *readBuf) {
c.txnStatus = transactionStatus(r.byte())
c.readyForQuery = true
}

func (cn *conn) readReadyForQuery() {
Expand All @@ -1590,6 +1638,16 @@ func (cn *conn) readReadyForQuery() {
}
}

func (cn *conn) waitReadyForQuery() {
for !cn.readyForQuery {
t, r := cn.recv1()
switch t {
case 'Z':
cn.processReadyForQuery(r)
}
}
}

func (cn *conn) readParseResponse() {
t, r := cn.recv1()
switch t {
Expand Down
65 changes: 65 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ func TestOpenURL(t *testing.T) {
}

const pgpass_file = "/tmp/pqgotest_pgpass"

func TestPgpass(t *testing.T) {
testAssert := func(conninfo string, expected string, reason string) {
conn, err := openTestConnConninfo(conninfo)
Expand Down Expand Up @@ -339,6 +340,70 @@ func TestRowsCloseBeforeDone(t *testing.T) {
}
}

func TestMultipleResults(t *testing.T) {
db := openTestConn(t)
defer db.Close()

var val int
if err := db.QueryRow("SELECT 1; SELECT 2").Scan(&val); err != nil {
t.Fatal(err)
}
if val != 1 {
t.Fatalf("expected 1, but found %d", val)
}
if err := db.QueryRow(NextResults).Scan(&val); err != nil {
t.Fatal(err)
}
if val != 2 {
t.Fatalf("expected 2, but found %d", val)
}
if err := db.QueryRow(NextResults).Scan(&val); err != ErrNoMoreResults {
t.Fatalf("expected %s, but found %v", ErrNoMoreResults, err)
}

// Now test discarding the second result.
if err := db.QueryRow("SELECT 3; SELECT 4").Scan(&val); err != nil {
t.Fatal(err)
}
if val != 3 {
t.Fatalf("expected 3, but found %d", val)
}
if err := db.QueryRow("SELECT 5").Scan(&val); err != nil {
t.Fatal(err)
}
if val != 5 {
t.Fatalf("expected 5, but found %d", val)
}
}

func TestTxnMultipleResults(t *testing.T) {
db := openTestConn(t)
defer db.Close()

tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()

var val int
if err := tx.QueryRow("SELECT 1; SELECT 2").Scan(&val); err != nil {
t.Fatal(err)
}
if val != 1 {
t.Fatalf("expected 1, but found %d", val)
}
if err := tx.QueryRow(NextResults).Scan(&val); err != nil {
t.Fatal(err)
}
if val != 2 {
t.Fatalf("expected 2, but found %d", val)
}
if err := tx.QueryRow(NextResults).Scan(&val); err != ErrNoMoreResults {
t.Fatalf("expected %s, but found %v", ErrNoMoreResults, err)
}
}

func TestParameterCountMismatch(t *testing.T) {
db := openTestConn(t)
defer db.Close()
Expand Down

0 comments on commit 73fc92e

Please sign in to comment.