diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index 6b37fe9d08dcf..b86f7a25c66b1 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -160,6 +160,9 @@ def _add_table_metrics(datasource: SqlaTable) -> None: col.is_dttm = True break + datasource.columns = columns + datasource.metrics = metrics + def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[Slice]]: metrics = [ diff --git a/superset/views/core.py b/superset/views/core.py index 2985350110aa9..0598c72446832 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1979,6 +1979,8 @@ def dashboard( return self.render_template( "superset/spa.html", entry="spa", + # dashboard title is always visible + title=dashboard.dashboard_title, bootstrap_data=json.dumps( bootstrap_data, default=utils.pessimistic_json_iso_dttm_ser ), diff --git a/tests/integration_tests/advanced_data_type/api_tests.py b/tests/integration_tests/advanced_data_type/api_tests.py index 7daa5f0067ddc..5bfe308e1683b 100644 --- a/tests/integration_tests/advanced_data_type/api_tests.py +++ b/tests/integration_tests/advanced_data_type/api_tests.py @@ -18,17 +18,9 @@ """Unit tests for Superset""" import json import prison -from sqlalchemy import null -from superset.connectors.sqla.models import SqlaTable from superset.utils.core import get_example_default_schema -from tests.integration_tests.base_tests import ( - SupersetTestCase, - logged_in_admin, - test_client, -) -from tests.integration_tests.test_app import app from tests.integration_tests.utils.get_dashboards import get_dashboards_ids from unittest import mock from sqlalchemy import Column @@ -80,7 +72,7 @@ def translate_filter_func(col: Column, op: FilterOperator, values: List[Any]): "superset.advanced_data_type.api.ADVANCED_DATA_TYPES", {"type": 1}, ) -def test_types_type_request(logged_in_admin): +def test_types_type_request(test_client, login_as_admin): """ Advanced Data Type API: Test to see if the API call returns all the valid advanced data types """ @@ -91,7 +83,7 @@ def test_types_type_request(logged_in_admin): assert data == {"result": ["type"]} -def test_types_convert_bad_request_no_vals(logged_in_admin): +def test_types_convert_bad_request_no_vals(test_client, login_as_admin): """ Advanced Data Type API: Test request to see if it behaves as expected when no values are passed """ @@ -101,7 +93,7 @@ def test_types_convert_bad_request_no_vals(logged_in_admin): assert response_value.status_code == 400 -def test_types_convert_bad_request_no_type(logged_in_admin): +def test_types_convert_bad_request_no_type(test_client, login_as_admin): """ Advanced Data Type API: Test request to see if it behaves as expected when no type is passed """ @@ -115,7 +107,7 @@ def test_types_convert_bad_request_no_type(logged_in_admin): "superset.advanced_data_type.api.ADVANCED_DATA_TYPES", {"type": 1}, ) -def test_types_convert_bad_request_type_not_found(logged_in_admin): +def test_types_convert_bad_request_type_not_found(test_client, login_as_admin): """ Advanced Data Type API: Test request to see if it behaves as expected when passed in type is not found/not valid @@ -130,7 +122,7 @@ def test_types_convert_bad_request_type_not_found(logged_in_admin): "superset.advanced_data_type.api.ADVANCED_DATA_TYPES", {"type": test_type}, ) -def test_types_convert_request(logged_in_admin): +def test_types_convert_request(test_client, login_as_admin): """ Advanced Data Type API: Test request to see if it behaves as expected when a valid type and valid values are passed in diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index 280d58774ab06..ee9eee299ab11 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -24,7 +24,6 @@ from unittest.mock import Mock, patch, MagicMock import pandas as pd -import pytest from flask import Response from flask_appbuilder.security.sqla import models as ab_models from flask_testing import TestCase @@ -34,7 +33,7 @@ from sqlalchemy.sql import func from sqlalchemy.dialects.mysql import dialect -from tests.integration_tests.test_app import app +from tests.integration_tests.test_app import app, login from superset.sql_parse import CtasMethod from superset import db, security_manager from superset.connectors.base.models import BaseDatasource @@ -52,11 +51,6 @@ test_client = app.test_client() -def login(client: Any, username: str = "admin", password: str = "general"): - resp = get_resp(client, "/login/", data=dict(username=username, password=password)) - assert "User confirmation needed" not in resp - - def get_resp( client: Any, url: str, @@ -101,15 +95,6 @@ def post_assert_metric( return rv -@pytest.fixture -def logged_in_admin(): - """Fixture with app context and logged in admin user.""" - with app.app_context(): - login(test_client, username="admin") - yield - test_client.get("/logout/", follow_redirects=True) - - class SupersetTestCase(TestCase): default_schema_backend_map = { "sqlite": "main", diff --git a/tests/integration_tests/cachekeys/api_tests.py b/tests/integration_tests/cachekeys/api_tests.py index e994380e9d998..d3552bfc8df26 100644 --- a/tests/integration_tests/cachekeys/api_tests.py +++ b/tests/integration_tests/cachekeys/api_tests.py @@ -18,7 +18,7 @@ """Unit tests for Superset""" from typing import Dict, Any -from tests.integration_tests.test_app import app # noqa +import pytest from superset.extensions import cache_manager, db from superset.models.cache import CacheKey @@ -26,23 +26,25 @@ from tests.integration_tests.base_tests import ( SupersetTestCase, post_assert_metric, - test_client, - logged_in_admin, -) # noqa +) -def invalidate(params: Dict[str, Any]): - return post_assert_metric( - test_client, "api/v1/cachekey/invalidate", params, "invalidate" - ) +@pytest.fixture +def invalidate(test_client, login_as_admin): + def _invalidate(params: Dict[str, Any]): + return post_assert_metric( + test_client, "api/v1/cachekey/invalidate", params, "invalidate" + ) + + return _invalidate -def test_invalidate_cache(logged_in_admin): +def test_invalidate_cache(invalidate): rv = invalidate({"datasource_uids": ["3__table"]}) assert rv.status_code == 201 -def test_invalidate_existing_cache(logged_in_admin): +def test_invalidate_existing_cache(invalidate): db.session.add(CacheKey(cache_key="cache_key", datasource_uid="3__table")) db.session.commit() cache_manager.cache.set("cache_key", "value") @@ -56,7 +58,7 @@ def test_invalidate_existing_cache(logged_in_admin): ) -def test_invalidate_cache_empty_input(logged_in_admin): +def test_invalidate_cache_empty_input(invalidate): rv = invalidate({"datasource_uids": []}) assert rv.status_code == 201 @@ -67,7 +69,7 @@ def test_invalidate_cache_empty_input(logged_in_admin): assert rv.status_code == 201 -def test_invalidate_cache_bad_request(logged_in_admin): +def test_invalidate_cache_bad_request(invalidate): rv = invalidate( { "datasource_uids": [], @@ -93,7 +95,7 @@ def test_invalidate_cache_bad_request(logged_in_admin): assert rv.status_code == 400 -def test_invalidate_existing_caches(logged_in_admin): +def test_invalidate_existing_caches(invalidate): schema = get_example_default_schema() or "" bn = SupersetTestCase.get_birth_names_dataset() diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index 9a9447640f120..f057d3128e574 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -17,7 +17,6 @@ # isort:skip_file """Unit tests for Superset Celery worker""" import datetime -import json import random import string import time @@ -33,9 +32,6 @@ import flask from flask import current_app -from tests.integration_tests.base_tests import login -from tests.integration_tests.conftest import CTAS_SCHEMA_NAME -from tests.integration_tests.test_app import app from superset import db, sql_lab from superset.common.db_query_status import QueryStatus from superset.result_set import SupersetResultSet @@ -46,6 +42,8 @@ from superset.sql_parse import ParsedQuery, CtasMethod from superset.utils.core import backend from superset.utils.database import get_example_database +from tests.integration_tests.conftest import CTAS_SCHEMA_NAME +from tests.integration_tests.test_app import app CELERY_SLEEP_TIME = 6 QUERY = "SELECT name FROM birth_names LIMIT 1" @@ -63,9 +61,6 @@ ] -test_client = app.test_client() - - def get_query_by_id(id: int): db.session.commit() query = db.session.query(Query).filter_by(id=id).first() @@ -74,10 +69,10 @@ def get_query_by_id(id: int): @pytest.fixture(autouse=True, scope="module") def setup_sqllab(): - + yield + # clean up after all tests are done + # use a new app context with app.app_context(): - yield - db.session.query(Query).delete() db.session.commit() for tbl in TMP_TABLES: @@ -92,11 +87,15 @@ def setup_sqllab(): def run_sql( - sql, cta=False, ctas_method=CtasMethod.TABLE, tmp_table="tmp", async_=False + test_client, + sql, + cta=False, + ctas_method=CtasMethod.TABLE, + tmp_table="tmp", + async_=False, ): - login(test_client, username="admin") db_id = get_example_database().id - resp = test_client.post( + return test_client.post( "/superset/sql_json/", json=dict( database_id=db_id, @@ -107,9 +106,7 @@ def run_sql( client_id="".join(random.choice(string.ascii_lowercase) for i in range(5)), ctas_method=ctas_method, ), - ) - test_client.get("/logout/", follow_redirects=True) - return json.loads(resp.data) + ).json def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None: @@ -144,12 +141,13 @@ def get_select_star(table: str, limit: int, schema: Optional[str] = None): return f"SELECT *\nFROM {table}\nLIMIT {limit}" +@pytest.mark.usefixtures("login_as_admin") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) -def test_run_sync_query_dont_exist(setup_sqllab, ctas_method): +def test_run_sync_query_dont_exist(test_client, ctas_method): examples_db = get_example_database() engine_name = examples_db.db_engine_spec.engine_name sql_dont_exist = "SELECT name FROM table_dont_exist" - result = run_sql(sql_dont_exist, cta=True, ctas_method=ctas_method) + result = run_sql(test_client, sql_dont_exist, cta=True, ctas_method=ctas_method) if backend() == "sqlite" and ctas_method == CtasMethod.VIEW: assert QueryStatus.SUCCESS == result["status"], result elif backend() == "presto": @@ -188,27 +186,29 @@ def test_run_sync_query_dont_exist(setup_sqllab, ctas_method): } -@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices", "login_as_admin") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) -def test_run_sync_query_cta(setup_sqllab, ctas_method): +def test_run_sync_query_cta(test_client, ctas_method): tmp_table_name = f"{TEST_SYNC}_{ctas_method.lower()}" - result = run_sql(QUERY, tmp_table=tmp_table_name, cta=True, ctas_method=ctas_method) + result = run_sql( + test_client, QUERY, tmp_table=tmp_table_name, cta=True, ctas_method=ctas_method + ) assert QueryStatus.SUCCESS == result["query"]["state"], result assert cta_result(ctas_method) == (result["data"], result["columns"]) # Check the data in the tmp table. select_query = get_query_by_id(result["query"]["serverId"]) - results = run_sql(select_query.select_sql) + results = run_sql(test_client, select_query.select_sql) assert QueryStatus.SUCCESS == results["status"], results assert len(results["data"]) > 0 delete_tmp_view_or_table(tmp_table_name, ctas_method) -@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") -def test_run_sync_query_cta_no_data(setup_sqllab): +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices", "login_as_admin") +def test_run_sync_query_cta_no_data(test_client): sql_empty_result = "SELECT * FROM birth_names WHERE name='random'" - result = run_sql(sql_empty_result) + result = run_sql(test_client, sql_empty_result) assert QueryStatus.SUCCESS == result["query"]["state"] assert ([], []) == (result["data"], result["columns"]) @@ -216,18 +216,20 @@ def test_run_sync_query_cta_no_data(setup_sqllab): assert QueryStatus.SUCCESS == query.status -@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices", "login_as_admin") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) @mock.patch( "superset.sqllab.sqllab_execution_context.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME, ) -def test_run_sync_query_cta_config(setup_sqllab, ctas_method): +def test_run_sync_query_cta_config(test_client, ctas_method): if backend() == "sqlite": # sqlite doesn't support schemas return tmp_table_name = f"{TEST_SYNC_CTA}_{ctas_method.lower()}" - result = run_sql(QUERY, cta=True, ctas_method=ctas_method, tmp_table=tmp_table_name) + result = run_sql( + test_client, QUERY, cta=True, ctas_method=ctas_method, tmp_table=tmp_table_name + ) assert QueryStatus.SUCCESS == result["query"]["state"], result assert cta_result(ctas_method) == (result["data"], result["columns"]) @@ -239,24 +241,25 @@ def test_run_sync_query_cta_config(setup_sqllab, ctas_method): assert query.select_sql == get_select_star( tmp_table_name, limit=query.limit, schema=CTAS_SCHEMA_NAME ) - results = run_sql(query.select_sql) + results = run_sql(test_client, query.select_sql) assert QueryStatus.SUCCESS == results["status"], result delete_tmp_view_or_table(f"{CTAS_SCHEMA_NAME}.{tmp_table_name}", ctas_method) -@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices", "login_as_admin") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) @mock.patch( "superset.sqllab.sqllab_execution_context.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME, ) -def test_run_async_query_cta_config(setup_sqllab, ctas_method): +def test_run_async_query_cta_config(test_client, ctas_method): if backend() in {"sqlite", "mysql"}: # sqlite doesn't support schemas, mysql is flaky return tmp_table_name = f"{TEST_ASYNC_CTA_CONFIG}_{ctas_method.lower()}" result = run_sql( + test_client, QUERY, cta=True, ctas_method=ctas_method, @@ -279,16 +282,21 @@ def test_run_async_query_cta_config(setup_sqllab, ctas_method): delete_tmp_view_or_table(f"{CTAS_SCHEMA_NAME}.{tmp_table_name}", ctas_method) -@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices", "login_as_admin") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) -def test_run_async_cta_query(setup_sqllab, ctas_method): +def test_run_async_cta_query(test_client, ctas_method): if backend() == "mysql": # failing return table_name = f"{TEST_ASYNC_CTA}_{ctas_method.lower()}" result = run_sql( - QUERY, cta=True, ctas_method=ctas_method, async_=True, tmp_table=table_name + test_client, + QUERY, + cta=True, + ctas_method=ctas_method, + async_=True, + tmp_table=table_name, ) query = wait_for_success(result) @@ -305,16 +313,21 @@ def test_run_async_cta_query(setup_sqllab, ctas_method): delete_tmp_view_or_table(table_name, ctas_method) -@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices", "login_as_admin") @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) -def test_run_async_cta_query_with_lower_limit(setup_sqllab, ctas_method): +def test_run_async_cta_query_with_lower_limit(test_client, ctas_method): if backend() == "mysql": # failing return tmp_table = f"{TEST_ASYNC_LOWER_LIMIT}_{ctas_method.lower()}" result = run_sql( - QUERY, cta=True, ctas_method=ctas_method, async_=True, tmp_table=tmp_table + test_client, + QUERY, + cta=True, + ctas_method=ctas_method, + async_=True, + tmp_table=tmp_table, ) query = wait_for_success(result) assert QueryStatus.SUCCESS == query.status diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index a37acf6eafc3a..b935579a00416 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -933,15 +933,27 @@ def test_admin_gets_filtered_energy_slices(self): } ], "keys": ["none"], - "columns": ["slice_name"], + "columns": ["slice_name", "description", "table.table_name"], } self.login(username="admin") uri = f"api/v1/chart/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) - data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 8) + data = rv.json + assert rv.status_code == 200 + assert data["count"] > 0 + for chart in data["result"]: + print(chart) + assert ( + "energy" + in " ".join( + [ + chart["slice_name"] or "", + chart["description"] or "", + chart["table"]["table_name"] or "", + ] + ).lower() + ) @pytest.mark.usefixtures("create_certified_charts") def test_gets_certified_charts_filter(self): diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index fee13c8950aba..ea46039d8412e 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -17,19 +17,22 @@ from __future__ import annotations import functools -from typing import Any, Callable, Generator, Optional, TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING from unittest.mock import patch import pytest +from flask.ctx import AppContext from sqlalchemy.engine import Engine from superset import db from superset.extensions import feature_flag_manager from superset.utils.core import json_dumps_w_dates from superset.utils.database import get_example_database, remove_database -from tests.integration_tests.test_app import app +from tests.integration_tests.test_app import app, login if TYPE_CHECKING: + from flask.testing import FlaskClient + from superset.connectors.sqla.models import Database CTAS_SCHEMA_NAME = "sqllab_test_db" @@ -38,8 +41,31 @@ @pytest.fixture def app_context(): - with app.app_context(): - yield + with app.app_context() as ctx: + yield ctx + + +@pytest.fixture +def test_client(app_context: AppContext): + with app.test_client() as client: + yield client + + +@pytest.fixture +def login_as(test_client: "FlaskClient[Any]"): + """Fixture with app context and logged in admin user.""" + + def _login_as(username: str, password: str = "general"): + login(test_client, username=username, password=password) + + yield _login_as + # no need to log out as both app_context and test_client are + # function level fixtures anyway + + +@pytest.fixture +def login_as_admin(login_as: Callable[..., None]): + yield login_as("admin") @pytest.fixture(autouse=True, scope="session") diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 276231f1ff550..86f6df7b15f90 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -83,11 +83,17 @@ logger = logging.getLogger(__name__) +@pytest.fixture(scope="module") +def cleanup(): + db.session.query(Query).delete() + db.session.query(DatasourceAccessRequest).delete() + db.session.query(models.Log).delete() + db.session.commit() + yield + + class TestCore(SupersetTestCase): def setUp(self): - db.session.query(Query).delete() - db.session.query(DatasourceAccessRequest).delete() - db.session.query(models.Log).delete() self.table_ids = { tbl.table_name: tbl.id for tbl in (db.session.query(SqlaTable).all()) } diff --git a/tests/integration_tests/csv_upload_tests.py b/tests/integration_tests/csv_upload_tests.py index c9bc11db98557..6cd228b9cb1d1 100644 --- a/tests/integration_tests/csv_upload_tests.py +++ b/tests/integration_tests/csv_upload_tests.py @@ -29,13 +29,12 @@ import superset.utils.database from superset.sql_parse import Table -from superset import security_manager -from tests.integration_tests.conftest import ADMIN_SCHEMA_NAME -from tests.integration_tests.test_app import app # isort:skip from superset import db from superset.models.core import Database from superset.utils import core as utils -from tests.integration_tests.base_tests import get_resp, login, SupersetTestCase +from tests.integration_tests.test_app import app, login +from tests.integration_tests.base_tests import get_resp + logger = logging.getLogger(__name__) @@ -59,30 +58,27 @@ @pytest.fixture(scope="module") -def setup_csv_upload(): - with app.app_context(): - login(test_client, username="admin") - - upload_db = superset.utils.database.get_or_create_db( - CSV_UPLOAD_DATABASE, app.config["SQLALCHEMY_EXAMPLES_URI"] - ) - extra = upload_db.get_extra() - extra["explore_database_id"] = superset.utils.database.get_example_database().id - upload_db.extra = json.dumps(extra) - upload_db.allow_file_upload = True - db.session.commit() - - yield - - upload_db = get_upload_db() - engine = upload_db.get_sqla_engine() - engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}") - engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}") - engine.execute(f"DROP TABLE IF EXISTS {PARQUET_UPLOAD_TABLE}") - engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}") - engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_EXPLORE}") - db.session.delete(upload_db) - db.session.commit() +def setup_csv_upload(login_as_admin): + upload_db = superset.utils.database.get_or_create_db( + CSV_UPLOAD_DATABASE, app.config["SQLALCHEMY_EXAMPLES_URI"] + ) + extra = upload_db.get_extra() + extra["explore_database_id"] = superset.utils.database.get_example_database().id + upload_db.extra = json.dumps(extra) + upload_db.allow_file_upload = True + db.session.commit() + + yield + + upload_db = get_upload_db() + engine = upload_db.get_sqla_engine() + engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}") + engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}") + engine.execute(f"DROP TABLE IF EXISTS {PARQUET_UPLOAD_TABLE}") + engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}") + engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_EXPLORE}") + db.session.delete(upload_db) + db.session.commit() @pytest.fixture(scope="module") diff --git a/tests/integration_tests/dashboard_tests.py b/tests/integration_tests/dashboard_tests.py index 3ad9b07e29c14..3432b0fc16d88 100644 --- a/tests/integration_tests/dashboard_tests.py +++ b/tests/integration_tests/dashboard_tests.py @@ -23,7 +23,7 @@ from random import random import pytest -from flask import escape, url_for +from flask import Response, escape, url_for from sqlalchemy import func from tests.integration_tests.test_app import app @@ -125,13 +125,12 @@ def get_mock_positions(self, dash): positions[id] = d return positions - def test_dashboard(self): + def test_get_dashboard(self): self.login(username="admin") - urls = {} - for dash in db.session.query(Dashboard).all(): - urls[dash.dashboard_title] = dash.url - for title, url in urls.items(): - assert escape(title) in self.client.get(url).data.decode("utf-8") + for dash in db.session.query(Dashboard): + assert escape(dash.dashboard_title) in self.client.get(dash.url).get_data( + as_text=True + ) def test_superset_dashboard_url(self): url_for("Superset.dashboard", dashboard_id_or_slug=1) @@ -424,8 +423,9 @@ def test_public_user_dashboard_access(self): # Cleanup self.revoke_public_access_to_table(table) - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - @pytest.mark.usefixtures("public_role_like_gamma") + @pytest.mark.usefixtures( + "load_birth_names_dashboard_with_slices", "public_role_like_gamma" + ) def test_dashboard_with_created_by_can_be_accessed_by_public_users(self): self.logout() table = db.session.query(SqlaTable).filter_by(table_name="birth_names").one() @@ -437,8 +437,9 @@ def test_dashboard_with_created_by_can_be_accessed_by_public_users(self): db.session.merge(dash) db.session.commit() - # this asserts a non-4xx response - self.get_resp("/superset/dashboard/births/") + res: Response = self.client.get("/superset/dashboard/births/") + assert res.status_code == 200 + # Cleanup self.revoke_public_access_to_table(table) diff --git a/tests/integration_tests/dashboards/filter_sets/create_api_tests.py b/tests/integration_tests/dashboards/filter_sets/create_api_tests.py index fcd2923fb8696..b5d1919dd430a 100644 --- a/tests/integration_tests/dashboards/filter_sets/create_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/create_api_tests.py @@ -16,7 +16,9 @@ # under the License. from __future__ import annotations -from typing import Any, Dict, TYPE_CHECKING +from typing import Any, Dict + +from flask.testing import FlaskClient from superset.dashboards.filter_sets.consts import ( DASHBOARD_OWNER_TYPE, @@ -27,7 +29,6 @@ OWNER_TYPE_FIELD, USER_OWNER_TYPE, ) -from tests.integration_tests.base_tests import login from tests.integration_tests.dashboards.filter_sets.consts import ( ADMIN_USERNAME_FOR_TEST, DASHBOARD_OWNER_USERNAME, @@ -38,9 +39,7 @@ get_filter_set_by_dashboard_id, get_filter_set_by_name, ) - -if TYPE_CHECKING: - from flask.testing import FlaskClient +from tests.integration_tests.test_app import login def assert_filterset_was_not_created(filter_set_data: Dict[str, Any]) -> None: diff --git a/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py b/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py index 8e7e0bcb6004e..7011cb5781282 100644 --- a/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py @@ -18,7 +18,6 @@ from typing import Any, Dict, List, TYPE_CHECKING -from tests.integration_tests.base_tests import login from tests.integration_tests.dashboards.filter_sets.consts import ( DASHBOARD_OWNER_USERNAME, FILTER_SET_OWNER_USERNAME, @@ -29,6 +28,7 @@ collect_all_ids, get_filter_set_by_name, ) +from tests.integration_tests.test_app import login if TYPE_CHECKING: from flask.testing import FlaskClient diff --git a/tests/integration_tests/dashboards/filter_sets/get_api_tests.py b/tests/integration_tests/dashboards/filter_sets/get_api_tests.py index 7be6f367dd6b9..ad40d0e33c859 100644 --- a/tests/integration_tests/dashboards/filter_sets/get_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/get_api_tests.py @@ -18,7 +18,6 @@ from typing import Any, Dict, List, Set, TYPE_CHECKING -from tests.integration_tests.base_tests import login from tests.integration_tests.dashboards.filter_sets.consts import ( DASHBOARD_OWNER_USERNAME, FILTER_SET_OWNER_USERNAME, @@ -28,6 +27,7 @@ call_get_filter_sets, collect_all_ids, ) +from tests.integration_tests.test_app import login if TYPE_CHECKING: from flask.testing import FlaskClient diff --git a/tests/integration_tests/dashboards/filter_sets/update_api_tests.py b/tests/integration_tests/dashboards/filter_sets/update_api_tests.py index 4096e100994f8..07db98f617815 100644 --- a/tests/integration_tests/dashboards/filter_sets/update_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/update_api_tests.py @@ -26,7 +26,6 @@ OWNER_TYPE_FIELD, PARAMS_PROPERTY, ) -from tests.integration_tests.base_tests import login from tests.integration_tests.dashboards.filter_sets.consts import ( DASHBOARD_OWNER_USERNAME, FILTER_SET_OWNER_USERNAME, @@ -37,6 +36,7 @@ collect_all_ids, get_filter_set_by_name, ) +from tests.integration_tests.test_app import login if TYPE_CHECKING: from flask.testing import FlaskClient diff --git a/tests/integration_tests/dashboards/filter_state/api_tests.py b/tests/integration_tests/dashboards/filter_state/api_tests.py index ea00f2e6714f5..1df752b230d40 100644 --- a/tests/integration_tests/dashboards/filter_state/api_tests.py +++ b/tests/integration_tests/dashboards/filter_state/api_tests.py @@ -18,6 +18,7 @@ from unittest.mock import patch import pytest +from flask.ctx import AppContext from flask_appbuilder.security.sqla.models import User from sqlalchemy.orm import Session @@ -26,8 +27,6 @@ from superset.models.dashboard import Dashboard from superset.temporary_cache.commands.entry import Entry from superset.temporary_cache.utils import cache_key -from tests.integration_tests.base_tests import login -from tests.integration_tests.fixtures.client import client from tests.integration_tests.fixtures.world_bank_dashboard import ( load_world_bank_dashboard_with_slices, load_world_bank_data, @@ -40,19 +39,17 @@ @pytest.fixture -def dashboard_id(load_world_bank_dashboard_with_slices) -> int: - with app.app_context() as ctx: - session: Session = ctx.app.appbuilder.get_session - dashboard = session.query(Dashboard).filter_by(slug="world_health").one() - return dashboard.id +def dashboard_id(app_context: AppContext, load_world_bank_dashboard_with_slices) -> int: + session: Session = app_context.app.appbuilder.get_session + dashboard = session.query(Dashboard).filter_by(slug="world_health").one() + return dashboard.id @pytest.fixture -def admin_id() -> int: - with app.app_context() as ctx: - session: Session = ctx.app.appbuilder.get_session - admin = session.query(User).filter_by(username="admin").one_or_none() - return admin.id +def admin_id(app_context: AppContext) -> int: + session: Session = app_context.app.appbuilder.get_session + admin = session.query(User).filter_by(username="admin").one_or_none() + return admin.id @pytest.fixture(autouse=True) @@ -61,55 +58,62 @@ def cache(dashboard_id, admin_id): cache_manager.filter_state_cache.set(cache_key(dashboard_id, KEY), entry) -def test_post(client, dashboard_id: int): - login(client, "admin") - payload = { - "value": INITIAL_VALUE, - } - resp = client.post(f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload) +def test_post(test_client, login_as_admin, dashboard_id: int): + resp = test_client.post( + f"api/v1/dashboard/{dashboard_id}/filter_state", + json={ + "value": INITIAL_VALUE, + }, + ) assert resp.status_code == 201 -def test_post_bad_request_non_string(client, dashboard_id: int): - login(client, "admin") - payload = { - "value": 1234, - } - resp = client.post(f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload) +def test_post_bad_request_non_string(test_client, login_as_admin, dashboard_id: int): + resp = test_client.post( + f"api/v1/dashboard/{dashboard_id}/filter_state", + json={ + "value": 1234, + }, + ) assert resp.status_code == 400 -def test_post_bad_request_non_json_string(client, dashboard_id: int): - login(client, "admin") +def test_post_bad_request_non_json_string( + test_client, login_as_admin, dashboard_id: int +): payload = { "value": "foo", } - resp = client.post(f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload) + resp = test_client.post( + f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload + ) assert resp.status_code == 400 @patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access") -def test_post_access_denied(mock_raise_for_dashboard_access, client, dashboard_id: int): - login(client, "admin") +def test_post_access_denied( + mock_raise_for_dashboard_access, test_client, login_as_admin, dashboard_id: int +): mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError() payload = { "value": INITIAL_VALUE, } - resp = client.post(f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload) + resp = test_client.post( + f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload + ) assert resp.status_code == 403 -def test_post_same_key_for_same_tab_id(client, dashboard_id: int): - login(client, "admin") +def test_post_same_key_for_same_tab_id(test_client, login_as_admin, dashboard_id: int): payload = { "value": INITIAL_VALUE, } - resp = client.post( + resp = test_client.post( f"api/v1/dashboard/{dashboard_id}/filter_state?tab_id=1", json=payload ) data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") - resp = client.post( + resp = test_client.post( f"api/v1/dashboard/{dashboard_id}/filter_state?tab_id=1", json=payload ) data = json.loads(resp.data.decode("utf-8")) @@ -117,17 +121,18 @@ def test_post_same_key_for_same_tab_id(client, dashboard_id: int): assert first_key == second_key -def test_post_different_key_for_different_tab_id(client, dashboard_id: int): - login(client, "admin") +def test_post_different_key_for_different_tab_id( + test_client, login_as_admin, dashboard_id: int +): payload = { "value": INITIAL_VALUE, } - resp = client.post( + resp = test_client.post( f"api/v1/dashboard/{dashboard_id}/filter_state?tab_id=1", json=payload ) data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") - resp = client.post( + resp = test_client.post( f"api/v1/dashboard/{dashboard_id}/filter_state?tab_id=2", json=payload ) data = json.loads(resp.data.decode("utf-8")) @@ -135,42 +140,45 @@ def test_post_different_key_for_different_tab_id(client, dashboard_id: int): assert first_key != second_key -def test_post_different_key_for_no_tab_id(client, dashboard_id: int): - login(client, "admin") +def test_post_different_key_for_no_tab_id( + test_client, login_as_admin, dashboard_id: int +): payload = { "value": INITIAL_VALUE, } - resp = client.post(f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload) + resp = test_client.post( + f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload + ) data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") - resp = client.post(f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload) + resp = test_client.post( + f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload + ) data = json.loads(resp.data.decode("utf-8")) second_key = data.get("key") assert first_key != second_key -def test_put(client, dashboard_id: int): - login(client, "admin") - payload = { - "value": UPDATED_VALUE, - } - resp = client.put( - f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", json=payload +def test_put(test_client, login_as_admin, dashboard_id: int): + resp = test_client.put( + f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", + json={ + "value": UPDATED_VALUE, + }, ) assert resp.status_code == 200 -def test_put_same_key_for_same_tab_id(client, dashboard_id: int): - login(client, "admin") +def test_put_same_key_for_same_tab_id(test_client, login_as_admin, dashboard_id: int): payload = { "value": INITIAL_VALUE, } - resp = client.put( + resp = test_client.put( f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}?tab_id=1", json=payload ) data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") - resp = client.put( + resp = test_client.put( f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}?tab_id=1", json=payload ) data = json.loads(resp.data.decode("utf-8")) @@ -178,17 +186,18 @@ def test_put_same_key_for_same_tab_id(client, dashboard_id: int): assert first_key == second_key -def test_put_different_key_for_different_tab_id(client, dashboard_id: int): - login(client, "admin") +def test_put_different_key_for_different_tab_id( + test_client, login_as_admin, dashboard_id: int +): payload = { "value": INITIAL_VALUE, } - resp = client.put( + resp = test_client.put( f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}?tab_id=1", json=payload ) data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") - resp = client.put( + resp = test_client.put( f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}?tab_id=2", json=payload ) data = json.loads(resp.data.decode("utf-8")) @@ -196,17 +205,18 @@ def test_put_different_key_for_different_tab_id(client, dashboard_id: int): assert first_key != second_key -def test_put_different_key_for_no_tab_id(client, dashboard_id: int): - login(client, "admin") +def test_put_different_key_for_no_tab_id( + test_client, login_as_admin, dashboard_id: int +): payload = { "value": INITIAL_VALUE, } - resp = client.put( + resp = test_client.put( f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", json=payload ) data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") - resp = client.put( + resp = test_client.put( f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", json=payload ) data = json.loads(resp.data.decode("utf-8")) @@ -214,97 +224,94 @@ def test_put_different_key_for_no_tab_id(client, dashboard_id: int): assert first_key != second_key -def test_put_bad_request_non_string(client, dashboard_id: int): - login(client, "admin") - payload = { - "value": 1234, - } - resp = client.put( - f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", json=payload +def test_put_bad_request_non_string(test_client, login_as_admin, dashboard_id: int): + resp = test_client.put( + f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", + json={ + "value": 1234, + }, ) assert resp.status_code == 400 -def test_put_bad_request_non_json_string(client, dashboard_id: int): - login(client, "admin") - payload = { - "value": "foo", - } - resp = client.put( - f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", json=payload +def test_put_bad_request_non_json_string( + test_client, login_as_admin, dashboard_id: int +): + resp = test_client.put( + f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", + json={ + "value": "foo", + }, ) assert resp.status_code == 400 @patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access") -def test_put_access_denied(mock_raise_for_dashboard_access, client, dashboard_id: int): - login(client, "admin") +def test_put_access_denied( + mock_raise_for_dashboard_access, test_client, login_as_admin, dashboard_id: int +): mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError() - payload = { - "value": UPDATED_VALUE, - } - resp = client.put( - f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", json=payload + resp = test_client.put( + f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", + json={ + "value": UPDATED_VALUE, + }, ) assert resp.status_code == 403 -def test_put_not_owner(client, dashboard_id: int): - login(client, "gamma") - payload = { - "value": UPDATED_VALUE, - } - resp = client.put( - f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", json=payload +def test_put_not_owner(test_client, login_as, dashboard_id: int): + login_as("gamma") + resp = test_client.put( + f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}", + json={ + "value": UPDATED_VALUE, + }, ) assert resp.status_code == 403 -def test_get_key_not_found(client, dashboard_id: int): - login(client, "admin") - resp = client.get(f"api/v1/dashboard/{dashboard_id}/filter_state/unknown-key/") +def test_get_key_not_found(test_client, login_as_admin, dashboard_id: int): + resp = test_client.get(f"api/v1/dashboard/{dashboard_id}/filter_state/unknown-key/") assert resp.status_code == 404 -def test_get_dashboard_not_found(client): - login(client, "admin") - resp = client.get(f"api/v1/dashboard/{-1}/filter_state/{KEY}") +def test_get_dashboard_not_found(test_client, login_as_admin): + resp = test_client.get(f"api/v1/dashboard/{-1}/filter_state/{KEY}") assert resp.status_code == 404 -def test_get(client, dashboard_id: int): - login(client, "admin") - resp = client.get(f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}") +def test_get_dashboard_filter_state(test_client, login_as_admin, dashboard_id: int): + resp = test_client.get(f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}") assert resp.status_code == 200 data = json.loads(resp.data.decode("utf-8")) assert INITIAL_VALUE == data.get("value") @patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access") -def test_get_access_denied(mock_raise_for_dashboard_access, client, dashboard_id): - login(client, "admin") +def test_get_access_denied( + mock_raise_for_dashboard_access, test_client, login_as_admin, dashboard_id +): mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError() - resp = client.get(f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}") + resp = test_client.get(f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}") assert resp.status_code == 403 -def test_delete(client, dashboard_id: int): - login(client, "admin") - resp = client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}") +def test_delete(test_client, login_as_admin, dashboard_id: int): + resp = test_client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}") assert resp.status_code == 200 @patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access") def test_delete_access_denied( - mock_raise_for_dashboard_access, client, dashboard_id: int + mock_raise_for_dashboard_access, test_client, login_as_admin, dashboard_id: int ): - login(client, "admin") mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError() - resp = client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}") + resp = test_client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}") assert resp.status_code == 403 -def test_delete_not_owner(client, dashboard_id: int): - login(client, "gamma") - resp = client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}") +def test_delete_not_owner(test_client, login_as, dashboard_id: int): + login_as("gamma") + resp = test_client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{KEY}") assert resp.status_code == 403 diff --git a/tests/integration_tests/dashboards/permalink/api_tests.py b/tests/integration_tests/dashboards/permalink/api_tests.py index 036b42857ace0..018e06cd49ac8 100644 --- a/tests/integration_tests/dashboards/permalink/api_tests.py +++ b/tests/integration_tests/dashboards/permalink/api_tests.py @@ -29,8 +29,6 @@ from superset.key_value.types import KeyValueResource from superset.key_value.utils import decode_permalink_id from superset.models.dashboard import Dashboard -from tests.integration_tests.base_tests import login -from tests.integration_tests.fixtures.client import client from tests.integration_tests.fixtures.world_bank_dashboard import ( load_world_bank_dashboard_with_slices, load_world_bank_data, @@ -67,9 +65,10 @@ def permalink_salt() -> Iterator[str]: db.session.commit() -def test_post(client, dashboard_id: int, permalink_salt: str) -> None: - login(client, "admin") - resp = client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE) +def test_post( + test_client, login_as_admin, dashboard_id: int, permalink_salt: str +) -> None: + resp = test_client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE) assert resp.status_code == 201 data = resp.json key = data["key"] @@ -79,7 +78,9 @@ def test_post(client, dashboard_id: int, permalink_salt: str) -> None: assert ( data - == client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE).json + == test_client.post( + f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE + ).json ), "Should always return the same permalink key for the same payload" db.session.query(KeyValueEntry).filter_by(id=id_).delete() @@ -87,27 +88,26 @@ def test_post(client, dashboard_id: int, permalink_salt: str) -> None: @patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access") -def test_post_access_denied(mock_raise_for_dashboard_access, client, dashboard_id: int): - login(client, "admin") +def test_post_access_denied( + mock_raise_for_dashboard_access, test_client, login_as_admin, dashboard_id: int +): mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError() - resp = client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE) + resp = test_client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE) assert resp.status_code == 403 -def test_post_invalid_schema(client, dashboard_id: int): - login(client, "admin") - resp = client.post( +def test_post_invalid_schema(test_client, login_as_admin, dashboard_id: int): + resp = test_client.post( f"api/v1/dashboard/{dashboard_id}/permalink", json={"foo": "bar"} ) assert resp.status_code == 400 -def test_get(client, dashboard_id: int, permalink_salt: str): - login(client, "admin") - key = client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE).json[ - "key" - ] - resp = client.get(f"api/v1/dashboard/permalink/{key}") +def test_get(test_client, login_as_admin, dashboard_id: int, permalink_salt: str): + key = test_client.post( + f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE + ).json["key"] + resp = test_client.get(f"api/v1/dashboard/permalink/{key}") assert resp.status_code == 200 result = resp.json assert result["dashboardId"] == str(dashboard_id) diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index c4432e3ad1f93..07f9bfcf318dc 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -293,8 +293,8 @@ def test_calculated_column_in_order_by_base_engine_spec(self): table=table, expression=""" case - when gender=true then "male" - else "female" + when gender='boy' then 'male' + else 'female' end """, ) @@ -309,8 +309,8 @@ def test_calculated_column_in_order_by_base_engine_spec(self): sql = table.get_query_str(query_obj) assert ( """ORDER BY case - when gender=true then "male" - else "female" + when gender='boy' then 'male' + else 'female' end ASC;""" in sql ) diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index 549b1109529e9..2241f74f00210 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -362,8 +362,8 @@ def test_calculated_column_in_order_by(self): table=table, expression=""" case - when gender=true then "male" - else "female" + when gender='boy' then 'male' + else 'female' end """, ) diff --git a/tests/integration_tests/explore/api_tests.py b/tests/integration_tests/explore/api_tests.py index abc85737b50b0..8fb642286a3d5 100644 --- a/tests/integration_tests/explore/api_tests.py +++ b/tests/integration_tests/explore/api_tests.py @@ -26,8 +26,6 @@ from superset.explore.form_data.commands.state import TemporaryExploreState from superset.extensions import cache_manager from superset.models.slice import Slice -from tests.integration_tests.base_tests import login -from tests.integration_tests.fixtures.client import client from tests.integration_tests.fixtures.world_bank_dashboard import ( load_world_bank_dashboard_with_slices, load_world_bank_data, @@ -101,9 +99,8 @@ def assert_slice(result, chart_id, dataset_id): assert slice["form_data"]["viz_type"] == "big_number" -def test_no_params_provided(client): - login(client, "admin") - resp = client.get(f"api/v1/explore/") +def test_no_params_provided(test_client, login_as_admin): + resp = test_client.get(f"api/v1/explore/") assert resp.status_code == 200 data = json.loads(resp.data.decode("utf-8")) result = data.get("result") @@ -113,9 +110,8 @@ def test_no_params_provided(client): assert result["slice"] == None -def test_get_from_cache(client, dataset): - login(client, "admin") - resp = client.get( +def test_get_from_cache(test_client, login_as_admin, dataset): + resp = test_client.get( f"api/v1/explore/?form_data_key={FORM_DATA_KEY}&dataset_id={dataset.id}&dataset_type={dataset.type}" ) assert resp.status_code == 200 @@ -128,10 +124,11 @@ def test_get_from_cache(client, dataset): assert result["slice"] == None -def test_get_from_cache_unknown_key_chart_id(client, chart_id, dataset): - login(client, "admin") +def test_get_from_cache_unknown_key_chart_id( + test_client, login_as_admin, chart_id, dataset +): unknown_key = "unknown_key" - resp = client.get( + resp = test_client.get( f"api/v1/explore/?form_data_key={unknown_key}&slice_id={chart_id}" ) assert resp.status_code == 200 @@ -146,10 +143,9 @@ def test_get_from_cache_unknown_key_chart_id(client, chart_id, dataset): ) -def test_get_from_cache_unknown_key_dataset(client, dataset): - login(client, "admin") +def test_get_from_cache_unknown_key_dataset(test_client, login_as_admin, dataset): unknown_key = "unknown_key" - resp = client.get( + resp = test_client.get( f"api/v1/explore/?form_data_key={unknown_key}&dataset_id={dataset.id}&dataset_type={dataset.type}" ) assert resp.status_code == 200 @@ -164,10 +160,9 @@ def test_get_from_cache_unknown_key_dataset(client, dataset): assert result["slice"] == None -def test_get_from_cache_unknown_key_no_extra_parameters(client): - login(client, "admin") +def test_get_from_cache_unknown_key_no_extra_parameters(test_client, login_as_admin): unknown_key = "unknown_key" - resp = client.get(f"api/v1/explore/?form_data_key={unknown_key}") + resp = test_client.get(f"api/v1/explore/?form_data_key={unknown_key}") assert resp.status_code == 200 data = json.loads(resp.data.decode("utf-8")) result = data.get("result") @@ -177,17 +172,16 @@ def test_get_from_cache_unknown_key_no_extra_parameters(client): assert result["slice"] == None -def test_get_from_permalink(client, chart_id, dataset): - login(client, "admin") +def test_get_from_permalink(test_client, login_as_admin, chart_id, dataset): form_data = { "chart_id": chart_id, "datasource": f"{dataset.id}__{dataset.type}", **FORM_DATA, } - resp = client.post(f"api/v1/explore/permalink", json={"formData": form_data}) + resp = test_client.post(f"api/v1/explore/permalink", json={"formData": form_data}) data = json.loads(resp.data.decode("utf-8")) permalink_key = data["key"] - resp = client.get(f"api/v1/explore/?permalink_key={permalink_key}") + resp = test_client.get(f"api/v1/explore/?permalink_key={permalink_key}") assert resp.status_code == 200 data = json.loads(resp.data.decode("utf-8")) result = data.get("result") @@ -198,21 +192,21 @@ def test_get_from_permalink(client, chart_id, dataset): assert result["slice"] == None -def test_get_from_permalink_unknown_key(client): - login(client, "admin") +def test_get_from_permalink_unknown_key(test_client, login_as_admin): unknown_key = "unknown_key" - resp = client.get(f"api/v1/explore/?permalink_key={unknown_key}") + resp = test_client.get(f"api/v1/explore/?permalink_key={unknown_key}") assert resp.status_code == 404 @patch("superset.security.SupersetSecurityManager.can_access_datasource") -def test_get_dataset_access_denied(mock_can_access_datasource, client, dataset): +def test_get_dataset_access_denied( + mock_can_access_datasource, test_client, login_as_admin, dataset +): message = "Dataset access denied" mock_can_access_datasource.side_effect = DatasetAccessDeniedError( message=message, dataset_id=dataset.id, dataset_type=dataset.type ) - login(client, "admin") - resp = client.get( + resp = test_client.get( f"api/v1/explore/?form_data_key={FORM_DATA_KEY}&dataset_id={dataset.id}&dataset_type={dataset.type}" ) data = json.loads(resp.data.decode("utf-8")) @@ -223,11 +217,10 @@ def test_get_dataset_access_denied(mock_can_access_datasource, client, dataset): @patch("superset.datasource.dao.DatasourceDAO.get_datasource") -def test_wrong_endpoint(mock_get_datasource, client, dataset): +def test_wrong_endpoint(mock_get_datasource, test_client, login_as_admin, dataset): dataset.default_endpoint = "another_endpoint" mock_get_datasource.return_value = dataset - login(client, "admin") - resp = client.get( + resp = test_client.get( f"api/v1/explore/?dataset_id={dataset.id}&dataset_type={dataset.type}" ) data = json.loads(resp.data.decode("utf-8")) diff --git a/tests/integration_tests/explore/form_data/api_tests.py b/tests/integration_tests/explore/form_data/api_tests.py index dae713ff7041b..fe8425e282dfb 100644 --- a/tests/integration_tests/explore/form_data/api_tests.py +++ b/tests/integration_tests/explore/form_data/api_tests.py @@ -27,8 +27,6 @@ from superset.extensions import cache_manager from superset.models.slice import Slice from superset.utils.core import DatasourceType -from tests.integration_tests.base_tests import login -from tests.integration_tests.fixtures.client import client from tests.integration_tests.fixtures.world_bank_dashboard import ( load_world_bank_dashboard_with_slices, load_world_bank_data, @@ -80,82 +78,85 @@ def cache(chart_id, admin_id, datasource): cache_manager.explore_form_data_cache.set(KEY, entry) -def test_post(client, chart_id: int, datasource: SqlaTable): - login(client, "admin") +def test_post(test_client, login_as_admin, chart_id: int, datasource: SqlaTable): payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": INITIAL_FORM_DATA, } - resp = client.post("api/v1/explore/form_data", json=payload) + resp = test_client.post("api/v1/explore/form_data", json=payload) assert resp.status_code == 201 -def test_post_bad_request_non_string(client, chart_id: int, datasource: SqlaTable): - login(client, "admin") +def test_post_bad_request_non_string( + test_client, login_as_admin, chart_id: int, datasource: SqlaTable +): payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": 1234, } - resp = client.post("api/v1/explore/form_data", json=payload) + resp = test_client.post("api/v1/explore/form_data", json=payload) assert resp.status_code == 400 -def test_post_bad_request_non_json_string(client, chart_id: int, datasource: SqlaTable): - login(client, "admin") +def test_post_bad_request_non_json_string( + test_client, login_as_admin, chart_id: int, datasource: SqlaTable +): payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": "foo", } - resp = client.post("api/v1/explore/form_data", json=payload) + resp = test_client.post("api/v1/explore/form_data", json=payload) assert resp.status_code == 400 -def test_post_access_denied(client, chart_id: int, datasource: SqlaTable): - login(client, "gamma") +def test_post_access_denied( + test_client, login_as, chart_id: int, datasource: SqlaTable +): + login_as("gamma") payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": INITIAL_FORM_DATA, } - resp = client.post("api/v1/explore/form_data", json=payload) + resp = test_client.post("api/v1/explore/form_data", json=payload) assert resp.status_code == 404 -def test_post_same_key_for_same_context(client, chart_id: int, datasource: SqlaTable): - login(client, "admin") +def test_post_same_key_for_same_context( + test_client, login_as_admin, chart_id: int, datasource: SqlaTable +): payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } - resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload) + resp = test_client.post("api/v1/explore/form_data?tab_id=1", json=payload) data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") - resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload) + resp = test_client.post("api/v1/explore/form_data?tab_id=1", json=payload) data = json.loads(resp.data.decode("utf-8")) second_key = data.get("key") assert first_key == second_key def test_post_different_key_for_different_context( - client, chart_id: int, datasource: SqlaTable + test_client, login_as_admin, chart_id: int, datasource: SqlaTable ): - login(client, "admin") payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } - resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload) + resp = test_client.post("api/v1/explore/form_data?tab_id=1", json=payload) data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") payload = { @@ -163,231 +164,235 @@ def test_post_different_key_for_different_context( "datasource_type": datasource.type, "form_data": json.dumps({"test": "initial value"}), } - resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload) + resp = test_client.post("api/v1/explore/form_data?tab_id=1", json=payload) data = json.loads(resp.data.decode("utf-8")) second_key = data.get("key") assert first_key != second_key -def test_post_same_key_for_same_tab_id(client, chart_id: int, datasource: SqlaTable): - login(client, "admin") +def test_post_same_key_for_same_tab_id( + test_client, login_as_admin, chart_id: int, datasource: SqlaTable +): payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": json.dumps({"test": "initial value"}), } - resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload) + resp = test_client.post("api/v1/explore/form_data?tab_id=1", json=payload) data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") - resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload) + resp = test_client.post("api/v1/explore/form_data?tab_id=1", json=payload) data = json.loads(resp.data.decode("utf-8")) second_key = data.get("key") assert first_key == second_key def test_post_different_key_for_different_tab_id( - client, chart_id: int, datasource: SqlaTable + test_client, login_as_admin, chart_id: int, datasource: SqlaTable ): - login(client, "admin") payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": json.dumps({"test": "initial value"}), } - resp = client.post("api/v1/explore/form_data?tab_id=1", json=payload) + resp = test_client.post("api/v1/explore/form_data?tab_id=1", json=payload) data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") - resp = client.post("api/v1/explore/form_data?tab_id=2", json=payload) + resp = test_client.post("api/v1/explore/form_data?tab_id=2", json=payload) data = json.loads(resp.data.decode("utf-8")) second_key = data.get("key") assert first_key != second_key -def test_post_different_key_for_no_tab_id(client, chart_id: int, datasource: SqlaTable): - login(client, "admin") +def test_post_different_key_for_no_tab_id( + test_client, login_as_admin, chart_id: int, datasource: SqlaTable +): payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": INITIAL_FORM_DATA, } - resp = client.post("api/v1/explore/form_data", json=payload) + resp = test_client.post("api/v1/explore/form_data", json=payload) data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") - resp = client.post("api/v1/explore/form_data", json=payload) + resp = test_client.post("api/v1/explore/form_data", json=payload) data = json.loads(resp.data.decode("utf-8")) second_key = data.get("key") assert first_key != second_key -def test_put(client, chart_id: int, datasource: SqlaTable): - login(client, "admin") +def test_put(test_client, login_as_admin, chart_id: int, datasource: SqlaTable): payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } - resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload) + resp = test_client.put(f"api/v1/explore/form_data/{KEY}", json=payload) assert resp.status_code == 200 -def test_put_same_key_for_same_tab_id(client, chart_id: int, datasource: SqlaTable): - login(client, "admin") +def test_put_same_key_for_same_tab_id( + test_client, login_as_admin, chart_id: int, datasource: SqlaTable +): payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } - resp = client.put(f"api/v1/explore/form_data/{KEY}?tab_id=1", json=payload) + resp = test_client.put(f"api/v1/explore/form_data/{KEY}?tab_id=1", json=payload) data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") - resp = client.put(f"api/v1/explore/form_data/{KEY}?tab_id=1", json=payload) + resp = test_client.put(f"api/v1/explore/form_data/{KEY}?tab_id=1", json=payload) data = json.loads(resp.data.decode("utf-8")) second_key = data.get("key") assert first_key == second_key def test_put_different_key_for_different_tab_id( - client, chart_id: int, datasource: SqlaTable + test_client, login_as_admin, chart_id: int, datasource: SqlaTable ): - login(client, "admin") payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } - resp = client.put(f"api/v1/explore/form_data/{KEY}?tab_id=1", json=payload) + resp = test_client.put(f"api/v1/explore/form_data/{KEY}?tab_id=1", json=payload) data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") - resp = client.put(f"api/v1/explore/form_data/{KEY}?tab_id=2", json=payload) + resp = test_client.put(f"api/v1/explore/form_data/{KEY}?tab_id=2", json=payload) data = json.loads(resp.data.decode("utf-8")) second_key = data.get("key") assert first_key != second_key -def test_put_different_key_for_no_tab_id(client, chart_id: int, datasource: SqlaTable): - login(client, "admin") +def test_put_different_key_for_no_tab_id( + test_client, login_as_admin, chart_id: int, datasource: SqlaTable +): payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } - resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload) + resp = test_client.put(f"api/v1/explore/form_data/{KEY}", json=payload) data = json.loads(resp.data.decode("utf-8")) first_key = data.get("key") - resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload) + resp = test_client.put(f"api/v1/explore/form_data/{KEY}", json=payload) data = json.loads(resp.data.decode("utf-8")) second_key = data.get("key") assert first_key != second_key -def test_put_bad_request(client, chart_id: int, datasource: SqlaTable): - login(client, "admin") +def test_put_bad_request( + test_client, login_as_admin, chart_id: int, datasource: SqlaTable +): payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": 1234, } - resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload) + resp = test_client.put(f"api/v1/explore/form_data/{KEY}", json=payload) assert resp.status_code == 400 -def test_put_bad_request_non_string(client, chart_id: int, datasource: SqlaTable): - login(client, "admin") +def test_put_bad_request_non_string( + test_client, login_as_admin, chart_id: int, datasource: SqlaTable +): payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": 1234, } - resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload) + resp = test_client.put(f"api/v1/explore/form_data/{KEY}", json=payload) assert resp.status_code == 400 -def test_put_bad_request_non_json_string(client, chart_id: int, datasource: SqlaTable): - login(client, "admin") +def test_put_bad_request_non_json_string( + test_client, login_as_admin, chart_id: int, datasource: SqlaTable +): payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": "foo", } - resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload) + resp = test_client.put(f"api/v1/explore/form_data/{KEY}", json=payload) assert resp.status_code == 400 -def test_put_access_denied(client, chart_id: int, datasource: SqlaTable): - login(client, "gamma") +def test_put_access_denied(test_client, login_as, chart_id: int, datasource: SqlaTable): + login_as("gamma") payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } - resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload) + resp = test_client.put(f"api/v1/explore/form_data/{KEY}", json=payload) assert resp.status_code == 404 -def test_put_not_owner(client, chart_id: int, datasource: SqlaTable): - login(client, "gamma") +def test_put_not_owner(test_client, login_as, chart_id: int, datasource: SqlaTable): + login_as("gamma") payload = { "datasource_id": datasource.id, "datasource_type": datasource.type, "chart_id": chart_id, "form_data": UPDATED_FORM_DATA, } - resp = client.put(f"api/v1/explore/form_data/{KEY}", json=payload) + resp = test_client.put(f"api/v1/explore/form_data/{KEY}", json=payload) assert resp.status_code == 404 -def test_get_key_not_found(client): - login(client, "admin") - resp = client.get(f"api/v1/explore/form_data/unknown-key") +def test_get_key_not_found(test_client, login_as_admin): + resp = test_client.get(f"api/v1/explore/form_data/unknown-key") assert resp.status_code == 404 -def test_get(client): - login(client, "admin") - resp = client.get(f"api/v1/explore/form_data/{KEY}") +def test_get(test_client, login_as_admin): + resp = test_client.get(f"api/v1/explore/form_data/{KEY}") assert resp.status_code == 200 data = json.loads(resp.data.decode("utf-8")) assert INITIAL_FORM_DATA == data.get("form_data") -def test_get_access_denied(client): - login(client, "gamma") - resp = client.get(f"api/v1/explore/form_data/{KEY}") +def test_get_access_denied(test_client, login_as): + login_as("gamma") + resp = test_client.get(f"api/v1/explore/form_data/{KEY}") assert resp.status_code == 404 @patch("superset.security.SupersetSecurityManager.can_access_datasource") -def test_get_dataset_access_denied(mock_can_access_datasource, client): +def test_get_dataset_access_denied( + mock_can_access_datasource, test_client, login_as_admin +): mock_can_access_datasource.side_effect = DatasetAccessDeniedError() - login(client, "admin") - resp = client.get(f"api/v1/explore/form_data/{KEY}") + resp = test_client.get(f"api/v1/explore/form_data/{KEY}") assert resp.status_code == 403 -def test_delete(client): - login(client, "admin") - resp = client.delete(f"api/v1/explore/form_data/{KEY}") +def test_delete(test_client, login_as_admin): + resp = test_client.delete(f"api/v1/explore/form_data/{KEY}") assert resp.status_code == 200 -def test_delete_access_denied(client): - login(client, "gamma") - resp = client.delete(f"api/v1/explore/form_data/{KEY}") +def test_delete_access_denied(test_client, login_as): + login_as("gamma") + resp = test_client.delete(f"api/v1/explore/form_data/{KEY}") assert resp.status_code == 404 -def test_delete_not_owner(client, chart_id: int, datasource: SqlaTable, admin_id: int): +def test_delete_not_owner( + test_client, login_as_admin, chart_id: int, datasource: SqlaTable, admin_id: int +): another_key = "another_key" another_owner = admin_id + 1 entry: TemporaryExploreState = { @@ -398,6 +403,5 @@ def test_delete_not_owner(client, chart_id: int, datasource: SqlaTable, admin_id "form_data": INITIAL_FORM_DATA, } cache_manager.explore_form_data_cache.set(another_key, entry) - login(client, "admin") - resp = client.delete(f"api/v1/explore/form_data/{another_key}") + resp = test_client.delete(f"api/v1/explore/form_data/{another_key}") assert resp.status_code == 403 diff --git a/tests/integration_tests/explore/permalink/api_tests.py b/tests/integration_tests/explore/permalink/api_tests.py index b5228ab301b24..a808f0111e961 100644 --- a/tests/integration_tests/explore/permalink/api_tests.py +++ b/tests/integration_tests/explore/permalink/api_tests.py @@ -28,8 +28,6 @@ from superset.key_value.utils import decode_permalink_id, encode_permalink_key from superset.models.slice import Slice from superset.utils.core import DatasourceType -from tests.integration_tests.base_tests import login -from tests.integration_tests.fixtures.client import client from tests.integration_tests.fixtures.world_bank_dashboard import ( load_world_bank_dashboard_with_slices, load_world_bank_data, @@ -38,11 +36,10 @@ @pytest.fixture -def chart(load_world_bank_dashboard_with_slices) -> Slice: - with app.app_context() as ctx: - session: Session = ctx.app.appbuilder.get_session - chart = session.query(Slice).filter_by(slice_name="World's Population").one() - return chart +def chart(app_context, load_world_bank_dashboard_with_slices) -> Slice: + session: Session = app_context.app.appbuilder.get_session + chart = session.query(Slice).filter_by(slice_name="World's Population").one() + return chart @pytest.fixture @@ -70,9 +67,10 @@ def permalink_salt() -> Iterator[str]: db.session.commit() -def test_post(client, form_data: Dict[str, Any], permalink_salt: str): - login(client, "admin") - resp = client.post(f"api/v1/explore/permalink", json={"formData": form_data}) +def test_post( + test_client, login_as_admin, form_data: Dict[str, Any], permalink_salt: str +): + resp = test_client.post(f"api/v1/explore/permalink", json={"formData": form_data}) assert resp.status_code == 201 data = json.loads(resp.data.decode("utf-8")) key = data["key"] @@ -83,13 +81,15 @@ def test_post(client, form_data: Dict[str, Any], permalink_salt: str): db.session.commit() -def test_post_access_denied(client, form_data): - login(client, "gamma") - resp = client.post(f"api/v1/explore/permalink", json={"formData": form_data}) +def test_post_access_denied(test_client, login_as, form_data): + login_as("gamma") + resp = test_client.post(f"api/v1/explore/permalink", json={"formData": form_data}) assert resp.status_code == 404 -def test_get_missing_chart(client, chart, permalink_salt: str) -> None: +def test_get_missing_chart( + test_client, login_as_admin, chart, permalink_salt: str +) -> None: from superset.key_value.models import KeyValueEntry chart_id = 1234 @@ -110,25 +110,24 @@ def test_get_missing_chart(client, chart, permalink_salt: str) -> None: db.session.add(entry) db.session.commit() key = encode_permalink_key(entry.id, permalink_salt) - login(client, "admin") - resp = client.get(f"api/v1/explore/permalink/{key}") + resp = test_client.get(f"api/v1/explore/permalink/{key}") assert resp.status_code == 404 db.session.delete(entry) db.session.commit() -def test_post_invalid_schema(client) -> None: - login(client, "admin") - resp = client.post(f"api/v1/explore/permalink", json={"abc": 123}) +def test_post_invalid_schema(test_client, login_as_admin) -> None: + resp = test_client.post(f"api/v1/explore/permalink", json={"abc": 123}) assert resp.status_code == 400 -def test_get(client, form_data: Dict[str, Any], permalink_salt: str) -> None: - login(client, "admin") - resp = client.post(f"api/v1/explore/permalink", json={"formData": form_data}) +def test_get( + test_client, login_as_admin, form_data: Dict[str, Any], permalink_salt: str +) -> None: + resp = test_client.post(f"api/v1/explore/permalink", json={"formData": form_data}) data = json.loads(resp.data.decode("utf-8")) key = data["key"] - resp = client.get(f"api/v1/explore/permalink/{key}") + resp = test_client.get(f"api/v1/explore/permalink/{key}") assert resp.status_code == 200 result = json.loads(resp.data.decode("utf-8")) assert result["state"]["formData"] == form_data diff --git a/tests/integration_tests/fixtures/birth_names_dashboard.py b/tests/integration_tests/fixtures/birth_names_dashboard.py index 0434e22295267..41fcd47919f35 100644 --- a/tests/integration_tests/fixtures/birth_names_dashboard.py +++ b/tests/integration_tests/fixtures/birth_names_dashboard.py @@ -93,24 +93,16 @@ def _create_table( return table -def _cleanup(dash_id: int, slices_ids: List[int]) -> None: +def _cleanup(dash_id: int, slice_ids: List[int]) -> None: schema = get_example_default_schema() - datasource = ( - db.session.query(SqlaTable) - .filter_by(table_name="birth_names", schema=schema) - .one() - ) - columns = [column for column in datasource.columns] - metrics = [metric for metric in datasource.metrics] - - for column in columns: - db.session.delete(column) - for metric in metrics: - db.session.delete(metric) - - dash = db.session.query(Dashboard).filter_by(id=dash_id).first() - - db.session.delete(dash) - for slice_id in slices_ids: - db.session.query(Slice).filter_by(id=slice_id).delete() + for datasource in db.session.query(SqlaTable).filter_by( + table_name="birth_names", schema=schema + ): + for col in datasource.columns + datasource.metrics: + db.session.delete(col) + + for dash in db.session.query(Dashboard).filter_by(id=dash_id): + db.session.delete(dash) + for slc in db.session.query(Slice).filter(Slice.id.in_(slice_ids)): + db.session.delete(slc) db.session.commit() diff --git a/tests/integration_tests/import_export_tests.py b/tests/integration_tests/import_export_tests.py index 81acda80185ca..5bbc985a36672 100644 --- a/tests/integration_tests/import_export_tests.py +++ b/tests/integration_tests/import_export_tests.py @@ -50,33 +50,31 @@ from .base_tests import SupersetTestCase +def delete_imports(): + with app.app_context(): + # Imported data clean up + session = db.session + for slc in session.query(Slice): + if "remote_id" in slc.params_dict: + session.delete(slc) + for dash in session.query(Dashboard): + if "remote_id" in dash.params_dict: + session.delete(dash) + for table in session.query(SqlaTable): + if "remote_id" in table.params_dict: + session.delete(table) + session.commit() + + +@pytest.fixture(autouse=True, scope="module") +def clean_imports(): + yield + delete_imports() + + class TestImportExport(SupersetTestCase): """Testing export import functionality for dashboards""" - @classmethod - def delete_imports(cls): - with app.app_context(): - # Imported data clean up - session = db.session - for slc in session.query(Slice): - if "remote_id" in slc.params_dict: - session.delete(slc) - for dash in session.query(Dashboard): - if "remote_id" in dash.params_dict: - session.delete(dash) - for table in session.query(SqlaTable): - if "remote_id" in table.params_dict: - session.delete(table) - session.commit() - - @classmethod - def setUpClass(cls): - cls.delete_imports() - - @classmethod - def tearDownClass(cls): - cls.delete_imports() - def create_slice( self, name, diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 5b811cfd15081..abd5d2be8b876 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -293,7 +293,7 @@ def test_csv_response_format(self): payload = get_query_context("birth_names") payload["result_format"] = ChartDataResultFormat.CSV.value payload["queries"][0]["row_limit"] = 10 - query_context = ChartDataQueryContextSchema().load(payload) + query_context: QueryContext = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() self.assertEqual(len(responses), 1) data = responses["queries"][0]["data"] diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 7ded842ef7afd..7bfb2ffafda1f 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -65,6 +65,7 @@ class TestSqlLab(SupersetTestCase): """Testings for Sql Lab""" + @pytest.mark.usefixtures("load_birth_names_data") def run_some_queries(self): db.session.query(Query).delete() db.session.commit() diff --git a/tests/integration_tests/test_app.py b/tests/integration_tests/test_app.py index 798f3e9cda288..c64076ec360a5 100644 --- a/tests/integration_tests/test_app.py +++ b/tests/integration_tests/test_app.py @@ -14,11 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import TYPE_CHECKING -""" -Here is where we create the app which ends up being shared across all tests.integration_tests. A future -optimization will be to create a separate app instance for each test class. -""" from superset.app import create_app +if TYPE_CHECKING: + from typing import Any + + from flask.testing import FlaskClient + app = create_app() + + +def login( + client: "FlaskClient[Any]", username: str = "admin", password: str = "general" +): + resp = client.post("/login/", data=dict(username=username, password=password)) + assert "User confirmation needed" not in resp