diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index d3d9dab0cb220..331961c7c5e5a 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -24,10 +24,8 @@ from datetime import datetime from typing import ( Any, - Callable, Dict, List, - Match, NamedTuple, Optional, Pattern, @@ -144,9 +142,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ] = None # used for user messages, overridden in child classes _date_trunc_functions: Dict[str, str] = {} _time_grain_expressions: Dict[Optional[str], str] = {} - column_type_mappings: Tuple[ - Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ..., - ] = () time_groupby_inline = False limit_method = LimitMethod.FORCE_LIMIT time_secondary_columns = False @@ -891,18 +886,12 @@ def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]: """ Return a sqlalchemy native column type that corresponds to the column type defined in the data source (return None to use default type inferred by - SQLAlchemy). Override `_column_type_mappings` for specific needs + SQLAlchemy). Needs to be overridden if column requires special handling (see MSSQL for example of NCHAR/NVARCHAR handling). :param type_: Column type returned by inspector :return: SqlAlchemy column type """ - for regex, sqla_type in cls.column_type_mappings: - match = regex.match(type_) - if match: - if callable(sqla_type): - return sqla_type(match) - return sqla_type return None @staticmethod diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 70bd9b5e36b38..abe1f6c2a2b57 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -19,7 +19,7 @@ from datetime import datetime from typing import Any, List, Optional, Tuple, TYPE_CHECKING -from sqlalchemy.types import String, UnicodeText +from sqlalchemy.types import String, TypeEngine, UnicodeText from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod from superset.utils import core as utils @@ -73,11 +73,18 @@ def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: # Lists of `pyodbc.Row` need to be unpacked further return cls.pyodbc_rows_to_tuples(data) - column_type_mappings = ( - (re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), UnicodeText()), - (re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), String()), + column_types = ( + (String(), re.compile(r"^(? Optional[TypeEngine]: + for sqla_type, regex in cls.column_types: + if regex.match(type_): + return sqla_type + return None + @classmethod def extract_error_message(cls, ex: Exception) -> str: if str(ex).startswith("(8155,"): diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 9a53d5d06b22e..16e6a4c53b2c9 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -28,7 +28,7 @@ import pandas as pd import simplejson as json from flask_babel import lazy_gettext as _ -from sqlalchemy import Column, literal_column, types +from sqlalchemy import Column, literal_column from sqlalchemy.engine.base import Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.result import RowProxy @@ -40,13 +40,7 @@ from superset.db_engine_specs.base import BaseEngineSpec from superset.exceptions import SupersetTemplateException from superset.models.sql_lab import Query -from superset.models.sql_types.presto_sql_types import ( - Array, - Interval, - Map, - Row, - TinyInteger, -) +from superset.models.sql_types.presto_sql_types import type_map as presto_type_map from superset.result_set import destringify from superset.sql_parse import ParsedQuery from superset.utils import core as utils @@ -266,16 +260,13 @@ def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branch field_info = cls._split_data_type(single_field, r"\s") # check if there is a structural data type within # overall structural data type - column_type = cls.get_sqla_column_type(field_info[1]) - if column_type is None: - raise NotImplementedError( - _("Unknown column type: %(col)s", col=field_info[1]) - ) if field_info[1] == "array" or field_info[1] == "row": stack.append((field_info[0], field_info[1])) full_parent_path = cls._get_full_name(stack) result.append( - cls._create_column_info(full_parent_path, column_type) + cls._create_column_info( + full_parent_path, presto_type_map[field_info[1]]() + ) ) else: # otherwise this field is a basic data type full_parent_path = cls._get_full_name(stack) @@ -283,7 +274,9 @@ def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branch full_parent_path, field_info[0] ) result.append( - cls._create_column_info(column_name, column_type) + cls._create_column_info( + column_name, presto_type_map[field_info[1]]() + ) ) # If the component type ends with a structural data type, do not pop # the stack. We have run across a structural data type within the @@ -325,34 +318,6 @@ def _show_columns( columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table)) return columns - column_type_mappings = ( - (re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()), - (re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()), - (re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()), - (re.compile(r"^integer.*", re.IGNORECASE), types.Integer()), - (re.compile(r"^bigint.*", re.IGNORECASE), types.BigInteger()), - (re.compile(r"^real.*", re.IGNORECASE), types.Float()), - (re.compile(r"^double.*", re.IGNORECASE), types.Float()), - (re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL()), - ( - re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE), - lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(), - ), - ( - re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE), - lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(), - ), - (re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY()), - (re.compile(r"^json.*", re.IGNORECASE), types.JSON()), - (re.compile(r"^date.*", re.IGNORECASE), types.DATE()), - (re.compile(r"^time.*", re.IGNORECASE), types.Time()), - (re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP()), - (re.compile(r"^interval.*", re.IGNORECASE), Interval()), - (re.compile(r"^array.*", re.IGNORECASE), Array()), - (re.compile(r"^map.*", re.IGNORECASE), Map()), - (re.compile(r"^row.*", re.IGNORECASE), Row()), - ) - @classmethod def get_columns( cls, inspector: Inspector, table_name: str, schema: Optional[str] @@ -369,24 +334,28 @@ def get_columns( columns = cls._show_columns(inspector, table_name, schema) result: List[Dict[str, Any]] = [] for column in columns: - # parse column if it is a row or array - if is_feature_enabled("PRESTO_EXPAND_DATA") and ( - "array" in column.Type or "row" in column.Type - ): - structural_column_index = len(result) - cls._parse_structural_column(column.Column, column.Type, result) - result[structural_column_index]["nullable"] = getattr( - column, "Null", True - ) - result[structural_column_index]["default"] = None - continue - - # otherwise column is a basic data type - column_type = cls.get_sqla_column_type(column.Type) - if column_type is None: - raise NotImplementedError( - _("Unknown column type: %(col)s", col=column_type) + try: + # parse column if it is a row or array + if is_feature_enabled("PRESTO_EXPAND_DATA") and ( + "array" in column.Type or "row" in column.Type + ): + structural_column_index = len(result) + cls._parse_structural_column(column.Column, column.Type, result) + result[structural_column_index]["nullable"] = getattr( + column, "Null", True + ) + result[structural_column_index]["default"] = None + continue + + # otherwise column is a basic data type + column_type = presto_type_map[column.Type]() + except KeyError: + logger.info( + "Did not recognize type {} of column {}".format( # pylint: disable=logging-format-interpolation + column.Type, column.Column + ) ) + column_type = "OTHER" column_info = cls._create_column_info(column.Column, column_type) column_info["nullable"] = getattr(column, "Null", True) column_info["default"] = None diff --git a/superset/models/sql_types/presto_sql_types.py b/superset/models/sql_types/presto_sql_types.py index a314639ca6907..d6f6d3995f6d7 100644 --- a/superset/models/sql_types/presto_sql_types.py +++ b/superset/models/sql_types/presto_sql_types.py @@ -16,6 +16,7 @@ # under the License. from typing import Any, Dict, List, Optional, Type +from sqlalchemy import types from sqlalchemy.sql.sqltypes import Integer from sqlalchemy.sql.type_api import TypeEngine from sqlalchemy.sql.visitors import Visitable @@ -91,3 +92,26 @@ def python_type(self) -> Optional[Type[Any]]: @classmethod def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str: return "ROW" + + +type_map = { + "boolean": types.Boolean, + "tinyint": TinyInteger, + "smallint": types.SmallInteger, + "integer": types.Integer, + "bigint": types.BigInteger, + "real": types.Float, + "double": types.Float, + "decimal": types.DECIMAL, + "varchar": types.String, + "char": types.CHAR, + "varbinary": types.VARBINARY, + "JSON": types.JSON, + "date": types.DATE, + "time": types.Time, + "timestamp": types.TIMESTAMP, + "interval": Interval, + "array": Array, + "map": Map, + "row": Row, +} diff --git a/tests/db_engine_specs/presto_tests.py b/tests/db_engine_specs/presto_tests.py index 3a0346bfe2591..9d1d384615275 100644 --- a/tests/db_engine_specs/presto_tests.py +++ b/tests/db_engine_specs/presto_tests.py @@ -17,7 +17,6 @@ from unittest import mock, skipUnless import pandas as pd -from sqlalchemy import types from sqlalchemy.engine.result import RowProxy from sqlalchemy.sql import select @@ -491,23 +490,3 @@ def test_presto_expand_data_array(self): self.assertEqual(actual_cols, expected_cols) self.assertEqual(actual_data, expected_data) self.assertEqual(actual_expanded_cols, expected_expanded_cols) - - def test_get_sqla_column_type(self): - sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)") - assert isinstance(sqla_type, types.VARCHAR) - assert sqla_type.length == 255 - - sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar") - assert isinstance(sqla_type, types.String) - assert sqla_type.length is None - - sqla_type = PrestoEngineSpec.get_sqla_column_type("char(10)") - assert isinstance(sqla_type, types.CHAR) - assert sqla_type.length == 10 - - sqla_type = PrestoEngineSpec.get_sqla_column_type("char") - assert isinstance(sqla_type, types.CHAR) - assert sqla_type.length is None - - sqla_type = PrestoEngineSpec.get_sqla_column_type("integer") - assert isinstance(sqla_type, types.Integer)