Skip to content

Commit

Permalink
fix(db_engine_specs): improve Presto column type matching (apache#10658)
Browse files Browse the repository at this point in the history
* fix: improve Presto column type matching

* add optional callback to type map and add tests

* lint

* change private to public
  • Loading branch information
villebro authored and Ofeknielsen committed Oct 5, 2020
1 parent 4085597 commit 2c6a0c8
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 65 deletions.
13 changes: 12 additions & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
from datetime import datetime
from typing import (
Any,
Callable,
Dict,
List,
Match,
NamedTuple,
Optional,
Pattern,
Expand Down Expand Up @@ -142,6 +144,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
] = None # used for user messages, overridden in child classes
_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[
Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ...,
] = ()
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
time_secondary_columns = False
Expand Down Expand Up @@ -886,12 +891,18 @@ def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]:
"""
Return a sqlalchemy native column type that corresponds to the column type
defined in the data source (return None to use default type inferred by
SQLAlchemy). Needs to be overridden if column requires special handling
SQLAlchemy). Override `_column_type_mappings` for specific needs
(see MSSQL for example of NCHAR/NVARCHAR handling).
:param type_: Column type returned by inspector
:return: SqlAlchemy column type
"""
for regex, sqla_type in cls.column_type_mappings:
match = regex.match(type_)
if match:
if callable(sqla_type):
return sqla_type(match)
return sqla_type
return None

@staticmethod
Expand Down
15 changes: 4 additions & 11 deletions superset/db_engine_specs/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datetime import datetime
from typing import Any, List, Optional, Tuple, TYPE_CHECKING

from sqlalchemy.types import String, TypeEngine, UnicodeText
from sqlalchemy.types import String, UnicodeText

from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.utils import core as utils
Expand Down Expand Up @@ -73,18 +73,11 @@ def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)

column_types = (
(String(), re.compile(r"^(?<!N)((VAR){0,1}CHAR|TEXT|STRING)", re.IGNORECASE)),
(UnicodeText(), re.compile(r"^N((VAR){0,1}CHAR|TEXT)", re.IGNORECASE)),
column_type_mappings = (
(re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), UnicodeText()),
(re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), String()),
)

@classmethod
def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]:
for sqla_type, regex in cls.column_types:
if regex.match(type_):
return sqla_type
return None

@classmethod
def extract_error_message(cls, ex: Exception) -> str:
if str(ex).startswith("(8155,"):
Expand Down
89 changes: 60 additions & 29 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import pandas as pd
import simplejson as json
from flask_babel import lazy_gettext as _
from sqlalchemy import Column, literal_column
from sqlalchemy import Column, literal_column, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import RowProxy
Expand All @@ -40,7 +40,13 @@
from superset.db_engine_specs.base import BaseEngineSpec
from superset.exceptions import SupersetTemplateException
from superset.models.sql_lab import Query
from superset.models.sql_types.presto_sql_types import type_map as presto_type_map
from superset.models.sql_types.presto_sql_types import (
Array,
Interval,
Map,
Row,
TinyInteger,
)
from superset.result_set import destringify
from superset.sql_parse import ParsedQuery
from superset.utils import core as utils
Expand Down Expand Up @@ -260,23 +266,24 @@ def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branch
field_info = cls._split_data_type(single_field, r"\s")
# check if there is a structural data type within
# overall structural data type
column_type = cls.get_sqla_column_type(field_info[1])
if column_type is None:
raise NotImplementedError(
_("Unknown column type: %(col)s", col=field_info[1])
)
if field_info[1] == "array" or field_info[1] == "row":
stack.append((field_info[0], field_info[1]))
full_parent_path = cls._get_full_name(stack)
result.append(
cls._create_column_info(
full_parent_path, presto_type_map[field_info[1]]()
)
cls._create_column_info(full_parent_path, column_type)
)
else: # otherwise this field is a basic data type
full_parent_path = cls._get_full_name(stack)
column_name = "{}.{}".format(
full_parent_path, field_info[0]
)
result.append(
cls._create_column_info(
column_name, presto_type_map[field_info[1]]()
)
cls._create_column_info(column_name, column_type)
)
# If the component type ends with a structural data type, do not pop
# the stack. We have run across a structural data type within the
Expand Down Expand Up @@ -318,6 +325,34 @@ def _show_columns(
columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table))
return columns

