diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 331961c7c5e5a..d3d9dab0cb220 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -24,8 +24,10 @@ from datetime import datetime from typing import ( Any, + Callable, Dict, List, + Match, NamedTuple, Optional, Pattern, @@ -142,6 +144,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ] = None # used for user messages, overridden in child classes _date_trunc_functions: Dict[str, str] = {} _time_grain_expressions: Dict[Optional[str], str] = {} + column_type_mappings: Tuple[ + Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ..., + ] = () time_groupby_inline = False limit_method = LimitMethod.FORCE_LIMIT time_secondary_columns = False @@ -886,12 +891,18 @@ def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]: """ Return a sqlalchemy native column type that corresponds to the column type defined in the data source (return None to use default type inferred by - SQLAlchemy). Needs to be overridden if column requires special handling + SQLAlchemy). Override `_column_type_mappings` for specific needs (see MSSQL for example of NCHAR/NVARCHAR handling). :param type_: Column type returned by inspector :return: SqlAlchemy column type """ + for regex, sqla_type in cls.column_type_mappings: + match = regex.match(type_) + if match: + if callable(sqla_type): + return sqla_type(match) + return sqla_type return None @staticmethod diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index abe1f6c2a2b57..70bd9b5e36b38 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -19,7 +19,7 @@ from datetime import datetime from typing import Any, List, Optional, Tuple, TYPE_CHECKING -from sqlalchemy.types import String, TypeEngine, UnicodeText +from sqlalchemy.types import String, UnicodeText from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod from superset.utils import core as utils @@ -73,18 +73,11 @@ def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: # Lists of `pyodbc.Row` need to be unpacked further return cls.pyodbc_rows_to_tuples(data) - column_types = ( - (String(), re.compile(r"^(? Optional[TypeEngine]: - for sqla_type, regex in cls.column_types: - if regex.match(type_): - return sqla_type - return None - @classmethod def extract_error_message(cls, ex: Exception) -> str: if str(ex).startswith("(8155,"): diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 16e6a4c53b2c9..9a53d5d06b22e 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -28,7 +28,7 @@ import pandas as pd import simplejson as json from flask_babel import lazy_gettext as _ -from sqlalchemy import Column, literal_column +from sqlalchemy import Column, literal_column, types from sqlalchemy.engine.base import Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.result import RowProxy @@ -40,7 +40,13 @@ from superset.db_engine_specs.base import BaseEngineSpec from superset.exceptions import SupersetTemplateException from superset.models.sql_lab import Query -from superset.models.sql_types.presto_sql_types import type_map as presto_type_map +from superset.models.sql_types.presto_sql_types import ( + Array, + Interval, + Map, + Row, + TinyInteger, +) from superset.result_set import destringify from superset.sql_parse import ParsedQuery from superset.utils import core as utils @@ -260,13 +266,16 @@ def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branch field_info = cls._split_data_type(single_field, r"\s") # check if there is a structural data type within # overall structural data type + column_type = cls.get_sqla_column_type(field_info[1]) + if column_type is None: + raise NotImplementedError( + _("Unknown column type: %(col)s", col=field_info[1]) + ) if field_info[1] == "array" or field_info[1] == "row": stack.append((field_info[0], field_info[1])) full_parent_path = cls._get_full_name(stack) result.append( - cls._create_column_info( - full_parent_path, presto_type_map[field_info[1]]() - ) + cls._create_column_info(full_parent_path, column_type) ) else: # otherwise this field is a basic data type full_parent_path = cls._get_full_name(stack) @@ -274,9 +283,7 @@ def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branch full_parent_path, field_info[0] ) result.append( - cls._create_column_info( - column_name, presto_type_map[field_info[1]]() - ) + cls._create_column_info(column_name, column_type) ) # If the component type ends with a structural data type, do not pop # the stack. We have run across a structural data type within the @@ -318,6 +325,34 @@ def _show_columns( columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table)) return columns + column_type_mappings = ( + (re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()), + (re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()), + (re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()), + (re.compile(r"^integer.*", re.IGNORECASE), types.Integer()), + (re.compile(r"^bigint.*", re.IGNORECASE), types.BigInteger()), + (re.compile(r"^real.*", re.IGNORECASE), types.Float()), + (re.compile(r"^double.*", re.IGNORECASE), types.Float()), + (re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL()), + ( + re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE), + lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(), + ), + ( + re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE), + lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(), + ), + (re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY()), + (re.compile(r"^json.*", re.IGNORECASE), types.JSON()), + (re.compile(r"^date.*", re.IGNORECASE), types.DATE()), + (re.compile(r"^time.*", re.IGNORECASE), types.Time()), + (re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP()), + (re.compile(r"^interval.*", re.IGNORECASE), Interval()), + (re.compile(r"^array.*", re.IGNORECASE), Array()), + (re.compile(r"^map.*", re.IGNORECASE), Map()), + (re.compile(r"^row.*", re.IGNORECASE), Row()), + ) + @classmethod def get_columns( cls, inspector: Inspector, table_name: str, schema: Optional[str] @@ -334,28 +369,24 @@ def get_columns( columns = cls._show_columns(inspector, table_name, schema) result: List[Dict[str, Any]] = [] for column in columns: - try: - # parse column if it is a row or array - if is_feature_enabled("PRESTO_EXPAND_DATA") and ( - "array" in column.Type or "row" in column.Type - ): - structural_column_index = len(result) - cls._parse_structural_column(column.Column, column.Type, result) - result[structural_column_index]["nullable"] = getattr( - column, "Null", True - ) - result[structural_column_index]["default"] = None - continue - - # otherwise column is a basic data type - column_type = presto_type_map[column.Type]() - except KeyError: - logger.info( - "Did not recognize type {} of column {}".format( # pylint: disable=logging-format-interpolation - column.Type, column.Column - ) + # parse column if it is a row or array + if is_feature_enabled("PRESTO_EXPAND_DATA") and ( + "array" in column.Type or "row" in column.Type + ): + structural_column_index = len(result) + cls._parse_structural_column(column.Column, column.Type, result) + result[structural_column_index]["nullable"] = getattr( + column, "Null", True + ) + result[structural_column_index]["default"] = None + continue + + # otherwise column is a basic data type + column_type = cls.get_sqla_column_type(column.Type) + if column_type is None: + raise NotImplementedError( + _("Unknown column type: %(col)s", col=column_type) ) - column_type = "OTHER" column_info = cls._create_column_info(column.Column, column_type) column_info["nullable"] = getattr(column, "Null", True) column_info["default"] = None diff --git a/superset/models/sql_types/presto_sql_types.py b/superset/models/sql_types/presto_sql_types.py index d6f6d3995f6d7..a314639ca6907 100644 --- a/superset/models/sql_types/presto_sql_types.py +++ b/superset/models/sql_types/presto_sql_types.py @@ -16,7 +16,6 @@ # under the License. from typing import Any, Dict, List, Optional, Type -from sqlalchemy import types from sqlalchemy.sql.sqltypes import Integer from sqlalchemy.sql.type_api import TypeEngine from sqlalchemy.sql.visitors import Visitable @@ -92,26 +91,3 @@ def python_type(self) -> Optional[Type[Any]]: @classmethod def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str: return "ROW" - - -type_map = { - "boolean": types.Boolean, - "tinyint": TinyInteger, - "smallint": types.SmallInteger, - "integer": types.Integer, - "bigint": types.BigInteger, - "real": types.Float, - "double": types.Float, - "decimal": types.DECIMAL, - "varchar": types.String, - "char": types.CHAR, - "varbinary": types.VARBINARY, - "JSON": types.JSON, - "date": types.DATE, - "time": types.Time, - "timestamp": types.TIMESTAMP, - "interval": Interval, - "array": Array, - "map": Map, - "row": Row, -} diff --git a/tests/db_engine_specs/presto_tests.py b/tests/db_engine_specs/presto_tests.py index 9d1d384615275..3a0346bfe2591 100644 --- a/tests/db_engine_specs/presto_tests.py +++ b/tests/db_engine_specs/presto_tests.py @@ -17,6 +17,7 @@ from unittest import mock, skipUnless import pandas as pd +from sqlalchemy import types from sqlalchemy.engine.result import RowProxy from sqlalchemy.sql import select @@ -490,3 +491,23 @@ def test_presto_expand_data_array(self): self.assertEqual(actual_cols, expected_cols) self.assertEqual(actual_data, expected_data) self.assertEqual(actual_expanded_cols, expected_expanded_cols) + + def test_get_sqla_column_type(self): + sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)") + assert isinstance(sqla_type, types.VARCHAR) + assert sqla_type.length == 255 + + sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar") + assert isinstance(sqla_type, types.String) + assert sqla_type.length is None + + sqla_type = PrestoEngineSpec.get_sqla_column_type("char(10)") + assert isinstance(sqla_type, types.CHAR) + assert sqla_type.length == 10 + + sqla_type = PrestoEngineSpec.get_sqla_column_type("char") + assert isinstance(sqla_type, types.CHAR) + assert sqla_type.length is None + + sqla_type = PrestoEngineSpec.get_sqla_column_type("integer") + assert isinstance(sqla_type, types.Integer)