Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(query): refactor bind set returning functions #16316

Merged
merged 2 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions src/query/ast/src/ast/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,27 +174,6 @@ impl Display for TableRef {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Drive, DriveMut)]
pub struct DictionaryRef {
pub catalog: Option<Identifier>,
pub database: Option<Identifier>,
pub dictionary_name: Identifier,
}

impl Display for DictionaryRef {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
assert!(self.catalog.is_none() || (self.catalog.is_some() && self.database.is_some()));
if let Some(catalog) = &self.catalog {
write!(f, "{}.", catalog)?;
}
if let Some(database) = &self.database {
write!(f, "{}.", database)?;
}
write!(f, "{}", self.dictionary_name)?;
Ok(())
}
}

#[derive(Debug, Clone, PartialEq, Eq, Drive, DriveMut)]
pub struct ColumnRef {
pub database: Option<Identifier>,
Expand Down
10 changes: 0 additions & 10 deletions src/query/ast/src/ast/visitors/walk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,6 @@ pub fn walk_table_ref<'a, V: Visitor<'a>>(visitor: &mut V, table: &'a TableRef)
visitor.visit_identifier(&table.table);
}

pub fn walk_dictionary_ref<'a, V: Visitor<'a>>(visitor: &mut V, dictionary: &'a DictionaryRef) {
if let Some(catalog) = &dictionary.catalog {
visitor.visit_identifier(catalog);
}
if let Some(database) = &dictionary.database {
visitor.visit_identifier(database);
}
visitor.visit_identifier(&dictionary.dictionary_name);
}

pub fn walk_query<'a, V: Visitor<'a>>(visitor: &mut V, query: &'a Query) {
let Query {
with,
Expand Down
8 changes: 4 additions & 4 deletions src/query/sql/src/planner/binder/bind_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ use crate::binder::ColumnBindingBuilder;
use crate::normalize_identifier;
use crate::optimizer::SExpr;
use crate::plans::ScalarExpr;
use crate::plans::ScalarItem;
use crate::ColumnSet;
use crate::IndexType;
use crate::MetadataRef;
Expand Down Expand Up @@ -135,8 +136,7 @@ pub struct BindContext {
pub view_info: Option<(String, String)>,

/// Set-returning functions in current context.
/// The key is the `Expr::to_string` of the function.
pub srfs: DashMap<String, ScalarExpr>,
pub srfs: Vec<ScalarItem>,

pub inverted_index_map: Box<IndexMap<IndexType, InvertedIndexInfo>>,

Expand Down Expand Up @@ -177,7 +177,7 @@ impl BindContext {
cte_map_ref: Box::default(),
in_grouping: false,
view_info: None,
srfs: DashMap::new(),
srfs: Vec::new(),
inverted_index_map: Box::default(),
expr_context: ExprContext::default(),
planning_agg_index: false,
Expand All @@ -197,7 +197,7 @@ impl BindContext {
cte_map_ref: parent.cte_map_ref.clone(),
in_grouping: false,
view_info: None,
srfs: DashMap::new(),
srfs: Vec::new(),
inverted_index_map: Box::default(),
expr_context: ExprContext::default(),
planning_agg_index: false,
Expand Down
26 changes: 10 additions & 16 deletions src/query/sql/src/planner/binder/bind_query/bind_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ use derive_visitor::Drive;
use derive_visitor::Visitor;
use log::warn;

use crate::binder::project_set::SrfCollector;
use crate::optimizer::SExpr;
use crate::planner::binder::BindContext;
use crate::planner::binder::Binder;
Expand Down Expand Up @@ -102,27 +101,16 @@ impl Binder {
let new_stmt = rewriter.rewrite(stmt)?;
let stmt = new_stmt.as_ref().unwrap_or(stmt);

// Collect set returning functions
let set_returning_functions = {
let mut collector = SrfCollector::new();
stmt.select_list.iter().for_each(|item| {
if let SelectTarget::AliasedExpr { expr, .. } = item {
collector.visit(expr);
}
});
collector.into_srfs()
};

// Bind set returning functions
s_expr = self.bind_project_set(&mut from_context, &set_returning_functions, s_expr)?;

// Try put window definitions into bind context.
// This operation should be before `normalize_select_list` because window functions can be used in select list.
self.analyze_window_definition(&mut from_context, &stmt.window_list)?;

// Generate a analyzed select list with from context
let mut select_list = self.normalize_select_list(&mut from_context, &stmt.select_list)?;

// analyze set returning functions
self.analyze_project_set_select(&mut from_context, &mut select_list)?;

// This will potentially add some alias group items to `from_context` if find some.
if let Some(group_by) = stmt.group_by.as_ref() {
self.analyze_group_items(&mut from_context, &select_list, group_by)?;
Expand All @@ -140,6 +128,12 @@ impl Binder {
.map(|item| (item.alias.clone(), item.scalar.clone()))
.collect::<Vec<_>>();

let have_srfs = !from_context.srfs.is_empty();
if have_srfs {
// Bind set returning functions first.
s_expr = self.bind_project_set(&mut from_context, s_expr)?;
}

// To support using aliased column in `WHERE` clause,
// we should bind where after `select_list` is rewritten.
let where_scalar = if let Some(expr) = &stmt.selection {
Expand Down Expand Up @@ -179,7 +173,7 @@ impl Binder {
)?;

// After all analysis is done.
if set_returning_functions.is_empty() {
if !have_srfs {
// Ignore SRFs.
self.analyze_lazy_materialization(
&from_context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,19 @@ impl Binder {
lambda: None,
},
};
let srfs = vec![srf.clone()];
let srf_expr = self.bind_project_set(&mut bind_context, &srfs, child)?;

if let Some((_, srf_result)) = bind_context.srfs.remove(&srf.to_string()) {
let select_list = vec![SelectTarget::AliasedExpr {
expr: Box::new(srf.clone()),
alias: None,
}];
let mut select_list =
self.normalize_select_list(&mut bind_context, &select_list)?;
// analyze set returning functions
self.analyze_project_set_select(&mut bind_context, &mut select_list)?;
// bind set returning functions
let srf_expr = self.bind_project_set(&mut bind_context, child)?;

if let Some(item) = select_list.items.pop() {
let srf_result = item.scalar;
let column_binding =
if let ScalarExpr::BoundColumnRef(column_ref) = &srf_result {
column_ref.column.clone()
Expand Down Expand Up @@ -408,7 +417,7 @@ impl Binder {
"The function '{}' is not supported for lateral joins. Lateral joins currently support only Set Returning Functions (SRFs).",
func_name
))
.set_span(*span))
.set_span(*span))
}
}
_ => unreachable!(),
Expand Down
Loading
Loading