Skip to content

Commit

Permalink
Fix bug with MetricFlowQueryRequest.sql_optimization_level handling (
Browse files Browse the repository at this point in the history
…#1524)

`MetricFlowQueryRequest` has the `sql_optimization_level` field which is
supposed to control the SQL optimization level. However, it was not
getting handled correctly, so this PR fixes that issue.
  • Loading branch information
plypaul authored Nov 14, 2024
1 parent 5c6e448 commit b5b6d3f
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 15 deletions.
13 changes: 7 additions & 6 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,6 @@ def __init__(
column_association_resolver=self._column_association_resolver,
semantic_manifest_lookup=self._semantic_manifest_lookup,
)
self._to_execution_plan_converter = DataflowToExecutionPlanConverter(
sql_plan_converter=self._to_sql_query_plan_converter,
sql_plan_renderer=self._sql_client.sql_query_plan_renderer,
sql_client=sql_client,
)
self._executor = SequentialPlanExecutor()

self._query_parser = query_parser or MetricFlowQueryParser(
Expand Down Expand Up @@ -539,7 +534,13 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me
)

logger.info(LazyFormat("Building execution plan"))
convert_to_execution_plan_result = self._to_execution_plan_converter.convert_to_execution_plan(dataflow_plan)
_to_execution_plan_converter = DataflowToExecutionPlanConverter(
sql_plan_converter=self._to_sql_query_plan_converter,
sql_plan_renderer=self._sql_client.sql_query_plan_renderer,
sql_client=self._sql_client,
sql_optimization_level=mf_query_request.sql_optimization_level,
)
convert_to_execution_plan_result = _to_execution_plan_converter.convert_to_execution_plan(dataflow_plan)
return MetricFlowExplainResult(
query_spec=query_spec,
dataflow_plan=dataflow_plan,
Expand Down
5 changes: 5 additions & 0 deletions metricflow/execution/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from metricflow.plan_conversion.convert_to_sql_plan import ConvertToSqlPlanResult
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from metricflow.sql.render.sql_plan_renderer import SqlPlanRenderResult, SqlQueryPlanRenderer

logger = logging.getLogger(__name__)
Expand All @@ -53,22 +54,26 @@ def __init__(
sql_plan_converter: DataflowToSqlQueryPlanConverter,
sql_plan_renderer: SqlQueryPlanRenderer,
sql_client: SqlClient,
sql_optimization_level: SqlQueryOptimizationLevel,
) -> None:
"""Constructor.
Args:
sql_plan_converter: Converts a dataflow plan node to a SQL query plan
sql_plan_renderer: Converts a SQL query plan to SQL text
sql_client: The client to use for running queries.
sql_optimization_level: The optimization level to use for generating the SQL.
"""
self._sql_plan_converter = sql_plan_converter
self._sql_plan_renderer = sql_plan_renderer
self._sql_client = sql_client
self._optimization_level = sql_optimization_level

def _convert_to_sql_plan(self, node: DataflowPlanNode) -> ConvertToSqlPlanResult:
logger.debug(LazyFormat(lambda: f"Generating SQL query plan from {node.node_id}"))
result = self._sql_plan_converter.convert_to_sql_query_plan(
sql_engine_type=self._sql_client.sql_engine_type,
optimization_level=self._optimization_level,
dataflow_plan_node=node,
)
logger.debug(LazyFormat(lambda: f"Generated SQL query plan is:\n{result.sql_plan.structure_text()}"))
Expand Down
4 changes: 4 additions & 0 deletions metricflow/sql/optimizer/optimization_levels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class SqlQueryOptimizationLevel(Enum):
O4 = "O4"
O5 = "O5"

@staticmethod
def default_level() -> SqlQueryOptimizationLevel: # noqa: D102
return SqlQueryOptimizationLevel.O4


@dataclass(frozen=True)
class SqlGenerationOptionSet:
Expand Down
25 changes: 25 additions & 0 deletions tests_metricflow/integration/test_mf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from _pytest.fixtures import FixtureRequest
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration

from metricflow.engine.metricflow_engine import MetricFlowExplainResult, MetricFlowQueryRequest
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from tests_metricflow.integration.conftest import IntegrationTestHelpers
from tests_metricflow.snapshot_utils import assert_object_snapshot_equal

Expand All @@ -16,3 +18,26 @@ def test_list_dimensions( # noqa: D103
obj_id="result0",
obj=sorted([dim.qualified_name for dim in it_helpers.mf_engine.list_dimensions()]),
)


def test_sql_optimization_level(it_helpers: IntegrationTestHelpers) -> None:
"""Check that different SQL optimization levels produce different SQL."""
assert (
SqlQueryOptimizationLevel.default_level() != SqlQueryOptimizationLevel.O0
), "The default optimization level should be different from the lowest level."
explain_result_at_default_level: MetricFlowExplainResult = it_helpers.mf_engine.explain(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=("bookings",),
group_by_names=("metric_time",),
sql_optimization_level=SqlQueryOptimizationLevel.default_level(),
)
)
explain_result_at_level_0: MetricFlowExplainResult = it_helpers.mf_engine.explain(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=("bookings",),
group_by_names=("metric_time",),
sql_optimization_level=SqlQueryOptimizationLevel.O0,
)
)

assert explain_result_at_default_level.rendered_sql.sql_query != explain_result_at_level_0.rendered_sql.sql_query
13 changes: 4 additions & 9 deletions tests_metricflow/plan_conversion/test_dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from metricflow.execution.dataflow_to_execution import DataflowToExecutionPlanConverter
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from tests_metricflow.snapshot_utils import assert_execution_plan_text_equal

Expand All @@ -30,6 +31,7 @@ def make_execution_plan_converter( # noqa: D103
),
sql_plan_renderer=DefaultSqlQueryPlanRenderer(),
sql_client=sql_client,
sql_optimization_level=SqlQueryOptimizationLevel.O4,
)


Expand Down Expand Up @@ -172,17 +174,10 @@ def test_multihop_joined_plan(
)
)

to_execution_plan_converter = DataflowToExecutionPlanConverter(
sql_plan_converter=DataflowToSqlQueryPlanConverter(
column_association_resolver=DunderColumnAssociationResolver(
partitioned_multi_hop_join_semantic_manifest_lookup
),
semantic_manifest_lookup=partitioned_multi_hop_join_semantic_manifest_lookup,
),
sql_plan_renderer=DefaultSqlQueryPlanRenderer(),
to_execution_plan_converter = make_execution_plan_converter(
semantic_manifest_lookup=partitioned_multi_hop_join_semantic_manifest_lookup,
sql_client=sql_client,
)

execution_plan = to_execution_plan_converter.convert_to_execution_plan(dataflow_plan).execution_plan

assert_execution_plan_text_equal(
Expand Down

0 comments on commit b5b6d3f

Please sign in to comment.