Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Revert "fix(db_engine_specs): improve Presto column type matching (ap…
Browse files Browse the repository at this point in the history
…ache#10658)"

This reverts commit 9461f9c.
  • Loading branch information
serenajiang committed Sep 2, 2020
1 parent 5a106eb commit c3dd451
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 97 deletions.
13 changes: 1 addition & 12 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@
from datetime import datetime
from typing import (
Any,
Callable,
Dict,
List,
Match,
NamedTuple,
Optional,
Pattern,
Expand Down Expand Up @@ -144,9 +142,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
] = None # used for user messages, overridden in child classes
_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[
Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ...,
] = ()
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
time_secondary_columns = False
Expand Down Expand Up @@ -882,18 +877,12 @@ def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]:
"""
Return a sqlalchemy native column type that corresponds to the column type
defined in the data source (return None to use default type inferred by
SQLAlchemy). Override `_column_type_mappings` for specific needs
SQLAlchemy). Needs to be overridden if column requires special handling
(see MSSQL for example of NCHAR/NVARCHAR handling).
:param type_: Column type returned by inspector
:return: SqlAlchemy column type
"""
for regex, sqla_type in cls.column_type_mappings:
match = regex.match(type_)
if match:
if callable(sqla_type):
return sqla_type(match)
return sqla_type
return None

@staticmethod
Expand Down
15 changes: 11 additions & 4 deletions superset/db_engine_specs/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datetime import datetime
from typing import Any, List, Optional, Tuple, TYPE_CHECKING

from sqlalchemy.types import String, UnicodeText
from sqlalchemy.types import String, TypeEngine, UnicodeText

from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.utils import core as utils
Expand Down Expand Up @@ -75,11 +75,18 @@ def fetch_data(
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)

column_type_mappings = (
(re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), UnicodeText()),
(re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), String()),
column_types = (
(String(), re.compile(r"^(?<!N)((VAR){0,1}CHAR|TEXT|STRING)", re.IGNORECASE)),
(UnicodeText(), re.compile(r"^N((VAR){0,1}CHAR|TEXT)", re.IGNORECASE)),
)

@classmethod
def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]:
for sqla_type, regex in cls.column_types:
if regex.match(type_):
return sqla_type
return None

@classmethod
def extract_error_message(cls, ex: Exception) -> str:
if str(ex).startswith("(8155,"):
Expand Down
89 changes: 29 additions & 60 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import pandas as pd
import simplejson as json
from flask_babel import lazy_gettext as _
from sqlalchemy import Column, literal_column, types
from sqlalchemy import Column, literal_column
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import RowProxy
Expand All @@ -40,13 +40,7 @@
from superset.db_engine_specs.base import BaseEngineSpec
from superset.exceptions import SupersetTemplateException
from superset.models.sql_lab import Query
from superset.models.sql_types.presto_sql_types import (
Array,
Interval,
Map,
Row,
TinyInteger,
)
from superset.models.sql_types.presto_sql_types import type_map as presto_type_map
from superset.result_set import destringify
from superset.sql_parse import ParsedQuery
from superset.utils import core as utils
Expand Down Expand Up @@ -263,24 +257,23 @@ def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branch
field_info = cls._split_data_type(single_field, r"\s")
# check if there is a structural data type within
# overall structural data type
column_type = cls.get_sqla_column_type(field_info[1])
if column_type is None:
raise NotImplementedError(
_("Unknown column type: %(col)s", col=field_info[1])
)
if field_info[1] == "array" or field_info[1] == "row":
stack.append((field_info[0], field_info[1]))
full_parent_path = cls._get_full_name(stack)
result.append(
cls._create_column_info(full_parent_path, column_type)
cls._create_column_info(
full_parent_path, presto_type_map[field_info[1]]()
)
)
else: # otherwise this field is a basic data type
full_parent_path = cls._get_full_name(stack)
column_name = "{}.{}".format(
full_parent_path, field_info[0]
)
result.append(
cls._create_column_info(column_name, column_type)
cls._create_column_info(
column_name, presto_type_map[field_info[1]]()
)
)
# If the component type ends with a structural data type, do not pop
# the stack. We have run across a structural data type within the
Expand Down Expand Up @@ -322,34 +315,6 @@ def _show_columns(
columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table))
return columns

