diff --git a/.changes/unreleased/Features-20231116-234049.yaml b/.changes/unreleased/Features-20231116-234049.yaml new file mode 100644 index 00000000000..786c15311a4 --- /dev/null +++ b/.changes/unreleased/Features-20231116-234049.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support --empty flag for schema-only dry runs +time: 2023-11-16T23:40:49.96651-05:00 +custom: + Author: michelleark + Issue: "8971" diff --git a/core/dbt/adapters/base/relation.py b/core/dbt/adapters/base/relation.py index 67a50d9061f..55da847a716 100644 --- a/core/dbt/adapters/base/relation.py +++ b/core/dbt/adapters/base/relation.py @@ -36,6 +36,7 @@ class BaseRelation(FakeAPIObject, Hashable): include_policy: Policy = field(default_factory=lambda: Policy()) quote_policy: Policy = field(default_factory=lambda: Policy()) dbt_created: bool = False + limit: Optional[int] = None # register relation types that can be renamed for the purpose of replacing relations using stages and backups # adding a relation type here also requires defining the associated rename macro @@ -194,6 +195,15 @@ def render(self) -> str: # if there is nothing set, this will return the empty string. return ".".join(part for _, part in self._render_iterator() if part is not None) + def render_limited(self) -> str: + rendered = self.render() + if self.limit is None: + return rendered + elif self.limit == 0: + return f"(select * from {rendered} where false limit 0) _dbt_limit_subq" + else: + return f"(select * from {rendered} limit {self.limit}) _dbt_limit_subq" + def quoted(self, identifier): return "{quote_char}{identifier}{quote_char}".format( quote_char=self.quote_character, @@ -227,13 +237,11 @@ def create_ephemeral_from_node( cls: Type[Self], config: HasQuoting, node: ManifestNode, + limit: Optional[int], ) -> Self: # Note that ephemeral models are based on the name. identifier = cls.add_ephemeral_prefix(node.name) - return cls.create( - type=cls.CTE, - identifier=identifier, - ).quote(identifier=False) + return cls.create(type=cls.CTE, identifier=identifier, limit=limit).quote(identifier=False) @classmethod def create_from_node( @@ -313,7 +321,7 @@ def __hash__(self) -> int: return hash(self.render()) def __str__(self) -> str: - return self.render() + return self.render() if self.limit is None else self.render_limited() @property def database(self) -> Optional[str]: diff --git a/core/dbt/cli/main.py b/core/dbt/cli/main.py index 7d4560a7910..2454a15a564 100644 --- a/core/dbt/cli/main.py +++ b/core/dbt/cli/main.py @@ -342,6 +342,7 @@ def docs_serve(ctx, **kwargs): @p.profile @p.profiles_dir @p.project_dir +@p.empty @p.select @p.selector @p.inline @@ -599,6 +600,7 @@ def parse(ctx, **kwargs): @p.profile @p.profiles_dir @p.project_dir +@p.empty @p.select @p.selector @p.state diff --git a/core/dbt/cli/params.py b/core/dbt/cli/params.py index 1898815a724..907f485f3b3 100644 --- a/core/dbt/cli/params.py +++ b/core/dbt/cli/params.py @@ -90,6 +90,12 @@ is_flag=True, ) +empty = click.option( + "--empty/--no-empty", + envvar="DBT_EMPTY", + help="If specified, limit input refs and sources to zero rows.", + is_flag=True, +) enable_legacy_logger = click.option( "--enable-legacy-logger/--no-enable-legacy-logger", diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index febc21a546f..1bbd37af4f2 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -216,6 +216,10 @@ def current_project(self): def Relation(self): return self.db_wrapper.Relation + @property + def resolve_limit(self) -> Optional[int]: + return 0 if getattr(self.config.args, "EMPTY", False) else None + @abc.abstractmethod def __call__(self, *args: str) -> Union[str, RelationProxy, MetricReference]: pass @@ -531,9 +535,11 @@ def resolve( def create_relation(self, target_model: ManifestNode) -> RelationProxy: if target_model.is_ephemeral_model: self.model.set_cte(target_model.unique_id, None) - return self.Relation.create_ephemeral_from_node(self.config, target_model) + return self.Relation.create_ephemeral_from_node( + self.config, target_model, limit=self.resolve_limit + ) else: - return self.Relation.create_from(self.config, target_model) + return self.Relation.create_from(self.config, target_model, limit=self.resolve_limit) def validate( self, @@ -590,7 +596,7 @@ def resolve(self, source_name: str, table_name: str): target_kind="source", disabled=(isinstance(target_source, Disabled)), ) - return self.Relation.create_from_source(target_source) + return self.Relation.create_from_source(target_source, limit=self.resolve_limit) # metric` implementations diff --git a/tests/adapter/dbt/tests/adapter/empty/test_empty.py b/tests/adapter/dbt/tests/adapter/empty/test_empty.py new file mode 100644 index 00000000000..a014a640c1f --- /dev/null +++ b/tests/adapter/dbt/tests/adapter/empty/test_empty.py @@ -0,0 +1,75 @@ +import pytest +from dbt.tests.util import run_dbt, relation_from_name + + +model_input_sql = """ +select 1 as id +""" + +ephemeral_model_input_sql = """ +{{ config(materialized='ephemeral') }} +select 2 as id +""" + +raw_source_csv = """id +3 +""" + + +model_sql = """ +select * +from {{ ref('model_input') }} +union all +select * +from {{ ref('ephemeral_model_input') }} +union all +select * +from {{ source('seed_sources', 'raw_source') }} +""" + + +schema_sources_yml = """ +sources: + - name: seed_sources + schema: "{{ target.schema }}" + tables: + - name: raw_source +""" + + +class BaseTestEmpty: + @pytest.fixture(scope="class") + def seeds(self): + return { + "raw_source.csv": raw_source_csv, + } + + @pytest.fixture(scope="class") + def models(self): + return { + "model_input.sql": model_input_sql, + "ephemeral_model_input.sql": ephemeral_model_input_sql, + "model.sql": model_sql, + "sources.yml": schema_sources_yml, + } + + def assert_row_count(self, project, relation_name: str, expected_row_count: int): + relation = relation_from_name(project.adapter, relation_name) + result = project.run_sql(f"select count(*) as num_rows from {relation}", fetch="one") + assert result[0] == expected_row_count + + def test_run_with_empty(self, project): + # create source from seed + run_dbt(["seed"]) + + # run without empty - 3 expected rows in output - 1 from each input + run_dbt(["run"]) + self.assert_row_count(project, "model", 3) + + # run with empty - 0 expected rows in output + run_dbt(["run", "--empty"]) + self.assert_row_count(project, "model", 0) + + +class TestEmpty(BaseTestEmpty): + pass diff --git a/tests/unit/test_providers.py b/tests/unit/test_providers.py new file mode 100644 index 00000000000..ca90580ba3f --- /dev/null +++ b/tests/unit/test_providers.py @@ -0,0 +1,104 @@ +import pytest +from unittest import mock + +from dbt.adapters.base import BaseRelation +from dbt.context.providers import BaseResolver, RuntimeRefResolver, RuntimeSourceResolver +from dbt.contracts.graph.unparsed import Quoting + + +class TestBaseResolver: + class ResolverSubclass(BaseResolver): + def __call__(self, *args: str): + pass + + @pytest.fixture + def resolver(self): + return self.ResolverSubclass( + db_wrapper=mock.Mock(), + model=mock.Mock(), + config=mock.Mock(), + manifest=mock.Mock(), + ) + + @pytest.mark.parametrize( + "empty,expected_resolve_limit", + [(False, None), (True, 0)], + ) + def test_resolve_limit(self, resolver, empty, expected_resolve_limit): + resolver.config.args.EMPTY = empty + + assert resolver.resolve_limit == expected_resolve_limit + + +class TestRuntimeRefResolver: + @pytest.fixture + def resolver(self): + mock_db_wrapper = mock.Mock() + mock_db_wrapper.Relation = BaseRelation + + return RuntimeRefResolver( + db_wrapper=mock_db_wrapper, + model=mock.Mock(), + config=mock.Mock(), + manifest=mock.Mock(), + ) + + @pytest.mark.parametrize( + "empty,is_ephemeral_model,expected_limit", + [ + (False, False, None), + (True, False, 0), + (False, True, None), + (True, True, 0), + ], + ) + def test_create_relation_with_empty(self, resolver, empty, is_ephemeral_model, expected_limit): + # setup resolver and input node + resolver.config.args.EMPTY = empty + mock_node = mock.Mock() + mock_node.database = "test" + mock_node.schema = "test" + mock_node.identifier = "test" + mock_node.alias = "test" + mock_node.is_ephemeral_model = is_ephemeral_model + + # create limited relation + with mock.patch("dbt.adapters.base.relation.ParsedNode", new=mock.Mock): + relation = resolver.create_relation(mock_node) + assert relation.limit == expected_limit + + +class TestRuntimeSourceResolver: + @pytest.fixture + def resolver(self): + mock_db_wrapper = mock.Mock() + mock_db_wrapper.Relation = BaseRelation + + return RuntimeSourceResolver( + db_wrapper=mock_db_wrapper, + model=mock.Mock(), + config=mock.Mock(), + manifest=mock.Mock(), + ) + + @pytest.mark.parametrize( + "empty,expected_limit", + [ + (False, None), + (True, 0), + ], + ) + def test_create_relation_with_empty(self, resolver, empty, expected_limit): + # setup resolver and input source + resolver.config.args.EMPTY = empty + + mock_source = mock.Mock() + mock_source.database = "test" + mock_source.schema = "test" + mock_source.identifier = "test" + mock_source.quoting = Quoting() + resolver.manifest.resolve_source.return_value = mock_source + + # create limited relation + relation = resolver.resolve("test", "test") + assert relation.limit == expected_limit diff --git a/tests/unit/test_relation.py b/tests/unit/test_relation.py index 94995958ba6..e4d90c4b504 100644 --- a/tests/unit/test_relation.py +++ b/tests/unit/test_relation.py @@ -40,3 +40,29 @@ def test_can_be_replaced(relation_type, result): def test_can_be_replaced_default(): my_relation = BaseRelation.create(type=RelationType.View) assert my_relation.can_be_replaced is False + + +@pytest.mark.parametrize( + "limit,expected_result", + [ + (None, '"test_database"."test_schema"."test_identifier"'), + ( + 0, + '(select * from "test_database"."test_schema"."test_identifier" where false limit 0) _dbt_limit_subq', + ), + ( + 1, + '(select * from "test_database"."test_schema"."test_identifier" limit 1) _dbt_limit_subq', + ), + ], +) +def test_render_limited(limit, expected_result): + my_relation = BaseRelation.create( + database="test_database", + schema="test_schema", + identifier="test_identifier", + limit=limit, + ) + actual_result = my_relation.render_limited() + assert actual_result == expected_result + assert str(my_relation) == expected_result