Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP - optimize filter elements node? #1432

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
43 changes: 27 additions & 16 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,28 +801,34 @@ def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> SqlD

def visit_filter_elements_node(self, node: FilterElementsNode) -> SqlDataSet:
"""Generates the query that realizes the behavior of FilterElementsNode."""
from_data_set: SqlDataSet = node.parent_node.accept(self)
output_instance_set = from_data_set.instance_set.transform(FilterElements(node.include_specs))
from_data_set_alias = self._next_unique_table_alias()

# Also, the output columns should always follow the resolver format.
output_instance_set = output_instance_set.transform(ChangeAssociatedColumns(self._column_association_resolver))

# This creates select expressions for all columns referenced in the instance set.
select_columns = output_instance_set.transform(
CreateSelectColumnsForInstances(from_data_set_alias, self._column_association_resolver)
).as_tuple()
parent_data_set: SqlDataSet = node.parent_node.accept(self)
output_instance_set = parent_data_set.instance_set.transform(FilterElements(node.include_specs))
output_column_names = [instance.associated_column.column_name for instance in output_instance_set.as_tuple]
output_select_columns = [
select_column
for select_column in parent_data_set.checked_sql_select_node.select_columns
if select_column.column_alias in output_column_names
]
# where is the limiting factor! need those columns in the select statement
# could have a conditional - use subquery if where clase, else don't
where = parent_data_set.checked_sql_select_node.where

# If distinct values requested, group by all select columns.
group_bys = select_columns if node.distinct else ()
group_bys = tuple(output_select_columns if node.distinct else parent_data_set.checked_sql_select_node.group_bys)
return SqlDataSet(
instance_set=output_instance_set,
# add method from parent node with override params
sql_select_node=SqlSelectStatementNode.create(
description=node.description,
select_columns=select_columns,
from_source=from_data_set.checked_sql_select_node,
from_source_alias=from_data_set_alias,
description=node.parent_node.description + "\n" + node.description,
select_columns=tuple(output_select_columns),
from_source=parent_data_set.checked_sql_select_node.from_source,
from_source_alias=parent_data_set.checked_sql_select_node.from_source_alias,
join_descs=parent_data_set.checked_sql_select_node.join_descs,
where=where,
group_bys=group_bys,
order_bys=parent_data_set.checked_sql_select_node.order_bys,
limit=parent_data_set.checked_sql_select_node.limit,
distinct=parent_data_set.checked_sql_select_node.distinct,
),
)

Expand Down Expand Up @@ -1499,6 +1505,11 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
from_source=parent_data_set.checked_sql_select_node.from_source,
from_source_alias=parent_alias,
join_descs=parent_data_set.checked_sql_select_node.join_descs + (join_description,),
where=parent_data_set.checked_sql_select_node.where,
group_bys=parent_data_set.checked_sql_select_node.group_bys,
order_bys=parent_data_set.checked_sql_select_node.order_bys,
limit=parent_data_set.checked_sql_select_node.limit,
distinct=parent_data_set.checked_sql_select_node.distinct,
),
)

Expand Down
Loading
Loading