diff --git a/superset/commands/importers/v1/__init__.py b/superset/commands/importers/v1/__init__.py index 89891f48e597a..c620ec9f2ac85 100644 --- a/superset/commands/importers/v1/__init__.py +++ b/superset/commands/importers/v1/__init__.py @@ -24,9 +24,11 @@ from superset.commands.base import BaseCommand from superset.commands.exceptions import CommandException, CommandInvalidError from superset.commands.importers.v1.utils import ( + load_configs, load_metadata, load_yaml, METADATA_FILE_NAME, + validate_metadata_type, ) from superset.dao.base import BaseDAO from superset.models.core import Database @@ -78,9 +80,13 @@ def validate(self) -> None: except ValidationError as exc: exceptions.append(exc) metadata = None + if self.dao.model_cls: + validate_metadata_type(metadata, self.dao.model_cls.__name__, exceptions) - self._validate_metadata_type(metadata, exceptions) - self._load__configs(exceptions) + # load the configs and make sure we have confirmation to overwrite existing models + self._configs = load_configs( + self.contents, self.schemas, self.passwords, exceptions + ) self._prevent_overwrite_existing_model(exceptions) if exceptions: @@ -88,49 +94,6 @@ def validate(self) -> None: exception.add_list(exceptions) raise exception - def _validate_metadata_type( - self, metadata: Optional[Dict[str, str]], exceptions: List[ValidationError] - ) -> None: - """Validate that the type declared in METADATA_FILE_NAME is correct""" - if metadata and "type" in metadata: - type_validator = validate.Equal(self.dao.model_cls.__name__) # type: ignore - try: - type_validator(metadata["type"]) - except ValidationError as exc: - exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}} - exceptions.append(exc) - - def _load__configs(self, exceptions: List[ValidationError]) -> None: - # load existing databases so we can apply the password validation - db_passwords: Dict[str, str] = { - str(uuid): password - for uuid, password in db.session.query( - Database.uuid, Database.password - ).all() - } - for file_name, content in self.contents.items(): - # skip directories - if not content: - continue - - prefix = file_name.split("/")[0] - schema = self.schemas.get(f"{prefix}/") - if schema: - try: - config = load_yaml(file_name, content) - - # populate passwords from the request or from existing DBs - if file_name in self.passwords: - config["password"] = self.passwords[file_name] - elif prefix == "databases" and config["uuid"] in db_passwords: - config["password"] = db_passwords[config["uuid"]] - - schema.load(config) - self._configs[file_name] = config - except ValidationError as exc: - exc.messages = {file_name: exc.messages} - exceptions.append(exc) - def _prevent_overwrite_existing_model( # pylint: disable=invalid-name self, exceptions: List[ValidationError] ) -> None: diff --git a/superset/commands/importers/v1/utils.py b/superset/commands/importers/v1/utils.py index 15bec8278ca75..de86e3f3cc6ab 100644 --- a/superset/commands/importers/v1/utils.py +++ b/superset/commands/importers/v1/utils.py @@ -15,14 +15,16 @@ import logging from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, List, Optional from zipfile import ZipFile import yaml from marshmallow import fields, Schema, validate from marshmallow.exceptions import ValidationError +from superset import db from superset.commands.importers.exceptions import IncorrectVersionError +from superset.models.core import Database METADATA_FILE_NAME = "metadata.yaml" IMPORT_VERSION = "1.0.0" @@ -76,6 +78,58 @@ def load_metadata(contents: Dict[str, str]) -> Dict[str, str]: return metadata +def validate_metadata_type( + metadata: Optional[Dict[str, str]], type_: str, exceptions: List[ValidationError], +) -> None: + """Validate that the type declared in METADATA_FILE_NAME is correct""" + if metadata and "type" in metadata: + type_validator = validate.Equal(type_) + try: + type_validator(metadata["type"]) + except ValidationError as exc: + exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}} + exceptions.append(exc) + + +def load_configs( + contents: Dict[str, str], + schemas: Dict[str, Schema], + passwords: Dict[str, str], + exceptions: List[ValidationError], +) -> Dict[str, Any]: + configs: Dict[str, Any] = {} + + # load existing databases so we can apply the password validation + db_passwords: Dict[str, str] = { + str(uuid): password + for uuid, password in db.session.query(Database.uuid, Database.password).all() + } + for file_name, content in contents.items(): + # skip directories + if not content: + continue + + prefix = file_name.split("/")[0] + schema = schemas.get(f"{prefix}/") + if schema: + try: + config = load_yaml(file_name, content) + + # populate passwords from the request or from existing DBs + if file_name in passwords: + config["password"] = passwords[file_name] + elif prefix == "databases" and config["uuid"] in db_passwords: + config["password"] = db_passwords[config["uuid"]] + + schema.load(config) + configs[file_name] = config + except ValidationError as exc: + exc.messages = {file_name: exc.messages} + exceptions.append(exc) + + return configs + + def is_valid_config(file_name: str) -> bool: path = Path(file_name)