Skip to content

Commit

Permalink
Ensure delete/update is applied on actual table specified in query
Browse files Browse the repository at this point in the history
Optimizer could flip the join and we don't want to apply the delete/update for incorrect table.
  • Loading branch information
Praveen2112 committed Oct 8, 2021
1 parent 85b6419 commit 29cb820
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar;
import static io.trino.sql.planner.plan.ChildReplacer.replaceChildren;
Expand Down Expand Up @@ -254,7 +255,9 @@ private WriterTarget createWriterTarget(WriterTarget target)
private TableHandle findTableScanHandle(PlanNode node)
{
if (node instanceof TableScanNode) {
return ((TableScanNode) node).getTable();
TableScanNode tableScanNode = (TableScanNode) node;
checkArgument(((TableScanNode) node).isUpdateTarget(), "TableScanNode should be an updatable target");
return tableScanNode.getTable();
}
if (node instanceof FilterNode) {
return findTableScanHandle(((FilterNode) node).getSource());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.TableWriterNode;
import io.trino.sql.planner.plan.TableWriterNode.DeleteTarget;
import io.trino.sql.planner.plan.TableWriterNode.UpdateTarget;
import io.trino.sql.planner.plan.TopNNode;
import io.trino.sql.planner.plan.TopNRankingNode;
import io.trino.sql.planner.plan.TopNRankingNode.RankingType;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.UpdateNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.planner.plan.WindowNode.Specification;
Expand Down Expand Up @@ -694,6 +696,49 @@ private DeleteTarget deleteTarget(SchemaTableName schemaTableName)
schemaTableName);
}

public TableFinishNode tableUpdate(SchemaTableName schemaTableName, PlanNode updateSource, Symbol updateRowId, List<Symbol> columnsToBeUpdated)
{
UpdateTarget updateTarget = updateTarget(
schemaTableName,
columnsToBeUpdated.stream()
.map(Symbol::getName)
.collect(toImmutableList()));
return new TableFinishNode(
idAllocator.getNextId(),
exchange(e -> e
.addSource(new UpdateNode(
idAllocator.getNextId(),
updateSource,
updateTarget,
updateRowId,
ImmutableList.<Symbol>builder()
.addAll(columnsToBeUpdated)
.add(updateRowId)
.build(),
ImmutableList.of(updateRowId)))
.addInputsSet(updateRowId)
.singleDistributionPartitioningScheme(updateRowId)),
updateTarget,
updateRowId,
Optional.empty(),
Optional.empty());
}

private UpdateTarget updateTarget(SchemaTableName schemaTableName, List<String> columnsToBeUpdated)
{
return new UpdateTarget(
Optional.of(new TableHandle(
new CatalogName("testConnector"),
new TestingTableHandle(),
TestingTransactionHandle.create(),
Optional.of(TestingHandle.INSTANCE))),
schemaTableName,
columnsToBeUpdated,
columnsToBeUpdated.stream()
.map(TestingColumnHandle::new)
.collect(toImmutableList()));
}

public ExchangeNode gatheringExchange(ExchangeNode.Scope scope, PlanNode child)
{
return exchange(builder -> builder.type(ExchangeNode.Type.GATHER)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.AbstractMockMetadata;
import io.trino.metadata.Metadata;
import io.trino.metadata.TableHandle;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.type.BigintType;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.PlanNode;
import org.testng.annotations.Test;

import java.util.List;
import java.util.function.Function;

import static io.trino.sql.planner.TypeProvider.empty;
import static io.trino.sql.planner.plan.JoinNode.Type.INNER;
import static io.trino.testing.TestingSession.testSessionBuilder;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

public class TestBeginTableWrite
{
@Test
public void testValidDelete()
{
assertThatCode(() -> applyOptimization(
p -> p.tableDelete(
new SchemaTableName("sch", "tab"),
p.tableScan(ImmutableList.of(p.symbol("rowId")), true),
p.symbol("rowId", BigintType.BIGINT))))
.doesNotThrowAnyException();
}

@Test
public void testValidUpdate()
{
assertThatCode(() -> applyOptimization(
p -> p.tableUpdate(
new SchemaTableName("sch", "tab"),
p.tableScan(ImmutableList.of(p.symbol("columnToBeUpdated")), true),
p.symbol("rowId", BigintType.BIGINT),
ImmutableList.of(p.symbol("columnToBeUpdated")))))
.doesNotThrowAnyException();
}

@Test
public void testDeleteWithNonDeletableTableScan()
{
assertThatThrownBy(() -> applyOptimization(
p -> p.tableDelete(
new SchemaTableName("sch", "tab"),
p.join(
INNER,
p.tableScan(ImmutableList.of(), false),
p.limit(
1,
p.tableScan(ImmutableList.of(p.symbol("rowId")), true))),
p.symbol("rowId", BigintType.BIGINT))))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("TableScanNode should be an updatable target");
}

@Test
public void testUpdateWithNonUpdatableTableScan()
{
assertThatThrownBy(() -> applyOptimization(
p -> p.tableUpdate(
new SchemaTableName("sch", "tab"),
p.join(
INNER,
p.tableScan(ImmutableList.of(), false),
p.limit(
1,
p.tableScan(ImmutableList.of(p.symbol("columnToBeUpdated"), p.symbol("rowId")), true))),
p.symbol("rowId", BigintType.BIGINT),
ImmutableList.of(p.symbol("columnToBeUpdated")))))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("TableScanNode should be an updatable target");
}

@Test
public void testDeleteWithInvalidNode()
{
assertThatThrownBy(() -> applyOptimization(
p -> p.tableDelete(
new SchemaTableName("sch", "tab"),
p.distinctLimit(
10,
ImmutableList.of(p.symbol("rowId")),
p.tableScan(ImmutableList.of(p.symbol("a")), true)),
p.symbol("rowId", BigintType.BIGINT))))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Invalid descendant for DeleteNode or UpdateNode: io.trino.sql.planner.plan.DistinctLimitNode");
}

@Test
public void testUpdateWithInvalidNode()
{
assertThatThrownBy(() -> applyOptimization(
p -> p.tableUpdate(
new SchemaTableName("sch", "tab"),
p.distinctLimit(
10,
ImmutableList.of(p.symbol("a"), p.symbol("rowId")),
p.tableScan(ImmutableList.of(p.symbol("a")), true)),
p.symbol("rowId", BigintType.BIGINT),
ImmutableList.of(p.symbol("columnToBeUpdated")))))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Invalid descendant for DeleteNode or UpdateNode: io.trino.sql.planner.plan.DistinctLimitNode");
}

private void applyOptimization(Function<PlanBuilder, PlanNode> planProvider)
{
Metadata metadata = new MockMetadata();
new BeginTableWrite(metadata)
.optimize(
planProvider.apply(new PlanBuilder(new PlanNodeIdAllocator(), metadata)),
testSessionBuilder().build(),
empty(),
new SymbolAllocator(),
new PlanNodeIdAllocator(),
WarningCollector.NOOP);
}

private static class MockMetadata
extends AbstractMockMetadata
{
@Override
public TableHandle beginDelete(Session session, TableHandle tableHandle)
{
return tableHandle;
}

@Override
public TableHandle beginUpdate(Session session, TableHandle tableHandle, List<ColumnHandle> updatedColumns)
{
return tableHandle;
}
}
}

0 comments on commit 29cb820

Please sign in to comment.