Skip to content

Commit

Permalink
chore(ssh): Allow users to set TUNNEL_TIMEOUT from config (#24202)
Browse files Browse the repository at this point in the history
  • Loading branch information
hughhhh authored May 24, 2023
1 parent c54eedf commit 8b0c68c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
1 change: 1 addition & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ class D3Format(TypedDict, total=False):
# ----------------------------------------------------------------------
SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager"
SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1"
SSH_TUNNEL_TIMEOUT_SEC = 10.0

# Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars.
DEFAULT_FEATURE_FLAGS.update(
Expand Down
9 changes: 5 additions & 4 deletions superset/extensions/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from io import StringIO
from typing import TYPE_CHECKING

import sshtunnel
from flask import Flask
from paramiko import RSAKey
from sshtunnel import open_tunnel, SSHTunnelForwarder

from superset.databases.utils import make_url_safe

Expand All @@ -34,9 +34,10 @@ class SSHManager:
def __init__(self, app: Flask) -> None:
super().__init__()
self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
sshtunnel.TUNNEL_TIMEOUT = app.config["SSH_TUNNEL_TIMEOUT_SEC"]

def build_sqla_url( # pylint: disable=no-self-use
self, sqlalchemy_url: str, server: SSHTunnelForwarder
self, sqlalchemy_url: str, server: sshtunnel.SSHTunnelForwarder
) -> str:
# override any ssh tunnel configuration object
url = make_url_safe(sqlalchemy_url)
Expand All @@ -49,7 +50,7 @@ def create_tunnel(
self,
ssh_tunnel: "SSHTunnel",
sqlalchemy_database_uri: str,
) -> SSHTunnelForwarder:
) -> sshtunnel.SSHTunnelForwarder:
url = make_url_safe(sqlalchemy_database_uri)
params = {
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
Expand All @@ -68,7 +69,7 @@ def create_tunnel(
)
params["ssh_pkey"] = private_key

return open_tunnel(**params)
return sshtunnel.open_tunnel(**params)


class SSHManagerFactory:
Expand Down
36 changes: 36 additions & 0 deletions tests/unit_tests/extensions/ssh_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any
from unittest.mock import Mock, patch

import pytest
import sshtunnel

from superset.extensions.ssh import SSHManagerFactory


def test_ssh_tunnel_timeout_setting() -> None:
app = Mock()
app.config = {
"SSH_TUNNEL_MAX_RETRIES": 2,
"SSH_TUNNEL_LOCAL_BIND_ADDRESS": "test",
"SSH_TUNNEL_TIMEOUT_SEC": 123.0,
"SSH_TUNNEL_MANAGER_CLASS": "superset.extensions.ssh.SSHManager",
}
factory = SSHManagerFactory()
factory.init_app(app)
assert sshtunnel.TUNNEL_TIMEOUT == 123.0

0 comments on commit 8b0c68c

Please sign in to comment.