Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(db_engine_specs): improve Presto column type matching #10658

Merged
merged 4 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
from datetime import datetime
from typing import (
Any,
Callable,
Dict,
List,
Match,
NamedTuple,
Optional,
Pattern,
Expand Down Expand Up @@ -142,6 +144,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
] = None # used for user messages, overridden in child classes
_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[
Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ...,
] = ()
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
time_secondary_columns = False
Expand Down Expand Up @@ -886,12 +891,18 @@ def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]:
"""
Return a sqlalchemy native column type that corresponds to the column type
defined in the data source (return None to use default type inferred by
SQLAlchemy). Needs to be overridden if column requires special handling
SQLAlchemy). Override `_column_type_mappings` for specific needs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we expect subclasses to override _column_type_mappings should we make it public?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, the thought crossed my mind while writing this but there were other similar properties that were private so I went with the flow. I changed this one to public, and I'll update the other ones in a later PR to keep this PR as light as possible.

(see MSSQL for example of NCHAR/NVARCHAR handling).

:param type_: Column type returned by inspector
:return: SqlAlchemy column type
"""
for regex, sqla_type in cls.column_type_mappings:
match = regex.match(type_)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@villebro we tested this during a deploy and got errors here because sometimes type_ is None. get_sqla_column_type is used in connectors/sqla/models.py where sometimes the type on a column can be null.

if match:
if callable(sqla_type):
return sqla_type(match)
return sqla_type
return None

@staticmethod
Expand Down
15 changes: 4 additions & 11 deletions superset/db_engine_specs/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datetime import datetime
from typing import Any, List, Optional, Tuple, TYPE_CHECKING

from sqlalchemy.types import String, TypeEngine, UnicodeText
from sqlalchemy.types import String, UnicodeText

from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.utils import core as utils
Expand Down Expand Up @@ -73,18 +73,11 @@ def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)

column_types = (
(String(), re.compile(r"^(?<!N)((VAR){0,1}CHAR|TEXT|STRING)", re.IGNORECASE)),
(UnicodeText(), re.compile(r"^N((VAR){0,1}CHAR|TEXT)", re.IGNORECASE)),
column_type_mappings = (
(re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), UnicodeText()),
(re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), String()),
)

@classmethod
def get_sqla_column_type(cls, type_: str) -> Optional[TypeEngine]:
for sqla_type, regex in cls.column_types:
if regex.match(type_):
return sqla_type
return None

@classmethod
def extract_error_message(cls, ex: Exception) -> str:
if str(ex).startswith("(8155,"):
Expand Down
89 changes: 60 additions & 29 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import pandas as pd
import simplejson as json
from flask_babel import lazy_gettext as _
from sqlalchemy import Column, literal_column
from sqlalchemy import Column, literal_column, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import RowProxy
Expand All @@ -40,7 +40,13 @@
from superset.db_engine_specs.base import BaseEngineSpec
from superset.exceptions import SupersetTemplateException
from superset.models.sql_lab import Query
from superset.models.sql_types.presto_sql_types import type_map as presto_type_map
from superset.models.sql_types.presto_sql_types import (
Array,
Interval,
Map,
Row,
TinyInteger,
)
from superset.result_set import destringify
from superset.sql_parse import ParsedQuery
from superset.utils import core as utils
Expand Down Expand Up @@ -260,23 +266,24 @@ def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branch
field_info = cls._split_data_type(single_field, r"\s")
# check if there is a structural data type within
# overall structural data type
column_type = cls.get_sqla_column_type(field_info[1])
if column_type is None:
raise NotImplementedError(
_("Unknown column type: %(col)s", col=field_info[1])
)
if field_info[1] == "array" or field_info[1] == "row":
stack.append((field_info[0], field_info[1]))
full_parent_path = cls._get_full_name(stack)
result.append(
cls._create_column_info(
full_parent_path, presto_type_map[field_info[1]]()
)
cls._create_column_info(full_parent_path, column_type)
)
else: # otherwise this field is a basic data type
full_parent_path = cls._get_full_name(stack)
column_name = "{}.{}".format(
full_parent_path, field_info[0]
)
result.append(
cls._create_column_info(
column_name, presto_type_map[field_info[1]]()
)
cls._create_column_info(column_name, column_type)
)
# If the component type ends with a structural data type, do not pop
# the stack. We have run across a structural data type within the
Expand Down Expand Up @@ -318,6 +325,34 @@ def _show_columns(
columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table))
return columns

