From 467e612533e977d44f59b11b24e0c805904ea7b6 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Tue, 7 May 2024 06:42:37 -0700 Subject: [PATCH] fix: Remedy logic for UpdateDatasetCommand uniqueness check (#28341) --- superset/commands/dataset/create.py | 7 +-- superset/commands/dataset/duplicate.py | 4 +- superset/commands/dataset/exceptions.py | 19 +++--- superset/commands/dataset/update.py | 14 +++-- tests/integration_tests/datasets/api_tests.py | 17 ++++-- .../commands/dataset/test_update.py | 60 +++++++++++++++++++ tests/unit_tests/dao/dataset_test.py | 3 - 7 files changed, 96 insertions(+), 28 deletions(-) create mode 100644 tests/unit_tests/commands/dataset/test_update.py diff --git a/superset/commands/dataset/create.py b/superset/commands/dataset/create.py index dace92f911bcf..b72c3ff46ebb8 100644 --- a/superset/commands/dataset/create.py +++ b/superset/commands/dataset/create.py @@ -61,17 +61,16 @@ def run(self) -> Model: def validate(self) -> None: exceptions: list[ValidationError] = [] database_id = self._properties["database"] - table_name = self._properties["table_name"] schema = self._properties.get("schema") catalog = self._properties.get("catalog") sql = self._properties.get("sql") owner_ids: Optional[list[int]] = self._properties.get("owners") - table = Table(table_name, schema, catalog) + table = Table(self._properties["table_name"], schema, catalog) # Validate uniqueness if not DatasetDAO.validate_uniqueness(database_id, table): - exceptions.append(DatasetExistsValidationError(table_name)) + exceptions.append(DatasetExistsValidationError(table)) # Validate/Populate database database = DatasetDAO.get_database_by_id(database_id) @@ -86,7 +85,7 @@ def validate(self) -> None: and not sql and not DatasetDAO.validate_table_exists(database, table) ): - exceptions.append(TableNotFoundValidationError(table_name)) + exceptions.append(TableNotFoundValidationError(table)) if sql: try: diff --git a/superset/commands/dataset/duplicate.py b/superset/commands/dataset/duplicate.py index 850290422e1c5..efe4935e60af7 100644 --- a/superset/commands/dataset/duplicate.py +++ b/superset/commands/dataset/duplicate.py @@ -37,7 +37,7 @@ from superset.exceptions import SupersetErrorException from superset.extensions import db from superset.models.core import Database -from superset.sql_parse import ParsedQuery +from superset.sql_parse import ParsedQuery, Table logger = logging.getLogger(__name__) @@ -124,7 +124,7 @@ def validate(self) -> None: exceptions.append(DatasourceTypeInvalidError()) if DatasetDAO.find_one_or_none(table_name=duplicate_name): - exceptions.append(DatasetExistsValidationError(table_name=duplicate_name)) + exceptions.append(DatasetExistsValidationError(table=Table(duplicate_name))) try: owners = self.populate_owners() diff --git a/superset/commands/dataset/exceptions.py b/superset/commands/dataset/exceptions.py index 4b8acaca08b8d..83b5436c233a0 100644 --- a/superset/commands/dataset/exceptions.py +++ b/superset/commands/dataset/exceptions.py @@ -26,10 +26,11 @@ ImportFailedError, UpdateFailedError, ) +from superset.sql_parse import Table -def get_dataset_exist_error_msg(full_name: str) -> str: - return _("Dataset %(name)s already exists", name=full_name) +def get_dataset_exist_error_msg(table: Table) -> str: + return _("Dataset %(table)s already exists", table=table) class DatabaseNotFoundValidationError(ValidationError): @@ -55,10 +56,8 @@ class DatasetExistsValidationError(ValidationError): Marshmallow validation error for dataset already exists """ - def __init__(self, table_name: str) -> None: - super().__init__( - [get_dataset_exist_error_msg(table_name)], field_name="table_name" - ) + def __init__(self, table: Table) -> None: + super().__init__([get_dataset_exist_error_msg(table)], field_name="table") class DatasetColumnNotFoundValidationError(ValidationError): @@ -124,18 +123,18 @@ class TableNotFoundValidationError(ValidationError): Marshmallow validation error when a table does not exist on the database """ - def __init__(self, table_name: str) -> None: + def __init__(self, table: Table) -> None: super().__init__( [ _( - "Table [%(table_name)s] could not be found, " + "Table [%(table)s] could not be found, " "please double check your " "database connection, schema, and " "table name", - table_name=table_name, + table=table, ) ], - field_name="table_name", + field_name="table", ) diff --git a/superset/commands/dataset/update.py b/superset/commands/dataset/update.py index 282c778eb432e..2b521452436eb 100644 --- a/superset/commands/dataset/update.py +++ b/superset/commands/dataset/update.py @@ -86,15 +86,21 @@ def validate(self) -> None: except SupersetSecurityException as ex: raise DatasetForbiddenError() from ex - database_id = self._properties.get("database", None) - table_name = self._properties.get("table_name", None) + database_id = self._properties.get("database") + + table = Table( + self._properties.get("table_name"), # type: ignore + self._properties.get("schema"), + self._properties.get("catalog"), + ) + # Validate uniqueness if not DatasetDAO.validate_update_uniqueness( self._model.database_id, - Table(table_name, self._model.schema, self._model.catalog), + table, self._model_id, ): - exceptions.append(DatasetExistsValidationError(table_name)) + exceptions.append(DatasetExistsValidationError(table)) # Validate/Populate database not allowed to change if database_id and database_id != self._model: exceptions.append(DatabaseChangeValidationError()) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index e3258651bb871..0a11732ad160a 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -42,6 +42,7 @@ from superset.extensions import db, security_manager from superset.models.core import Database from superset.models.slice import Slice +from superset.sql_parse import Table from superset.utils.core import backend, get_example_default_schema from superset.utils.database import get_example_database, get_main_database from superset.utils.dict_import_export import export_to_dict @@ -697,7 +698,11 @@ def test_create_dataset_validate_uniqueness(self): assert rv.status_code == 422 data = json.loads(rv.data.decode("utf-8")) assert data == { - "message": {"table_name": ["Dataset energy_usage already exists"]} + "message": { + "table": [ + f"Dataset {Table(energy_usage_ds.table_name, schema)} already exists" + ] + } } @pytest.mark.usefixtures("load_energy_table_with_slice") @@ -719,7 +724,11 @@ def test_create_dataset_with_sql_validate_uniqueness(self): assert rv.status_code == 422 data = json.loads(rv.data.decode("utf-8")) assert data == { - "message": {"table_name": ["Dataset energy_usage already exists"]} + "message": { + "table": [ + f"Dataset {Table(energy_usage_ds.table_name, schema)} already exists" + ] + } } @pytest.mark.usefixtures("load_energy_table_with_slice") @@ -1465,9 +1474,7 @@ def test_update_dataset_item_uniqueness(self): rv = self.put_assert_metric(uri, table_data, "put") data = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 422 - expected_response = { - "message": {"table_name": ["Dataset ab_user already exists"]} - } + expected_response = {"message": {"table": ["Dataset ab_user already exists"]}} assert data == expected_response db.session.delete(dataset) db.session.delete(ab_user) diff --git a/tests/unit_tests/commands/dataset/test_update.py b/tests/unit_tests/commands/dataset/test_update.py new file mode 100644 index 0000000000000..59d43de6fdda1 --- /dev/null +++ b/tests/unit_tests/commands/dataset/test_update.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockFixture + +from superset import db +from superset.commands.dataset.exceptions import DatasetInvalidError +from superset.commands.dataset.update import UpdateDatasetCommand +from superset.connectors.sqla.models import SqlaTable +from superset.models.core import Database + + +@pytest.mark.usefixture("session") +def test_update_uniqueness_error(mocker: MockFixture) -> None: + SqlaTable.metadata.create_all(db.session.get_bind()) + database = Database(database_name="my_db", sqlalchemy_uri="sqlite://") + bar = SqlaTable(table_name="bar", schema="foo", database=database) + baz = SqlaTable(table_name="baz", schema="qux", database=database) + db.session.add_all([database, bar, baz]) + db.session.commit() + + mock_g = mocker.patch("superset.security.manager.g") + mock_g.user = MagicMock() + + mocker.patch( + "superset.views.base.security_manager.can_access_all_datasources", + return_value=True, + ) + + mocker.patch( + "superset.commands.dataset.update.security_manager.raise_for_ownership", + return_value=None, + ) + + mocker.patch.object(UpdateDatasetCommand, "compute_owners", return_value=[]) + + with pytest.raises(DatasetInvalidError): + UpdateDatasetCommand( + bar.id, + { + "table_name": "baz", + "schema": "qux", + }, + ).run() diff --git a/tests/unit_tests/dao/dataset_test.py b/tests/unit_tests/dao/dataset_test.py index a2e2b2b39fba6..473d1e27b7660 100644 --- a/tests/unit_tests/dao/dataset_test.py +++ b/tests/unit_tests/dao/dataset_test.py @@ -51,7 +51,6 @@ def test_validate_update_uniqueness(session: Session) -> None: db.session.add_all([database, dataset1, dataset2]) db.session.flush() - # same table name, different schema assert ( DatasetDAO.validate_update_uniqueness( database_id=database.id, @@ -61,7 +60,6 @@ def test_validate_update_uniqueness(session: Session) -> None: is True ) - # duplicate schema and table name assert ( DatasetDAO.validate_update_uniqueness( database_id=database.id, @@ -71,7 +69,6 @@ def test_validate_update_uniqueness(session: Session) -> None: is False ) - # no schema assert ( DatasetDAO.validate_update_uniqueness( database_id=database.id,