Skip to content

Commit

Permalink
backport #10644
Browse files Browse the repository at this point in the history
  • Loading branch information
mikealfare committed Sep 17, 2024
1 parent 88041d6 commit 8ccc7d0
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 7 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240829-135320.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add support for behavior flags
time: 2024-08-29T13:53:20.16122-04:00
custom:
Author: mikealfare
Issue: "10618"
9 changes: 4 additions & 5 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ def create_project(self, rendered: RenderComponents) -> "Project":
rendered.selectors_dict["selectors"]
)
dbt_cloud = cfg.dbt_cloud
flags: Dict[str, Any] = cfg.flags

project = Project(
project_name=name,
Expand Down Expand Up @@ -535,6 +536,7 @@ def create_project(self, rendered: RenderComponents) -> "Project":
project_env_vars=project_env_vars,
restrict_access=cfg.restrict_access,
dbt_cloud=dbt_cloud,
flags=flags,
)
# sanity check - this means an internal issue
project.validate()
Expand Down Expand Up @@ -579,11 +581,6 @@ def from_project_root(
) = package_and_project_data_from_root(project_root)
selectors_dict = selector_data_from_root(project_root)

if "flags" in project_dict:
# We don't want to include "flags" in the Project,
# it goes in ProjectFlags
project_dict.pop("flags")

return cls.from_dicts(
project_root=project_root,
project_dict=project_dict,
Expand Down Expand Up @@ -656,6 +653,7 @@ class Project:
project_env_vars: Dict[str, Any]
restrict_access: bool
dbt_cloud: Dict[str, Any]
flags: Dict[str, Any]

@property
def all_source_paths(self) -> List[str]:
Expand Down Expand Up @@ -735,6 +733,7 @@ def to_project_config(self, with_packages=False):
"require-dbt-version": [v.to_version_string() for v in self.dbt_version],
"restrict-access": self.restrict_access,
"dbt-cloud": self.dbt_cloud,
"flags": self.flags,
}
)
if self.query_comment:
Expand Down
1 change: 1 addition & 0 deletions core/dbt/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def from_parts(
log_cache_events=log_cache_events,
dependencies=dependencies,
dbt_cloud=project.dbt_cloud,
flags=project.flags,
)

# Called by 'load_projects' in this class
Expand Down
1 change: 1 addition & 0 deletions core/dbt/contracts/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ class Project(dbtClassMixin):
query_comment: Optional[Union[QueryComment, NoValue, str]] = field(default_factory=NoValue)
restrict_access: bool = False
dbt_cloud: Optional[Dict[str, Any]] = None
flags: Dict[str, Any] = field(default_factory=dict)

class Config(dbtMashConfig):
# These tell mashumaro to use aliases for jsonschema and for "from_dict"
Expand Down
55 changes: 53 additions & 2 deletions tests/unit/config/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import unittest
import pytest
from typing import Any, Dict

from unittest import mock

Expand All @@ -11,7 +12,7 @@
import dbt.exceptions
from dbt.adapters.factory import load_plugin
from dbt.adapters.contracts.connection import QueryComment, DEFAULT_QUERY_COMMENT
from dbt.config.project import Project
from dbt.config.project import Project, _get_required_version
from dbt.contracts.project import PackageConfig, LocalPackage, GitPackage
from dbt.node_types import NodeType
from dbt_common.exceptions import DbtRuntimeError
Expand Down Expand Up @@ -45,7 +46,7 @@ def test_fixture_paths(self, project: Project):
def test__str__(self, project: Project):
assert (
str(project)
== "{'name': 'test_project', 'version': 1.0, 'project-root': 'doesnt/actually/exist', 'profile': 'test_profile', 'model-paths': ['models'], 'macro-paths': ['macros'], 'seed-paths': ['seeds'], 'test-paths': ['tests'], 'analysis-paths': ['analyses'], 'docs-paths': ['docs'], 'asset-paths': ['assets'], 'target-path': 'target', 'snapshot-paths': ['snapshots'], 'clean-targets': ['target'], 'log-path': 'path/to/project/logs', 'quoting': {'database': True, 'schema': True, 'identifier': True}, 'models': {}, 'on-run-start': [], 'on-run-end': [], 'dispatch': [{'macro_namespace': 'dbt_utils', 'search_order': ['test_project', 'dbt_utils']}], 'seeds': {}, 'snapshots': {}, 'sources': {}, 'data_tests': {}, 'unit_tests': {}, 'metrics': {}, 'semantic-models': {}, 'saved-queries': {}, 'exposures': {}, 'vars': {}, 'require-dbt-version': ['=0.0.0'], 'restrict-access': False, 'dbt-cloud': {}, 'query-comment': {'comment': \"\\n{%- set comment_dict = {} -%}\\n{%- do comment_dict.update(\\n app='dbt',\\n dbt_version=dbt_version,\\n profile_name=target.get('profile_name'),\\n target_name=target.get('target_name'),\\n) -%}\\n{%- if node is not none -%}\\n {%- do comment_dict.update(\\n node_id=node.unique_id,\\n ) -%}\\n{% else %}\\n {# in the node context, the connection name is the node_id #}\\n {%- do comment_dict.update(connection_name=connection_name) -%}\\n{%- endif -%}\\n{{ return(tojson(comment_dict)) }}\\n\", 'append': False, 'job-label': False}, 'packages': []}"
== "{'name': 'test_project', 'version': 1.0, 'project-root': 'doesnt/actually/exist', 'profile': 'test_profile', 'model-paths': ['models'], 'macro-paths': ['macros'], 'seed-paths': ['seeds'], 'test-paths': ['tests'], 'analysis-paths': ['analyses'], 'docs-paths': ['docs'], 'asset-paths': ['assets'], 'target-path': 'target', 'snapshot-paths': ['snapshots'], 'clean-targets': ['target'], 'log-path': 'path/to/project/logs', 'quoting': {}, 'models': {}, 'on-run-start': [], 'on-run-end': [], 'dispatch': [{'macro_namespace': 'dbt_utils', 'search_order': ['test_project', 'dbt_utils']}], 'seeds': {}, 'snapshots': {}, 'sources': {}, 'data_tests': {}, 'unit_tests': {}, 'metrics': {}, 'semantic-models': {}, 'saved-queries': {}, 'exposures': {}, 'vars': {}, 'require-dbt-version': ['=0.0.0'], 'restrict-access': False, 'dbt-cloud': {}, 'flags': {}, 'query-comment': {'comment': \"\\n{%- set comment_dict = {} -%}\\n{%- do comment_dict.update(\\n app='dbt',\\n dbt_version=dbt_version,\\n profile_name=target.get('profile_name'),\\n target_name=target.get('target_name'),\\n) -%}\\n{%- if node is not none -%}\\n {%- do comment_dict.update(\\n node_id=node.unique_id,\\n ) -%}\\n{% else %}\\n {# in the node context, the connection name is the node_id #}\\n {%- do comment_dict.update(connection_name=connection_name) -%}\\n{%- endif -%}\\n{{ return(tojson(comment_dict)) }}\\n\", 'append': False, 'job-label': False}, 'packages': []}"
)

