Skip to content

Commit

Permalink
fix: Remove BASE_AXIS from pre-query (#29084)
Browse files Browse the repository at this point in the history
(cherry picked from commit 17d7e7e)
  • Loading branch information
john-bodley authored and michael-s-molina committed Jun 5, 2024
1 parent c9a6537 commit a608917
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 19 deletions.
Binary file added null_byte.csv
Binary file not shown.
10 changes: 5 additions & 5 deletions superset/common/query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
get_column_names_from_columns,
get_column_names_from_metrics,
get_metric_names,
get_xaxis_label,
get_x_axis_label,
normalize_dttm_col,
TIME_COMPARISON,
)
Expand Down Expand Up @@ -403,7 +403,7 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme
for offset in query_object.time_offsets:
try:
# pylint: disable=line-too-long
# Since the xaxis is also a column name for the time filter, xaxis_label will be set as granularity
# Since the x-axis is also a column name for the time filter, x_axis_label will be set as granularity
# these query object are equivalent:
# 1) { granularity: 'dttm_col', time_range: '2020 : 2021', time_offsets: ['1 year ago']}
# 2) { columns: [
Expand All @@ -418,9 +418,9 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme
)
query_object_clone.to_dttm = get_past_or_future(offset, outer_to_dttm)

xaxis_label = get_xaxis_label(query_object.columns)
x_axis_label = get_x_axis_label(query_object.columns)
query_object_clone.granularity = (
query_object_clone.granularity or xaxis_label
query_object_clone.granularity or x_axis_label
)
except ValueError as ex:
raise QueryObjectValidationError(str(ex)) from ex
Expand All @@ -432,7 +432,7 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme
query_object_clone.filter = [
flt
for flt in query_object_clone.filter
if flt.get("col") != xaxis_label
if flt.get("col") != x_axis_label
]

# `offset` is added to the hash function
Expand Down
6 changes: 3 additions & 3 deletions superset/common/query_object_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
DatasourceDict,
DatasourceType,
FilterOperator,
get_xaxis_label,
get_x_axis_label,
QueryObjectFilterClause,
)

Expand Down Expand Up @@ -122,9 +122,9 @@ def _process_time_range(
# Use the temporal filter as the time range.
# if the temporal filters uses x-axis as the temporal filter
# then use it or use the first temporal filter
xaxis_label = get_xaxis_label(columns or [])
x_axis_label = get_x_axis_label(columns)
match_flt = [
flt for flt in temporal_flt if flt.get("col") == xaxis_label
flt for flt in temporal_flt if flt.get("col") == x_axis_label
]
if match_flt:
time_range = cast(str, match_flt[0].get("val"))
Expand Down
6 changes: 3 additions & 3 deletions superset/common/utils/time_range_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from superset import app
from superset.common.query_object import QueryObject
from superset.utils.core import FilterOperator, get_xaxis_label
from superset.utils.core import FilterOperator, get_x_axis_label
from superset.utils.date_parser import get_since_until


Expand Down Expand Up @@ -49,7 +49,7 @@ def get_since_until_from_query_object(
"""
this function will return since and until by tuple if
1) the time_range is in the query object.
2) the xaxis column is in the columns field
2) the x-axis column is in the columns field
and its corresponding `temporal_range` filter is in the adhoc filters.
:param query_object: a valid query object
:return: since and until by tuple
Expand All @@ -65,7 +65,7 @@ def get_since_until_from_query_object(
for flt in query_object.filter:
if (
flt.get("op") == FilterOperator.TEMPORAL_RANGE.value
and flt.get("col") == get_xaxis_label(query_object.columns)
and flt.get("col") == get_x_axis_label(query_object.columns)
and isinstance(flt.get("val"), str)
):
time_range = cast(str, flt.get("val"))
Expand Down
3 changes: 2 additions & 1 deletion superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from superset.utils.core import (
GenericDataType,
get_column_name,
get_non_base_axis_columns,
get_user_id,
is_adhoc_column,
MediumText,
Expand Down Expand Up @@ -2070,7 +2071,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
"filter": filter,
"orderby": orderby,
"extras": extras,
"columns": columns,
"columns": get_non_base_axis_columns(columns),
"order_desc": True,
}

Expand Down
21 changes: 14 additions & 7 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,16 +1177,23 @@ def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]:
)


def is_base_axis(column: Column) -> bool:
return is_adhoc_column(column) and column.get("columnType") == "BASE_AXIS"


def get_base_axis_columns(columns: list[Column] | None) -> list[Column]:
return [column for column in columns or [] if is_base_axis(column)]


def get_non_base_axis_columns(columns: list[Column] | None) -> list[Column]:
return [column for column in columns or [] if not is_base_axis(column)]


def get_base_axis_labels(columns: list[Column] | None) -> tuple[str, ...]:
axis_cols = [
col
for col in columns or []
if is_adhoc_column(col) and col.get("columnType") == "BASE_AXIS"
]
return tuple(get_column_name(col) for col in axis_cols)
return tuple(get_column_name(column) for column in get_base_axis_columns(columns))


def get_xaxis_label(columns: list[Column] | None) -> str | None:
def get_x_axis_label(columns: list[Column] | None) -> str | None:
labels = get_base_axis_labels(columns)
return labels[0] if labels else None

Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/db_engine_specs/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def test_where_latest_partition(
PrestoEngineSpec.where_latest_partition( # type: ignore
database=mock.MagicMock(),
table_name="table",
schema="schema",
query=sql.select(text("* FROM table")),
columns=[
{
Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/db_engine_specs/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
assert_column_spec,
assert_convert_dttm,
)
from tests.unit_tests.fixtures.common import dttm


def _assert_columns_equal(actual_cols, expected_cols) -> None:
Expand Down Expand Up @@ -575,6 +576,7 @@ def test_where_latest_partition(
TrinoEngineSpec.where_latest_partition( # type: ignore
database=MagicMock(),
table_name="table",
schema="schema",
query=sql.select(text("* FROM table")),
columns=[
{
Expand Down

0 comments on commit a608917

Please sign in to comment.