From 609c3594ef74ad875d1f47e7f2a4c631036503c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nikola=20Gigi=C4=87?= Date: Fri, 12 Mar 2021 09:36:43 +0100 Subject: [PATCH] feat(explore): Postgres datatype conversion (#13294) * test * unnecessary import * fix lint * changes * fix lint * changes * changes * changes * changes * answering comments & changes * answering comments * answering comments * changes * changes * changes * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests --- superset/connectors/sqla/models.py | 28 ++-- superset/db_engine_specs/base.py | 191 +++++++++++++++++++------- superset/db_engine_specs/mssql.py | 13 +- superset/db_engine_specs/mysql.py | 68 ++++++++- superset/db_engine_specs/postgres.py | 54 +++++++- superset/db_engine_specs/presto.py | 138 ++++++++++++++++--- superset/result_set.py | 7 +- superset/utils/core.py | 18 ++- tests/db_engine_specs/mssql_tests.py | 31 +++-- tests/db_engine_specs/mysql_tests.py | 15 +- tests/db_engine_specs/presto_tests.py | 57 ++++---- tests/sqla_models_tests.py | 4 +- 12 files changed, 471 insertions(+), 153 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 9f745f96a1b48..ff4f819f4e9a4 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -69,6 +69,7 @@ from superset.sql_parse import ParsedQuery from superset.typing import Metric, QueryObjectDict from superset.utils import core as utils +from superset.utils.core import GenericDataType config = app.config metadata = Model.metadata # pylint: disable=no-member @@ -186,20 +187,20 @@ def is_numeric(self) -> bool: """ Check if the column has a numeric datatype. """ - db_engine_spec = self.table.database.db_engine_spec - return db_engine_spec.is_db_column_type_match( - self.type, utils.GenericDataType.NUMERIC - ) + column_spec = self.table.database.db_engine_spec.get_column_spec(self.type) + if column_spec is None: + return False + return column_spec.generic_type == GenericDataType.NUMERIC @property def is_string(self) -> bool: """ Check if the column has a string datatype. """ - db_engine_spec = self.table.database.db_engine_spec - return db_engine_spec.is_db_column_type_match( - self.type, utils.GenericDataType.STRING - ) + column_spec = self.table.database.db_engine_spec.get_column_spec(self.type) + if column_spec is None: + return False + return column_spec.generic_type == GenericDataType.STRING @property def is_temporal(self) -> bool: @@ -211,10 +212,10 @@ def is_temporal(self) -> bool: """ if self.is_dttm is not None: return self.is_dttm - db_engine_spec = self.table.database.db_engine_spec - return db_engine_spec.is_db_column_type_match( - self.type, utils.GenericDataType.TEMPORAL - ) + column_spec = self.table.database.db_engine_spec.get_column_spec(self.type) + if column_spec is None: + return False + return column_spec.is_dttm def get_sqla_col(self, label: Optional[str] = None) -> Column: label = label or self.column_name @@ -222,7 +223,8 @@ def get_sqla_col(self, label: Optional[str] = None) -> Column: col = literal_column(self.expression) else: db_engine_spec = self.table.database.db_engine_spec - type_ = db_engine_spec.get_sqla_column_type(self.type) + column_spec = db_engine_spec.get_column_spec(self.type) + type_ = column_spec.sqla_type if column_spec else None col = column(self.column_name, type_=type_) col = self.table.make_sqla_column_compatible(col, label) return col diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 3052ff7b4da11..e3e83cdd43db1 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -41,7 +41,7 @@ import sqlparse from flask import g from flask_babel import gettext as __, lazy_gettext as _ -from sqlalchemy import column, DateTime, select +from sqlalchemy import column, DateTime, select, types from sqlalchemy.engine.base import Engine from sqlalchemy.engine.interfaces import Compiled, Dialect from sqlalchemy.engine.reflection import Inspector @@ -50,13 +50,14 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom -from sqlalchemy.types import TypeEngine +from sqlalchemy.types import String, TypeEngine, UnicodeText from superset import app, security_manager, sql_parse from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.sql_lab import Query from superset.sql_parse import ParsedQuery, Table from superset.utils import core as utils +from superset.utils.core import ColumnSpec, GenericDataType if TYPE_CHECKING: # prevent circular imports @@ -145,8 +146,87 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods _date_trunc_functions: Dict[str, str] = {} _time_grain_expressions: Dict[Optional[str], str] = {} column_type_mappings: Tuple[ - Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ..., - ] = () + Tuple[ + Pattern[str], + Union[TypeEngine, Callable[[Match[str]], TypeEngine]], + GenericDataType, + ], + ..., + ] = ( + ( + re.compile(r"^smallint", re.IGNORECASE), + types.SmallInteger(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^integer", re.IGNORECASE), + types.Integer(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^bigint", re.IGNORECASE), + types.BigInteger(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^decimal", re.IGNORECASE), + types.Numeric(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^numeric", re.IGNORECASE), + types.Numeric(), + GenericDataType.NUMERIC, + ), + (re.compile(r"^real", re.IGNORECASE), types.REAL, GenericDataType.NUMERIC,), + ( + re.compile(r"^smallserial", re.IGNORECASE), + types.SmallInteger(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^serial", re.IGNORECASE), + types.Integer(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^bigserial", re.IGNORECASE), + types.BigInteger(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^string", re.IGNORECASE), + types.String(), + utils.GenericDataType.STRING, + ), + ( + re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), + UnicodeText(), + utils.GenericDataType.STRING, + ), + ( + re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), + String(), + utils.GenericDataType.STRING, + ), + (re.compile(r"^date", re.IGNORECASE), types.Date(), GenericDataType.TEMPORAL,), + ( + re.compile(r"^timestamp", re.IGNORECASE), + types.TIMESTAMP(), + GenericDataType.TEMPORAL, + ), + ( + re.compile(r"^interval", re.IGNORECASE), + types.Interval(), + GenericDataType.TEMPORAL, + ), + (re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,), + ( + re.compile(r"^boolean", re.IGNORECASE), + types.Boolean(), + GenericDataType.BOOLEAN, + ), + ) time_groupby_inline = False limit_method = LimitMethod.FORCE_LIMIT time_secondary_columns = False @@ -160,25 +240,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods try_remove_schema_from_table_name = True # pylint: disable=invalid-name run_multiple_statements_as_one = False - # default matching patterns to convert database specific column types to - # more generic types - db_column_types: Dict[utils.GenericDataType, Tuple[Pattern[str], ...]] = { - utils.GenericDataType.NUMERIC: ( - re.compile(r"BIT", re.IGNORECASE), - re.compile( - r".*(DOUBLE|FLOAT|INT|NUMBER|REAL|NUMERIC|DECIMAL|MONEY).*", - re.IGNORECASE, - ), - re.compile(r".*LONG$", re.IGNORECASE), - ), - utils.GenericDataType.STRING: ( - re.compile(r".*(CHAR|STRING|TEXT).*", re.IGNORECASE), - ), - utils.GenericDataType.TEMPORAL: ( - re.compile(r".*(DATE|TIME).*", re.IGNORECASE), - ), - } - @classmethod def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: """ @@ -208,25 +269,6 @@ def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception: return exception return new_exception(str(exception)) - @classmethod - def is_db_column_type_match( - cls, db_column_type: Optional[str], target_column_type: utils.GenericDataType - ) -> bool: - """ - Check if a column type satisfies a pattern in a collection of regexes found in - `db_column_types`. For example, if `db_column_type == "NVARCHAR"`, - it would be a match for "STRING" due to being a match for the regex ".*CHAR.*". - - :param db_column_type: Column type to evaluate - :param target_column_type: The target type to evaluate for - :return: `True` if a `db_column_type` matches any pattern corresponding to - `target_column_type` - """ - if not db_column_type: - return False - patterns = cls.db_column_types[target_column_type] - return any(pattern.match(db_column_type) for pattern in patterns) - @classmethod def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: return False @@ -967,24 +1009,35 @@ def make_label_compatible(cls, label: str) -> Union[str, quoted_name]: return label_mutated @classmethod - def get_sqla_column_type(cls, type_: Optional[str]) -> Optional[TypeEngine]: + def get_sqla_column_type( + cls, + column_type: Optional[str], + column_type_mappings: Tuple[ + Tuple[ + Pattern[str], + Union[TypeEngine, Callable[[Match[str]], TypeEngine]], + GenericDataType, + ], + ..., + ] = column_type_mappings, + ) -> Union[Tuple[TypeEngine, GenericDataType], None]: """ Return a sqlalchemy native column type that corresponds to the column type defined in the data source (return None to use default type inferred by SQLAlchemy). Override `column_type_mappings` for specific needs (see MSSQL for example of NCHAR/NVARCHAR handling). - :param type_: Column type returned by inspector + :param column_type: Column type returned by inspector :return: SqlAlchemy column type """ - if not type_: + if not column_type: return None - for regex, sqla_type in cls.column_type_mappings: - match = regex.match(type_) + for regex, sqla_type, generic_type in column_type_mappings: + match = regex.match(column_type) if match: if callable(sqla_type): - return sqla_type(match) - return sqla_type + return sqla_type(match), generic_type + return sqla_type, generic_type return None @staticmethod @@ -1101,3 +1154,43 @@ def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool: or parsed_query.is_explain() or parsed_query.is_show() ) + + @classmethod + def get_column_spec( + cls, + native_type: Optional[str], + source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, + column_type_mappings: Tuple[ + Tuple[ + Pattern[str], + Union[TypeEngine, Callable[[Match[str]], TypeEngine]], + GenericDataType, + ], + ..., + ] = column_type_mappings, + ) -> Union[ColumnSpec, None]: + """ + Converts native database type to sqlalchemy column type. + :param native_type: Native database typee + :param source: Type coming from the database table or cursor description + :return: ColumnSpec object + """ + column_type = None + + if ( + cls.get_sqla_column_type( + native_type, column_type_mappings=column_type_mappings + ) + is not None + ): + column_type, generic_type = cls.get_sqla_column_type( # type: ignore + native_type, column_type_mappings=column_type_mappings + ) + is_dttm = generic_type == GenericDataType.TEMPORAL + + if column_type: + return ColumnSpec( + sqla_type=column_type, generic_type=generic_type, is_dttm=is_dttm + ) + + return None diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index b105c709d5518..67b9ec1b62dee 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -15,18 +15,12 @@ # specific language governing permissions and limitations # under the License. import logging -import re from datetime import datetime -from typing import Any, List, Optional, Tuple, TYPE_CHECKING - -from sqlalchemy.types import String, UnicodeText +from typing import Any, List, Optional, Tuple from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod from superset.utils import core as utils -if TYPE_CHECKING: - from superset.models.core import Database - logger = logging.getLogger(__name__) @@ -77,11 +71,6 @@ def fetch_data( # Lists of `pyodbc.Row` need to be unpacked further return cls.pyodbc_rows_to_tuples(data) - column_type_mappings = ( - (re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), UnicodeText()), - (re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), String()), - ) - @classmethod def extract_error_message(cls, ex: Exception) -> str: if str(ex).startswith("(8155,"): diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 481a7693762c9..3cb35e308c05e 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -14,14 +14,29 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import re from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Match, Optional, Pattern, Tuple, Union from urllib import parse +from sqlalchemy.dialects.mysql import ( + BIT, + DECIMAL, + DOUBLE, + FLOAT, + INTEGER, + LONGTEXT, + MEDIUMINT, + MEDIUMTEXT, + TINYINT, + TINYTEXT, +) from sqlalchemy.engine.url import URL +from sqlalchemy.types import TypeEngine from superset.db_engine_specs.base import BaseEngineSpec from superset.utils import core as utils +from superset.utils.core import ColumnSpec, GenericDataType class MySQLEngineSpec(BaseEngineSpec): @@ -29,6 +44,34 @@ class MySQLEngineSpec(BaseEngineSpec): engine_name = "MySQL" max_column_name_length = 64 + column_type_mappings: Tuple[ + Tuple[ + Pattern[str], + Union[TypeEngine, Callable[[Match[str]], TypeEngine]], + GenericDataType, + ], + ..., + ] = ( + (re.compile(r"^int.*", re.IGNORECASE), INTEGER(), GenericDataType.NUMERIC,), + (re.compile(r"^tinyint", re.IGNORECASE), TINYINT(), GenericDataType.NUMERIC,), + ( + re.compile(r"^mediumint", re.IGNORECASE), + MEDIUMINT(), + GenericDataType.NUMERIC, + ), + (re.compile(r"^decimal", re.IGNORECASE), DECIMAL(), GenericDataType.NUMERIC,), + (re.compile(r"^float", re.IGNORECASE), FLOAT(), GenericDataType.NUMERIC,), + (re.compile(r"^double", re.IGNORECASE), DOUBLE(), GenericDataType.NUMERIC,), + (re.compile(r"^bit", re.IGNORECASE), BIT(), GenericDataType.NUMERIC,), + (re.compile(r"^tinytext", re.IGNORECASE), TINYTEXT(), GenericDataType.STRING,), + ( + re.compile(r"^mediumtext", re.IGNORECASE), + MEDIUMTEXT(), + GenericDataType.STRING, + ), + (re.compile(r"^longtext", re.IGNORECASE), LONGTEXT(), GenericDataType.STRING,), + ) + _time_grain_expressions = { None: "{col}", "PT1S": "DATE_ADD(DATE({col}), " @@ -98,3 +141,26 @@ def _extract_error_message(cls, ex: Exception) -> str: except (AttributeError, KeyError): pass return message + + @classmethod + def get_column_spec( # type: ignore + cls, + native_type: Optional[str], + source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, + column_type_mappings: Tuple[ + Tuple[ + Pattern[str], + Union[TypeEngine, Callable[[Match[str]], TypeEngine]], + GenericDataType, + ], + ..., + ] = column_type_mappings, + ) -> Union[ColumnSpec, None]: + + column_spec = super().get_column_spec(native_type) + if column_spec: + return column_spec + + return super().get_column_spec( + native_type, column_type_mappings=column_type_mappings + ) diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index a63ffdd8b707e..38c4a6dca4df1 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -18,14 +18,28 @@ import logging import re from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import ( + Any, + Callable, + Dict, + List, + Match, + Optional, + Pattern, + Tuple, + TYPE_CHECKING, + Union, +) from pytz import _FixedOffset # type: ignore +from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON from sqlalchemy.dialects.postgresql.base import PGInspector +from sqlalchemy.types import String, TypeEngine from superset.db_engine_specs.base import BaseEngineSpec from superset.exceptions import SupersetException from superset.utils import core as utils +from superset.utils.core import ColumnSpec, GenericDataType if TYPE_CHECKING: from superset.models.core import Database # pragma: no cover @@ -77,6 +91,21 @@ class PostgresEngineSpec(PostgresBaseEngineSpec): max_column_name_length = 63 try_remove_schema_from_table_name = False + column_type_mappings = ( + ( + re.compile(r"^double precision", re.IGNORECASE), + DOUBLE_PRECISION(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^array.*", re.IGNORECASE), + lambda match: ARRAY(int(match[2])) if match[2] else String(), + utils.GenericDataType.STRING, + ), + (re.compile(r"^json.*", re.IGNORECASE), JSON(), utils.GenericDataType.STRING,), + (re.compile(r"^enum.*", re.IGNORECASE), ENUM(), utils.GenericDataType.STRING,), + ) + @classmethod def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: return True @@ -144,3 +173,26 @@ def get_extra_params(database: "Database") -> Dict[str, Any]: engine_params["connect_args"] = connect_args extra["engine_params"] = engine_params return extra + + @classmethod + def get_column_spec( # type: ignore + cls, + native_type: Optional[str], + source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, + column_type_mappings: Tuple[ + Tuple[ + Pattern[str], + Union[TypeEngine, Callable[[Match[str]], TypeEngine]], + GenericDataType, + ], + ..., + ] = column_type_mappings, + ) -> Union[ColumnSpec, None]: + + column_spec = super().get_column_spec(native_type) + if column_spec: + return column_spec + + return super().get_column_spec( + native_type, column_type_mappings=column_type_mappings + ) diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 6ea687fcfe025..27fad223e7874 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -23,7 +23,19 @@ from contextlib import closing from datetime import datetime from distutils.version import StrictVersion -from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Match, + Optional, + Pattern, + Tuple, + TYPE_CHECKING, + Union, +) from urllib import parse import pandas as pd @@ -36,6 +48,7 @@ from sqlalchemy.engine.url import make_url, URL from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ColumnClause, Select +from sqlalchemy.types import TypeEngine from superset import app, cache_manager, is_feature_enabled from superset.db_engine_specs.base import BaseEngineSpec @@ -52,6 +65,7 @@ from superset.result_set import destringify from superset.sql_parse import ParsedQuery from superset.utils import core as utils +from superset.utils.core import ColumnSpec, GenericDataType if TYPE_CHECKING: # prevent circular imports @@ -293,7 +307,8 @@ def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branch field_info = cls._split_data_type(single_field, r"\s") # check if there is a structural data type within # overall structural data type - column_type = cls.get_sqla_column_type(field_info[1]) + column_spec = cls.get_column_spec(field_info[1]) + column_type = column_spec.sqla_type if column_spec else None if column_type is None: column_type = types.String() logger.info( @@ -356,31 +371,89 @@ def _show_columns( return columns column_type_mappings = ( - (re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()), - (re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()), - (re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()), - (re.compile(r"^integer.*", re.IGNORECASE), types.Integer()), - (re.compile(r"^bigint.*", re.IGNORECASE), types.BigInteger()), - (re.compile(r"^real.*", re.IGNORECASE), types.Float()), - (re.compile(r"^double.*", re.IGNORECASE), types.Float()), - (re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL()), + ( + re.compile(r"^boolean.*", re.IGNORECASE), + types.BOOLEAN, + utils.GenericDataType.BOOLEAN, + ), + ( + re.compile(r"^tinyint.*", re.IGNORECASE), + TinyInteger(), + utils.GenericDataType.NUMERIC, + ), + ( + re.compile(r"^smallint.*", re.IGNORECASE), + types.SMALLINT(), + utils.GenericDataType.NUMERIC, + ), + ( + re.compile(r"^integer.*", re.IGNORECASE), + types.INTEGER(), + utils.GenericDataType.NUMERIC, + ), + ( + re.compile(r"^bigint.*", re.IGNORECASE), + types.BIGINT(), + utils.GenericDataType.NUMERIC, + ), + ( + re.compile(r"^real.*", re.IGNORECASE), + types.FLOAT(), + utils.GenericDataType.NUMERIC, + ), + ( + re.compile(r"^double.*", re.IGNORECASE), + types.FLOAT(), + utils.GenericDataType.NUMERIC, + ), + ( + re.compile(r"^decimal.*", re.IGNORECASE), + types.DECIMAL(), + utils.GenericDataType.NUMERIC, + ), ( re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE), lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(), + utils.GenericDataType.STRING, ), ( re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE), lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(), + utils.GenericDataType.STRING, ), - (re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY()), - (re.compile(r"^json.*", re.IGNORECASE), types.JSON()), - (re.compile(r"^date.*", re.IGNORECASE), types.DATE()), - (re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP()), - (re.compile(r"^time.*", re.IGNORECASE), types.Time()), - (re.compile(r"^interval.*", re.IGNORECASE), Interval()), - (re.compile(r"^array.*", re.IGNORECASE), Array()), - (re.compile(r"^map.*", re.IGNORECASE), Map()), - (re.compile(r"^row.*", re.IGNORECASE), Row()), + ( + re.compile(r"^varbinary.*", re.IGNORECASE), + types.VARBINARY(), + utils.GenericDataType.STRING, + ), + ( + re.compile(r"^json.*", re.IGNORECASE), + types.JSON(), + utils.GenericDataType.STRING, + ), + ( + re.compile(r"^date.*", re.IGNORECASE), + types.DATE(), + utils.GenericDataType.TEMPORAL, + ), + ( + re.compile(r"^timestamp.*", re.IGNORECASE), + types.TIMESTAMP(), + utils.GenericDataType.TEMPORAL, + ), + ( + re.compile(r"^interval.*", re.IGNORECASE), + Interval(), + utils.GenericDataType.TEMPORAL, + ), + ( + re.compile(r"^time.*", re.IGNORECASE), + types.Time(), + utils.GenericDataType.TEMPORAL, + ), + (re.compile(r"^array.*", re.IGNORECASE), Array(), utils.GenericDataType.STRING), + (re.compile(r"^map.*", re.IGNORECASE), Map(), utils.GenericDataType.STRING), + (re.compile(r"^row.*", re.IGNORECASE), Row(), utils.GenericDataType.STRING), ) @classmethod @@ -412,7 +485,8 @@ def get_columns( continue # otherwise column is a basic data type - column_type = cls.get_sqla_column_type(column.Type) + column_spec = cls.get_column_spec(column.Type) + column_type = column_spec.sqla_type if column_spec else None if column_type is None: column_type = types.String() logger.info( @@ -1111,3 +1185,27 @@ def extract_errors(cls, ex: Exception) -> List[Dict[str, Any]]: def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool: """Pessimistic readonly, 100% sure statement won't mutate anything""" return super().is_readonly_query(parsed_query) or parsed_query.is_show() + + @classmethod + def get_column_spec( # type: ignore + cls, + native_type: Optional[str], + source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, + column_type_mappings: Tuple[ + Tuple[ + Pattern[str], + Union[TypeEngine, Callable[[Match[str]], TypeEngine]], + GenericDataType, + ], + ..., + ] = column_type_mappings, + ) -> Union[ColumnSpec, None]: + + column_spec = super().get_column_spec( + native_type, column_type_mappings=column_type_mappings + ) + + if column_spec: + return column_spec + + return super().get_column_spec(native_type) diff --git a/superset/result_set.py b/superset/result_set.py index f3f68ac2dc813..34d5dc909d630 100644 --- a/superset/result_set.py +++ b/superset/result_set.py @@ -181,9 +181,10 @@ def first_nonempty(items: List[Any]) -> Any: return next((i for i in items if i), None) def is_temporal(self, db_type_str: Optional[str]) -> bool: - return self.db_engine_spec.is_db_column_type_match( - db_type_str, utils.GenericDataType.TEMPORAL - ) + column_spec = self.db_engine_spec.get_column_spec(db_type_str) + if column_spec is None: + return False + return column_spec.is_dttm def data_type(self, col_name: str, pa_dtype: pa.DataType) -> Optional[str]: """Given a pyarrow data type, Returns a generic database type""" diff --git a/superset/utils/core.py b/superset/utils/core.py index 893592bbc8749..a1e3d113660b2 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -82,7 +82,7 @@ from sqlalchemy.engine import Connection, Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.sql.type_api import Variant -from sqlalchemy.types import TEXT, TypeDecorator +from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine from typing_extensions import TypedDict import _thread # pylint: disable=C0411 @@ -148,6 +148,10 @@ class GenericDataType(IntEnum): STRING = 1 TEMPORAL = 2 BOOLEAN = 3 + # ARRAY = 4 # Mapping all the complex data types to STRING for now + # JSON = 5 # and leaving these as a reminder. + # MAP = 6 + # ROW = 7 class ChartDataResultFormat(str, Enum): @@ -306,6 +310,18 @@ class TemporalType(str, Enum): TIMESTAMP = "TIMESTAMP" +class ColumnTypeSource(Enum): + GET_TABLE = 1 + CURSOR_DESCRIPION = 2 + + +class ColumnSpec(NamedTuple): + sqla_type: Union[TypeEngine, str] + generic_type: GenericDataType + is_dttm: bool + python_date_format: Optional[str] = None + + try: # Having might not have been imported. class DimSelector(Having): diff --git a/tests/db_engine_specs/mssql_tests.py b/tests/db_engine_specs/mssql_tests.py index 149ed692c93ed..74c3715f28a92 100644 --- a/tests/db_engine_specs/mssql_tests.py +++ b/tests/db_engine_specs/mssql_tests.py @@ -24,32 +24,37 @@ from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.mssql import MssqlEngineSpec +from superset.utils.core import GenericDataType from tests.db_engine_specs.base_tests import TestDbEngineSpec class TestMssqlEngineSpec(TestDbEngineSpec): def test_mssql_column_types(self): - def assert_type(type_string, type_expected): - type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string) + def assert_type(type_string, type_expected, generic_type_expected): if type_expected is None: + type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string) self.assertIsNone(type_assigned) else: - self.assertIsInstance(type_assigned, type_expected) + column_spec = MssqlEngineSpec.get_column_spec(type_string) + if column_spec != None: + self.assertIsInstance(column_spec.sqla_type, type_expected) + self.assertEquals(column_spec.generic_type, generic_type_expected) - assert_type("INT", None) - assert_type("STRING", String) - assert_type("CHAR(10)", String) - assert_type("VARCHAR(10)", String) - assert_type("TEXT", String) - assert_type("NCHAR(10)", UnicodeText) - assert_type("NVARCHAR(10)", UnicodeText) - assert_type("NTEXT", UnicodeText) + assert_type("STRING", String, GenericDataType.STRING) + assert_type("CHAR(10)", String, GenericDataType.STRING) + assert_type("VARCHAR(10)", String, GenericDataType.STRING) + assert_type("TEXT", String, GenericDataType.STRING) + assert_type("NCHAR(10)", UnicodeText, GenericDataType.STRING) + assert_type("NVARCHAR(10)", UnicodeText, GenericDataType.STRING) + assert_type("NTEXT", UnicodeText, GenericDataType.STRING) def test_where_clause_n_prefix(self): dialect = mssql.dialect() spec = MssqlEngineSpec - str_col = column("col", type_=spec.get_sqla_column_type("VARCHAR(10)")) - unicode_col = column("unicode_col", type_=spec.get_sqla_column_type("NTEXT")) + type_, _ = spec.get_sqla_column_type("VARCHAR(10)") + str_col = column("col", type_=type_) + type_, _ = spec.get_sqla_column_type("NTEXT") + unicode_col = column("unicode_col", type_=type_) tbl = table("tbl") sel = ( select([str_col, unicode_col]) diff --git a/tests/db_engine_specs/mysql_tests.py b/tests/db_engine_specs/mysql_tests.py index ba56b6c9fd296..035b06f682e38 100644 --- a/tests/db_engine_specs/mysql_tests.py +++ b/tests/db_engine_specs/mysql_tests.py @@ -89,18 +89,9 @@ def test_is_db_column_type_match(self): ("TIME", GenericDataType.TEMPORAL), ) - for type_expectation in type_expectations: - type_str = type_expectation[0] - col_type = type_expectation[1] - assert MySQLEngineSpec.is_db_column_type_match( - type_str, GenericDataType.NUMERIC - ) is (col_type == GenericDataType.NUMERIC) - assert MySQLEngineSpec.is_db_column_type_match( - type_str, GenericDataType.STRING - ) is (col_type == GenericDataType.STRING) - assert MySQLEngineSpec.is_db_column_type_match( - type_str, GenericDataType.TEMPORAL - ) is (col_type == GenericDataType.TEMPORAL) + for type_str, col_type in type_expectations: + column_spec = MySQLEngineSpec.get_column_spec(type_str) + assert column_spec.generic_type == col_type def test_extract_error_message(self): from MySQLdb._exceptions import OperationalError diff --git a/tests/db_engine_specs/presto_tests.py b/tests/db_engine_specs/presto_tests.py index d0343a32d792d..5fd16a69ac847 100644 --- a/tests/db_engine_specs/presto_tests.py +++ b/tests/db_engine_specs/presto_tests.py @@ -24,7 +24,7 @@ from superset.db_engine_specs.presto import PrestoEngineSpec from superset.sql_parse import ParsedQuery -from superset.utils.core import DatasourceName +from superset.utils.core import DatasourceName, GenericDataType from tests.db_engine_specs.base_tests import TestDbEngineSpec @@ -535,30 +535,37 @@ def test_presto_expand_data_array(self): self.assertEqual(actual_expanded_cols, expected_expanded_cols) def test_get_sqla_column_type(self): - sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)") - assert isinstance(sqla_type, types.VARCHAR) - assert sqla_type.length == 255 - - sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar") - assert isinstance(sqla_type, types.String) - assert sqla_type.length is None - - sqla_type = PrestoEngineSpec.get_sqla_column_type("char(10)") - assert isinstance(sqla_type, types.CHAR) - assert sqla_type.length == 10 - - sqla_type = PrestoEngineSpec.get_sqla_column_type("char") - assert isinstance(sqla_type, types.CHAR) - assert sqla_type.length is None - - sqla_type = PrestoEngineSpec.get_sqla_column_type("integer") - assert isinstance(sqla_type, types.Integer) - - sqla_type = PrestoEngineSpec.get_sqla_column_type("time") - assert isinstance(sqla_type, types.Time) - - sqla_type = PrestoEngineSpec.get_sqla_column_type("timestamp") - assert isinstance(sqla_type, types.TIMESTAMP) + column_spec = PrestoEngineSpec.get_column_spec("varchar(255)") + assert isinstance(column_spec.sqla_type, types.VARCHAR) + assert column_spec.sqla_type.length == 255 + self.assertEqual(column_spec.generic_type, GenericDataType.STRING) + + column_spec = PrestoEngineSpec.get_column_spec("varchar") + assert isinstance(column_spec.sqla_type, types.String) + assert column_spec.sqla_type.length is None + self.assertEqual(column_spec.generic_type, GenericDataType.STRING) + + column_spec = PrestoEngineSpec.get_column_spec("char(10)") + assert isinstance(column_spec.sqla_type, types.CHAR) + assert column_spec.sqla_type.length == 10 + self.assertEqual(column_spec.generic_type, GenericDataType.STRING) + + column_spec = PrestoEngineSpec.get_column_spec("char") + assert isinstance(column_spec.sqla_type, types.CHAR) + assert column_spec.sqla_type.length is None + self.assertEqual(column_spec.generic_type, GenericDataType.STRING) + + column_spec = PrestoEngineSpec.get_column_spec("integer") + assert isinstance(column_spec.sqla_type, types.Integer) + self.assertEqual(column_spec.generic_type, GenericDataType.NUMERIC) + + column_spec = PrestoEngineSpec.get_column_spec("time") + assert isinstance(column_spec.sqla_type, types.Time) + self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL) + + column_spec = PrestoEngineSpec.get_column_spec("timestamp") + assert isinstance(column_spec.sqla_type, types.TIMESTAMP) + self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL) sqla_type = PrestoEngineSpec.get_sqla_column_type(None) assert sqla_type is None diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py index e03460980a90c..cdd77c270b90d 100644 --- a/tests/sqla_models_tests.py +++ b/tests/sqla_models_tests.py @@ -84,11 +84,9 @@ def test_db_column_types(self): "TEXT": GenericDataType.STRING, "NTEXT": GenericDataType.STRING, # numeric - "INT": GenericDataType.NUMERIC, + "INTEGER": GenericDataType.NUMERIC, "BIGINT": GenericDataType.NUMERIC, - "FLOAT": GenericDataType.NUMERIC, "DECIMAL": GenericDataType.NUMERIC, - "MONEY": GenericDataType.NUMERIC, # temporal "DATE": GenericDataType.TEMPORAL, "DATETIME": GenericDataType.TEMPORAL,