diff --git a/conn.go b/conn.go index 394e659d..ee6b53ff 100644 --- a/conn.go +++ b/conn.go @@ -32,8 +32,11 @@ var ( ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.") ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.") + ErrNoMoreResults = errors.New("pq: no more results") ) +const NextResults = "NEXT" + type drv struct{} func (d *drv) Open(name string) (driver.Conn, error) { @@ -115,6 +118,9 @@ type conn struct { // Whether to always send []byte parameters over as binary. Enables single // round-trip mode for non-prepared Query calls. binaryParameters bool + + // Whether the connection is ready to execute a query. + readyForQuery bool } // Handle driver-side settings in parsed connection string. @@ -164,7 +170,7 @@ func (c *conn) handlePgpass(o values) { return } mode := fileinfo.Mode() - if mode & (0x77) != 0 { + if mode&(0x77) != 0 { // XXX should warn about incorrect .pgpass permissions as psql does return } @@ -180,7 +186,7 @@ func (c *conn) handlePgpass(o values) { db := o.Get("dbname") username := o.Get("user") // From: https://github.com/tg/pgpass/blob/master/reader.go - getFields := func (s string) []string { + getFields := func(s string) []string { fs := make([]string, 0, 5) f := make([]rune, 0, len(s)) @@ -200,7 +206,7 @@ func (c *conn) handlePgpass(o values) { } } return append(fs, string(f)) - } + } for scanner.Scan() { line := scanner.Text() if len(line) == 0 || line[0] == '#' { @@ -210,7 +216,7 @@ func (c *conn) handlePgpass(o values) { if len(split) != 5 { continue } - if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { + if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { o["password"] = split[4] return } @@ -587,6 +593,8 @@ func (cn *conn) gname() string { } func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) { + cn.waitReadyForQuery() + b := cn.writeBuf('Q') b.string(q) cn.send(b) @@ -614,56 +622,73 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err func (cn *conn) simpleQuery(q string) (res *rows, err error) { defer cn.errRecover(&err) - st := &stmt{cn: cn, name: ""} - - b := cn.writeBuf('Q') - b.string(q) - cn.send(b) + querySent := false + nextResult := q == NextResults for { + if cn.readyForQuery && !querySent { + if nextResult { + return nil, ErrNoMoreResults + } + + // Mark the connection as having sent a query. + cn.readyForQuery = false + b := cn.writeBuf('Q') + b.string(q) + cn.send(b) + querySent = true + } + t, r := cn.recv1() switch t { case 'C', 'I': - // We allow queries which don't return any results through Query as - // well as Exec. We still have to give database/sql a rows object - // the user can close, though, to avoid connections from being - // leaked. A "rows" with done=true works fine for that purpose. - if err != nil { - cn.bad = true - errorf("unexpected message %q in simple query execution", t) - } - if res == nil { - res = &rows{ - cn: cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, + if nextResult || querySent { + // We allow queries which don't return any results through Query as + // well as Exec. We still have to give database/sql a rows object + // the user can close, though, to avoid connections from being + // leaked. A "rows" with done=true works fine for that purpose. + if err != nil { + cn.bad = true + errorf("unexpected message %q in simple query execution", t) } + if res == nil { + res = &rows{ + cn: cn, + } + } + res.done = true } - res.done = true case 'Z': cn.processReadyForQuery(r) - // done - return + if querySent { + // done + return + } case 'E': - res = nil - err = parseError(r) + if nextResult || querySent { + res = nil + err = parseError(r) + } case 'D': - if res == nil { - cn.bad = true - errorf("unexpected DataRow in simple query execution") + if nextResult || querySent { + if res == nil { + cn.bad = true + errorf("unexpected DataRow in simple query execution") + } + // the query didn't fail; kick off to Next + cn.saveMessage(t, r) + return } - // the query didn't fail; kick off to Next - cn.saveMessage(t, r) - return case 'T': - // res might be non-nil here if we received a previous - // CommandComplete, but that's fine; just overwrite it - res = &rows{cn: cn} - res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r) - - // To work around a bug in QueryRow in Go 1.2 and earlier, wait - // until the first DataRow has been received. + if nextResult || querySent { + // res might be non-nil here if we received a previous + // CommandComplete, but that's fine; just overwrite it + res = &rows{cn: cn} + res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r) + + // To work around a bug in QueryRow in Go 1.2 and earlier, wait + // until the first DataRow has been received. + } default: cn.bad = true errorf("unknown response for simple query: %q", t) @@ -747,6 +772,8 @@ func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { } defer cn.errRecover(&err) + cn.waitReadyForQuery() + if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { return cn.prepareCopyIn(q) } @@ -782,6 +809,8 @@ func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err err return cn.simpleQuery(query) } + cn.waitReadyForQuery() + if cn.binaryParameters { cn.sendBinaryModeQuery(query, args) @@ -818,6 +847,8 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err return r, err } + cn.waitReadyForQuery() + if cn.binaryParameters { cn.sendBinaryModeQuery(query, args) @@ -1306,6 +1337,10 @@ func (st *stmt) exec(v []driver.Value) { } cn := st.cn + cn.waitReadyForQuery() + // Mark the connection has having sent a query. + cn.readyForQuery = false + w := cn.writeBuf('B') w.byte(0) // unnamed portal w.string(st.name) @@ -1436,7 +1471,11 @@ func (rs *rows) Next(dest []driver.Value) (err error) { case 'E': err = parseError(&rs.rb) case 'C', 'I': - continue + rs.done = true + if err != nil { + return err + } + return io.EOF case 'Z': conn.processReadyForQuery(&rs.rb) rs.done = true @@ -1532,6 +1571,9 @@ func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) } + // Mark the connection has having sent a query. + cn.readyForQuery = false + b := cn.writeBuf('P') b.byte(0) // unnamed statement b.string(query) @@ -1581,6 +1623,7 @@ func (c *conn) processParameterStatus(r *readBuf) { func (c *conn) processReadyForQuery(r *readBuf) { c.txnStatus = transactionStatus(r.byte()) + c.readyForQuery = true } func (cn *conn) readReadyForQuery() { @@ -1595,6 +1638,21 @@ func (cn *conn) readReadyForQuery() { } } +func (cn *conn) waitReadyForQuery() { + // The postgres server sends a 'Z' command when it is ready to receive a + // query. We use this as a sync marker to skip over commands we're not + // handling in our current state. For example, we might be skipping over + // subsequent results when a query contained multiple statements and only the + // first result was retrieved. + for !cn.readyForQuery { + t, r := cn.recv1() + switch t { + case 'Z': + cn.processReadyForQuery(r) + } + } +} + func (cn *conn) readParseResponse() { t, r := cn.recv1() switch t { diff --git a/conn_test.go b/conn_test.go index 2639c8ef..0c7db598 100644 --- a/conn_test.go +++ b/conn_test.go @@ -136,6 +136,7 @@ func TestOpenURL(t *testing.T) { } const pgpass_file = "/tmp/pqgotest_pgpass" + func TestPgpass(t *testing.T) { testAssert := func(conninfo string, expected string, reason string) { conn, err := openTestConnConninfo(conninfo) @@ -339,6 +340,70 @@ func TestRowsCloseBeforeDone(t *testing.T) { } } +func TestMultipleResults(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + var val int + if err := db.QueryRow("SELECT 1; SELECT 2").Scan(&val); err != nil { + t.Fatal(err) + } + if val != 1 { + t.Fatalf("expected 1, but found %d", val) + } + if err := db.QueryRow(NextResults).Scan(&val); err != nil { + t.Fatal(err) + } + if val != 2 { + t.Fatalf("expected 2, but found %d", val) + } + if err := db.QueryRow(NextResults).Scan(&val); err != ErrNoMoreResults { + t.Fatalf("expected %s, but found %v", ErrNoMoreResults, err) + } + + // Now test discarding the second result. + if err := db.QueryRow("SELECT 3; SELECT 4").Scan(&val); err != nil { + t.Fatal(err) + } + if val != 3 { + t.Fatalf("expected 3, but found %d", val) + } + if err := db.QueryRow("SELECT 5").Scan(&val); err != nil { + t.Fatal(err) + } + if val != 5 { + t.Fatalf("expected 5, but found %d", val) + } +} + +func TestTxnMultipleResults(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + var val int + if err := tx.QueryRow("SELECT 1; SELECT 2").Scan(&val); err != nil { + t.Fatal(err) + } + if val != 1 { + t.Fatalf("expected 1, but found %d", val) + } + if err := tx.QueryRow(NextResults).Scan(&val); err != nil { + t.Fatal(err) + } + if val != 2 { + t.Fatalf("expected 2, but found %d", val) + } + if err := tx.QueryRow(NextResults).Scan(&val); err != ErrNoMoreResults { + t.Fatalf("expected %s, but found %v", ErrNoMoreResults, err) + } +} + func TestParameterCountMismatch(t *testing.T) { db := openTestConn(t) defer db.Close()