Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: window in scalar subquery returns wrong results #16567

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ impl SubqueryRewriter {
Arc::new(left.clone()),
Arc::new(flatten_plan),
);
Ok((s_expr, UnnestResult::SingleJoin { output_index: None }))
Ok((s_expr, UnnestResult::SingleJoin))
}
SubqueryType::Exists | SubqueryType::NotExists => {
if is_conjunctive_predicate {
Expand Down
237 changes: 46 additions & 191 deletions src/query/sql/src/planner/optimizer/decorrelate/subquery_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use crate::optimizer::RelExpr;
use crate::optimizer::SExpr;
use crate::plans::Aggregate;
use crate::plans::AggregateFunction;
use crate::plans::AggregateMode;
use crate::plans::BoundColumnRef;
use crate::plans::CastExpr;
use crate::plans::ComparisonOp;
Expand Down Expand Up @@ -61,7 +60,7 @@ pub enum UnnestResult {
// Semi/Anti Join, Cross join for EXISTS
SimpleJoin { output_index: Option<IndexType> },
MarkJoin { marker_index: IndexType },
SingleJoin { output_index: Option<IndexType> },
SingleJoin,
}

pub struct FlattenInfo {
Expand Down Expand Up @@ -164,6 +163,17 @@ impl SubqueryRewriter {
Ok(SExpr::create_unary(Arc::new(plan.into()), Arc::new(input)))
}

RelOperator::Sort(mut sort) => {
let mut input = self.rewrite(s_expr.child(0)?)?;
for item in sort.window_partition.iter_mut() {
let res = self.try_rewrite_subquery(&item.scalar, &input, false)?;
input = res.1;
item.scalar = res.0;
}

Ok(SExpr::create_unary(Arc::new(sort.into()), Arc::new(input)))
}

RelOperator::Join(_) | RelOperator::UnionAll(_) | RelOperator::MaterializedCte(_) => {
Ok(SExpr::create_binary(
Arc::new(s_expr.plan().clone()),
Expand All @@ -172,13 +182,12 @@ impl SubqueryRewriter {
))
}

RelOperator::Limit(_)
| RelOperator::Sort(_)
| RelOperator::Udf(_)
| RelOperator::AsyncFunction(_) => Ok(SExpr::create_unary(
Arc::new(s_expr.plan().clone()),
Arc::new(self.rewrite(s_expr.child(0)?)?),
)),
RelOperator::Limit(_) | RelOperator::Udf(_) | RelOperator::AsyncFunction(_) => {
Ok(SExpr::create_unary(
Arc::new(s_expr.plan().clone()),
Arc::new(self.rewrite(s_expr.child(0)?)?),
))
}

RelOperator::DummyTableScan(_)
| RelOperator::Scan(_)
Expand Down Expand Up @@ -294,20 +303,15 @@ impl SubqueryRewriter {
}
let (index, name) = if let UnnestResult::MarkJoin { marker_index } = result {
(marker_index, marker_index.to_string())
} else if let UnnestResult::SingleJoin { output_index } = result {
if let Some(output_idx) = output_index {
// uncorrelated scalar subquery
(output_idx, "_if_scalar_subquery".to_string())
} else {
let mut output_column = subquery.output_column;
if let Some(index) = self.derived_columns.get(&output_column.index) {
output_column.index = *index;
}
(
output_column.index,
format!("scalar_subquery_{:?}", output_column.index),
)
} else if let UnnestResult::SingleJoin = result {
let mut output_column = subquery.output_column;
if let Some(index) = self.derived_columns.get(&output_column.index) {
output_column.index = *index;
}
(
output_column.index,
format!("scalar_subquery_{:?}", output_column.index),
)
} else {
let index = subquery.output_column.index;
(index, format!("subquery_{}", index))
Expand Down Expand Up @@ -423,7 +427,26 @@ impl SubqueryRewriter {
is_conjunctive_predicate: bool,
) -> Result<(SExpr, UnnestResult)> {
match subquery.typ {
SubqueryType::Scalar => self.rewrite_uncorrelated_scalar_subquery(left, subquery),
SubqueryType::Scalar => {
let join_plan = Join {
non_equi_conditions: vec![],
join_type: JoinType::LeftSingle,
marker_index: None,
from_correlated_subquery: false,
equi_conditions: vec![],
need_hold_hash_table: false,
is_lateral: false,
single_to_inner: None,
build_side_cache_info: None,
}
.into();
let s_expr = SExpr::create_binary(
Arc::new(join_plan),
Arc::new(left.clone()),
Arc::new(*subquery.subquery.clone()),
);
Ok((s_expr, UnnestResult::SingleJoin))
}
SubqueryType::Exists | SubqueryType::NotExists => {
let mut subquery_expr = *subquery.subquery.clone();
// Wrap Limit to current subquery
Expand Down Expand Up @@ -617,174 +640,6 @@ impl SubqueryRewriter {
_ => unreachable!(),
}
}

fn rewrite_uncorrelated_scalar_subquery(
&mut self,
left: &SExpr,
subquery: &SubqueryExpr,
) -> Result<(SExpr, UnnestResult)> {
// Use cross join which brings chance to push down filter under cross join.
// Such as `SELECT * FROM c WHERE c_id=(SELECT max(c_id) FROM o WHERE ship='WA');`
// We can push down `c_id = max(c_id)` to cross join then make it as inner join.
let join_plan = Join {
equi_conditions: JoinEquiCondition::new_conditions(vec![], vec![], vec![]),
non_equi_conditions: vec![],
join_type: JoinType::Cross,
marker_index: None,
from_correlated_subquery: false,
need_hold_hash_table: false,
is_lateral: false,
single_to_inner: None,
build_side_cache_info: None,
}
.into();

// For some cases, empty result set will be occur, we should return null instead of empty set.
// So let wrap an expression: `if(count()=0, null, any(subquery.output_column)`
let count_func = ScalarExpr::AggregateFunction(AggregateFunction {
span: subquery.span,
func_name: "count".to_string(),
distinct: false,
params: vec![],
args: vec![ScalarExpr::BoundColumnRef(BoundColumnRef {
span: None,
column: subquery.output_column.clone(),
})],
return_type: Box::new(DataType::Number(NumberDataType::UInt64)),
display_name: "count".to_string(),
});
let any_func = ScalarExpr::AggregateFunction(AggregateFunction {
span: subquery.span,
func_name: "any".to_string(),
distinct: false,
params: vec![],
return_type: subquery.output_column.data_type.clone(),
args: vec![ScalarExpr::BoundColumnRef(BoundColumnRef {
span: None,
column: subquery.output_column.clone(),
})],
display_name: "any".to_string(),
});
// Add `count_func` and `any_func` to metadata
let count_idx = self.metadata.write().add_derived_column(
"_count_scalar_subquery".to_string(),
DataType::Number(NumberDataType::UInt64),
None,
);
let any_idx = self.metadata.write().add_derived_column(
"_any_scalar_subquery".to_string(),
*subquery.output_column.data_type.clone(),
None,
);
// Aggregate operator
let agg = SExpr::create_unary(
Arc::new(
Aggregate {
mode: AggregateMode::Initial,
group_items: vec![],
aggregate_functions: vec![
ScalarItem {
scalar: count_func,
index: count_idx,
},
ScalarItem {
scalar: any_func,
index: any_idx,
},
],
..Default::default()
}
.into(),
),
Arc::new(*subquery.subquery.clone()),
);

let limit = SExpr::create_unary(
Arc::new(
Limit {
limit: Some(1),
offset: 0,
before_exchange: false,
}
.into(),
),
Arc::new(agg),
);

// Wrap expression
let count_col_ref = ScalarExpr::BoundColumnRef(BoundColumnRef {
span: None,
column: ColumnBindingBuilder::new(
"_count_scalar_subquery".to_string(),
count_idx,
Box::new(DataType::Number(NumberDataType::UInt64)),
Visibility::Visible,
)
.build(),
});
let any_col_ref = ScalarExpr::BoundColumnRef(BoundColumnRef {
span: None,
column: ColumnBindingBuilder::new(
"_any_scalar_subquery".to_string(),
any_idx,
subquery.output_column.data_type.clone(),
Visibility::Visible,
)
.build(),
});
let eq_func = ScalarExpr::FunctionCall(FunctionCall {
span: None,
func_name: "eq".to_string(),
params: vec![],
arguments: vec![
count_col_ref,
ScalarExpr::ConstantExpr(ConstantExpr {
span: None,
value: Scalar::Number(NumberScalar::UInt8(0)),
}),
],
});
// If function
let if_func = ScalarExpr::FunctionCall(FunctionCall {
span: None,
func_name: "if".to_string(),
params: vec![],
arguments: vec![
eq_func,
ScalarExpr::ConstantExpr(ConstantExpr {
span: None,
value: Scalar::Null,
}),
any_col_ref,
],
});
let if_func_idx = self.metadata.write().add_derived_column(
"_if_scalar_subquery".to_string(),
*subquery.output_column.data_type.clone(),
None,
);
let scalar_expr = SExpr::create_unary(
Arc::new(
EvalScalar {
items: vec![ScalarItem {
scalar: if_func,
index: if_func_idx,
}],
}
.into(),
),
Arc::new(limit),
);

let s_expr = SExpr::create_binary(
Arc::new(join_plan),
Arc::new(left.clone()),
Arc::new(scalar_expr),
);
Ok((s_expr, UnnestResult::SingleJoin {
output_index: Some(if_func_idx),
}))
}
}

pub fn check_child_expr_in_subquery(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,17 @@ pub fn try_push_down_filter_join(s_expr: &SExpr, metadata: MetadataRef) -> Resul
}
JoinPredicate::Other(_) => original_predicates.push(predicate),
JoinPredicate::Both { is_equal_op, .. } => {
if matches!(join.join_type, JoinType::Inner | JoinType::Cross) {
if matches!(join.join_type, JoinType::Inner | JoinType::Cross)
|| join.single_to_inner.is_some()
{
if is_equal_op {
push_down_predicates.push(predicate);
} else {
non_equi_predicates.push(predicate);
}
join.join_type = JoinType::Inner;
if join.join_type == JoinType::Cross {
join.join_type = JoinType::Inner;
}
} else {
original_predicates.push(predicate);
}
Expand Down
4 changes: 4 additions & 0 deletions src/query/sql/src/planner/semantic/distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,16 @@ impl DistinctToGroupBy {
distinct,
name,
args,
window,
..
},
},
alias,
} = &select_list[0]
{
if window.is_some() {
return;
}
let sub_query_name = "_distinct_group_by_subquery";
if ((name.name.to_ascii_lowercase() == "count" && *distinct)
|| name.name.to_ascii_lowercase() == "count_distinct")
Expand Down
Loading
Loading