diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index b9c083e42b94d..f355e4ef8cea8 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -323,6 +323,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # engine-specific type mappings to check prior to the defaults column_type_mappings: tuple[ColumnTypeMapping, ...] = () + # type-specific functions to mutate values received from the database. + # Needed on certain databases that return values in an unexpected format + column_type_mutators: dict[TypeEngine, Callable[[Any], Any]] = {} + # Does database support join-free timeslot grouping time_groupby_inline = False limit_method = LimitMethod.FORCE_LIMIT @@ -743,7 +747,30 @@ def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, .. try: if cls.limit_method == LimitMethod.FETCH_MANY and limit: return cursor.fetchmany(limit) - return cursor.fetchall() + data = cursor.fetchall() + description = cursor.description or [] + # Create a mapping between column name and a mutator function to normalize + # values with. The first two items in the description row are + # the column name and type. + column_mutators = { + row[0]: func + for row in description + if ( + func := cls.column_type_mutators.get( + type(cls.get_sqla_column_type(cls.get_datatype(row[1]))) + ) + ) + } + if column_mutators: + indexes = {row[0]: idx for idx, row in enumerate(description)} + for row_idx, row in enumerate(data): + new_row = list(row) + for col, func in column_mutators.items(): + col_idx = indexes[col] + new_row[col_idx] = func(row[col_idx]) + data[row_idx] = tuple(new_row) + + return data except Exception as ex: raise cls.get_dbapi_mapped_exception(ex) from ex diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 4d5604222d74e..687ffee7d565b 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -17,8 +17,9 @@ import contextlib import re from datetime import datetime +from decimal import Decimal from re import Pattern -from typing import Any, Optional +from typing import Any, Callable, Optional from urllib import parse from flask_babel import gettext as __ @@ -126,6 +127,9 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): GenericDataType.STRING, ), ) + column_type_mutators: dict[types.TypeEngine, Callable[[Any], Any]] = { + DECIMAL: lambda val: Decimal(val) if isinstance(val, str) else val + } _time_grain_expressions = { None: "{col}", diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py index 89abf2321d79b..ed643470176ec 100644 --- a/tests/unit_tests/db_engine_specs/test_mysql.py +++ b/tests/unit_tests/db_engine_specs/test_mysql.py @@ -16,6 +16,7 @@ # under the License. from datetime import datetime +from decimal import Decimal from typing import Any, Optional from unittest.mock import Mock, patch @@ -220,3 +221,42 @@ def test_get_schema_from_engine_params() -> None: ) == "db1" ) + + +@pytest.mark.parametrize( + "data,description,expected_result", + [ + ( + [("1.23456", "abc")], + [("dec", "decimal(12,6)"), ("str", "varchar(3)")], + [(Decimal("1.23456"), "abc")], + ), + ( + [(Decimal("1.23456"), "abc")], + [("dec", "decimal(12,6)"), ("str", "varchar(3)")], + [(Decimal("1.23456"), "abc")], + ), + ( + [(None, "abc")], + [("dec", "decimal(12,6)"), ("str", "varchar(3)")], + [(None, "abc")], + ), + ( + [("1.23456", "abc")], + [("dec", "varchar(255)"), ("str", "varchar(3)")], + [("1.23456", "abc")], + ), + ], +) +def test_column_type_mutator( + data: list[tuple[Any, ...]], + description: list[Any], + expected_result: list[tuple[Any, ...]], +): + from superset.db_engine_specs.mysql import MySQLEngineSpec as spec + + mock_cursor = Mock() + mock_cursor.fetchall.return_value = data + mock_cursor.description = description + + assert spec.fetch_data(mock_cursor) == expected_result