Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][SPARK-31919][SQL] Push down more predicates through Join #28741

Conversation

gengliangwang
Copy link
Member

What changes were proposed in this pull request?

Currently, in the rulePushPredicateThroughJoin, if the condition predicate of Or operator can't be entirely pushed down, it will be thrown away.
In fact, the predicates under Or operators can be partially pushed down.
For example, says a and b are able to be pushed into one of the joined tables, while c can't be pushed down, the predicate
a or (b and c)
can be converted as
(a or b) and (a or c)
We can still push down (a or b).
We can't push down disjunctive predicates only when one of its children is not partially convertible.

The common way is to convert the condition into conjunctive normal form(CNF), so that we can find all the predicates that can be pushed down by going over the CNF predicate.
However, CNF conversion result can be huge, the recursive implementation can cause stack overflow on complex predicates. There were PRs for it such as #10444, #15558, #28575.
There is also non-recursive implementation: #28733 . It should be workable but this PR proposes a simpler implementation.

Essentially, we just need to traverse predicate and extract the convertible sub-predicates like what we did in #24598 . There is no need to maintain the CNF result set.

Why are the changes needed?

Improve query performance. PostgreSQL, Impala and Hive support similiar feature.

Does this PR introduce any user-facing change?

No

How was this patch tested?

Unit test and benchmark test.

SQL Before this PR After this PR
TPCDS 5T Q13 84s 21s
TPCDS 5T q85 66s 34s
TPCH 1T q19 37s 32s

@gengliangwang
Copy link
Member Author

I will add more test cases.

@gengliangwang gengliangwang changed the title [SPARK-31919][SQL] Push down more predicates through Join [WIP][SPARK-31919][SQL] Push down more predicates through Join Jun 6, 2020

def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally

val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is copied from the original one, except that the join condition is not changed when enablePushingExtraPredicates is true

@SparkQA
Copy link

SparkQA commented Jun 6, 2020

Test build #123587 has finished for PR 28741 at commit 2ca483e.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@maropu
Copy link
Member

maropu commented Jun 6, 2020

Just to check; it seems this PR doesn't depend on CNF conversion, but we could get the totally same performance gains with #28733?

import org.apache.spark.sql.catalyst.rules.Rule

trait PushPredicateThroughJoinBase extends Rule[LogicalPlan] with PredicateHelper {
protected def enablePushingExtraPredicates: Boolean
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you split PushPredicateThroughJoinBase into the two rules? You couldn't realize this optimization in a single rule?

// if we do not understand what the parent filters are.
//
// Here is an example used to explain the reason.
// Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: How about the case, NOT(a = 2 OR b in ('1'))? This case can be transformed into NOT(a = 2) AND NOT(b in ('1')) then I think it can be partially pushed down.

@wangyum
Copy link
Member

wangyum commented Jun 7, 2020

It seems this solution cannot be fully optimized.
PostgreSQL:

 Aggregate  (cost=77.33..77.34 rows=1 width=128)
   ->  Nested Loop  (cost=0.00..77.31 rows=1 width=32)
         Join Filter: (store_sales.ss_sold_date_sk = date_dim.d_date_sk)
         ->  Nested Loop  (cost=0.00..67.18 rows=1 width=36)
               Join Filter: ((store_sales.ss_addr_sk = customer_address.ca_address_sk) AND ((((customer_address.ca_state)::text = ANY ('{TX,OH,TX}'::text[])) AND (store_sales.ss_net_profit >=
'100'::numeric) AND (store_sales.ss_net_profit <= '200'::numeric)) OR (((customer_address.ca_state)::text = ANY ('{OR,NM,KY}'::text[])) AND (store_sales.ss_net_profit >= '150'::numeric) AND (s
tore_sales.ss_net_profit <= '300'::numeric)) OR (((customer_address.ca_state)::text = ANY ('{VA,TX,MS}'::text[])) AND (store_sales.ss_net_profit >= '50'::numeric) AND (store_sales.ss_net_profi
t <= '250'::numeric))))
               ->  Nested Loop  (cost=0.00..56.90 rows=1 width=54)
                     Join Filter: ((store_sales.ss_cdemo_sk = customer_demographics.cd_demo_sk) AND ((((customer_demographics.cd_marital_status)::text = 'M'::text) AND ((customer_demographics.
cd_education_status)::text = 'Advanced Degree'::text) AND (store_sales.ss_sales_price >= 100.00) AND (store_sales.ss_sales_price <= 150.00) AND (household_demographics.hd_dep_count = 3)) OR ((
(customer_demographics.cd_marital_status)::text = 'S'::text) AND ((customer_demographics.cd_education_status)::text = 'College'::text) AND (store_sales.ss_sales_price >= 50.00) AND (store_sale
s.ss_sales_price <= 100.00) AND (household_demographics.hd_dep_count = 1)) OR (((customer_demographics.cd_marital_status)::text = 'W'::text) AND ((customer_demographics.cd_education_status)::t
ext = '2 yr Degree'::text) AND (store_sales.ss_sales_price >= 150.00) AND (store_sales.ss_sales_price <= 200.00) AND (household_demographics.hd_dep_count = 1))))
                     ->  Nested Loop  (cost=0.00..46.10 rows=1 width=76)
                           Join Filter: (store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk)
                           ->  Nested Loop  (cost=0.00..33.61 rows=1 width=76)
                                 Join Filter: (store_sales.ss_store_sk = store.s_store_sk)
                                 ->  Seq Scan on store_sales  (cost=0.00..23.60 rows=1 width=80)
                                       Filter: ((((ss_sales_price >= 100.00) AND (ss_sales_price <= 150.00)) OR ((ss_sales_price >= 50.00) AND (ss_sales_price <= 100.00)) OR ((ss_sales_price >
= 150.00) AND (ss_sales_price <= 200.00))) AND (((ss_net_profit >= '100'::numeric) AND (ss_net_profit <= '200'::numeric)) OR ((ss_net_profit >= '150'::numeric) AND (ss_net_profit <= '300'::num
eric)) OR ((ss_net_profit >= '50'::numeric) AND (ss_net_profit <= '250'::numeric))))
                                 ->  Seq Scan on store  (cost=0.00..10.00 rows=1 width=4)
                           ->  Seq Scan on household_demographics  (cost=0.00..12.45 rows=3 width=8)
                                 Filter: ((hd_dep_count = 3) OR (hd_dep_count = 1) OR (hd_dep_count = 1))
                     ->  Seq Scan on customer_demographics  (cost=0.00..10.75 rows=1 width=1036)
                           Filter: ((((cd_marital_status)::text = 'M'::text) AND ((cd_education_status)::text = 'Advanced Degree'::text)) OR (((cd_marital_status)::text = 'S'::text) AND ((cd_e
ducation_status)::text = 'College'::text)) OR (((cd_marital_status)::text = 'W'::text) AND ((cd_education_status)::text = '2 yr Degree'::text)))
               ->  Seq Scan on customer_address  (cost=0.00..10.24 rows=1 width=520)
                     Filter: (((ca_country)::text = 'United States'::text) AND (((ca_state)::text = ANY ('{TX,OH,TX}'::text[])) OR ((ca_state)::text = ANY ('{OR,NM,KY}'::text[])) OR ((ca_state
)::text = ANY ('{VA,TX,MS}'::text[]))))
         ->  Seq Scan on date_dim  (cost=0.00..10.12 rows=1 width=4)
               Filter: (d_year = 2001)
(22 rows)

After this PR(set spark.sql.constraintPropagation.enabled=false to ignore infer IsNotNull):

*(7) HashAggregate(keys=[], functions=[avg(cast(ss_quantity#10 as bigint)), avg(UnscaledValue(ss_ext_sales_price#15)), avg(UnscaledValue(ss_ext_wholesale_cost#16)), sum(UnscaledValue(ss_ext_wholesale_cost#16))])
+- Exchange SinglePartition, true, [id=#252]
   +- *(6) HashAggregate(keys=[], functions=[partial_avg(cast(ss_quantity#10 as bigint)), partial_avg(UnscaledValue(ss_ext_sales_price#15)), partial_avg(UnscaledValue(ss_ext_wholesale_cost#16)), partial_sum(UnscaledValue(ss_ext_wholesale_cost#16))])
      +- *(6) Project [ss_quantity#10, ss_ext_sales_price#15, ss_ext_wholesale_cost#16]
         +- *(6) BroadcastHashJoin [ss_hdemo_sk#5], [hd_demo_sk#61], Inner, BuildRight, (((((((cd_marital_status#54 = M) AND (cd_education_status#55 = Advanced Degree)) AND (ss_sales_price#13 >= 100.00)) AND (ss_sales_price#13 <= 150.00)) AND (hd_dep_count#64 = 3)) OR (((((cd_marital_status#54 = S) AND (cd_education_status#55 = College)) AND (ss_sales_price#13 >= 50.00)) AND (ss_sales_price#13 <= 100.00)) AND (hd_dep_count#64 = 1))) OR (((((cd_marital_status#54 = W) AND (cd_education_status#55 = 2 yr Degree)) AND (ss_sales_price#13 >= 150.00)) AND (ss_sales_price#13 <= 200.00)) AND (hd_dep_count#64 = 1)))
            :- *(6) Project [ss_hdemo_sk#5, ss_quantity#10, ss_sales_price#13, ss_ext_sales_price#15, ss_ext_wholesale_cost#16, cd_marital_status#54, cd_education_status#55]
            :  +- *(6) BroadcastHashJoin [ss_cdemo_sk#4], [cd_demo_sk#52], Inner, BuildRight, ((((((cd_marital_status#54 = M) AND (cd_education_status#55 = Advanced Degree)) AND (ss_sales_price#13 >= 100.00)) AND (ss_sales_price#13 <= 150.00)) OR ((((cd_marital_status#54 = S) AND (cd_education_status#55 = College)) AND (ss_sales_price#13 >= 50.00)) AND (ss_sales_price#13 <= 100.00))) OR ((((cd_marital_status#54 = W) AND (cd_education_status#55 = 2 yr Degree)) AND (ss_sales_price#13 >= 150.00)) AND (ss_sales_price#13 <= 200.00)))
            :     :- *(6) Project [ss_cdemo_sk#4, ss_hdemo_sk#5, ss_quantity#10, ss_sales_price#13, ss_ext_sales_price#15, ss_ext_wholesale_cost#16]
            :     :  +- *(6) BroadcastHashJoin [ss_sold_date_sk#0], [d_date_sk#79], Inner, BuildRight
            :     :     :- *(6) Project [ss_sold_date_sk#0, ss_cdemo_sk#4, ss_hdemo_sk#5, ss_quantity#10, ss_sales_price#13, ss_ext_sales_price#15, ss_ext_wholesale_cost#16]
            :     :     :  +- *(6) BroadcastHashJoin [ss_addr_sk#6], [ca_address_sk#66], Inner, BuildRight, ((((ca_state#74 IN (TX,OH) AND (ss_net_profit#22 >= 100.00)) AND (ss_net_profit#22 <= 200.00)) OR ((ca_state#74 IN (OR,NM,KY) AND (ss_net_profit#22 >= 150.00)) AND (ss_net_profit#22 <= 300.00))) OR ((ca_state#74 IN (VA,TX,MS) AND (ss_net_profit#22 >= 50.00)) AND (ss_net_profit#22 <= 250.00)))
            :     :     :     :- *(6) Project [ss_sold_date_sk#0, ss_cdemo_sk#4, ss_hdemo_sk#5, ss_addr_sk#6, ss_quantity#10, ss_sales_price#13, ss_ext_sales_price#15, ss_ext_wholesale_cost#16, ss_net_profit#22]
            :     :     :     :  +- *(6) BroadcastHashJoin [ss_store_sk#7], [s_store_sk#23], Inner, BuildRight
            :     :     :     :     :- *(6) Filter ((((ss_net_profit#22 >= 100.00) AND (ss_net_profit#22 <= 200.00)) OR ((ss_net_profit#22 >= 150.00) AND (ss_net_profit#22 <= 300.00))) OR ((ss_net_profit#22 >= 50.00) AND (ss_net_profit#22 <= 250.00)))
            :     :     :     :     :  +- *(6) ColumnarToRow
            :     :     :     :     :     +- FileScan parquet default.store_sales[ss_sold_date_sk#0,ss_cdemo_sk#4,ss_hdemo_sk#5,ss_addr_sk#6,ss_store_sk#7,ss_quantity#10,ss_sales_price#13,ss_ext_sales_price#15,ss_ext_wholesale_cost#16,ss_net_profit#22] Batched: true, DataFilters: [((((ss_net_profit#22 >= 100.00) AND (ss_net_profit#22 <= 200.00)) OR ((ss_net_profit#22 >= 150.0..., Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/opensource/spark/sql/core/spark-warehouse/org.apache.spark...., PartitionFilters: [], PushedFilters: [Or(Or(And(GreaterThanOrEqual(ss_net_profit,100.00),LessThanOrEqual(ss_net_profit,200.00)),And(Gr..., ReadSchema: struct<ss_sold_date_sk:int,ss_cdemo_sk:int,ss_hdemo_sk:int,ss_addr_sk:int,ss_store_sk:int,ss_quan...
            :     :     :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#213]
            :     :     :     :        +- *(1) ColumnarToRow
            :     :     :     :           +- FileScan parquet default.store[s_store_sk#23] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/opensource/spark/sql/core/spark-warehouse/org.apache.spark...., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<s_store_sk:int>
            :     :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#222]
            :     :     :        +- *(2) Project [ca_address_sk#66, ca_state#74]
            :     :     :           +- *(2) Filter ((ca_country#76 = United States) AND ((ca_state#74 IN (TX,OH) OR ca_state#74 IN (OR,NM,KY)) OR ca_state#74 IN (VA,TX,MS)))
            :     :     :              +- *(2) ColumnarToRow
            :     :     :                 +- FileScan parquet default.customer_address[ca_address_sk#66,ca_state#74,ca_country#76] Batched: true, DataFilters: [(ca_country#76 = United States), ((ca_state#74 IN (TX,OH) OR ca_state#74 IN (OR,NM,KY)) OR ca_st..., Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/opensource/spark/sql/core/spark-warehouse/org.apache.spark...., PartitionFilters: [], PushedFilters: [EqualTo(ca_country,United States), Or(Or(In(ca_state, [TX,OH]),In(ca_state, [OR,NM,KY])),In(ca_s..., ReadSchema: struct<ca_address_sk:int,ca_state:string,ca_country:string>
            :     :     +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#231]
            :     :        +- *(3) Project [d_date_sk#79]
            :     :           +- *(3) Filter (d_year#85 = 2001)
            :     :              +- *(3) ColumnarToRow
            :     :                 +- FileScan parquet default.date_dim[d_date_sk#79,d_year#85] Batched: true, DataFilters: [(d_year#85 = 2001)], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/opensource/spark/sql/core/spark-warehouse/org.apache.spark...., PartitionFilters: [], PushedFilters: [EqualTo(d_year,2001)], ReadSchema: struct<d_date_sk:int,d_year:int>
            :     +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#238]
            :        +- *(4) ColumnarToRow
            :           +- FileScan parquet default.customer_demographics[cd_demo_sk#52,cd_marital_status#54,cd_education_status#55] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/opensource/spark/sql/core/spark-warehouse/org.apache.spark...., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<cd_demo_sk:int,cd_marital_status:string,cd_education_status:string>
            +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#246]
               +- *(5) Filter (((hd_dep_count#64 = 3) OR (hd_dep_count#64 = 1)) OR (hd_dep_count#64 = 1))
                  +- *(5) ColumnarToRow
                     +- FileScan parquet default.household_demographics[hd_demo_sk#61,hd_dep_count#64] Batched: true, DataFilters: [(((hd_dep_count#64 = 3) OR (hd_dep_count#64 = 1)) OR (hd_dep_count#64 = 1))], Format: Parquet, Location: InMemoryFileIndex[file:/Users/yumwang/opensource/spark/sql/core/spark-warehouse/org.apache.spark...., PartitionFilters: [], PushedFilters: [Or(Or(EqualTo(hd_dep_count,3),EqualTo(hd_dep_count,1)),EqualTo(hd_dep_count,1))], ReadSchema: struct<hd_demo_sk:int,hd_dep_count:int>

@gengliangwang
Copy link
Member Author

gengliangwang commented Jun 7, 2020

@wangyum Thanks for point it out.
I made a mistake. This solution is not as powerful as CNF conversion

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants