diff --git a/superset/config.py b/superset/config.py index f163997c6ee4b..71b220333977a 100644 --- a/superset/config.py +++ b/superset/config.py @@ -45,6 +45,7 @@ ) import pkg_resources +import sshtunnel from cachelib.base import BaseCache from celery.schedules import crontab from dateutil import tz @@ -56,6 +57,7 @@ from superset.advanced_data_type.plugins.internet_port import internet_port from superset.advanced_data_type.types import AdvancedDataType from superset.constants import CHANGE_ME_SECRET_KEY +from superset.databases.utils import make_url_safe from superset.jinja_context import BaseTemplateProcessor from superset.reports.types import ReportScheduleExecutor from superset.stats_logger import DummyStatsLogger @@ -471,8 +473,42 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: "DRILL_TO_DETAIL": False, "DATAPANEL_CLOSED_BY_DEFAULT": False, "HORIZONTAL_FILTER_BAR": False, + # Allow users to enable ssh tunneling when creating a DB. + # Users must check whether the DB engine supports SSH Tunnels + # otherwise enabling this flag won't have any effect on the DB. + "SSH_TUNNELING": False, } +# ------------------------------ +# SSH Tunnel +# ------------------------------ +# Allow users to set the host used when connecting to the SSH Tunnel +# as localhost and any other alias (0.0.0.0) +# ---------------------------------------------------------------------- +# | +# -------------+ | +----------+ +# LOCAL | | | REMOTE | :22 SSH +# CLIENT | <== SSH ========> | SERVER | :8080 web service +# -------------+ | +----------+ +# | +# FIREWALL (only port 22 is open) + +# ---------------------------------------------------------------------- +class SSHManager: # pylint: disable=too-few-public-methods + local_bind_address = "127.0.0.1" + + @classmethod + def mutator(cls, sqlalchemy_url: str, server: sshtunnel.SSHTunnelForwarder) -> str: + # override any ssh tunnel configuration object + url = make_url_safe(sqlalchemy_url) + return url.set( + host=cls.local_bind_address, + port=server.local_bind_port, + ) + + +SSH_TUNNEL_MANAGER = SSHManager # pylint: disable=invalid-name + # Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars. DEFAULT_FEATURE_FLAGS.update( { @@ -1462,7 +1498,7 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument try: # pylint: disable=import-error,wildcard-import,unused-wildcard-import import superset_config - from superset_config import * # type:ignore + from superset_config import * # type: ignore print(f"Loaded your LOCAL configuration at [{superset_config.__file__}]") except Exception: diff --git a/superset/constants.py b/superset/constants.py index c0fbb7c2cd8db..7d759acf6741c 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -34,8 +34,6 @@ NO_TIME_RANGE = "No filter" -SSH_TUNNELLING_LOCAL_BIND_ADDRESS = "127.0.0.1" - class RouteMethod: # pylint: disable=too-few-public-methods """ diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py index 17f4628f8070f..5f334d8c154bb 100644 --- a/superset/databases/ssh_tunnel/models.py +++ b/superset/databases/ssh_tunnel/models.py @@ -23,7 +23,6 @@ from sqlalchemy.orm import backref, relationship from sqlalchemy_utils import EncryptedType -from superset.constants import SSH_TUNNELLING_LOCAL_BIND_ADDRESS from superset.models.core import Database from superset.models.helpers import ( AuditMixinNullable, @@ -32,6 +31,7 @@ ) app_config = current_app.config +ssh_manager = app_config["SSH_TUNNEL_MANAGER"] class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): @@ -74,7 +74,7 @@ def parameters(self, bind_host: str, bind_port: int) -> Dict[str, Any]: "ssh_port": self.server_port, "ssh_username": self.username, "remote_bind_address": (bind_host, bind_port), - "local_bind_address": (SSH_TUNNELLING_LOCAL_BIND_ADDRESS,), + "local_bind_address": (ssh_manager.local_bind_address,), } if self.password: diff --git a/superset/models/core.py b/superset/models/core.py index 790dd9fe9ced0..51bd0ce21b6ff 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -55,7 +55,7 @@ from sqlalchemy.sql import expression, Select from superset import app, db_engine_specs -from superset.constants import PASSWORD_MASK, SSH_TUNNELLING_LOCAL_BIND_ADDRESS +from superset.constants import PASSWORD_MASK from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import MetricType, TimeGrain from superset.extensions import cache_manager, encrypted_field_factory, security_manager @@ -66,6 +66,7 @@ from superset.utils.memoized import memoized config = app.config +ssh_manager = config["SSH_TUNNEL_MANAGER"] custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] stats_logger = config["STATS_LOGGER"] log_query = config["QUERY_LOGGER"] @@ -447,13 +448,8 @@ def _get_sqla_engine( ) if ssh_tunnel_server: - # update sqlalchemy_url - url = make_url_safe(sqlalchemy_url) - sqlalchemy_url = url.set( - host=SSH_TUNNELLING_LOCAL_BIND_ADDRESS, - port=ssh_tunnel_server.local_bind_port, - ) - + # update sqlalchemy_url with ssh tunnel manager info + sqlalchemy_url = ssh_manager.mutator(sqlalchemy_url, ssh_tunnel_server) try: return create_engine(sqlalchemy_url, **params) except Exception as ex: