Skip to content

Commit

Permalink
refactor: pass all properties to validate_parameters (#21487)
Browse files Browse the repository at this point in the history
  • Loading branch information
eschutho authored Oct 4, 2022
1 parent 4417c6e commit e98943e
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 74 deletions.
4 changes: 4 additions & 0 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"allow_cvas",
"allow_dml",
"backend",
"driver",
"force_ctas_schema",
"impersonate_user",
"masked_encrypted_extra",
Expand Down Expand Up @@ -269,6 +270,9 @@ def post(self) -> Response:
if new_model.parameters:
item["parameters"] = new_model.parameters

if new_model.driver:
item["driver"] = new_model.driver

return self.response(201, id=new_model.id, result=item)
except DatabaseInvalidError as ex:
return self.response_422(message=ex.normalized_messages())
Expand Down
8 changes: 3 additions & 5 deletions superset/databases/commands/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@


class ValidateDatabaseParametersCommand(BaseCommand):
def __init__(self, parameters: Dict[str, Any]):
self._properties = parameters.copy()
def __init__(self, properties: Dict[str, Any]):
self._properties = properties.copy()
self._model: Optional[Database] = None

def run(self) -> None:
Expand All @@ -66,9 +66,7 @@ def run(self) -> None:
)

# perform initial validation
errors = engine_spec.validate_parameters( # type: ignore
self._properties.get("parameters", {})
)
errors = engine_spec.validate_parameters(self._properties) # type: ignore
if errors:
event_logger.log_with_context(action="validation_error", engine=engine)
raise InvalidParametersError(errors)
Expand Down
7 changes: 6 additions & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,10 @@ class BasicParametersType(TypedDict, total=False):
encryption: bool


class BasicPropertiesType(TypedDict):
parameters: BasicParametersType


class BasicParametersMixin:
"""
Mixin for configuring DB engine specs via a dictionary.
Expand Down Expand Up @@ -1762,7 +1766,7 @@ def get_parameters_from_uri( # pylint: disable=unused-argument

@classmethod
def validate_parameters(
cls, parameters: BasicParametersType
cls, properties: BasicPropertiesType
) -> List[SupersetError]:
"""
Validates any number of parameters, for progressive validation.
Expand All @@ -1773,6 +1777,7 @@ def validate_parameters(
errors: List[SupersetError] = []

required = {"host", "port", "username", "database"}
parameters = properties.get("parameters", {})
present = {key for key in parameters if parameters.get(key, ())}
missing = sorted(required - present)

Expand Down
5 changes: 3 additions & 2 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from superset.constants import PASSWORD_MASK
from superset.databases.schemas import encrypted_field_properties, EncryptedString
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
from superset.db_engine_specs.exceptions import SupersetDBAPIDisconnectionError
from superset.errors import SupersetError, SupersetErrorType
from superset.sql_parse import Table
Expand Down Expand Up @@ -450,7 +450,8 @@ def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:

@classmethod
def validate_parameters(
cls, parameters: BigQueryParametersType # pylint: disable=unused-argument
cls,
properties: BasicPropertiesType, # pylint: disable=unused-argument
) -> List[SupersetError]:
return []

Expand Down
7 changes: 6 additions & 1 deletion superset/db_engine_specs/gsheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class GSheetsParametersType(TypedDict):
catalog: Dict[str, str]


class GSheetsPropertiesType(TypedDict):
parameters: GSheetsParametersType


class GSheetsEngineSpec(SqliteEngineSpec):
"""Engine for Google spreadsheets"""

Expand Down Expand Up @@ -208,9 +212,10 @@ def parameters_json_schema(cls) -> Any:
@classmethod
def validate_parameters(
cls,
parameters: GSheetsParametersType,
properties: GSheetsPropertiesType,
) -> List[SupersetError]:
errors: List[SupersetError] = []
parameters = properties.get("parameters", {})
encrypted_credentials = parameters.get("service_account_info") or "{}"

# On create the encrypted credentials are a string,
Expand Down
4 changes: 3 additions & 1 deletion superset/db_engine_specs/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from typing_extensions import TypedDict

from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BasicPropertiesType
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
Expand Down Expand Up @@ -242,7 +243,7 @@ def get_parameters_from_uri(

@classmethod
def validate_parameters(
cls, parameters: SnowflakeParametersType
cls, properties: BasicPropertiesType
) -> List[SupersetError]:
errors: List[SupersetError] = []
required = {
Expand All @@ -253,6 +254,7 @@ def validate_parameters(
"role",
"password",
}
parameters = properties.get("parameters", {})
present = {key for key in parameters if parameters.get(key, ())}
missing = sorted(required - present)

Expand Down
17 changes: 9 additions & 8 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,23 +244,24 @@ def url_object(self) -> URL:

@property
def backend(self) -> str:
sqlalchemy_url = make_url_safe(self.sqlalchemy_uri_decrypted)
return sqlalchemy_url.get_backend_name()
return self.url_object.get_backend_name()

@property
def driver(self) -> str:
return self.url_object.get_driver_name()

@property
def masked_encrypted_extra(self) -> Optional[str]:
return self.db_engine_spec.mask_encrypted_extra(self.encrypted_extra)

@property
def parameters(self) -> Dict[str, Any]:
db_engine_spec = self.db_engine_spec

# Database parameters are a dictionary of values that are used to make up
# the sqlalchemy_uri
# When returning the parameters we should use the masked SQLAlchemy URI and the
# masked ``encrypted_extra`` to prevent exposing sensitive credentials.
masked_uri = make_url_safe(self.sqlalchemy_uri)
masked_encrypted_extra = db_engine_spec.mask_encrypted_extra(
self.encrypted_extra
)
masked_encrypted_extra = self.masked_encrypted_extra
encrypted_config = {}
if masked_encrypted_extra is not None:
try:
Expand All @@ -270,7 +271,7 @@ def parameters(self) -> Dict[str, Any]:

try:
# pylint: disable=useless-suppression
parameters = db_engine_spec.get_parameters_from_uri( # type: ignore
parameters = self.db_engine_spec.get_parameters_from_uri( # type: ignore
masked_uri,
encrypted_extra=encrypted_config,
)
Expand Down
72 changes: 40 additions & 32 deletions tests/integration_tests/db_engine_specs/base_engine_spec_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,28 +421,32 @@ def test_validate(is_port_open, is_hostname_valid):
is_hostname_valid.return_value = True
is_port_open.return_value = True

parameters = {
"host": "localhost",
"port": 5432,
"username": "username",
"password": "password",
"database": "dbname",
"query": {"sslmode": "verify-full"},
properties = {
"parameters": {
"host": "localhost",
"port": 5432,
"username": "username",
"password": "password",
"database": "dbname",
"query": {"sslmode": "verify-full"},
}
}
errors = BasicParametersMixin.validate_parameters(parameters)
errors = BasicParametersMixin.validate_parameters(properties)
assert errors == []


def test_validate_parameters_missing():
parameters = {
"host": "",
"port": None,
"username": "",
"password": "",
"database": "",
"query": {},
properties = {
"parameters": {
"host": "",
"port": None,
"username": "",
"password": "",
"database": "",
"query": {},
}
}
errors = BasicParametersMixin.validate_parameters(parameters)
errors = BasicParametersMixin.validate_parameters(properties)
assert errors == [
SupersetError(
message=(
Expand All @@ -459,15 +463,17 @@ def test_validate_parameters_missing():
def test_validate_parameters_invalid_host(is_hostname_valid):
is_hostname_valid.return_value = False

parameters = {
"host": "localhost",
"port": None,
"username": "username",
"password": "password",
"database": "dbname",
"query": {"sslmode": "verify-full"},
properties = {
"parameters": {
"host": "localhost",
"port": None,
"username": "username",
"password": "password",
"database": "dbname",
"query": {"sslmode": "verify-full"},
}
}
errors = BasicParametersMixin.validate_parameters(parameters)
errors = BasicParametersMixin.validate_parameters(properties)
assert errors == [
SupersetError(
message="One or more parameters are missing: port",
Expand All @@ -490,15 +496,17 @@ def test_validate_parameters_port_closed(is_port_open, is_hostname_valid):
is_hostname_valid.return_value = True
is_port_open.return_value = False

parameters = {
"host": "localhost",
"port": 5432,
"username": "username",
"password": "password",
"database": "dbname",
"query": {"sslmode": "verify-full"},
properties = {
"parameters": {
"host": "localhost",
"port": 5432,
"username": "username",
"password": "password",
"database": "dbname",
"query": {"sslmode": "verify-full"},
}
}
errors = BasicParametersMixin.validate_parameters(parameters)
errors = BasicParametersMixin.validate_parameters(properties)
assert errors == [
SupersetError(
message="The port is closed.",
Expand Down
6 changes: 6 additions & 0 deletions tests/unit_tests/databases/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from uuid import UUID

import pytest
from pytest_mock import MockFixture
from sqlalchemy.orm.session import Session


Expand Down Expand Up @@ -53,6 +54,7 @@ def test_post_with_uuid(


def test_password_mask(
mocker: MockFixture,
app: Any,
session: Session,
client: Any,
Expand Down Expand Up @@ -92,6 +94,10 @@ def test_password_mask(
session.add(database)
session.commit()

# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")

response = client.get("/api/v1/database/1")
assert (
response.json["result"]["parameters"]["service_account_info"]["private_key"]
Expand Down
1 change: 0 additions & 1 deletion tests/unit_tests/databases/schema_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def test_database_parameters_schema_mixin_invalid_engine(
try:
dummy_schema.load(payload)
except ValidationError as err:
print(err.messages)
assert err.messages == {
"_schema": ['Engine "dummy_engine" is not a valid engine.']
}
Expand Down
Loading

0 comments on commit e98943e

Please sign in to comment.