column_type_mappings = (
(re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()),
(re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()),
(re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()),
(re.compile(r"^integer.*", re.IGNORECASE), types.Integer()),
(re.compile(r"^bigint.*", re.IGNORECASE), types.BigInteger()),
(re.compile(r"^real.*", re.IGNORECASE), types.Float()),
(re.compile(r"^double.*", re.IGNORECASE), types.Float()),
(re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL()),
(
re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(),
),
(
re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(),
),
(re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY()),
(re.compile(r"^json.*", re.IGNORECASE), types.JSON()),
(re.compile(r"^date.*", re.IGNORECASE), types.DATE()),
(re.compile(r"^time.*", re.IGNORECASE), types.Time()),
(re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP()),
(re.compile(r"^interval.*", re.IGNORECASE), Interval()),
(re.compile(r"^array.*", re.IGNORECASE), Array()),
(re.compile(r"^map.*", re.IGNORECASE), Map()),
(re.compile(r"^row.*", re.IGNORECASE), Row()),
)

@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
Expand All @@ -334,28 +369,24 @@ def get_columns(
columns = cls._show_columns(inspector, table_name, schema)
result: List[Dict[str, Any]] = []
for column in columns:
try:
# parse column if it is a row or array
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
"array" in column.Type or "row" in column.Type
):
structural_column_index = len(result)
cls._parse_structural_column(column.Column, column.Type, result)
result[structural_column_index]["nullable"] = getattr(
column, "Null", True
)
result[structural_column_index]["default"] = None
continue

# otherwise column is a basic data type
column_type = presto_type_map[column.Type]()
except KeyError:
logger.info(
"Did not recognize type {} of column {}".format( # pylint: disable=logging-format-interpolation
column.Type, column.Column
)
# parse column if it is a row or array
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
"array" in column.Type or "row" in column.Type
):
structural_column_index = len(result)
cls._parse_structural_column(column.Column, column.Type, result)
result[structural_column_index]["nullable"] = getattr(
column, "Null", True
)
result[structural_column_index]["default"] = None
continue

# otherwise column is a basic data type
column_type = cls.get_sqla_column_type(column.Type)
if column_type is None:
raise NotImplementedError(
_("Unknown column type: %(col)s", col=column_type)
)
column_type = "OTHER"
column_info = cls._create_column_info(column.Column, column_type)
column_info["nullable"] = getattr(column, "Null", True)
column_info["default"] = None
Expand Down
24 changes: 0 additions & 24 deletions superset/models/sql_types/presto_sql_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
from typing import Any, Dict, List, Optional, Type

from sqlalchemy import types
from sqlalchemy.sql.sqltypes import Integer
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.sql.visitors import Visitable
Expand Down Expand Up @@ -92,26 +91,3 @@ def python_type(self) -> Optional[Type[Any]]:
@classmethod
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "ROW"


type_map = {
"boolean": types.Boolean,
"tinyint": TinyInteger,
"smallint": types.SmallInteger,
"integer": types.Integer,
"bigint": types.BigInteger,
"real": types.Float,
"double": types.Float,
"decimal": types.DECIMAL,
"varchar": types.String,
"char": types.CHAR,
"varbinary": types.VARBINARY,
"JSON": types.JSON,
"date": types.DATE,
"time": types.Time,
"timestamp": types.TIMESTAMP,
"interval": Interval,
"array": Array,
"map": Map,
"row": Row,
}
21 changes: 21 additions & 0 deletions tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from unittest import mock, skipUnless

import pandas as pd
from sqlalchemy import types
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql import select

Expand Down Expand Up @@ -490,3 +491,23 @@ def test_presto_expand_data_array(self):
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)

def test_get_sqla_column_type(self):
sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)")
assert isinstance(sqla_type, types.VARCHAR)
assert sqla_type.length == 255

sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar")
assert isinstance(sqla_type, types.String)
assert sqla_type.length is None

sqla_type = PrestoEngineSpec.get_sqla_column_type("char(10)")
assert isinstance(sqla_type, types.CHAR)
assert sqla_type.length == 10

sqla_type = PrestoEngineSpec.get_sqla_column_type("char")
assert isinstance(sqla_type, types.CHAR)
assert sqla_type.length is None

sqla_type = PrestoEngineSpec.get_sqla_column_type("integer")
assert isinstance(sqla_type, types.Integer)

0 comments on commit 2c6a0c8

Please sign in to comment.