diff --git a/superset/config.py b/superset/config.py index d522e10ac1ec3..a424a09d23d8c 100644 --- a/superset/config.py +++ b/superset/config.py @@ -515,6 +515,7 @@ class D3Format(TypedDict, total=False): # ---------------------------------------------------------------------- SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager" SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1" +SSH_TUNNEL_TIMEOUT_SEC = 10.0 # Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars. DEFAULT_FEATURE_FLAGS.update( diff --git a/superset/extensions/ssh.py b/superset/extensions/ssh.py index 6a852ea7cd692..78b0c4116b192 100644 --- a/superset/extensions/ssh.py +++ b/superset/extensions/ssh.py @@ -20,9 +20,9 @@ from io import StringIO from typing import TYPE_CHECKING +import sshtunnel from flask import Flask from paramiko import RSAKey -from sshtunnel import open_tunnel, SSHTunnelForwarder from superset.databases.utils import make_url_safe @@ -34,9 +34,10 @@ class SSHManager: def __init__(self, app: Flask) -> None: super().__init__() self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"] + sshtunnel.TUNNEL_TIMEOUT = app.config["SSH_TUNNEL_TIMEOUT_SEC"] def build_sqla_url( # pylint: disable=no-self-use - self, sqlalchemy_url: str, server: SSHTunnelForwarder + self, sqlalchemy_url: str, server: sshtunnel.SSHTunnelForwarder ) -> str: # override any ssh tunnel configuration object url = make_url_safe(sqlalchemy_url) @@ -49,7 +50,7 @@ def create_tunnel( self, ssh_tunnel: "SSHTunnel", sqlalchemy_database_uri: str, - ) -> SSHTunnelForwarder: + ) -> sshtunnel.SSHTunnelForwarder: url = make_url_safe(sqlalchemy_database_uri) params = { "ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port), @@ -68,7 +69,7 @@ def create_tunnel( ) params["ssh_pkey"] = private_key - return open_tunnel(**params) + return sshtunnel.open_tunnel(**params) class SSHManagerFactory: diff --git a/tests/unit_tests/extensions/ssh_test.py b/tests/unit_tests/extensions/ssh_test.py new file mode 100644 index 0000000000000..0e997729d96fe --- /dev/null +++ b/tests/unit_tests/extensions/ssh_test.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any +from unittest.mock import Mock, patch + +import pytest +import sshtunnel + +from superset.extensions.ssh import SSHManagerFactory + + +def test_ssh_tunnel_timeout_setting() -> None: + app = Mock() + app.config = { + "SSH_TUNNEL_MAX_RETRIES": 2, + "SSH_TUNNEL_LOCAL_BIND_ADDRESS": "test", + "SSH_TUNNEL_TIMEOUT_SEC": 123.0, + "SSH_TUNNEL_MANAGER_CLASS": "superset.extensions.ssh.SSHManager", + } + factory = SSHManagerFactory() + factory.init_app(app) + assert sshtunnel.TUNNEL_TIMEOUT == 123.0