Skip to content

Commit

Permalink
Refactor tests to use assertJoinConditionallyPushedDown
Browse files Browse the repository at this point in the history
assertJoinConditionallyPushedDown is simpler to use and more generic as
it doesn't get affected by exact plan shape.
  • Loading branch information
vlad-lyutenko authored and hashhar committed Nov 16, 2022
1 parent 603c83d commit fcee24d
Showing 1 changed file with 37 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
import static io.trino.sql.planner.OptimizerConfig.JoinDistributionType.BROADCAST;
import static io.trino.sql.planner.OptimizerConfig.JoinDistributionType.PARTITIONED;
import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree;
import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange;
import static io.trino.sql.planner.assertions.PlanMatchPattern.node;
import static io.trino.testing.DataProviders.toDataProvider;
import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder;
Expand Down Expand Up @@ -1129,9 +1128,7 @@ public void testJoinPushdownDisabled()
.build();

assertThat(query(noJoinPushdown, "SELECT r.name, n.name FROM nation n JOIN region r ON n.regionkey = r.regionkey"))
.isNotFullyPushedDown(node(JoinNode.class,
anyTree(node(TableScanNode.class)),
anyTree(node(TableScanNode.class))));
.joinIsNotFullyPushedDown();
}

/**
Expand All @@ -1146,31 +1143,37 @@ public void verifySupportsJoinPushdownDeclaration()
}

assertThat(query(joinPushdownEnabled(getSession()), "SELECT r.name, n.name FROM nation n JOIN region r ON n.regionkey = r.regionkey"))
.isNotFullyPushedDown(
node(JoinNode.class,
anyTree(node(TableScanNode.class)),
anyTree(node(TableScanNode.class))));
.joinIsNotFullyPushedDown();
}

/**
* Verify !SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN declaration is true.
*/
@Test
public void verifySupportsJoinPushdownWithFullJoinDeclaration()
{
if (hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN)) {
// Covered by testJoinPushdown
return;
}

assertThat(query(joinPushdownEnabled(getSession()), "SELECT r.name, n.name FROM nation n FULL JOIN region r ON n.regionkey = r.regionkey"))
.joinIsNotFullyPushedDown();
}

@Test(dataProvider = "joinOperators")
public void testJoinPushdown(JoinOperator joinOperator)
{
PlanMatchPattern joinOverTableScans =
node(JoinNode.class,
anyTree(node(TableScanNode.class)),
anyTree(node(TableScanNode.class)));

Session session = joinPushdownEnabled(getSession());

if (!hasBehavior(SUPPORTS_JOIN_PUSHDOWN)) {
assertThat(query(session, "SELECT r.name, n.name FROM nation n JOIN region r ON n.regionkey = r.regionkey"))
.isNotFullyPushedDown(joinOverTableScans);
.joinIsNotFullyPushedDown();
return;
}

if (joinOperator == FULL_JOIN && !hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN)) {
assertThat(query(session, "SELECT r.name, n.name FROM nation n FULL JOIN region r ON n.regionkey = r.regionkey"))
.isNotFullyPushedDown(joinOverTableScans);
// Covered by verifySupportsJoinPushdownWithFullJoinDeclaration
return;
}

Expand Down Expand Up @@ -1202,16 +1205,14 @@ public void testJoinPushdown(JoinOperator joinOperator)
assertThat(query(session, format("SELECT r.name, n.name FROM nation n %s region r USING(regionkey)", joinOperator))).isFullyPushedDown();

// varchar equality predicate
assertConditionallyPushedDown(
assertJoinConditionallyPushedDown(
session,
format("SELECT n.name, n2.regionkey FROM nation n %s nation n2 ON n.name = n2.name", joinOperator),
hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY),
joinOverTableScans);
assertConditionallyPushedDown(
hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY));
assertJoinConditionallyPushedDown(
session,
format("SELECT n.name, nl.regionkey FROM nation n %s %s nl ON n.name = nl.name", joinOperator, nationLowercaseTable.getName()),
hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY),
joinOverTableScans);
hasBehavior(SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY));

