Skip to content

Commit

Permalink
Merge pull request #2726 from fishtown-analytics/fix/require-version-…
Browse files Browse the repository at this point in the history
…validation

Validate require-dbt-version before validating dbt_project.ymls chema
  • Loading branch information
beckjake authored Aug 27, 2020
2 parents fe46138 + 0130398 commit 75faceb
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 68 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
- Added 3 more adapter methods that the new dbt-adapter-test suite can use for testing. ([#2492](https://github.com/fishtown-analytics/dbt/issues/2492), [#2721](https://github.com/fishtown-analytics/dbt/pull/2721))


### Fixes
- dbt now validates the require-dbt-version field before it validates the dbt_project.yml schema ([#2638](https://github.com/fishtown-analytics/dbt/issues/2638), [#2726](https://github.com/fishtown-analytics/dbt/pull/2726))


## dbt 0.18.0rc1 (August 19, 2020)


Expand Down
113 changes: 78 additions & 35 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ class PartialProject:
metadata=dict(description='The root directory of the project'),
)
project_dict: Dict[str, Any]
verify_version: bool = field(
metadata=dict(description=(
'If True, verify the dbt version matches the required version'
))
)

def render(self, renderer):
packages_dict = package_data_from_root(self.project_root)
Expand All @@ -225,6 +230,7 @@ def render(self, renderer):
packages_dict,
selectors_dict,
renderer,
verify_version=self.verify_version,
)

def render_profile_name(self, renderer) -> Optional[str]:
Expand Down Expand Up @@ -292,6 +298,32 @@ def to_dict(self):
return self.vars


def validate_version(
required: List[VersionSpecifier],
project_name: str,
) -> None:
"""Ensure this package works with the installed version of dbt."""
installed = get_installed_version()
if not versions_compatible(*required):
msg = IMPOSSIBLE_VERSION_ERROR.format(
package=project_name,
version_spec=[
x.to_version_string() for x in required
]
)
raise DbtProjectError(msg)

if not versions_compatible(installed, *required):
msg = INVALID_VERSION_ERROR.format(
package=project_name,
installed=installed.to_version_string(),
version_spec=[
x.to_version_string() for x in required
]
)
raise DbtProjectError(msg)


@dataclass
class Project:
project_name: str
Expand Down Expand Up @@ -363,6 +395,7 @@ def from_project_config(
project_dict: Dict[str, Any],
packages_dict: Optional[Dict[str, Any]] = None,
selectors_dict: Optional[Dict[str, Any]] = None,
required_dbt_version: Optional[List[VersionSpecifier]] = None,
) -> 'Project':
"""Create a project from its project and package configuration, as read
by yaml.safe_load().
Expand All @@ -374,6 +407,11 @@ def from_project_config(
the packages file exists and is invalid.
:returns: The project, with defaults populated.
"""
if required_dbt_version is None:
dbt_version = cls._get_required_version(project_dict)
else:
dbt_version = required_dbt_version

try:
project_dict = cls._preprocess(project_dict)
except RecursionException:
Expand Down Expand Up @@ -460,18 +498,8 @@ def from_project_config(
on_run_start: List[str] = value_or(cfg.on_run_start, [])
on_run_end: List[str] = value_or(cfg.on_run_end, [])

# weird type handling: no value_or use
dbt_raw_version: Union[List[str], str] = '>=0.0.0'
if cfg.require_dbt_version is not None:
dbt_raw_version = cfg.require_dbt_version

query_comment = _query_comment_from_cfg(cfg.query_comment)

try:
dbt_version = _parse_versions(dbt_raw_version)
except SemverException as e:
raise DbtProjectError(str(e)) from e

try:
packages = package_config_from_data(packages_dict)
except ValidationError as e:
Expand Down Expand Up @@ -583,6 +611,30 @@ def validate(self):
except ValidationError as e:
raise DbtProjectError(validator_error_message(e)) from e

@classmethod
def _get_required_version(
cls, rendered_project: Dict[str, Any], verify_version: bool = False
) -> List[VersionSpecifier]:
dbt_raw_version: Union[List[str], str] = '>=0.0.0'
required = rendered_project.get('require-dbt-version')
if required is not None:
dbt_raw_version = required

try:
dbt_version = _parse_versions(dbt_raw_version)
except SemverException as e:
raise DbtProjectError(str(e)) from e

if verify_version:
# no name is also an error that we want to raise
if 'name' not in rendered_project:
raise DbtProjectError(
'Required "name" field not present in project',
)
validate_version(dbt_version, rendered_project['name'])

return dbt_version

@classmethod
def render_from_dict(
cls,
Expand All @@ -591,18 +643,26 @@ def render_from_dict(
packages_dict: Dict[str, Any],
selectors_dict: Dict[str, Any],
renderer: DbtProjectYamlRenderer,
*,
verify_version: bool = False
) -> 'Project':
rendered_project = renderer.render_data(project_dict)
rendered_project['project-root'] = project_root
package_renderer = renderer.get_package_renderer()
rendered_packages = package_renderer.render_data(packages_dict)
selectors_renderer = renderer.get_selector_renderer()
rendered_selectors = selectors_renderer.render_data(selectors_dict)

try:
dbt_version = cls._get_required_version(
rendered_project, verify_version=verify_version
)

return cls.from_project_config(
rendered_project,
rendered_packages,
rendered_selectors,
dbt_version,
)
except DbtProjectError as exc:
if exc.path is None:
Expand All @@ -611,7 +671,7 @@ def render_from_dict(

@classmethod
def partial_load(
cls, project_root: str
cls, project_root: str, *, verify_version: bool = False
) -> PartialProject:
project_root = os.path.normpath(project_root)
project_dict = _raw_project_from(project_root)
Expand All @@ -626,41 +686,24 @@ def partial_load(
project_name=project_name,
project_root=project_root,
project_dict=project_dict,
verify_version=verify_version,
)

@classmethod
def from_project_root(
cls, project_root: str, renderer: DbtProjectYamlRenderer
cls,
project_root: str,
renderer: DbtProjectYamlRenderer,
*,
verify_version: bool = False,
) -> 'Project':
partial = cls.partial_load(project_root)
partial = cls.partial_load(project_root, verify_version=verify_version)
renderer.version = partial.config_version
return partial.render(renderer)

def hashed_name(self):
return hashlib.md5(self.project_name.encode('utf-8')).hexdigest()

def validate_version(self):
"""Ensure this package works with the installed version of dbt."""
installed = get_installed_version()
if not versions_compatible(*self.dbt_version):
msg = IMPOSSIBLE_VERSION_ERROR.format(
package=self.project_name,
version_spec=[
x.to_version_string() for x in self.dbt_version
]
)
raise DbtProjectError(msg)

if not versions_compatible(installed, *self.dbt_version):
msg = INVALID_VERSION_ERROR.format(
package=self.project_name,
installed=installed.to_version_string(),
version_spec=[
x.to_version_string() for x in self.dbt_version
]
)
raise DbtProjectError(msg)

def as_v1(self, all_projects: Iterable[str]):
if self.config_version == 1:
return self
Expand Down
15 changes: 10 additions & 5 deletions core/dbt/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ def new_project(self, project_root: str) -> 'RuntimeConfig':
# load the new project and its packages. Don't pass cli variables.
renderer = DbtProjectYamlRenderer(generate_target_context(profile, {}))

project = Project.from_project_root(project_root, renderer)
project = Project.from_project_root(
project_root,
renderer,
verify_version=getattr(self.args, 'version_check', False),
)

cfg = self.from_parts(
project=project,
Expand Down Expand Up @@ -173,9 +177,6 @@ def validate(self):
except ValidationError as e:
raise DbtProjectError(validator_error_message(e)) from e

if getattr(self.args, 'version_check', False):
self.validate_version()

@classmethod
def _get_rendered_profile(
cls,
Expand All @@ -193,7 +194,11 @@ def collect_parts(
) -> Tuple[Project, Profile]:
# profile_name from the project
project_root = args.project_dir if args.project_dir else os.getcwd()
partial = Project.partial_load(project_root)
version_check = getattr(args, 'version_check', False)
partial = Project.partial_load(
project_root,
verify_version=version_check
)

# build the profile using the base renderer and the one fact we know
cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}'))
Expand Down
23 changes: 14 additions & 9 deletions core/dbt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def _build_debug_subparser(subparsers, base_subparser):
If specified, DBT will show path information for this project
'''
)
_add_version_check(sub)
sub.set_defaults(cls=debug_task.DebugTask, which='debug', rpc_method=None)
return sub

Expand Down Expand Up @@ -597,6 +598,18 @@ def _add_table_mutability_arguments(*subparsers):
)


def _add_version_check(sub):
sub.add_argument(
'--no-version-check',
dest='version_check',
action='store_false',
help='''
If set, skip ensuring dbt's version matches the one specified in
the dbt_project.yml file ('require-dbt-version')
'''
)


def _add_common_arguments(*subparsers):
for sub in subparsers:
sub.add_argument(
Expand All @@ -608,15 +621,7 @@ def _add_common_arguments(*subparsers):
settings in profiles.yml.
'''
)
sub.add_argument(
'--no-version-check',
dest='version_check',
action='store_false',
help='''
If set, skip ensuring dbt's version matches the one specified in
the dbt_project.yml file ('require-dbt-version')
'''
)
_add_version_check(sub)


def _build_seed_subparser(subparsers, base_subparser):
Expand Down
7 changes: 5 additions & 2 deletions core/dbt/task/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def _load_project(self):

try:
self.project = Project.from_project_root(
self.project_dir, renderer
self.project_dir,
renderer,
verify_version=getattr(self.args, 'version_check', False),
)
except dbt.exceptions.DbtConfigError as exc:
self.project_fail_details = str(exc)
Expand Down Expand Up @@ -181,7 +183,8 @@ def _choose_profile_names(self) -> Optional[List[str]]:
if os.path.exists(self.project_path):
try:
partial = Project.partial_load(
os.path.dirname(self.project_path)
os.path.dirname(self.project_path),
verify_version=getattr(self.args, 'version_check', False),
)
renderer = DbtProjectYamlRenderer(
generate_base_context(self.cli_vars)
Expand Down
Loading

0 comments on commit 75faceb

Please sign in to comment.