Skip to content

Commit

Permalink
fix: Apply normalization to all dttm columns (#25147)
Browse files Browse the repository at this point in the history
(cherry picked from commit 58fcd29)
  • Loading branch information
kgabryje authored and michael-s-molina committed Oct 9, 2023
1 parent 8b66603 commit dd769eb
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 10 deletions.
1 change: 1 addition & 0 deletions superset/common/query_context_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def _apply_granularity(
filter
for filter in query_object.filter
if filter["col"] != filter_to_remove
or filter["op"] != "TEMPORAL_RANGE"
]

def _apply_filters(self, query_object: QueryObject) -> None:
Expand Down
5 changes: 3 additions & 2 deletions superset/common/query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,11 @@ def _get_timestamp_format(
datasource = self._qc_datasource
labels = tuple(
label
for label in [
for label in {
*get_base_axis_labels(query_object.columns),
*[col for col in query_object.columns or [] if isinstance(col, str)],
query_object.granularity,
]
}
if datasource
# Query datasource didn't support `get_column`
and hasattr(datasource, "get_column")
Expand Down
67 changes: 65 additions & 2 deletions superset/common/query_object_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,24 @@
# under the License.
from __future__ import annotations

from datetime import datetime
from typing import Any, TYPE_CHECKING

from superset.common.chart_data import ChartDataResultType
from superset.common.query_object import QueryObject
from superset.common.utils.time_range_utils import get_since_until_from_time_range
from superset.utils.core import apply_max_row_limit, DatasourceDict, DatasourceType
from superset.utils.core import (
apply_max_row_limit,
DatasourceDict,
DatasourceType,
FilterOperator,
QueryObjectFilterClause,
)

if TYPE_CHECKING:
from sqlalchemy.orm import sessionmaker

from superset.connectors.base.models import BaseDatasource
from superset.connectors.base.models import BaseColumn, BaseDatasource
from superset.daos.datasource import DatasourceDAO


Expand Down Expand Up @@ -66,6 +73,10 @@ def create( # pylint: disable=too-many-arguments
)
kwargs["from_dttm"] = from_dttm
kwargs["to_dttm"] = to_dttm
if datasource_model_instance and kwargs.get("filters", []):
kwargs["filters"] = self._process_filters(
datasource_model_instance, kwargs["filters"]
)
return QueryObject(
datasource=datasource_model_instance,
extras=extras,
Expand Down Expand Up @@ -102,3 +113,55 @@ def _process_row_limit(
# light version of the view.utils.core
# import view.utils require application context
# Todo: move it and the view.utils.core to utils package

# pylint: disable=no-self-use
def _process_filters(
self, datasource: BaseDatasource, query_filters: list[QueryObjectFilterClause]
) -> list[QueryObjectFilterClause]:
def get_dttm_filter_value(
value: Any, col: BaseColumn, date_format: str
) -> int | str:
if not isinstance(value, int):
return value
if date_format in {"epoch_ms", "epoch_s"}:
if date_format == "epoch_s":
value = str(value)
else:
value = str(value * 1000)
else:
dttm = datetime.utcfromtimestamp(value / 1000)
value = dttm.strftime(date_format)

if col.type in col.num_types:
value = int(value)
return value

for query_filter in query_filters:
if query_filter.get("op") == FilterOperator.TEMPORAL_RANGE:
continue
filter_col = query_filter.get("col")
if not isinstance(filter_col, str):
continue
column = datasource.get_column(filter_col)
if not column:
continue
filter_value = query_filter.get("val")

date_format = column.python_date_format
if not date_format and datasource.db_extra:
date_format = datasource.db_extra.get(
"python_date_format_by_column_name", {}
).get(column.column_name)

if column.is_dttm and date_format:
if isinstance(filter_value, list):
query_filter["val"] = [
get_dttm_filter_value(value, column, date_format)
for value in filter_value
]
else:
query_filter["val"] = get_dttm_filter_value(
filter_value, column, date_format
)

return query_filters
8 changes: 3 additions & 5 deletions tests/integration_tests/query_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,11 +836,9 @@ def test_special_chars_in_column_name(app_context, physical_dataset):

query_object = qc.queries[0]
df = qc.get_df_payload(query_object)["df"]
if query_object.datasource.database.backend == "sqlite":
# sqlite returns string as timestamp column
assert df["time column with spaces"][0] == "2002-01-03 00:00:00"
assert df["I_AM_A_TRUNC_COLUMN"][0] == "2002-01-01 00:00:00"
else:

# sqlite doesn't have timestamp columns
if query_object.datasource.database.backend != "sqlite":
assert df["time column with spaces"][0].strftime("%Y-%m-%d") == "2002-01-03"
assert df["I_AM_A_TRUNC_COLUMN"][0].strftime("%Y-%m-%d") == "2002-01-01"

Expand Down
90 changes: 89 additions & 1 deletion tests/unit_tests/common/test_query_object_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,45 @@ def session_factory() -> Mock:
return Mock()


class SimpleDatasetColumn:
def __init__(self, col_params: dict[str, Any]):
self.__dict__.update(col_params)


TEMPORAL_COLUMN_NAMES = ["temporal_column", "temporal_column_with_python_date_format"]
TEMPORAL_COLUMNS = {
TEMPORAL_COLUMN_NAMES[0]: SimpleDatasetColumn(
{
"column_name": TEMPORAL_COLUMN_NAMES[0],
"is_dttm": True,
"python_date_format": None,
"type": "string",
"num_types": ["BIGINT"],
}
),
TEMPORAL_COLUMN_NAMES[1]: SimpleDatasetColumn(
{
"column_name": TEMPORAL_COLUMN_NAMES[1],
"type": "BIGINT",
"is_dttm": True,
"python_date_format": "%Y",
"num_types": ["BIGINT"],
}
),
}


@fixture
def connector_registry() -> Mock:
return Mock(spec=["get_datasource"])
datasource_dao_mock = Mock(spec=["get_datasource"])
datasource_dao_mock.get_datasource.return_value = Mock()
datasource_dao_mock.get_datasource().get_column = Mock(
side_effect=lambda col_name: TEMPORAL_COLUMNS[col_name]
if col_name in TEMPORAL_COLUMN_NAMES
else Mock()
)
datasource_dao_mock.get_datasource().db_extra = None
return datasource_dao_mock


def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int:
Expand Down Expand Up @@ -112,3 +148,55 @@ def test_query_context_null_post_processing_op(
raw_query_context["result_type"], **raw_query_object
)
assert query_object.post_processing == []

def test_query_context_no_python_date_format_filters(
self,
query_object_factory: QueryObjectFactory,
raw_query_context: dict[str, Any],
):
raw_query_object = raw_query_context["queries"][0]
raw_query_object["filters"].append(
{"col": TEMPORAL_COLUMN_NAMES[0], "op": "==", "val": 315532800000}
)
query_object = query_object_factory.create(
raw_query_context["result_type"],
raw_query_context["datasource"],
**raw_query_object
)
assert query_object.filter[3]["val"] == 315532800000

def test_query_context_python_date_format_filters(
self,
query_object_factory: QueryObjectFactory,
raw_query_context: dict[str, Any],
):
raw_query_object = raw_query_context["queries"][0]
raw_query_object["filters"].append(
{"col": TEMPORAL_COLUMN_NAMES[1], "op": "==", "val": 315532800000}
)
query_object = query_object_factory.create(
raw_query_context["result_type"],
raw_query_context["datasource"],
**raw_query_object
)
assert query_object.filter[3]["val"] == 1980

def test_query_context_python_date_format_filters_list_of_values(
self,
query_object_factory: QueryObjectFactory,
raw_query_context: dict[str, Any],
):
raw_query_object = raw_query_context["queries"][0]
raw_query_object["filters"].append(
{
"col": TEMPORAL_COLUMN_NAMES[1],
"op": "==",
"val": [315532800000, 631152000000],
}
)
query_object = query_object_factory.create(
raw_query_context["result_type"],
raw_query_context["datasource"],
**raw_query_object
)
assert query_object.filter[3]["val"] == [1980, 1990]

0 comments on commit dd769eb

Please sign in to comment.