Skip to content

Commit

Permalink
sql: add restrictions for pausable portals
Browse files Browse the repository at this point in the history
This commits add the following restrictions for pausable portals:

1. Not an internal queries
2. Read-only queries
3. No sub-quereis or post-queries
4. Local plan only

This is because the current changes to the consumer-receiver model only consider
the local push-based case.

Release note: None
  • Loading branch information
ZhouXing19 committed Mar 23, 2023
1 parent 3484588 commit 3ef34e5
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 7 deletions.
20 changes: 19 additions & 1 deletion pkg/sql/conn_executor_exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -1511,9 +1511,15 @@ func (ex *connExecutor) dispatchToExecutionEngine(
}

ex.sessionTracing.TracePlanCheckStart(ctx)

distSQLMode := ex.sessionData().DistSQLMode
// We only allow non-distributed plan for pausable portals.
if planner.pausablePortal != nil {
distSQLMode = sessiondatapb.DistSQLOff
}
distributePlan := getPlanDistribution(
ctx, planner.Descriptors().HasUncommittedTypes(),
ex.sessionData().DistSQLMode, planner.curPlan.main,
distSQLMode, planner.curPlan.main,
)
ex.sessionTracing.TracePlanCheckEnd(ctx, nil, distributePlan.WillDistribute())

Expand Down Expand Up @@ -2004,6 +2010,18 @@ func (ex *connExecutor) execWithDistSQLEngine(
factoryEvalCtx.SessionID = planner.ExtendedEvalContext().SessionID
return factoryEvalCtx
}
// We don't allow sub / post queries for pausable portal. Set it back to an
// un-pausable (normal) portal.
if planCtx.getPortalPauseInfo() != nil {
// With pauseInfo is nil, no cleanup function will be added to the stack
// and all clean-up steps will be performed as for normal portals.
planCtx.planner.pausablePortal.pauseInfo = nil
// We need this so that the result consumption for this portal cannot be
// paused either.
if err := res.RevokePortalPausability(); err != nil {
return recv.stats, err
}
}
}
err = ex.server.cfg.DistSQLPlanner.PlanAndRunAll(ctx, evalCtx, planCtx, planner, recv, evalCtxFactory)
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/sql/conn_executor_prepare.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ func (ex *connExecutor) execBind(
}

