Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Apply normalization to all dttm columns #25147

Merged
merged 7 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions superset/common/query_context_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def _apply_granularity(
filter
for filter in query_object.filter
if filter["col"] != filter_to_remove
or filter["op"] != "TEMPORAL_RANGE"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for this change, i.e., why are we including a filter which was flagged for removal if it operation is not a temporal range?

]

def _apply_filters(self, query_object: QueryObject) -> None:
Expand Down
5 changes: 3 additions & 2 deletions superset/common/query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,11 @@ def _get_timestamp_format(
datasource = self._qc_datasource
labels = tuple(
label
for label in [
for label in {
*get_base_axis_labels(query_object.columns),
*[col for col in query_object.columns or [] if isinstance(col, str)],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there risk that labels will contain duplicated columns? If so, maybe we should put a set(...) inside the tuple(...) to dedupe them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I don't know if there's such risk but can't hurt to make it bulletproof 👍

query_object.granularity,
]
}
if datasource
# Query datasource didn't support `get_column`
and hasattr(datasource, "get_column")
Expand Down
66 changes: 64 additions & 2 deletions superset/common/query_object_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,24 @@
# under the License.
from __future__ import annotations

from datetime import datetime
from typing import Any, TYPE_CHECKING

from superset.common.chart_data import ChartDataResultType
from superset.common.query_object import QueryObject
from superset.common.utils.time_range_utils import get_since_until_from_time_range
from superset.utils.core import apply_max_row_limit, DatasourceDict, DatasourceType
from superset.utils.core import (
apply_max_row_limit,
DatasourceDict,
DatasourceType,
FilterOperator,
QueryObjectFilterClause,
)

if TYPE_CHECKING:
from sqlalchemy.orm import sessionmaker

from superset.connectors.base.models import BaseDatasource
from superset.connectors.base.models import BaseColumn, BaseDatasource
from superset.daos.datasource import DatasourceDAO


Expand Down Expand Up @@ -66,6 +73,10 @@ def create( # pylint: disable=too-many-arguments
)
kwargs["from_dttm"] = from_dttm
kwargs["to_dttm"] = to_dttm
if datasource_model_instance and kwargs.get("filters", []):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if datasource_model_instance and kwargs.get("filters", []):
if datasource_model_instance and kwargs.get("filters"):

kwargs["filters"] = self._process_filters(
datasource_model_instance, kwargs["filters"]
)
return QueryObject(
datasource=datasource_model_instance,
extras=extras,
Expand Down Expand Up @@ -102,3 +113,54 @@ def _process_row_limit(
# light version of the view.utils.core
# import view.utils require application context
# Todo: move it and the view.utils.core to utils package

def _process_filters(
self, datasource: BaseDatasource, query_filters: list[QueryObjectFilterClause]
) -> list[QueryObjectFilterClause]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we returning the filters where in actuality we're mutating the query_filters in place?

def get_dttm_filter_value(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of this logic and the logic defined on lines 156-164 was already defined here. Wouldn't it be better (and potentially less error prone) if we adhered to the DRY principle and reused the same helper function.

value: Any, col: BaseColumn, date_format: str
) -> int | str:
if not isinstance(value, int):
return value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If value is Any but not an int there's no guarantee that the return type will be a str. Am I missing something here?

if date_format in {"epoch_ms", "epoch_s"}:
if date_format == "epoch_s":
value = str(value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If value is coming from the frontend and is a timestamp in milliseconds shouldn't epoch_s be str(value / 1000) and epoch_ms be str(value)?

else:
value = str(value * 1000)
else:
dttm = datetime.utcfromtimestamp(value / 1000)
value = dttm.strftime(date_format)

if col.type in col.num_types:
value = int(value)
return value

for query_filter in query_filters:
if query_filter.get("op") == FilterOperator.TEMPORAL_RANGE:
Copy link
Member

@john-bodley john-bodley Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than the continue I think one if statement (based on truthiness and not falseness) is likely more readable. i.e.,

if (
    query_filter.get("op") != FilterOperator.TEMPORAL_RANGE
    and (filter_col := query_filter.get("col"))
    and isinstance(filter_col, str)
    and (column := datasource.get_column(filter_col))
 ):
     ...

continue
filter_col = query_filter.get("col")
if not isinstance(filter_col, str):
continue
column = datasource.get_column(filter_col)
if not column:
continue
Comment on lines +144 to +146
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mentioned during our call that you've added lines 145-146 to make the linter happy. A better way — more explicit — is to do:

            column = cast(BaseColumn, datasource.get_column(filter_col))

This informs the linter that column is always of type BaseColumn.

Copy link
Member Author

@kgabryje kgabryje Sep 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn’t it safer to check for None since .get_column can return None if the col name doesn't exist? I don't know if such scenario is possible though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I thought datasource.get_column(filter_col) was guaranteed to return a truthy value. If that's not the case, then this is fine.

filter_value = query_filter.get("val")

date_format = column.python_date_format
if not date_format and datasource.db_extra:
date_format = datasource.db_extra.get(
"python_date_format_by_column_name", {}
).get(column.column_name)

if column.is_dttm and date_format:
if isinstance(filter_value, list):
query_filter["val"] = [
get_dttm_filter_value(value, column, date_format)
for value in filter_value
]
else:
query_filter["val"] = get_dttm_filter_value(
filter_value, column, date_format
)

return query_filters
8 changes: 3 additions & 5 deletions tests/integration_tests/query_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,11 +836,9 @@ def test_special_chars_in_column_name(app_context, physical_dataset):

query_object = qc.queries[0]
df = qc.get_df_payload(query_object)["df"]
if query_object.datasource.database.backend == "sqlite":
# sqlite returns string as timestamp column
assert df["time column with spaces"][0] == "2002-01-03 00:00:00"
assert df["I_AM_A_TRUNC_COLUMN"][0] == "2002-01-01 00:00:00"
else:

# sqlite doesn't have timestamp columns
if query_object.datasource.database.backend != "sqlite":
assert df["time column with spaces"][0].strftime("%Y-%m-%d") == "2002-01-03"
assert df["I_AM_A_TRUNC_COLUMN"][0].strftime("%Y-%m-%d") == "2002-01-01"

Expand Down
90 changes: 89 additions & 1 deletion tests/unit_tests/common/test_query_object_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,45 @@ def session_factory() -> Mock:
return Mock()


class SimpleDatasetColumn:
def __init__(self, col_params: dict[str, Any]):
self.__dict__.update(col_params)


TEMPORAL_COLUMN_NAMES = ["temporal_column", "temporal_column_with_python_date_format"]
TEMPORAL_COLUMNS = {
TEMPORAL_COLUMN_NAMES[0]: SimpleDatasetColumn(
{
"column_name": TEMPORAL_COLUMN_NAMES[0],
"is_dttm": True,
"python_date_format": None,
"type": "string",
"num_types": ["BIGINT"],
}
),
TEMPORAL_COLUMN_NAMES[1]: SimpleDatasetColumn(
{
"column_name": TEMPORAL_COLUMN_NAMES[1],
"type": "BIGINT",
"is_dttm": True,
"python_date_format": "%Y",
"num_types": ["BIGINT"],
}
),
}


@fixture
def connector_registry() -> Mock:
return Mock(spec=["get_datasource"])
datasource_dao_mock = Mock(spec=["get_datasource"])
datasource_dao_mock.get_datasource.return_value = Mock()
datasource_dao_mock.get_datasource().get_column = Mock(
side_effect=lambda col_name: TEMPORAL_COLUMNS[col_name]
if col_name in TEMPORAL_COLUMN_NAMES
else Mock()
)
datasource_dao_mock.get_datasource().db_extra = None
return datasource_dao_mock


def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int:
Expand Down Expand Up @@ -112,3 +148,55 @@ def test_query_context_null_post_processing_op(
raw_query_context["result_type"], **raw_query_object
)
assert query_object.post_processing == []

def test_query_context_no_python_date_format_filters(
self,
query_object_factory: QueryObjectFactory,
raw_query_context: dict[str, Any],
):
raw_query_object = raw_query_context["queries"][0]
raw_query_object["filters"].append(
{"col": TEMPORAL_COLUMN_NAMES[0], "op": "==", "val": 315532800000}
)
query_object = query_object_factory.create(
raw_query_context["result_type"],
raw_query_context["datasource"],
**raw_query_object
)
assert query_object.filter[3]["val"] == 315532800000

def test_query_context_python_date_format_filters(
self,
query_object_factory: QueryObjectFactory,
raw_query_context: dict[str, Any],
):
raw_query_object = raw_query_context["queries"][0]
raw_query_object["filters"].append(
{"col": TEMPORAL_COLUMN_NAMES[1], "op": "==", "val": 315532800000}
)
query_object = query_object_factory.create(
raw_query_context["result_type"],
raw_query_context["datasource"],
**raw_query_object
)
assert query_object.filter[3]["val"] == 1980

def test_query_context_python_date_format_filters_list_of_values(
self,
query_object_factory: QueryObjectFactory,
raw_query_context: dict[str, Any],
):
raw_query_object = raw_query_context["queries"][0]
raw_query_object["filters"].append(
{
"col": TEMPORAL_COLUMN_NAMES[1],
"op": "==",
"val": [315532800000, 631152000000],
}
)
query_object = query_object_factory.create(
raw_query_context["result_type"],
raw_query_context["datasource"],
**raw_query_object
)
assert query_object.filter[3]["val"] == [1980, 1990]
Loading