// multiple bigint predicates
assertThat(query(session, format("SELECT n.name, c.name FROM nation n %s customer c ON n.nationkey = c.nationkey and n.regionkey = c.custkey", joinOperator)))
Expand All @@ -1234,20 +1235,18 @@ public void testJoinPushdown(JoinOperator joinOperator)

// inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join
for (String operator : nonEqualities) {
assertConditionallyPushedDown(
assertJoinConditionallyPushedDown(
session,
format("SELECT n.name, c.name FROM nation n %s customer c ON n.nationkey = c.nationkey AND n.regionkey %s c.custkey", joinOperator, operator),
expectJoinPushdown(operator),
joinOverTableScans);
expectJoinPushdown(operator));
}

// varchar inequality along with an equality, which constitutes an equi-condition and allows filter to remain as part of the Join
for (String operator : nonEqualities) {
assertConditionallyPushedDown(
assertJoinConditionallyPushedDown(
session,
format("SELECT n.name, nl.name FROM nation n %s %s nl ON n.regionkey = nl.regionkey AND n.name %s nl.name", joinOperator, nationLowercaseTable.getName(), operator),
expectVarcharJoinPushdown(operator),
joinOverTableScans);
expectVarcharJoinPushdown(operator));
}

// Join over a (double) predicate
Expand All @@ -1258,44 +1257,39 @@ public void testJoinPushdown(JoinOperator joinOperator)
.isFullyPushedDown();

// Join over a varchar equality predicate
assertConditionallyPushedDown(
assertJoinConditionallyPushedDown(
session,
format("SELECT c.name, n.name FROM (SELECT * FROM customer WHERE address = 'TcGe5gaZNgVePxU5kRrvXBfkasDTea') c " +
"%s nation n ON c.custkey = n.nationkey", joinOperator),
hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY),
joinOverTableScans);
hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY));

// Join over a varchar inequality predicate
assertConditionallyPushedDown(
assertJoinConditionallyPushedDown(
session,
format("SELECT c.name, n.name FROM (SELECT * FROM customer WHERE address < 'TcGe5gaZNgVePxU5kRrvXBfkasDTea') c " +
"%s nation n ON c.custkey = n.nationkey", joinOperator),
hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY),
joinOverTableScans);
hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY));

// join over aggregation
assertConditionallyPushedDown(
assertJoinConditionallyPushedDown(
session,
format("SELECT * FROM (SELECT regionkey rk, count(nationkey) c FROM nation GROUP BY regionkey) n " +
"%s region r ON n.rk = r.regionkey", joinOperator),
hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN),
joinOverTableScans);
hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN));

// join over LIMIT
assertConditionallyPushedDown(
assertJoinConditionallyPushedDown(
session,
format("SELECT * FROM (SELECT nationkey FROM nation LIMIT 30) n " +
"%s region r ON n.nationkey = r.regionkey", joinOperator),
hasBehavior(SUPPORTS_LIMIT_PUSHDOWN),
joinOverTableScans);
hasBehavior(SUPPORTS_LIMIT_PUSHDOWN));

// join over TopN
assertConditionallyPushedDown(
assertJoinConditionallyPushedDown(
session,
format("SELECT * FROM (SELECT nationkey FROM nation ORDER BY regionkey LIMIT 5) n " +
"%s region r ON n.nationkey = r.regionkey", joinOperator),
hasBehavior(SUPPORTS_TOPN_PUSHDOWN),
joinOverTableScans);
hasBehavior(SUPPORTS_TOPN_PUSHDOWN));

// join over join
assertThat(query(session, "SELECT * FROM nation n, region r, customer c WHERE n.regionkey = r.regionkey AND r.regionkey = c.custkey"))
Expand Down

0 comments on commit fcee24d

Please sign in to comment.