// Create the new PreparedPortal.
if err := ex.addPortal(ctx, portalName, ps, qargs, columnFormatCodes); err != nil {
if err := ex.addPortal(ctx, portalName, ps, qargs, bindCmd.isInternal, columnFormatCodes); err != nil {
return retErr(err)
}

Expand All @@ -513,6 +513,7 @@ func (ex *connExecutor) addPortal(
portalName string,
stmt *PreparedStatement,
qargs tree.QueryArguments,
isInternal bool,
outFormats []pgwirebase.FormatCode,
) error {
if _, ok := ex.extraTxnState.prepStmtsNamespace.portals[portalName]; ok {
Expand All @@ -522,7 +523,7 @@ func (ex *connExecutor) addPortal(
panic(errors.AssertionFailedf("portal already exists as cursor: %q", portalName))
}

portal, err := ex.makePreparedPortal(ctx, portalName, stmt, qargs, outFormats)
portal, err := ex.makePreparedPortal(ctx, portalName, stmt, qargs, isInternal, outFormats)
if err != nil {
return err
}
Expand Down
13 changes: 13 additions & 0 deletions pkg/sql/conn_io.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ type BindStmt struct {
// inferred types should reflect that).
// If internalArgs is specified, Args and ArgFormatCodes are ignored.
internalArgs []tree.Datum

// isInternal is set to ture when the bound stmt is from an internal executor.
isInternal bool
}

// command implements the Command interface.
Expand Down Expand Up @@ -815,6 +818,11 @@ type RestrictedCommandResult interface {
// GetBulkJobId returns the id of the job for the query, if the query is
// IMPORT, BACKUP or RESTORE.
GetBulkJobId() uint64

// RevokePortalPausability is to make a portal un-pausable. It is called when
// we find the underlying query is not supported for a pausable portal.
// This method is implemented only by pgwire.limitedCommandResult.
RevokePortalPausability() error
}

// DescribeResult represents the result of a Describe command (for either
Expand Down Expand Up @@ -969,6 +977,11 @@ type streamingCommandResult struct {
var _ RestrictedCommandResult = &streamingCommandResult{}
var _ CommandResultClose = &streamingCommandResult{}

// RevokePortalPausability is part of the sql.RestrictedCommandResult interface.
func (r *streamingCommandResult) RevokePortalPausability() error {
return errors.AssertionFailedf("forPausablePortal is for limitedCommandResult only")
}

// SetColumns is part of the RestrictedCommandResult interface.
func (r *streamingCommandResult) SetColumns(ctx context.Context, cols colinfo.ResultColumns) {
// The interface allows for cols to be nil, yet the iterator result expects
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ func (ie *InternalExecutor) execInternal(
return nil, err
}

if err := stmtBuf.Push(ctx, BindStmt{internalArgs: datums}); err != nil {
if err := stmtBuf.Push(ctx, BindStmt{internalArgs: datums, isInternal: true}); err != nil {
return nil, err
}

Expand Down
13 changes: 13 additions & 0 deletions pkg/sql/pgwire/command_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ type paramStatusUpdate struct {

var _ sql.CommandResult = &commandResult{}

// RevokePortalPausability is part of the sql.RestrictedCommandResult interface.
func (r *commandResult) RevokePortalPausability() error {
return errors.AssertionFailedf("RevokePortalPausability is only implemented by limitedCommandResult only")
}

// Close is part of the sql.RestrictedCommandResult interface.
func (r *commandResult) Close(ctx context.Context, t sql.TransactionStatusIndicator) {
r.assertNotReleased()
Expand Down Expand Up @@ -457,6 +462,8 @@ type limitedCommandResult struct {
portalPausablity sql.PortalPausablity
}

var _ sql.RestrictedCommandResult = &limitedCommandResult{}

// AddRow is part of the sql.RestrictedCommandResult interface.
func (r *limitedCommandResult) AddRow(ctx context.Context, row tree.Datums) error {
if err := r.commandResult.AddRow(ctx, row); err != nil {
Expand Down Expand Up @@ -486,6 +493,12 @@ func (r *limitedCommandResult) AddRow(ctx context.Context, row tree.Datums) erro
return nil
}

// RevokePortalPausability is part of the sql.RestrictedCommandResult interface.
func (r *limitedCommandResult) RevokePortalPausability() error {
r.portalPausablity = sql.NotPausablePortalForUnsupportedStmt
return nil
}

// SupportsAddBatch is part of the sql.RestrictedCommandResult interface.
// TODO(yuzefovich): implement limiting behavior for AddBatch.
func (r *limitedCommandResult) SupportsAddBatch() bool {
Expand Down
13 changes: 10 additions & 3 deletions pkg/sql/prepared_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ func (ex *connExecutor) makePreparedPortal(
name string,
stmt *PreparedStatement,
qargs tree.QueryArguments,
isInternal bool,
outFormats []pgwirebase.FormatCode,
) (PreparedPortal, error) {
portal := PreparedPortal{
Expand All @@ -182,9 +183,15 @@ func (ex *connExecutor) makePreparedPortal(
OutFormats: outFormats,
}

if EnableMultipleActivePortals.Get(&ex.server.cfg.Settings.SV) {
portal.pauseInfo = &portalPauseInfo{}
portal.portalPausablity = PausablePortal
if EnableMultipleActivePortals.Get(&ex.server.cfg.Settings.SV) && !isInternal {
if tree.IsAllowedToPause(stmt.AST) {
portal.pauseInfo = &portalPauseInfo{queryStats: &topLevelQueryStats{}}
portal.portalPausablity = PausablePortal
} else {
// We have set sql.defaults.multiple_active_portals.enabled to true, but
// we don't support the underlying query for a pausable portal.
portal.portalPausablity = NotPausablePortalForUnsupportedStmt
}
}
return portal, portal.accountForCopy(ctx, &ex.extraTxnState.prepStmtsNamespaceMemAcc, name)
}
Expand Down
24 changes: 24 additions & 0 deletions pkg/sql/sem/tree/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,30 @@ type canModifySchema interface {
modifiesSchema() bool
}

// IsAllowedToPause returns true if the stmt cannot either modify the schema or
// write data.
// This function is to gate the queries allowed for pausable portals.
// TODO(janexing): We should be more accurate about the stmt selection here.
// Now we only allow SELECT, but is it too strict? And how to filter out
// SELECT with data writes / schema changes?
func IsAllowedToPause(stmt Statement) bool {
if !CanModifySchema(stmt) && !CanWriteData(stmt) {
switch t := stmt.(type) {
case *Select:
if t.With != nil {
ctes := t.With.CTEList
for _, cte := range ctes {
if !IsAllowedToPause(cte.Stmt) {
return false
}
}
}
return true
}
}
return false
}

// CanModifySchema returns true if the statement can modify
// the database schema.
func CanModifySchema(stmt Statement) bool {
Expand Down

0 comments on commit 3ef34e5

Please sign in to comment.