From bb66be6551da90bb7209c1d73e03ad6a6c807b56 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 23 Aug 2020 10:52:36 +0300 Subject: [PATCH 1/4] fix: improve Presto column type matching --- superset/db_engine_specs/base.py | 6 +- superset/db_engine_specs/mssql.py | 9 +-- superset/db_engine_specs/presto.py | 81 ++++++++++++------- superset/models/sql_types/presto_sql_types.py | 24 ------ 4 files changed, 58 insertions(+), 62 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 331961c7c5e5a..2a73e5798b4ee 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -142,6 +142,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ] = None # used for user messages, overridden in child classes _date_trunc_functions: Dict[str, str] = {} _time_grain_expressions: Dict[Optional[str], str] = {} + _column_type_mappings: Tuple[Tuple[TypeEngine, Pattern[str]], ...] = () time_groupby_inline = False limit_method = LimitMethod.FORCE_LIMIT time_secondary_columns = False @@ -886,12 +887,15 @@ def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]: """ Return a sqlalchemy native column type that corresponds to the column type defined in the data source (return None to use default type inferred by - SQLAlchemy). Needs to be overridden if column requires special handling + SQLAlchemy). Override `_column_type_mappings` for specific needs (see MSSQL for example of NCHAR/NVARCHAR handling). :param type_: Column type returned by inspector :return: SqlAlchemy column type """ + for sqla_type, regex in cls._column_type_mappings: + if regex.match(type_): + return sqla_type return None @staticmethod diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index abe1f6c2a2b57..6c597db2a9801 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -73,18 +73,11 @@ def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: # Lists of `pyodbc.Row` need to be unpacked further return cls.pyodbc_rows_to_tuples(data) - column_types = ( + _column_type_mappings = ( (String(), re.compile(r"^(? Optional[TypeEngine]: - for sqla_type, regex in cls.column_types: - if regex.match(type_): - return sqla_type - return None - @classmethod def extract_error_message(cls, ex: Exception) -> str: if str(ex).startswith("(8155,"): diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 16e6a4c53b2c9..fca54eefdf6ef 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -28,7 +28,7 @@ import pandas as pd import simplejson as json from flask_babel import lazy_gettext as _ -from sqlalchemy import Column, literal_column +from sqlalchemy import Column, literal_column, types from sqlalchemy.engine.base import Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.result import RowProxy @@ -40,7 +40,13 @@ from superset.db_engine_specs.base import BaseEngineSpec from superset.exceptions import SupersetTemplateException from superset.models.sql_lab import Query -from superset.models.sql_types.presto_sql_types import type_map as presto_type_map +from superset.models.sql_types.presto_sql_types import ( + Array, + Interval, + Map, + Row, + TinyInteger, +) from superset.result_set import destringify from superset.sql_parse import ParsedQuery from superset.utils import core as utils @@ -260,13 +266,16 @@ def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branch field_info = cls._split_data_type(single_field, r"\s") # check if there is a structural data type within # overall structural data type + column_type = cls.get_sqla_column_type(field_info[1]) + if column_type is None: + raise NotImplementedError( + "Unknown column type: %s", field_info[1] + ) if field_info[1] == "array" or field_info[1] == "row": stack.append((field_info[0], field_info[1])) full_parent_path = cls._get_full_name(stack) result.append( - cls._create_column_info( - full_parent_path, presto_type_map[field_info[1]]() - ) + cls._create_column_info(full_parent_path, column_type) ) else: # otherwise this field is a basic data type full_parent_path = cls._get_full_name(stack) @@ -274,9 +283,7 @@ def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branch full_parent_path, field_info[0] ) result.append( - cls._create_column_info( - column_name, presto_type_map[field_info[1]]() - ) + cls._create_column_info(column_name, column_type) ) # If the component type ends with a structural data type, do not pop # the stack. We have run across a structural data type within the @@ -318,6 +325,28 @@ def _show_columns( columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table)) return columns + _column_type_mappings = ( + (types.Boolean(), re.compile(r"^boolean.*", re.IGNORECASE)), + (TinyInteger(), re.compile(r"^tinyint.*", re.IGNORECASE)), + (types.SmallInteger(), re.compile(r"^smallint.*", re.IGNORECASE)), + (types.Integer(), re.compile(r"^integer.*", re.IGNORECASE)), + (types.BigInteger(), re.compile(r"^bigint.*", re.IGNORECASE)), + (types.Float(), re.compile(r"^real.*", re.IGNORECASE)), + (types.Float(), re.compile(r"^double.*", re.IGNORECASE)), + (types.DECIMAL(), re.compile(r"^decimal.*", re.IGNORECASE)), + (types.String(), re.compile(r"^varchar.*", re.IGNORECASE),), + (types.CHAR(), re.compile(r"^char.*", re.IGNORECASE)), + (types.VARBINARY(), re.compile(r"^varbinary.*", re.IGNORECASE)), + (types.JSON(), re.compile(r"^json.*", re.IGNORECASE)), + (types.DATE(), re.compile(r"^date.*", re.IGNORECASE)), + (types.Time(), re.compile(r"^time.*", re.IGNORECASE)), + (types.TIMESTAMP(), re.compile(r"^timestamp.*", re.IGNORECASE)), + (Interval(), re.compile(r"^interval.*", re.IGNORECASE)), + (Array(), re.compile(r"^array.*", re.IGNORECASE),), + (Map(), re.compile(r"^map.*", re.IGNORECASE)), + (Row(), re.compile(r"^row.*", re.IGNORECASE),), + ) + @classmethod def get_columns( cls, inspector: Inspector, table_name: str, schema: Optional[str] @@ -334,28 +363,22 @@ def get_columns( columns = cls._show_columns(inspector, table_name, schema) result: List[Dict[str, Any]] = [] for column in columns: - try: - # parse column if it is a row or array - if is_feature_enabled("PRESTO_EXPAND_DATA") and ( - "array" in column.Type or "row" in column.Type - ): - structural_column_index = len(result) - cls._parse_structural_column(column.Column, column.Type, result) - result[structural_column_index]["nullable"] = getattr( - column, "Null", True - ) - result[structural_column_index]["default"] = None - continue - - # otherwise column is a basic data type - column_type = presto_type_map[column.Type]() - except KeyError: - logger.info( - "Did not recognize type {} of column {}".format( # pylint: disable=logging-format-interpolation - column.Type, column.Column - ) + # parse column if it is a row or array + if is_feature_enabled("PRESTO_EXPAND_DATA") and ( + "array" in column.Type or "row" in column.Type + ): + structural_column_index = len(result) + cls._parse_structural_column(column.Column, column.Type, result) + result[structural_column_index]["nullable"] = getattr( + column, "Null", True ) - column_type = "OTHER" + result[structural_column_index]["default"] = None + continue + + # otherwise column is a basic data type + column_type = cls.get_sqla_column_type(column.Type) + if column_type is None: + raise NotImplementedError("Unknown column type: %s", column.Type) column_info = cls._create_column_info(column.Column, column_type) column_info["nullable"] = getattr(column, "Null", True) column_info["default"] = None diff --git a/superset/models/sql_types/presto_sql_types.py b/superset/models/sql_types/presto_sql_types.py index d6f6d3995f6d7..a314639ca6907 100644 --- a/superset/models/sql_types/presto_sql_types.py +++ b/superset/models/sql_types/presto_sql_types.py @@ -16,7 +16,6 @@ # under the License. from typing import Any, Dict, List, Optional, Type -from sqlalchemy import types from sqlalchemy.sql.sqltypes import Integer from sqlalchemy.sql.type_api import TypeEngine from sqlalchemy.sql.visitors import Visitable @@ -92,26 +91,3 @@ def python_type(self) -> Optional[Type[Any]]: @classmethod def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str: return "ROW" - - -type_map = { - "boolean": types.Boolean, - "tinyint": TinyInteger, - "smallint": types.SmallInteger, - "integer": types.Integer, - "bigint": types.BigInteger, - "real": types.Float, - "double": types.Float, - "decimal": types.DECIMAL, - "varchar": types.String, - "char": types.CHAR, - "varbinary": types.VARBINARY, - "JSON": types.JSON, - "date": types.DATE, - "time": types.Time, - "timestamp": types.TIMESTAMP, - "interval": Interval, - "array": Array, - "map": Map, - "row": Row, -} From 63308f4f4c49de72a8728762c224e69c4364b296 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Sun, 23 Aug 2020 14:55:55 +0300 Subject: [PATCH 2/4] add optional callback to type map and add tests --- superset/db_engine_specs/base.py | 13 ++++++-- superset/db_engine_specs/mssql.py | 4 +-- superset/db_engine_specs/presto.py | 46 ++++++++++++++++----------- tests/db_engine_specs/presto_tests.py | 13 ++++++++ 4 files changed, 52 insertions(+), 24 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 2a73e5798b4ee..8255d27a98fd7 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -24,6 +24,7 @@ from datetime import datetime from typing import ( Any, + Callable, Dict, List, NamedTuple, @@ -142,7 +143,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ] = None # used for user messages, overridden in child classes _date_trunc_functions: Dict[str, str] = {} _time_grain_expressions: Dict[Optional[str], str] = {} - _column_type_mappings: Tuple[Tuple[TypeEngine, Pattern[str]], ...] = () + _column_type_mappings: Tuple[ + Tuple[Pattern[str], Union[TypeEngine, Callable[[re.Match[str]], TypeEngine]]], + ..., + ] = () time_groupby_inline = False limit_method = LimitMethod.FORCE_LIMIT time_secondary_columns = False @@ -893,8 +897,11 @@ def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]: :param type_: Column type returned by inspector :return: SqlAlchemy column type """ - for sqla_type, regex in cls._column_type_mappings: - if regex.match(type_): + for regex, sqla_type in cls._column_type_mappings: + match = regex.match(type_) + if match: + if callable(sqla_type): + return sqla_type(match) return sqla_type return None diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 6c597db2a9801..370e1a2ee5e90 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -74,8 +74,8 @@ def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: return cls.pyodbc_rows_to_tuples(data) _column_type_mappings = ( - (String(), re.compile(r"^(? Date: Sun, 23 Aug 2020 15:03:05 +0300 Subject: [PATCH 3/4] lint --- superset/db_engine_specs/base.py | 4 ++-- superset/db_engine_specs/mssql.py | 6 +++--- superset/db_engine_specs/presto.py | 16 ++++++++-------- tests/db_engine_specs/presto_tests.py | 9 +++++++++ 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 8255d27a98fd7..5b159273a505d 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -27,6 +27,7 @@ Callable, Dict, List, + Match, NamedTuple, Optional, Pattern, @@ -144,8 +145,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods _date_trunc_functions: Dict[str, str] = {} _time_grain_expressions: Dict[Optional[str], str] = {} _column_type_mappings: Tuple[ - Tuple[Pattern[str], Union[TypeEngine, Callable[[re.Match[str]], TypeEngine]]], - ..., + Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ..., ] = () time_groupby_inline = False limit_method = LimitMethod.FORCE_LIMIT diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 370e1a2ee5e90..c74e4e01a6b9f 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -19,7 +19,7 @@ from datetime import datetime from typing import Any, List, Optional, Tuple, TYPE_CHECKING -from sqlalchemy.types import String, TypeEngine, UnicodeText +from sqlalchemy.types import String, UnicodeText from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod from superset.utils import core as utils @@ -74,8 +74,8 @@ def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: return cls.pyodbc_rows_to_tuples(data) _column_type_mappings = ( - (re.compile(r"^(? Date: Mon, 24 Aug 2020 18:22:23 +0300 Subject: [PATCH 4/4] change private to public --- superset/db_engine_specs/base.py | 4 ++-- superset/db_engine_specs/mssql.py | 2 +- superset/db_engine_specs/presto.py | 2 +- tests/db_engine_specs/presto_tests.py | 3 +-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 5b159273a505d..d3d9dab0cb220 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -144,7 +144,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ] = None # used for user messages, overridden in child classes _date_trunc_functions: Dict[str, str] = {} _time_grain_expressions: Dict[Optional[str], str] = {} - _column_type_mappings: Tuple[ + column_type_mappings: Tuple[ Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ..., ] = () time_groupby_inline = False @@ -897,7 +897,7 @@ def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]: :param type_: Column type returned by inspector :return: SqlAlchemy column type """ - for regex, sqla_type in cls._column_type_mappings: + for regex, sqla_type in cls.column_type_mappings: match = regex.match(type_) if match: if callable(sqla_type): diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index c74e4e01a6b9f..70bd9b5e36b38 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -73,7 +73,7 @@ def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: # Lists of `pyodbc.Row` need to be unpacked further return cls.pyodbc_rows_to_tuples(data) - _column_type_mappings = ( + column_type_mappings = ( (re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), UnicodeText()), (re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), String()), ) diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index d0a03525257d1..9a53d5d06b22e 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -325,7 +325,7 @@ def _show_columns( columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table)) return columns - _column_type_mappings = ( + column_type_mappings = ( (re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()), (re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()), (re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()), diff --git a/tests/db_engine_specs/presto_tests.py b/tests/db_engine_specs/presto_tests.py index b0e0e58bb3095..3a0346bfe2591 100644 --- a/tests/db_engine_specs/presto_tests.py +++ b/tests/db_engine_specs/presto_tests.py @@ -492,8 +492,7 @@ def test_presto_expand_data_array(self): self.assertEqual(actual_data, expected_data) self.assertEqual(actual_expanded_cols, expected_expanded_cols) - @staticmethod - def test_get_sqla_column_type(): + def test_get_sqla_column_type(self): sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)") assert isinstance(sqla_type, types.VARCHAR) assert sqla_type.length == 255