diff --git a/pkg/cli/declarative_corpus.go b/pkg/cli/declarative_corpus.go index faacaad74673..43ea4086d0e9 100644 --- a/pkg/cli/declarative_corpus.go +++ b/pkg/cli/declarative_corpus.go @@ -50,7 +50,7 @@ a given corpus file. return jobID }, } - _, err := scplan.MakePlan(*state, params) + _, err := scplan.MakePlan(cmd.Context(), *state, params) if err != nil { fmt.Printf("failed to validate %s with error %v\n", name, err) } else { diff --git a/pkg/sql/explain_ddl.go b/pkg/sql/explain_ddl.go index 58031d7145a2..01bd212606d0 100644 --- a/pkg/sql/explain_ddl.go +++ b/pkg/sql/explain_ddl.go @@ -67,15 +67,17 @@ func (n *explainDDLNode) startExec(params runParams) error { return explainNotPossibleError } } - return n.setExplainValues(scNode.plannedState) + return n.setExplainValues(params.ctx, scNode.plannedState) } -func (n *explainDDLNode) setExplainValues(scState scpb.CurrentState) (err error) { +func (n *explainDDLNode) setExplainValues( + ctx context.Context, scState scpb.CurrentState, +) (err error) { defer func() { err = errors.WithAssertionFailure(err) }() var p scplan.Plan - p, err = scplan.MakePlan(scState, scplan.Params{ + p, err = scplan.MakePlan(ctx, scState, scplan.Params{ ExecutionPhase: scop.StatementPhase, SchemaChangerJobIDSupplier: func() jobspb.JobID { return 1 }, }) diff --git a/pkg/sql/schemachanger/corpus/corpus_test.go b/pkg/sql/schemachanger/corpus/corpus_test.go index e825b371e7a3..52bb87de4e8c 100644 --- a/pkg/sql/schemachanger/corpus/corpus_test.go +++ b/pkg/sql/schemachanger/corpus/corpus_test.go @@ -11,6 +11,7 @@ package corpus_test import ( + "context" "flag" "testing" @@ -40,7 +41,7 @@ func TestValidateCorpuses(t *testing.T) { jobID := jobspb.InvalidJobID name, state := reader.GetCorpus(corpusIdx) t.Run(name, func(t *testing.T) { - _, err := scplan.MakePlan(*state, scplan.Params{ + _, err := scplan.MakePlan(context.Background(), *state, scplan.Params{ ExecutionPhase: scop.LatestPhase, InRollback: state.InRollback, SchemaChangerJobIDSupplier: func() jobspb.JobID { diff --git a/pkg/sql/schemachanger/scdeps/sctestutils/sctestutils.go b/pkg/sql/schemachanger/scdeps/sctestutils/sctestutils.go index ba70e010dab4..ec3a14e4cde8 100644 --- a/pkg/sql/schemachanger/scdeps/sctestutils/sctestutils.go +++ b/pkg/sql/schemachanger/scdeps/sctestutils/sctestutils.go @@ -168,7 +168,7 @@ func ProtoDiff(a, b protoutil.Message, args DiffArgs, rewrites func(interface{}) // MakePlan is a convenient alternative to calling scplan.MakePlan in tests. func MakePlan(t *testing.T, state scpb.CurrentState, phase scop.Phase) scplan.Plan { - plan, err := scplan.MakePlan(state, scplan.Params{ + plan, err := scplan.MakePlan(context.Background(), state, scplan.Params{ ExecutionPhase: phase, SchemaChangerJobIDSupplier: func() jobspb.JobID { return 1 }, }) diff --git a/pkg/sql/schemachanger/scplan/internal/opgen/op_gen.go b/pkg/sql/schemachanger/scplan/internal/opgen/op_gen.go index fc6663e016ed..e64906446e67 100644 --- a/pkg/sql/schemachanger/scplan/internal/opgen/op_gen.go +++ b/pkg/sql/schemachanger/scplan/internal/opgen/op_gen.go @@ -81,17 +81,19 @@ func IterateTransitions( // BuildGraph constructs a graph with operation edges populated from an initial // state. -func BuildGraph(cs scpb.CurrentState) (*scgraph.Graph, error) { - return opRegistry.buildGraph(cs) +func BuildGraph(ctx context.Context, cs scpb.CurrentState) (*scgraph.Graph, error) { + return opRegistry.buildGraph(ctx, cs) } -func (r *registry) buildGraph(cs scpb.CurrentState) (_ *scgraph.Graph, err error) { +func (r *registry) buildGraph( + ctx context.Context, cs scpb.CurrentState, +) (_ *scgraph.Graph, err error) { start := timeutil.Now() defer func() { - if err != nil || !log.V(2) { + if err != nil || !log.ExpensiveLogEnabled(ctx, 2) { return } - log.Infof(context.TODO(), "operation graph generation took %v", timeutil.Since(start)) + log.Infof(ctx, "operation graph generation took %v", timeutil.Since(start)) }() g, err := scgraph.New(cs) if err != nil { diff --git a/pkg/sql/schemachanger/scplan/internal/rules/registry.go b/pkg/sql/schemachanger/scplan/internal/rules/registry.go index ee77102b54ec..8ecae0757246 100644 --- a/pkg/sql/schemachanger/scplan/internal/rules/registry.go +++ b/pkg/sql/schemachanger/scplan/internal/rules/registry.go @@ -27,7 +27,7 @@ import ( // ApplyDepRules adds dependency edges to the graph according to the // registered dependency rules. -func ApplyDepRules(g *scgraph.Graph) error { +func ApplyDepRules(ctx context.Context, g *scgraph.Graph) error { for _, dr := range registry.depRules { start := timeutil.Now() var added int @@ -41,9 +41,15 @@ func ApplyDepRules(g *scgraph.Graph) error { }); err != nil { return errors.Wrapf(err, "applying dep rule %s", dr.name) } - if log.V(2) { + // Applying the dep rules can be slow in some cases. Check for + // cancellation when applying the rules to ensure we don't spin for + // too long while the user is waiting for the task to exit cleanly. + if ctx.Err() != nil { + return ctx.Err() + } + if log.ExpensiveLogEnabled(ctx, 2) { log.Infof( - context.TODO(), "applying dep rule %s %d took %v", + ctx, "applying dep rule %s %d took %v", dr.name, added, timeutil.Since(start), ) } @@ -53,7 +59,7 @@ func ApplyDepRules(g *scgraph.Graph) error { // ApplyOpRules marks op edges as no-op in a shallow copy of the graph according // to the registered rules. -func ApplyOpRules(g *scgraph.Graph) (*scgraph.Graph, error) { +func ApplyOpRules(ctx context.Context, g *scgraph.Graph) (*scgraph.Graph, error) { db := g.Database() m := make(map[*screl.Node][]scgraph.RuleName) for _, rule := range registry.opRules { @@ -68,9 +74,9 @@ func ApplyOpRules(g *scgraph.Graph) (*scgraph.Graph, error) { if err != nil { return nil, errors.Wrapf(err, "applying op rule %s", rule.name) } - if log.V(2) { + if log.ExpensiveLogEnabled(ctx, 2) { log.Infof( - context.TODO(), "applying op rule %s %d took %v", + ctx, "applying op rule %s %d took %v", rule.name, added, timeutil.Since(start), ) } diff --git a/pkg/sql/schemachanger/scplan/internal/scstage/build.go b/pkg/sql/schemachanger/scplan/internal/scstage/build.go index 6c1912504bf6..8c1594799be6 100644 --- a/pkg/sql/schemachanger/scplan/internal/scstage/build.go +++ b/pkg/sql/schemachanger/scplan/internal/scstage/build.go @@ -11,6 +11,7 @@ package scstage import ( + "context" "fmt" "sort" "strings" @@ -31,7 +32,11 @@ import ( // Note that the scJobIDSupplier function is idempotent, and must return the // same value for all calls. func BuildStages( - init scpb.CurrentState, phase scop.Phase, g *scgraph.Graph, scJobIDSupplier func() jobspb.JobID, + ctx context.Context, + init scpb.CurrentState, + phase scop.Phase, + g *scgraph.Graph, + scJobIDSupplier func() jobspb.JobID, ) []Stage { c := buildContext{ rollback: init.InRollback, diff --git a/pkg/sql/schemachanger/scplan/plan.go b/pkg/sql/schemachanger/scplan/plan.go index 8d65e4e72ceb..5818abe00927 100644 --- a/pkg/sql/schemachanger/scplan/plan.go +++ b/pkg/sql/schemachanger/scplan/plan.go @@ -74,19 +74,19 @@ func (p Plan) StagesForCurrentPhase() []scstage.Stage { // MakePlan generates a Plan for a particular phase of a schema change, given // the initial state for a set of targets. Returns an error when planning fails. -func MakePlan(initial scpb.CurrentState, params Params) (p Plan, err error) { +func MakePlan(ctx context.Context, initial scpb.CurrentState, params Params) (p Plan, err error) { p = Plan{ CurrentState: initial, Params: params, } - err = makePlan(&p) - if err != nil { + err = makePlan(ctx, &p) + if err != nil && ctx.Err() == nil { err = p.DecorateErrorWithPlanDetails(err) } return p, err } -func makePlan(p *Plan) (err error) { +func makePlan(ctx context.Context, p *Plan) (err error) { defer func() { if r := recover(); r != nil { rAsErr, ok := r.(error) @@ -99,18 +99,18 @@ func makePlan(p *Plan) (err error) { }() { start := timeutil.Now() - p.Graph = buildGraph(p.CurrentState) - if log.V(2) { - log.Infof(context.TODO(), "graph generation took %v", timeutil.Since(start)) + p.Graph = buildGraph(ctx, p.CurrentState) + if log.ExpensiveLogEnabled(ctx, 2) { + log.Infof(ctx, "graph generation took %v", timeutil.Since(start)) } } { start := timeutil.Now() p.Stages = scstage.BuildStages( - p.CurrentState, p.Params.ExecutionPhase, p.Graph, p.Params.SchemaChangerJobIDSupplier, + ctx, p.CurrentState, p.Params.ExecutionPhase, p.Graph, p.Params.SchemaChangerJobIDSupplier, ) - if log.V(2) { - log.Infof(context.TODO(), "stage generation took %v", timeutil.Since(start)) + if log.ExpensiveLogEnabled(ctx, 2) { + log.Infof(ctx, "stage generation took %v", timeutil.Since(start)) } } if n := len(p.Stages); n > 0 && p.Stages[n-1].Phase > scop.PreCommitPhase { @@ -123,12 +123,12 @@ func makePlan(p *Plan) (err error) { return nil } -func buildGraph(cs scpb.CurrentState) *scgraph.Graph { - g, err := opgen.BuildGraph(cs) +func buildGraph(ctx context.Context, cs scpb.CurrentState) *scgraph.Graph { + g, err := opgen.BuildGraph(ctx, cs) if err != nil { panic(errors.Wrapf(err, "build graph op edges")) } - err = rules.ApplyDepRules(g) + err = rules.ApplyDepRules(ctx, g) if err != nil { panic(errors.Wrapf(err, "build graph dep edges")) } @@ -136,7 +136,7 @@ func buildGraph(cs scpb.CurrentState) *scgraph.Graph { if err != nil { panic(errors.Wrapf(err, "validate graph")) } - g, err = rules.ApplyOpRules(g) + g, err = rules.ApplyOpRules(ctx, g) if err != nil { panic(errors.Wrapf(err, "mark op edges as no-op")) } diff --git a/pkg/sql/schemachanger/scrun/scrun.go b/pkg/sql/schemachanger/scrun/scrun.go index 98557d0ddd70..a4d30d3cdcec 100644 --- a/pkg/sql/schemachanger/scrun/scrun.go +++ b/pkg/sql/schemachanger/scrun/scrun.go @@ -65,7 +65,7 @@ func runTransactionPhase( if len(state.Current) == 0 { return scpb.CurrentState{}, jobspb.InvalidJobID, nil } - sc, err := scplan.MakePlan(state, scplan.Params{ + sc, err := scplan.MakePlan(ctx, state, scplan.Params{ ExecutionPhase: phase, SchemaChangerJobIDSupplier: deps.TransactionalJobRegistry().SchemaChangerJobID, }) @@ -112,7 +112,7 @@ func RunSchemaChangesInJob( } return errors.Wrapf(err, "failed to construct state for job %d", jobID) } - sc, err := scplan.MakePlan(state, scplan.Params{ + sc, err := scplan.MakePlan(ctx, state, scplan.Params{ ExecutionPhase: scop.PostCommitPhase, SchemaChangerJobIDSupplier: func() jobspb.JobID { return jobID }, }) diff --git a/pkg/sql/schemachanger/sctest/end_to_end.go b/pkg/sql/schemachanger/sctest/end_to_end.go index 34960dd4c08a..3d473dc14da0 100644 --- a/pkg/sql/schemachanger/sctest/end_to_end.go +++ b/pkg/sql/schemachanger/sctest/end_to_end.go @@ -243,7 +243,7 @@ func checkExplainDiagrams( params.InRollback = true params.ExecutionPhase = scop.PostCommitNonRevertiblePhase } - pl, err := scplan.MakePlan(state, params) + pl, err := scplan.MakePlan(context.Background(), state, params) require.NoErrorf(t, err, "%s: %s", fileNameSuffix, explainedStmt) action(explainDir, "ddl", pl.ExplainCompact) action(explainVerboseDir, "ddl, verbose", pl.ExplainVerbose)