def test_get_selector(self, project: Project):
Expand Down Expand Up @@ -537,3 +538,53 @@ def setUp(self):
def test_setting_multiple_flags(self):
with pytest.raises(dbt.exceptions.DbtProjectError):
set_from_args(self.args, None)


class TestGetRequiredVersion:
@pytest.fixture
def project_dict(self) -> Dict[str, Any]:
return {
"name": "test_project",
"require-dbt-version": ">0.0.0",
}

def test_supported_version(self, project_dict: Dict[str, Any]) -> None:
specifiers = _get_required_version(project_dict=project_dict, verify_version=True)
assert set(x.to_version_string() for x in specifiers) == {">0.0.0"}

def test_unsupported_version(self, project_dict: Dict[str, Any]) -> None:
project_dict["require-dbt-version"] = ">99999.0.0"
with pytest.raises(
dbt.exceptions.DbtProjectError, match="This version of dbt is not supported"
):
_get_required_version(project_dict=project_dict, verify_version=True)

def test_unsupported_version_no_check(self, project_dict: Dict[str, Any]) -> None:
project_dict["require-dbt-version"] = ">99999.0.0"
specifiers = _get_required_version(project_dict=project_dict, verify_version=False)
assert set(x.to_version_string() for x in specifiers) == {">99999.0.0"}

def test_supported_version_range(self, project_dict: Dict[str, Any]) -> None:
project_dict["require-dbt-version"] = [">0.0.0", "<=99999.0.0"]
specifiers = _get_required_version(project_dict=project_dict, verify_version=True)
assert set(x.to_version_string() for x in specifiers) == {">0.0.0", "<=99999.0.0"}

def test_unsupported_version_range(self, project_dict: Dict[str, Any]) -> None:
project_dict["require-dbt-version"] = [">0.0.0", "<=0.0.1"]
with pytest.raises(
dbt.exceptions.DbtProjectError, match="This version of dbt is not supported"
):
_get_required_version(project_dict=project_dict, verify_version=True)

def test_unsupported_version_range_no_check(self, project_dict: Dict[str, Any]) -> None:
project_dict["require-dbt-version"] = [">0.0.0", "<=0.0.1"]
specifiers = _get_required_version(project_dict=project_dict, verify_version=False)
assert set(x.to_version_string() for x in specifiers) == {">0.0.0", "<=0.0.1"}

def test_impossible_version_range(self, project_dict: Dict[str, Any]) -> None:
project_dict["require-dbt-version"] = [">99999.0.0", "<=0.0.1"]
with pytest.raises(
dbt.exceptions.DbtProjectError,
match="The package version requirement can never be satisfied",
):
_get_required_version(project_dict=project_dict, verify_version=True)
1 change: 1 addition & 0 deletions tests/unit/utils/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,5 @@ def project(selector_config: SelectorConfig) -> Project:
project_env_vars={},
restrict_access=False,
dbt_cloud={},
flags={},
)

0 comments on commit 8ccc7d0

Please sign in to comment.