diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ea8db2d870..3dba5abfe83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ - Added 3 more adapter methods that the new dbt-adapter-test suite can use for testing. ([#2492](https://github.com/fishtown-analytics/dbt/issues/2492), [#2721](https://github.com/fishtown-analytics/dbt/pull/2721)) +### Fixes +- dbt now validates the require-dbt-version field before it validates the dbt_project.yml schema ([#2638](https://github.com/fishtown-analytics/dbt/issues/2638), [#2726](https://github.com/fishtown-analytics/dbt/pull/2726)) + + ## dbt 0.18.0rc1 (August 19, 2020) diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py index f16941533bc..b8530cbf6b6 100644 --- a/core/dbt/config/project.py +++ b/core/dbt/config/project.py @@ -215,6 +215,11 @@ class PartialProject: metadata=dict(description='The root directory of the project'), ) project_dict: Dict[str, Any] + verify_version: bool = field( + metadata=dict(description=( + 'If True, verify the dbt version matches the required version' + )) + ) def render(self, renderer): packages_dict = package_data_from_root(self.project_root) @@ -225,6 +230,7 @@ def render(self, renderer): packages_dict, selectors_dict, renderer, + verify_version=self.verify_version, ) def render_profile_name(self, renderer) -> Optional[str]: @@ -292,6 +298,32 @@ def to_dict(self): return self.vars +def validate_version( + required: List[VersionSpecifier], + project_name: str, +) -> None: + """Ensure this package works with the installed version of dbt.""" + installed = get_installed_version() + if not versions_compatible(*required): + msg = IMPOSSIBLE_VERSION_ERROR.format( + package=project_name, + version_spec=[ + x.to_version_string() for x in required + ] + ) + raise DbtProjectError(msg) + + if not versions_compatible(installed, *required): + msg = INVALID_VERSION_ERROR.format( + package=project_name, + installed=installed.to_version_string(), + version_spec=[ + x.to_version_string() for x in required + ] + ) + raise DbtProjectError(msg) + + @dataclass class Project: project_name: str @@ -363,6 +395,7 @@ def from_project_config( project_dict: Dict[str, Any], packages_dict: Optional[Dict[str, Any]] = None, selectors_dict: Optional[Dict[str, Any]] = None, + required_dbt_version: Optional[List[VersionSpecifier]] = None, ) -> 'Project': """Create a project from its project and package configuration, as read by yaml.safe_load(). @@ -374,6 +407,11 @@ def from_project_config( the packages file exists and is invalid. :returns: The project, with defaults populated. """ + if required_dbt_version is None: + dbt_version = cls._get_required_version(project_dict) + else: + dbt_version = required_dbt_version + try: project_dict = cls._preprocess(project_dict) except RecursionException: @@ -460,18 +498,8 @@ def from_project_config( on_run_start: List[str] = value_or(cfg.on_run_start, []) on_run_end: List[str] = value_or(cfg.on_run_end, []) - # weird type handling: no value_or use - dbt_raw_version: Union[List[str], str] = '>=0.0.0' - if cfg.require_dbt_version is not None: - dbt_raw_version = cfg.require_dbt_version - query_comment = _query_comment_from_cfg(cfg.query_comment) - try: - dbt_version = _parse_versions(dbt_raw_version) - except SemverException as e: - raise DbtProjectError(str(e)) from e - try: packages = package_config_from_data(packages_dict) except ValidationError as e: @@ -583,6 +611,30 @@ def validate(self): except ValidationError as e: raise DbtProjectError(validator_error_message(e)) from e + @classmethod + def _get_required_version( + cls, rendered_project: Dict[str, Any], verify_version: bool = False + ) -> List[VersionSpecifier]: + dbt_raw_version: Union[List[str], str] = '>=0.0.0' + required = rendered_project.get('require-dbt-version') + if required is not None: + dbt_raw_version = required + + try: + dbt_version = _parse_versions(dbt_raw_version) + except SemverException as e: + raise DbtProjectError(str(e)) from e + + if verify_version: + # no name is also an error that we want to raise + if 'name' not in rendered_project: + raise DbtProjectError( + 'Required "name" field not present in project', + ) + validate_version(dbt_version, rendered_project['name']) + + return dbt_version + @classmethod def render_from_dict( cls, @@ -591,6 +643,8 @@ def render_from_dict( packages_dict: Dict[str, Any], selectors_dict: Dict[str, Any], renderer: DbtProjectYamlRenderer, + *, + verify_version: bool = False ) -> 'Project': rendered_project = renderer.render_data(project_dict) rendered_project['project-root'] = project_root @@ -598,11 +652,17 @@ def render_from_dict( rendered_packages = package_renderer.render_data(packages_dict) selectors_renderer = renderer.get_selector_renderer() rendered_selectors = selectors_renderer.render_data(selectors_dict) + try: + dbt_version = cls._get_required_version( + rendered_project, verify_version=verify_version + ) + return cls.from_project_config( rendered_project, rendered_packages, rendered_selectors, + dbt_version, ) except DbtProjectError as exc: if exc.path is None: @@ -611,7 +671,7 @@ def render_from_dict( @classmethod def partial_load( - cls, project_root: str + cls, project_root: str, *, verify_version: bool = False ) -> PartialProject: project_root = os.path.normpath(project_root) project_dict = _raw_project_from(project_root) @@ -626,41 +686,24 @@ def partial_load( project_name=project_name, project_root=project_root, project_dict=project_dict, + verify_version=verify_version, ) @classmethod def from_project_root( - cls, project_root: str, renderer: DbtProjectYamlRenderer + cls, + project_root: str, + renderer: DbtProjectYamlRenderer, + *, + verify_version: bool = False, ) -> 'Project': - partial = cls.partial_load(project_root) + partial = cls.partial_load(project_root, verify_version=verify_version) renderer.version = partial.config_version return partial.render(renderer) def hashed_name(self): return hashlib.md5(self.project_name.encode('utf-8')).hexdigest() - def validate_version(self): - """Ensure this package works with the installed version of dbt.""" - installed = get_installed_version() - if not versions_compatible(*self.dbt_version): - msg = IMPOSSIBLE_VERSION_ERROR.format( - package=self.project_name, - version_spec=[ - x.to_version_string() for x in self.dbt_version - ] - ) - raise DbtProjectError(msg) - - if not versions_compatible(installed, *self.dbt_version): - msg = INVALID_VERSION_ERROR.format( - package=self.project_name, - installed=installed.to_version_string(), - version_spec=[ - x.to_version_string() for x in self.dbt_version - ] - ) - raise DbtProjectError(msg) - def as_v1(self, all_projects: Iterable[str]): if self.config_version == 1: return self diff --git a/core/dbt/config/runtime.py b/core/dbt/config/runtime.py index 184e2732762..402bd10b4bb 100644 --- a/core/dbt/config/runtime.py +++ b/core/dbt/config/runtime.py @@ -138,7 +138,11 @@ def new_project(self, project_root: str) -> 'RuntimeConfig': # load the new project and its packages. Don't pass cli variables. renderer = DbtProjectYamlRenderer(generate_target_context(profile, {})) - project = Project.from_project_root(project_root, renderer) + project = Project.from_project_root( + project_root, + renderer, + verify_version=getattr(self.args, 'version_check', False), + ) cfg = self.from_parts( project=project, @@ -173,9 +177,6 @@ def validate(self): except ValidationError as e: raise DbtProjectError(validator_error_message(e)) from e - if getattr(self.args, 'version_check', False): - self.validate_version() - @classmethod def _get_rendered_profile( cls, @@ -193,7 +194,11 @@ def collect_parts( ) -> Tuple[Project, Profile]: # profile_name from the project project_root = args.project_dir if args.project_dir else os.getcwd() - partial = Project.partial_load(project_root) + version_check = getattr(args, 'version_check', False) + partial = Project.partial_load( + project_root, + verify_version=version_check + ) # build the profile using the base renderer and the one fact we know cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}')) diff --git a/core/dbt/main.py b/core/dbt/main.py index a372fe3fb1e..932fae13c33 100644 --- a/core/dbt/main.py +++ b/core/dbt/main.py @@ -406,6 +406,7 @@ def _build_debug_subparser(subparsers, base_subparser): If specified, DBT will show path information for this project ''' ) + _add_version_check(sub) sub.set_defaults(cls=debug_task.DebugTask, which='debug', rpc_method=None) return sub @@ -597,6 +598,18 @@ def _add_table_mutability_arguments(*subparsers): ) +def _add_version_check(sub): + sub.add_argument( + '--no-version-check', + dest='version_check', + action='store_false', + help=''' + If set, skip ensuring dbt's version matches the one specified in + the dbt_project.yml file ('require-dbt-version') + ''' + ) + + def _add_common_arguments(*subparsers): for sub in subparsers: sub.add_argument( @@ -608,15 +621,7 @@ def _add_common_arguments(*subparsers): settings in profiles.yml. ''' ) - sub.add_argument( - '--no-version-check', - dest='version_check', - action='store_false', - help=''' - If set, skip ensuring dbt's version matches the one specified in - the dbt_project.yml file ('require-dbt-version') - ''' - ) + _add_version_check(sub) def _build_seed_subparser(subparsers, base_subparser): diff --git a/core/dbt/task/debug.py b/core/dbt/task/debug.py index 84fdd420713..26ad51d00c5 100644 --- a/core/dbt/task/debug.py +++ b/core/dbt/task/debug.py @@ -143,7 +143,9 @@ def _load_project(self): try: self.project = Project.from_project_root( - self.project_dir, renderer + self.project_dir, + renderer, + verify_version=getattr(self.args, 'version_check', False), ) except dbt.exceptions.DbtConfigError as exc: self.project_fail_details = str(exc) @@ -181,7 +183,8 @@ def _choose_profile_names(self) -> Optional[List[str]]: if os.path.exists(self.project_path): try: partial = Project.partial_load( - os.path.dirname(self.project_path) + os.path.dirname(self.project_path), + verify_version=getattr(self.args, 'version_check', False), ) renderer = DbtProjectYamlRenderer( generate_base_context(self.cli_vars) diff --git a/test/unit/test_config.py b/test/unit/test_config.py index 77e5b72567a..f76e1305329 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -35,6 +35,10 @@ def temp_cd(path): finally: os.chdir(current_path) +@contextmanager +def raises_nothing(): + yield + def empty_profile_renderer(): return dbt.config.renderer.ProfileRenderer(generate_base_context({})) @@ -179,6 +183,12 @@ def setUp(self): 'env_value_profile': 'default', } + def assertRaisesOrReturns(self, exc): + if exc is None: + return raises_nothing() + else: + return self.assertRaises(exc) + class BaseFileTest(BaseConfigTest): def setUp(self): @@ -460,7 +470,7 @@ def test_target_override(self): self.assertEqual(profile.credentials.password, 'db_pass') self.assertEqual(profile.credentials.schema, 'redshift-schema') self.assertEqual(profile.credentials.database, 'redshift-db-name') - self.assertEqual(profile, from_raw) + self.assertEqual(profile, from_raw) def test_env_vars(self): self.args.target = 'with-vars' @@ -947,8 +957,12 @@ def setUp(self): self.default_project_data['project-root'] = self.project_dir def get_project(self): + version = dbt.config.Project._get_required_version( + self.default_project_data, + verify_version=bool(self.args.version_check) + ) return dbt.config.Project.from_project_config( - self.default_project_data, None + self.default_project_data, None, required_dbt_version=version ) def get_profile(self): @@ -958,14 +972,16 @@ def get_profile(self): ) def from_parts(self, exc=None): - project = self.get_project() - profile = self.get_profile() - if exc is None: - return dbt.config.RuntimeConfig.from_parts(project, profile, self.args) + with self.assertRaisesOrReturns(exc) as err: + project = self.get_project() + profile = self.get_profile() + + result = dbt.config.RuntimeConfig.from_parts(project, profile, self.args) - with self.assertRaises(exc) as err: - dbt.config.RuntimeConfig.from_parts(project, profile, self.args) - return err + if exc is None: + return result + else: + return err def test_from_parts(self): project = self.get_project() @@ -1029,6 +1045,12 @@ def test_unsupported_version_range(self): raised = self.from_parts(dbt.exceptions.DbtProjectError) self.assertIn('This version of dbt is not supported', str(raised.exception)) + def test_unsupported_version_range_bad_config(self): + self.default_project_data['require-dbt-version'] = ['>0.0.0', '<=0.0.1'] + self.default_project_data['some-extra-field-not-allowed'] = True + raised = self.from_parts(dbt.exceptions.DbtProjectError) + self.assertIn('This version of dbt is not supported', str(raised.exception)) + def test_unsupported_version_range_no_check(self): self.default_project_data['require-dbt-version'] = ['>0.0.0', '<=0.0.1'] self.args.version_check = False @@ -1040,6 +1062,11 @@ def test_impossible_version_range(self): raised = self.from_parts(dbt.exceptions.DbtProjectError) self.assertIn('The package version requirement can never be satisfied', str(raised.exception)) + def test_unsupported_version_extra_config(self): + self.default_project_data['some-extra-field-not-allowed'] = True + raised = self.from_parts(dbt.exceptions.DbtProjectError) + self.assertIn('Additional properties are not allowed', str(raised.exception)) + def test_archive_not_allowed(self): self.default_project_data['archive'] = [{ "source_schema": 'a', @@ -1124,8 +1151,11 @@ def setUp(self): ))} def get_project(self): + version = dbt.config.Project._get_required_version( + self.default_project_data, verify_version=True + ) return dbt.config.Project.from_project_config( - self.default_project_data, None + self.default_project_data, None, required_dbt_version=version ) def get_profile(self): @@ -1135,15 +1165,16 @@ def get_profile(self): ) def from_parts(self, exc=None): - project = self.get_project() - profile = self.get_profile() - if exc is None: - return dbt.config.RuntimeConfig.from_parts(project, profile, self.args) + with self.assertRaisesOrReturns(exc) as err: + project = self.get_project() + profile = self.get_profile() - with self.assertRaises(exc) as err: - dbt.config.RuntimeConfig.from_parts(project, profile, self.args) - return err + result = dbt.config.RuntimeConfig.from_parts(project, profile, self.args) + if exc is None: + return result + else: + return err def test__get_unused_resource_config_paths(self): project = self.from_parts()