column_type_mappings = (
(re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()),
(re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()),
(re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()),
(re.compile(r"^integer.*", re.IGNORECASE), types.Integer()),
(re.compile(r"^bigint.*", re.IGNORECASE), types.BigInteger()),
(re.compile(r"^real.*", re.IGNORECASE), types.Float()),
(re.compile(r"^double.*", re.IGNORECASE), types.Float()),
(re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL()),
(
re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(),
),
Comment on lines +337 to +340
Copy link
Member Author

@villebro villebro Aug 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we're instantiating a types.VARCHAR object with a set length based on the regex match. This is usually unnecessary, but I thought I'd throw it in here to demonstrate how the matches from the regex can be used to customize the SQLA object.

(
re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(),
),
(re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY()),
(re.compile(r"^json.*", re.IGNORECASE), types.JSON()),
(re.compile(r"^date.*", re.IGNORECASE), types.DATE()),
(re.compile(r"^time.*", re.IGNORECASE), types.Time()),
(re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP()),
(re.compile(r"^interval.*", re.IGNORECASE), Interval()),
(re.compile(r"^array.*", re.IGNORECASE), Array()),
(re.compile(r"^map.*", re.IGNORECASE), Map()),
(re.compile(r"^row.*", re.IGNORECASE), Row()),
)

@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
Expand All @@ -334,28 +369,24 @@ def get_columns(
columns = cls._show_columns(inspector, table_name, schema)
result: List[Dict[str, Any]] = []
for column in columns:
try:
# parse column if it is a row or array
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
"array" in column.Type or "row" in column.Type
):
structural_column_index = len(result)
cls._parse_structural_column(column.Column, column.Type, result)
result[structural_column_index]["nullable"] = getattr(
column, "Null", True
)
result[structural_column_index]["default"] = None
continue

# otherwise column is a basic data type
column_type = presto_type_map[column.Type]()
except KeyError:
logger.info(
"Did not recognize type {} of column {}".format( # pylint: disable=logging-format-interpolation
column.Type, column.Column
)
# parse column if it is a row or array
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
"array" in column.Type or "row" in column.Type
):
structural_column_index = len(result)
cls._parse_structural_column(column.Column, column.Type, result)
result[structural_column_index]["nullable"] = getattr(
column, "Null", True
)
result[structural_column_index]["default"] = None
continue

# otherwise column is a basic data type
column_type = cls.get_sqla_column_type(column.Type)
if column_type is None:
raise NotImplementedError(
_("Unknown column type: %(col)s", col=column_type)
)
column_type = "OTHER"
column_info = cls._create_column_info(column.Column, column_type)
column_info["nullable"] = getattr(column, "Null", True)
column_info["default"] = None
Expand Down
24 changes: 0 additions & 24 deletions superset/models/sql_types/presto_sql_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
from typing import Any, Dict, List, Optional, Type

from sqlalchemy import types
from sqlalchemy.sql.sqltypes import Integer
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.sql.visitors import Visitable
Expand Down Expand Up @@ -92,26 +91,3 @@ def python_type(self) -> Optional[Type[Any]]:
@classmethod
def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "ROW"


type_map = {
"boolean": types.Boolean,
"tinyint": TinyInteger,
"smallint": types.SmallInteger,
"integer": types.Integer,
"bigint": types.BigInteger,
"real": types.Float,
"double": types.Float,
"decimal": types.DECIMAL,
"varchar": types.String,
"char": types.CHAR,
"varbinary": types.VARBINARY,
"JSON": types.JSON,
"date": types.DATE,
"time": types.Time,
"timestamp": types.TIMESTAMP,
"interval": Interval,
"array": Array,
"map": Map,
"row": Row,
}
21 changes: 21 additions & 0 deletions tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from unittest import mock, skipUnless

import pandas as pd
from sqlalchemy import types
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.sql import select

Expand Down Expand Up @@ -490,3 +491,23 @@ def test_presto_expand_data_array(self):
self.assertEqual(actual_cols, expected_cols)
self.assertEqual(actual_data, expected_data)
self.assertEqual(actual_expanded_cols, expected_expanded_cols)

def test_get_sqla_column_type(self):
sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to add some more test cases here especially the ones that are not supported by presto

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is blocking another fairly critical feature I'll defer the more comprehensive tests to a forthcoming PR and just cover the discovered bug here. I'm hoping I can carve out some time to build a template for testing this kind of functionality, and don't mind doing the initial work on the Presto spec.

assert isinstance(sqla_type, types.VARCHAR)
assert sqla_type.length == 255

sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar")
assert isinstance(sqla_type, types.String)
Comment on lines +496 to +501
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we can see that varchar(255) creates a 255 long VARCHAR object, whereas just varchar leaves the length undefined.

assert sqla_type.length is None

sqla_type = PrestoEngineSpec.get_sqla_column_type("char(10)")
assert isinstance(sqla_type, types.CHAR)
assert sqla_type.length == 10

sqla_type = PrestoEngineSpec.get_sqla_column_type("char")
assert isinstance(sqla_type, types.CHAR)
assert sqla_type.length is None

sqla_type = PrestoEngineSpec.get_sqla_column_type("integer")
assert isinstance(sqla_type, types.Integer)