From 178239ebea99cac945709adf30b34c14416d1b7e Mon Sep 17 00:00:00 2001 From: Yahor Yuzefovich Date: Fri, 22 Jan 2021 15:32:56 -0800 Subject: [PATCH] sql: make internal executor streaming This commit updates the internal executor to operate in a streaming fashion by refactoring its internal logic to implement an iterator pattern. A new method `QueryInternalEx` (and its counterpart `QueryInternal`) is introduced (both not used currently) while all existing methods of `InternalExecutor` interface are implemented using the new iterator logic. The communication between the iterator goroutine (the receiver) and the connExecutor goroutine (the sender) is done via a buffered (of 32 size in non-test setting) channel. The channel is closed when the connExecutor goroutine exits its run() loop. Care needs to be taken when closing the iterator - we need to make sure to close the stmtBuf (so that there are no more commands for the connExecutor goroutine to execute) and then we need to drain the channel (since the connExecutor goroutine might be blocked on adding a row to the channel). After that we have to wait for the connExecutor goroutine to exit so that we can finish the tracing span. For convenience purposes, if the iterator is fully exhausted, it will get closed automatically. Release note: None --- .../protectedts/ptstorage/storage_test.go | 17 + pkg/sql/conn_executor_internal_test.go | 52 +-- pkg/sql/conn_io.go | 78 ++-- pkg/sql/copy.go | 2 +- pkg/sql/internal.go | 388 ++++++++++++++---- pkg/sql/internal_test.go | 2 +- pkg/sql/sqlutil/internal_executor.go | 57 +++ 7 files changed, 454 insertions(+), 142 deletions(-) diff --git a/pkg/kv/kvserver/protectedts/ptstorage/storage_test.go b/pkg/kv/kvserver/protectedts/ptstorage/storage_test.go index 235b4377c33d..db135d691214 100644 --- a/pkg/kv/kvserver/protectedts/ptstorage/storage_test.go +++ b/pkg/kv/kvserver/protectedts/ptstorage/storage_test.go @@ -711,6 +711,23 @@ func (ie *wrappedInternalExecutor) QueryRow( panic("not implemented") } +func (ie *wrappedInternalExecutor) QueryIterator( + ctx context.Context, opName string, txn *kv.Txn, stmt string, qargs ...interface{}, +) (sqlutil.InternalRows, error) { + panic("not implemented") +} + +func (ie *wrappedInternalExecutor) QueryIteratorEx( + ctx context.Context, + opName string, + txn *kv.Txn, + session sessiondata.InternalExecutorOverride, + stmt string, + qargs ...interface{}, +) (sqlutil.InternalRows, error) { + panic("not implemented") +} + func (ie *wrappedInternalExecutor) getErrFunc() func(statement string) error { ie.mu.RLock() defer ie.mu.RUnlock() diff --git a/pkg/sql/conn_executor_internal_test.go b/pkg/sql/conn_executor_internal_test.go index 987f5efc1c5a..30530ede0999 100644 --- a/pkg/sql/conn_executor_internal_test.go +++ b/pkg/sql/conn_executor_internal_test.go @@ -55,7 +55,7 @@ func TestPortalsDestroyedOnTxnFinish(t *testing.T) { defer log.Scope(t).Close(t) ctx := context.Background() - buf, syncResults, finished, stopper, err := startConnExecutor(ctx) + buf, syncResults, finished, stopper, _, err := startConnExecutor(ctx) if err != nil { t.Fatal(err) } @@ -69,11 +69,7 @@ func TestPortalsDestroyedOnTxnFinish(t *testing.T) { // succeed and the 2nd one to fail (since the portal is destroyed after the // Execute). cmdPos := 0 - stmt := mustParseOne("SELECT 1") - if err != nil { - t.Fatal(err) - } - if err = buf.Push(ctx, PrepareStmt{Name: "ps_nontxn", Statement: stmt}); err != nil { + if err = buf.Push(ctx, PrepareStmt{Name: "ps_nontxn", Statement: mustParseOne("SELECT 1")}); err != nil { t.Fatal(err) } @@ -121,7 +117,7 @@ func TestPortalsDestroyedOnTxnFinish(t *testing.T) { if numResults != cmdPos+1 { t.Fatalf("expected %d results, got: %d", cmdPos+1, len(results)) } - if err := results[successfulDescribePos].err; err != nil { + if err = results[successfulDescribePos].err; err != nil { t.Fatalf("expected first Describe to succeed, got err: %s", err) } if !testutils.IsError(results[failedDescribePos].err, "unknown portal") { @@ -134,20 +130,12 @@ func TestPortalsDestroyedOnTxnFinish(t *testing.T) { // after the COMMIT). The point of the SELECT is to show that the portal // survives execution of a statement. cmdPos++ - stmt = mustParseOne("BEGIN") - if err != nil { - t.Fatal(err) - } - if err := buf.Push(ctx, ExecStmt{Statement: stmt}); err != nil { + if err = buf.Push(ctx, ExecStmt{Statement: mustParseOne("BEGIN")}); err != nil { t.Fatal(err) } cmdPos++ - stmt = mustParseOne("SELECT 1") - if err != nil { - t.Fatal(err) - } - if err = buf.Push(ctx, PrepareStmt{Name: "ps1", Statement: stmt}); err != nil { + if err = buf.Push(ctx, PrepareStmt{Name: "ps1", Statement: mustParseOne("SELECT 1")}); err != nil { t.Fatal(err) } @@ -160,11 +148,7 @@ func TestPortalsDestroyedOnTxnFinish(t *testing.T) { } cmdPos++ - stmt = mustParseOne("SELECT 2") - if err != nil { - t.Fatal(err) - } - if err := buf.Push(ctx, ExecStmt{Statement: stmt}); err != nil { + if err = buf.Push(ctx, ExecStmt{Statement: mustParseOne("SELECT 2")}); err != nil { t.Fatal(err) } @@ -178,11 +162,7 @@ func TestPortalsDestroyedOnTxnFinish(t *testing.T) { } cmdPos++ - stmt = mustParseOne("COMMIT") - if err != nil { - t.Fatal(err) - } - if err := buf.Push(ctx, ExecStmt{Statement: stmt}); err != nil { + if err = buf.Push(ctx, ExecStmt{Statement: mustParseOne("COMMIT")}); err != nil { t.Fatal(err) } @@ -207,7 +187,7 @@ func TestPortalsDestroyedOnTxnFinish(t *testing.T) { t.Fatalf("expected %d results, got: %d", exp, len(results)) } succDescIdx := successfulDescribePos - numResults - if err := results[succDescIdx].err; err != nil { + if err = results[succDescIdx].err; err != nil { t.Fatalf("expected first Describe to succeed, got err: %s", err) } failDescIdx := failedDescribePos - numResults @@ -216,7 +196,7 @@ func TestPortalsDestroyedOnTxnFinish(t *testing.T) { } buf.Close() - if err := <-finished; err != nil { + if err = <-finished; err != nil { t.Fatal(err) } } @@ -240,9 +220,13 @@ func mustParseOne(s string) parser.Statement { // gets the error from closing down the executor once the StmtBuf is closed, a // stopper that must be stopped when the test completes (this does not stop the // executor but stops other background work). +// +// It also returns a channel that AddRow might block on which can buffer up to +// 16 items (including column types when applicable), so the caller might need +// to receive from it occasionally. func startConnExecutor( ctx context.Context, -) (*StmtBuf, <-chan []resWithPos, <-chan error, *stop.Stopper, error) { +) (*StmtBuf, <-chan []resWithPos, <-chan error, *stop.Stopper, <-chan ieIteratorResult, error) { // A lot of boilerplate for creating a connExecutor. stopper := stop.NewStopper() clock := hlc.NewClock(hlc.UnixNano, 0 /* maxOffset */) @@ -258,7 +242,7 @@ func startConnExecutor( gw := gossip.MakeOptionalGossip(nil) tempEngine, tempFS, err := storage.NewTempEngine(ctx, base.DefaultTestTempStorageConfig(st), base.DefaultTestStoreSpec) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, err } defer tempEngine.Close() cfg := &ExecutorConfig{ @@ -305,16 +289,18 @@ func startConnExecutor( s := NewServer(cfg, pool) buf := NewStmtBuf() syncResults := make(chan []resWithPos, 1) + iteratorCh := make(chan ieIteratorResult, 16) var cc ClientComm = &internalClientComm{ sync: func(res []resWithPos) { syncResults <- res }, + ch: iteratorCh, } sqlMetrics := MakeMemMetrics("test" /* endpoint */, time.Second /* histogramWindow */) conn, err := s.SetupConn(ctx, SessionArgs{}, buf, cc, sqlMetrics) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, err } finished := make(chan error) @@ -324,7 +310,7 @@ func startConnExecutor( go func() { finished <- s.ServeConn(ctx, conn, mon.BoundAccount{}, nil /* cancel */) }() - return buf, syncResults, finished, stopper, nil + return buf, syncResults, finished, stopper, iteratorCh, nil } // Test that a client session can close without deadlocking when the closing diff --git a/pkg/sql/conn_io.go b/pkg/sql/conn_io.go index a0e75edd7904..e4f95c5756f5 100644 --- a/pkg/sql/conn_io.go +++ b/pkg/sql/conn_io.go @@ -858,108 +858,110 @@ type resCloseType bool const closed resCloseType = true const discarded resCloseType = false -// bufferedCommandResult is a CommandResult that buffers rows and can call a -// provided callback when closed. -type bufferedCommandResult struct { +// streamingCommandResult is a CommandResult that streams rows on the channel +// and can call a provided callback when closed. +type streamingCommandResult struct { + ch chan ieIteratorResult err error - rows []tree.Datums rowsAffected int - cols colinfo.ResultColumns - - // errOnly, if set, makes AddRow() panic. This can be used when the execution - // of the query is not expected to produce any results. - errOnly bool // closeCallback, if set, is called when Close()/Discard() is called. - closeCallback func(*bufferedCommandResult, resCloseType, error) + closeCallback func(*streamingCommandResult, resCloseType) } -var _ RestrictedCommandResult = &bufferedCommandResult{} -var _ CommandResultClose = &bufferedCommandResult{} +var _ RestrictedCommandResult = &streamingCommandResult{} +var _ CommandResultClose = &streamingCommandResult{} // SetColumns is part of the RestrictedCommandResult interface. -func (r *bufferedCommandResult) SetColumns(_ context.Context, cols colinfo.ResultColumns) { - if r.errOnly { - panic("SetColumns() called when errOnly is set") - } - r.cols = cols +func (r *streamingCommandResult) SetColumns(ctx context.Context, cols colinfo.ResultColumns) { + r.ch <- ieIteratorResult{cols: cols} } // BufferParamStatusUpdate is part of the RestrictedCommandResult interface. -func (r *bufferedCommandResult) BufferParamStatusUpdate(key string, val string) { +func (r *streamingCommandResult) BufferParamStatusUpdate(key string, val string) { panic("unimplemented") } // BufferNotice is part of the RestrictedCommandResult interface. -func (r *bufferedCommandResult) BufferNotice(notice pgnotice.Notice) { +func (r *streamingCommandResult) BufferNotice(notice pgnotice.Notice) { panic("unimplemented") } // ResetStmtType is part of the RestrictedCommandResult interface. -func (r *bufferedCommandResult) ResetStmtType(stmt tree.Statement) { +func (r *streamingCommandResult) ResetStmtType(stmt tree.Statement) { panic("unimplemented") } // AddRow is part of the RestrictedCommandResult interface. -func (r *bufferedCommandResult) AddRow(ctx context.Context, row tree.Datums) error { - if r.errOnly { - panic("AddRow() called when errOnly is set") - } +func (r *streamingCommandResult) AddRow(ctx context.Context, row tree.Datums) error { + // AddRow() and IncrementRowsAffected() are never called on the same command + // result, so we will not double count the affected rows by an increment + // here. + r.rowsAffected++ rowCopy := make(tree.Datums, len(row)) copy(rowCopy, row) - r.rows = append(r.rows, rowCopy) + r.ch <- ieIteratorResult{row: rowCopy} return nil } -func (r *bufferedCommandResult) DisableBuffering() { +func (r *streamingCommandResult) DisableBuffering() { panic("cannot disable buffering here") } // SetError is part of the RestrictedCommandResult interface. -func (r *bufferedCommandResult) SetError(err error) { +func (r *streamingCommandResult) SetError(err error) { r.err = err + // Note that we intentionally do not send the error on the channel (when it + // is present) since we might replace the error with another one later which + // is allowed by the interface. An example of this is queryDone() closure + // in execStmtInOpenState(). } // Err is part of the RestrictedCommandResult interface. -func (r *bufferedCommandResult) Err() error { +func (r *streamingCommandResult) Err() error { return r.err } // IncrementRowsAffected is part of the RestrictedCommandResult interface. -func (r *bufferedCommandResult) IncrementRowsAffected(n int) { +func (r *streamingCommandResult) IncrementRowsAffected(n int) { r.rowsAffected += n + if r.ch != nil { + // streamingCommandResult might be used outside of the internal executor + // (i.e. not by rowsIterator) in which case the channel is not set. + r.ch <- ieIteratorResult{rowsAffectedIncrement: &n} + } } // RowsAffected is part of the RestrictedCommandResult interface. -func (r *bufferedCommandResult) RowsAffected() int { +func (r *streamingCommandResult) RowsAffected() int { return r.rowsAffected } // Close is part of the CommandResultClose interface. -func (r *bufferedCommandResult) Close(context.Context, TransactionStatusIndicator) { +func (r *streamingCommandResult) Close(context.Context, TransactionStatusIndicator) { if r.closeCallback != nil { - r.closeCallback(r, closed, nil /* err */) + r.closeCallback(r, closed) } } // Discard is part of the CommandResult interface. -func (r *bufferedCommandResult) Discard() { +func (r *streamingCommandResult) Discard() { if r.closeCallback != nil { - r.closeCallback(r, discarded, nil /* err */) + r.closeCallback(r, discarded) } } // SetInferredTypes is part of the DescribeResult interface. -func (r *bufferedCommandResult) SetInferredTypes([]oid.Oid) {} +func (r *streamingCommandResult) SetInferredTypes([]oid.Oid) {} // SetNoDataRowDescription is part of the DescribeResult interface. -func (r *bufferedCommandResult) SetNoDataRowDescription() {} +func (r *streamingCommandResult) SetNoDataRowDescription() {} // SetPrepStmtOutput is part of the DescribeResult interface. -func (r *bufferedCommandResult) SetPrepStmtOutput(context.Context, colinfo.ResultColumns) {} +func (r *streamingCommandResult) SetPrepStmtOutput(context.Context, colinfo.ResultColumns) {} // SetPortalOutput is part of the DescribeResult interface. -func (r *bufferedCommandResult) SetPortalOutput( +func (r *streamingCommandResult) SetPortalOutput( context.Context, colinfo.ResultColumns, []pgwirebase.FormatCode, ) { } diff --git a/pkg/sql/copy.go b/pkg/sql/copy.go index bf46834cd48f..4259b740000b 100644 --- a/pkg/sql/copy.go +++ b/pkg/sql/copy.go @@ -595,7 +595,7 @@ func (c *copyMachine) insertRows(ctx context.Context) (retErr error) { return err } - var res bufferedCommandResult + var res streamingCommandResult err := c.execInsertPlan(ctx, &c.p, &res) if err != nil { return err diff --git a/pkg/sql/internal.go b/pkg/sql/internal.go index a8acac7b5ec5..ca8adf2c3e0d 100644 --- a/pkg/sql/internal.go +++ b/pkg/sql/internal.go @@ -30,6 +30,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sessiondatapb" "github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry" "github.com/cockroachdb/cockroach/pkg/sql/sqlutil" + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util" "github.com/cockroachdb/cockroach/pkg/util/mon" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/tracing" @@ -128,18 +130,23 @@ func (ie *InternalExecutor) SetSessionData(sessionData *sessiondata.SessionData) // // If txn is not nil, the statement will be executed in the respective txn. // +// ch is used by the connExecutor goroutine to send the rows and will be closed +// once the goroutine exits its run() loop. +// // sd will constitute the executor's session state. func (ie *InternalExecutor) initConnEx( ctx context.Context, txn *kv.Txn, + ch chan ieIteratorResult, sd *sessiondata.SessionData, syncCallback func([]resWithPos), errCallback func(error), -) (*StmtBuf, *sync.WaitGroup, error) { +) (*StmtBuf, *sync.WaitGroup) { clientComm := &internalClientComm{ - sync: syncCallback, + ch: ch, // init lastDelivered below the position of the first result (0). lastDelivered: -1, + sync: syncCallback, } // When the connEx is serving an internal executor, it can inherit the @@ -196,6 +203,7 @@ func (ie *InternalExecutor) initConnEx( sqltelemetry.RecordError(ctx, err, &ex.server.cfg.Settings.SV) errCallback(err) } + close(ch) closeMode := normalClose if txn != nil { closeMode = externalTxnClose @@ -203,7 +211,155 @@ func (ie *InternalExecutor) initConnEx( ex.close(ctx, closeMode) wg.Done() }() - return stmtBuf, &wg, nil + return stmtBuf, &wg +} + +type ieIteratorResult struct { + // Exactly one of these 4 fields will be set. + row tree.Datums + rowsAffectedIncrement *int + cols colinfo.ResultColumns + err error +} + +type rowsIterator struct { + // ch is the channel on which the connExecutor goroutine sends the rows (in + // streamingCommandResult.AddRow). The iterator goroutine receives rows (or + // other metadata) and will block on this channel if it is empty. The + // channel will be closed when the connExecutor goroutine exits its run() + // loop. + ch chan ieIteratorResult + rowsAffected int + resultCols colinfo.ResultColumns + + lastRow tree.Datums + lastErr error + done bool + + // errCallback is an optional callback that will be called exactly once + // before an error is returned by Next() or Close(). + errCallback func(err error) error + + // stmtBuf will be closed on Close(). This is necessary in order to tell + // the connExecutor's goroutine to exit when the iterator's user wants to + // short-circuit the iteration (i.e. before Next() returns false). + stmtBuf *StmtBuf + + // wg can be used to wait for the connExecutor's goroutine to exit. + wg *sync.WaitGroup + + // sp will finished on Close(). + sp *tracing.Span +} + +var _ sqlutil.InternalRows = &rowsIterator{} + +func (r *rowsIterator) Next(ctx context.Context) (_ bool, retErr error) { + if r.done { + return false, r.lastErr + } + + // Due to recursive calls to Next() below, this deferred function might get + // executed multiple times, yet it is not a problem because Close() is + // idempotent and we're unsetting the error callback. + defer func() { + // If the iterator has just reached its terminal state, we'll close it + // automatically. + if r.done { + // We can ignore the returned error because Close() will update + // r.lastErr if necessary. + _ /* err */ = r.Close() + } + if r.errCallback != nil { + r.lastErr = r.errCallback(r.lastErr) + r.errCallback = nil + } + retErr = r.lastErr + }() + + select { + case next, ok := <-r.ch: + if !ok { + r.done = true + return false, nil + } + if next.row != nil { + r.rowsAffected++ + // No need to make a copy because streamingCommandResult does that + // for us. + r.lastRow = next.row + return true, nil + } + if next.rowsAffectedIncrement != nil { + r.rowsAffected += *next.rowsAffectedIncrement + return r.Next(ctx) + } + if next.cols != nil { + // Ignore the result columns if they are already set on the + // iterator: it is possible for ROWS statement type to be executed + // in a 'rows affected' mode, in such case the correct columns are + // set manually when instantiating an iterator, but the result + // columns of the statement are also sent by SetColumns() (we need + // to keep the former). + if r.resultCols == nil { + r.resultCols = next.cols + } + return r.Next(ctx) + } + if next.err == nil { + next.err = errors.AssertionFailedf("unexpectedly empty ieIteratorResult object") + } + r.lastErr = next.err + r.done = true + return false, r.lastErr + + case <-ctx.Done(): + r.lastErr = ctx.Err() + r.done = true + return false, r.lastErr + } +} + +func (r *rowsIterator) Cur() tree.Datums { + return r.lastRow +} + +func (r *rowsIterator) Close() error { + // Closing the stmtBuf will tell the connExecutor to stop executing commands + // (if it hasn't exited yet). + r.stmtBuf.Close() + // We need to finish the span but only after the connExecutor goroutine is + // done. + defer func() { + if r.sp != nil { + r.wg.Wait() + r.sp.Finish() + r.sp = nil + } + }() + + // We also need to exhaust the channel since the connExecutor goroutine + // might be blocked on sending the row in AddRow(). + // TODO(yuzefovich): at the moment, the connExecutor goroutine will not stop + // execution of the current command right away when the stmtBuf is closed + // (e.g. if it is currently executing ExecStmt command, all rows will still + // be pushed into the channel). Improve this. + for res := range r.ch { + // We are only interested in possible errors if we haven't already seen + // one. All other things are simply ignored. + if res.err != nil && r.lastErr == nil { + r.lastErr = res.err + if r.errCallback != nil { + r.lastErr = r.errCallback(r.lastErr) + r.errCallback = nil + } + } + } + return r.lastErr +} + +func (r *rowsIterator) Types() colinfo.ResultColumns { + return r.resultCols } // Query executes the supplied SQL statement and returns the resulting rows. @@ -233,7 +389,7 @@ func (ie *InternalExecutor) QueryEx( stmt string, qargs ...interface{}, ) ([]tree.Datums, error) { - datums, _, err := ie.queryInternal(ctx, opName, txn, session, stmt, qargs...) + datums, _, err := ie.queryInternalBuffered(ctx, opName, txn, session, stmt, 0 /* limit */, qargs...) return datums, err } @@ -247,22 +403,38 @@ func (ie *InternalExecutor) QueryWithCols( stmt string, qargs ...interface{}, ) ([]tree.Datums, colinfo.ResultColumns, error) { - return ie.queryInternal(ctx, opName, txn, session, stmt, qargs...) + return ie.queryInternalBuffered(ctx, opName, txn, session, stmt, 0 /* limit */, qargs...) } -func (ie *InternalExecutor) queryInternal( +func (ie *InternalExecutor) queryInternalBuffered( ctx context.Context, opName string, txn *kv.Txn, sessionDataOverride sessiondata.InternalExecutorOverride, stmt string, + // Non-zero limit specifies the limit on the number of rows returned. + limit int, qargs ...interface{}, ) ([]tree.Datums, colinfo.ResultColumns, error) { - res, err := ie.execInternal(ctx, opName, txn, sessionDataOverride, stmt, qargs...) + it, err := ie.execInternal(ctx, opName, txn, sessionDataOverride, stmt, qargs...) + if err != nil { + return nil, nil, err + } + var rows []tree.Datums + var ok bool + for ok, err = it.Next(ctx); ok; ok, err = it.Next(ctx) { + rows = append(rows, it.Cur()) + if limit != 0 && len(rows) == limit { + // We have accumulated the requested number of rows, so we can + // short-circuit the iteration. + err = it.Close() + break + } + } if err != nil { return nil, nil, err } - return res.rows, res.cols, res.err + return rows, it.Types(), nil } // QueryRow is like Query, except it returns a single row, or nil if not row is @@ -288,7 +460,7 @@ func (ie *InternalExecutor) QueryRowEx( stmt string, qargs ...interface{}, ) (tree.Datums, error) { - rows, err := ie.QueryEx(ctx, opName, txn, session, stmt, qargs...) + rows, _, err := ie.queryInternalBuffered(ctx, opName, txn, session, stmt, 2 /* limit */, qargs...) if err != nil { return nil, err } @@ -329,18 +501,45 @@ func (ie *InternalExecutor) ExecEx( stmt string, qargs ...interface{}, ) (int, error) { - res, err := ie.execInternal(ctx, opName, txn, session, stmt, qargs...) + it, err := ie.execInternal(ctx, opName, txn, session, stmt, qargs...) + if err != nil { + return 0, err + } + // We need to exhaust the iterator so that it can count the number of rows + // affected. + var ok bool + for ok, err = it.Next(ctx); ok; ok, err = it.Next(ctx) { + } if err != nil { return 0, err } - return res.rowsAffected, res.err + return it.rowsAffected, nil } -type result struct { - rows []tree.Datums - rowsAffected int - cols colinfo.ResultColumns - err error +// QueryIterator executes the query, returning an iterator that can be used +// to get the results. If the call is successful, the returned iterator +// *must* be closed. +// +// QueryIterator is deprecated because it may transparently execute a query +// as root. Use QueryIteratorEx instead. +func (ie *InternalExecutor) QueryIterator( + ctx context.Context, opName string, txn *kv.Txn, stmt string, qargs ...interface{}, +) (sqlutil.InternalRows, error) { + return ie.QueryIteratorEx(ctx, opName, txn, ie.maybeRootSessionDataOverride(opName), stmt, qargs...) +} + +// QueryIteratorEx executes the query, returning an iterator that can be used +// to get the results. If the call is successful, the returned iterator +// *must* be closed. +func (ie *InternalExecutor) QueryIteratorEx( + ctx context.Context, + opName string, + txn *kv.Txn, + session sessiondata.InternalExecutorOverride, + stmt string, + qargs ...interface{}, +) (sqlutil.InternalRows, error) { + return ie.execInternal(ctx, opName, txn, session, stmt, qargs...) } // applyOverrides overrides the respective fields from sd for all the fields set on o. @@ -381,6 +580,20 @@ func (ie *InternalExecutor) maybeRootSessionDataOverride( return o } +var rowsAffectedResultColumns = colinfo.ResultColumns{ + colinfo.ResultColumn{ + Name: "rows_affected", + Typ: types.Int, + }, +} + +var ieIteratorChannelBufferSize = util.ConstantWithMetamorphicTestRange( + "iterator-channel-buffer-size", + 32, /* defaultValue */ + 1, /* min */ + 32, /* max */ +) + // execInternal executes a statement. // // sessionDataOverride can be used to control select fields in the executor's @@ -393,7 +606,7 @@ func (ie *InternalExecutor) execInternal( sessionDataOverride sessiondata.InternalExecutorOverride, stmt string, qargs ...interface{}, -) (retRes result, retErr error) { +) (r *rowsIterator, retErr error) { ctx = logtags.AddTag(ctx, "intExec", opName) var sd *sessiondata.SessionData @@ -406,12 +619,21 @@ func (ie *InternalExecutor) execInternal( } applyOverrides(sessionDataOverride, sd) if sd.User().Undefined() { - return result{}, errors.AssertionFailedf("no user specified for internal query") + return nil, errors.AssertionFailedf("no user specified for internal query") } if sd.ApplicationName == "" { sd.ApplicationName = catconstants.InternalAppNamePrefix + "-" + opName } + // The returned span is finished by this function in all error paths, but if + // an iterator is returned, then we transfer the responsibility of closing + // the span to the iterator. This is necessary so that the connExecutor + // exits before the span is finished. + ctx, sp := tracing.EnsureChildSpan(ctx, ie.s.cfg.AmbientCtx.Tracer, opName) + + var stmtBuf *StmtBuf + var wg *sync.WaitGroup + defer func() { // We wrap errors with the opName, but not if they're retriable - in that // case we need to leave the error intact so that it can be retried at a @@ -419,64 +641,82 @@ func (ie *InternalExecutor) execInternal( // // TODO(knz): track the callers and check whether opName could be turned // into a type safe for reporting. - if retErr != nil && !errIsRetriable(retErr) { - retErr = errors.Wrapf(retErr, "%s", opName) - } - if retRes.err != nil && !errIsRetriable(retRes.err) { - retRes.err = errors.Wrapf(retRes.err, "%s", opName) + if retErr != nil { + if !errIsRetriable(retErr) { + retErr = errors.Wrapf(retErr, "%s", opName) + } + if stmtBuf != nil { + // If stmtBuf is non-nil, then the connExecutor goroutine has + // been spawn up - we gotta wait for it to exit. + // + // Note that at the moment of writing when retErr is non-nil, + // the stmtBuf is necessarily nil (the only errors emitted after + // the connExecutor is initialized are the errors on pushing + // into the stmtBuf, and those could occur only if the stmtBuf + // is closed which would indicate problems with + // synchronization). In any case, we want to be safe and handle + // such a scenario accordingly. + stmtBuf.Close() + wg.Wait() + } + sp.Finish() + } else { + // r must be non-nil here. + r.errCallback = func(err error) error { + if err != nil && !errIsRetriable(err) { + err = errors.Wrapf(err, "%s", opName) + } + return err + } + r.sp = sp } }() - ctx, sp := tracing.EnsureChildSpan(ctx, ie.s.cfg.AmbientCtx.Tracer, opName) - defer sp.Finish() - timeReceived := timeutil.Now() parseStart := timeReceived parsed, err := parser.ParseOne(stmt) if err != nil { - return result{}, err + return nil, err } parseEnd := timeutil.Now() + // Transforms the args to datums. The datum types will be passed as type + // hints to the PrepareStmt command below. + datums, err := golangFillQueryArguments(qargs...) + if err != nil { + return nil, err + } + // resPos will be set to the position of the command that represents the // statement we care about before that command is sent for execution. var resPos CmdPos - resCh := make(chan result) - var resultsReceived bool + ch := make(chan ieIteratorResult, ieIteratorChannelBufferSize) syncCallback := func(results []resWithPos) { - resultsReceived = true + // Close the stmtBuf so that the connExecutor exits its run() loop. + stmtBuf.Close() for _, res := range results { - if res.pos == resPos { - resCh <- result{rows: res.rows, rowsAffected: res.RowsAffected(), cols: res.cols, err: res.Err()} + if res.Err() != nil { + // If we encounter an error, there's no point in looking + // further; the rest of the commands in the batch have been + // skipped. + ch <- ieIteratorResult{err: res.Err()} return } - if res.err != nil { - // If we encounter an error, there's no point in looking further; the - // rest of the commands in the batch have been skipped. - resCh <- result{err: res.Err()} + if res.pos == resPos { return } } - resCh <- result{err: errors.AssertionFailedf("missing result for pos: %d and no previous error", resPos)} + ch <- ieIteratorResult{err: errors.AssertionFailedf("missing result for pos: %d and no previous error", resPos)} } errCallback := func(err error) { - if resultsReceived { - return - } - resCh <- result{err: err} - } - stmtBuf, wg, err := ie.initConnEx(ctx, txn, sd, syncCallback, errCallback) - if err != nil { - return result{}, err + // The connExecutor exited its run() loop, so the stmtBuf must have been + // closed. Still, since Close() is idempotent, we'll call it here too. + stmtBuf.Close() + ch <- ieIteratorResult{err: err} } + stmtBuf, wg = ie.initConnEx(ctx, txn, ch, sd, syncCallback, errCallback) - // Transforms the args to datums. The datum types will be passed as type hints - // to the PrepareStmt command. - datums, err := golangFillQueryArguments(qargs...) - if err != nil { - return result{}, err - } typeHints := make(tree.PlaceholderTypes, len(datums)) for i, d := range datums { // Arg numbers start from 1. @@ -492,7 +732,7 @@ func (ie *InternalExecutor) execInternal( ParseStart: parseStart, ParseEnd: parseEnd, }); err != nil { - return result{}, err + return nil, err } } else { resPos = 2 @@ -505,25 +745,31 @@ func (ie *InternalExecutor) execInternal( TypeHints: typeHints, }, ); err != nil { - return result{}, err + return nil, err } if err := stmtBuf.Push(ctx, BindStmt{internalArgs: datums}); err != nil { - return result{}, err + return nil, err } if err := stmtBuf.Push(ctx, ExecPortal{TimeReceived: timeReceived}); err != nil { - return result{}, err + return nil, err } } if err := stmtBuf.Push(ctx, Sync{}); err != nil { - return result{}, err + return nil, err } - res := <-resCh - stmtBuf.Close() - wg.Wait() - return res, nil + var resultColumns colinfo.ResultColumns + if parsed.AST.StatementType() != tree.Rows { + resultColumns = rowsAffectedResultColumns + } + return &rowsIterator{ + ch: ch, + resultCols: resultColumns, + stmtBuf: stmtBuf, + wg: wg, + }, nil } // internalClientComm is an implementation of ClientComm used by the @@ -533,17 +779,20 @@ type internalClientComm struct { // InternalExecutor. results []resWithPos + // ch is the channel on which the results of the query execution (ExecStmt + // or ExecPortal commands) are propagated to the consumer (the iterator). + ch chan ieIteratorResult + lastDelivered CmdPos - // sync, if set, is called whenever a Sync is executed. It returns all the - // results since the previous Sync. + // sync, if set, is called whenever a Sync is executed. sync func([]resWithPos) } var _ ClientComm = &internalClientComm{} type resWithPos struct { - *bufferedCommandResult + *streamingCommandResult pos CmdPos } @@ -564,15 +813,16 @@ func (icc *internalClientComm) CreateStatementResult( // createRes creates a result. onClose, if not nil, is called when the result is // closed. -func (icc *internalClientComm) createRes(pos CmdPos, onClose func(error)) *bufferedCommandResult { - res := &bufferedCommandResult{ - closeCallback: func(res *bufferedCommandResult, typ resCloseType, err error) { +func (icc *internalClientComm) createRes(pos CmdPos, onClose func()) *streamingCommandResult { + res := &streamingCommandResult{ + ch: icc.ch, + closeCallback: func(res *streamingCommandResult, typ resCloseType) { if typ == discarded { return } - icc.results = append(icc.results, resWithPos{bufferedCommandResult: res, pos: pos}) + icc.results = append(icc.results, resWithPos{streamingCommandResult: res, pos: pos}) if onClose != nil { - onClose(err) + onClose() } }, } @@ -593,7 +843,7 @@ func (icc *internalClientComm) CreateBindResult(pos CmdPos) BindResult { // // The returned SyncResult will call the sync callback when its closed. func (icc *internalClientComm) CreateSyncResult(pos CmdPos) SyncResult { - return icc.createRes(pos, func(err error) { + return icc.createRes(pos, func() { results := make([]resWithPos, len(icc.results)) copy(results, icc.results) icc.results = icc.results[:0] diff --git a/pkg/sql/internal_test.go b/pkg/sql/internal_test.go index 90a70fbc2de1..1b67503dfdb6 100644 --- a/pkg/sql/internal_test.go +++ b/pkg/sql/internal_test.go @@ -458,7 +458,7 @@ func testInternalExecutorAppNameInitialization( } // We'll want to look at statistics below, and finish the test with - // no goroutine leakage. To achieve this, cancel the query. and + // no goroutine leakage. To achieve this, cancel the query and // drain the goroutine. if _, err := ie.Exec(context.Background(), "cancel-query", nil, "CANCEL QUERY $1", queryID); err != nil { t.Fatal(err) diff --git a/pkg/sql/sqlutil/internal_executor.go b/pkg/sql/sqlutil/internal_executor.go index 4d6b65791387..ecdfaf864d31 100644 --- a/pkg/sql/sqlutil/internal_executor.go +++ b/pkg/sql/sqlutil/internal_executor.go @@ -104,6 +104,63 @@ type InternalExecutor interface { stmt string, qargs ...interface{}, ) (tree.Datums, error) + + // QueryIterator executes the query, returning an iterator that can be used + // to get the results. If the call is successful, the returned iterator + // *must* be closed. + // + // QueryIterator is deprecated because it may transparently execute a query + // as root. Use QueryIteratorEx instead. + QueryIterator( + ctx context.Context, + opName string, + txn *kv.Txn, + stmt string, + qargs ...interface{}, + ) (InternalRows, error) + + // QueryIteratorEx executes the query, returning an iterator that can be + // used to get the results. If the call is successful, the returned iterator + // *must* be closed. + QueryIteratorEx( + ctx context.Context, + opName string, + txn *kv.Txn, + session sessiondata.InternalExecutorOverride, + stmt string, + qargs ...interface{}, + ) (InternalRows, error) +} + +// InternalRows is an iterator interface that's exposed by the internal +// executor. It provides access to the rows from a query. +type InternalRows interface { + // Next advances the iterator by one row, returning false if there are no + // more rows in this iterator or if an error is encountered (the latter is + // then returned). + // + // The iterator is automatically closed when false is returned, consequent + // calls to Next will return the same values as when the iterator was + // closed. + Next(context.Context) (bool, error) + + // Cur returns the row at the current position of the iterator. The row is + // safe to hold onto (meaning that calling Next() or Close() will not + // invalidate it). + Cur() tree.Datums + + // Close closes this iterator, releasing any resources it held open. Close + // is idempotent and *must* be called once the caller is done with the + // iterator. + Close() error + + // Types returns the types of the columns returned by this iterator. The + // returned array is guaranteed to correspond 1:1 with the tree.Datums rows + // returned by Cur(). + // + // WARNING: this method is safe to call anytime *after* the first call to + // Next() (including after Close() was called). + Types() colinfo.ResultColumns } // SessionBoundInternalExecutorFactory is a function that produces a "session