Skip to content

Commit

Permalink
fix postgres edge case
Browse files Browse the repository at this point in the history
  • Loading branch information
villebro committed May 29, 2023
1 parent 14b6155 commit f5a2f73
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 14 deletions.
2 changes: 1 addition & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ def fetch_data(
for row in cursor.description
if (
func := cls.column_type_mutators.get(
type(cls.get_sqla_column_type(row[1]))
type(cls.get_sqla_column_type(cls.get_datatype(row[1])))
)
)
}
Expand Down
2 changes: 1 addition & 1 deletion superset/db_engine_specs/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
),
)
column_type_mutators: dict[types.TypeEngine, Callable[[Any], Any]] = {
DECIMAL: lambda val: float(val) if isinstance(val, (str, Decimal)) else val
DECIMAL: lambda val: Decimal(val) if isinstance(val, str) else val
}

_time_grain_expressions = {
Expand Down
9 changes: 1 addition & 8 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
allows_alias_to_source_column = False

column_type_mutators: dict[TypeEngine, Callable[[Any], Any]] = {
DECIMAL: lambda val: float(val) if isinstance(val, (str, Decimal)) else val
DECIMAL: lambda val: Decimal(val) if isinstance(val, str) else val
}

@classmethod
Expand Down Expand Up @@ -275,10 +275,3 @@ def get_dbapi_exception_mapping(cls) -> dict[Type[Exception], Type[Exception]]:
return {
requests_exceptions.ConnectionError: SupersetDBAPIConnectionError,
}

@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
) -> list[tuple[Any, ...]]:
data = super().fetch_data(cursor, limit)
return data
4 changes: 2 additions & 2 deletions tests/unit_tests/db_engine_specs/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,12 @@ def test_get_schema_from_engine_params() -> None:
(
[["1.23456", "abc"]],
[("dec", "decimal(12,6)"), ("str", "varchar(3)")],
[(1.23456, "abc")],
[(Decimal("1.23456"), "abc")],
),
(
[[Decimal("1.23456"), "abc"]],
[("dec", "decimal(12,6)"), ("str", "varchar(3)")],
[(1.23456, "abc")],
[(Decimal("1.23456"), "abc")],
),
(
[["1.23456", "abc"]],
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/db_engine_specs/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,12 @@ def test_handle_cursor_early_cancel(
(
[["1.23456", "abc"]],
[("dec", "decimal(12,6)"), ("str", "varchar(3)")],
[(1.23456, "abc")],
[(Decimal("1.23456"), "abc")],
),
(
[[Decimal("1.23456"), "abc"]],
[("dec", "decimal(12,6)"), ("str", "varchar(3)")],
[(1.23456, "abc")],
[(Decimal("1.23456"), "abc")],
),
(
[["1.23456", "abc"]],
Expand Down

0 comments on commit f5a2f73

Please sign in to comment.