Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: Implement a preliminary version of multiple active portals #96358

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pkg/sql/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ func runPlanInsidePlan(
// Make a copy of the EvalContext so it can be safely modified.
evalCtx := params.p.ExtendedEvalContextCopy()
plannerCopy := *params.p
// If we reach this part when re-executing a pausable portal, we won't want to
// resume the flow bound to it. The inner-plan should have its own lifecycle
// for its flow.
plannerCopy.pausablePortal = nil
distributePlan := getPlanDistribution(
ctx, plannerCopy.Descriptors().HasUncommittedTypes(),
plannerCopy.SessionData().DistSQLMode, plan.main,
Expand Down
78 changes: 63 additions & 15 deletions pkg/sql/conn_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,14 @@ func (ex *connExecutor) close(ctx context.Context, closeType closeType) {
txnEvType = txnRollback
}

// Close all portals, otherwise there will be leftover bytes.
ex.extraTxnState.prepStmtsNamespace.closeAllPortals(
ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc,
)
ex.extraTxnState.prepStmtsNamespaceAtTxnRewindPos.closeAllPortals(
ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc,
)

if closeType == normalClose {
// We'll cleanup the SQL txn by creating a non-retriable (commit:true) event.
// This event is guaranteed to be accepted in every state.
Expand Down Expand Up @@ -1760,6 +1768,26 @@ func (ns *prepStmtNamespace) touchLRUEntry(name string) {
ns.addLRUEntry(name, 0)
}

func (ns *prepStmtNamespace) closeAllPortals(
ctx context.Context, prepStmtsNamespaceMemAcc *mon.BoundAccount,
) {
for name, p := range ns.portals {
p.close(ctx, prepStmtsNamespaceMemAcc, name)
delete(ns.portals, name)
}
}

func (ns *prepStmtNamespace) closeAllPausablePortals(
ctx context.Context, prepStmtsNamespaceMemAcc *mon.BoundAccount,
) {
for name, p := range ns.portals {
if p.pauseInfo != nil {
p.close(ctx, prepStmtsNamespaceMemAcc, name)
delete(ns.portals, name)
}
}
}

// MigratablePreparedStatements returns a mapping of all prepared statements.
func (ns *prepStmtNamespace) MigratablePreparedStatements() []sessiondatapb.MigratableSession_PreparedStatement {
ret := make([]sessiondatapb.MigratableSession_PreparedStatement, 0, len(ns.prepStmts))
Expand Down Expand Up @@ -1836,10 +1864,7 @@ func (ns *prepStmtNamespace) resetTo(
for name := range ns.prepStmtsLRU {
delete(ns.prepStmtsLRU, name)
}
for name, p := range ns.portals {
p.close(ctx, prepStmtsNamespaceMemAcc, name)
delete(ns.portals, name)
}
ns.closeAllPortals(ctx, prepStmtsNamespaceMemAcc)

for name, ps := range to.prepStmts {
ps.incRef(ctx)
Expand Down Expand Up @@ -1880,10 +1905,9 @@ func (ex *connExecutor) resetExtraTxnState(ctx context.Context, ev txnEvent) {
}

// Close all portals.
for name, p := range ex.extraTxnState.prepStmtsNamespace.portals {
p.close(ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, name)
delete(ex.extraTxnState.prepStmtsNamespace.portals, name)
}
ex.extraTxnState.prepStmtsNamespace.closeAllPortals(
ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc,
)

// Close all cursors.
if err := ex.extraTxnState.sqlCursors.closeAll(false /* errorOnWithHold */); err != nil {
Expand All @@ -1894,10 +1918,9 @@ func (ex *connExecutor) resetExtraTxnState(ctx context.Context, ev txnEvent) {

switch ev.eventType {
case txnCommit, txnRollback:
for name, p := range ex.extraTxnState.prepStmtsNamespaceAtTxnRewindPos.portals {
p.close(ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, name)
delete(ex.extraTxnState.prepStmtsNamespaceAtTxnRewindPos.portals, name)
}
ex.extraTxnState.prepStmtsNamespaceAtTxnRewindPos.closeAllPortals(
ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc,
)
ex.extraTxnState.savepoints.clear()
ex.onTxnFinish(ctx, ev)
case txnRestart:
Expand Down Expand Up @@ -2044,7 +2067,6 @@ func (ex *connExecutor) run(
return err
}
}

}

// errDrainingComplete is returned by execCmd when the connExecutor previously got
Expand Down Expand Up @@ -2116,7 +2138,7 @@ func (ex *connExecutor) execCmd() (retErr error) {
(tcmd.LastInBatchBeforeShowCommitTimestamp ||
tcmd.LastInBatch || !implicitTxnForBatch)
ev, payload, err = ex.execStmt(
ctx, tcmd.Statement, nil /* prepared */, nil /* pinfo */, stmtRes, canAutoCommit,
ctx, tcmd.Statement, nil /* portal */, nil /* pinfo */, stmtRes, canAutoCommit,
)

return err
Expand Down Expand Up @@ -2204,6 +2226,8 @@ func (ex *connExecutor) execCmd() (retErr error) {
// - ex.statsCollector merely contains a copy of the times, that
// was created when the statement started executing (via the
// reset() method).
// TODO(sql-sessions): fix the phase time for pausable portals.
// https://github.com/cockroachdb/cockroach/issues/99410
ex.statsCollector.PhaseTimes().SetSessionPhaseTime(sessionphase.SessionQueryServiced, timeutil.Now())
if err != nil {
return err
Expand Down Expand Up @@ -2314,6 +2338,12 @@ func (ex *connExecutor) execCmd() (retErr error) {
// If an event was generated, feed it to the state machine.
if ev != nil {
var err error
if _, ok := payload.(eventNonRetriableErrPayload); ok {
// We need this as otherwise, there'll be leftover bytes when
// txnState.finishSQLTxn() is being called, as the underlying resources of
// pausable portals hasn't been cleared yet.
ex.extraTxnState.prepStmtsNamespace.closeAllPausablePortals(ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc)
}
advInfo, err = ex.txnStateTransitionsApplyWrapper(ev, payload, res, pos)
if err != nil {
return err
Expand Down Expand Up @@ -2364,6 +2394,17 @@ func (ex *connExecutor) execCmd() (retErr error) {
res.SetError(pe.errorCause())
}
}
// For a pausable portal, we don't log the affected rows until we close the
// portal. However, we update the result for each execution. Thus, we need
// to accumulate the number of affected rows before closing the result.
switch tcmd := cmd.(type) {
case ExecPortal:
if portal, ok := ex.extraTxnState.prepStmtsNamespace.portals[tcmd.Name]; ok {
if portal.pauseInfo != nil {
portal.pauseInfo.rowsAffected += res.(RestrictedCommandResult).RowsAffected()
}
}
}
res.Close(ctx, stateToTxnStatusIndicator(ex.machine.CurState()))
} else {
res.Discard()
Expand Down Expand Up @@ -3605,8 +3646,15 @@ func (ex *connExecutor) txnStateTransitionsApplyWrapper(
}

fallthrough
case txnRestart, txnRollback:
case txnRestart:
ex.resetExtraTxnState(ex.Ctx(), advInfo.txnEvent)
case txnRollback:
ex.resetExtraTxnState(ex.Ctx(), advInfo.txnEvent)
// Since we're doing a complete rollback, there's no need to keep the
// prepared stmts for a txn rewind.
ex.extraTxnState.prepStmtsNamespaceAtTxnRewindPos.closeAllPortals(
ex.Ctx(), &ex.extraTxnState.prepStmtsNamespaceMemAcc,
)
default:
return advanceInfo{}, errors.AssertionFailedf(
"unexpected event: %v", errors.Safe(advInfo.txnEvent))
Expand Down
Loading