Skip to content

Commit

Permalink
feat(explore): Postgres datatype conversion (#13294)
Browse files Browse the repository at this point in the history
* test

* unnecessary import

* fix lint

* changes

* fix lint

* changes

* changes

* changes

* changes

* answering comments & changes

* answering comments

* answering comments

* changes

* changes

* changes

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests
  • Loading branch information
nikolagigic authored Mar 12, 2021
1 parent 1a46f93 commit 609c359
Show file tree
Hide file tree
Showing 12 changed files with 471 additions and 153 deletions.
28 changes: 15 additions & 13 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from superset.sql_parse import ParsedQuery
from superset.typing import Metric, QueryObjectDict
from superset.utils import core as utils
from superset.utils.core import GenericDataType

config = app.config
metadata = Model.metadata # pylint: disable=no-member
Expand Down Expand Up @@ -186,20 +187,20 @@ def is_numeric(self) -> bool:
"""
Check if the column has a numeric datatype.
"""
db_engine_spec = self.table.database.db_engine_spec
return db_engine_spec.is_db_column_type_match(
self.type, utils.GenericDataType.NUMERIC
)
column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
if column_spec is None:
return False
return column_spec.generic_type == GenericDataType.NUMERIC

@property
def is_string(self) -> bool:
"""
Check if the column has a string datatype.
"""
db_engine_spec = self.table.database.db_engine_spec
return db_engine_spec.is_db_column_type_match(
self.type, utils.GenericDataType.STRING
)
column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
if column_spec is None:
return False
return column_spec.generic_type == GenericDataType.STRING

@property
def is_temporal(self) -> bool:
Expand All @@ -211,18 +212,19 @@ def is_temporal(self) -> bool:
"""
if self.is_dttm is not None:
return self.is_dttm
db_engine_spec = self.table.database.db_engine_spec
return db_engine_spec.is_db_column_type_match(
self.type, utils.GenericDataType.TEMPORAL
)
column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
if column_spec is None:
return False
return column_spec.is_dttm

def get_sqla_col(self, label: Optional[str] = None) -> Column:
label = label or self.column_name
if self.expression:
col = literal_column(self.expression)
else:
db_engine_spec = self.table.database.db_engine_spec
type_ = db_engine_spec.get_sqla_column_type(self.type)
column_spec = db_engine_spec.get_column_spec(self.type)
type_ = column_spec.sqla_type if column_spec else None
col = column(self.column_name, type_=type_)
col = self.table.make_sqla_column_compatible(col, label)
return col
Expand Down
191 changes: 142 additions & 49 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import sqlparse
from flask import g
from flask_babel import gettext as __, lazy_gettext as _
from sqlalchemy import column, DateTime, select
from sqlalchemy import column, DateTime, select, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.interfaces import Compiled, Dialect
from sqlalchemy.engine.reflection import Inspector
Expand All @@ -50,13 +50,14 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom
from sqlalchemy.types import TypeEngine
from sqlalchemy.types import String, TypeEngine, UnicodeText

from superset import app, security_manager, sql_parse
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery, Table
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType

if TYPE_CHECKING:
# prevent circular imports
Expand Down Expand Up @@ -145,8 +146,87 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[
Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ...,
] = ()
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = (
(
re.compile(r"^smallint", re.IGNORECASE),
types.SmallInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^integer", re.IGNORECASE),
types.Integer(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^bigint", re.IGNORECASE),
types.BigInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^decimal", re.IGNORECASE),
types.Numeric(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^numeric", re.IGNORECASE),
types.Numeric(),
GenericDataType.NUMERIC,
),
(re.compile(r"^real", re.IGNORECASE), types.REAL, GenericDataType.NUMERIC,),
(
re.compile(r"^smallserial", re.IGNORECASE),
types.SmallInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^serial", re.IGNORECASE),
types.Integer(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^bigserial", re.IGNORECASE),
types.BigInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^string", re.IGNORECASE),
types.String(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE),
UnicodeText(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE),
String(),
utils.GenericDataType.STRING,
),
(re.compile(r"^date", re.IGNORECASE), types.Date(), GenericDataType.TEMPORAL,),
(
re.compile(r"^timestamp", re.IGNORECASE),
types.TIMESTAMP(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^interval", re.IGNORECASE),
types.Interval(),
GenericDataType.TEMPORAL,
),
(re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,),
(
re.compile(r"^boolean", re.IGNORECASE),
types.Boolean(),
GenericDataType.BOOLEAN,
),
)
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
time_secondary_columns = False
Expand All @@ -160,25 +240,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
run_multiple_statements_as_one = False

# default matching patterns to convert database specific column types to
# more generic types
db_column_types: Dict[utils.GenericDataType, Tuple[Pattern[str], ...]] = {
utils.GenericDataType.NUMERIC: (
re.compile(r"BIT", re.IGNORECASE),
re.compile(
r".*(DOUBLE|FLOAT|INT|NUMBER|REAL|NUMERIC|DECIMAL|MONEY).*",
re.IGNORECASE,
),
re.compile(r".*LONG$", re.IGNORECASE),
),
utils.GenericDataType.STRING: (
re.compile(r".*(CHAR|STRING|TEXT).*", re.IGNORECASE),
),
utils.GenericDataType.TEMPORAL: (
re.compile(r".*(DATE|TIME).*", re.IGNORECASE),
),
}

@classmethod
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
"""
Expand Down Expand Up @@ -208,25 +269,6 @@ def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception:
return exception
return new_exception(str(exception))

@classmethod
def is_db_column_type_match(
cls, db_column_type: Optional[str], target_column_type: utils.GenericDataType
) -> bool:
"""
Check if a column type satisfies a pattern in a collection of regexes found in
`db_column_types`. For example, if `db_column_type == "NVARCHAR"`,
it would be a match for "STRING" due to being a match for the regex ".*CHAR.*".
:param db_column_type: Column type to evaluate
:param target_column_type: The target type to evaluate for
:return: `True` if a `db_column_type` matches any pattern corresponding to
`target_column_type`
"""
if not db_column_type:
return False
patterns = cls.db_column_types[target_column_type]
return any(pattern.match(db_column_type) for pattern in patterns)

@classmethod
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
return False
Expand Down Expand Up @@ -967,24 +1009,35 @@ def make_label_compatible(cls, label: str) -> Union[str, quoted_name]:
return label_mutated

@classmethod
def get_sqla_column_type(cls, type_: Optional[str]) -> Optional[TypeEngine]:
def get_sqla_column_type(
cls,
column_type: Optional[str],
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[Tuple[TypeEngine, GenericDataType], None]:
"""
Return a sqlalchemy native column type that corresponds to the column type
defined in the data source (return None to use default type inferred by
SQLAlchemy). Override `column_type_mappings` for specific needs
(see MSSQL for example of NCHAR/NVARCHAR handling).
:param type_: Column type returned by inspector
:param column_type: Column type returned by inspector
:return: SqlAlchemy column type
"""
if not type_:
if not column_type:
return None
for regex, sqla_type in cls.column_type_mappings:
match = regex.match(type_)
for regex, sqla_type, generic_type in column_type_mappings:
match = regex.match(column_type)
if match:
if callable(sqla_type):
return sqla_type(match)
return sqla_type
return sqla_type(match), generic_type
return sqla_type, generic_type
return None

@staticmethod
Expand Down Expand Up @@ -1101,3 +1154,43 @@ def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
or parsed_query.is_explain()
or parsed_query.is_show()
)

@classmethod
def get_column_spec(
cls,
native_type: Optional[str],
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[ColumnSpec, None]:
"""
Converts native database type to sqlalchemy column type.
:param native_type: Native database typee
:param source: Type coming from the database table or cursor description
:return: ColumnSpec object
"""
column_type = None

if (
cls.get_sqla_column_type(
native_type, column_type_mappings=column_type_mappings
)
is not None
):
column_type, generic_type = cls.get_sqla_column_type( # type: ignore
native_type, column_type_mappings=column_type_mappings
)
is_dttm = generic_type == GenericDataType.TEMPORAL

if column_type:
return ColumnSpec(
sqla_type=column_type, generic_type=generic_type, is_dttm=is_dttm
)

return None
13 changes: 1 addition & 12 deletions superset/db_engine_specs/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,12 @@
# specific language governing permissions and limitations
# under the License.
import logging
import re
from datetime import datetime
from typing import Any, List, Optional, Tuple, TYPE_CHECKING

from sqlalchemy.types import String, UnicodeText
from typing import Any, List, Optional, Tuple

from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.utils import core as utils

if TYPE_CHECKING:
from superset.models.core import Database

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -77,11 +71,6 @@ def fetch_data(
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)

column_type_mappings = (
(re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), UnicodeText()),
(re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), String()),
)

@classmethod
def extract_error_message(cls, ex: Exception) -> str:
if str(ex).startswith("(8155,"):
Expand Down
Loading

0 comments on commit 609c359

Please sign in to comment.