-
Notifications
You must be signed in to change notification settings - Fork 14.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(db_engine_specs): improve Presto column type matching #10658
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,8 +24,10 @@ | |
from datetime import datetime | ||
from typing import ( | ||
Any, | ||
Callable, | ||
Dict, | ||
List, | ||
Match, | ||
NamedTuple, | ||
Optional, | ||
Pattern, | ||
|
@@ -142,6 +144,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods | |
] = None # used for user messages, overridden in child classes | ||
_date_trunc_functions: Dict[str, str] = {} | ||
_time_grain_expressions: Dict[Optional[str], str] = {} | ||
column_type_mappings: Tuple[ | ||
Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ..., | ||
] = () | ||
time_groupby_inline = False | ||
limit_method = LimitMethod.FORCE_LIMIT | ||
time_secondary_columns = False | ||
|
@@ -886,12 +891,18 @@ def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]: | |
""" | ||
Return a sqlalchemy native column type that corresponds to the column type | ||
defined in the data source (return None to use default type inferred by | ||
SQLAlchemy). Needs to be overridden if column requires special handling | ||
SQLAlchemy). Override `_column_type_mappings` for specific needs | ||
(see MSSQL for example of NCHAR/NVARCHAR handling). | ||
|
||
:param type_: Column type returned by inspector | ||
:return: SqlAlchemy column type | ||
""" | ||
for regex, sqla_type in cls.column_type_mappings: | ||
match = regex.match(type_) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @villebro we tested this during a deploy and got errors here because sometimes type_ is None. get_sqla_column_type is used in |
||
if match: | ||
if callable(sqla_type): | ||
return sqla_type(match) | ||
return sqla_type | ||
return None | ||
|
||
@staticmethod | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ | |
import pandas as pd | ||
import simplejson as json | ||
from flask_babel import lazy_gettext as _ | ||
from sqlalchemy import Column, literal_column | ||
from sqlalchemy import Column, literal_column, types | ||
from sqlalchemy.engine.base import Engine | ||
from sqlalchemy.engine.reflection import Inspector | ||
from sqlalchemy.engine.result import RowProxy | ||
|
@@ -40,7 +40,13 @@ | |
from superset.db_engine_specs.base import BaseEngineSpec | ||
from superset.exceptions import SupersetTemplateException | ||
from superset.models.sql_lab import Query | ||
from superset.models.sql_types.presto_sql_types import type_map as presto_type_map | ||
from superset.models.sql_types.presto_sql_types import ( | ||
Array, | ||
Interval, | ||
Map, | ||
Row, | ||
TinyInteger, | ||
) | ||
from superset.result_set import destringify | ||
from superset.sql_parse import ParsedQuery | ||
from superset.utils import core as utils | ||
|
@@ -260,23 +266,24 @@ def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branch | |
field_info = cls._split_data_type(single_field, r"\s") | ||
# check if there is a structural data type within | ||
# overall structural data type | ||
column_type = cls.get_sqla_column_type(field_info[1]) | ||
if column_type is None: | ||
raise NotImplementedError( | ||
_("Unknown column type: %(col)s", col=field_info[1]) | ||
) | ||
if field_info[1] == "array" or field_info[1] == "row": | ||
stack.append((field_info[0], field_info[1])) | ||
full_parent_path = cls._get_full_name(stack) | ||
result.append( | ||
cls._create_column_info( | ||
full_parent_path, presto_type_map[field_info[1]]() | ||
) | ||
cls._create_column_info(full_parent_path, column_type) | ||
) | ||
else: # otherwise this field is a basic data type | ||
full_parent_path = cls._get_full_name(stack) | ||
column_name = "{}.{}".format( | ||
full_parent_path, field_info[0] | ||
) | ||
result.append( | ||
cls._create_column_info( | ||
column_name, presto_type_map[field_info[1]]() | ||
) | ||
cls._create_column_info(column_name, column_type) | ||
) | ||
# If the component type ends with a structural data type, do not pop | ||
# the stack. We have run across a structural data type within the | ||
|
@@ -318,6 +325,34 @@ def _show_columns( | |
columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table)) | ||
return columns | ||
|
||
column_type_mappings = ( | ||
(re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()), | ||
(re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()), | ||
(re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()), | ||
(re.compile(r"^integer.*", re.IGNORECASE), types.Integer()), | ||
(re.compile(r"^bigint.*", re.IGNORECASE), types.BigInteger()), | ||
(re.compile(r"^real.*", re.IGNORECASE), types.Float()), | ||
(re.compile(r"^double.*", re.IGNORECASE), types.Float()), | ||
(re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL()), | ||
( | ||
re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE), | ||
lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(), | ||
), | ||
Comment on lines
+337
to
+340
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we're instantiating a |
||
( | ||
re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE), | ||
lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(), | ||
), | ||
(re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY()), | ||
(re.compile(r"^json.*", re.IGNORECASE), types.JSON()), | ||
(re.compile(r"^date.*", re.IGNORECASE), types.DATE()), | ||
(re.compile(r"^time.*", re.IGNORECASE), types.Time()), | ||
(re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP()), | ||
(re.compile(r"^interval.*", re.IGNORECASE), Interval()), | ||
(re.compile(r"^array.*", re.IGNORECASE), Array()), | ||
(re.compile(r"^map.*", re.IGNORECASE), Map()), | ||
(re.compile(r"^row.*", re.IGNORECASE), Row()), | ||
) | ||
|
||
@classmethod | ||
def get_columns( | ||
cls, inspector: Inspector, table_name: str, schema: Optional[str] | ||
|
@@ -334,28 +369,24 @@ def get_columns( | |
columns = cls._show_columns(inspector, table_name, schema) | ||
result: List[Dict[str, Any]] = [] | ||
for column in columns: | ||
try: | ||
# parse column if it is a row or array | ||
if is_feature_enabled("PRESTO_EXPAND_DATA") and ( | ||
"array" in column.Type or "row" in column.Type | ||
): | ||
structural_column_index = len(result) | ||
cls._parse_structural_column(column.Column, column.Type, result) | ||
result[structural_column_index]["nullable"] = getattr( | ||
column, "Null", True | ||
) | ||
result[structural_column_index]["default"] = None | ||
continue | ||
|
||
# otherwise column is a basic data type | ||
column_type = presto_type_map[column.Type]() | ||
except KeyError: | ||
logger.info( | ||
"Did not recognize type {} of column {}".format( # pylint: disable=logging-format-interpolation | ||
column.Type, column.Column | ||
) | ||
# parse column if it is a row or array | ||
if is_feature_enabled("PRESTO_EXPAND_DATA") and ( | ||
"array" in column.Type or "row" in column.Type | ||
): | ||
structural_column_index = len(result) | ||
cls._parse_structural_column(column.Column, column.Type, result) | ||
result[structural_column_index]["nullable"] = getattr( | ||
column, "Null", True | ||
) | ||
result[structural_column_index]["default"] = None | ||
continue | ||
|
||
# otherwise column is a basic data type | ||
column_type = cls.get_sqla_column_type(column.Type) | ||
if column_type is None: | ||
raise NotImplementedError( | ||
_("Unknown column type: %(col)s", col=column_type) | ||
) | ||
column_type = "OTHER" | ||
column_info = cls._create_column_info(column.Column, column_type) | ||
column_info["nullable"] = getattr(column, "Null", True) | ||
column_info["default"] = None | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
from unittest import mock, skipUnless | ||
|
||
import pandas as pd | ||
from sqlalchemy import types | ||
from sqlalchemy.engine.result import RowProxy | ||
from sqlalchemy.sql import select | ||
|
||
|
@@ -490,3 +491,23 @@ def test_presto_expand_data_array(self): | |
self.assertEqual(actual_cols, expected_cols) | ||
self.assertEqual(actual_data, expected_data) | ||
self.assertEqual(actual_expanded_cols, expected_expanded_cols) | ||
|
||
def test_get_sqla_column_type(self): | ||
sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it would be nice to add some more test cases here especially the ones that are not supported by presto There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is blocking another fairly critical feature I'll defer the more comprehensive tests to a forthcoming PR and just cover the discovered bug here. I'm hoping I can carve out some time to build a template for testing this kind of functionality, and don't mind doing the initial work on the Presto spec. |
||
assert isinstance(sqla_type, types.VARCHAR) | ||
assert sqla_type.length == 255 | ||
|
||
sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar") | ||
assert isinstance(sqla_type, types.String) | ||
Comment on lines
+496
to
+501
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we can see that |
||
assert sqla_type.length is None | ||
|
||
sqla_type = PrestoEngineSpec.get_sqla_column_type("char(10)") | ||
assert isinstance(sqla_type, types.CHAR) | ||
assert sqla_type.length == 10 | ||
|
||
sqla_type = PrestoEngineSpec.get_sqla_column_type("char") | ||
assert isinstance(sqla_type, types.CHAR) | ||
assert sqla_type.length is None | ||
|
||
sqla_type = PrestoEngineSpec.get_sqla_column_type("integer") | ||
assert isinstance(sqla_type, types.Integer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we expect subclasses to override
_column_type_mappings
should we make it public?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, the thought crossed my mind while writing this but there were other similar properties that were private so I went with the flow. I changed this one to public, and I'll update the other ones in a later PR to keep this PR as light as possible.