Skip to content

Commit

Permalink
fix: set correct schema on config import (#16041)
Browse files Browse the repository at this point in the history
* fix: set correct schema on config import

* Fix lint

* Fix test

* Fix tests

* Fix another test

* Fix another test

* Fix base test

* Add helper function

* Fix examples

* Fix test

* Fix test

* Fixing more tests

(cherry picked from commit 1fbce88)
  • Loading branch information
betodealmeida authored and eschutho committed Dec 10, 2021
1 parent 3d8ce13 commit 77c4f2c
Show file tree
Hide file tree
Showing 30 changed files with 309 additions and 116 deletions.
8 changes: 6 additions & 2 deletions superset/commands/importers/v1/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from superset.datasets.commands.importers.v1.utils import import_dataset
from superset.datasets.schemas import ImportV1DatasetSchema
from superset.models.dashboard import dashboard_slices
from superset.utils.core import get_example_database
from superset.utils.core import get_example_database, get_example_default_schema


class ImportExamplesCommand(ImportModelsCommand):
Expand Down Expand Up @@ -85,7 +85,7 @@ def _get_uuids(cls) -> Set[str]:
)

@staticmethod
def _import( # pylint: disable=arguments-differ,too-many-locals
def _import( # pylint: disable=arguments-differ, too-many-locals, too-many-branches
session: Session,
configs: Dict[str, Any],
overwrite: bool = False,
Expand Down Expand Up @@ -114,6 +114,10 @@ def _import( # pylint: disable=arguments-differ,too-many-locals
else:
config["database_id"] = database_ids[config["database_uuid"]]

# set schema
if config["schema"] is None:
config["schema"] = get_example_default_schema()

dataset = import_dataset(
session, config, overwrite=overwrite, force_data=force_data
)
Expand Down
2 changes: 1 addition & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,7 +1660,7 @@ def before_update(
target: "SqlaTable",
) -> None:
"""
Check whether before update if the target table already exists.
Check before update if the target table already exists.
Note this listener is called when any fields are being updated and thus it is
necessary to first check whether the reference table is being updated.
Expand Down
15 changes: 14 additions & 1 deletion superset/datasets/commands/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from flask import current_app, g
from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, String, Text
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import MultipleResultsFound
from sqlalchemy.sql.visitors import VisitableType

from superset.connectors.sqla.models import SqlaTable
Expand Down Expand Up @@ -110,7 +111,19 @@ def import_dataset(
data_uri = config.get("data")

# import recursively to include columns and metrics
dataset = SqlaTable.import_from_dict(session, config, recursive=True, sync=sync)
try:
dataset = SqlaTable.import_from_dict(session, config, recursive=True, sync=sync)
except MultipleResultsFound:
# Finding multiple results when importing a dataset only happens because initially
# datasets were imported without schemas (eg, `examples.NULL.users`), and later
# they were fixed to have the default schema (eg, `examples.public.users`). If a
# user created `examples.public.users` during that time the second import will
# fail because the UUID match will try to update `examples.NULL.users` to
# `examples.public.users`, resulting in a conflict.
#
# When that happens, we return the original dataset, unmodified.
dataset = session.query(SqlaTable).filter_by(uuid=config["uuid"]).one()

if dataset.id is None:
session.flush()

Expand Down
9 changes: 6 additions & 3 deletions superset/examples/bart_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import pandas as pd
import polyline
from sqlalchemy import String, Text
from sqlalchemy import inspect, String, Text

from superset import db
from superset.utils.core import get_example_database
Expand All @@ -29,6 +29,8 @@
def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
tbl_name = "bart_lines"
database = get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

if not only_metadata and (not table_exists or force):
Expand All @@ -40,7 +42,8 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:

df.to_sql(
tbl_name,
database.get_sqla_engine(),
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
Expand All @@ -56,7 +59,7 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
table = get_table_connector_registry()
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name)
tbl = table(table_name=tbl_name, schema=schema)
tbl.description = "BART lines"
tbl.database = database
tbl.filter_select_enabled = True
Expand Down
32 changes: 18 additions & 14 deletions superset/examples/birth_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@

import pandas as pd
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import DateTime, String
from sqlalchemy import DateTime, inspect, String
from sqlalchemy.sql import column

from superset import app, db, security_manager
from superset.connectors.base.models import BaseDatasource
from superset.connectors.sqla.models import SqlMetric, TableColumn
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.exceptions import NoDataException
from superset.models.core import Database
from superset.models.dashboard import Dashboard
Expand Down Expand Up @@ -75,9 +74,13 @@ def load_data(tbl_name: str, database: Database, sample: bool = False) -> None:
pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
pdf = pdf.head(100) if sample else pdf

engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name

pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
Expand All @@ -98,18 +101,21 @@ def load_birth_names(
only_metadata: bool = False, force: bool = False, sample: bool = False
) -> None:
"""Loading birth name dataset from a zip file in the repo"""
tbl_name = "birth_names"
database = get_example_database()
table_exists = database.has_table_by_name(tbl_name)
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name

tbl_name = "birth_names"
table_exists = database.has_table_by_name(tbl_name, schema=schema)

if not only_metadata and (not table_exists or force):
load_data(tbl_name, database, sample=sample)

table = get_table_connector_registry()
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
obj = db.session.query(table).filter_by(table_name=tbl_name, schema=schema).first()
if not obj:
print(f"Creating table [{tbl_name}] reference")
obj = table(table_name=tbl_name)
obj = table(table_name=tbl_name, schema=schema)
db.session.add(obj)

_set_table_metadata(obj, database)
Expand All @@ -121,14 +127,14 @@ def load_birth_names(
create_dashboard(slices)


def _set_table_metadata(datasource: "BaseDatasource", database: "Database") -> None:
datasource.main_dttm_col = "ds" # type: ignore
def _set_table_metadata(datasource: SqlaTable, database: "Database") -> None:
datasource.main_dttm_col = "ds"
datasource.database = database
datasource.filter_select_enabled = True
datasource.fetch_metadata()


def _add_table_metrics(datasource: "BaseDatasource") -> None:
def _add_table_metrics(datasource: SqlaTable) -> None:
if not any(col.column_name == "num_california" for col in datasource.columns):
col_state = str(column("state").compile(db.engine))
col_num = str(column("num").compile(db.engine))
Expand All @@ -147,13 +153,11 @@ def _add_table_metrics(datasource: "BaseDatasource") -> None:

for col in datasource.columns:
if col.column_name == "ds":
col.is_dttm = True # type: ignore
col.is_dttm = True
break


def create_slices(
tbl: BaseDatasource, admin_owner: bool
) -> Tuple[List[Slice], List[Slice]]:
def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[Slice]]:
metrics = [
{
"expressionType": "SIMPLE",
Expand Down
9 changes: 6 additions & 3 deletions superset/examples/country_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import datetime

import pandas as pd
from sqlalchemy import BigInteger, Date, String
from sqlalchemy import BigInteger, Date, inspect, String
from sqlalchemy.sql import column

from superset import db
Expand All @@ -38,6 +38,8 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
"""Loading data for map with country map"""
tbl_name = "birth_france_by_region"
database = utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

if not only_metadata and (not table_exists or force):
Expand All @@ -48,7 +50,8 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
data["dttm"] = datetime.datetime.now().date()
data.to_sql(
tbl_name,
database.get_sqla_engine(),
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
Expand Down Expand Up @@ -76,7 +79,7 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
table = get_table_connector_registry()
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = table(table_name=tbl_name)
obj = table(table_name=tbl_name, schema=schema)
obj.main_dttm_col = "dttm"
obj.database = database
obj.filter_select_enabled = True
Expand Down
9 changes: 6 additions & 3 deletions superset/examples/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import textwrap

import pandas as pd
from sqlalchemy import Float, String
from sqlalchemy import Float, inspect, String
from sqlalchemy.sql import column

from superset import db
Expand All @@ -40,6 +40,8 @@ def load_energy(
"""Loads an energy related dataset to use with sankey and graphs"""
tbl_name = "energy_usage"
database = utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

if not only_metadata and (not table_exists or force):
Expand All @@ -48,7 +50,8 @@ def load_energy(
pdf = pdf.head(100) if sample else pdf
pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={"source": String(255), "target": String(255), "value": Float()},
Expand All @@ -60,7 +63,7 @@ def load_energy(
table = get_table_connector_registry()
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name)
tbl = table(table_name=tbl_name, schema=schema)
tbl.description = "Energy consumption"
tbl.database = database
tbl.filter_select_enabled = True
Expand Down
9 changes: 6 additions & 3 deletions superset/examples/flights.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import pandas as pd
from sqlalchemy import DateTime
from sqlalchemy import DateTime, inspect

from superset import db
from superset.utils import core as utils
Expand All @@ -27,6 +27,8 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
"""Loading random time series data from a zip file in the repo"""
tbl_name = "flights"
database = utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

if not only_metadata and (not table_exists or force):
Expand All @@ -47,7 +49,8 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST")
pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={"ds": DateTime},
Expand All @@ -57,7 +60,7 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
table = get_table_connector_registry()
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name)
tbl = table(table_name=tbl_name, schema=schema)
tbl.description = "Random set of flights in the US"
tbl.database = database
tbl.filter_select_enabled = True
Expand Down
9 changes: 6 additions & 3 deletions superset/examples/long_lat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import geohash
import pandas as pd
from sqlalchemy import DateTime, Float, String
from sqlalchemy import DateTime, Float, inspect, String

from superset import db
from superset.models.slice import Slice
Expand All @@ -38,6 +38,8 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
"""Loading lat/long data from a csv file in the repo"""
tbl_name = "long_lat"
database = utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

if not only_metadata and (not table_exists or force):
Expand All @@ -56,7 +58,8 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",")
pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
Expand Down Expand Up @@ -85,7 +88,7 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
table = get_table_connector_registry()
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = table(table_name=tbl_name)
obj = table(table_name=tbl_name, schema=schema)
obj.main_dttm_col = "datetime"
obj.database = database
obj.filter_select_enabled = True
Expand Down
9 changes: 6 additions & 3 deletions superset/examples/multiformat_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Dict, Optional, Tuple

import pandas as pd
from sqlalchemy import BigInteger, Date, DateTime, String
from sqlalchemy import BigInteger, Date, DateTime, inspect, String

from superset import app, db
from superset.models.slice import Slice
Expand All @@ -38,6 +38,8 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
"""Loading time series data from a zip file in the repo"""
tbl_name = "multiformat_time_series"
database = get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

if not only_metadata and (not table_exists or force):
Expand All @@ -55,7 +57,8 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals

pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
Expand All @@ -77,7 +80,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
table = get_table_connector_registry()
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = table(table_name=tbl_name)
obj = table(table_name=tbl_name, schema=schema)
obj.main_dttm_col = "ds"
obj.database = database
obj.filter_select_enabled = True
Expand Down
Loading

0 comments on commit 77c4f2c

Please sign in to comment.