From 9db863b54bca34ca5dbe977490a23054cdb8c0de Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Mon, 1 Nov 2021 17:54:56 -0700 Subject: [PATCH] Fix tests --- superset/examples/birth_names.py | 14 ++++++-------- tests/integration_tests/dashboard_utils.py | 4 ++++ .../fixtures/birth_names_dashboard.py | 13 ++++++++++--- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index e1d8aff9221fb..fa9d188040e1c 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -101,18 +101,21 @@ def load_birth_names( only_metadata: bool = False, force: bool = False, sample: bool = False ) -> None: """Loading birth name dataset from a zip file in the repo""" - tbl_name = "birth_names" database = get_example_database() + engine = database.get_sqla_engine() + schema = inspect(engine).default_schema_name + + tbl_name = "birth_names" table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): load_data(tbl_name, database, sample=sample) table = get_table_connector_registry() - obj = db.session.query(table).filter_by(table_name=tbl_name).first() + obj = db.session.query(table).filter_by(table_name=tbl_name, schema=schema).first() if not obj: print(f"Creating table [{tbl_name}] reference") - obj = table(table_name=tbl_name) + obj = table(table_name=tbl_name, schema=schema) db.session.add(obj) _set_table_metadata(obj, database) @@ -125,13 +128,8 @@ def load_birth_names( def _set_table_metadata(datasource: SqlaTable, database: "Database") -> None: - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - datasource.main_dttm_col = "ds" datasource.database = database - if schema: - datasource.schema = schema datasource.filter_select_enabled = True datasource.fetch_metadata() diff --git a/tests/integration_tests/dashboard_utils.py b/tests/integration_tests/dashboard_utils.py index 85daa0b1b8d09..3d91f6178d73b 100644 --- a/tests/integration_tests/dashboard_utils.py +++ b/tests/integration_tests/dashboard_utils.py @@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional from pandas import DataFrame +from sqlalchemy import inspect from superset import ConnectorRegistry, db from superset.connectors.sqla.models import SqlaTable @@ -37,6 +38,9 @@ def create_table_for_dashboard( fetch_values_predicate: Optional[str] = None, schema: Optional[str] = None, ) -> SqlaTable: + engine = database.get_sqla_engine() + schema = schema or inspect(engine).default_schema_name + df.to_sql( table_name, database.get_sqla_engine(), diff --git a/tests/integration_tests/fixtures/birth_names_dashboard.py b/tests/integration_tests/fixtures/birth_names_dashboard.py index 7d78a22656a2f..bbdfc77da8cee 100644 --- a/tests/integration_tests/fixtures/birth_names_dashboard.py +++ b/tests/integration_tests/fixtures/birth_names_dashboard.py @@ -23,7 +23,7 @@ import pandas as pd import pytest from pandas import DataFrame -from sqlalchemy import DateTime, String, TIMESTAMP +from sqlalchemy import DateTime, inspect, String, TIMESTAMP from superset import ConnectorRegistry, db from superset.connectors.sqla.models import SqlaTable @@ -103,12 +103,19 @@ def _create_table( def _cleanup(dash_id: int, slices_ids: List[int]) -> None: - table_id = db.session.query(SqlaTable).filter_by(table_name="birth_names").one().id + engine = get_example_database().get_sqla_engine() + schema = inspect(engine).default_schema_name + + table_id = ( + db.session.query(SqlaTable) + .filter_by(table_name="birth_names", schema=schema) + .one() + .id + ) datasource = ConnectorRegistry.get_datasource("table", table_id, db.session) columns = [column for column in datasource.columns] metrics = [metric for metric in datasource.metrics] - engine = get_example_database().get_sqla_engine() engine.execute("DROP TABLE IF EXISTS birth_names") for column in columns: db.session.delete(column)