Skip to content

Commit

Permalink
cleanup column_type_mappings (#17569)
Browse files Browse the repository at this point in the history
Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>
  • Loading branch information
dungdm93 authored Jan 14, 2022
1 parent 26dc600 commit 5a74090
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 170 deletions.
137 changes: 61 additions & 76 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
from sqlalchemy.types import String, TypeEngine, UnicodeText
from sqlalchemy.types import TypeEngine
from typing_extensions import TypedDict

from superset import security_manager, sql_parse
Expand All @@ -71,6 +71,12 @@
from superset.connectors.sqla.models import TableColumn
from superset.models.core import Database

ColumnTypeMapping = Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
]

logger = logging.getLogger()


Expand Down Expand Up @@ -156,26 +162,37 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods

engine = "base" # str as defined in sqlalchemy.engine.engine
engine_aliases: Set[str] = set()
engine_name: Optional[
str
] = None # used for user messages, overridden in child classes
engine_name: Optional[str] = None # for user messages, overridden in child classes
_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = (
column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
(
re.compile(r"^string", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r"^n((var)?char|text)", re.IGNORECASE),
types.UnicodeText(),
GenericDataType.STRING,
),
(
re.compile(r"^(var)?char", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r"^(tiny|medium|long)?text", re.IGNORECASE),
types.String(),
GenericDataType.STRING,
),
(
re.compile(r"^smallint", re.IGNORECASE),
types.SmallInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^int.*", re.IGNORECASE),
re.compile(r"^int(eger)?", re.IGNORECASE),
types.Integer(),
GenericDataType.NUMERIC,
),
Expand All @@ -184,6 +201,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
types.BigInteger(),
GenericDataType.NUMERIC,
),
(re.compile(r"^long", re.IGNORECASE), types.Float(), GenericDataType.NUMERIC,),
(
re.compile(r"^decimal", re.IGNORECASE),
types.Numeric(),
Expand Down Expand Up @@ -222,26 +240,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
GenericDataType.NUMERIC,
),
(
re.compile(r"^string", re.IGNORECASE),
types.String(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE),
UnicodeText(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE),
String(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^((TINY|MEDIUM|LONG)?TEXT)", re.IGNORECASE),
String(),
utils.GenericDataType.STRING,
re.compile(r"^timestamp", re.IGNORECASE),
types.TIMESTAMP(),
GenericDataType.TEMPORAL,
),
(re.compile(r"^LONG", re.IGNORECASE), types.Float(), GenericDataType.NUMERIC,),
(
re.compile(r"^datetime", re.IGNORECASE),
types.DateTime(),
Expand All @@ -252,19 +254,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
types.DateTime(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^timestamp", re.IGNORECASE),
types.TIMESTAMP(),
GenericDataType.TEMPORAL,
),
(re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,),
(
re.compile(r"^interval", re.IGNORECASE),
types.Interval(),
GenericDataType.TEMPORAL,
),
(re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,),
(
re.compile(r"^bool.*", re.IGNORECASE),
re.compile(r"^bool(ean)?", re.IGNORECASE),
types.Boolean(),
GenericDataType.BOOLEAN,
),
Expand Down Expand Up @@ -693,7 +690,6 @@ def df_to_sql(
to_sql_kwargs["name"] = table.table

if table.schema:

# Only add schema when it is preset and non empty.
to_sql_kwargs["schema"] = table.schema

Expand Down Expand Up @@ -844,6 +840,7 @@ def get_table_names( # pylint: disable=unused-argument
"""
Get all tables from schema
:param database: The database to get info
:param inspector: SqlAlchemy inspector
:param schema: Schema to inspect. If omitted, uses default schema for database
:return: All tables in schema
Expand All @@ -860,6 +857,7 @@ def get_view_names( # pylint: disable=unused-argument
"""
Get all views from schema
:param database: The database to get info
:param inspector: SqlAlchemy inspector
:param schema: Schema name. If omitted, uses default schema for database
:return: All views in schema
Expand Down Expand Up @@ -924,7 +922,7 @@ def where_latest_partition( # pylint: disable=too-many-arguments,unused-argumen
:param database: Database instance
:param query: SqlAlchemy query
:param columns: List of TableColumns
:return: SqlAlchemy query with additional where clause referencing latest
:return: SqlAlchemy query with additional where clause referencing the latest
partition
"""
# TODO: Fix circular import caused by importing Database, TableColumn
Expand Down Expand Up @@ -954,12 +952,12 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals
:param database: Database instance
:param table_name: Table name, unquoted
:param engine: SqlALchemy Engine instance
:param engine: SqlAlchemy Engine instance
:param schema: Schema, unquoted
:param limit: limit to impose on query
:param show_cols: Show columns in query; otherwise use "*"
:param indent: Add indentation to query
:param latest_partition: Only query latest partition
:param latest_partition: Only query the latest partition
:param cols: Columns to include in query
:return: SQL query
"""
Expand Down Expand Up @@ -993,7 +991,7 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals
return sql

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any,) -> Dict[str, Any]:
def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
"""
Generate a SQL query that estimates the cost of a given statement.
Expand Down Expand Up @@ -1024,7 +1022,7 @@ def process_statement(
:param statement: A single SQL statement
:param database: Database instance
:param username: Effective username
:param user_name: Effective username
:return: Dictionary with different costs
"""
parsed_query = ParsedQuery(statement)
Expand Down Expand Up @@ -1089,7 +1087,6 @@ def update_impersonation_config(
:param connect_args: config to be updated
:param uri: URI
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
:return: None
"""
Expand Down Expand Up @@ -1122,8 +1119,8 @@ def make_label_compatible(cls, label: str) -> Union[str, quoted_name]:
Conditionally mutate and/or quote a sqlalchemy expression label. If
force_column_alias_quotes is set to True, return the label as a
sqlalchemy.sql.elements.quoted_name object to ensure that the select query
and query results have same case. Otherwise return the mutated label as a
regular string. If maxmimum supported column name length is exceeded,
and query results have same case. Otherwise, return the mutated label as a
regular string. If maximum supported column name length is exceeded,
generate a truncated label by calling truncate_label().
:param label: expected expression label/alias
Expand All @@ -1143,32 +1140,27 @@ def make_label_compatible(cls, label: str) -> Union[str, quoted_name]:
def get_sqla_column_type(
cls,
column_type: Optional[str],
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[Tuple[TypeEngine, GenericDataType], None]:
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[Tuple[TypeEngine, GenericDataType]]:
"""
Return a sqlalchemy native column type that corresponds to the column type
defined in the data source (return None to use default type inferred by
SQLAlchemy). Override `column_type_mappings` for specific needs
(see MSSQL for example of NCHAR/NVARCHAR handling).
:param column_type: Column type returned by inspector
:param column_type_mappings: Maps from string to SqlAlchemy TypeEngine
:return: SqlAlchemy column type
"""
if not column_type:
return None
for regex, sqla_type, generic_type in column_type_mappings:
match = regex.match(column_type)
if match:
if callable(sqla_type):
return sqla_type(match), generic_type
return sqla_type, generic_type
if not match:
continue
if callable(sqla_type):
return sqla_type(match), generic_type
return sqla_type, generic_type
return None

@staticmethod
Expand All @@ -1192,7 +1184,7 @@ def _truncate_label(cls, label: str) -> str:
"""
In the case that a label exceeds the max length supported by the engine,
this method is used to construct a deterministic and unique label based on
the original label. By default this returns an md5 hash of the original label,
the original label. By default, this returns a md5 hash of the original label,
conditionally truncated if the length of the hash exceeds the max column length
of the engine.
Expand All @@ -1211,8 +1203,8 @@ def column_datatype_to_string(
) -> str:
"""
Convert sqlalchemy column type to string representation.
By default removes collation and character encoding info to avoid unnecessarily
long datatypes.
By default, removes collation and character encoding info to avoid
unnecessarily long datatypes.
:param sqla_column_type: SqlAlchemy column type
:param dialect: Sqlalchemy dialect
Expand Down Expand Up @@ -1304,20 +1296,14 @@ def get_column_spec( # pylint: disable=unused-argument
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[ColumnSpec, None]:
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[ColumnSpec]:
"""
Converts native database type to sqlalchemy column type.
:param native_type: Native database typee
:param source: Type coming from the database table or cursor description
:param native_type: Native database type
:param db_extra: The database extra object
:param source: Type coming from the database table or cursor description
:param column_type_mappings: Maps from string to SqlAlchemy TypeEngine
:return: ColumnSpec object
"""
col_types = cls.get_sqla_column_type(
Expand Down Expand Up @@ -1417,7 +1403,6 @@ class BasicParametersType(TypedDict, total=False):


class BasicParametersMixin:

"""
Mixin for configuring DB engine specs via a dictionary.
Expand Down
29 changes: 9 additions & 20 deletions superset/db_engine_specs/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
import re
from datetime import datetime
from typing import Any, Callable, Dict, Match, Optional, Pattern, Tuple, Union
from typing import Any, Dict, Optional, Pattern, Tuple
from urllib import parse

from flask_babel import gettext as __
Expand All @@ -33,9 +33,12 @@
TINYTEXT,
)
from sqlalchemy.engine.url import URL
from sqlalchemy.types import TypeEngine

from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
ColumnTypeMapping,
)
from superset.errors import SupersetErrorType
from superset.models.sql_lab import Query
from superset.utils import core as utils
Expand Down Expand Up @@ -70,14 +73,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
)
encryption_parameters = {"ssl": "1"}

column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = (
column_type_mappings = (
(re.compile(r"^int.*", re.IGNORECASE), INTEGER(), GenericDataType.NUMERIC,),
(re.compile(r"^tinyint", re.IGNORECASE), TINYINT(), GenericDataType.NUMERIC,),
(
Expand Down Expand Up @@ -208,15 +204,8 @@ def get_column_spec(
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[ColumnSpec, None]:
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[ColumnSpec]:

column_spec = super().get_column_spec(native_type)
if column_spec:
Expand Down
Loading

0 comments on commit 5a74090

Please sign in to comment.