Skip to content

Commit

Permalink
perf: improve perf in SIP-68 migration (#19416)
Browse files Browse the repository at this point in the history
* chore: improve perf in SIP-68 migration

* Small fixes

* Create tables referenced in SQL

* Update logic in SqlaTable as well

* Fix unit tests
  • Loading branch information
betodealmeida authored Mar 30, 2022
1 parent 0968f86 commit 63b5e2e
Show file tree
Hide file tree
Showing 9 changed files with 364 additions and 41 deletions.
5 changes: 4 additions & 1 deletion scripts/benchmark_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def find_models(module: ModuleType) -> List[Type[Model]]:
while tables:
table = tables.pop()
seen.add(table)
model = getattr(Base.classes, table)
try:
model = getattr(Base.classes, table)
except AttributeError:
continue
model.__tablename__ = table
models.append(model)

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def get_git_sha() -> str:
"slackclient==2.5.0", # PINNED! slack changes file upload api in the future versions
"sqlalchemy>=1.3.16, <1.4, !=1.3.21",
"sqlalchemy-utils>=0.37.8, <0.38",
"sqloxide==0.1.15",
"sqlparse==0.3.0", # PINNED! see https://github.com/andialbrecht/sqlparse/issues/562
"tabulate==0.8.9",
# needed to support Literal (3.8) and TypeGuard (3.10)
Expand Down
36 changes: 19 additions & 17 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from superset.connectors.sqla.utils import (
get_physical_table_metadata,
get_virtual_table_metadata,
load_or_create_tables,
validate_adhoc_subquery,
)
from superset.datasets.models import Dataset as NewDataset
Expand Down Expand Up @@ -2242,7 +2243,10 @@ def write_shadow_dataset( # pylint: disable=too-many-locals
if column.is_active is False:
continue

extra_json = json.loads(column.extra or "{}")
try:
extra_json = json.loads(column.extra or "{}")
except json.decoder.JSONDecodeError:
extra_json = {}
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
value = getattr(column, attr)
if value:
Expand All @@ -2269,7 +2273,10 @@ def write_shadow_dataset( # pylint: disable=too-many-locals

# create metrics
for metric in dataset.metrics:
extra_json = json.loads(metric.extra or "{}")
try:
extra_json = json.loads(metric.extra or "{}")
except json.decoder.JSONDecodeError:
extra_json = {}
for attr in {"verbose_name", "metric_type", "d3format"}:
value = getattr(metric, attr)
if value:
Expand Down Expand Up @@ -2300,8 +2307,7 @@ def write_shadow_dataset( # pylint: disable=too-many-locals
)

# physical dataset
tables = []
if dataset.sql is None:
if not dataset.sql:
physical_columns = [column for column in columns if column.is_physical]

# create table
Expand All @@ -2314,7 +2320,7 @@ def write_shadow_dataset( # pylint: disable=too-many-locals
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
)
tables.append(table)
tables = [table]

# virtual dataset
else:
Expand All @@ -2325,18 +2331,14 @@ def write_shadow_dataset( # pylint: disable=too-many-locals
# find referenced tables
parsed = ParsedQuery(dataset.sql)
referenced_tables = parsed.tables

# predicate for finding the referenced tables
predicate = or_(
*[
and_(
NewTable.schema == (table.schema or dataset.schema),
NewTable.name == table.table,
)
for table in referenced_tables
]
tables = load_or_create_tables(
session,
dataset.database_id,
dataset.schema,
referenced_tables,
conditional_quote,
engine,
)
tables = session.query(NewTable).filter(predicate).all()

# create the new dataset
new_dataset = NewDataset(
Expand All @@ -2345,7 +2347,7 @@ def write_shadow_dataset( # pylint: disable=too-many-locals
expression=dataset.sql or conditional_quote(dataset.table_name),
tables=tables,
columns=columns,
is_physical=dataset.sql is None,
is_physical=not dataset.sql,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
)
Expand Down
87 changes: 85 additions & 2 deletions superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,34 @@
# specific language governing permissions and limitations
# under the License.
from contextlib import closing
from typing import Dict, List, Optional, TYPE_CHECKING
from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING

import sqlparse
from flask_babel import lazy_gettext as _
from sqlalchemy import and_, inspect, or_
from sqlalchemy.engine import Engine
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.orm import Session
from sqlalchemy.sql.type_api import TypeEngine

from superset.columns.models import Column as NewColumn
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
SupersetGenericDBErrorException,
SupersetSecurityException,
)
from superset.models.core import Database
from superset.result_set import SupersetResultSet
from superset.sql_parse import has_table_query, ParsedQuery
from superset.sql_parse import has_table_query, ParsedQuery, Table
from superset.tables.models import Table as NewTable

if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable


TEMPORAL_TYPES = {"DATETIME", "DATE", "TIME", "TIMEDELTA"}


def get_physical_table_metadata(
database: Database,
table_name: str,
Expand Down Expand Up @@ -151,3 +159,78 @@ def validate_adhoc_subquery(raw_sql: str) -> None:
)
)
return


def load_or_create_tables( # pylint: disable=too-many-arguments
session: Session,
database_id: int,
default_schema: Optional[str],
tables: Set[Table],
conditional_quote: Callable[[str], str],
engine: Engine,
) -> List[NewTable]:
"""
Load or create new table model instances.
"""
if not tables:
return []

# set the default schema in tables that don't have it
if default_schema:
fixed_tables = list(tables)
for i, table in enumerate(fixed_tables):
if table.schema is None:
fixed_tables[i] = Table(table.table, default_schema, table.catalog)
tables = set(fixed_tables)

# load existing tables
predicate = or_(
*[
and_(
NewTable.database_id == database_id,
NewTable.schema == table.schema,
NewTable.name == table.table,
)
for table in tables
]
)
new_tables = session.query(NewTable).filter(predicate).all()

# add missing tables
existing = {(table.schema, table.name) for table in new_tables}
for table in tables:
if (table.schema, table.table) not in existing:
try:
inspector = inspect(engine)
column_metadata = inspector.get_columns(
table.table, schema=table.schema
)
except Exception: # pylint: disable=broad-except
continue
columns = [
NewColumn(
name=column["name"],
type=str(column["type"]),
expression=conditional_quote(column["name"]),
is_temporal=column["type"].python_type.__name__.upper()
in TEMPORAL_TYPES,
is_aggregation=False,
is_physical=True,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
)
for column in column_metadata
]
new_tables.append(
NewTable(
name=table.table,
schema=table.schema,
catalog=None,
database_id=database_id,
columns=columns,
)
)
existing.add((table.schema, table.table))

return new_tables
66 changes: 66 additions & 0 deletions superset/migrations/shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,39 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Iterator, Optional, Set

from alembic import op
from sqlalchemy import engine_from_config
from sqlalchemy.engine import reflection
from sqlalchemy.exc import NoSuchTableError
from sqloxide import parse_sql

from superset.sql_parse import ParsedQuery, Table

logger = logging.getLogger("alembic")


# mapping between sqloxide and SQLAlchemy dialects
sqloxide_dialects = {
"ansi": {"trino", "trinonative", "presto"},
"hive": {"hive", "databricks"},
"ms": {"mssql"},
"mysql": {"mysql"},
"postgres": {
"cockroachdb",
"hana",
"netezza",
"postgres",
"postgresql",
"redshift",
"vertica",
},
"snowflake": {"snowflake"},
"sqlite": {"sqlite", "gsheets", "shillelagh"},
"clickhouse": {"clickhouse"},
}


def table_has_column(table: str, column: str) -> bool:
Expand All @@ -38,3 +67,40 @@ def table_has_column(table: str, column: str) -> bool:
return any(col["name"] == column for col in insp.get_columns(table))
except NoSuchTableError:
return False


def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]:
"""
Find all nodes in a SQL tree matching a given key.
"""
if isinstance(element, list):
for child in element:
yield from find_nodes_by_key(child, target)
elif isinstance(element, dict):
for key, value in element.items():
if key == target:
yield value
else:
yield from find_nodes_by_key(value, target)


def extract_table_references(sql_text: str, sqla_dialect: str) -> Set[Table]:
"""
Return all the dependencies from a SQL sql_text.
"""
dialect = "generic"
for dialect, sqla_dialects in sqloxide_dialects.items():
if sqla_dialect in sqla_dialects:
break
try:
tree = parse_sql(sql_text, dialect=dialect)
except Exception: # pylint: disable=broad-except
logger.warning("Unable to parse query with sqloxide: %s", sql_text)
# fallback to sqlparse
parsed = ParsedQuery(sql_text)
return parsed.tables

return {
Table(*[part["value"] for part in table["name"][::-1]])
for table in find_nodes_by_key(tree, "Table")
}
Loading

0 comments on commit 63b5e2e

Please sign in to comment.