diff --git a/pkg/sql/logictest/testdata/logic_test/fk b/pkg/sql/logictest/testdata/logic_test/fk index 87e4ca004765..fc7841d9eb7e 100644 --- a/pkg/sql/logictest/testdata/logic_test/fk +++ b/pkg/sql/logictest/testdata/logic_test/fk @@ -3748,6 +3748,9 @@ INSERT INTO nonunique_idx_child VALUES (0, 1, 10) # stepping which caused issues when executing postqueries, so a query which # involved a mixed situation would error out. +statement ok +SET enable_multiple_modifications_of_table = true + statement ok CREATE TABLE x ( k INT PRIMARY KEY @@ -3768,6 +3771,8 @@ SELECT FROM a +statement ok +RESET enable_multiple_modifications_of_table # Test that reversing a constraint addition after adding a self foreign key # works correctly. diff --git a/pkg/sql/logictest/testdata/logic_test/udf_fk b/pkg/sql/logictest/testdata/logic_test/udf_fk index e87c81b345c0..10a747a2f937 100644 --- a/pkg/sql/logictest/testdata/logic_test/udf_fk +++ b/pkg/sql/logictest/testdata/logic_test/udf_fk @@ -595,3 +595,122 @@ SELECT * FROM selfref; 2 2 subtest end + + +subtest corruption_check + +statement ok +DROP TABLE IF EXISTS parent CASCADE; + +statement ok +DROP TABLE IF EXISTS child CASCADE; + +statement ok +DROP TABLE IF EXISTS grandchild CASCADE; + +statement ok +CREATE TABLE parent (j INT PRIMARY KEY); + +statement ok +CREATE TABLE child (i INT PRIMARY KEY, j INT REFERENCES parent (j) ON UPDATE CASCADE ON DELETE CASCADE, INDEX (j)); + +statement ok +INSERT INTO parent VALUES (0), (2), (4); + +statement ok +INSERT INTO child VALUES (0, 0); + +statement ok +CREATE OR REPLACE FUNCTION f(k INT) RETURNS INT AS $$ + UPDATE parent SET j = j + 1 WHERE j = k RETURNING j +$$ LANGUAGE SQL; + +# Check 1 level of cascades. +statement error pgcode 0A000 pq: multiple mutations of the same table "child" are not supported unless they all use INSERT without ON CONFLICT; this is to prevent data corruption, see documentation of sql.multiple_modifications_of_table.enabled +WITH x AS (SELECT f(0) AS j), y AS (UPDATE child SET j = 2 WHERE i = 0 RETURNING j) SELECT * FROM x; + +query II rowsort +SELECT i, j FROM child@primary; +---- +0 0 + +query II rowsort +SELECT i, j FROM child@child_j_idx; +---- +0 0 + +statement ok +CREATE FUNCTION f2(old INT, new INT) RETURNS INT AS $$ + UPDATE child SET j = new WHERE i = old RETURNING i +$$ LANGUAGE SQL; + +# Test that we allow mutations in cases were the cascade happens after the +# function call. +# this should not cause corruption, and should be allowed +# (the cascade to cookie will always be strictly after the function call) +statement ok +UPDATE parent SET j = j + 1 WHERE j = f2(0, 2); + +query II rowsort +SELECT i, j FROM child@primary; +---- +0 2 + +query II rowsort +SELECT i, j FROM child@child_j_idx; +---- +0 2 + +statement ok +DROP TABLE IF EXISTS child CASCADE; + +statement ok +TRUNCATE TABLE parent; + +statement ok +CREATE TABLE child (i INT PRIMARY KEY, j INT UNIQUE REFERENCES parent (j) ON UPDATE CASCADE ON DELETE CASCADE, INDEX (j)); + +statement ok +CREATE TABLE grandchild (i INT PRIMARY KEY, j INT REFERENCES child (j) ON UPDATE CASCADE ON DELETE CASCADE, INDEX (j)); + +statement ok +INSERT INTO parent VALUES (0), (2), (4); + +statement ok +INSERT INTO child VALUES (0, 0); + +statement ok +INSERT INTO grandchild VALUES (0,0) + +# Check 2 levels of cascades. +statement error pgcode 0A000 pq: multiple mutations of the same table "grandchild" are not supported unless they all use INSERT without ON CONFLICT; this is to prevent data corruption, see documentation of sql.multiple_modifications_of_table.enabled +WITH x AS (SELECT f(0) AS j), y AS (UPDATE grandchild SET j = 2 WHERE i = 0 RETURNING j) SELECT * FROM x; + +statement ok +DROP TABLE IF EXISTS child CASCADE; + +statement ok +DROP TABLE IF EXISTS grandchild CASCADE; + +statement ok +CREATE TABLE child (i INT PRIMARY KEY, j INT UNIQUE REFERENCES parent (j), k INT UNIQUE REFERENCES parent (j) ON UPDATE RESTRICT, INDEX (j)); + +statement ok +INSERT INTO child VALUES (0,4) + +# Check that we can mutate if there are no actions. +statement ok +WITH x AS (SELECT f(0) AS j), y AS (UPDATE child SET j = 2, k = 2 WHERE i = 0 RETURNING j) SELECT * FROM x; + +query II rowsort +SELECT i, j FROM child@primary; +---- +0 2 + +query II rowsort +SELECT i, j FROM child@child_j_idx; +---- +0 2 + + +subtest end diff --git a/pkg/sql/opt/optbuilder/statement_tree.go b/pkg/sql/opt/optbuilder/statement_tree.go index a8895b06d263..ee63416b5936 100644 --- a/pkg/sql/opt/optbuilder/statement_tree.go +++ b/pkg/sql/opt/optbuilder/statement_tree.go @@ -136,17 +136,39 @@ func (st *statementTree) Pop() { // statement1: UPDATE t1 // ├── statement2: UPDATE t2 // └── statement3: UPDATE t1 -func (st *statementTree) CanMutateTable(tabID cat.StableID, typ mutationType) bool { +// +// isPostStmt indicates that this mutation will be evaluated at the end of the +// statement (e.g., as part of a foreign key constraint). +func (st *statementTree) CanMutateTable( + tabID cat.StableID, typ mutationType, isPostStmt bool, +) bool { if len(st.stmts) == 0 { panic(errors.AssertionFailedf("unexpected empty tree")) } - curr := &st.stmts[len(st.stmts)-1] + if isPostStmt && len(st.stmts) == 1 { + return true + } + offset := 1 + if isPostStmt { + // If this mutation will be evaluated at the end of the current statement, + // the mutation should be added to the parent statement. This is because + // during execution, we step the transaction sequence point before starting + // check evaluations, so all updates in the current statement will be + // visible. + offset = 2 + } + curr := &st.stmts[len(st.stmts)-offset] // Check the children of the current statement for a conflict. - if curr.childrenConflictWithMutation(tabID, typ) { + if !isPostStmt && curr.childrenConflictWithMutation(tabID, typ) { return false } // Check the current statement and all parent statements for a conflict. for i := range st.stmts { + if isPostStmt && i == len(st.stmts)-1 { + // Don't check against the originating statement since we're adding this + // mutation to the parent statement. + break + } n := &st.stmts[i] if n.conflictsWithMutation(tabID, typ) { return false @@ -155,9 +177,17 @@ func (st *statementTree) CanMutateTable(tabID cat.StableID, typ mutationType) bo // The new mutation is valid, so track it. switch typ { case simpleInsert: - curr.simpleInsertTables.Add(int(tabID)) + if isPostStmt { + curr.childrenSimpleInsertTables.Add(int(tabID)) + } else { + curr.simpleInsertTables.Add(int(tabID)) + } case generalMutation: - curr.generalMutationTables.Add(int(tabID)) + if isPostStmt { + curr.childrenGeneralMutationTables.Add(int(tabID)) + } else { + curr.generalMutationTables.Add(int(tabID)) + } } return true } diff --git a/pkg/sql/opt/optbuilder/statement_tree_test.go b/pkg/sql/opt/optbuilder/statement_tree_test.go index b23fa1456adf..5ddd5951f2e0 100644 --- a/pkg/sql/opt/optbuilder/statement_tree_test.go +++ b/pkg/sql/opt/optbuilder/statement_tree_test.go @@ -23,6 +23,7 @@ func TestStatementTree(t *testing.T) { pop mut simple + post t1 t2 fail @@ -374,6 +375,55 @@ func TestStatementTree(t *testing.T) { mut | t1 | fail, }, }, + // 21. + // Push + // CanMutateTable(t1, simpleInsert) + // Push + // CanMutateTable(t2, default) + // CanMutateTable(t1, default, post) FAIL + { + cmds: []cmd{ + push, + mut | t1 | simple, + push, + mut | t2, + mut | t1 | post | fail, + }, + }, + // 22. + // Push + // Push + // CanMutateTable(t1, default) + // CanMutateTable(t2, default, post) + // Pop + // CanMutateTable(t2, simpleInsert) FAIL + { + cmds: []cmd{ + push, + push, + mut | t1, + mut | t2 | post, + pop, + mut | t1 | simple | fail, + }, + }, + // 23. + // Push + // Push + // CanMutateTable(t1, default) + // CanMutateTable(t2, default, post) + // Pop + // Pop + { + cmds: []cmd{ + push, + push, + mut | t1, + mut | t2 | post, + pop, + pop, + }, + }, } for i, tc := range testCases { @@ -400,7 +450,12 @@ func TestStatementTree(t *testing.T) { typ = simpleInsert } - res := mt.CanMutateTable(tabID, typ) + isPost := false + if c&post == post { + isPost = true + } + + res := mt.CanMutateTable(tabID, typ, isPost) expected := c&fail != fail if res != expected { diff --git a/pkg/sql/opt/optbuilder/util.go b/pkg/sql/opt/optbuilder/util.go index ca1794ee8bda..fc589ee96c43 100644 --- a/pkg/sql/opt/optbuilder/util.go +++ b/pkg/sql/opt/optbuilder/util.go @@ -23,6 +23,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sqlerrors" "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/intsets" "github.com/cockroachdb/errors" ) @@ -513,9 +514,11 @@ func (b *Builder) resolveSchemaForCreate( } func (b *Builder) checkMultipleMutations(tab cat.Table, typ mutationType) { - if !b.stmtTree.CanMutateTable(tab.ID(), typ) && - !multipleModificationsOfTableEnabled.Get(&b.evalCtx.Settings.SV) && - !b.evalCtx.SessionData().MultipleModificationsOfTable { + if multipleModificationsOfTableEnabled.Get(&b.evalCtx.Settings.SV) || + b.evalCtx.SessionData().MultipleModificationsOfTable { + return + } + if !b.stmtTree.CanMutateTable(tab.ID(), typ, false /* isPostStmt */) { panic(pgerror.Newf( pgcode.FeatureNotSupported, "multiple mutations of the same table %q are not supported unless they all "+ @@ -523,6 +526,40 @@ func (b *Builder) checkMultipleMutations(tab cat.Table, typ mutationType) { "documentation of sql.multiple_modifications_of_table.enabled", tab.Name(), )) } + if tab.InboundForeignKeyCount() > 0 { + var visited intsets.Fast + b.checkMultipleMutationsCascade(tab, typ, visited) + } +} + +func (b *Builder) checkMultipleMutationsCascade( + tab cat.Table, typ mutationType, visited intsets.Fast, +) { + // If this table references foreign keys that will also be mutated, then add + // them to the statement tree via a recursive call. We only need to check each + // table once even if there are multiple references to it. + for i := 0; i < tab.InboundForeignKeyCount(); i++ { + fk := tab.InboundForeignKey(i) + if (fk.DeleteReferenceAction() != tree.NoAction && fk.DeleteReferenceAction() != tree.Restrict && typ != simpleInsert) || + (fk.UpdateReferenceAction() != tree.NoAction && fk.UpdateReferenceAction() != tree.Restrict) { + fkTab := resolveTable(b.ctx, b.catalog, fk.OriginTableID()) + // If the origin table is still being added, it will be nil. It's safe to + // do the mutation in this case. + if fkTab == nil || visited.Contains(int(fkTab.ID())) { + continue + } + if !b.stmtTree.CanMutateTable(fkTab.ID(), typ, true /* isPostStmt */) { + panic(pgerror.Newf( + pgcode.FeatureNotSupported, + "multiple mutations of the same table %q are not supported unless they all "+ + "use INSERT without ON CONFLICT; this is to prevent data corruption, see "+ + "documentation of sql.multiple_modifications_of_table.enabled", fkTab.Name(), + )) + } + visited.Add(int(fkTab.ID())) + b.checkMultipleMutationsCascade(fkTab, typ, visited) + } + } } // resolveTableForMutation is a helper method for building mutations. It returns