column_type_mappings = (
(re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()),
(re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()),
(re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()),
(re.compile(r"^integer.*", re.IGNORECASE), types.Integer()),
(re.compile(r"^bigint.*", re.IGNORECASE), types.BigInteger()),
(re.compile(r"^real.*", re.IGNORECASE), types.Float()),
(re.compile(r"^double.*", re.IGNORECASE), types.Float()),
(re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL()),
(
re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(),
),
(
re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(),
),
(re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY()),
(re.compile(r"^json.*", re.IGNORECASE), types.JSON()),
(re.compile(r"^date.*", re.IGNORECASE), types.DATE()),
(re.compile(r"^time.*", re.IGNORECASE), types.Time()),
(re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP()),
(re.compile(r"^interval.*", re.IGNORECASE), Interval()),
(re.compile(r"^array.*", re.IGNORECASE), Array()),
(re.compile(r"^map.*", re.IGNORECASE), Map()),
(re.compile(r"^row.*", re.IGNORECASE), Row()),
)

@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
Expand All @@ -366,24 +331,28 @@ def get_columns(
columns = cls._show_columns(inspector, table_name, schema)
result: List[Dict[str, Any]] = []
for column in columns:
# parse column if it is a row or array
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
"array" in column.Type or "row" in column.Type
):
structural_column_index = len(result)
cls._parse_structural_column(column.Column, column.Type, result)
result[structural_column_index]["nullable"] = getattr(
column, "Null", True
)
result[structural_column_index]["default"] = None
continue

# otherwise column is a basic data type
column_type = cls.get_sqla_column_type(column.Type)
if column_type is None:
raise NotImplementedError(
_("Unknown column type: %(col)s", col=column_type)
try:
# parse column if it is a row or array
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
"array" in column.Type or "row" in column.Type
):
structural_column_index = len(result)
cls._parse_structural_column(column.Column, column.Type, result)
result[structural_column_index]["nullable"] = getattr(
column, "Null", True
)
result[structural_column_index]["default"] = None
continue

# otherwise column is a basic data type
column_type = presto_type_map[column.Type]()
except KeyError:
logger.info(
"Did not recognize type {} of column {}".format( # pylint: disable=logging-format-interpolation
column.Type, column.Column
)
)
column_type = "OTHER"
column_info = cls._create_column_info(column.Column, column_type)
column_info["nullable"] = getattr(column, "Null", True)
column_info["default"] = None
Expand Down
24 changes: 24 additions & 0 deletions superset/models/sql_types/presto_sql_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from typing import Any, Dict, List, Optional, Type

from sqlalchemy import types
from sqlalchemy.sql.sqltypes import Integer
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.sql.visitors import Visitable
Expand Down Expand Up @@ -91,3 +92,26 @@ def python_type(self) -> Optional[Type[Any]]:
@classmethod
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "ROW"


type_map = {
"boolean": types.Boolean,
"tinyint": TinyInteger,
"smallint": types.SmallInteger,
"integer": types.Integer,
"bigint": types.BigInteger,
"real": types.Float,
"double": types.Float,
"decimal": types.DECIMAL,
"varchar": types.String,
"char": types.CHAR,
"varbinary": types.VARBINARY,
"JSON": types.JSON,
"date": types.DATE,
"time": types.Time,
"timestamp": types.TIMESTAMP,
"interval": Interval,
"array": Array,
"map": Map,
"row": Row,
}
21 changes: 0 additions & 21 deletions tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from unittest import mock, skipUnless

import pandas as pd
from sqlalchemy import types
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql import select

Expand Down Expand Up @@ -491,23 +490,3 @@ def test_presto_expand_data_array(self):
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)

def test_get_sqla_column_type(self):
sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)")
assert isinstance(sqla_type, types.VARCHAR)
assert sqla_type.length == 255

sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar")
assert isinstance(sqla_type, types.String)
assert sqla_type.length is None

sqla_type = PrestoEngineSpec.get_sqla_column_type("char(10)")
assert isinstance(sqla_type, types.CHAR)
assert sqla_type.length == 10

sqla_type = PrestoEngineSpec.get_sqla_column_type("char")
assert isinstance(sqla_type, types.CHAR)
assert sqla_type.length is None

sqla_type = PrestoEngineSpec.get_sqla_column_type("integer")
assert isinstance(sqla_type, types.Integer)

0 comments on commit c3dd451

Please sign in to comment.