From e4904144846afb7d8d0433fdc7ca97c49b41d0b4 Mon Sep 17 00:00:00 2001 From: Craig Rueda Date: Wed, 20 Nov 2019 07:47:06 -0800 Subject: [PATCH] Flask App factory PR #1 (#8418) * First cut at app factory * Setting things back to master * Working with new FLASK_APP * Still need to refactor Celery * CLI mostly working * Working on unit tests * Moving cli stuff around a bit * Removing get in config * Defaulting test config * Adding flask-testing * flask-testing casing * resultsbackend property bug * Fixing up cli * Quick fix for KV api * Working on save slice * Fixed core_tests * Fixed utils_tests * Most tests working - still need to dig into remaining app_context issue in tests * All tests passing locally - need to update code comments * Fixing dashboard tests again * Blacking * Sorting imports * linting * removing envvar mangling * blacking * Fixing unit tests * isorting * licensing * fixing mysql tests * fixing cypress? * fixing .flaskenv * fixing test app_ctx * fixing cypress * moving manifest processor around * moving results backend manager around * Cleaning up __init__ a bit more * Addressing PR comments * Addressing PR comments * Blacking * Fixes for running celery worker * Tuning isort * Blacking --- .flaskenv | 4 +- requirements-dev.txt | 1 + setup.py | 1 + superset/__init__.py | 253 +++--------------------- superset/app.py | 260 +++++++++++++++++++++++++ superset/bin/superset | 14 +- superset/cli.py | 94 ++++++--- superset/extensions.py | 113 +++++++++++ superset/forms.py | 4 - superset/sql_lab.py | 2 +- superset/tasks/__init__.py | 1 - superset/tasks/cache.py | 2 +- superset/tasks/celery_app.py | 20 +- superset/tasks/schedules.py | 2 +- superset/utils/cache.py | 15 +- superset/utils/cache_manager.py | 56 ++++++ superset/utils/core.py | 27 --- superset/utils/feature_flag_manager.py | 39 ++++ superset/views/core.py | 7 +- tests/access_tests.py | 32 +-- tests/base_tests.py | 74 +++---- tests/celery_tests.py | 40 ++-- tests/core_tests.py | 75 +++---- tests/dashboard_tests.py | 38 ++-- tests/db_engine_specs/presto_tests.py | 36 +++- tests/dict_import_export_tests.py | 21 +- tests/druid_func_tests.py | 2 +- tests/email_tests.py | 3 +- tests/feature_flag_tests.py | 40 ++++ tests/import_export_tests.py | 151 +++++++------- tests/load_examples_test.py | 25 ++- tests/schedules_test.py | 66 ++++--- tests/security_tests.py | 1 + tests/sql_validator_tests.py | 12 +- tests/sqllab_tests.py | 8 +- tests/test_app.py | 24 +++ tests/utils_tests.py | 17 +- tox.ini | 2 +- 38 files changed, 1002 insertions(+), 580 deletions(-) create mode 100644 superset/app.py create mode 100644 superset/extensions.py create mode 100644 superset/utils/cache_manager.py create mode 100644 superset/utils/feature_flag_manager.py create mode 100644 tests/feature_flag_tests.py create mode 100644 tests/test_app.py diff --git a/.flaskenv b/.flaskenv index 33adde0c5eb85..49bf390fb3db5 100644 --- a/.flaskenv +++ b/.flaskenv @@ -14,5 +14,5 @@ # See the License for the specific language governing permissions and # limitations under the License. # -FLASK_APP=superset:app -FLASK_ENV=development +FLASK_APP="superset.app:create_app()" +FLASK_ENV="development" diff --git a/requirements-dev.txt b/requirements-dev.txt index 63116507b5533..9a53de611119e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,6 +17,7 @@ black==19.3b0 coverage==4.5.3 flask-cors==3.0.7 +flask-testing==0.7.1 ipdb==0.12 isort==4.3.21 mypy==0.670 diff --git a/setup.py b/setup.py index f7ac4716f47ad..0a3942c6bfed6 100644 --- a/setup.py +++ b/setup.py @@ -128,4 +128,5 @@ def get_git_sha(): "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", ], + tests_require=["flask-testing==0.7.1"], ) diff --git a/superset/__init__.py b/superset/__init__.py index 01d98f55fd621..1de8778ff94b1 100644 --- a/superset/__init__.py +++ b/superset/__init__.py @@ -14,229 +14,38 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=C,R,W """Package's main module!""" -import json -import logging -import os -from copy import deepcopy -from typing import Any, Dict +from flask import current_app, Flask +from werkzeug.local import LocalProxy -import wtforms_json -from flask import Flask, redirect -from flask_appbuilder import AppBuilder, IndexView, SQLA -from flask_appbuilder.baseviews import expose -from flask_compress import Compress -from flask_migrate import Migrate -from flask_talisman import Talisman -from flask_wtf.csrf import CSRFProtect - -from superset import config +from superset.app import create_app from superset.connectors.connector_registry import ConnectorRegistry +from superset.extensions import ( + appbuilder, + cache_manager, + db, + event_logger, + feature_flag_manager, + manifest_processor, + results_backend_manager, + security_manager, + talisman, +) from superset.security import SupersetSecurityManager -from superset.utils.core import pessimistic_connection_handling, setup_cache -from superset.utils.log import get_event_logger_from_cfg_value - -wtforms_json.init() - -APP_DIR = os.path.dirname(__file__) -CONFIG_MODULE = os.environ.get("SUPERSET_CONFIG", "superset.config") - -if not os.path.exists(config.DATA_DIR): - os.makedirs(config.DATA_DIR) - -app = Flask(__name__) -app.config.from_object(CONFIG_MODULE) # type: ignore -conf = app.config - -################################################################# -# Handling manifest file logic at app start -################################################################# -MANIFEST_FILE = APP_DIR + "/static/assets/dist/manifest.json" -manifest: Dict[Any, Any] = {} - - -def parse_manifest_json(): - global manifest - try: - with open(MANIFEST_FILE, "r") as f: - # the manifest inclues non-entry files - # we only need entries in templates - full_manifest = json.load(f) - manifest = full_manifest.get("entrypoints", {}) - except Exception: - pass - - -def get_js_manifest_files(filename): - if app.debug: - parse_manifest_json() - entry_files = manifest.get(filename, {}) - return entry_files.get("js", []) - - -def get_css_manifest_files(filename): - if app.debug: - parse_manifest_json() - entry_files = manifest.get(filename, {}) - return entry_files.get("css", []) - - -def get_unloaded_chunks(files, loaded_chunks): - filtered_files = [f for f in files if f not in loaded_chunks] - for f in filtered_files: - loaded_chunks.add(f) - return filtered_files - - -parse_manifest_json() - - -@app.context_processor -def get_manifest(): - return dict( - loaded_chunks=set(), - get_unloaded_chunks=get_unloaded_chunks, - js_manifest=get_js_manifest_files, - css_manifest=get_css_manifest_files, - ) - - -################################################################# - -for bp in conf["BLUEPRINTS"]: - try: - print("Registering blueprint: '{}'".format(bp.name)) - app.register_blueprint(bp) - except Exception as e: - print("blueprint registration failed") - logging.exception(e) - -if conf.get("SILENCE_FAB"): - logging.getLogger("flask_appbuilder").setLevel(logging.ERROR) - -db = SQLA(app) - -if conf.get("WTF_CSRF_ENABLED"): - csrf = CSRFProtect(app) - csrf_exempt_list = conf.get("WTF_CSRF_EXEMPT_LIST", []) - for ex in csrf_exempt_list: - csrf.exempt(ex) - -pessimistic_connection_handling(db.engine) - -cache = setup_cache(app, conf.get("CACHE_CONFIG")) -tables_cache = setup_cache(app, conf.get("TABLE_NAMES_CACHE_CONFIG")) - -migrate = Migrate(app, db, directory=APP_DIR + "/migrations") - -app.config["LOGGING_CONFIGURATOR"].configure_logging(app.config, app.debug) - -if app.config["ENABLE_CORS"]: - from flask_cors import CORS - - CORS(app, **app.config["CORS_OPTIONS"]) - -if app.config["ENABLE_PROXY_FIX"]: - from werkzeug.middleware.proxy_fix import ProxyFix - - app.wsgi_app = ProxyFix( # type: ignore - app.wsgi_app, **app.config["PROXY_FIX_CONFIG"] - ) - -if app.config["ENABLE_CHUNK_ENCODING"]: - - class ChunkedEncodingFix(object): - def __init__(self, app): - self.app = app - - def __call__(self, environ, start_response): - # Setting wsgi.input_terminated tells werkzeug.wsgi to ignore - # content-length and read the stream till the end. - if environ.get("HTTP_TRANSFER_ENCODING", "").lower() == u"chunked": - environ["wsgi.input_terminated"] = True - return self.app(environ, start_response) - - app.wsgi_app = ChunkedEncodingFix(app.wsgi_app) # type: ignore - -if app.config["UPLOAD_FOLDER"]: - try: - os.makedirs(app.config["UPLOAD_FOLDER"]) - except OSError: - pass - -for middleware in app.config["ADDITIONAL_MIDDLEWARE"]: - app.wsgi_app = middleware(app.wsgi_app) # type: ignore - - -class MyIndexView(IndexView): - @expose("/") - def index(self): - return redirect("/superset/welcome") - - -custom_sm = app.config["CUSTOM_SECURITY_MANAGER"] or SupersetSecurityManager -if not issubclass(custom_sm, SupersetSecurityManager): - raise Exception( - """Your CUSTOM_SECURITY_MANAGER must now extend SupersetSecurityManager, - not FAB's security manager. - See [4565] in UPDATING.md""" - ) - -with app.app_context(): - appbuilder = AppBuilder( - app, - db.session, - base_template="superset/base.html", - indexview=MyIndexView, - security_manager_class=custom_sm, - update_perms=False, # Run `superset init` to update FAB's perms - ) - -security_manager = appbuilder.sm - -results_backend = app.config["RESULTS_BACKEND"] -results_backend_use_msgpack = app.config["RESULTS_BACKEND_USE_MSGPACK"] - -# Merge user defined feature flags with default feature flags -_feature_flags = app.config["DEFAULT_FEATURE_FLAGS"] -_feature_flags.update(app.config["FEATURE_FLAGS"]) - -# Event Logger -event_logger = get_event_logger_from_cfg_value(app.config["EVENT_LOGGER"]) - - -def get_feature_flags(): - GET_FEATURE_FLAGS_FUNC = app.config["GET_FEATURE_FLAGS_FUNC"] - if GET_FEATURE_FLAGS_FUNC: - return GET_FEATURE_FLAGS_FUNC(deepcopy(_feature_flags)) - return _feature_flags - - -def is_feature_enabled(feature): - """Utility function for checking whether a feature is turned on""" - return get_feature_flags().get(feature) - - -# Flask-Compress -if conf.get("ENABLE_FLASK_COMPRESS"): - Compress(app) - - -talisman = Talisman() - -if app.config["TALISMAN_ENABLED"]: - talisman.init_app(app, **app.config["TALISMAN_CONFIG"]) - -# Hook that provides administrators a handle on the Flask APP -# after initialization -flask_app_mutator = app.config["FLASK_APP_MUTATOR"] -if flask_app_mutator: - flask_app_mutator(app) - -from superset import views # noqa isort:skip - -# Registering sources -module_datasource_map = app.config["DEFAULT_MODULE_DS_MAP"] -module_datasource_map.update(app.config["ADDITIONAL_MODULE_DS_MAP"]) -ConnectorRegistry.register_sources(module_datasource_map) +from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value + +# All of the fields located here should be considered legacy. The correct way +# to declare "global" dependencies is to define it in extensions.py, +# then initialize it in app.create_app(). These fields will be removed +# in subsequent PRs as things are migrated towards the factory pattern +app: Flask = current_app +cache = LocalProxy(lambda: cache_manager.cache) +conf = LocalProxy(lambda: current_app.config) +get_feature_flags = feature_flag_manager.get_feature_flags +get_css_manifest_files = manifest_processor.get_css_manifest_files +is_feature_enabled = feature_flag_manager.is_feature_enabled +results_backend = LocalProxy(lambda: results_backend_manager.results_backend) +results_backend_use_msgpack = LocalProxy( + lambda: results_backend_manager.should_use_msgpack +) +tables_cache = LocalProxy(lambda: cache_manager.tables_cache) diff --git a/superset/app.py b/superset/app.py new file mode 100644 index 0000000000000..efa0622808a88 --- /dev/null +++ b/superset/app.py @@ -0,0 +1,260 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +import os + +import wtforms_json +from flask import Flask, redirect +from flask_appbuilder import expose, IndexView +from flask_compress import Compress +from flask_wtf import CSRFProtect + +from superset.connectors.connector_registry import ConnectorRegistry +from superset.extensions import ( + _event_logger, + APP_DIR, + appbuilder, + cache_manager, + celery_app, + db, + feature_flag_manager, + manifest_processor, + migrate, + results_backend_manager, + talisman, +) +from superset.security import SupersetSecurityManager +from superset.utils.core import pessimistic_connection_handling +from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value + +logger = logging.getLogger(__name__) + + +def create_app(): + app = Flask(__name__) + + try: + # Allow user to override our config completely + config_module = os.environ.get("SUPERSET_CONFIG", "superset.config") + app.config.from_object(config_module) + + app_initializer = app.config.get("APP_INITIALIZER", SupersetAppInitializer)(app) + app_initializer.init_app() + + return app + + # Make sure that bootstrap errors ALWAYS get logged + except Exception as ex: + logger.exception("Failed to create app") + raise ex + + +class SupersetIndexView(IndexView): + @expose("/") + def index(self): + return redirect("/superset/welcome") + + +class SupersetAppInitializer: + def __init__(self, app: Flask) -> None: + super().__init__() + + self.flask_app = app + self.config = app.config + self.manifest: dict = {} + + def pre_init(self) -> None: + """ + Called after all other init tasks are complete + """ + wtforms_json.init() + + if not os.path.exists(self.config["DATA_DIR"]): + os.makedirs(self.config["DATA_DIR"]) + + def post_init(self) -> None: + """ + Called before any other init tasks + """ + pass + + def configure_celery(self) -> None: + celery_app.config_from_object(self.config["CELERY_CONFIG"]) + celery_app.set_default() + + @staticmethod + def init_views() -> None: + # TODO - This should iterate over all views and register them with FAB... + from superset import views # noqa pylint: disable=unused-variable + + def init_app_in_ctx(self) -> None: + """ + Runs init logic in the context of the app + """ + self.configure_feature_flags() + self.configure_fab() + self.configure_data_sources() + + # Hook that provides administrators a handle on the Flask APP + # after initialization + flask_app_mutator = self.config["FLASK_APP_MUTATOR"] + if flask_app_mutator: + flask_app_mutator(self.flask_app) + + self.init_views() + + def init_app(self) -> None: + """ + Main entry point which will delegate to other methods in + order to fully init the app + """ + self.pre_init() + + self.setup_db() + + self.configure_celery() + + self.setup_event_logger() + + self.setup_bundle_manifest() + + self.register_blueprints() + + self.configure_wtf() + + self.configure_logging() + + self.configure_middlewares() + + self.configure_cache() + + with self.flask_app.app_context(): + self.init_app_in_ctx() + + self.post_init() + + def setup_event_logger(self): + _event_logger["event_logger"] = get_event_logger_from_cfg_value( + self.flask_app.config.get("EVENT_LOGGER", DBEventLogger()) + ) + + def configure_data_sources(self): + # Registering sources + module_datasource_map = self.config["DEFAULT_MODULE_DS_MAP"] + module_datasource_map.update(self.config["ADDITIONAL_MODULE_DS_MAP"]) + ConnectorRegistry.register_sources(module_datasource_map) + + def configure_cache(self): + cache_manager.init_app(self.flask_app) + results_backend_manager.init_app(self.flask_app) + + def configure_feature_flags(self): + feature_flag_manager.init_app(self.flask_app) + + def configure_fab(self): + if self.config["SILENCE_FAB"]: + logging.getLogger("flask_appbuilder").setLevel(logging.ERROR) + + custom_sm = self.config["CUSTOM_SECURITY_MANAGER"] or SupersetSecurityManager + if not issubclass(custom_sm, SupersetSecurityManager): + raise Exception( + """Your CUSTOM_SECURITY_MANAGER must now extend SupersetSecurityManager, + not FAB's security manager. + See [4565] in UPDATING.md""" + ) + + appbuilder.indexview = SupersetIndexView + appbuilder.base_template = "superset/base.html" + appbuilder.security_manager_class = custom_sm + appbuilder.update_perms = False + appbuilder.init_app(self.flask_app, db.session) + + def configure_middlewares(self): + if self.config["ENABLE_CORS"]: + from flask_cors import CORS + + CORS(self.flask_app, **self.config["CORS_OPTIONS"]) + + if self.config["ENABLE_PROXY_FIX"]: + from werkzeug.middleware.proxy_fix import ProxyFix + + self.flask_app.wsgi_app = ProxyFix( + self.flask_app.wsgi_app, **self.config["PROXY_FIX_CONFIG"] + ) + + if self.config["ENABLE_CHUNK_ENCODING"]: + + class ChunkedEncodingFix(object): # pylint: disable=too-few-public-methods + def __init__(self, app): + self.app = app + + def __call__(self, environ, start_response): + # Setting wsgi.input_terminated tells werkzeug.wsgi to ignore + # content-length and read the stream till the end. + if environ.get("HTTP_TRANSFER_ENCODING", "").lower() == "chunked": + environ["wsgi.input_terminated"] = True + return self.app(environ, start_response) + + self.flask_app.wsgi_app = ChunkedEncodingFix(self.flask_app.wsgi_app) + + if self.config["UPLOAD_FOLDER"]: + try: + os.makedirs(self.config["UPLOAD_FOLDER"]) + except OSError: + pass + + for middleware in self.config["ADDITIONAL_MIDDLEWARE"]: + self.flask_app.wsgi_app = middleware(self.flask_app.wsgi_app) + + # Flask-Compress + if self.config["ENABLE_FLASK_COMPRESS"]: + Compress(self.flask_app) + + if self.config["TALISMAN_ENABLED"]: + talisman.init_app(self.flask_app, **self.config["TALISMAN_CONFIG"]) + + def configure_logging(self): + self.config["LOGGING_CONFIGURATOR"].configure_logging( + self.config, self.flask_app.debug + ) + + def setup_db(self): + db.init_app(self.flask_app) + + with self.flask_app.app_context(): + pessimistic_connection_handling(db.engine) + + migrate.init_app(self.flask_app, db=db, directory=APP_DIR + "/migrations") + + def configure_wtf(self): + if self.config["WTF_CSRF_ENABLED"]: + csrf = CSRFProtect(self.flask_app) + csrf_exempt_list = self.config["WTF_CSRF_EXEMPT_LIST"] + for ex in csrf_exempt_list: + csrf.exempt(ex) + + def register_blueprints(self): + for bp in self.config["BLUEPRINTS"]: + try: + logger.info(f"Registering blueprint: '{bp.name}'") + self.flask_app.register_blueprint(bp) + except Exception: # pylint: disable=broad-except + logger.exception("blueprint registration failed") + + def setup_bundle_manifest(self): + manifest_processor.init_app(self.flask_app) diff --git a/superset/bin/superset b/superset/bin/superset index 0617335e99f29..4b37e8339b2c7 100755 --- a/superset/bin/superset +++ b/superset/bin/superset @@ -15,17 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import click -from flask.cli import FlaskGroup - -from superset.cli import create_app - - -@click.group(cls=FlaskGroup, create_app=create_app) -def cli(): - """This is a management script for the Superset application.""" - pass - +from superset.cli import superset if __name__ == '__main__': - cli() + superset() diff --git a/superset/cli.py b/superset/cli.py index 8e695bbf292c6..905f36a3a9f00 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -25,27 +25,28 @@ import yaml from colorama import Fore, Style from flask import g +from flask.cli import FlaskGroup, with_appcontext from flask_appbuilder import Model from pathlib2 import Path -from superset import app, appbuilder, db, examples, security_manager -from superset.common.tags import add_favorites, add_owners, add_types -from superset.utils import core as utils, dashboard_import_export, dict_import_export +from superset import app, appbuilder, security_manager +from superset.app import create_app +from superset.extensions import celery_app, db +from superset.utils import core as utils -config = app.config -celery_app = utils.get_celery_app(config) +@click.group(cls=FlaskGroup, create_app=create_app) +@with_appcontext +def superset(): + """This is a management script for the Superset application.""" -def create_app(script_info=None): - return app + @app.shell_context_processor + def make_shell_context(): + return dict(app=app, db=db) -@app.shell_context_processor -def make_shell_context(): - return dict(app=app, db=db) - - -@app.cli.command() +@superset.command() +@with_appcontext def init(): """Inits the Superset application""" utils.get_example_database() @@ -53,7 +54,8 @@ def init(): security_manager.sync_role_definitions() -@app.cli.command() +@superset.command() +@with_appcontext @click.option("--verbose", "-v", is_flag=True, help="Show extra information") def version(verbose): """Prints the current version number""" @@ -62,7 +64,7 @@ def version(verbose): Fore.YELLOW + "Superset " + Fore.CYAN - + "{version}".format(version=config["VERSION_STRING"]) + + "{version}".format(version=app.config["VERSION_STRING"]) ) print(Fore.BLUE + "-=" * 15) if verbose: @@ -77,6 +79,8 @@ def load_examples_run(load_test_data, only_metadata=False, force=False): examples_db = utils.get_example_database() print(f"Loading examples metadata and related data into {examples_db}") + from superset import examples + examples.load_css_templates() print("Loading energy related dataset") @@ -129,7 +133,8 @@ def load_examples_run(load_test_data, only_metadata=False, force=False): examples.load_tabbed_dashboard(only_metadata) -@app.cli.command() +@with_appcontext +@superset.command() @click.option("--load-test-data", "-t", is_flag=True, help="Load additional test data") @click.option( "--only-metadata", "-m", is_flag=True, help="Only load metadata, skip actual data" @@ -142,7 +147,8 @@ def load_examples(load_test_data, only_metadata=False, force=False): load_examples_run(load_test_data, only_metadata, force) -@app.cli.command() +@with_appcontext +@superset.command() @click.option("--database_name", "-d", help="Database name to change") @click.option("--uri", "-u", help="Database URI to change") def set_database_uri(database_name, uri): @@ -150,7 +156,8 @@ def set_database_uri(database_name, uri): utils.get_or_create_db(database_name, uri) -@app.cli.command() +@superset.command() +@with_appcontext @click.option( "--datasource", "-d", @@ -180,7 +187,8 @@ def refresh_druid(datasource, merge): session.commit() -@app.cli.command() +@superset.command() +@with_appcontext @click.option( "--path", "-p", @@ -202,6 +210,8 @@ def refresh_druid(datasource, merge): ) def import_dashboards(path, recursive, username): """Import dashboards from JSON""" + from superset.utils import dashboard_import_export + p = Path(path) files = [] if p.is_file(): @@ -222,7 +232,8 @@ def import_dashboards(path, recursive, username): logging.error(e) -@app.cli.command() +@superset.command() +@with_appcontext @click.option( "--dashboard-file", "-f", default=None, help="Specify the the file to export to" ) @@ -231,6 +242,8 @@ def import_dashboards(path, recursive, username): ) def export_dashboards(print_stdout, dashboard_file): """Export dashboards to JSON""" + from superset.utils import dashboard_import_export + data = dashboard_import_export.export_dashboards(db.session) if print_stdout or not dashboard_file: print(data) @@ -240,7 +253,8 @@ def export_dashboards(print_stdout, dashboard_file): data_stream.write(data) -@app.cli.command() +@superset.command() +@with_appcontext @click.option( "--path", "-p", @@ -265,6 +279,8 @@ def export_dashboards(print_stdout, dashboard_file): ) def import_datasources(path, sync, recursive): """Import datasources from YAML""" + from superset.utils import dict_import_export + sync_array = sync.split(",") p = Path(path) files = [] @@ -288,7 +304,8 @@ def import_datasources(path, sync, recursive): logging.error(e) -@app.cli.command() +@superset.command() +@with_appcontext @click.option( "--datasource-file", "-f", default=None, help="Specify the the file to export to" ) @@ -313,6 +330,8 @@ def export_datasources( print_stdout, datasource_file, back_references, include_defaults ): """Export datasources to YAML""" + from superset.utils import dict_import_export + data = dict_import_export.export_to_dict( session=db.session, recursive=True, @@ -327,7 +346,8 @@ def export_datasources( yaml.safe_dump(data, data_stream, default_flow_style=False) -@app.cli.command() +@superset.command() +@with_appcontext @click.option( "--back-references", "-b", @@ -337,11 +357,14 @@ def export_datasources( ) def export_datasource_schema(back_references): """Export datasource YAML schema to stdout""" + from superset.utils import dict_import_export + data = dict_import_export.export_schema_to_dict(back_references=back_references) yaml.safe_dump(data, stdout, default_flow_style=False) -@app.cli.command() +@superset.command() +@with_appcontext def update_datasources_cache(): """Refresh sqllab datasources cache""" from superset.models.core import Database @@ -360,7 +383,8 @@ def update_datasources_cache(): print("{}".format(str(e))) -@app.cli.command() +@superset.command() +@with_appcontext @click.option( "--workers", "-w", type=int, help="Number of celery server workers to fire up" ) @@ -372,14 +396,17 @@ def worker(workers): ) if workers: celery_app.conf.update(CELERYD_CONCURRENCY=workers) - elif config["SUPERSET_CELERY_WORKERS"]: - celery_app.conf.update(CELERYD_CONCURRENCY=config["SUPERSET_CELERY_WORKERS"]) + elif app.config["SUPERSET_CELERY_WORKERS"]: + celery_app.conf.update( + CELERYD_CONCURRENCY=app.config["SUPERSET_CELERY_WORKERS"] + ) worker = celery_app.Worker(optimization="fair") worker.start() -@app.cli.command() +@superset.command() +@with_appcontext @click.option( "-p", "--port", default="5555", help="Port on which to start the Flower process" ) @@ -409,7 +436,8 @@ def flower(port, address): Popen(cmd, shell=True).wait() -@app.cli.command() +@superset.command() +@with_appcontext def load_test_users(): """ Loads admin, alpha, and gamma user for testing purposes @@ -426,7 +454,7 @@ def load_test_users_run(): Syncs permissions for those users/roles """ - if config["TESTING"]: + if app.config["TESTING"]: sm = security_manager @@ -463,11 +491,15 @@ def load_test_users_run(): sm.get_session.commit() -@app.cli.command() +@superset.command() +@with_appcontext def sync_tags(): """Rebuilds special tags (owner, type, favorited by).""" # pylint: disable=no-member metadata = Model.metadata + + from superset.common.tags import add_favorites, add_owners, add_types + add_types(db.engine, metadata) add_owners(db.engine, metadata) add_favorites(db.engine, metadata) diff --git a/superset/extensions.py b/superset/extensions.py new file mode 100644 index 0000000000000..2974739b71f88 --- /dev/null +++ b/superset/extensions.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import os + +import celery +from flask_appbuilder import AppBuilder, SQLA +from flask_migrate import Migrate +from flask_talisman import Talisman +from werkzeug.local import LocalProxy + +from superset.utils.cache_manager import CacheManager +from superset.utils.feature_flag_manager import FeatureFlagManager + + +class ResultsBackendManager: + def __init__(self) -> None: + super().__init__() + self._results_backend = None + self._use_msgpack = False + + def init_app(self, app): + self._results_backend = app.config.get("RESULTS_BACKEND") + self._use_msgpack = app.config.get("RESULTS_BACKEND_USE_MSGPACK") + + @property + def results_backend(self): + return self._results_backend + + @property + def should_use_msgpack(self): + return self._use_msgpack + + +class UIManifestProcessor: + def __init__(self, app_dir: str) -> None: + super().__init__() + self.app = None + self.manifest: dict = {} + self.manifest_file = f"{app_dir}/static/assets/dist/manifest.json" + + def init_app(self, app): + self.app = app + # Preload the cache + self.parse_manifest_json() + + @app.context_processor + def get_manifest(): # pylint: disable=unused-variable + return dict( + loaded_chunks=set(), + get_unloaded_chunks=self.get_unloaded_chunks, + js_manifest=self.get_js_manifest_files, + css_manifest=self.get_css_manifest_files, + ) + + def parse_manifest_json(self): + try: + with open(self.manifest_file, "r") as f: + # the manifest includes non-entry files + # we only need entries in templates + full_manifest = json.load(f) + self.manifest = full_manifest.get("entrypoints", {}) + except Exception: # pylint: disable=broad-except + pass + + def get_js_manifest_files(self, filename): + if self.app.debug: + self.parse_manifest_json() + entry_files = self.manifest.get(filename, {}) + return entry_files.get("js", []) + + def get_css_manifest_files(self, filename): + if self.app.debug: + self.parse_manifest_json() + entry_files = self.manifest.get(filename, {}) + return entry_files.get("css", []) + + @staticmethod + def get_unloaded_chunks(files, loaded_chunks): + filtered_files = [f for f in files if f not in loaded_chunks] + for f in filtered_files: + loaded_chunks.add(f) + return filtered_files + + +APP_DIR = os.path.dirname(__file__) + +appbuilder = AppBuilder(update_perms=False) +cache_manager = CacheManager() +celery_app = celery.Celery() +db = SQLA() +_event_logger: dict = {} +event_logger = LocalProxy(lambda: _event_logger.get("event_logger")) +feature_flag_manager = FeatureFlagManager() +manifest_processor = UIManifestProcessor(APP_DIR) +migrate = Migrate() +results_backend_manager = ResultsBackendManager() +security_manager = LocalProxy(lambda: appbuilder.sm) +talisman = Talisman() diff --git a/superset/forms.py b/superset/forms.py index c11bf7b258571..9bf781de8becb 100644 --- a/superset/forms.py +++ b/superset/forms.py @@ -19,10 +19,6 @@ from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from wtforms import Field -from superset import app - -config = app.config - class CommaSeparatedListField(Field): widget = BS3TextFieldWidget() diff --git a/superset/sql_lab.py b/superset/sql_lab.py index df4ecc53fe2ca..842f292e27874 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -42,9 +42,9 @@ ) from superset.dataframe import SupersetDataFrame from superset.db_engine_specs import BaseEngineSpec +from superset.extensions import celery_app from superset.models.sql_lab import Query from superset.sql_parse import ParsedQuery -from superset.tasks.celery_app import app as celery_app from superset.utils.core import json_iso_dttm_ser, QueryStatus, sources, zlib_compress from superset.utils.dates import now_as_float from superset.utils.decorators import stats_timing diff --git a/superset/tasks/__init__.py b/superset/tasks/__init__.py index f6fd1b2b21f9e..fd9417fe5c1e9 100644 --- a/superset/tasks/__init__.py +++ b/superset/tasks/__init__.py @@ -15,4 +15,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from . import cache, schedules diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py index 22da4f9a48944..09cd1f344427c 100644 --- a/superset/tasks/cache.py +++ b/superset/tasks/cache.py @@ -25,9 +25,9 @@ from sqlalchemy import and_, func from superset import app, db +from superset.extensions import celery_app from superset.models.core import Dashboard, Log, Slice from superset.models.tags import Tag, TaggedObject -from superset.tasks.celery_app import app as celery_app from superset.utils.core import parse_human_datetime logger = get_task_logger(__name__) diff --git a/superset/tasks/celery_app.py b/superset/tasks/celery_app.py index 1c0305a5b0d6c..3724ec72da9e2 100644 --- a/superset/tasks/celery_app.py +++ b/superset/tasks/celery_app.py @@ -16,12 +16,20 @@ # under the License. # pylint: disable=C,R,W -"""Utility functions used across Superset""" +""" +This is the main entrypoint used by Celery workers. As such, +it needs to call create_app() in order to initialize things properly +""" # Superset framework imports -from superset import app -from superset.utils.core import get_celery_app +from superset import create_app +from superset.extensions import celery_app -# Globals -config = app.config -app = get_celery_app(config) +# Init the Flask app / configure everything +create_app() + +# Need to import late, as the celery_app will have been setup by "create_app()" +from . import cache, schedules # isort:skip + +# Export the celery app globally for Celery (as run on the cmd line) to find +app = celery_app diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py index 64d9b4dc57aae..f4b7911d8924b 100644 --- a/superset/tasks/schedules.py +++ b/superset/tasks/schedules.py @@ -39,13 +39,13 @@ # Superset framework imports from superset import app, db, security_manager +from superset.extensions import celery_app from superset.models.schedules import ( EmailDeliveryType, get_scheduler_model, ScheduleType, SliceEmailReportFormat, ) -from superset.tasks.celery_app import app as celery_app from superset.utils.core import get_email_address_list, send_email_smtp # Globals diff --git a/superset/utils/cache.py b/superset/utils/cache.py index 8fba7f888f1f2..88bc8703e887c 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -15,9 +15,12 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=C,R,W -from flask import request +from typing import Optional -from superset import tables_cache +from flask import Flask, request +from flask_caching import Cache + +from superset.extensions import cache_manager def view_cache_key(*unused_args, **unused_kwargs) -> str: @@ -43,7 +46,7 @@ def memoized_func(key=view_cache_key, attribute_in_key=None): """ def wrap(f): - if tables_cache: + if cache_manager.tables_cache: def wrapped_f(self, *args, **kwargs): if not kwargs.get("cache", True): @@ -55,11 +58,13 @@ def wrapped_f(self, *args, **kwargs): ) else: cache_key = key(*args, **kwargs) - o = tables_cache.get(cache_key) + o = cache_manager.tables_cache.get(cache_key) if not kwargs.get("force") and o is not None: return o o = f(self, *args, **kwargs) - tables_cache.set(cache_key, o, timeout=kwargs.get("cache_timeout")) + cache_manager.tables_cache.set( + cache_key, o, timeout=kwargs.get("cache_timeout") + ) return o else: diff --git a/superset/utils/cache_manager.py b/superset/utils/cache_manager.py new file mode 100644 index 0000000000000..cfbeb349978f5 --- /dev/null +++ b/superset/utils/cache_manager.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from flask import Flask +from flask_caching import Cache + + +class CacheManager: + def __init__(self) -> None: + super().__init__() + + self._tables_cache = None + self._cache = None + + def init_app(self, app): + self._cache = self._setup_cache(app, app.config.get("CACHE_CONFIG")) + self._tables_cache = self._setup_cache( + app, app.config.get("TABLE_NAMES_CACHE_CONFIG") + ) + + @staticmethod + def _setup_cache(app: Flask, cache_config) -> Optional[Cache]: + """Setup the flask-cache on a flask app""" + if cache_config: + if isinstance(cache_config, dict): + if cache_config.get("CACHE_TYPE") != "null": + return Cache(app, config=cache_config) + else: + # Accepts a custom cache initialization function, + # returning an object compatible with Flask-Caching API + return cache_config(app) + + return None + + @property + def tables_cache(self): + return self._tables_cache + + @property + def cache(self): + return self._cache diff --git a/superset/utils/core.py b/superset/utils/core.py index bb9437fd6d664..8c9cba2886bd6 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -791,20 +791,6 @@ def choicify(values): return [(v, v) for v in values] -def setup_cache(app: Flask, cache_config) -> Optional[Cache]: - """Setup the flask-cache on a flask app""" - if cache_config: - if isinstance(cache_config, dict): - if cache_config["CACHE_TYPE"] != "null": - return Cache(app, config=cache_config) - else: - # Accepts a custom cache initialization function, - # returning an object compatible with Flask-Caching API - return cache_config(app) - - return None - - def zlib_compress(data): """ Compress things in a py2/3 safe fashion @@ -832,19 +818,6 @@ def zlib_decompress(blob: bytes, decode: Optional[bool] = True) -> Union[bytes, return decompressed.decode("utf-8") if decode else decompressed -_celery_app = None - - -def get_celery_app(config): - global _celery_app - if _celery_app: - return _celery_app - _celery_app = celery.Celery() - _celery_app.config_from_object(config["CELERY_CONFIG"]) - _celery_app.set_default() - return _celery_app - - def to_adhoc(filt, expressionType="SIMPLE", clause="where"): result = { "clause": clause.upper(), diff --git a/superset/utils/feature_flag_manager.py b/superset/utils/feature_flag_manager.py new file mode 100644 index 0000000000000..7802f65c3f6cc --- /dev/null +++ b/superset/utils/feature_flag_manager.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from copy import deepcopy + + +class FeatureFlagManager: + def __init__(self) -> None: + super().__init__() + self._get_feature_flags_func = None + self._feature_flags = None + + def init_app(self, app): + self._get_feature_flags_func = app.config.get("GET_FEATURE_FLAGS_FUNC") + self._feature_flags = app.config.get("DEFAULT_FEATURE_FLAGS") or {} + self._feature_flags.update(app.config.get("FEATURE_FLAGS") or {}) + + def get_feature_flags(self): + if self._get_feature_flags_func: + return self._get_feature_flags_func(deepcopy(self._feature_flags)) + + return self._feature_flags + + def is_feature_enabled(self, feature): + """Utility function for checking whether a feature is turned on""" + return self.get_feature_flags().get(feature) diff --git a/superset/views/core.py b/superset/views/core.py index 7b24afcf010ad..e98e32199697c 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -686,7 +686,9 @@ def store(self): def get_value(self, key_id): kv = None try: - kv = db.session.query(models.KeyValue).filter_by(id=key_id).one() + kv = db.session.query(models.KeyValue).filter_by(id=key_id).scalar() + if not kv: + return Response(status=404, content_type="text/plain") except Exception as e: return json_error_response(e) return Response(kv.value, status=200, content_type="text/plain") @@ -736,6 +738,8 @@ def shortner(self): class Superset(BaseSupersetView): """The base views for Superset!""" + logger = logging.getLogger(__name__) + @has_access_api @expose("/datasources/") def datasources(self): @@ -2059,6 +2063,7 @@ def warm_up_cache(self): ) obj.get_json() except Exception as e: + self.logger.exception("Failed to warm up cache") return json_error_response(utils.error_msg_from_exception(e)) return json_success( json.dumps( diff --git a/tests/access_tests.py b/tests/access_tests.py index a27000ac29dd9..7b0be43c042f4 100644 --- a/tests/access_tests.py +++ b/tests/access_tests.py @@ -14,12 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# isort:skip_file """Unit tests for Superset""" import json import unittest from unittest import mock -from superset import app, db, security_manager +from tests.test_app import app # isort:skip +from superset import db, security_manager from superset.connectors.connector_registry import ConnectorRegistry from superset.connectors.druid.models import DruidDatasource from superset.connectors.sqla.models import SqlaTable @@ -94,22 +96,24 @@ def create_access_request(session, ds_type, ds_name, role_name, user_name): class RequestAccessTests(SupersetTestCase): @classmethod def setUpClass(cls): - security_manager.add_role("override_me") - security_manager.add_role(TEST_ROLE_1) - security_manager.add_role(TEST_ROLE_2) - security_manager.add_role(DB_ACCESS_ROLE) - security_manager.add_role(SCHEMA_ACCESS_ROLE) - db.session.commit() + with app.app_context(): + security_manager.add_role("override_me") + security_manager.add_role(TEST_ROLE_1) + security_manager.add_role(TEST_ROLE_2) + security_manager.add_role(DB_ACCESS_ROLE) + security_manager.add_role(SCHEMA_ACCESS_ROLE) + db.session.commit() @classmethod def tearDownClass(cls): - override_me = security_manager.find_role("override_me") - db.session.delete(override_me) - db.session.delete(security_manager.find_role(TEST_ROLE_1)) - db.session.delete(security_manager.find_role(TEST_ROLE_2)) - db.session.delete(security_manager.find_role(DB_ACCESS_ROLE)) - db.session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE)) - db.session.commit() + with app.app_context(): + override_me = security_manager.find_role("override_me") + db.session.delete(override_me) + db.session.delete(security_manager.find_role(TEST_ROLE_1)) + db.session.delete(security_manager.find_role(TEST_ROLE_2)) + db.session.delete(security_manager.find_role(DB_ACCESS_ROLE)) + db.session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE)) + db.session.commit() def setUp(self): self.login("admin") diff --git a/tests/base_tests.py b/tests/base_tests.py index 9e342d818395b..7399774cc8ef8 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -14,52 +14,57 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# isort:skip_file """Unit tests for Superset""" import imp import json -import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock import pandas as pd from flask_appbuilder.security.sqla import models as ab_models +from flask_testing import TestCase -from superset import app, db, is_feature_enabled, security_manager +from tests.test_app import app # isort:skip +from superset import db, security_manager from superset.connectors.druid.models import DruidCluster, DruidDatasource from superset.connectors.sqla.models import SqlaTable from superset.models import core as models from superset.models.core import Database from superset.utils.core import get_example_database -BASE_DIR = app.config["BASE_DIR"] +FAKE_DB_NAME = "fake_db_100" -class SupersetTestCase(unittest.TestCase): +class SupersetTestCase(TestCase): def __init__(self, *args, **kwargs): super(SupersetTestCase, self).__init__(*args, **kwargs) - self.client = app.test_client() self.maxDiff = None + def create_app(self): + return app + @classmethod def create_druid_test_objects(cls): # create druid cluster and druid datasources - session = db.session - cluster = ( - session.query(DruidCluster).filter_by(cluster_name="druid_test").first() - ) - if not cluster: - cluster = DruidCluster(cluster_name="druid_test") - session.add(cluster) - session.commit() - - druid_datasource1 = DruidDatasource( - datasource_name="druid_ds_1", cluster_name="druid_test" + with app.app_context(): + session = db.session + cluster = ( + session.query(DruidCluster).filter_by(cluster_name="druid_test").first() ) - session.add(druid_datasource1) - druid_datasource2 = DruidDatasource( - datasource_name="druid_ds_2", cluster_name="druid_test" - ) - session.add(druid_datasource2) - session.commit() + if not cluster: + cluster = DruidCluster(cluster_name="druid_test") + session.add(cluster) + session.commit() + + druid_datasource1 = DruidDatasource( + datasource_name="druid_ds_1", cluster_name="druid_test" + ) + session.add(druid_datasource1) + druid_datasource2 = DruidDatasource( + datasource_name="druid_ds_2", cluster_name="druid_test" + ) + session.add(druid_datasource2) + session.commit() def get_table(self, table_id): return db.session.query(SqlaTable).filter_by(id=table_id).one() @@ -210,7 +215,7 @@ def run_sql( def create_fake_db(self): self.login(username="admin") - database_name = "fake_db_100" + database_name = FAKE_DB_NAME db_id = 100 extra = """{ "schemas_allowed_for_csv_upload": @@ -225,6 +230,15 @@ def create_fake_db(self): extra=extra, ) + def delete_fake_db(self): + database = ( + db.session.query(Database) + .filter(Database.database_name == FAKE_DB_NAME) + .scalar() + ) + if database: + db.session.delete(database) + def validate_sql( self, sql, @@ -246,18 +260,6 @@ def validate_sql( raise Exception("validate_sql failed") return resp - @patch.dict("superset._feature_flags", {"FOO": True}, clear=True) - def test_existing_feature_flags(self): - self.assertTrue(is_feature_enabled("FOO")) - - @patch.dict("superset._feature_flags", {}, clear=True) - def test_nonexistent_feature_flags(self): - self.assertFalse(is_feature_enabled("FOO")) - - def test_feature_flags(self): - self.assertEqual(is_feature_enabled("foo"), "bar") - self.assertEqual(is_feature_enabled("super"), "set") - def get_dash_by_slug(self, dash_slug): sesh = db.session() return sesh.query(models.Dashboard).filter_by(slug=dash_slug).first() diff --git a/tests/celery_tests.py b/tests/celery_tests.py index a90fd957680b6..954c84d5fade4 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# isort:skip_file """Unit tests for Superset Celery worker""" import datetime import json @@ -22,7 +23,8 @@ import unittest import unittest.mock as mock -from superset import app, db, sql_lab +from tests.test_app import app # isort:skip +from superset import db, sql_lab from superset.dataframe import SupersetDataFrame from superset.db_engine_specs.base import BaseEngineSpec from superset.models.helpers import QueryStatus @@ -32,20 +34,9 @@ from .base_tests import SupersetTestCase -BASE_DIR = app.config["BASE_DIR"] CELERY_SLEEP_TIME = 5 -class CeleryConfig(object): - BROKER_URL = app.config["CELERY_CONFIG"].BROKER_URL - CELERY_IMPORTS = ("superset.sql_lab",) - CELERY_ANNOTATIONS = {"sql_lab.add": {"rate_limit": "10/s"}} - CONCURRENCY = 1 - - -app.config["CELERY_CONFIG"] = CeleryConfig - - class UtilityFunctionTests(SupersetTestCase): # TODO(bkyryliuk): support more cases in CTA function. @@ -79,10 +70,6 @@ def test_create_table_as(self): class CeleryTestCase(SupersetTestCase): - def __init__(self, *args, **kwargs): - super(CeleryTestCase, self).__init__(*args, **kwargs) - self.client = app.test_client() - def get_query_by_name(self, sql): session = db.session query = session.query(Query).filter_by(sql=sql).first() @@ -97,11 +84,22 @@ def get_query_by_id(self, id): @classmethod def setUpClass(cls): - db.session.query(Query).delete() - db.session.commit() + with app.app_context(): + + class CeleryConfig(object): + BROKER_URL = app.config["CELERY_CONFIG"].BROKER_URL + CELERY_IMPORTS = ("superset.sql_lab",) + CELERY_ANNOTATIONS = {"sql_lab.add": {"rate_limit": "10/s"}} + CONCURRENCY = 1 + + app.config["CELERY_CONFIG"] = CeleryConfig + + db.session.query(Query).delete() + db.session.commit() - worker_command = BASE_DIR + "/bin/superset worker -w 2" - subprocess.Popen(worker_command, shell=True, stdout=subprocess.PIPE) + base_dir = app.config["BASE_DIR"] + worker_command = base_dir + "/bin/superset worker -w 2" + subprocess.Popen(worker_command, shell=True, stdout=subprocess.PIPE) @classmethod def tearDownClass(cls): @@ -190,6 +188,7 @@ def test_run_async_query(self): result = self.run_sql( db_id, sql_where, "4", async_=True, tmp_table="tmp_async_1", cta=True ) + db.session.close() assert result["query"]["state"] in ( QueryStatus.PENDING, QueryStatus.RUNNING, @@ -224,6 +223,7 @@ def test_run_async_query_with_lower_limit(self): result = self.run_sql( db_id, sql_where, "5", async_=True, tmp_table=tmp_table, cta=True ) + db.session.close() assert result["query"]["state"] in ( QueryStatus.PENDING, QueryStatus.RUNNING, diff --git a/tests/core_tests.py b/tests/core_tests.py index 6b15e593f94ff..b227c58532dd2 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# isort:skip_file """Unit tests for Superset""" import cgi import csv @@ -33,7 +34,8 @@ import psycopg2 import sqlalchemy as sqla -from superset import app, dataframe, db, jinja_context, security_manager, sql_lab +from tests.test_app import app +from superset import dataframe, db, jinja_context, security_manager, sql_lab from superset.connectors.sqla.models import SqlaTable from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.mssql import MssqlEngineSpec @@ -51,16 +53,13 @@ class CoreTests(SupersetTestCase): def __init__(self, *args, **kwargs): super(CoreTests, self).__init__(*args, **kwargs) - @classmethod - def setUpClass(cls): - cls.table_ids = { - tbl.table_name: tbl.id for tbl in (db.session.query(SqlaTable).all()) - } - def setUp(self): db.session.query(Query).delete() db.session.query(models.DatasourceAccessRequest).delete() db.session.query(models.Log).delete() + self.table_ids = { + tbl.table_name: tbl.id for tbl in (db.session.query(SqlaTable).all()) + } def tearDown(self): db.session.query(Query).delete() @@ -196,12 +195,11 @@ def assert_admin_view_menus_in(role_name, assert_func): def test_save_slice(self): self.login(username="admin") - slice_name = "Energy Sankey" + slice_name = f"Energy Sankey" slice_id = self.get_slice(slice_name, db.session).id - db.session.commit() - copy_name = "Test Sankey Save" + copy_name = f"Test Sankey Save_{random.random()}" tbl_id = self.table_ids.get("energy_usage") - new_slice_name = "Test Sankey Overwirte" + new_slice_name = f"Test Sankey Overwrite_{random.random()}" url = ( "/superset/explore/table/{}/?slice_name={}&" @@ -216,13 +214,17 @@ def test_save_slice(self): "slice_id": slice_id, } # Changing name and save as a new slice - self.get_resp( + resp = self.client.post( url.format(tbl_id, copy_name, "saveas"), - {"form_data": json.dumps(form_data)}, + data={"form_data": json.dumps(form_data)}, ) - slices = db.session.query(models.Slice).filter_by(slice_name=copy_name).all() - assert len(slices) == 1 - new_slice_id = slices[0].id + db.session.expunge_all() + new_slice_id = resp.json["form_data"]["slice_id"] + slc = db.session.query(models.Slice).filter_by(id=new_slice_id).one() + + self.assertEqual(slc.slice_name, copy_name) + form_data.pop("slice_id") # We don't save the slice id when saving as + self.assertEqual(slc.viz.form_data, form_data) form_data = { "viz_type": "sankey", @@ -233,14 +235,18 @@ def test_save_slice(self): "time_range": "now", } # Setting the name back to its original name by overwriting new slice - self.get_resp( + self.client.post( url.format(tbl_id, new_slice_name, "overwrite"), - {"form_data": json.dumps(form_data)}, + data={"form_data": json.dumps(form_data)}, ) - slc = db.session.query(models.Slice).filter_by(id=new_slice_id).first() - assert slc.slice_name == new_slice_name - assert slc.viz.form_data == form_data + db.session.expunge_all() + slc = db.session.query(models.Slice).filter_by(id=new_slice_id).one() + self.assertEqual(slc.slice_name, new_slice_name) + self.assertEqual(slc.viz.form_data, form_data) + + # Cleanup db.session.delete(slc) + db.session.commit() def test_filter_endpoint(self): self.login(username="admin") @@ -406,10 +412,16 @@ def test_databaseview_edit(self, username="admin"): database = utils.get_example_database() self.assertEqual(sqlalchemy_uri_decrypted, database.sqlalchemy_uri_decrypted) + # Need to clean up after ourselves + database.impersonate_user = False + database.allow_dml = False + database.allow_run_async = False + db.session.commit() + def test_warm_up_cache(self): slc = self.get_slice("Girls", db.session) data = self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(slc.id)) - assert data == [{"slice_id": slc.id, "slice_name": slc.slice_name}] + self.assertEqual(data, [{"slice_id": slc.id, "slice_name": slc.slice_name}]) data = self.get_json_resp( "/superset/warm_up_cache?table_name=energy_usage&db_name=main" @@ -430,13 +442,10 @@ def test_shortner(self): assert re.search(r"\/r\/[0-9]+", resp.data.decode("utf-8")) def test_kv(self): - self.logout() self.login(username="admin") - try: - resp = self.client.post("/kv/store/", data=dict()) - except Exception: - self.assertRaises(TypeError) + resp = self.client.get("/kv/10001/") + self.assertEqual(404, resp.status_code) value = json.dumps({"data": "this is a test"}) resp = self.client.post("/kv/store/", data=dict(data=value)) @@ -449,11 +458,6 @@ def test_kv(self): self.assertEqual(resp.status_code, 200) self.assertEqual(json.loads(value), json.loads(resp.data.decode("utf-8"))) - try: - resp = self.client.get("/kv/10001/") - except Exception: - self.assertRaises(TypeError) - def test_gamma(self): self.login(username="gamma") assert "Charts" in self.get_resp("/chart/list/") @@ -808,6 +812,7 @@ def test_schemas_access_for_csv_upload_endpoint( ) ) assert data == ["this_schema_is_allowed_too"] + self.delete_fake_db() def test_select_star(self): self.login(username="admin") @@ -950,7 +955,11 @@ def test_results_msgpack_deserialization(self): self.assertDictEqual(deserialized_payload, payload) expand_data.assert_called_once() - @mock.patch.dict("superset._feature_flags", {"FOO": lambda x: 1}, clear=True) + @mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + {"FOO": lambda x: 1}, + clear=True, + ) def test_feature_flag_serialization(self): """ Functions in feature flags don't break bootstrap data serialization. diff --git a/tests/dashboard_tests.py b/tests/dashboard_tests.py index 6753b9ae7ca51..4dca9bf156635 100644 --- a/tests/dashboard_tests.py +++ b/tests/dashboard_tests.py @@ -17,6 +17,7 @@ """Unit tests for Superset""" import json import unittest +from random import random from flask import escape from sqlalchemy import func @@ -399,16 +400,19 @@ def test_users_can_view_published_dashboard(self): self.grant_public_access_to_table(table) + hidden_dash_slug = f"hidden_dash_{random()}" + published_dash_slug = f"published_dash_{random()}" + # Create a published and hidden dashboard and add them to the database published_dash = models.Dashboard() published_dash.dashboard_title = "Published Dashboard" - published_dash.slug = "published_dash" + published_dash.slug = published_dash_slug published_dash.slices = [slice] published_dash.published = True hidden_dash = models.Dashboard() hidden_dash.dashboard_title = "Hidden Dashboard" - hidden_dash.slug = "hidden_dash" + hidden_dash.slug = hidden_dash_slug hidden_dash.slices = [slice] hidden_dash.published = False @@ -417,22 +421,24 @@ def test_users_can_view_published_dashboard(self): db.session.commit() resp = self.get_resp("/dashboard/list/") - self.assertNotIn("/superset/dashboard/hidden_dash/", resp) - self.assertIn("/superset/dashboard/published_dash/", resp) + self.assertNotIn(f"/superset/dashboard/{hidden_dash_slug}/", resp) + self.assertIn(f"/superset/dashboard/{published_dash_slug}/", resp) def test_users_can_view_own_dashboard(self): user = security_manager.find_user("gamma") + my_dash_slug = f"my_dash_{random()}" + not_my_dash_slug = f"not_my_dash_{random()}" # Create one dashboard I own and another that I don't dash = models.Dashboard() dash.dashboard_title = "My Dashboard" - dash.slug = "my_dash" + dash.slug = my_dash_slug dash.owners = [user] dash.slices = [] hidden_dash = models.Dashboard() hidden_dash.dashboard_title = "Not My Dashboard" - hidden_dash.slug = "not_my_dash" + hidden_dash.slug = not_my_dash_slug hidden_dash.slices = [] hidden_dash.owners = [] @@ -443,29 +449,27 @@ def test_users_can_view_own_dashboard(self): self.login(user.username) resp = self.get_resp("/dashboard/list/") - self.assertIn("/superset/dashboard/my_dash/", resp) - self.assertNotIn("/superset/dashboard/not_my_dash/", resp) + self.assertIn(f"/superset/dashboard/{my_dash_slug}/", resp) + self.assertNotIn(f"/superset/dashboard/{not_my_dash_slug}/", resp) def test_users_can_view_favorited_dashboards(self): user = security_manager.find_user("gamma") + fav_dash_slug = f"my_favorite_dash_{random()}" + regular_dash_slug = f"regular_dash_{random()}" favorite_dash = models.Dashboard() favorite_dash.dashboard_title = "My Favorite Dashboard" - favorite_dash.slug = "my_favorite_dash" + favorite_dash.slug = fav_dash_slug regular_dash = models.Dashboard() regular_dash.dashboard_title = "A Plain Ol Dashboard" - regular_dash.slug = "regular_dash" + regular_dash.slug = regular_dash_slug db.session.merge(favorite_dash) db.session.merge(regular_dash) db.session.commit() - dash = ( - db.session.query(models.Dashboard) - .filter_by(slug="my_favorite_dash") - .first() - ) + dash = db.session.query(models.Dashboard).filter_by(slug=fav_dash_slug).first() favorites = models.FavStar() favorites.obj_id = dash.id @@ -478,12 +482,12 @@ def test_users_can_view_favorited_dashboards(self): self.login(user.username) resp = self.get_resp("/dashboard/list/") - self.assertIn("/superset/dashboard/my_favorite_dash/", resp) + self.assertIn(f"/superset/dashboard/{fav_dash_slug}/", resp) def test_user_can_not_view_unpublished_dash(self): admin_user = security_manager.find_user("admin") gamma_user = security_manager.find_user("gamma") - slug = "admin_owned_unpublished_dash" + slug = f"admin_owned_unpublished_dash_{random()}" # Create a dashboard owned by admin and unpublished dash = models.Dashboard() diff --git a/tests/db_engine_specs/presto_tests.py b/tests/db_engine_specs/presto_tests.py index 48ad18a244503..bfb032294bee8 100644 --- a/tests/db_engine_specs/presto_tests.py +++ b/tests/db_engine_specs/presto_tests.py @@ -60,7 +60,9 @@ def test_presto_get_column(self): self.verify_presto_column(presto_column, expected_results) @mock.patch.dict( - "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True + "superset.extensions.feature_flag_manager._feature_flags", + {"PRESTO_EXPAND_DATA": True}, + clear=True, ) def test_presto_get_simple_row_column(self): presto_column = ("column_name", "row(nested_obj double)", "") @@ -68,7 +70,9 @@ def test_presto_get_simple_row_column(self): self.verify_presto_column(presto_column, expected_results) @mock.patch.dict( - "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True + "superset.extensions.feature_flag_manager._feature_flags", + {"PRESTO_EXPAND_DATA": True}, + clear=True, ) def test_presto_get_simple_row_column_with_name_containing_whitespace(self): presto_column = ("column name", "row(nested_obj double)", "") @@ -76,7 +80,9 @@ def test_presto_get_simple_row_column_with_name_containing_whitespace(self): self.verify_presto_column(presto_column, expected_results) @mock.patch.dict( - "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True + "superset.extensions.feature_flag_manager._feature_flags", + {"PRESTO_EXPAND_DATA": True}, + clear=True, ) def test_presto_get_simple_row_column_with_tricky_nested_field_name(self): presto_column = ("column_name", 'row("Field Name(Tricky, Name)" double)', "") @@ -87,7 +93,9 @@ def test_presto_get_simple_row_column_with_tricky_nested_field_name(self): self.verify_presto_column(presto_column, expected_results) @mock.patch.dict( - "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True + "superset.extensions.feature_flag_manager._feature_flags", + {"PRESTO_EXPAND_DATA": True}, + clear=True, ) def test_presto_get_simple_array_column(self): presto_column = ("column_name", "array(double)", "") @@ -95,7 +103,9 @@ def test_presto_get_simple_array_column(self): self.verify_presto_column(presto_column, expected_results) @mock.patch.dict( - "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True + "superset.extensions.feature_flag_manager._feature_flags", + {"PRESTO_EXPAND_DATA": True}, + clear=True, ) def test_presto_get_row_within_array_within_row_column(self): presto_column = ( @@ -112,7 +122,9 @@ def test_presto_get_row_within_array_within_row_column(self): self.verify_presto_column(presto_column, expected_results) @mock.patch.dict( - "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True + "superset.extensions.feature_flag_manager._feature_flags", + {"PRESTO_EXPAND_DATA": True}, + clear=True, ) def test_presto_get_array_within_row_within_array_column(self): presto_column = ( @@ -147,7 +159,9 @@ def test_presto_get_fields(self): self.assertEqual(actual_result.name, expected_result["label"]) @mock.patch.dict( - "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True + "superset.extensions.feature_flag_manager._feature_flags", + {"PRESTO_EXPAND_DATA": True}, + clear=True, ) def test_presto_expand_data_with_simple_structural_columns(self): cols = [ @@ -182,7 +196,9 @@ def test_presto_expand_data_with_simple_structural_columns(self): self.assertEqual(actual_expanded_cols, expected_expanded_cols) @mock.patch.dict( - "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True + "superset.extensions.feature_flag_manager._feature_flags", + {"PRESTO_EXPAND_DATA": True}, + clear=True, ) def test_presto_expand_data_with_complex_row_columns(self): cols = [ @@ -229,7 +245,9 @@ def test_presto_expand_data_with_complex_row_columns(self): self.assertEqual(actual_expanded_cols, expected_expanded_cols) @mock.patch.dict( - "superset._feature_flags", {"PRESTO_EXPAND_DATA": True}, clear=True + "superset.extensions.feature_flag_manager._feature_flags", + {"PRESTO_EXPAND_DATA": True}, + clear=True, ) def test_presto_expand_data_with_complex_array_columns(self): cols = [ diff --git a/tests/dict_import_export_tests.py b/tests/dict_import_export_tests.py index 4890c345c87e2..d30443e3c8076 100644 --- a/tests/dict_import_export_tests.py +++ b/tests/dict_import_export_tests.py @@ -14,12 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# isort:skip_file """Unit tests for Superset""" import json import unittest import yaml +from tests.test_app import app from superset import db from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn @@ -41,15 +43,16 @@ def __init__(self, *args, **kwargs): @classmethod def delete_imports(cls): - # Imported data clean up - session = db.session - for table in session.query(SqlaTable): - if DBREF in table.params_dict: - session.delete(table) - for datasource in session.query(DruidDatasource): - if DBREF in datasource.params_dict: - session.delete(datasource) - session.commit() + with app.app_context(): + # Imported data clean up + session = db.session + for table in session.query(SqlaTable): + if DBREF in table.params_dict: + session.delete(table) + for datasource in session.query(DruidDatasource): + if DBREF in datasource.params_dict: + session.delete(datasource) + session.commit() @classmethod def setUpClass(cls): diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py index 3afbfaccada49..508d45f3f7827 100644 --- a/tests/druid_func_tests.py +++ b/tests/druid_func_tests.py @@ -47,7 +47,7 @@ def emplace(metrics_dict, metric_name, is_postagg=False): # Unit tests that can be run without initializing base tests -class DruidFuncTestCase(unittest.TestCase): +class DruidFuncTestCase(SupersetTestCase): @unittest.skipUnless( SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" ) diff --git a/tests/email_tests.py b/tests/email_tests.py index ba4c2638e84f6..e8fd3578886e9 100644 --- a/tests/email_tests.py +++ b/tests/email_tests.py @@ -26,13 +26,14 @@ from superset import app from superset.utils import core as utils +from tests.base_tests import SupersetTestCase from .utils import read_fixture send_email_test = mock.Mock() -class EmailSmtpTest(unittest.TestCase): +class EmailSmtpTest(SupersetTestCase): def setUp(self): app.config["smtp_ssl"] = False diff --git a/tests/feature_flag_tests.py b/tests/feature_flag_tests.py new file mode 100644 index 0000000000000..8712c63657d44 --- /dev/null +++ b/tests/feature_flag_tests.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest.mock import patch + +from superset import is_feature_enabled +from tests.base_tests import SupersetTestCase + + +class FeatureFlagTests(SupersetTestCase): + @patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + {"FOO": True}, + clear=True, + ) + def test_existing_feature_flags(self): + self.assertTrue(is_feature_enabled("FOO")) + + @patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", {}, clear=True + ) + def test_nonexistent_feature_flags(self): + self.assertFalse(is_feature_enabled("FOO")) + + def test_feature_flags(self): + self.assertEqual(is_feature_enabled("foo"), "bar") + self.assertEqual(is_feature_enabled("super"), "set") diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index c4f032fccc534..2641e35adbdc9 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -14,13 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# isort:skip_file """Unit tests for Superset""" import json import unittest -from flask import Flask, g +from flask import g from sqlalchemy.orm.session import make_transient +from tests.test_app import app from superset import db, security_manager from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn @@ -35,21 +37,22 @@ class ImportExportTests(SupersetTestCase): @classmethod def delete_imports(cls): - # Imported data clean up - session = db.session - for slc in session.query(models.Slice): - if "remote_id" in slc.params_dict: - session.delete(slc) - for dash in session.query(models.Dashboard): - if "remote_id" in dash.params_dict: - session.delete(dash) - for table in session.query(SqlaTable): - if "remote_id" in table.params_dict: - session.delete(table) - for datasource in session.query(DruidDatasource): - if "remote_id" in datasource.params_dict: - session.delete(datasource) - session.commit() + with app.app_context(): + # Imported data clean up + session = db.session + for slc in session.query(models.Slice): + if "remote_id" in slc.params_dict: + session.delete(slc) + for dash in session.query(models.Dashboard): + if "remote_id" in dash.params_dict: + session.delete(dash) + for table in session.query(SqlaTable): + if "remote_id" in table.params_dict: + session.delete(table) + for datasource in session.query(DruidDatasource): + if "remote_id" in datasource.params_dict: + session.delete(datasource) + session.commit() @classmethod def setUpClass(cls): @@ -460,68 +463,64 @@ def test_import_override_dashboard_2_slices(self): ) def test_import_new_dashboard_slice_reset_ownership(self): - app = Flask("test_import_dashboard_slice_set_user") - with app.app_context(): - admin_user = security_manager.find_user(username="admin") - self.assertTrue(admin_user) - gamma_user = security_manager.find_user(username="gamma") - self.assertTrue(gamma_user) - g.user = gamma_user - - dash_with_1_slice = self._create_dashboard_for_import(id_=10200) - # set another user as an owner of importing dashboard - dash_with_1_slice.created_by = admin_user - dash_with_1_slice.changed_by = admin_user - dash_with_1_slice.owners = [admin_user] - - imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice) - imported_dash = self.get_dash(imported_dash_id) - self.assertEqual(imported_dash.created_by, gamma_user) - self.assertEqual(imported_dash.changed_by, gamma_user) - self.assertEqual(imported_dash.owners, [gamma_user]) - - imported_slc = imported_dash.slices[0] - self.assertEqual(imported_slc.created_by, gamma_user) - self.assertEqual(imported_slc.changed_by, gamma_user) - self.assertEqual(imported_slc.owners, [gamma_user]) + admin_user = security_manager.find_user(username="admin") + self.assertTrue(admin_user) + gamma_user = security_manager.find_user(username="gamma") + self.assertTrue(gamma_user) + g.user = gamma_user + + dash_with_1_slice = self._create_dashboard_for_import(id_=10200) + # set another user as an owner of importing dashboard + dash_with_1_slice.created_by = admin_user + dash_with_1_slice.changed_by = admin_user + dash_with_1_slice.owners = [admin_user] + + imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice) + imported_dash = self.get_dash(imported_dash_id) + self.assertEqual(imported_dash.created_by, gamma_user) + self.assertEqual(imported_dash.changed_by, gamma_user) + self.assertEqual(imported_dash.owners, [gamma_user]) + + imported_slc = imported_dash.slices[0] + self.assertEqual(imported_slc.created_by, gamma_user) + self.assertEqual(imported_slc.changed_by, gamma_user) + self.assertEqual(imported_slc.owners, [gamma_user]) def test_import_override_dashboard_slice_reset_ownership(self): - app = Flask("test_import_dashboard_slice_set_user") - with app.app_context(): - admin_user = security_manager.find_user(username="admin") - self.assertTrue(admin_user) - gamma_user = security_manager.find_user(username="gamma") - self.assertTrue(gamma_user) - g.user = gamma_user - - dash_with_1_slice = self._create_dashboard_for_import(id_=10300) - - imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice) - imported_dash = self.get_dash(imported_dash_id) - self.assertEqual(imported_dash.created_by, gamma_user) - self.assertEqual(imported_dash.changed_by, gamma_user) - self.assertEqual(imported_dash.owners, [gamma_user]) - - imported_slc = imported_dash.slices[0] - self.assertEqual(imported_slc.created_by, gamma_user) - self.assertEqual(imported_slc.changed_by, gamma_user) - self.assertEqual(imported_slc.owners, [gamma_user]) - - # re-import with another user shouldn't change the permissions - g.user = admin_user - - dash_with_1_slice = self._create_dashboard_for_import(id_=10300) - - imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice) - imported_dash = self.get_dash(imported_dash_id) - self.assertEqual(imported_dash.created_by, gamma_user) - self.assertEqual(imported_dash.changed_by, gamma_user) - self.assertEqual(imported_dash.owners, [gamma_user]) - - imported_slc = imported_dash.slices[0] - self.assertEqual(imported_slc.created_by, gamma_user) - self.assertEqual(imported_slc.changed_by, gamma_user) - self.assertEqual(imported_slc.owners, [gamma_user]) + admin_user = security_manager.find_user(username="admin") + self.assertTrue(admin_user) + gamma_user = security_manager.find_user(username="gamma") + self.assertTrue(gamma_user) + g.user = gamma_user + + dash_with_1_slice = self._create_dashboard_for_import(id_=10300) + + imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice) + imported_dash = self.get_dash(imported_dash_id) + self.assertEqual(imported_dash.created_by, gamma_user) + self.assertEqual(imported_dash.changed_by, gamma_user) + self.assertEqual(imported_dash.owners, [gamma_user]) + + imported_slc = imported_dash.slices[0] + self.assertEqual(imported_slc.created_by, gamma_user) + self.assertEqual(imported_slc.changed_by, gamma_user) + self.assertEqual(imported_slc.owners, [gamma_user]) + + # re-import with another user shouldn't change the permissions + g.user = admin_user + + dash_with_1_slice = self._create_dashboard_for_import(id_=10300) + + imported_dash_id = models.Dashboard.import_obj(dash_with_1_slice) + imported_dash = self.get_dash(imported_dash_id) + self.assertEqual(imported_dash.created_by, gamma_user) + self.assertEqual(imported_dash.changed_by, gamma_user) + self.assertEqual(imported_dash.owners, [gamma_user]) + + imported_slc = imported_dash.slices[0] + self.assertEqual(imported_slc.created_by, gamma_user) + self.assertEqual(imported_slc.changed_by, gamma_user) + self.assertEqual(imported_slc.owners, [gamma_user]) def _create_dashboard_for_import(self, id_=10100): slc = self.create_slice("health_slc" + str(id_), id=id_ + 1) diff --git a/tests/load_examples_test.py b/tests/load_examples_test.py index d98542f4f2c21..5b61730163d32 100644 --- a/tests/load_examples_test.py +++ b/tests/load_examples_test.py @@ -14,27 +14,36 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from superset import examples -from superset.cli import load_test_users_run - from .base_tests import SupersetTestCase class SupersetDataFrameTestCase(SupersetTestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.examples = None + + def setUp(self) -> None: + # Late importing here as we need an app context to be pushed... + from superset import examples + + self.examples = examples + def test_load_css_templates(self): - examples.load_css_templates() + self.examples.load_css_templates() def test_load_energy(self): - examples.load_energy() + self.examples.load_energy() def test_load_world_bank_health_n_pop(self): - examples.load_world_bank_health_n_pop() + self.examples.load_world_bank_health_n_pop() def test_load_birth_names(self): - examples.load_birth_names() + self.examples.load_birth_names() def test_load_test_users_run(self): + from superset.cli import load_test_users_run + load_test_users_run() def test_load_unicode_test_data(self): - examples.load_unicode_test_data() + self.examples.load_unicode_test_data() diff --git a/tests/schedules_test.py b/tests/schedules_test.py index e6ad92d65d626..da55f9efa287a 100644 --- a/tests/schedules_test.py +++ b/tests/schedules_test.py @@ -14,14 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import unittest +# isort:skip_file from datetime import datetime, timedelta from unittest.mock import Mock, patch, PropertyMock from flask_babel import gettext as __ from selenium.common.exceptions import WebDriverException -from superset import app, db +from tests.test_app import app +from superset import db from superset.models.core import Dashboard, Slice from superset.models.schedules import ( DashboardEmailSchedule, @@ -35,11 +36,12 @@ deliver_slice, next_schedules, ) +from tests.base_tests import SupersetTestCase from .utils import read_fixture -class SchedulesTestCase(unittest.TestCase): +class SchedulesTestCase(SupersetTestCase): RECIPIENTS = "recipient1@superset.com, recipient2@superset.com" BCC = "bcc@superset.com" @@ -47,41 +49,45 @@ class SchedulesTestCase(unittest.TestCase): @classmethod def setUpClass(cls): - cls.common_data = dict( - active=True, - crontab="* * * * *", - recipients=cls.RECIPIENTS, - deliver_as_group=True, - delivery_type=EmailDeliveryType.inline, - ) + with app.app_context(): + cls.common_data = dict( + active=True, + crontab="* * * * *", + recipients=cls.RECIPIENTS, + deliver_as_group=True, + delivery_type=EmailDeliveryType.inline, + ) - # Pick up a random slice and dashboard - slce = db.session.query(Slice).all()[0] - dashboard = db.session.query(Dashboard).all()[0] + # Pick up a random slice and dashboard + slce = db.session.query(Slice).all()[0] + dashboard = db.session.query(Dashboard).all()[0] - dashboard_schedule = DashboardEmailSchedule(**cls.common_data) - dashboard_schedule.dashboard_id = dashboard.id - dashboard_schedule.user_id = 1 - db.session.add(dashboard_schedule) + dashboard_schedule = DashboardEmailSchedule(**cls.common_data) + dashboard_schedule.dashboard_id = dashboard.id + dashboard_schedule.user_id = 1 + db.session.add(dashboard_schedule) - slice_schedule = SliceEmailSchedule(**cls.common_data) - slice_schedule.slice_id = slce.id - slice_schedule.user_id = 1 - slice_schedule.email_format = SliceEmailReportFormat.data + slice_schedule = SliceEmailSchedule(**cls.common_data) + slice_schedule.slice_id = slce.id + slice_schedule.user_id = 1 + slice_schedule.email_format = SliceEmailReportFormat.data - db.session.add(slice_schedule) - db.session.commit() + db.session.add(slice_schedule) + db.session.commit() - cls.slice_schedule = slice_schedule.id - cls.dashboard_schedule = dashboard_schedule.id + cls.slice_schedule = slice_schedule.id + cls.dashboard_schedule = dashboard_schedule.id @classmethod def tearDownClass(cls): - db.session.query(SliceEmailSchedule).filter_by(id=cls.slice_schedule).delete() - db.session.query(DashboardEmailSchedule).filter_by( - id=cls.dashboard_schedule - ).delete() - db.session.commit() + with app.app_context(): + db.session.query(SliceEmailSchedule).filter_by( + id=cls.slice_schedule + ).delete() + db.session.query(DashboardEmailSchedule).filter_by( + id=cls.dashboard_schedule + ).delete() + db.session.commit() def test_crontab_scheduler(self): crontab = "* * * * *" diff --git a/tests/security_tests.py b/tests/security_tests.py index 4478389902fc9..0391d0d51e7ba 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -310,6 +310,7 @@ def test_views_are_secured(self): ["Superset", "welcome"], ["SecurityApi", "login"], ["SecurityApi", "refresh"], + ["SupersetIndexView", "index"], ] unsecured_views = [] for view_class in appbuilder.baseviews: diff --git a/tests/sql_validator_tests.py b/tests/sql_validator_tests.py index 069bae8a6bc20..553e799039fa8 100644 --- a/tests/sql_validator_tests.py +++ b/tests/sql_validator_tests.py @@ -60,7 +60,11 @@ def test_validate_sql_endpoint_noconfig(self): self.assertIn("no SQL validator is configured", resp["error"]) @patch("superset.views.core.get_validator_by_name") - @patch.dict("superset._feature_flags", PRESTO_TEST_FEATURE_FLAGS, clear=True) + @patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + PRESTO_TEST_FEATURE_FLAGS, + clear=True, + ) def test_validate_sql_endpoint_mocked(self, get_validator_by_name): """Assert that, with a mocked validator, annotations make it back out from the validate_sql_json endpoint as a list of json dictionaries""" @@ -87,7 +91,11 @@ def test_validate_sql_endpoint_mocked(self, get_validator_by_name): self.assertIn("expected,", resp[0]["message"]) @patch("superset.views.core.get_validator_by_name") - @patch.dict("superset._feature_flags", PRESTO_TEST_FEATURE_FLAGS, clear=True) + @patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + PRESTO_TEST_FEATURE_FLAGS, + clear=True, + ) def test_validate_sql_endpoint_failure(self, get_validator_by_name): """Assert that validate_sql_json errors out when the selected validator raises an unexpected exception""" diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index dde381742f416..a37ea06d7774a 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -17,6 +17,7 @@ """Unit tests for Sql Lab""" import json from datetime import datetime, timedelta +from random import random import prison @@ -294,19 +295,19 @@ def test_sqllab_viz(self): examples_dbid = get_example_database().id payload = { "chartType": "dist_bar", - "datasourceName": "test_viz_flow_table", + "datasourceName": f"test_viz_flow_table_{random()}", "schema": "superset", "columns": [ { "is_date": False, "type": "STRING", - "name": "viz_type", + "name": f"viz_type_{random()}", "is_dim": True, }, { "is_date": False, "type": "OBJECT", - "name": "ccount", + "name": f"ccount_{random()}", "is_dim": True, "agg": "sum", }, @@ -421,3 +422,4 @@ def test_api_database(self): {"examples", "fake_db_100"}, {r.get("database_name") for r in self.get_json_resp(url)["result"]}, ) + self.delete_fake_db() diff --git a/tests/test_app.py b/tests/test_app.py new file mode 100644 index 0000000000000..f3e708b550fd3 --- /dev/null +++ b/tests/test_app.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Here is where we create the app which ends up being shared across all tests. A future +optimization will be to create a separate app instance for each test class. +""" +from superset.app import create_app + +app = create_app() diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 81adb93a0a01a..b44282cade4da 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -28,6 +28,7 @@ from superset import app, db, security_manager from superset.exceptions import SupersetException from superset.models.core import Database +from superset.utils.cache_manager import CacheManager from superset.utils.core import ( base_json_conv, convert_legacy_filters_into_adhoc, @@ -45,7 +46,6 @@ parse_human_timedelta, parse_js_uri_path_item, parse_past_timedelta, - setup_cache, split, TimeRangeEndpoint, validate_json, @@ -53,6 +53,7 @@ zlib_decompress, ) from superset.views.utils import get_time_range_endpoints +from tests.base_tests import SupersetTestCase def mock_parse_human_datetime(s): @@ -93,7 +94,7 @@ def mock_to_adhoc(filt, expressionType="SIMPLE", clause="where"): return result -class UtilsTestCase(unittest.TestCase): +class UtilsTestCase(SupersetTestCase): def test_json_int_dttm_ser(self): dttm = datetime(2020, 1, 1) ts = 1577836800000.0 @@ -809,12 +810,12 @@ def test_parse_js_uri_path_items_item_optional(self): def test_setup_cache_no_config(self): app = Flask(__name__) cache_config = None - self.assertIsNone(setup_cache(app, cache_config)) + self.assertIsNone(CacheManager._setup_cache(app, cache_config)) def test_setup_cache_null_config(self): app = Flask(__name__) cache_config = {"CACHE_TYPE": "null"} - self.assertIsNone(setup_cache(app, cache_config)) + self.assertIsNone(CacheManager._setup_cache(app, cache_config)) def test_setup_cache_standard_config(self): app = Flask(__name__) @@ -824,7 +825,7 @@ def test_setup_cache_standard_config(self): "CACHE_KEY_PREFIX": "superset_results", "CACHE_REDIS_URL": "redis://localhost:6379/0", } - assert isinstance(setup_cache(app, cache_config), Cache) is True + assert isinstance(CacheManager._setup_cache(app, cache_config), Cache) is True def test_setup_cache_custom_function(self): app = Flask(__name__) @@ -833,7 +834,9 @@ def test_setup_cache_custom_function(self): def init_cache(app): return CustomCache(app, {}) - assert isinstance(setup_cache(app, init_cache), CustomCache) is True + assert ( + isinstance(CacheManager._setup_cache(app, init_cache), CustomCache) is True + ) def test_get_stacktrace(self): with app.app_context(): @@ -879,6 +882,8 @@ def test_get_or_create_db(self): get_or_create_db("test_db", "sqlite:///changed.db") database = db.session.query(Database).filter_by(database_name="test_db").one() self.assertEqual(database.sqlalchemy_uri, "sqlite:///changed.db") + db.session.delete(database) + db.session.commit() def test_get_or_create_db_invalid_uri(self): with self.assertRaises(ArgumentError): diff --git a/tox.ini b/tox.ini index 6873d291302ab..5f2d017af1c0c 100644 --- a/tox.ini +++ b/tox.ini @@ -19,7 +19,7 @@ commands = {toxinidir}/superset/bin/superset db upgrade {toxinidir}/superset/bin/superset init nosetests tests/load_examples_test.py - nosetests -e load_examples_test {posargs} + nosetests -e load_examples_test tests {posargs} deps = -rrequirements.txt -rrequirements-dev.txt